From 179082a35f42a115f673d6c335c37df695909ea1 Mon Sep 17 00:00:00 2001 From: "Arun C. Murthy" <124712100+acmatscale@users.noreply.github.com> Date: Tue, 18 Jul 2023 11:29:46 -0700 Subject: [PATCH 001/425] Enhancements to hosting docs. (#124) * Enhancements to hosting docs. * Fixed capitalization. --- docs/guides/completions.md | 1 + docs/guides/self_hosting.md | 6 +++--- docs/pricing.md | 15 ++++++++------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/guides/completions.md b/docs/guides/completions.md index e5b6fdde..ff2b72ac 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -11,6 +11,7 @@ that can be used for producing completions to prompts. An example API call looks as follows: +=== "Completion call in Python" ```python from llmengine import Completion diff --git a/docs/guides/self_hosting.md b/docs/guides/self_hosting.md index 39fa5bd4..027428c9 100644 --- a/docs/guides/self_hosting.md +++ b/docs/guides/self_hosting.md @@ -1,7 +1,7 @@ -# [Experimental] Self Hosting +# Self Hosting _[Experimental]_ **This guide is currently highly experimental. Instructions are subject to change as we improve support for self-hosting.** -We provide a Helm chart that deploys LLM Engine to an [Elastic Kubernetes Cluster](https://aws.amazon.com/eks/). This Helm chart should be configured to connect to dependencies (such as a PostgreSQL database) that you may already have available in your environment. +We provide a Helm chart that deploys LLM Engine to an [Elastic Kubernetes Cluster](https://aws.amazon.com/eks/) (EKS) in [AWS](https://aws.amazon.com/). This Helm chart should be configured to connect to dependencies (such as a PostgreSQL database) that you may already have available in your environment. The only portions of the Helm chart that are production ready are the parts that configure and manage LLM Server itself (not PostgreSQL, IAM, etc.) @@ -74,7 +74,7 @@ The LLM Engine server will an IAM role to perform various AWS operations. This r | `sqs:ListQueues` | `*` | | `ecr:BatchGetImage`, `ecr:DescribeImages`, `ecr:GetDownloadUrlForLayer`, `ecr:ListImages` | `${ecr_repository_arn}` | -# Helm Chart +## Helm Chart Now that all dependencies have been installed and configured, we can run the provided Helm chart. The values in the Helm chart will need to correspond with the resources described in the Dependencies section. Ensure that Helm V3 is installed [instructions](https://helm.sh/docs/intro/install/) and can connect to the EKS cluster. Users should be able to install the chart with `helm install llm-engine llm-engine -f llm-engine/values_sample.yaml -n `. diff --git a/docs/pricing.md b/docs/pricing.md index 1929cc8c..b61923b9 100644 --- a/docs/pricing.md +++ b/docs/pricing.md @@ -1,15 +1,16 @@ # Pricing -LLM Engine is being offered initially as a free preview. LLM Engine is an open-source project and free self-hosting will always be an option. +LLM Engine is an open-source project and free [self-hosting](../guides/self_hosting) will always be an option. -## Hosted Models +A hosted option for LLM Engine is being offered initially as a free preview via [Scale](https://scale.com/) [Spellbook](https://spellbook.scale.com/). -Once the limited preview period has ended, billing for hosted models will be managed through Scale's [Spellbook](https://spellbook.scale.com/settings) product. +## Self-Hosted Models -Scale Spellbook leverages usage-based spending, billed to a credit card. +We are committed to supporting the open-source community. [Self-hosting](../guides/self_hosting) LLM Engine will remain free and open-source. -Scale will share usage-based pricing before completing the limited preview to all users. +We would love [contributions](../contributing) from the community make this even more amazing! +## Hosted Models -## Self-Hosted Models +Once the limited preview period has ended, billing for hosted models will be managed through the Scale [Spellbook](https://spellbook.scale.com/settings) product. -We are committed to supporting the open-source community. Self-hosting LLM Engine will remain free and open-source. +Scale Spellbook leverages usage-based spending, billed to a credit card. Details on usage-based pricing will be shared with everyone before completing the limited preview. From c4531c9185dab44910a8537f952d5a529ea3e118 Mon Sep 17 00:00:00 2001 From: "Ray (Jui-Tse) Hung" <135046452+ruizehung-scale@users.noreply.github.com> Date: Tue, 18 Jul 2023 11:56:25 -0700 Subject: [PATCH 002/425] Fix CI run_unit_tests_server (#123) * Update the unit test command in server/tests/README.md * Avoid calling get_kubernetes_cluster_version() in circleci --- .../gateways/resources/k8s_endpoint_resource_delegate.py | 4 ++-- server/tests/README.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 7ec3d19d..b6c4d2d2 100644 --- a/server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -132,7 +132,7 @@ def get_kubernetes_autoscaling_client(): # pragma: no cover else: _kubernetes_autoscaling_api = None if not _kubernetes_autoscaling_api: - cluster_version = get_kubernetes_cluster_version() + cluster_version = get_kubernetes_cluster_version() if not CIRCLECI else "1.26" # For k8s cluster versions 1.23 - 1.25 we need to use the v2beta2 api # For 1.26+ v2beta2 has been deperecated and merged into v2 if version.parse(cluster_version) >= version.parse("1.26"): @@ -1081,7 +1081,7 @@ async def _create_or_update_resources( ModelEndpointType.SYNC, ModelEndpointType.STREAMING, }: - cluster_version = get_kubernetes_cluster_version() + cluster_version = get_kubernetes_cluster_version() if not CIRCLECI else "1.26" # For k8s cluster versions 1.23 - 1.25 we need to use the v2beta2 api # For 1.26+ v2beta2 has been deperecated and merged into v2 if version.parse(cluster_version) >= version.parse("1.26"): diff --git a/server/tests/README.md b/server/tests/README.md index e519b96b..2c442472 100644 --- a/server/tests/README.md +++ b/server/tests/README.md @@ -2,6 +2,6 @@ ```shell pushd ../ -PYTHONPATH=llm_engine WORKSPACE=. python3 -m pytest llm_engine/tests --cov=llm_engine +PYTHONPATH=llm_engine_server WORKSPACE=. python3 -m pytest tests --cov=llm_engine_server popd ``` From e92bbe4b8aab4c4736bdca5ad8e1b3477f9ff0b3 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 18 Jul 2023 19:36:39 -0400 Subject: [PATCH 003/425] Updates to cookbook (#127) * try lower lr + # epochs * update cookbook --- docs/examples/finetuning.ipynb | 90 +++++++++++----------------------- 1 file changed, 29 insertions(+), 61 deletions(-) diff --git a/docs/examples/finetuning.ipynb b/docs/examples/finetuning.ipynb index 48669318..5f8a436d 100644 --- a/docs/examples/finetuning.ipynb +++ b/docs/examples/finetuning.ipynb @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -79,7 +79,7 @@ "\"From: dougb@comm.mot.com (Doug Bank)\\nSubject: Re: Info needed for Cleveland tickets\\nReply-To: dougb@ecs.comm.mot.com\\nOrganization: Motorola Land Mobile Products Sector\\nDistribution: usa\\nNntp-Posting-Host: 145.1.146.35\\nLines: 17\\n\\nIn article <1993Apr1.234031.4950@leland.Stanford.EDU>, bohnert@leland.Stanford.EDU (matthew bohnert) writes:\\n\\n|> I'm going to be in Cleveland Thursday, April 15 to Sunday, April 18.\\n|> Does anybody know if the Tribe will be in town on those dates, and\\n|> if so, who're they playing and if tickets are available?\\n\\nThe tribe will be in town from April 16 to the 19th.\\nThere are ALWAYS tickets available! (Though they are playing Toronto,\\nand many Toronto fans make the trip to Cleveland as it is easier to\\nget tickets in Cleveland than in Toronto. Either way, I seriously\\ndoubt they will sell out until the end of the season.)\\n\\n-- \\nDoug Bank Private Systems Division\\ndougb@ecs.comm.mot.com Motorola Communications Sector\\ndougb@nwu.edu Schaumburg, Illinois\\ndougb@casbah.acns.nwu.edu 708-576-8207\"" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -90,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -102,7 +102,7 @@ "Name: count, dtype: int64" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -125,7 +125,7 @@ "Name: count, dtype: int64" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -143,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -167,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -234,7 +234,7 @@ "4 baseball Prompt: Subject: Let it be Known\\nFrom: \n", " \n", " \n", - " \n", - " 974\n", - " From: maX <maX@maxim.rinaco.msk.su>\\nSubject: ...\n", - " hockey\n", - " \n", - " \n", - " \n", - " 988\n", - " From: jca2@cec1.wustl.edu (Joseph Charles Achk...\n", - " hockey\n", - " NHL\n", - " \n", - " \n", - " 997\n", - " From: apland@mala.bc.ca (Ron Apland)\\nSubject:...\n", - " hockey\n", - " \n", - " \n", " \n", "\n", "" ], "text/plain": [ - " raw_prompt response \\\n", - "974 From: maX \\nSubject: ... hockey \n", - "988 From: jca2@cec1.wustl.edu (Joseph Charles Achk... hockey \n", - "997 From: apland@mala.bc.ca (Ron Apland)\\nSubject:... hockey \n", - "\n", - " predicted_response \n", - "974 \n", - "988 NHL \n", - "997 " + "Empty DataFrame\n", + "Columns: [raw_prompt, response, predicted_response]\n", + "Index: []" ] }, - "execution_count": 23, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } From babc4bd1d031b422cb5b72f1245341ea21ed56ad Mon Sep 17 00:00:00 2001 From: Russell Kaplan Date: Tue, 18 Jul 2023 16:56:37 -0700 Subject: [PATCH 004/425] Update README.md (#126) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 51feafdb..68a4f3ef 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # ⚡ LLM Engine ⚡ -**The open source engine for fine-tuning large language models**. +**The open source engine for fine-tuning and serving large language models**. Scale's LLM Engine is the easiest way to customize and serve LLMs. In LLM Engine, models can be accessed via Scale's hosted version or by using the Helm charts in this repository to run model inference and fine-tuning in your own infrastructure. From 0f570d345463647659a9b971f4a64a4c2c4ed20e Mon Sep 17 00:00:00 2001 From: "Ray (Jui-Tse) Hung" <135046452+ruizehung-scale@users.noreply.github.com> Date: Tue, 18 Jul 2023 17:09:18 -0700 Subject: [PATCH 005/425] Play With It section in self hosting doc (#128) * Add a Play With It section in docs/guides/self_hosting.md to document to how test sending a request and getting response via pod port forwarding * Change llm-engine-image-cache pod age to 18m * pod name -> pod names * - 5000 -> - NAMESPACE_YOU_INSTALL_LLM_ENGINE -> NAMESPACE_WHERE_LLM_ENGINE_IS_INSTALLED --- docs/guides/self_hosting.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/guides/self_hosting.md b/docs/guides/self_hosting.md index 027428c9..23f66579 100644 --- a/docs/guides/self_hosting.md +++ b/docs/guides/self_hosting.md @@ -115,3 +115,29 @@ Below are the configurations to specify in the `values_sample.yaml` file. | config.values.llm_engine.cache_redis_url | The full url for the redis cluster you wish to connect | Yes | | config.values.llm_engine.s3_file_llm_fine_tuning_job_repository | The S3 URI for the S3 bucket/key that you wish to save fine-tuned assets | Yes | | config.values.datadog_trace_enabled | Whether to enable datadog tracing, datadog must be installed in the cluster | No | + +## Play With It +Once `helm install` succeeds, you can forward port `5000` from a `llm-engine` pod and test sending requests to it. + +First, see a list of pods in the namespace that you performed `helm install` in: +``` +$ kubectl get pods -n +NAME READY STATUS RESTARTS AGE +llm-engine-668679554-9q4wj 1/1 Running 0 18m +llm-engine-668679554-xfhxx 1/1 Running 0 18m +llm-engine-cacher-5f8b794585-fq7dj 1/1 Running 0 18m +llm-engine-endpoint-builder-5cd6bf5bbc-sm254 1/1 Running 0 18m +llm-engine-image-cache-a10-sw4pg 1/1 Running 0 18m +``` +Note the pod names you see may be different. + +Forward a port from a `llm-engine` pod: +``` +$ kubectl port-forward pod/llm-engine- 5000:5000 -n +``` + +Then, try sending a request to get LLM model endpoints for `test-user-id`. You should get a response with empty list: +``` +$ curl -X GET -H "Content-Type: application/json" -u "test-user-id:" "http://localhost:5000/v1/llm/model-endpoints" +{"model_endpoints":[]}% +``` \ No newline at end of file From e3ba70be81e7b289032e523def02d2c3aed9013a Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 18 Jul 2023 18:35:52 -0700 Subject: [PATCH 006/425] Documentation update for llama-v2 (#129) * Documentation update for llama-v2 * fix name * Update finetuning.ipynb --- clients/python/README.md | 4 +- clients/python/llmengine/completion.py | 8 ++-- clients/python/llmengine/fine_tuning.py | 4 +- clients/python/llmengine/model.py | 10 ++--- docs/examples/finetuning.ipynb | 4 +- docs/guides/completions.md | 51 +++++++++++++------------ docs/guides/fine_tuning.md | 44 +++++++++++---------- docs/guides/rate_limits.md | 26 +++++++------ docs/model_zoo.md | 1 + 9 files changed, 79 insertions(+), 73 deletions(-) diff --git a/clients/python/README.md b/clients/python/README.md index 9d4b7dbc..21befcb4 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -15,7 +15,7 @@ pip install scale-llm-engine ### Usage If you are using LLM Engine, you can get your API key from -[https://spellbook.scale.com/settings](https://spellbook.scale.com/settings). +[https://spellbook.scale.com/settings](https://spellbook.scale.com/settings). Set the `SCALE_API_KEY` environment variable to your API key. If you are using your own infrastructure, you can set the @@ -26,7 +26,7 @@ self-hosted `llmengine` endpoint. from llmengine import Completion response = Completion.create( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 6c6f2039..661ac30e 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -78,7 +78,7 @@ async def acreate( async def main(): response = await Completion.acreate( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, @@ -108,7 +108,7 @@ async def main(): async def main(): stream = await Completion.acreate( - model="llama-7b", + model="llama-2-7b", prompt="why is the sky blue?", max_new_tokens=5, temperature=0.2, @@ -224,7 +224,7 @@ def create( from llmengine import Completion response = Completion.create( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, @@ -250,7 +250,7 @@ def create( from llmengine import Completion stream = Completion.create( - model="llama-7b", + model="llama-2-7b", prompt="why is the sky blue?", max_new_tokens=5, temperature=0.2, diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index 62c147a6..d2c3d96f 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -71,7 +71,7 @@ def create( will be formatted like `"[model].[suffix].[YYYY-MM-DD-HH-MM-SS]"`. If absent, the fine-tuned model name will be formatted `"[model].[YYYY-MM-DD-HH-MM-SS]"`. For example, if `suffix` is `"my-experiment"`, the fine-tuned model name could be - `"llama-7b.my-experiment.2023-07-17-23-01-50"`. + `"llama-2-7b.my-experiment.2023-07-17-23-01-50"`. Returns: CreateFineTuneResponse: an object that contains the ID of the created fine-tuning job @@ -114,7 +114,7 @@ def create( from llmengine import FineTune response = FineTune.create( - model="llama-7b", + model="llama-2-7b", training_file="https://my-bucket.s3.us-west-2.amazonaws.com/path/to/training-file.csv", ) diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index bbf15843..0b21496b 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -41,7 +41,7 @@ def get( ```python from llmengine import Model - response = Model.get("llama-7b.suffix.2023-07-18-12-00-00") + response = Model.get("llama-2-7b.suffix.2023-07-18-12-00-00") print(response.json()) ``` @@ -50,7 +50,7 @@ def get( ```json { "id": null, - "name": "llama-7b.suffix.2023-07-18-12-00-00", + "name": "llama-2-7b.suffix.2023-07-18-12-00-00", "model_name": null, "source": "hugging_face", "inference_framework": "text_generation_inference", @@ -92,7 +92,7 @@ def list(cls) -> ListLLMEndpointsResponse: "model_endpoints": [ { "id": null, - "name": "llama-7b.suffix.2023-07-18-12-00-00", + "name": "llama-2-7b.suffix.2023-07-18-12-00-00", "model_name": null, "source": "hugging_face", "inference_framework": "text_generation_inference", @@ -103,7 +103,7 @@ def list(cls) -> ListLLMEndpointsResponse: }, { "id": null, - "name": "llama-7b", + "name": "llama-2-7b", "model_name": null, "source": "hugging_face", "inference_framework": "text_generation_inference", @@ -163,7 +163,7 @@ def delete(cls, model: str) -> DeleteLLMEndpointResponse: ```python from llmengine import Model - response = Model.delete("llama-7b.suffix.2023-07-18-12-00-00") + response = Model.delete("llama-2-7b.suffix.2023-07-18-12-00-00") print(response.json()) ``` diff --git a/docs/examples/finetuning.ipynb b/docs/examples/finetuning.ipynb index 5f8a436d..819b85fa 100644 --- a/docs/examples/finetuning.ipynb +++ b/docs/examples/finetuning.ipynb @@ -304,7 +304,7 @@ "outputs": [], "source": [ "create_fine_tune_response = FineTune.create(\n", - " model=\"llama-7b\",\n", + " model=\"llama-2-7b\",\n", " training_file=\"https://scale-demo-datasets.s3.us-west-2.amazonaws.com/sports/sports_training_dataset.csv\",\n", " validation_file=None,\n", " hyperparameters={\"epochs\": \"1\", \"lr\": \"0.0002\"},\n", @@ -355,7 +355,7 @@ "metadata": {}, "outputs": [], "source": [ - "your_fine_tuned_model = \"llama-7b.my-first-finetune.2023-07-18-20-28-50\" # Note: you will have a different model!" + "your_fine_tuned_model = \"llama-2-7b.my-first-finetune.2023-07-18-20-28-50\" # Note: you will have a different model!" ] }, { diff --git a/docs/guides/completions.md b/docs/guides/completions.md index ff2b72ac..c9747889 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -1,10 +1,10 @@ -Language Models are trained to predict natural language and provide text outputs as a response -to their inputs. The inputs are called _prompts_ and outputs are referred to as _completions_. -LLMs take the input _prompts_ and chunk them into smaller units called _tokens_ to process and -generate language. Tokens may include trailing spaces and even sub-words. This process is +Language Models are trained to predict natural language and provide text outputs as a response +to their inputs. The inputs are called _prompts_ and outputs are referred to as _completions_. +LLMs take the input _prompts_ and chunk them into smaller units called _tokens_ to process and +generate language. Tokens may include trailing spaces and even sub-words. This process is language dependent. -Scale's LLM Engine provides access to open source language models (see [Model Zoo](../../model_zoo)) +Scale's LLM Engine provides access to open source language models (see [Model Zoo](../../model_zoo)) that can be used for producing completions to prompts. ## Completion API call @@ -16,7 +16,7 @@ An example API call looks as follows: from llmengine import Completion response = Completion.create( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, @@ -29,11 +29,11 @@ print(response.output.text) # ________ and I am a ________ ``` -- **model:** The LLM you want to use (see [Model Zoo](../../model_zoo)). -- **prompt:** The main input for the LLM to respond to. -- **max_new_tokens:** The maximum number of tokens to generate in the chat completion. -- **temperature:** The sampling temperature to use. Higher values make the output more random, -while lower values will make it more focused and deterministic. +- **model:** The LLM you want to use (see [Model Zoo](../../model_zoo)). +- **prompt:** The main input for the LLM to respond to. +- **max_new_tokens:** The maximum number of tokens to generate in the chat completion. +- **temperature:** The sampling temperature to use. Higher values make the output more random, + while lower values will make it more focused and deterministic. See the full [Completion API reference documentation](../../api/python_client/#llmengine.Completion) to learn more. @@ -42,11 +42,11 @@ See the full [Completion API reference documentation](../../api/python_client/#l An example Completion API response looks as follows: === "Response in JSON" - ```python +`python >>> print(response.json()) - ``` - Example output: - ```json + ` +Example output: +`json { "request_id": "c4bf0732-08e0-48a8-8b44-dfe8d4702fb0", "output": { @@ -54,20 +54,19 @@ An example Completion API response looks as follows: "num_completion_tokens": 10 } } - ``` + ` === "Response in Python" - ```python +`python >>> print(response.output.text) - ``` - Example output: - ``` - _______ and I am a _______ - ``` + ` +Example output: +` _______ and I am a _______ + ` ## Token streaming -The Completions API supports token streaming to reduce _perceived_ latency for certain -applications. When streaming, tokens will be sent as data-only +The Completions API supports token streaming to reduce _perceived_ latency for certain +applications. When streaming, tokens will be sent as data-only [server-side events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format). To enable token streaming, pass `stream=True` to either [Completion.create](../../api/python_client/#llmengine.completion.Completion.create) or [Completion.acreate](../../api/python_client/#llmengine.completion.Completion.acreate). @@ -75,6 +74,7 @@ To enable token streaming, pass `stream=True` to either [Completion.create](../. An example of token streaming using the synchronous Completions API looks as follows: === "Token streaming with synchronous API in python" + ```python import sys @@ -102,13 +102,14 @@ to utilize async processing. The function signatures are otherwise identical. An example of async Completions looks as follows: === "Completions with asynchronous API in python" + ```python import asyncio from llmengine import Completion async def main(): response = await Completion.acreate( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, diff --git a/docs/guides/fine_tuning.md b/docs/guides/fine_tuning.md index ab83d0d1..93ff7655 100644 --- a/docs/guides/fine_tuning.md +++ b/docs/guides/fine_tuning.md @@ -22,16 +22,16 @@ The training data for fine-tuning should consist of prompt and response pairs. As a rule of thumb, you should expect to see linear improvements in your fine-tuned model's quality with each doubling of the dataset size. Having high-quality data is also essential to improving performance. For every linear increase in the error rate in your training data, you may encounter a roughly quadratic increase in your fine-tuned model's error rate. -High quality data is critical to achieve improved model performance, and in several cases will require _experts_ to -generate and prepare data - the breadth and diversity of the data is highly critical. Scale's Data Engine can help +High quality data is critical to achieve improved model performance, and in several cases will require _experts_ to +generate and prepare data - the breadth and diversity of the data is highly critical. Scale's Data Engine can help prepare such high quality, diverse data sets - more information [here](https://scale.com/rlhf). ## Preparing data + Your data must be formatted as a CSV file that includes two columns: `prompt` and `response`. A maximum of 100,000 rows of data is currently supported. At least 200 rows of data is recommended to start to see benefits from fine-tuning. Here is an example script to create a 50-row CSV of properly formatted data for fine-tuning an airline question answering bot -
Creating a sample dataset @@ -98,9 +98,11 @@ with open('customer_service_data.csv', 'w', newline='') as file: writer.writerow(["prompt", "response"]) writer.writerows(data) ``` +
## Making your data accessible to LLM Engine + Currently, data needs to be uploaded to a publicly accessible web URL so that it can be read for fine-tuning. Publicly accessible HTTP and HTTPS URLs are currently supported. Support for privately sharing data with the LLM Engine API is coming shortly. For quick @@ -110,31 +112,31 @@ files in a public manner. An example Github Gist can be found you can use the URL given when you click the “Raw” button ([URL](https://gist.githubusercontent.com/tigss/7cec73251a37de72756a3b15eace9965/raw/85d9742890e1e6b0c06468507292893b820c13c9/llm_sample_data.csv)). - ## Launching the fine-tune -Once you have uploaded your data, you can use the LLM Engine's [FineTune.Create](../../api/python_client/#llmengine.fine_tuning.FineTune.create) API to launch a fine-tune. You will need to specify which base model to fine-tune, the locations of the training file and optional validation data file, an optional set of hyperparameters to customize the fine-tuning behavior, and an optional suffix to append to the name of the fine-tune. For sequences longer than the native + +Once you have uploaded your data, you can use the LLM Engine's [FineTune.Create](../../api/python_client/#llmengine.fine_tuning.FineTune.create) API to launch a fine-tune. You will need to specify which base model to fine-tune, the locations of the training file and optional validation data file, an optional set of hyperparameters to customize the fine-tuning behavior, and an optional suffix to append to the name of the fine-tune. For sequences longer than the native `max_seq_length` of the model, the sequences will be truncated. -If you specify a suffix, the fine-tune will be named `model.suffix.`. If you do not, -the fine-tune will be named `model.`. The timestamp will be the time the fine-tune was +If you specify a suffix, the fine-tune will be named `model.suffix.`. If you do not, +the fine-tune will be named `model.`. The timestamp will be the time the fine-tune was launched.
Hyper-parameters for fine-tune -* `lr`: Peak learning rate used during fine-tuning. It decays with a cosine schedule afterward. (Default: 2e-3) -* `warmup_ratio`: Ratio of training steps used for learning rate warmup. (Default: 0.03) -* `epochs`: Number of fine-tuning epochs. This should be less than 20. (Default: 5) -* `weight_decay`: Regularization penalty applied to learned weights. (Default: 0.001) +- `lr`: Peak learning rate used during fine-tuning. It decays with a cosine schedule afterward. (Default: 2e-3) +- `warmup_ratio`: Ratio of training steps used for learning rate warmup. (Default: 0.03) +- `epochs`: Number of fine-tuning epochs. This should be less than 20. (Default: 5) +- `weight_decay`: Regularization penalty applied to learned weights. (Default: 0.001)
-=== "Create a fine-tune in python" +=== "Create a fine-tune in python" ```python from llmengine import FineTune response = FineTune.create( - model="llama-7b", + model="llama-2-7b", training_file="s3://my-bucket/path/to/training-file.csv", ) @@ -147,14 +149,14 @@ Once the fine-tune is launched, you can also [get the status of your fine-tune]( ## Making inference calls to your fine-tune -Once your fine-tune is finished, you will be able to start making inference requests to the -model. You can use the `fine_tuned_model` returned from your +Once your fine-tune is finished, you will be able to start making inference requests to the +model. You can use the `fine_tuned_model` returned from your [FineTune.get](../../api/python_client/#llmengine.fine_tuning.FineTune.get) -API call to reference your fine-tuned model in the Completions API. Alternatively, you can list -available LLMs with `Model.list` in order to find the name of your fine-tuned model. See the -[Completion API](../../api/python_client/#llmengine.Completion) for more details. You can then -use that name to direct your completion requests. You must wait until your fine-tune is complete -before you can plug it into the Completions API. You can check the status of your fine-tune with +API call to reference your fine-tuned model in the Completions API. Alternatively, you can list +available LLMs with `Model.list` in order to find the name of your fine-tuned model. See the +[Completion API](../../api/python_client/#llmengine.Completion) for more details. You can then +use that name to direct your completion requests. You must wait until your fine-tune is complete +before you can plug it into the Completions API. You can check the status of your fine-tune with [FineTune.get](../../api/python_client/#llmengine.fine_tuning.FineTune.get). === "Inference with a fine-tuned model in python" @@ -163,7 +165,7 @@ before you can plug it into the Completions API. You can check the status of you from llmengine import Completion response = Completion.create( - model="llama-7b.airlines.2023-07-17-08-30-45", + model="llama-2-7b.airlines.2023-07-17-08-30-45", prompt="Do you offer in-flight Wi-fi?", max_new_tokens=100, temperature=0.2, diff --git a/docs/guides/rate_limits.md b/docs/guides/rate_limits.md index 1224f19f..2aa59dd4 100644 --- a/docs/guides/rate_limits.md +++ b/docs/guides/rate_limits.md @@ -18,25 +18,26 @@ will return HTTP 429 on an as-needed basis. ## Retrying with exponential backoff -One easy way to avoid rate limit errors is to automatically retry requests with a random exponential backoff. -Retrying with exponential backoff means performing a short sleep when a rate limit error is hit, then retrying the -unsuccessful request. If the request is still unsuccessful, the sleep length is increased and the process is repeated. +One easy way to avoid rate limit errors is to automatically retry requests with a random exponential backoff. +Retrying with exponential backoff means performing a short sleep when a rate limit error is hit, then retrying the +unsuccessful request. If the request is still unsuccessful, the sleep length is increased and the process is repeated. This continues until the request is successful or until a maximum number of retries is reached. This approach has many benefits: -* Automatic retries means you can recover from rate limit errors without crashes or missing data -* Exponential backoff means that your first retries can be tried quickly, while still benefiting from longer delays if your first few retries fail -* Adding random jitter to the delay helps retries from all hitting at the same time. +- Automatic retries means you can recover from rate limit errors without crashes or missing data +- Exponential backoff means that your first retries can be tried quickly, while still benefiting from longer delays if your first few retries fail +- Adding random jitter to the delay helps retries from all hitting at the same time. Below are a few example solutions **for Python** that use exponential backoff. ### Example #1: Using the `tenacity` library -Tenacity is an Apache 2.0 licensed general-purpose retrying library, written in Python, to simplify the task of adding -retry behavior to just about anything. To add exponential backoff to your requests, you can use the tenacity.retry -decorator. The below example uses the tenacity.wait_random_exponential function to add random exponential backoff to a +Tenacity is an Apache 2.0 licensed general-purpose retrying library, written in Python, to simplify the task of adding +retry behavior to just about anything. To add exponential backoff to your requests, you can use the tenacity.retry +decorator. The below example uses the tenacity.wait_random_exponential function to add random exponential backoff to a request. === "Exponential backoff in python" + ```python import llmengine from tenacity import ( @@ -49,14 +50,15 @@ from tenacity import ( def completion_with_backoff(**kwargs): return llmengine.Completion.create(**kwargs) -completion_with_backoff(model="llama-7b", prompt="Why is the sky blue?") +completion_with_backoff(model="llama-2-7b", prompt="Why is the sky blue?") ``` ### Example #2: Using the `backoff` library -[Backoff](https://github.com/litl/backoff) is another python library that provides function decorators which can be used to wrap a function such that it will be retried until some condition is met. +[Backoff](https://github.com/litl/backoff) is another python library that provides function decorators which can be used to wrap a function such that it will be retried until some condition is met. === "Decorators for backoff and retry in python" + ```python import llmengine import backoff @@ -65,5 +67,5 @@ import backoff def completion_with_backoff(**kwargs): return llmengine.Completion.create(**kwargs) -completions_with_backoff(model="llama-7b", prompt="Why is the sky blue?") +completions_with_backoff(model="llama-2-7b", prompt="Why is the sky blue?") ``` diff --git a/docs/model_zoo.md b/docs/model_zoo.md index f4e6875f..62fb2b48 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -5,6 +5,7 @@ Scale hosts the following models in the LLM Engine Model Zoo: | Model Name | Inference APIs Available | Fine-tuning APIs Available | | --------------------- | ------------------------ | -------------------------- | | `llama-7b` | ✅ | ✅ | +| `llama-2-7b` | ✅ | ✅ | | `falcon-7b` | ✅ | | | `falcon-7b-instruct` | ✅ | | | `falcon-40b` | ✅ | | From 823baa172b91f76bc0ef186bdf410ea5c1326b60 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 18 Jul 2023 21:38:18 -0400 Subject: [PATCH 007/425] Update cookbook for llamav2 (#130) * make code changes * update cookbook with numbers --- docs/examples/finetuning.ipynb | 48 ++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/docs/examples/finetuning.ipynb b/docs/examples/finetuning.ipynb index 819b85fa..16573392 100644 --- a/docs/examples/finetuning.ipynb +++ b/docs/examples/finetuning.ipynb @@ -299,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -316,9 +316,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BatchJobStatus.RUNNING\n" + ] + } + ], "source": [ "# Wait for fine tune to complete\n", "\n", @@ -351,11 +359,11 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "your_fine_tuned_model = \"llama-2-7b.my-first-finetune.2023-07-18-20-28-50\" # Note: you will have a different model!" + "your_fine_tuned_model = \"llama-2-7b.my-first-fine-tune.2023-07-19-00-48-07\" # Note: you will have a different model!" ] }, { @@ -374,7 +382,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -412,7 +420,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -492,7 +500,7 @@ "954 hockey " ] }, - "execution_count": 13, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -503,16 +511,16 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "1.0" + "0.98" ] }, - "execution_count": 14, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -524,7 +532,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -554,17 +562,25 @@ " \n", " \n", " \n", + " \n", + " 974\n", + " From: maX <maX@maxim.rinaco.msk.su>\\nSubject: ...\n", + " hockey\n", + " baseball\n", + " \n", " \n", "\n", "" ], "text/plain": [ - "Empty DataFrame\n", - "Columns: [raw_prompt, response, predicted_response]\n", - "Index: []" + " raw_prompt response \\\n", + "974 From: maX \\nSubject: ... hockey \n", + "\n", + " predicted_response \n", + "974 baseball " ] }, - "execution_count": 15, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } From 07a23f69a2f4bfe4e20e630d2cfca2c2d4c3d3e5 Mon Sep 17 00:00:00 2001 From: "Arun C. Murthy" <124712100+acmatscale@users.noreply.github.com> Date: Tue, 18 Jul 2023 19:42:51 -0700 Subject: [PATCH 008/425] Doc enhancements to use llama-2 (#131) --- .gitignore | 1 + clients/python/llmengine/fine_tuning.py | 4 +-- docs/getting_started.md | 4 +-- docs/guides/completions.md | 35 +++++++++++-------------- docs/index.md | 6 ++--- 5 files changed, 23 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index d5bec1e7..276b0676 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +tags *.cache *.pt *.pkl diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index d2c3d96f..d8ffa84d 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -283,7 +283,7 @@ def get_events(cls, fine_tune_id: str) -> GetFineTuneEventsResponse: Returns: GetFineTuneEventsResponse: an object that contains the list of events for the fine-tuning job - Example: + === "Getting events for fine-tuning jobs in Python" ```python from llmengine import FineTune @@ -291,7 +291,7 @@ def get_events(cls, fine_tune_id: str) -> GetFineTuneEventsResponse: print(response.json()) ``` - JSON Response: + === "Response in JSON" ```json { "events": diff --git a/docs/getting_started.md b/docs/getting_started.md index 23ef79f3..ead931e2 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -48,7 +48,7 @@ With your API key set, you can now send LLM Engine requests using the Python cli from llmengine import Completion response = Completion.create( - model="falcon-7b-instruct", + model="llama-2-7b", prompt="I'm opening a pancake restaurant that specializes in unique pancake shapes, colors, and flavors. List 3 quirky names I could name my restaurant.", max_new_tokens=100, temperature=0.2, @@ -66,7 +66,7 @@ import sys from llmengine import Completion stream = Completion.create( - model="falcon-7b-instruct", + model="llama-2-7b", prompt="Give me a 200 word summary on the current economic events in the US.", max_new_tokens=1000, temperature=0.2, diff --git a/docs/guides/completions.md b/docs/guides/completions.md index c9747889..eb16b94e 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -42,26 +42,21 @@ See the full [Completion API reference documentation](../../api/python_client/#l An example Completion API response looks as follows: === "Response in JSON" -`python - >>> print(response.json()) - ` -Example output: -`json - { - "request_id": "c4bf0732-08e0-48a8-8b44-dfe8d4702fb0", - "output": { - "text": "_______ and I am a _______", - "num_completion_tokens": 10 - } - } - ` + ```python + >>> print(response.json()) + { + "request_id": "c4bf0732-08e0-48a8-8b44-dfe8d4702fb0", + "output": { + "text": "_______ and I am a _______", + "num_completion_tokens": 10 + } + } + ``` === "Response in Python" -`python - >>> print(response.output.text) - ` -Example output: -` _______ and I am a _______ - ` + ```python + >>> print(response.output.text) + _______ and I am a _______ + ``` ## Token streaming @@ -81,7 +76,7 @@ import sys from llmengine import Completion stream = Completion.create( - model="falcon-7b-instruct", + model="llama-2-7b", prompt="Give me a 200 word summary on the current economic events in the US.", max_new_tokens=1000, temperature=0.2, diff --git a/docs/index.md b/docs/index.md index fcf8cf3e..798519ee 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,11 +30,11 @@ Kubernetes. ### Key Features **Ready-to-use APIs for your favorite models**: Deploy and serve -open source foundation models - including LLaMA, MPT, and Falcon. +open source foundation models - including Llama-2, MPT, and Falcon. Use Scale-hosted models or deploy to your own infrastructure. -**Fine-tune your favorite models**: Fine-tune open-source foundation -models like LLaMA, MPT, etc. with your own data for optimized performance. +**Fine-tune the best open-source models**: Fine-tune open-source foundation +models like Llama-2, MPT, etc. with your own data for optimized performance. **Optimized Inference**: LLM Engine provides inference APIs for streaming responses and dynamically batching inputs for higher throughput From 10fee6cfb7eab0a8daba5645cf02b024b5946920 Mon Sep 17 00:00:00 2001 From: "Arun C. Murthy" <124712100+acmatscale@users.noreply.github.com> Date: Tue, 18 Jul 2023 21:20:48 -0700 Subject: [PATCH 009/425] Clarified current self-hosting features. (#134) * Clarified current self-hosting features. * Fixed link --- docs/index.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/index.md b/docs/index.md index 798519ee..01d4f84e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -48,10 +48,11 @@ auto-scaling deployment with simple APIs. ### Features Coming Soon -**Kubernetes Installation Documentation**: We are working hard to document the installation and -maintenance of inference and fine-tuning functionality on your infrastructure. -For now, our documentation covers using our client libraries to access Scale's -hosted infrastructure. +**Kubernetes Installation Enhancements**: We are working hard to enhance the +installation and maintenance of inference and fine-tuning functionality on +your infrastructure. For now, our documentation covers _experimental_ libraries +to [deploy language models on your infrastructure](guides/self_hosting) +and libraries to access Scale's [hosted infrastructure](https://spellbook.scale.com). **Fast Cold-Start Times**: To prevent GPUs from idling, LLM Engine automatically scales your model to zero when it's not in use and scales up From 45fb9c1bc93f61f46fb794d3014a38b74aba9446 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 18 Jul 2023 21:59:37 -0700 Subject: [PATCH 010/425] Add more models to model zoo (#135) --- docs/model_zoo.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 62fb2b48..9cfcbc02 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -6,6 +6,9 @@ Scale hosts the following models in the LLM Engine Model Zoo: | --------------------- | ------------------------ | -------------------------- | | `llama-7b` | ✅ | ✅ | | `llama-2-7b` | ✅ | ✅ | +| `llama-2-7b-chat` | ✅ | | +| `llama-2-13b` | ✅ | | +| `llama-2-13b-chat` | ✅ | | | `falcon-7b` | ✅ | | | `falcon-7b-instruct` | ✅ | | | `falcon-40b` | ✅ | | From ffa1093af3f1ec71d8003adb377a29f9001df754 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 19 Jul 2023 09:35:51 -0700 Subject: [PATCH 011/425] Add back Model.create (#125) * Add back Model.create * format --- clients/python/llmengine/data_types.py | 12 ++ clients/python/llmengine/model.py | 173 ++++++++++++++++++++++++- 2 files changed, 184 insertions(+), 1 deletion(-) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 90cd201c..97ebb009 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -113,6 +113,18 @@ class GetModelEndpointResponse(BaseModel): public_inference: Optional[bool] = Field(default=None) +class PostInferenceHooks(str, Enum): + """ + Post-inference hooks are functions that are called after inference is complete. + + Attributes: + CALLBACK: The callback hook is called with the inference response and the task ID. + """ + + # INSIGHT = "insight" + CALLBACK: str = "callback" + + class CreateLLMEndpointRequest(BaseModel): name: str diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 0b21496b..1b6c9dba 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -1,8 +1,17 @@ -from llmengine.api_engine import DEFAULT_TIMEOUT, APIEngine +from typing import Dict, List, Optional + +from llmengine.api_engine import DEFAULT_TIMEOUT, APIEngine, assert_self_hosted from llmengine.data_types import ( + CreateLLMEndpointRequest, + CreateLLMEndpointResponse, DeleteLLMEndpointResponse, GetLLMEndpointResponse, + GpuType, ListLLMEndpointsResponse, + LLMInferenceFramework, + LLMSource, + ModelEndpointType, + PostInferenceHooks, ) @@ -15,6 +24,168 @@ class Model(APIEngine): See [Model Zoo](../../model_zoo) for the list of publicly available base models. """ + @classmethod + @assert_self_hosted + def create( + cls, + # LLM specific fields + model: str, + inference_framework_image_tag: str, + source: LLMSource = LLMSource.HUGGING_FACE, + inference_framework: LLMInferenceFramework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + num_shards: int = 4, + # General endpoint fields + cpus: int = 32, + memory: str = "192Gi", + storage: str = "96Gi", + gpus: int = 4, + min_workers: int = 0, + max_workers: int = 1, + per_worker: int = 10, + endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING, + gpu_type: Optional[str] = "nvidia-ampere-a10", + high_priority: Optional[bool] = False, + post_inference_hooks: Optional[List[PostInferenceHooks]] = None, + default_callback_url: Optional[str] = None, + public_inference: Optional[bool] = True, + labels: Optional[Dict[str, str]] = None, + ) -> CreateLLMEndpointResponse: + """ + Create an LLM model. Note: This feature is only available for self-hosted users. + Args: + model (`str`): + Name of the model + + inference_framework_image_tag (`str`): + Image tag for the inference framework + + source (`LLMSource`): + Source of the LLM. Currently only HuggingFace is supported + + inference_framework (`LLMInferenceFramework`): + Inference framework for the LLM. Currently only DeepSpeed is supported + + num_shards (`int`): + Number of shards for the LLM. When bigger than 1, LLM will be sharded + to multiple GPUs. Number of GPUs must be larger than num_shards. + + cpus (`int`): + Number of cpus each worker should get, e.g. 1, 2, etc. This must be greater + than or equal to 1 + + memory (`str`): + Amount of memory each worker should get, e.g. "4Gi", "512Mi", etc. This must + be a positive amount of memory + + storage (`str`): + Amount of local ephemeral storage each worker should get, e.g. "4Gi", + "512Mi", etc. This must be a positive amount of storage + + gpus (`int`): + Number of gpus each worker should get, e.g. 0, 1, etc. + + min_workers (`int`): + The minimum number of workers. Must be greater than or equal to 0. This + should be determined by computing the minimum throughput of your workload and + dividing it by the throughput of a single worker. This field must be at least ``1`` + for synchronous endpoints + + max_workers (`int`): + The maximum number of workers. Must be greater than or equal to 0, + and as well as greater than or equal to ``min_workers``. This should be determined by + computing the maximum throughput of your workload and dividing it by the throughput + of a single worker + + per_worker (`int`): + The maximum number of concurrent requests that an individual worker can + service. Launch automatically scales the number of workers for the endpoint so that + each worker is processing ``per_worker`` requests, subject to the limits defined by + ``min_workers`` and ``max_workers`` + + - If the average number of concurrent requests per worker is lower than + ``per_worker``, then the number of workers will be reduced. - Otherwise, + if the average number of concurrent requests per worker is higher than + ``per_worker``, then the number of workers will be increased to meet the elevated + traffic. + + Here is our recommendation for computing ``per_worker``: + + 1. Compute ``min_workers`` and ``max_workers`` per your minimum and maximum + throughput requirements. 2. Determine a value for the maximum number of + concurrent requests in the workload. Divide this number by ``max_workers``. Doing + this ensures that the number of workers will "climb" to ``max_workers``. + + endpoint_type (`ModelEndpointType`): + ``"sync"``, ``"async"`` or ``"streaming"``. + + gpu_type (`Optional[str]`): + If specifying a non-zero number of gpus, this controls the type of gpu + requested. Here are the supported values: + + - ``nvidia-tesla-t4`` + - ``nvidia-ampere-a10`` + + high_priority (`Optional[bool]`): + Either ``True`` or ``False``. Enabling this will allow the created + endpoint to leverage the shared pool of prewarmed nodes for faster spinup time + + post_inference_hooks (`Optional[List[PostInferenceHooks]]`): + List of hooks to trigger after inference tasks are served + + default_callback_url (`Optional[str]`): + The default callback url to use for async endpoints. + This can be overridden in the task parameters for each individual task. + post_inference_hooks must contain "callback" for the callback to be triggered + + public_inference (`Optional[bool]`): + If ``True``, this endpoint will be available to all user IDs for + inference + + + labels (`Optional[Dict[str, str]]`): + An optional dictionary of key/value pairs to associate with this endpoint + Returns: + CreateLLMEndpointResponse: creation task ID of the created Model. + """ + post_inference_hooks_strs = None + if post_inference_hooks is not None: + post_inference_hooks_strs = [] + for hook in post_inference_hooks: + if isinstance(hook, PostInferenceHooks): + post_inference_hooks_strs.append(hook.value) + else: + post_inference_hooks_strs.append(hook) + + request = CreateLLMEndpointRequest( + name=model, + model_name=model, + source=source, + inference_framework=inference_framework, + inference_framework_image_tag=inference_framework_image_tag, + num_shards=num_shards, + cpus=cpus, + endpoint_type=ModelEndpointType(endpoint_type), + gpus=gpus, + gpu_type=GpuType(gpu_type) if gpu_type is not None else None, + labels=labels or {}, + max_workers=max_workers, + memory=memory, + metadata={}, + min_workers=min_workers, + per_worker=per_worker, + high_priority=high_priority, + post_inference_hooks=post_inference_hooks_strs, + default_callback_url=default_callback_url, + storage=storage, + public_inference=public_inference, + ) + response = cls.post_sync( + resource_name="v1/llm/model-endpoints", + data=request.dict(), + timeout=DEFAULT_TIMEOUT, + ) + return CreateLLMEndpointResponse.parse_obj(response) + @classmethod def get( cls, From 97f64ecef54fc683826fe1368b26b83c749719ad Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 19 Jul 2023 16:24:17 -0400 Subject: [PATCH 012/425] Update allowed types on hyperparameter values (#140) --- clients/python/llmengine/fine_tuning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index d8ffa84d..29be75ce 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Union from llmengine.api_engine import DEFAULT_TIMEOUT, APIEngine from llmengine.data_types import ( @@ -28,7 +28,7 @@ def create( model: str, training_file: str, validation_file: Optional[str] = None, - hyperparameters: Optional[Dict[str, str]] = None, + hyperparameters: Optional[Dict[str, Union[str, int, float]]] = None, suffix: Optional[str] = None, ) -> CreateFineTuneResponse: """ From 767cbc45556a80e6d775918c2e20ade855049c91 Mon Sep 17 00:00:00 2001 From: "Ray (Jui-Tse) Hung" <135046452+ruizehung-scale@users.noreply.github.com> Date: Wed, 19 Jul 2023 15:46:12 -0700 Subject: [PATCH 013/425] Can create LLM endpoint (#132) --- charts/llm-engine/templates/_helpers.tpl | 2 +- .../service_template_config_map.yaml | 8 +++---- charts/llm-engine/values_sample.yaml | 16 ++++++++++++- .../use_cases/llm_model_endpoint_use_cases.py | 2 +- .../service_template_config_map_circleci.yaml | 24 +++++++++---------- 5 files changed, 33 insertions(+), 19 deletions(-) diff --git a/charts/llm-engine/templates/_helpers.tpl b/charts/llm-engine/templates/_helpers.tpl index 04c8168f..08af45f4 100644 --- a/charts/llm-engine/templates/_helpers.tpl +++ b/charts/llm-engine/templates/_helpers.tpl @@ -344,7 +344,7 @@ volumeMounts: {{- define "llmEngine.forwarderVolumeMounts" }} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /home/user/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config diff --git a/charts/llm-engine/templates/service_template_config_map.yaml b/charts/llm-engine/templates/service_template_config_map.yaml index 87b992cf..08ce1424 100644 --- a/charts/llm-engine/templates/service_template_config_map.yaml +++ b/charts/llm-engine/templates/service_template_config_map.yaml @@ -180,7 +180,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --http - production_threads - --port @@ -221,9 +221,9 @@ data: - ddtrace-run - python - -m - - llm_engine.inference.forwarding.http_forwarder + - server.llm_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/service--http_forwarder.yaml + - /workspace/server/llm_engine_server/inference/configs/service--http_forwarder.yaml - --port - "${FORWARDER_PORT}" - --num-workers @@ -266,7 +266,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility diff --git a/charts/llm-engine/values_sample.yaml b/charts/llm-engine/values_sample.yaml index 7b2cbbf0..06d70362 100644 --- a/charts/llm-engine/values_sample.yaml +++ b/charts/llm-engine/values_sample.yaml @@ -1,7 +1,7 @@ # This is a YAML-formatted file. # tag [required] is the LLM Engine docker image tag -tag: 1defd4f9c5376149e27673e154731a0c7820fe5d +tag: 41ecada1b51ce3a46bbc3190a36ed7890db370d3 # context is a user-specified deployment tag. Can be used to context: production image: @@ -171,6 +171,20 @@ imageCache: - key: "nvidia.com/gpu" operator: "Exists" effect: "NoSchedule" + - name: a100 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-ampere-a100 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: t4 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-tesla-t4 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" # celeryBrokerType specifies the celery broker type for async endpoints (coming soon) celeryBrokerType: sqs diff --git a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 482a4519..8de7ef72 100644 --- a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -227,7 +227,7 @@ async def create_text_generation_inference_bundle( schema_location="TBA", flavor=StreamingEnhancedRunnableImageFlavor( flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, - repository="text-generation-inference", # TODO: let user choose repo + repository="ghcr.io/huggingface/text-generation-inference", # TODO: let user choose repo tag=framework_image_tag, command=command, streaming_command=command, diff --git a/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 41ffe75b..3f2e519f 100644 --- a/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -114,7 +114,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -383,7 +383,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -805,7 +805,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --http - production_threads - --port @@ -1071,7 +1071,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --http - production_threads - --port @@ -1473,9 +1473,9 @@ data: - ddtrace-run - python - -m - - llm_engine.inference.forwarding.http_forwarder + - server.llm_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/service--http_forwarder.yaml + - /workspace/server/llm_engine_server/inference/configs/service--http_forwarder.yaml - --port - "${FORWARDER_PORT}" - --num-workers @@ -1712,7 +1712,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -1987,7 +1987,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -2421,7 +2421,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --http - production_threads - --port @@ -2693,7 +2693,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --http - production_threads - --port @@ -3107,9 +3107,9 @@ data: - ddtrace-run - python - -m - - llm_engine.inference.forwarding.http_forwarder + - server.llm_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/service--http_forwarder.yaml + - /workspace/server/llm_engine_server/inference/configs/service--http_forwarder.yaml - --port - "${FORWARDER_PORT}" - --num-workers From 9c94ec51adcd3095268e917479c28e78f38b3dde Mon Sep 17 00:00:00 2001 From: jihan-yin <78386805+jihan-yin@users.noreply.github.com> Date: Wed, 19 Jul 2023 15:59:00 -0700 Subject: [PATCH 014/425] refetch API key (#142) Co-authored-by: Ubuntu --- clients/python/llmengine/api_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index fb0ec830..0d82c184 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -9,14 +9,14 @@ from aiohttp import ClientSession, ClientTimeout from llmengine.errors import parse_error -SCALE_API_KEY = os.getenv("SCALE_API_KEY") SPELLBOOK_API_URL = "https://api.spellbook.scale.com" LLM_ENGINE_BASE_PATH = os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) DEFAULT_TIMEOUT: int = 10 def get_api_key() -> str: - return SCALE_API_KEY or "root" + env_api_key = os.getenv("SCALE_API_KEY") + return env_api_key or "root" def assert_self_hosted(func): @@ -32,7 +32,7 @@ def inner(*args, **kwargs): class APIEngine: @classmethod def validate_api_key(cls): - if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH and not SCALE_API_KEY: + if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH and not get_api_key(): raise ValueError( "You must set SCALE_API_KEY in your environment to to use the LLM Engine API." ) From 5b9e5d2abc3e2f7a642a80f837f5ac27d5819b3a Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 19 Jul 2023 20:29:41 -0400 Subject: [PATCH 015/425] Add fine_tuned_model field to Get/ListFineTuneResponse (#145) --- clients/python/llmengine/data_types.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 97ebb009..ffcd4f5c 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -381,6 +381,16 @@ class GetFineTuneResponse(BaseModel): The ID of the FineTune. """ + fine_tuned_model: Optional[str] = Field( + default=None, + description="Name of the resulting fine-tuned model. This can be plugged into the " + "Completion API once the fine-tune is complete", + ) + """ + The name of the resulting fine-tuned model. This can be plugged into the Completion API + once the fine-tune is complete. + """ + status: BatchJobStatus = Field(..., description="Status of the requested job.") """ The status of the FineTune job. From 73614862514761b40f67620e3748d24f7e4c586c Mon Sep 17 00:00:00 2001 From: Utsav Garg <110483261+gargutsav@users.noreply.github.com> Date: Thu, 20 Jul 2023 00:04:28 -0700 Subject: [PATCH 016/425] Update fine_tuning.md (#147) Add details about training and validation file inputs for fine-tuning. --- clients/python/llmengine/fine_tuning.py | 2 +- docs/guides/fine_tuning.md | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index 29be75ce..c280ac39 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -51,7 +51,7 @@ def create( The name of the base model to fine-tune. See [Model Zoo](../../model_zoo) for the list of available models to fine-tune. training_file (`str`): - Publicly accessible URL to a CSV file for training. + Publicly accessible URL to a CSV file for training. When no validation_file is provided, one will automatically be created using a 10% split of the training_file data. validation_file (`Optional[str]`): Publicly accessible URL to a CSV file for validation. The validation file is used to compute metrics which let LLM Engine pick the best fine-tuned checkpoint, which will be used for inference when fine-tuning is complete. diff --git a/docs/guides/fine_tuning.md b/docs/guides/fine_tuning.md index 93ff7655..b4d5d9ef 100644 --- a/docs/guides/fine_tuning.md +++ b/docs/guides/fine_tuning.md @@ -28,7 +28,7 @@ prepare such high quality, diverse data sets - more information [here](https://s ## Preparing data -Your data must be formatted as a CSV file that includes two columns: `prompt` and `response`. A maximum of 100,000 rows of data is currently supported. At least 200 rows of data is recommended to start to see benefits from fine-tuning. +Your data must be formatted as a CSV file that includes two columns: `prompt` and `response`. A maximum of 100,000 rows of data is currently supported. At least 200 rows of data is recommended to start to see benefits from fine-tuning. LLM Engine supports fine-tuning with a training and validation dataset. If only a training dataset is provided, 10% of the data is randomly split to be used as validation. Here is an example script to create a 50-row CSV of properly formatted data for fine-tuning an airline question answering bot @@ -138,6 +138,7 @@ from llmengine import FineTune response = FineTune.create( model="llama-2-7b", training_file="s3://my-bucket/path/to/training-file.csv", + validation_file="s3://my-bucket/path/to/validation-file.csv", ) print(response.json()) From c09785138537f088b6a76057f798936040fd715a Mon Sep 17 00:00:00 2001 From: William Song Date: Thu, 20 Jul 2023 10:53:10 -0700 Subject: [PATCH 017/425] bump pip package version from 0.0.0.beta3 -> 0.0.0.beta4 (#154) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 97324d59..1d882f5c 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta3" +__version__ = "0.0.0.beta4" from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 9d1db1f1..f8c28795 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta3" +version = "0.0.0.beta4" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 4d14ebd3..d0cb52ab 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta3", + version="0.0.0.beta4", packages=find_packages(), ) From b2588eac37af8cc1a72bcb91dc659593ffbf502e Mon Sep 17 00:00:00 2001 From: jihan-yin <78386805+jihan-yin@users.noreply.github.com> Date: Thu, 20 Jul 2023 11:53:54 -0700 Subject: [PATCH 018/425] Example notebook for fine-tuning Llama-2 7B on ScienceQA (#148) Example notebook for fine-tuning Llama-2 7B on ScienceQA #148 --- examples/finetune_llama_2_on_science_qa.ipynb | 219 ++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 examples/finetune_llama_2_on_science_qa.ipynb diff --git a/examples/finetune_llama_2_on_science_qa.ipynb b/examples/finetune_llama_2_on_science_qa.ipynb new file mode 100644 index 00000000..9b4f77a4 --- /dev/null +++ b/examples/finetune_llama_2_on_science_qa.ipynb @@ -0,0 +1,219 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8d3a4214", + "metadata": {}, + "source": [ + "# Finetune on ScienceQA\n", + "Let's use LLM Engine to fine-tune Llama-2 on ScienceQA!" + ] + }, + { + "cell_type": "markdown", + "id": "a3dc2a56", + "metadata": {}, + "source": [ + "# Data Preparation\n", + "Let's load in the dataset using Huggingface and view the features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e06ac39e", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from smart_open import smart_open\n", + "import pandas as pd\n", + "\n", + "dataset = load_dataset('derek-thomas/ScienceQA')\n", + "dataset['train'].features" + ] + }, + { + "cell_type": "markdown", + "id": "1cbe8a58", + "metadata": {}, + "source": [ + "Now, let's format the dataset into what's acceptable for LLM Engine - a CSV file with 'prompt' and 'response' columns." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b0eb8ad", + "metadata": {}, + "outputs": [], + "source": [ + "choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z\n", + "def format_options(options, choice_prefixes):\n", + " return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])\n", + "\n", + "def format_prompt(r, choice_prefixes):\n", + " options = format_options(r['choices'], choice_prefixes)\n", + " return f'''Context: {r[\"hint\"]}\\nQuestion: {r[\"question\"]}\\nOptions:{options}\\nAnswer:'''\n", + "\n", + "def format_label(r, choice_prefixes):\n", + " return choice_prefixes[r['answer']]\n", + "\n", + "def convert_dataset(ds):\n", + " prompts = [format_prompt(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " labels = [format_label(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " df = pd.DataFrame.from_dict({'prompt': prompts, 'response': labels})\n", + " return df\n", + "\n", + "save_to_s3 = False\n", + "df_train = convert_dataset(dataset['train'])\n", + "if save_to_s3:\n", + " train_url = 's3://...'\n", + " val_url = 's3://...'\n", + " df_train = convert_dataset(dataset['train'])\n", + " with smart_open(train_url, 'wb') as f:\n", + " df_train.to_csv(f)\n", + "\n", + " df_val = convert_dataset(dataset['validation'])\n", + " with smart_open(val_url, 'wb') as f:\n", + " df_val.to_csv(f)\n", + "else:\n", + " # Gists of the already processed datasets\n", + " train_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_train.csv'\n", + " val_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_val.csv'\n", + " \n", + "df_train" + ] + }, + { + "cell_type": "markdown", + "id": "e2fc8d76", + "metadata": {}, + "source": [ + "# Fine-tune\n", + "Now, we can fine-tune the model using LLM Engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4905d447", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['SCALE_API_KEY'] = 'xxx'\n", + "\n", + "from llmengine import FineTune\n", + "\n", + "response = FineTune.create(\n", + " model=\"llama-2-7b\",\n", + " training_file=train_url,\n", + " validation_file=val_url,\n", + " hyperparameters={\n", + " 'lr':2e-4,\n", + " },\n", + " suffix='science-qa-llama'\n", + ")\n", + "run_id = response.fine_tune_id" + ] + }, + { + "cell_type": "markdown", + "id": "55074457", + "metadata": {}, + "source": [ + "We can sleep until the job completes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "840938dd", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "while True:\n", + " job_status = FineTune.get(run_id).status\n", + " print(job_status)\n", + " if job_status == 'SUCCESS':\n", + " break\n", + " time.sleep(60)\n", + " \n", + "fine_tuned_model = FineTune.get(run_id).fine_tuned_model" + ] + }, + { + "cell_type": "markdown", + "id": "31278c6d", + "metadata": {}, + "source": [ + "# Inference and Evaluation\n", + "Let's evaluate the new fine-tuned model by running inference against it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b9d7643", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from llmengine import Completion\n", + "\n", + "# Helper function to get outputs for fine-tuned model with retries\n", + "def get_output(prompt: str, num_retry: int = 5):\n", + " for _ in range(num_retry):\n", + " try:\n", + " response = Completion.create(\n", + " model=fine_tuned_model, \n", + " prompt=prompt, \n", + " max_new_tokens=1, \n", + " temperature=0.01\n", + " )\n", + " return response.output.text.strip()\n", + " except Exception as e:\n", + " print(e)\n", + " return \"\"\n", + "\n", + "# Read the test data\n", + "test = pd.read_csv(val_url)\n", + "\n", + "test[\"prediction\"] = test[\"prompt\"].apply(get_output)\n", + "print(f\"Accuracy: {(test['response'] == test['prediction']).mean() * 100:.2f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f2f3f43", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Environment (conda_pytorch_p38)", + "language": "python", + "name": "conda_pytorch_p38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 826d3c4a82176d9ada532664e5452fa1337553da Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Thu, 20 Jul 2023 21:19:16 -0400 Subject: [PATCH 019/425] Bump version to 0.0.0b5 (#158) * bump setup version to 0.0.0.beta5 * bump in more places --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 1d882f5c..d7567145 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta4" +__version__ = "0.0.0.beta5" from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index f8c28795..7ca3992d 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta4" +version = "0.0.0.beta5" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index d0cb52ab..dbf97fb6 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta4", + version="0.0.0.beta5", packages=find_packages(), ) From d9acb5b01a5731b95b0fd638f8b09af102b33ad9 Mon Sep 17 00:00:00 2001 From: "Ray (Jui-Tse) Hung" <135046452+ruizehung-scale@users.noreply.github.com> Date: Fri, 21 Jul 2023 10:15:32 -0700 Subject: [PATCH 020/425] =?UTF-8?q?Add=20llm=20endpoint=20creation=20and?= =?UTF-8?q?=20inference=20sample=20code=20to=20self=20hosting=20d=E2=80=A6?= =?UTF-8?q?=20(#153)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add llm endpoint creation and inference sample code to self hosting doc play with it section * Small fix on sending list llm endpoints request * Update doc --- docs/guides/self_hosting.md | 64 +++++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/docs/guides/self_hosting.md b/docs/guides/self_hosting.md index 23f66579..8c6c963b 100644 --- a/docs/guides/self_hosting.md +++ b/docs/guides/self_hosting.md @@ -136,8 +136,68 @@ Forward a port from a `llm-engine` pod: $ kubectl port-forward pod/llm-engine- 5000:5000 -n ``` -Then, try sending a request to get LLM model endpoints for `test-user-id`. You should get a response with empty list: +Then, try sending a request to get LLM model endpoints for `test-user-id`: ``` $ curl -X GET -H "Content-Type: application/json" -u "test-user-id:" "http://localhost:5000/v1/llm/model-endpoints" -{"model_endpoints":[]}% +``` + +You should get the following response: +``` +{"model_endpoints":[]} +``` + +Next, let's create a LLM endpoint using llama-7b: +``` +$ curl -X POST 'http://localhost:5000/v1/llm/model-endpoints' \ + -H 'Content-Type: application/json' \ + -d '{ + "name": "llama-7b", + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "text_generation_inference", + "inference_framework_image_tag": "0.9.3", + "num_shards": 4, + "endpoint_type": "streaming", + "cpus": 32, + "gpus": 4, + "memory": "40Gi", + "storage": "40Gi", + "gpu_type": "nvidia-ampere-a10", + "min_workers": 1, + "max_workers": 12, + "per_worker": 1, + "labels": {}, + "metadata": {} + }' \ + -u test_user_id: +``` + +It should output something like: +``` +{"endpoint_creation_task_id":"8d323344-b1b5-497d-a851-6d6284d2f8e4"} +``` + +Wait a few minutes for the endpoint to be ready. You can tell that it's ready by listing pods and checking that all containers in the llm endpoint pod are ready: +``` +$ kubectl get pods -n +NAME READY STATUS RESTARTS AGE +llm-engine-endpoint-id-end-cismpd08agn003rr2kc0-7f86ff64f9qj9xp 2/2 Running 1 (4m41s ago) 7m26s +``` +Note the endpoint name could be different. + +Then, you can send an inference request to the endppoint: +``` +$ curl -X POST 'http://localhost:5000/v1/llm/completions-sync?model_endpoint_name=llama-7b' \ + -H 'Content-Type: application/json' \ + -d '{ + "prompts": ["Tell me a joke about AI"], + "max_new_tokens": 30, + "temperature": 0.1 + }' \ + -u test-user-id: +``` + +You should get a response similar to: +``` +{"status":"SUCCESS","outputs":[{"text":". Tell me a joke about AI. Tell me a joke about AI. Tell me a joke about AI. Tell me","num_completion_tokens":30}],"traceback":null} ``` \ No newline at end of file From b77dff22305089dfdb4fc8a8a9477a5bd1fec054 Mon Sep 17 00:00:00 2001 From: William Song Date: Fri, 21 Jul 2023 17:17:45 -0700 Subject: [PATCH 021/425] iterate helm install (#161) --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f9243eef..1afcd413 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -90,7 +90,7 @@ jobs: name: Install helm chart command: | cd $HOME/project/charts - helm install llm-engine llm-engine --values llm-engine/values_sample.yaml + helm install llm-engine llm-engine --values llm-engine/values_circleci.yaml --set tag=$CIRCLE_SHA1 executors: ubuntu-large: From 06bdd5c8bb3904470a009c564549073c3e2c74c7 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 25 Jul 2023 15:37:52 -0700 Subject: [PATCH 022/425] updating fine_tune_id to id (#174) * updating fine_tune_id to id * changing version to 0.0.0.b6 --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/data_types.py | 4 ++-- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- server/llm_engine_server/common/dtos/llms.py | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index d7567145..faed4b0d 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta5" +__version__ = "0.0.0.beta6" from typing import Sequence diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index ffcd4f5c..b3fcbf58 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -355,7 +355,7 @@ class CreateFineTuneResponse(BaseModel): Response object for creating a FineTune. """ - fine_tune_id: str = Field(..., description="ID of the created fine-tuning job.") + id: str = Field(..., description="ID of the created fine-tuning job.") """ The ID of the FineTune. """ @@ -376,7 +376,7 @@ class GetFineTuneResponse(BaseModel): Response object for retrieving a FineTune. """ - fine_tune_id: str = Field(..., description="ID of the requested job.") + id: str = Field(..., description="ID of the requested job.") """ The ID of the FineTune. """ diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 7ca3992d..b316dea0 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta5" +version = "0.0.0.beta6" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index dbf97fb6..c78a082c 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta5", + version="0.0.0.beta6", packages=find_packages(), ) diff --git a/server/llm_engine_server/common/dtos/llms.py b/server/llm_engine_server/common/dtos/llms.py index 17d6f13d..2739dc1f 100644 --- a/server/llm_engine_server/common/dtos/llms.py +++ b/server/llm_engine_server/common/dtos/llms.py @@ -155,11 +155,11 @@ class CreateFineTuneJobRequest(BaseModel): class CreateFineTuneJobResponse(BaseModel): - fine_tune_id: str + id: str class GetFineTuneJobResponse(BaseModel): - fine_tune_id: str + id: str status: BatchJobStatus From e4d0bf04373f2557cca11675195e8696def729fc Mon Sep 17 00:00:00 2001 From: "Ray (Jui-Tse) Hung" Date: Tue, 25 Jul 2023 17:01:41 -0700 Subject: [PATCH 023/425] Update unit test instruction (#176) --- server/tests/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/tests/README.md b/server/tests/README.md index 2c442472..c9ddba03 100644 --- a/server/tests/README.md +++ b/server/tests/README.md @@ -1,7 +1,7 @@ -## To Run Tests: +## To Run Unit Tests: + +Inside `server/` folder, run ```shell -pushd ../ PYTHONPATH=llm_engine_server WORKSPACE=. python3 -m pytest tests --cov=llm_engine_server -popd -``` +``` \ No newline at end of file From cdb96d0810bc86a51491e186555a4f68ac1c7320 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 25 Jul 2023 19:43:29 -0700 Subject: [PATCH 024/425] Add llm-engine suffix to Spellbook URL (#173) --- clients/python/llmengine/api_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index 0d82c184..16fd3a72 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -9,7 +9,7 @@ from aiohttp import ClientSession, ClientTimeout from llmengine.errors import parse_error -SPELLBOOK_API_URL = "https://api.spellbook.scale.com" +SPELLBOOK_API_URL = "https://api.spellbook.scale.com/llm-engine" LLM_ENGINE_BASE_PATH = os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) DEFAULT_TIMEOUT: int = 10 From 7d6a9c3476ccd49360c71c8c83f6c06f9cee800d Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 26 Jul 2023 10:29:41 -0700 Subject: [PATCH 025/425] adding required dependency installs to scienceqa (#177) --- examples/finetune_llama_2_on_science_qa.ipynb | 485 ++++++++++-------- 1 file changed, 268 insertions(+), 217 deletions(-) diff --git a/examples/finetune_llama_2_on_science_qa.ipynb b/examples/finetune_llama_2_on_science_qa.ipynb index 9b4f77a4..9812258b 100644 --- a/examples/finetune_llama_2_on_science_qa.ipynb +++ b/examples/finetune_llama_2_on_science_qa.ipynb @@ -1,219 +1,270 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "8d3a4214", - "metadata": {}, - "source": [ - "# Finetune on ScienceQA\n", - "Let's use LLM Engine to fine-tune Llama-2 on ScienceQA!" - ] + "cells": [ + { + "cell_type": "markdown", + "id": "8d3a4214", + "metadata": { + "id": "8d3a4214" + }, + "source": [ + "# Finetune on ScienceQA\n", + "Let's use LLM Engine to fine-tune Llama-2 on ScienceQA!" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Packages Required\n", + "For this demo, we'll be using the `scale-llm-engine` package and `datasets` from Huggingface.\n" + ], + "metadata": { + "id": "XK6VpTnOL4OV" + }, + "id": "XK6VpTnOL4OV" + }, + { + "cell_type": "code", + "source": [ + "!pip install scale-llm-engine\n", + "!pip install datasets" + ], + "metadata": { + "id": "S5u6DdInMEQ7" + }, + "id": "S5u6DdInMEQ7", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "a3dc2a56", + "metadata": { + "id": "a3dc2a56" + }, + "source": [ + "# Data Preparation\n", + "Let's load in the dataset using Huggingface and view the features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e06ac39e", + "metadata": { + "id": "e06ac39e" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from smart_open import smart_open\n", + "import pandas as pd\n", + "\n", + "dataset = load_dataset('derek-thomas/ScienceQA')\n", + "dataset['train'].features" + ] + }, + { + "cell_type": "markdown", + "id": "1cbe8a58", + "metadata": { + "id": "1cbe8a58" + }, + "source": [ + "Now, let's format the dataset into what's acceptable for LLM Engine - a CSV file with 'prompt' and 'response' columns." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b0eb8ad", + "metadata": { + "id": "0b0eb8ad" + }, + "outputs": [], + "source": [ + "choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z\n", + "def format_options(options, choice_prefixes):\n", + " return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])\n", + "\n", + "def format_prompt(r, choice_prefixes):\n", + " options = format_options(r['choices'], choice_prefixes)\n", + " return f'''Context: {r[\"hint\"]}\\nQuestion: {r[\"question\"]}\\nOptions:{options}\\nAnswer:'''\n", + "\n", + "def format_label(r, choice_prefixes):\n", + " return choice_prefixes[r['answer']]\n", + "\n", + "def convert_dataset(ds):\n", + " prompts = [format_prompt(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " labels = [format_label(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " df = pd.DataFrame.from_dict({'prompt': prompts, 'response': labels})\n", + " return df\n", + "\n", + "save_to_s3 = False\n", + "df_train = convert_dataset(dataset['train'])\n", + "if save_to_s3:\n", + " train_url = 's3://...'\n", + " val_url = 's3://...'\n", + " df_train = convert_dataset(dataset['train'])\n", + " with smart_open(train_url, 'wb') as f:\n", + " df_train.to_csv(f)\n", + "\n", + " df_val = convert_dataset(dataset['validation'])\n", + " with smart_open(val_url, 'wb') as f:\n", + " df_val.to_csv(f)\n", + "else:\n", + " # Gists of the already processed datasets\n", + " train_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_train.csv'\n", + " val_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_val.csv'\n", + "\n", + "df_train" + ] + }, + { + "cell_type": "markdown", + "id": "e2fc8d76", + "metadata": { + "id": "e2fc8d76" + }, + "source": [ + "# Fine-tune\n", + "Now, we can fine-tune the model using LLM Engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4905d447", + "metadata": { + "id": "4905d447" + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ['SCALE_API_KEY'] = 'xxx'\n", + "\n", + "from llmengine import FineTune\n", + "\n", + "response = FineTune.create(\n", + " model=\"llama-2-7b\",\n", + " training_file=train_url,\n", + " validation_file=val_url,\n", + " hyperparameters={\n", + " 'lr':2e-4,\n", + " },\n", + " suffix='science-qa-llama'\n", + ")\n", + "run_id = response.fine_tune_id" + ] + }, + { + "cell_type": "markdown", + "id": "55074457", + "metadata": { + "id": "55074457" + }, + "source": [ + "We can sleep until the job completes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "840938dd", + "metadata": { + "id": "840938dd" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "while True:\n", + " job_status = FineTune.get(run_id).status\n", + " print(job_status)\n", + " if job_status == 'SUCCESS':\n", + " break\n", + " time.sleep(60)\n", + "\n", + "fine_tuned_model = FineTune.get(run_id).fine_tuned_model" + ] + }, + { + "cell_type": "markdown", + "id": "31278c6d", + "metadata": { + "id": "31278c6d" + }, + "source": [ + "# Inference and Evaluation\n", + "Let's evaluate the new fine-tuned model by running inference against it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b9d7643", + "metadata": { + "id": "3b9d7643" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from llmengine import Completion\n", + "\n", + "# Helper function to get outputs for fine-tuned model with retries\n", + "def get_output(prompt: str, num_retry: int = 5):\n", + " for _ in range(num_retry):\n", + " try:\n", + " response = Completion.create(\n", + " model=fine_tuned_model,\n", + " prompt=prompt,\n", + " max_new_tokens=1,\n", + " temperature=0.01\n", + " )\n", + " return response.output.text.strip()\n", + " except Exception as e:\n", + " print(e)\n", + " return \"\"\n", + "\n", + "# Read the test data\n", + "test = pd.read_csv(val_url)\n", + "\n", + "test[\"prediction\"] = test[\"prompt\"].apply(get_output)\n", + "print(f\"Accuracy: {(test['response'] == test['prediction']).mean() * 100:.2f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f2f3f43", + "metadata": { + "id": "9f2f3f43" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Environment (conda_pytorch_p38)", + "language": "python", + "name": "conda_pytorch_p38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "colab": { + "provenance": [] + } }, - { - "cell_type": "markdown", - "id": "a3dc2a56", - "metadata": {}, - "source": [ - "# Data Preparation\n", - "Let's load in the dataset using Huggingface and view the features." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e06ac39e", - "metadata": {}, - "outputs": [], - "source": [ - "from datasets import load_dataset\n", - "from smart_open import smart_open\n", - "import pandas as pd\n", - "\n", - "dataset = load_dataset('derek-thomas/ScienceQA')\n", - "dataset['train'].features" - ] - }, - { - "cell_type": "markdown", - "id": "1cbe8a58", - "metadata": {}, - "source": [ - "Now, let's format the dataset into what's acceptable for LLM Engine - a CSV file with 'prompt' and 'response' columns." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0b0eb8ad", - "metadata": {}, - "outputs": [], - "source": [ - "choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z\n", - "def format_options(options, choice_prefixes):\n", - " return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])\n", - "\n", - "def format_prompt(r, choice_prefixes):\n", - " options = format_options(r['choices'], choice_prefixes)\n", - " return f'''Context: {r[\"hint\"]}\\nQuestion: {r[\"question\"]}\\nOptions:{options}\\nAnswer:'''\n", - "\n", - "def format_label(r, choice_prefixes):\n", - " return choice_prefixes[r['answer']]\n", - "\n", - "def convert_dataset(ds):\n", - " prompts = [format_prompt(i, choice_prefixes) for i in ds if i['hint'] != '']\n", - " labels = [format_label(i, choice_prefixes) for i in ds if i['hint'] != '']\n", - " df = pd.DataFrame.from_dict({'prompt': prompts, 'response': labels})\n", - " return df\n", - "\n", - "save_to_s3 = False\n", - "df_train = convert_dataset(dataset['train'])\n", - "if save_to_s3:\n", - " train_url = 's3://...'\n", - " val_url = 's3://...'\n", - " df_train = convert_dataset(dataset['train'])\n", - " with smart_open(train_url, 'wb') as f:\n", - " df_train.to_csv(f)\n", - "\n", - " df_val = convert_dataset(dataset['validation'])\n", - " with smart_open(val_url, 'wb') as f:\n", - " df_val.to_csv(f)\n", - "else:\n", - " # Gists of the already processed datasets\n", - " train_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_train.csv'\n", - " val_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_val.csv'\n", - " \n", - "df_train" - ] - }, - { - "cell_type": "markdown", - "id": "e2fc8d76", - "metadata": {}, - "source": [ - "# Fine-tune\n", - "Now, we can fine-tune the model using LLM Engine." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4905d447", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.environ['SCALE_API_KEY'] = 'xxx'\n", - "\n", - "from llmengine import FineTune\n", - "\n", - "response = FineTune.create(\n", - " model=\"llama-2-7b\",\n", - " training_file=train_url,\n", - " validation_file=val_url,\n", - " hyperparameters={\n", - " 'lr':2e-4,\n", - " },\n", - " suffix='science-qa-llama'\n", - ")\n", - "run_id = response.fine_tune_id" - ] - }, - { - "cell_type": "markdown", - "id": "55074457", - "metadata": {}, - "source": [ - "We can sleep until the job completes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "840938dd", - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "\n", - "while True:\n", - " job_status = FineTune.get(run_id).status\n", - " print(job_status)\n", - " if job_status == 'SUCCESS':\n", - " break\n", - " time.sleep(60)\n", - " \n", - "fine_tuned_model = FineTune.get(run_id).fine_tuned_model" - ] - }, - { - "cell_type": "markdown", - "id": "31278c6d", - "metadata": {}, - "source": [ - "# Inference and Evaluation\n", - "Let's evaluate the new fine-tuned model by running inference against it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3b9d7643", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "from llmengine import Completion\n", - "\n", - "# Helper function to get outputs for fine-tuned model with retries\n", - "def get_output(prompt: str, num_retry: int = 5):\n", - " for _ in range(num_retry):\n", - " try:\n", - " response = Completion.create(\n", - " model=fine_tuned_model, \n", - " prompt=prompt, \n", - " max_new_tokens=1, \n", - " temperature=0.01\n", - " )\n", - " return response.output.text.strip()\n", - " except Exception as e:\n", - " print(e)\n", - " return \"\"\n", - "\n", - "# Read the test data\n", - "test = pd.read_csv(val_url)\n", - "\n", - "test[\"prediction\"] = test[\"prompt\"].apply(get_output)\n", - "print(f\"Accuracy: {(test['response'] == test['prediction']).mean() * 100:.2f}%\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9f2f3f43", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Environment (conda_pytorch_p38)", - "language": "python", - "name": "conda_pytorch_p38" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file From f4c6307de304f9457e62684022659f097d2e6eb5 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 26 Jul 2023 10:38:49 -0700 Subject: [PATCH 026/425] fix fine_tune_id to be id in scienceqa example (#179) --- examples/finetune_llama_2_on_science_qa.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/finetune_llama_2_on_science_qa.ipynb b/examples/finetune_llama_2_on_science_qa.ipynb index 9812258b..1b8f0ce5 100644 --- a/examples/finetune_llama_2_on_science_qa.ipynb +++ b/examples/finetune_llama_2_on_science_qa.ipynb @@ -153,7 +153,7 @@ " },\n", " suffix='science-qa-llama'\n", ")\n", - "run_id = response.fine_tune_id" + "run_id = response.id" ] }, { @@ -267,4 +267,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} From 29078959aea78281a0068de6d5b4a154e9e4b08e Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Wed, 26 Jul 2023 10:40:00 -0700 Subject: [PATCH 027/425] Bump version to 0.0.0.beta7 (#178) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index faed4b0d..cbb60297 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta6" +__version__ = "0.0.0.beta7" from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index b316dea0..0b9759db 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta6" +version = "0.0.0.beta7" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index c78a082c..67b6d4f0 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta6", + version="0.0.0.beta7", packages=find_packages(), ) From f2233724213a0c9f2a56e7b766abfb3702366741 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Thu, 27 Jul 2023 15:55:21 -0700 Subject: [PATCH 028/425] updating api key settings so that we can set api key without environment var (#180) --- clients/python/llmengine/api_engine.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index 16fd3a72..e6279095 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -13,8 +13,15 @@ LLM_ENGINE_BASE_PATH = os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) DEFAULT_TIMEOUT: int = 10 +api_key = None + +def set_api_key(key): + global api_key + api_key = key def get_api_key() -> str: + if api_key is not None: + return api_key env_api_key = os.getenv("SCALE_API_KEY") return env_api_key or "root" From d2b84f3ed6e18162994826d705b195c9628dad91 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 27 Jul 2023 17:41:01 -0700 Subject: [PATCH 029/425] Correct A100 tag (#183) --- clients/python/llmengine/data_types.py | 2 +- server/llm_engine_server/domain/entities/gpu_type.py | 2 +- server/tests/unit/common/test_batch_jobs_dtos.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index b3fcbf58..cb7ad263 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -29,7 +29,7 @@ class GpuType(str, Enum): NVIDIA_TESLA_T4 = "nvidia-tesla-t4" NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" - NVIDIA_AMPERE_A100 = "nvidia-a100" + NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" class ModelEndpointType(str, Enum): diff --git a/server/llm_engine_server/domain/entities/gpu_type.py b/server/llm_engine_server/domain/entities/gpu_type.py index 99cfd1b4..a8c4ade4 100644 --- a/server/llm_engine_server/domain/entities/gpu_type.py +++ b/server/llm_engine_server/domain/entities/gpu_type.py @@ -6,4 +6,4 @@ class GpuType(str, Enum): NVIDIA_TESLA_T4 = "nvidia-tesla-t4" NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" - NVIDIA_AMPERE_A100 = "nvidia-a100" + NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" diff --git a/server/tests/unit/common/test_batch_jobs_dtos.py b/server/tests/unit/common/test_batch_jobs_dtos.py index b5f704f0..f6eb384e 100644 --- a/server/tests/unit/common/test_batch_jobs_dtos.py +++ b/server/tests/unit/common/test_batch_jobs_dtos.py @@ -24,10 +24,10 @@ def test_create_docker_image_batch_job_resource_requests_merge_requests(): # Test merging default = CreateDockerImageBatchJobResourceRequests(cpus=0.5) override = CreateDockerImageBatchJobResourceRequests( - memory="100Mi", gpus=1, gpu_type="nvidia-a100", storage="10Gi" + memory="100Mi", gpus=1, gpu_type="nvidia-ampere-a100", storage="10Gi" ) expected = CreateDockerImageBatchJobResourceRequests( - cpus=0.5, memory="100Mi", gpus=1, gpu_type="nvidia-a100", storage="10Gi" + cpus=0.5, memory="100Mi", gpus=1, gpu_type="nvidia-ampere-a100", storage="10Gi" ) actual = CreateDockerImageBatchJobResourceRequests.merge_requests(default, override) assert expected == actual From e044a9625587cf9be1ba03240ffff6169af44244 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 28 Jul 2023 10:42:15 -0700 Subject: [PATCH 030/425] Support checkpoint_path for endpoint creation (#181) * Support checkpoint_path for endpoint creation * comment * comments --- clients/python/llmengine/api_engine.py | 5 +++++ clients/python/llmengine/data_types.py | 23 ++++++++++++++++++----- clients/python/llmengine/model.py | 22 ++++++++++++++++++++-- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index e6279095..cfd87e88 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -51,6 +51,7 @@ def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: os.path.join(LLM_ENGINE_BASE_PATH, resource_name), timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -67,6 +68,7 @@ def put( json=data, timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -80,6 +82,7 @@ def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: os.path.join(LLM_ENGINE_BASE_PATH, resource_name), timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -94,6 +97,7 @@ def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Di json=data, timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -110,6 +114,7 @@ def post_stream( json=data, timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), stream=True, ) if response.status_code != 200: diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index cb7ad263..49cb7645 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -131,17 +131,27 @@ class CreateLLMEndpointRequest(BaseModel): # LLM specific fields model_name: str source: LLMSource = LLMSource.HUGGING_FACE - inference_framework: LLMInferenceFramework = LLMInferenceFramework.DEEPSPEED + inference_framework: LLMInferenceFramework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE inference_framework_image_tag: str - num_shards: int + num_shards: int = 1 """ - Number of shards to distribute the model onto GPUs. + Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models + """ + + quantize: Optional[Quantization] = None + """ + Quantization for the LLM. Only affects behavior for text-generation-inference models + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models """ # General endpoint fields metadata: Dict[str, Any] # TODO: JSON type post_inference_hooks: Optional[List[str]] - endpoint_type: ModelEndpointType = ModelEndpointType.SYNC + endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING cpus: CpuSpecificationType gpus: int memory: StorageSpecificationType @@ -156,7 +166,10 @@ class CreateLLMEndpointRequest(BaseModel): high_priority: Optional[bool] default_callback_url: Optional[HttpUrl] default_callback_auth: Optional[CallbackAuth] - public_inference: Optional[bool] = True # LLM endpoints are public by default. + public_inference: Optional[bool] = True + """ + Whether the endpoint can be used for inference for all users. LLM endpoints are public by default. + """ class CreateLLMEndpointResponse(BaseModel): diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 1b6c9dba..1c854eba 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -12,6 +12,7 @@ LLMSource, ModelEndpointType, PostInferenceHooks, + Quantization, ) @@ -28,12 +29,15 @@ class Model(APIEngine): @assert_self_hosted def create( cls, + name: str, # LLM specific fields model: str, inference_framework_image_tag: str, source: LLMSource = LLMSource.HUGGING_FACE, inference_framework: LLMInferenceFramework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE, num_shards: int = 4, + quantize: Optional[Quantization] = None, + checkpoint_path: Optional[str] = None, # General endpoint fields cpus: int = 32, memory: str = "192Gi", @@ -53,8 +57,11 @@ def create( """ Create an LLM model. Note: This feature is only available for self-hosted users. Args: + name (`str`): + Name of the endpoint + model (`str`): - Name of the model + Name of the base model inference_framework_image_tag (`str`): Image tag for the inference framework @@ -68,6 +75,15 @@ def create( num_shards (`int`): Number of shards for the LLM. When bigger than 1, LLM will be sharded to multiple GPUs. Number of GPUs must be larger than num_shards. + Only affects behavior for text-generation-inference models + + quantize (`Optional[Quantization]`): + Quantization for the LLM. Only affects behavior for text-generation-inference models + + checkpoint_path (`Optional[str]`): + Path to the checkpoint for the LLM. For now we only support loading a tar file from AWS S3. + Safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be slower). + Only affects behavior for text-generation-inference models cpus (`int`): Number of cpus each worker should get, e.g. 1, 2, etc. This must be greater @@ -157,12 +173,14 @@ def create( post_inference_hooks_strs.append(hook) request = CreateLLMEndpointRequest( - name=model, + name=name, model_name=model, source=source, inference_framework=inference_framework, inference_framework_image_tag=inference_framework_image_tag, num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, cpus=cpus, endpoint_type=ModelEndpointType(endpoint_type), gpus=gpus, From 40f0980b68f58f98d9eb66212160acbe2ae83e70 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Mon, 31 Jul 2023 00:02:25 +0900 Subject: [PATCH 031/425] Fix typo in roles.py (#150) seperated -> separated Co-authored-by: Phil Chen <92065453+phil-scale@users.noreply.github.com> --- server/llm_engine_server/core/aws/roles.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/llm_engine_server/core/aws/roles.py b/server/llm_engine_server/core/aws/roles.py index 52c71d48..65548087 100644 --- a/server/llm_engine_server/core/aws/roles.py +++ b/server/llm_engine_server/core/aws/roles.py @@ -196,14 +196,14 @@ def parse_arn_string(arn: str) -> ArnData: if not 2 <= len(bits) <= 3: raise ValueError( f"Invalid format for AWS ARN string: {arn} -- " - f"Expecting either 2 or 3 parts seperated by '/'" + f"Expecting either 2 or 3 parts separated by '/'" ) account_and_source: List[str] = bits[0].split("::") if len(account_and_source) != 2: raise ValueError( f"Expecting ARN string to have 2 parts in the first '/' part, " - f"seperated by '::'. Instead found {account_and_source} from " + f"separated by '::'. Instead found {account_and_source} from " f"arn={arn}" ) @@ -234,8 +234,8 @@ def parse_arn_string(arn: str) -> ArnData: except ValueError as err: raise ValueError( "ARN format invalid: expecting account ID to appear as 2nd to last " - "value seperated by ':' within the first value seperated by '/' and " - "second value seperated by '::' -- " + "value separated by ':' within the first value separated by '/' and " + "second value separated by '::' -- " f"arn={arn} and expecting {account_str} to be account ID" ) from err From 2701948128a73996b6c93075238e3b39f3ed1a49 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 30 Jul 2023 08:45:16 -0700 Subject: [PATCH 032/425] Bump aiohttp from 3.8.4 to 3.8.5 in /clients/python (#151) * fix lint * Bump aiohttp from 3.8.4 to 3.8.5 in /clients/python Bumps [aiohttp](https://github.com/aio-libs/aiohttp) from 3.8.4 to 3.8.5. - [Release notes](https://github.com/aio-libs/aiohttp/releases) - [Changelog](https://github.com/aio-libs/aiohttp/blob/v3.8.5/CHANGES.rst) - [Commits](https://github.com/aio-libs/aiohttp/compare/v3.8.4...v3.8.5) --- updated-dependencies: - dependency-name: aiohttp dependency-type: direct:production ... Signed-off-by: dependabot[bot] * fix types * fix * fix * fix --------- Signed-off-by: dependabot[bot] Co-authored-by: Phil Chen Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- clients/python/llmengine/api_engine.py | 6 +- clients/python/poetry.lock | 176 +++++++++--------- .../use_cases/llm_fine_tuning_use_cases.py | 6 +- 3 files changed, 95 insertions(+), 93 deletions(-) diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index cfd87e88..adcba342 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -15,12 +15,14 @@ api_key = None -def set_api_key(key): + +def set_api_key(key): global api_key api_key = key + def get_api_key() -> str: - if api_key is not None: + if api_key is not None: return api_key env_api_key = os.getenv("SCALE_API_KEY") return env_api_key or "root" diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock index 99b869f9..bee5a561 100644 --- a/clients/python/poetry.lock +++ b/clients/python/poetry.lock @@ -2,98 +2,98 @@ [[package]] name = "aiohttp" -version = "3.8.4" +version = "3.8.5" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.6" files = [ - {file = "aiohttp-3.8.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5ce45967538fb747370308d3145aa68a074bdecb4f3a300869590f725ced69c1"}, - {file = "aiohttp-3.8.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b744c33b6f14ca26b7544e8d8aadff6b765a80ad6164fb1a430bbadd593dfb1a"}, - {file = "aiohttp-3.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a45865451439eb320784918617ba54b7a377e3501fb70402ab84d38c2cd891b"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a86d42d7cba1cec432d47ab13b6637bee393a10f664c425ea7b305d1301ca1a3"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee3c36df21b5714d49fc4580247947aa64bcbe2939d1b77b4c8dcb8f6c9faecc"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:176a64b24c0935869d5bbc4c96e82f89f643bcdf08ec947701b9dbb3c956b7dd"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c844fd628851c0bc309f3c801b3a3d58ce430b2ce5b359cd918a5a76d0b20cb5"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5393fb786a9e23e4799fec788e7e735de18052f83682ce2dfcabaf1c00c2c08e"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e4b09863aae0dc965c3ef36500d891a3ff495a2ea9ae9171e4519963c12ceefd"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:adfbc22e87365a6e564c804c58fc44ff7727deea782d175c33602737b7feadb6"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:147ae376f14b55f4f3c2b118b95be50a369b89b38a971e80a17c3fd623f280c9"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:eafb3e874816ebe2a92f5e155f17260034c8c341dad1df25672fb710627c6949"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c6cc15d58053c76eacac5fa9152d7d84b8d67b3fde92709195cb984cfb3475ea"}, - {file = "aiohttp-3.8.4-cp310-cp310-win32.whl", hash = "sha256:59f029a5f6e2d679296db7bee982bb3d20c088e52a2977e3175faf31d6fb75d1"}, - {file = "aiohttp-3.8.4-cp310-cp310-win_amd64.whl", hash = "sha256:fe7ba4a51f33ab275515f66b0a236bcde4fb5561498fe8f898d4e549b2e4509f"}, - {file = "aiohttp-3.8.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3d8ef1a630519a26d6760bc695842579cb09e373c5f227a21b67dc3eb16cfea4"}, - {file = "aiohttp-3.8.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b3f2e06a512e94722886c0827bee9807c86a9f698fac6b3aee841fab49bbfb4"}, - {file = "aiohttp-3.8.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3a80464982d41b1fbfe3154e440ba4904b71c1a53e9cd584098cd41efdb188ef"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b631e26df63e52f7cce0cce6507b7a7f1bc9b0c501fcde69742130b32e8782f"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3f43255086fe25e36fd5ed8f2ee47477408a73ef00e804cb2b5cba4bf2ac7f5e"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4d347a172f866cd1d93126d9b239fcbe682acb39b48ee0873c73c933dd23bd0f"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3fec6a4cb5551721cdd70473eb009d90935b4063acc5f40905d40ecfea23e05"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80a37fe8f7c1e6ce8f2d9c411676e4bc633a8462844e38f46156d07a7d401654"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d1e6a862b76f34395a985b3cd39a0d949ca80a70b6ebdea37d3ab39ceea6698a"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd468460eefef601ece4428d3cf4562459157c0f6523db89365202c31b6daebb"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:618c901dd3aad4ace71dfa0f5e82e88b46ef57e3239fc7027773cb6d4ed53531"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:652b1bff4f15f6287550b4670546a2947f2a4575b6c6dff7760eafb22eacbf0b"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80575ba9377c5171407a06d0196b2310b679dc752d02a1fcaa2bc20b235dbf24"}, - {file = "aiohttp-3.8.4-cp311-cp311-win32.whl", hash = "sha256:bbcf1a76cf6f6dacf2c7f4d2ebd411438c275faa1dc0c68e46eb84eebd05dd7d"}, - {file = "aiohttp-3.8.4-cp311-cp311-win_amd64.whl", hash = "sha256:6e74dd54f7239fcffe07913ff8b964e28b712f09846e20de78676ce2a3dc0bfc"}, - {file = "aiohttp-3.8.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:880e15bb6dad90549b43f796b391cfffd7af373f4646784795e20d92606b7a51"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb96fa6b56bb536c42d6a4a87dfca570ff8e52de2d63cabebfd6fb67049c34b6"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a6cadebe132e90cefa77e45f2d2f1a4b2ce5c6b1bfc1656c1ddafcfe4ba8131"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f352b62b45dff37b55ddd7b9c0c8672c4dd2eb9c0f9c11d395075a84e2c40f75"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ab43061a0c81198d88f39aaf90dae9a7744620978f7ef3e3708339b8ed2ef01"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9cb1565a7ad52e096a6988e2ee0397f72fe056dadf75d17fa6b5aebaea05622"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:1b3ea7edd2d24538959c1c1abf97c744d879d4e541d38305f9bd7d9b10c9ec41"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:7c7837fe8037e96b6dd5cfcf47263c1620a9d332a87ec06a6ca4564e56bd0f36"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:3b90467ebc3d9fa5b0f9b6489dfb2c304a1db7b9946fa92aa76a831b9d587e99"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:cab9401de3ea52b4b4c6971db5fb5c999bd4260898af972bf23de1c6b5dd9d71"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d1f9282c5f2b5e241034a009779e7b2a1aa045f667ff521e7948ea9b56e0c5ff"}, - {file = "aiohttp-3.8.4-cp36-cp36m-win32.whl", hash = "sha256:5e14f25765a578a0a634d5f0cd1e2c3f53964553a00347998dfdf96b8137f777"}, - {file = "aiohttp-3.8.4-cp36-cp36m-win_amd64.whl", hash = "sha256:4c745b109057e7e5f1848c689ee4fb3a016c8d4d92da52b312f8a509f83aa05e"}, - {file = "aiohttp-3.8.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:aede4df4eeb926c8fa70de46c340a1bc2c6079e1c40ccf7b0eae1313ffd33519"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ddaae3f3d32fc2cb4c53fab020b69a05c8ab1f02e0e59665c6f7a0d3a5be54f"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4eb3b82ca349cf6fadcdc7abcc8b3a50ab74a62e9113ab7a8ebc268aad35bb9"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bcb89336efa095ea21b30f9e686763f2be4478f1b0a616969551982c4ee4c3b"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c08e8ed6fa3d477e501ec9db169bfac8140e830aa372d77e4a43084d8dd91ab"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c6cd05ea06daca6ad6a4ca3ba7fe7dc5b5de063ff4daec6170ec0f9979f6c332"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7a00a9ed8d6e725b55ef98b1b35c88013245f35f68b1b12c5cd4100dddac333"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:de04b491d0e5007ee1b63a309956eaed959a49f5bb4e84b26c8f5d49de140fa9"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:40653609b3bf50611356e6b6554e3a331f6879fa7116f3959b20e3528783e699"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:dbf3a08a06b3f433013c143ebd72c15cac33d2914b8ea4bea7ac2c23578815d6"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:854f422ac44af92bfe172d8e73229c270dc09b96535e8a548f99c84f82dde241"}, - {file = "aiohttp-3.8.4-cp37-cp37m-win32.whl", hash = "sha256:aeb29c84bb53a84b1a81c6c09d24cf33bb8432cc5c39979021cc0f98c1292a1a"}, - {file = "aiohttp-3.8.4-cp37-cp37m-win_amd64.whl", hash = "sha256:db3fc6120bce9f446d13b1b834ea5b15341ca9ff3f335e4a951a6ead31105480"}, - {file = "aiohttp-3.8.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:fabb87dd8850ef0f7fe2b366d44b77d7e6fa2ea87861ab3844da99291e81e60f"}, - {file = "aiohttp-3.8.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:91f6d540163f90bbaef9387e65f18f73ffd7c79f5225ac3d3f61df7b0d01ad15"}, - {file = "aiohttp-3.8.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d265f09a75a79a788237d7f9054f929ced2e69eb0bb79de3798c468d8a90f945"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d89efa095ca7d442a6d0cbc755f9e08190ba40069b235c9886a8763b03785da"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4dac314662f4e2aa5009977b652d9b8db7121b46c38f2073bfeed9f4049732cd"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe11310ae1e4cd560035598c3f29d86cef39a83d244c7466f95c27ae04850f10"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ddb2a2026c3f6a68c3998a6c47ab6795e4127315d2e35a09997da21865757f8"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e75b89ac3bd27d2d043b234aa7b734c38ba1b0e43f07787130a0ecac1e12228a"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6e601588f2b502c93c30cd5a45bfc665faaf37bbe835b7cfd461753068232074"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a5d794d1ae64e7753e405ba58e08fcfa73e3fad93ef9b7e31112ef3c9a0efb52"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a1f4689c9a1462f3df0a1f7e797791cd6b124ddbee2b570d34e7f38ade0e2c71"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3032dcb1c35bc330134a5b8a5d4f68c1a87252dfc6e1262c65a7e30e62298275"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8189c56eb0ddbb95bfadb8f60ea1b22fcfa659396ea36f6adcc521213cd7b44d"}, - {file = "aiohttp-3.8.4-cp38-cp38-win32.whl", hash = "sha256:33587f26dcee66efb2fff3c177547bd0449ab7edf1b73a7f5dea1e38609a0c54"}, - {file = "aiohttp-3.8.4-cp38-cp38-win_amd64.whl", hash = "sha256:e595432ac259af2d4630008bf638873d69346372d38255774c0e286951e8b79f"}, - {file = "aiohttp-3.8.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5a7bdf9e57126dc345b683c3632e8ba317c31d2a41acd5800c10640387d193ed"}, - {file = "aiohttp-3.8.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:22f6eab15b6db242499a16de87939a342f5a950ad0abaf1532038e2ce7d31567"}, - {file = "aiohttp-3.8.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7235604476a76ef249bd64cb8274ed24ccf6995c4a8b51a237005ee7a57e8643"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea9eb976ffdd79d0e893869cfe179a8f60f152d42cb64622fca418cd9b18dc2a"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:92c0cea74a2a81c4c76b62ea1cac163ecb20fb3ba3a75c909b9fa71b4ad493cf"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:493f5bc2f8307286b7799c6d899d388bbaa7dfa6c4caf4f97ef7521b9cb13719"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a63f03189a6fa7c900226e3ef5ba4d3bd047e18f445e69adbd65af433add5a2"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10c8cefcff98fd9168cdd86c4da8b84baaa90bf2da2269c6161984e6737bf23e"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bca5f24726e2919de94f047739d0a4fc01372801a3672708260546aa2601bf57"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:03baa76b730e4e15a45f81dfe29a8d910314143414e528737f8589ec60cf7391"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:8c29c77cc57e40f84acef9bfb904373a4e89a4e8b74e71aa8075c021ec9078c2"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:03543dcf98a6619254b409be2d22b51f21ec66272be4ebda7b04e6412e4b2e14"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:17b79c2963db82086229012cff93ea55196ed31f6493bb1ccd2c62f1724324e4"}, - {file = "aiohttp-3.8.4-cp39-cp39-win32.whl", hash = "sha256:34ce9f93a4a68d1272d26030655dd1b58ff727b3ed2a33d80ec433561b03d67a"}, - {file = "aiohttp-3.8.4-cp39-cp39-win_amd64.whl", hash = "sha256:41a86a69bb63bb2fc3dc9ad5ea9f10f1c9c8e282b471931be0268ddd09430b04"}, - {file = "aiohttp-3.8.4.tar.gz", hash = "sha256:bf2e1a9162c1e441bf805a1fd166e249d574ca04e03b34f97e2928769e91ab5c"}, + {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a94159871304770da4dd371f4291b20cac04e8c94f11bdea1c3478e557fbe0d8"}, + {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:13bf85afc99ce6f9ee3567b04501f18f9f8dbbb2ea11ed1a2e079670403a7c84"}, + {file = "aiohttp-3.8.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ce2ac5708501afc4847221a521f7e4b245abf5178cf5ddae9d5b3856ddb2f3a"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96943e5dcc37a6529d18766597c491798b7eb7a61d48878611298afc1fca946c"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ad5c3c4590bb3cc28b4382f031f3783f25ec223557124c68754a2231d989e2b"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c413c633d0512df4dc7fd2373ec06cc6a815b7b6d6c2f208ada7e9e93a5061d"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df72ac063b97837a80d80dec8d54c241af059cc9bb42c4de68bd5b61ceb37caa"}, + {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c48c5c0271149cfe467c0ff8eb941279fd6e3f65c9a388c984e0e6cf57538e14"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:368a42363c4d70ab52c2c6420a57f190ed3dfaca6a1b19afda8165ee16416a82"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7607ec3ce4993464368505888af5beb446845a014bc676d349efec0e05085905"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0d21c684808288a98914e5aaf2a7c6a3179d4df11d249799c32d1808e79503b5"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:312fcfbacc7880a8da0ae8b6abc6cc7d752e9caa0051a53d217a650b25e9a691"}, + {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad093e823df03bb3fd37e7dec9d4670c34f9e24aeace76808fc20a507cace825"}, + {file = "aiohttp-3.8.5-cp310-cp310-win32.whl", hash = "sha256:33279701c04351a2914e1100b62b2a7fdb9a25995c4a104259f9a5ead7ed4802"}, + {file = "aiohttp-3.8.5-cp310-cp310-win_amd64.whl", hash = "sha256:6e4a280e4b975a2e7745573e3fc9c9ba0d1194a3738ce1cbaa80626cc9b4f4df"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae871a964e1987a943d83d6709d20ec6103ca1eaf52f7e0d36ee1b5bebb8b9b9"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:461908b2578955045efde733719d62f2b649c404189a09a632d245b445c9c975"}, + {file = "aiohttp-3.8.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:72a860c215e26192379f57cae5ab12b168b75db8271f111019509a1196dfc780"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc14be025665dba6202b6a71cfcdb53210cc498e50068bc088076624471f8bb9"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8af740fc2711ad85f1a5c034a435782fbd5b5f8314c9a3ef071424a8158d7f6b"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:841cd8233cbd2111a0ef0a522ce016357c5e3aff8a8ce92bcfa14cef890d698f"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed1c46fb119f1b59304b5ec89f834f07124cd23ae5b74288e364477641060ff"}, + {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84f8ae3e09a34f35c18fa57f015cc394bd1389bce02503fb30c394d04ee6b938"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62360cb771707cb70a6fd114b9871d20d7dd2163a0feafe43fd115cfe4fe845e"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:23fb25a9f0a1ca1f24c0a371523546366bb642397c94ab45ad3aedf2941cec6a"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0ba0d15164eae3d878260d4c4df859bbdc6466e9e6689c344a13334f988bb53"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5d20003b635fc6ae3f96d7260281dfaf1894fc3aa24d1888a9b2628e97c241e5"}, + {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0175d745d9e85c40dcc51c8f88c74bfbaef9e7afeeeb9d03c37977270303064c"}, + {file = "aiohttp-3.8.5-cp311-cp311-win32.whl", hash = "sha256:2e1b1e51b0774408f091d268648e3d57f7260c1682e7d3a63cb00d22d71bb945"}, + {file = "aiohttp-3.8.5-cp311-cp311-win_amd64.whl", hash = "sha256:043d2299f6dfdc92f0ac5e995dfc56668e1587cea7f9aa9d8a78a1b6554e5755"}, + {file = "aiohttp-3.8.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cae533195e8122584ec87531d6df000ad07737eaa3c81209e85c928854d2195c"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f21e83f355643c345177a5d1d8079f9f28b5133bcd154193b799d380331d5d3"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a75ef35f2df54ad55dbf4b73fe1da96f370e51b10c91f08b19603c64004acc"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e2e9839e14dd5308ee773c97115f1e0a1cb1d75cbeeee9f33824fa5144c7634"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44e65da1de4403d0576473e2344828ef9c4c6244d65cf4b75549bb46d40b8dd"}, + {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d847e4cde6ecc19125ccbc9bfac4a7ab37c234dd88fbb3c5c524e8e14da543"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:c7a815258e5895d8900aec4454f38dca9aed71085f227537208057853f9d13f2"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:8b929b9bd7cd7c3939f8bcfffa92fae7480bd1aa425279d51a89327d600c704d"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:5db3a5b833764280ed7618393832e0853e40f3d3e9aa128ac0ba0f8278d08649"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:a0215ce6041d501f3155dc219712bc41252d0ab76474615b9700d63d4d9292af"}, + {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:fd1ed388ea7fbed22c4968dd64bab0198de60750a25fe8c0c9d4bef5abe13824"}, + {file = "aiohttp-3.8.5-cp36-cp36m-win32.whl", hash = "sha256:6e6783bcc45f397fdebc118d772103d751b54cddf5b60fbcc958382d7dd64f3e"}, + {file = "aiohttp-3.8.5-cp36-cp36m-win_amd64.whl", hash = "sha256:b5411d82cddd212644cf9360879eb5080f0d5f7d809d03262c50dad02f01421a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:01d4c0c874aa4ddfb8098e85d10b5e875a70adc63db91f1ae65a4b04d3344cda"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5980a746d547a6ba173fd5ee85ce9077e72d118758db05d229044b469d9029a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a482e6da906d5e6e653be079b29bc173a48e381600161c9932d89dfae5942ef"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80bd372b8d0715c66c974cf57fe363621a02f359f1ec81cba97366948c7fc873"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1161b345c0a444ebcf46bf0a740ba5dcf50612fd3d0528883fdc0eff578006a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd56db019015b6acfaaf92e1ac40eb8434847d9bf88b4be4efe5bfd260aee692"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:153c2549f6c004d2754cc60603d4668899c9895b8a89397444a9c4efa282aaf4"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4a01951fabc4ce26ab791da5f3f24dca6d9a6f24121746eb19756416ff2d881b"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bfb9162dcf01f615462b995a516ba03e769de0789de1cadc0f916265c257e5d8"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7dde0009408969a43b04c16cbbe252c4f5ef4574ac226bc8815cd7342d2028b6"}, + {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4149d34c32f9638f38f544b3977a4c24052042affa895352d3636fa8bffd030a"}, + {file = "aiohttp-3.8.5-cp37-cp37m-win32.whl", hash = "sha256:68c5a82c8779bdfc6367c967a4a1b2aa52cd3595388bf5961a62158ee8a59e22"}, + {file = "aiohttp-3.8.5-cp37-cp37m-win_amd64.whl", hash = "sha256:2cf57fb50be5f52bda004b8893e63b48530ed9f0d6c96c84620dc92fe3cd9b9d"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:eca4bf3734c541dc4f374ad6010a68ff6c6748f00451707f39857f429ca36ced"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1274477e4c71ce8cfe6c1ec2f806d57c015ebf84d83373676036e256bc55d690"}, + {file = "aiohttp-3.8.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28c543e54710d6158fc6f439296c7865b29e0b616629767e685a7185fab4a6b9"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910bec0c49637d213f5d9877105d26e0c4a4de2f8b1b29405ff37e9fc0ad52b8"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5443910d662db951b2e58eb70b0fbe6b6e2ae613477129a5805d0b66c54b6cb7"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e460be6978fc24e3df83193dc0cc4de46c9909ed92dd47d349a452ef49325b7"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1558def481d84f03b45888473fc5a1f35747b5f334ef4e7a571bc0dfcb11f8"}, + {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34dd0c107799dcbbf7d48b53be761a013c0adf5571bf50c4ecad5643fe9cfcd0"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aa1990247f02a54185dc0dff92a6904521172a22664c863a03ff64c42f9b5410"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0e584a10f204a617d71d359fe383406305a4b595b333721fa50b867b4a0a1548"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a3cf433f127efa43fee6b90ea4c6edf6c4a17109d1d037d1a52abec84d8f2e42"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c11f5b099adafb18e65c2c997d57108b5bbeaa9eeee64a84302c0978b1ec948b"}, + {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:84de26ddf621d7ac4c975dbea4c945860e08cccde492269db4e1538a6a6f3c35"}, + {file = "aiohttp-3.8.5-cp38-cp38-win32.whl", hash = "sha256:ab88bafedc57dd0aab55fa728ea10c1911f7e4d8b43e1d838a1739f33712921c"}, + {file = "aiohttp-3.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:5798a9aad1879f626589f3df0f8b79b3608a92e9beab10e5fda02c8a2c60db2e"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a6ce61195c6a19c785df04e71a4537e29eaa2c50fe745b732aa937c0c77169f3"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:773dd01706d4db536335fcfae6ea2440a70ceb03dd3e7378f3e815b03c97ab51"}, + {file = "aiohttp-3.8.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f83a552443a526ea38d064588613aca983d0ee0038801bc93c0c916428310c28"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f7372f7341fcc16f57b2caded43e81ddd18df53320b6f9f042acad41f8e049a"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea353162f249c8097ea63c2169dd1aa55de1e8fecbe63412a9bc50816e87b761"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d47ae48db0b2dcf70bc8a3bc72b3de86e2a590fc299fdbbb15af320d2659de"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d827176898a2b0b09694fbd1088c7a31836d1a505c243811c87ae53a3f6273c1"}, + {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3562b06567c06439d8b447037bb655ef69786c590b1de86c7ab81efe1c9c15d8"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4e874cbf8caf8959d2adf572a78bba17cb0e9d7e51bb83d86a3697b686a0ab4d"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6809a00deaf3810e38c628e9a33271892f815b853605a936e2e9e5129762356c"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:33776e945d89b29251b33a7e7d006ce86447b2cfd66db5e5ded4e5cd0340585c"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eaeed7abfb5d64c539e2db173f63631455f1196c37d9d8d873fc316470dfbacd"}, + {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e91d635961bec2d8f19dfeb41a539eb94bd073f075ca6dae6c8dc0ee89ad6f91"}, + {file = "aiohttp-3.8.5-cp39-cp39-win32.whl", hash = "sha256:00ad4b6f185ec67f3e6562e8a1d2b69660be43070bd0ef6fcec5211154c7df67"}, + {file = "aiohttp-3.8.5-cp39-cp39-win_amd64.whl", hash = "sha256:c0a9034379a37ae42dea7ac1e048352d96286626251862e448933c0f59cbd79c"}, + {file = "aiohttp-3.8.5.tar.gz", hash = "sha256:b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"}, ] [package.dependencies] diff --git a/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py index 34ef2172..78ca4bf1 100644 --- a/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ b/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -28,7 +28,7 @@ async def execute( hyperparameters=request.hyperparameters, ) return CreateFineTuneJobResponse( - fine_tune_id=fine_tune_id, + id=fine_tune_id, ) @@ -44,7 +44,7 @@ async def execute(self, user: User, fine_tune_id: str) -> GetFineTuneJobResponse if di_batch_job is None: raise ObjectNotFoundException return GetFineTuneJobResponse( - fine_tune_id=di_batch_job.id, + id=fine_tune_id, status=di_batch_job.status, ) @@ -60,7 +60,7 @@ async def execute(self, user: User) -> ListFineTuneJobResponse: return ListFineTuneJobResponse( jobs=[ GetFineTuneJobResponse( - fine_tune_id=job.id, + id=job.id, status=job.status, ) for job in di_batch_jobs From 61dc135c707cf522fcad6814be1baa63f0cf812c Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 31 Jul 2023 09:35:51 -0700 Subject: [PATCH 033/425] Add llama 2 70B in model zoo (#185) * Add llama 2 70B in model zoo * black --- docs/model_zoo.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 9cfcbc02..da07c287 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -9,6 +9,7 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `llama-2-7b-chat` | ✅ | | | `llama-2-13b` | ✅ | | | `llama-2-13b-chat` | ✅ | | +| `llama-2-70b` | ✅ | | | `falcon-7b` | ✅ | | | `falcon-7b-instruct` | ✅ | | | `falcon-40b` | ✅ | | From 0f908dc6d2f59f3abc74264b21ba44ce460a2f42 Mon Sep 17 00:00:00 2001 From: Phil Chen <92065453+phil-scale@users.noreply.github.com> Date: Tue, 1 Aug 2023 17:36:53 +0800 Subject: [PATCH 034/425] add CNAME (#191) --- docs/CNAME | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/CNAME diff --git a/docs/CNAME b/docs/CNAME new file mode 100644 index 00000000..bd01b7f9 --- /dev/null +++ b/docs/CNAME @@ -0,0 +1 @@ +llm-engine.scale.com From 24f6a3206218af7aec073d3e8801bbbc350406b9 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 1 Aug 2023 16:28:27 -0700 Subject: [PATCH 035/425] File API functions (#160) --- clients/python/llmengine/__init__.py | 14 +- clients/python/llmengine/api_engine.py | 17 ++ clients/python/llmengine/data_types.py | 44 ++++ clients/python/llmengine/file.py | 193 ++++++++++++++++++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- docs/api/data_types.md | 10 + docs/api/python_client.md | 9 + server/llm_engine_server/api/batch_jobs_v1.py | 1 - .../domain/use_cases/batch_job_use_cases.py | 1 - ...eaming_model_endpoint_inference_gateway.py | 1 - ...ocker_image_batch_job_bundle_repository.py | 1 - 12 files changed, 288 insertions(+), 7 deletions(-) create mode 100644 clients/python/llmengine/file.py diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index cbb60297..25dd9652 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta7" +__version__ = "0.0.0.beta8" from typing import Sequence @@ -25,12 +25,18 @@ CompletionSyncResponse, CreateFineTuneRequest, CreateFineTuneResponse, + DeleteFileResponse, DeleteLLMEndpointResponse, + GetFileContentResponse, + GetFileResponse, GetFineTuneResponse, GetLLMEndpointResponse, + ListFilesResponse, ListFineTunesResponse, ListLLMEndpointsResponse, + UploadFileResponse, ) +from llmengine.file import File from llmengine.fine_tuning import FineTune from llmengine.model import Model @@ -43,11 +49,17 @@ "CompletionSyncResponse", "CreateFineTuneRequest", "CreateFineTuneResponse", + "DeleteFileResponse", "DeleteLLMEndpointResponse", + "GetFileContentResponse", + "File", "FineTune", + "GetFileResponse", "GetFineTuneResponse", "GetLLMEndpointResponse", + "ListFilesResponse", "ListFineTunesResponse", "ListLLMEndpointsResponse", "Model", + "UploadFileResponse", ) diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index adcba342..089138b7 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -3,6 +3,7 @@ import json import os from functools import wraps +from io import BufferedReader from typing import Any, AsyncIterable, Dict, Iterator, Optional import requests @@ -138,6 +139,22 @@ def post_stream( except json.JSONDecodeError: raise ValueError(f"Invalid JSON payload: {payload_data}") + @classmethod + def post_file( + cls, resource_name: str, files: Dict[str, BufferedReader], timeout: int + ) -> Dict[str, Any]: + api_key = get_api_key() + response = requests.post( + os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + files=files, + timeout=timeout, + headers={"x-api-key": api_key}, + ) + if response.status_code != 200: + raise parse_error(response.status_code, response.content) + payload = response.json() + return payload + @classmethod async def apost_sync( cls, resource_name: str, data: Dict[str, Any], timeout: int diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 49cb7645..63aefbbd 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -453,3 +453,47 @@ class GetFineTuneEventsResponse(BaseModel): """ events: List[LLMFineTuneEvent] = Field(..., description="List of fine-tuning events.") + + +class UploadFileResponse(BaseModel): + """Response object for uploading a file.""" + + id: str = Field(..., description="ID of the uploaded file.") + """ID of the uploaded file.""" + + +class GetFileResponse(BaseModel): + """Response object for retrieving a file.""" + + id: str = Field(..., description="ID of the requested file.") + """ID of the requested file.""" + + filename: str = Field(..., description="File name.") + """File name.""" + + size: int = Field(..., description="Length of the file, in characters.") + """Length of the file, in characters.""" + + +class ListFilesResponse(BaseModel): + """Response object for listing files.""" + + files: List[GetFileResponse] = Field(..., description="List of file IDs, names, and sizes.") + """List of file IDs, names, and sizes.""" + + +class DeleteFileResponse(BaseModel): + """Response object for deleting a file.""" + + deleted: bool = Field(..., description="Whether deletion was successful.") + """Whether deletion was successful.""" + + +class GetFileContentResponse(BaseModel): + """Response object for retrieving a file's content.""" + + id: str = Field(..., description="ID of the requested file.") + """ID of the requested file.""" + + content: str = Field(..., description="File content.") + """File content.""" diff --git a/clients/python/llmengine/file.py b/clients/python/llmengine/file.py new file mode 100644 index 00000000..c3e7ae48 --- /dev/null +++ b/clients/python/llmengine/file.py @@ -0,0 +1,193 @@ +from io import BufferedReader + +from llmengine.api_engine import DEFAULT_TIMEOUT, APIEngine +from llmengine.data_types import ( + DeleteFileResponse, + GetFileContentResponse, + GetFileResponse, + ListFilesResponse, + UploadFileResponse, +) + + +class File(APIEngine): + """ + File API. This API is used to upload private files to LLM engine so that fine-tunes can access them for training and validation data. + + Functions are provided to upload, get, list, and delete files, as well as to get the contents of a file. + """ + + @classmethod + def upload(cls, file: BufferedReader) -> UploadFileResponse: + """ + Uploads a file to LLM engine. + + Args: + file (`BufferedReader`): + A file opened with open(file_path, "r") + + Returns: + UploadFileResponse: an object that contains the ID of the uploaded file + + === "Uploading file in Python" + ```python + from llmengine import File + + response = File.upload(open("training_dataset.csv", "r")) + + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "id": "file-abc123" + } + ``` + """ + files = {"file": file} + response = cls.post_file( + resource_name="v1/files", + files=files, + timeout=DEFAULT_TIMEOUT, + ) + return UploadFileResponse.parse_obj(response) + + @classmethod + def get(cls, file_id: str) -> GetFileResponse: + """ + Get file metadata, including filename and size. + + Args: + file_id (`str`): + ID of the file + + Returns: + GetFileResponse: an object that contains the ID, filename, and size of the requested file + + === "Getting metadata about file in Python" + ```python + from llmengine import File + + response = File.get( + file_id="file-abc123", + ) + + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "id": "file-abc123", + "filename": "training_dataset.csv", + "size": 100 + } + ``` + """ + response = cls._get(f"v1/files/{file_id}", timeout=DEFAULT_TIMEOUT) + return GetFileResponse.parse_obj(response) + + @classmethod + def list(cls) -> ListFilesResponse: + """ + List metadata about all files, e.g. their filenames and sizes. + + Returns: + ListFilesResponse: an object that contains a list of all files and their filenames and sizes + + === "Listing files in Python" + ```python + from llmengine import File + + response = File.list() + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "files": [ + { + "id": "file-abc123", + "filename": "training_dataset.csv", + "size": 100 + }, + { + "id": "file-def456", + "filename": "validation_dataset.csv", + "size": 50 + } + ] + } + ``` + """ + response = cls._get("v1/files", timeout=30) + return ListFilesResponse.parse_obj(response) + + @classmethod + def delete(cls, file_id: str) -> DeleteFileResponse: + """ + Deletes a file. + + Args: + file_id (`str`): + ID of the file + + Returns: + DeleteFileResponse: an object that contains whether the deletion was successful + + === "Deleting file in Python" + ```python + from llmengine import File + + response = File.delete(file_id="file-abc123") + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "deleted": true + } + ``` + """ + response = cls._delete( + f"v1/files/{file_id}", + timeout=DEFAULT_TIMEOUT, + ) + return DeleteFileResponse.parse_obj(response) + + @classmethod + def download(cls, file_id: str) -> GetFileContentResponse: + """ + Get contents of a file, as a string. (If the uploaded file is in binary, a string encoding will be returned.) + + Args: + file_id (`str`): + ID of the file + + Returns: + GetFileContentResponse: an object that contains the ID and content of the file + + === "Getting file content in Python" + ```python + from llmengine import File + + response = File.get_content(file_id="file-abc123") + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "id": "file-abc123", + "content": "Hello world!" + } + ``` + """ + response = cls._get( + f"v1/files/{file_id}/content", + timeout=DEFAULT_TIMEOUT, + ) + return GetFileContentResponse.parse_obj(response) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 0b9759db..6f229a8b 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta7" +version = "0.0.0.beta8" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 67b6d4f0..625c8984 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta7", + version="0.0.0.beta8", packages=find_packages(), ) diff --git a/docs/api/data_types.md b/docs/api/data_types.md index 55f33028..2d53a3bf 100644 --- a/docs/api/data_types.md +++ b/docs/api/data_types.md @@ -43,3 +43,13 @@ ::: llmengine.ListLLMEndpointsResponse ::: llmengine.DeleteLLMEndpointResponse + +::: llmengine.UploadFileResponse + +::: llmengine.GetFileResponse + +::: llmengine.GetFileContentResponse + +::: llmengine.ListFilesResponse + +::: llmengine.DeleteFileResponse diff --git a/docs/api/python_client.md b/docs/api/python_client.md index 820b2e56..e1a8b1f2 100644 --- a/docs/api/python_client.md +++ b/docs/api/python_client.md @@ -21,3 +21,12 @@ - get - list - delete + +::: llmengine.File + selection: + members: + - upload + - get + - get_content + - list + - delete diff --git a/server/llm_engine_server/api/batch_jobs_v1.py b/server/llm_engine_server/api/batch_jobs_v1.py index 86f46ff9..ac83425f 100644 --- a/server/llm_engine_server/api/batch_jobs_v1.py +++ b/server/llm_engine_server/api/batch_jobs_v1.py @@ -125,7 +125,6 @@ async def create_docker_image_batch_job( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> CreateDockerImageBatchJobV1Response: - add_trace_resource_name("batch_jobs_di_create") logger.info(f"POST /docker-image-batch-jobs with {request} for {auth}") try: diff --git a/server/llm_engine_server/domain/use_cases/batch_job_use_cases.py b/server/llm_engine_server/domain/use_cases/batch_job_use_cases.py index e6710313..7b5f7520 100644 --- a/server/llm_engine_server/domain/use_cases/batch_job_use_cases.py +++ b/server/llm_engine_server/domain/use_cases/batch_job_use_cases.py @@ -175,7 +175,6 @@ def __init__( async def execute( self, user: User, request: CreateDockerImageBatchJobV1Request ) -> CreateDockerImageBatchJobV1Response: - if request.docker_image_batch_job_bundle_id is not None: batch_bundle = await self.docker_image_batch_job_bundle_repository.get_docker_image_batch_job_bundle( request.docker_image_batch_job_bundle_id diff --git a/server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index 0c812e8f..1bbcbc11 100644 --- a/server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -72,7 +72,6 @@ def __init__(self, use_asyncio: bool): async def make_single_request(self, request_url: str, payload_json: Dict[str, Any]): errored = False if self.use_asyncio: - async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: aio_resp = await aioclient.post( request_url, diff --git a/server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py b/server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py index 81d172ad..579d13d0 100644 --- a/server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py +++ b/server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py @@ -106,7 +106,6 @@ async def test_list_docker_image_batch_job_bundles( test_api_key: str, test_api_key_team: str, ): - orm_docker_image_batch_job_bundle_1_v2.created_by = test_api_key_team orm_docker_image_batch_job_bundle_1_v2.owner = test_api_key_team docker_image_batch_job_bundle_1_v2.created_by = test_api_key_team From 582dd0ca537cc14469bfefda0524c938cafb2dd4 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 1 Aug 2023 18:47:21 -0700 Subject: [PATCH 036/425] Fix File documentation (#192) --- clients/python/llmengine/file.py | 2 +- docs/api/python_client.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/clients/python/llmengine/file.py b/clients/python/llmengine/file.py index c3e7ae48..097fed1f 100644 --- a/clients/python/llmengine/file.py +++ b/clients/python/llmengine/file.py @@ -174,7 +174,7 @@ def download(cls, file_id: str) -> GetFileContentResponse: ```python from llmengine import File - response = File.get_content(file_id="file-abc123") + response = File.download(file_id="file-abc123") print(response.json()) ``` diff --git a/docs/api/python_client.md b/docs/api/python_client.md index e1a8b1f2..8b3fdc1f 100644 --- a/docs/api/python_client.md +++ b/docs/api/python_client.md @@ -27,6 +27,6 @@ members: - upload - get - - get_content + - download - list - delete From 12d0e9af3f7f51b02b1d35a73ab74db83ef9eb50 Mon Sep 17 00:00:00 2001 From: Phil Chen <92065453+phil-scale@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:14:00 +0800 Subject: [PATCH 037/425] Deploy docs from CI (#190) --- .circleci/config.yml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1afcd413..77c2a7d1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -10,6 +10,11 @@ workflows: - integration_tests - build_image - build_docs + - deploy_docs: + filters: + branches: + only: + - main jobs: run_unit_tests_python_client: @@ -62,6 +67,25 @@ jobs: name: Build Docs command: | mkdocs build --strict + deploy_docs: + docker: + - image: python:3.8-bookworm + resource_class: small + parallelism: 1 + steps: + - add_ssh_keys: # gives write access to CircleCI worker + fingerprints: + - "76:0c:1b:9e:e3:6a:c3:5c:6f:24:91:ef:7c:54:d2:7a" + - checkout # checkout source code to working directory + - environment_setup + - install_client + - python/install-packages: + pkg-manager: pip + pip-dependency-file: requirements-docs.txt + - run: + name: Deploy Docs + command: | + mkdocs gh-deploy build_image: executor: ubuntu-large steps: From 4a7cb4bec0eaad74f866d031cdcdd9275a278034 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 3 Aug 2023 21:33:27 -0700 Subject: [PATCH 038/425] Some improvements to completions APIs (#194) * Some improvements to completions APIs * fix black * isort --- clients/python/llmengine/completion.py | 46 ++++++++++++++++++++++---- clients/python/llmengine/data_types.py | 33 ++++++++++++++++-- 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 661ac30e..dcd979a3 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -1,4 +1,4 @@ -from typing import AsyncIterable, Iterator, Union +from typing import AsyncIterable, Iterator, List, Optional, Union from llmengine.api_engine import APIEngine from llmengine.data_types import ( @@ -29,6 +29,8 @@ async def acreate( prompt: str, max_new_tokens: int = 20, temperature: float = 0.2, + stop_sequences: Optional[List[str]] = None, + return_token_log_probs: Optional[bool] = False, timeout: int = 10, stream: bool = False, ) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]: @@ -57,8 +59,16 @@ async def acreate( [Model Zoo](../../model_zoo) for information on each supported model's context length. temperature (float): - What sampling temperature to use, in the range `(0, 1]`. Higher values like 0.8 will make the output + What sampling temperature to use, in the range `[0, 1]`. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + When temperature is 0 greedy sampling is used. + + stop_sequences (Optional[List[str]]): + One or more sequences where the API will stop generating tokens for the current completion. + + return_token_log_probs (Optional[bool]): + Whether to return the log probabilities of generated tokens. + When True, the response will include a list of tokens and their log probabilities. timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -150,6 +160,8 @@ async def _acreate_stream( prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, + stop_sequences=stop_sequences, + return_token_log_probs=return_token_log_probs, timeout=timeout, ) @@ -165,7 +177,11 @@ async def _acreate_sync(**kwargs) -> CompletionSyncResponse: return CompletionSyncResponse.parse_obj(response) return await _acreate_sync( - prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature + prompt=prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + stop_sequences=stop_sequences, + return_token_log_probs=return_token_log_probs, ) @classmethod @@ -175,6 +191,8 @@ def create( prompt: str, max_new_tokens: int = 20, temperature: float = 0.2, + stop_sequences: Optional[List[str]] = None, + return_token_log_probs: Optional[bool] = False, timeout: int = 10, stream: bool = False, ) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]: @@ -204,8 +222,16 @@ def create( [Model Zoo](../../model_zoo) for information on each supported model's context length. temperature (float): - What sampling temperature to use, in the range `(0, 1]`. Higher values like 0.8 will make the output + What sampling temperature to use, in the range `[0, 1]`. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + When temperature is 0 greedy sampling is used. + + stop_sequences (Optional[List[str]]): + One or more sequences where the API will stop generating tokens for the current completion. + + return_token_log_probs (Optional[bool]): + Whether to return the log probabilities of generated tokens. + When True, the response will include a list of tokens and their log probabilities. timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -284,12 +310,20 @@ def _create_stream(**kwargs): yield CompletionStreamResponse.parse_obj(chunk) return _create_stream( - prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature + prompt=prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + stop_sequences=stop_sequences, + return_token_log_probs=return_token_log_probs, ) else: data = CompletionSyncV1Request( - prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature + prompt=prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + stop_sequences=stop_sequences, + return_token_log_probs=return_token_log_probs, ).dict() response = cls.post_sync( resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}", diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 63aefbbd..baa9e087 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -182,7 +182,8 @@ class GetLLMEndpointResponse(BaseModel): """ id: Optional[str] = Field( - default=None, description="(For self-hosted users) The autogenerated ID of the model." + default=None, + description="(For self-hosted users) The autogenerated ID of the model.", ) """(For self-hosted users) The autogenerated ID of the model.""" @@ -259,7 +260,25 @@ class CompletionSyncV1Request(BaseModel): prompt: str = Field(..., min_length=1) max_new_tokens: int = Field(..., gt=0) - temperature: float = Field(..., gt=0.0) + temperature: float = Field(..., ge=0.0) + stop_sequences: Optional[List[str]] = Field(default=None) + return_token_log_probs: Optional[bool] = Field(default=False) + + +class TokenOutput(BaseModel): + """ + Detailed token information. + """ + + token: str + """ + The token text. + """ + + log_prob: float + """ + The log probability of the token. + """ class CompletionOutput(BaseModel): @@ -273,6 +292,9 @@ class CompletionOutput(BaseModel): num_completion_tokens: int """Number of tokens in the completion.""" + tokens: Optional[List[TokenOutput]] = None + """Detailed token information.""" + class CompletionSyncResponse(BaseModel): """ @@ -299,7 +321,9 @@ class CompletionStreamV1Request(BaseModel): prompt: str = Field(..., min_length=1) max_new_tokens: int = Field(..., gt=0) - temperature: float = Field(..., gt=0.0) + temperature: float = Field(..., ge=0.0) + stop_sequences: Optional[List[str]] = Field(default=None) + return_token_log_probs: Optional[bool] = Field(default=False) class CompletionStreamOutput(BaseModel): @@ -312,6 +336,9 @@ class CompletionStreamOutput(BaseModel): num_completion_tokens: Optional[int] = None """Number of tokens in the completion.""" + token: Optional[TokenOutput] = None + """Detailed token information.""" + class CompletionStreamResponse(BaseModel): """ From ab29df2838b288684e34885760ea49c02878f948 Mon Sep 17 00:00:00 2001 From: William Song Date: Fri, 4 Aug 2023 13:02:38 -0700 Subject: [PATCH 039/425] bump beta8 -> beta9 (#195) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 25dd9652..73e94c20 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta8" +__version__ = "0.0.0.beta9" from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 6f229a8b..e1f7b301 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta8" +version = "0.0.0.beta9" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 625c8984..ed773d7d 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta8", + version="0.0.0.beta9", packages=find_packages(), ) From c74bcd84bbe9b6f1a037cdfe683b983722547b8b Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 7 Aug 2023 10:55:57 -0700 Subject: [PATCH 040/425] Link to HF greedy search (#198) * Link to HF greedy search * fix --- clients/python/llmengine/completion.py | 4 ++-- docs/guides/completions.md | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index dcd979a3..8178867c 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -61,7 +61,7 @@ async def acreate( temperature (float): What sampling temperature to use, in the range `[0, 1]`. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. - When temperature is 0 greedy sampling is used. + When temperature is 0 [greedy search](https://huggingface.co/docs/transformers/generation_strategies#greedy-search) is used. stop_sequences (Optional[List[str]]): One or more sequences where the API will stop generating tokens for the current completion. @@ -224,7 +224,7 @@ def create( temperature (float): What sampling temperature to use, in the range `[0, 1]`. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. - When temperature is 0 greedy sampling is used. + When temperature is 0 [greedy search](https://huggingface.co/docs/transformers/generation_strategies#greedy-search) is used. stop_sequences (Optional[List[str]]): One or more sequences where the API will stop generating tokens for the current completion. diff --git a/docs/guides/completions.md b/docs/guides/completions.md index eb16b94e..a5ea9a06 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -34,6 +34,7 @@ print(response.output.text) - **max_new_tokens:** The maximum number of tokens to generate in the chat completion. - **temperature:** The sampling temperature to use. Higher values make the output more random, while lower values will make it more focused and deterministic. + When temperature is 0 [greedy search](https://huggingface.co/docs/transformers/generation_strategies#greedy-search) is used. See the full [Completion API reference documentation](../../api/python_client/#llmengine.Completion) to learn more. From 53f9d88c54d5ab6116b3b801fa6c8a66d8cd6b48 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 7 Aug 2023 14:16:20 -0700 Subject: [PATCH 041/425] Integrate finetune with wandb (#199) * Integrate finetune with wandb * model zoo update * hyperparameter --- clients/python/llmengine/data_types.py | 10 ++++++++++ clients/python/llmengine/fine_tuning.py | 10 +++++++++- docs/model_zoo.md | 1 + 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index baa9e087..e17e755a 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -383,6 +383,16 @@ class CreateFineTuneRequest(BaseModel): ) """Hyperparameters to pass in to training job.""" + wandb_config: Optional[Dict[str, Any]] = Field( + default=None, description="Configuration for Weights and Biases." + ) + """ + A dict of configuration parameters for Weights & Biases. See [Weights & Biases](https://docs.wandb.ai/ref/python/init) for more information. + Set `hyperparameter["report_to"]` to `wandb` to enable automatic finetune metrics logging. + Must include `api_key` field which is the wandb API key. + Also supports setting `base_url` to use a custom Weights & Biases server. + """ + suffix: Optional[str] = Field( default=None, description="Optional user-provided identifier suffix for the fine-tuned model.", diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index c280ac39..d0a0ef86 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from llmengine.api_engine import DEFAULT_TIMEOUT, APIEngine from llmengine.data_types import ( @@ -29,6 +29,7 @@ def create( training_file: str, validation_file: Optional[str] = None, hyperparameters: Optional[Dict[str, Union[str, int, float]]] = None, + wandb_config: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, ) -> CreateFineTuneResponse: """ @@ -66,6 +67,12 @@ def create( * `epochs`: Number of fine-tuning epochs. This should be less than 20. (Default: 5) * `weight_decay`: Regularization penalty applied to learned weights. (Default: 0.001) + wandb_config (`Optional[Dict[str, Any]]`): + A dict of configuration parameters for Weights & Biases. See [Weights & Biases](https://docs.wandb.ai/ref/python/init) for more information. + Set `hyperparameter["report_to"]` to `wandb` to enable automatic finetune metrics logging. + Must include `api_key` field which is the wandb API key. + Also supports setting `base_url` to use a custom Weights & Biases server. + suffix (`Optional[str]`): A string that will be added to your fine-tuned model name. If present, the entire fine-tuned model name will be formatted like `"[model].[suffix].[YYYY-MM-DD-HH-MM-SS]"`. If absent, the @@ -134,6 +141,7 @@ def create( training_file=training_file, validation_file=validation_file, hyperparameters=hyperparameters, + wandb_config=wandb_config, suffix=suffix, ) response = cls.post_sync( diff --git a/docs/model_zoo.md b/docs/model_zoo.md index da07c287..264196a6 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -10,6 +10,7 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `llama-2-13b` | ✅ | | | `llama-2-13b-chat` | ✅ | | | `llama-2-70b` | ✅ | | +| `llama-2-70b-chat` | ✅ | | | `falcon-7b` | ✅ | | | `falcon-7b-instruct` | ✅ | | | `falcon-40b` | ✅ | | From 7fb42a6df83bdc5e75783c24d5057373ceb3f9a0 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 7 Aug 2023 14:28:54 -0700 Subject: [PATCH 042/425] Bump version (#201) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 73e94c20..66efafb4 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta9" +__version__ = "0.0.0.beta10" from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index e1f7b301..adf2ed20 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta9" +version = "0.0.0.beta10" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index ed773d7d..1baaf486 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta9", + version="0.0.0.beta10", packages=find_packages(), ) From 9a1d567bfa6a32e4df9a27cfc68a16a8a8264673 Mon Sep 17 00:00:00 2001 From: "Ray (Jui-Tse) Hung" <135046452+ruizehung-scale@users.noreply.github.com> Date: Tue, 8 Aug 2023 13:32:11 -0700 Subject: [PATCH 043/425] Add documentation on pointing llmengine client to self-hosted infrastructure (#200) * Add documentation on pointing llmengine client to self-hosted infrastructure * url -> URL * Add code sample for setting LLM_ENGINE_BASE_PATH * Small wording fix * Wording fix * Update dns_host_domain --- charts/llm-engine/values_sample.yaml | 2 +- docs/guides/self_hosting.md | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/charts/llm-engine/values_sample.yaml b/charts/llm-engine/values_sample.yaml index 06d70362..70d740cf 100644 --- a/charts/llm-engine/values_sample.yaml +++ b/charts/llm-engine/values_sample.yaml @@ -96,7 +96,7 @@ config: # k8s_cluster_name [required] is the name of the k8s cluster k8s_cluster_name: main_cluster # dns_host_domain [required] is the domain name of the k8s cluster - dns_host_domain: domain.llm-engine.com + dns_host_domain: llm-engine.domain.com # default_region [required] is the default AWS region for various resources (e.g ECR) default_region: us-east-1 # aws_account_id [required] is the AWS account ID for various resources (e.g ECR) diff --git a/docs/guides/self_hosting.md b/docs/guides/self_hosting.md index 8c6c963b..0c446191 100644 --- a/docs/guides/self_hosting.md +++ b/docs/guides/self_hosting.md @@ -200,4 +200,12 @@ $ curl -X POST 'http://localhost:5000/v1/llm/completions-sync?model_endpoint_nam You should get a response similar to: ``` {"status":"SUCCESS","outputs":[{"text":". Tell me a joke about AI. Tell me a joke about AI. Tell me a joke about AI. Tell me","num_completion_tokens":30}],"traceback":null} +``` + +### Pointing LLM Engine client to use self-hosted infrastructure +The `llmengine` client makes requests to Scale AI's hosted infrastructure by default. You can have `llmengine` client make requests to your own self-hosted infrastructure by setting the `LLM_ENGINE_BASE_PATH` environment variable to the URL of the `llm-engine` service. + +The exact URL of `llm-engine` service depends on your Kubernetes cluster networking setup. The domain is specified at `config.values.infra.dns_host_domain` in the helm chart values config file. Using `charts/llm-engine/values_sample.yaml` as an example, you would do: +```bash +export LLM_ENGINE_BASE_PATH=https://llm-engine.domain.com ``` \ No newline at end of file From 7198a58ac51eea4a6be6ad3074d4d0cfe254d3ef Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 8 Aug 2023 14:59:07 -0700 Subject: [PATCH 044/425] adding status field to model get response (#202) --- clients/python/llmengine/data_types.py | 3 +++ clients/python/llmengine/model.py | 1 + docs/api/data_types.md | 1 + docs/guides/endpoint_creation.md | 15 +++++++++++++++ 4 files changed, 20 insertions(+) create mode 100644 docs/guides/endpoint_creation.md diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index e17e755a..c0dc185f 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -201,6 +201,9 @@ class GetLLMEndpointResponse(BaseModel): source: LLMSource = Field(description="The source of the model, e.g. Hugging Face.") """The source of the model, e.g. Hugging Face.""" + status: ModelEndpointStatus = Field(description="The status of the model.") + """The status of the model (can be one of "READY", "UPDATE_PENDING", "UPDATE_IN_PROGRESS", "UPDATE_FAILED", "DELETE_IN_PROGRESS").""" + inference_framework: LLMInferenceFramework = Field( description="The inference framework used by the model." ) diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 1c854eba..b96afa1b 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -242,6 +242,7 @@ def get( "name": "llama-2-7b.suffix.2023-07-18-12-00-00", "model_name": null, "source": "hugging_face", + "status": "READY", "inference_framework": "text_generation_inference", "inference_framework_tag": null, "num_shards": null, diff --git a/docs/api/data_types.md b/docs/api/data_types.md index 2d53a3bf..b932fa70 100644 --- a/docs/api/data_types.md +++ b/docs/api/data_types.md @@ -35,6 +35,7 @@ - inference_framework - id - model_name + - status - inference_framework_tag - num_shards - quantize diff --git a/docs/guides/endpoint_creation.md b/docs/guides/endpoint_creation.md new file mode 100644 index 00000000..2a51d0bc --- /dev/null +++ b/docs/guides/endpoint_creation.md @@ -0,0 +1,15 @@ +When creating a model endpoint, you can periodically poll the model status field to +track the status of your model endpoint. In general, you'll need to wait after the +model creation step for the model endpoint to be ready and available for use. +An example is provided below: + +*Assuming the user has created a model named "llama-2-7b.suffix.2023-07-18-12-00-00"* +``` +model_name = "llama-2-7b.suffix.2023-07-18-12-00-00" +response = Model.get(model_name) +while response.status != "READY": + time.sleep(60) + response = Model.get(model_name) +``` + +Once the endpoint status is ready, you can use your newly created model for inference. \ No newline at end of file From b20ded1945dcf8d91086c7863dfbcff87cf29083 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 8 Aug 2023 16:17:24 -0700 Subject: [PATCH 045/425] Add integrations doc page (#203) * Add integrations doc page * fix * update * comments --- docs/guides/fine_tuning.md | 2 ++ docs/integrations.md | 24 ++++++++++++++++++++++++ mkdocs.yml | 1 + 3 files changed, 27 insertions(+) create mode 100644 docs/integrations.md diff --git a/docs/guides/fine_tuning.md b/docs/guides/fine_tuning.md index b4d5d9ef..192d196d 100644 --- a/docs/guides/fine_tuning.md +++ b/docs/guides/fine_tuning.md @@ -146,6 +146,8 @@ print(response.json()) See the [Model Zoo](../../model_zoo) to see which models have fine-tuning support. +See [Integrations](../integrations.md) to see how to track fine-tuning metrics. + Once the fine-tune is launched, you can also [get the status of your fine-tune](../../api/python_client/#llmengine.fine_tuning.FineTune.get). You can also [list events that your fine-tune produces](../../api/python_client/#llmengine.fine_tuning.FineTune.get_events). ## Making inference calls to your fine-tune diff --git a/docs/integrations.md b/docs/integrations.md new file mode 100644 index 00000000..60674387 --- /dev/null +++ b/docs/integrations.md @@ -0,0 +1,24 @@ +# Integrations + +## Weights & Biases + +LLM Engine integrates with Weights & Biases to track metrics during fine tuning. To enable: + +```python +from llmengine import FineTune + +response = FineTune.create( + model="llama-2-7b", + training_file="s3://my-bucket/path/to/training-file.csv", + validation_file="s3://my-bucket/path/to/validation-file.csv", + hyperparameters={"report_to": "wandb"}, + wandb_config={"api_key":"key", "project":"fine-tune project"} +) +``` + +Configs to specify: + +- (Required) Set `hyperparameters.report_to` to `wandb` to enables automatic metrics tracking. +- (Required) Set `wandb_config.api_key` to the API key. +- (Optional) Set `wandb_config.base_url` to use a custom Weights & Biases server. +- `wandb_config` also accepts keys from [wandb.init()](https://docs.wandb.ai/ref/python/init). diff --git a/mkdocs.yml b/mkdocs.yml index 45a7329c..a24b4763 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -46,6 +46,7 @@ nav: - "API Reference": api/python_client.md - "Data Type Reference": api/data_types.md - "Error handling": api/error_handling.md + - "Integrations": integrations.md - "Pricing": pricing.md - "Contributing": contributing.md # - "FAQ": faq.md From 9364e5b6eb9c431d130e629aa19a76d66b766afd Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Mon, 14 Aug 2023 15:11:44 -0700 Subject: [PATCH 046/425] Update docs to reflect maximum suffix length (#207) --- clients/python/llmengine/data_types.py | 4 ++-- clients/python/llmengine/fine_tuning.py | 7 ++++--- docs/guides/fine_tuning.md | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index c0dc185f..5effe216 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -398,9 +398,9 @@ class CreateFineTuneRequest(BaseModel): suffix: Optional[str] = Field( default=None, - description="Optional user-provided identifier suffix for the fine-tuned model.", + description="Optional user-provided identifier suffix for the fine-tuned model. Can be up to 28 characters long.", ) - """Optional user-provided identifier suffix for the fine-tuned model.""" + """Optional user-provided identifier suffix for the fine-tuned model. Can be up to 28 characters long.""" class CreateFineTuneResponse(BaseModel): diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index d0a0ef86..6a19dc3b 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -75,10 +75,11 @@ def create( suffix (`Optional[str]`): A string that will be added to your fine-tuned model name. If present, the entire fine-tuned model name - will be formatted like `"[model].[suffix].[YYYY-MM-DD-HH-MM-SS]"`. If absent, the - fine-tuned model name will be formatted `"[model].[YYYY-MM-DD-HH-MM-SS]"`. + will be formatted like `"[model].[suffix].[YYMMDD-HHMMSS]"`. If absent, the + fine-tuned model name will be formatted `"[model].[YYMMDD-HHMMSS]"`. For example, if `suffix` is `"my-experiment"`, the fine-tuned model name could be - `"llama-2-7b.my-experiment.2023-07-17-23-01-50"`. + `"llama-2-7b.my-experiment.230717-230150"`. + Note: `suffix` must be between 1 and 28 characters long, and can only contain alphanumeric characters and hyphens. Returns: CreateFineTuneResponse: an object that contains the ID of the created fine-tuning job diff --git a/docs/guides/fine_tuning.md b/docs/guides/fine_tuning.md index 192d196d..b5d43673 100644 --- a/docs/guides/fine_tuning.md +++ b/docs/guides/fine_tuning.md @@ -119,7 +119,7 @@ Once you have uploaded your data, you can use the LLM Engine's [FineTune.Create] If you specify a suffix, the fine-tune will be named `model.suffix.`. If you do not, the fine-tune will be named `model.`. The timestamp will be the time the fine-tune was -launched. +launched. Note: the suffix must only contain alphanumeric characters and hyphens, and be at most 28 characters long.
Hyper-parameters for fine-tune From d55c1b2171194a22a907ed8a28c762f604268ddb Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 15 Aug 2023 13:27:49 -0700 Subject: [PATCH 047/425] adding download api to launch client, updating example (#196) --- clients/python/llmengine/__init__.py | 4 + clients/python/llmengine/data_types.py | 22 ++ clients/python/llmengine/model.py | 50 ++++ docs/api/data_types.md | 4 + docs/api/python_client.md | 1 + examples/download_a_finetuned_model.ipynb | 348 ++++++++++++++++++++++ 6 files changed, 429 insertions(+) create mode 100644 examples/download_a_finetuned_model.ipynb diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 66efafb4..380ef213 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -34,6 +34,8 @@ ListFilesResponse, ListFineTunesResponse, ListLLMEndpointsResponse, + ModelDownloadRequest, + ModelDownloadResponse, UploadFileResponse, ) from llmengine.file import File @@ -51,6 +53,8 @@ "CreateFineTuneResponse", "DeleteFileResponse", "DeleteLLMEndpointResponse", + "ModelDownloadRequest", + "ModelDownloadResponse", "GetFileContentResponse", "File", "FineTune", diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 5effe216..34eaf0f9 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -495,6 +495,28 @@ class GetFineTuneEventsResponse(BaseModel): events: List[LLMFineTuneEvent] = Field(..., description="List of fine-tuning events.") +class ModelDownloadRequest(BaseModel): + """ + Request object for downloading a model. + """ + + model_name: str = Field(..., description="Name of the model to download.") + download_format: Optional[str] = Field( + default="hugging_face", + description="Desired return format for downloaded model weights (default=hugging_face).", + ) + + +class ModelDownloadResponse(BaseModel): + """ + Response object for downloading a model. + """ + + urls: Dict[str, str] = Field( + ..., description="Dictionary of (file_name, url) pairs to download the model from." + ) + + class UploadFileResponse(BaseModel): """Response object for uploading a file.""" diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index b96afa1b..cd2191e3 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -10,6 +10,8 @@ ListLLMEndpointsResponse, LLMInferenceFramework, LLMSource, + ModelDownloadRequest, + ModelDownloadResponse, ModelEndpointType, PostInferenceHooks, Quantization, @@ -366,3 +368,51 @@ def delete(cls, model: str) -> DeleteLLMEndpointResponse: """ response = cls._delete(f"v1/llm/model-endpoints/{model}", timeout=DEFAULT_TIMEOUT) return DeleteLLMEndpointResponse.parse_obj(response) + + @classmethod + def download( + cls, + model_name: str, + download_format: str = "hugging_face", + ) -> ModelDownloadResponse: + """ + Download a fine-tuned model. + + This API can be used to download the resulting model from a fine-tuning job. + It takes the `model_name` and `download_format` as parameter and returns a + response object which contains a list of urls associated with the fine-tuned model. + The user can then download these urls to obtain the fine-tuned model. If called + on a nonexistent model, an error will be thrown. + + Args: + model_name (`str`): + name of the fine-tuned model + download_format (`str`): + download format requested (default=hugging_face) + Returns: + DownloadModelResponse: an object that contains a dictionary of filenames, urls from which to download the model weights. + The urls are presigned urls that grant temporary access and expire after an hour. + + === "Downloading model in Python" + ```python + from llmengine import Model + + response = Model.download("llama-2-7b.suffix.2023-07-18-12-00-00", download_format="hugging_face") + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "urls": {"my_model_file": 'https://url-to-my-model-weights'} + } + ``` + """ + + request = ModelDownloadRequest(model_name=model_name, download_format=download_format) + response = cls.post_sync( + resource_name="v1/llm/model-endpoints/download", + data=request.dict(), + timeout=DEFAULT_TIMEOUT, + ) + return ModelDownloadResponse.parse_obj(response) diff --git a/docs/api/data_types.md b/docs/api/data_types.md index b932fa70..0663607f 100644 --- a/docs/api/data_types.md +++ b/docs/api/data_types.md @@ -45,6 +45,10 @@ ::: llmengine.DeleteLLMEndpointResponse +::: llmengine.ModelDownloadRequest + +::: llmengine.ModelDownloadResponse + ::: llmengine.UploadFileResponse ::: llmengine.GetFileResponse diff --git a/docs/api/python_client.md b/docs/api/python_client.md index 8b3fdc1f..427ae8b6 100644 --- a/docs/api/python_client.md +++ b/docs/api/python_client.md @@ -21,6 +21,7 @@ - get - list - delete + - download ::: llmengine.File selection: diff --git a/examples/download_a_finetuned_model.ipynb b/examples/download_a_finetuned_model.ipynb new file mode 100644 index 00000000..da548b12 --- /dev/null +++ b/examples/download_a_finetuned_model.ipynb @@ -0,0 +1,348 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8d3a4214", + "metadata": { + "id": "8d3a4214" + }, + "source": [ + "# Download a FineTuned Model \n", + "This notebook demonstrates how to download a finetuned model that you've created using LLM Engine and add it to huggingface!\n", + "\n", + "**This notebook is an extension of the previous finetuning notebook on ScienceQA**" + ] + }, + { + "cell_type": "markdown", + "id": "XK6VpTnOL4OV", + "metadata": { + "id": "XK6VpTnOL4OV" + }, + "source": [ + "# Packages Required\n", + "For this demo, we'll be using the `scale-llm-engine` package, the `datasets` package for downloading our finetuning dataset, `transformers`, and `huggingface_hub` for uploading our model to huggingface.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "S5u6DdInMEQ7", + "metadata": { + "id": "S5u6DdInMEQ7" + }, + "outputs": [], + "source": [ + "!pip install scale-llm-engine\n", + "!pip install transformers\n", + "!pip install datasets" + ] + }, + { + "cell_type": "markdown", + "id": "a3dc2a56", + "metadata": { + "id": "a3dc2a56" + }, + "source": [ + "# Data Preparation\n", + "Let's load in the dataset using Huggingface and view the features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e06ac39e", + "metadata": { + "id": "e06ac39e" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from smart_open import smart_open\n", + "import pandas as pd\n", + "\n", + "dataset = load_dataset('derek-thomas/ScienceQA')\n", + "dataset['train'].features" + ] + }, + { + "cell_type": "markdown", + "id": "1cbe8a58", + "metadata": { + "id": "1cbe8a58" + }, + "source": [ + "Now, let's format the dataset into what's acceptable for LLM Engine - a CSV file with 'prompt' and 'response' columns." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b0eb8ad", + "metadata": { + "id": "0b0eb8ad" + }, + "outputs": [], + "source": [ + "choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z\n", + "def format_options(options, choice_prefixes):\n", + " return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])\n", + "\n", + "def format_prompt(r, choice_prefixes):\n", + " options = format_options(r['choices'], choice_prefixes)\n", + " return f'''Context: {r[\"hint\"]}\\nQuestion: {r[\"question\"]}\\nOptions:{options}\\nAnswer:'''\n", + "\n", + "def format_label(r, choice_prefixes):\n", + " return choice_prefixes[r['answer']]\n", + "\n", + "def convert_dataset(ds):\n", + " prompts = [format_prompt(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " labels = [format_label(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " df = pd.DataFrame.from_dict({'prompt': prompts, 'response': labels})\n", + " return df\n", + "\n", + "save_to_s3 = False\n", + "df_train = convert_dataset(dataset['train'])\n", + "if save_to_s3:\n", + " train_url = 's3://...'\n", + " val_url = 's3://...'\n", + " df_train = convert_dataset(dataset['train'])\n", + " with smart_open(train_url, 'wb') as f:\n", + " df_train.to_csv(f)\n", + "\n", + " df_val = convert_dataset(dataset['validation'])\n", + " with smart_open(val_url, 'wb') as f:\n", + " df_val.to_csv(f)\n", + "else:\n", + " # Gists of the already processed datasets\n", + " train_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_train.csv'\n", + " val_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_val.csv'\n", + "\n", + "df_train" + ] + }, + { + "cell_type": "markdown", + "id": "e2fc8d76", + "metadata": { + "id": "e2fc8d76" + }, + "source": [ + "# Fine-tune\n", + "Now, we can fine-tune the model using LLM Engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4905d447", + "metadata": { + "id": "4905d447" + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ['SCALE_API_KEY'] = 'xxx'\n", + "\n", + "from llmengine import FineTune\n", + "\n", + "response = FineTune.create(\n", + " model=\"llama-2-7b\",\n", + " training_file=train_url,\n", + " validation_file=val_url,\n", + " hyperparameters={\n", + " 'lr':2e-4,\n", + " },\n", + " suffix='science-qa-llama'\n", + ")\n", + "run_id = response.id" + ] + }, + { + "cell_type": "markdown", + "id": "55074457", + "metadata": { + "id": "55074457" + }, + "source": [ + "We can sleep until the job completes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "840938dd", + "metadata": { + "id": "840938dd" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "while True:\n", + " job_status = FineTune.get(run_id).status\n", + " print(job_status)\n", + " if job_status == 'SUCCESS':\n", + " break\n", + " time.sleep(60)\n", + "\n", + "fine_tuned_model = FineTune.get(run_id).fine_tuned_model" + ] + }, + { + "cell_type": "markdown", + "id": "31278c6d", + "metadata": { + "id": "31278c6d" + }, + "source": [ + "# Downloading our Finetuned model \n", + "Let's download the weights for the new fine-tuned model using LLM Engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f2f3f43", + "metadata": { + "id": "9f2f3f43" + }, + "outputs": [], + "source": [ + "from llmengine import Model\n", + "\n", + "response = Model.download(FineTune.get(run_id).fine_tune_model, download_format=\"hugging_face\")\n", + "print(response.urls)" + ] + }, + { + "cell_type": "markdown", + "id": "ae9cbdf3", + "metadata": {}, + "source": [ + "We now have a dictionary of filenames and urls that point to the file(s) where our finetuned model lives. We can download the associated finetuned model either synchronously or asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc363e48", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import requests\n", + "\n", + "def download_files(url_dict, directory):\n", + " \"\"\"\n", + " Download files from given URLs to specified directory.\n", + " \n", + " Parameters:\n", + " - url_dict: Dictionary of {file_name: url} pairs.\n", + " - directory: Directory to save the files.\n", + " \"\"\"\n", + " if not os.path.exists(directory):\n", + " os.makedirs(directory)\n", + " \n", + " for file_name, url in url_dict.items():\n", + " response = requests.get(url, stream=True)\n", + " response.raise_for_status() # Raise an exception for HTTP errors\n", + " file_path = os.path.join(directory, file_name)\n", + " \n", + " with open(file_path, 'wb') as file:\n", + " for chunk in response.iter_content(chunk_size=8192):\n", + " file.write(chunk)\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "000e1633", + "metadata": {}, + "outputs": [], + "source": [ + "output_directory = \"YOUR_MODEL_DIR\"\n", + "download_files(response.urls, output_directory) " + ] + }, + { + "cell_type": "markdown", + "id": "e4e87233", + "metadata": {}, + "source": [ + "Lastly, we can upload our downloaded model to the huggingface hub." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7c8ee18", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install huggingface-hub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "328efd19", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from huggingface_hub import Repository\n", + "\n", + "HF_USERNAME = \"YOUR_HUGGINGFACE_USERNAME\"\n", + "HF_TOKEN = \"YOUR_HUGGINGFACE_TOKEN\"\n", + "\n", + "def upload_to_huggingface(directory, model_name):\n", + " \"\"\"\n", + " Upload files from a directory to the Hugging Face Hub as a new model.\n", + "\n", + " Parameters:\n", + " - directory: Directory containing the files to be uploaded.\n", + " - model_name: Name of the new model.\n", + " - token: Your Hugging Face authentication token.\n", + " \"\"\"\n", + " \n", + " # Create a repository with the given name\n", + " repo = Repository(directory, clone_from=f\"{HF_USERNAME}/{model_name}\", use_auth_token=HF_TOKEN)\n", + " \n", + " # Commit and push files\n", + " repo.push_to_hub()\n", + "\n", + "model_name = \"my-new-model\"\n", + " \n", + "upload_to_huggingface(output_directory, model_name, HF_TOKEN)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Environment (conda_pytorch_p38)", + "language": "python", + "name": "conda_pytorch_p38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 79e9bbb09d232b1d996578b9caf5690c6035e105 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 15 Aug 2023 15:19:27 -0700 Subject: [PATCH 048/425] Ianmacleod/update download docs (#210) --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/model.py | 8 ++++---- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- docs/guides/endpoint_creation.md | 8 +++++--- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 380ef213..3dab0728 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta10" +__version__ = "0.0.0.beta11" from typing import Sequence diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index cd2191e3..26bbcf2d 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -380,9 +380,9 @@ def download( This API can be used to download the resulting model from a fine-tuning job. It takes the `model_name` and `download_format` as parameter and returns a - response object which contains a list of urls associated with the fine-tuned model. - The user can then download these urls to obtain the fine-tuned model. If called - on a nonexistent model, an error will be thrown. + response object which contains a dictonary of filename, url pairs associated + with the fine-tuned model. The user can then download these urls to obtain + the fine-tuned model. If called on a nonexistent model, an error will be thrown. Args: model_name (`str`): @@ -404,7 +404,7 @@ def download( === "Response in JSON" ```json { - "urls": {"my_model_file": 'https://url-to-my-model-weights'} + "urls": {"my_model_file": "https://url-to-my-model-weights"} } ``` """ diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index adf2ed20..4adfdb19 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta10" +version = "0.0.0.beta11" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 1baaf486..b8559ba8 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta10", + version="0.0.0.beta11", packages=find_packages(), ) diff --git a/docs/guides/endpoint_creation.md b/docs/guides/endpoint_creation.md index 2a51d0bc..e16602b7 100644 --- a/docs/guides/endpoint_creation.md +++ b/docs/guides/endpoint_creation.md @@ -3,11 +3,13 @@ track the status of your model endpoint. In general, you'll need to wait after t model creation step for the model endpoint to be ready and available for use. An example is provided below: -*Assuming the user has created a model named "llama-2-7b.suffix.2023-07-18-12-00-00"* + ``` -model_name = "llama-2-7b.suffix.2023-07-18-12-00-00" +model_name = "test_deploy" +model = Model.create(name=model_name, model="llama-2-7b", inference_frame_image_tag="0.9.4") response = Model.get(model_name) -while response.status != "READY": +while response.status.name != "READY": + print(response.status.name) time.sleep(60) response = Model.get(model_name) ``` From c898821764ba0c0f2ac8361086e3d78eb2734a79 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 16 Aug 2023 17:05:20 -0700 Subject: [PATCH 049/425] Faster s5cmd download (#212) --- .../use_cases/llm_model_endpoint_use_cases.py | 42 +++++++++++++++++-- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 8de7ef72..08b25354 100644 --- a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1,4 +1,5 @@ import json +import os from dataclasses import asdict from typing import Any, AsyncIterable, Dict, Optional @@ -193,9 +194,40 @@ async def create_text_generation_inference_bundle( command = [] if checkpoint_path is not None: if checkpoint_path.startswith("s3://"): - command = ["bash", "launch_s3_model.sh", checkpoint_path, str(num_shards)] + base_path = checkpoint_path.split("/")[-1] + final_weights_folder = "model_files" + subcommands = [] + + s5cmd = "s5cmd" + subcommands.append( + f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}" + ) + + if base_path.endswith(".tar"): + # If the checkpoint file is a tar file, extract it into final_weights_folder + subcommands.extend( + [ + f"{s5cmd} cp {checkpoint_path} .", + f"mkdir -p {final_weights_folder}", + f"tar --no-same-owner -xf {base_path} -C {final_weights_folder}", + ] + ) + else: + subcommands.append( + f"{s5cmd} --numworkers 512 cp --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) + + subcommands.append( + f"text-generation-launcher --hostname :: --model-id ./{final_weights_folder} --num-shard {num_shards} --port 5005" + ) + if quantize: - command = command + [f"'--quantize {str(quantize)}'"] + subcommands[-1] = subcommands[-1] + f" --quantize {quantize}" + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] else: raise ObjectHasInvalidValueException( f"Not able to load checkpoint path {checkpoint_path}." @@ -599,7 +631,8 @@ async def execute( inference_request = EndpointPredictV1Request(args=args) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: @@ -632,7 +665,8 @@ async def execute( } inference_request = EndpointPredictV1Request(args=tgi_args) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: From 14635b0be69c5331b90f2a50f263b57e9558a714 Mon Sep 17 00:00:00 2001 From: Phil Chen <92065453+phil-scale@users.noreply.github.com> Date: Wed, 23 Aug 2023 10:46:09 -0700 Subject: [PATCH 050/425] delete server and copy over model-engine-server (#214) --- .circleci/config.yml | 16 +- .pre-commit-config.yaml | 6 +- .../templates/cacher_deployment.yaml | 2 +- .../templates/database_init_job.yaml | 2 +- .../endpoint_builder_deployment.yaml | 2 +- .../templates/gateway_deployment.yaml | 2 +- .../templates/llm_engine_init_job.yaml | 2 +- .../service_template_config_map.yaml | 12 +- {server => model-engine}/Dockerfile | 29 +- .../model_engine_server}/__init__.py | 0 .../model_engine_server}/api/__init__.py | 0 model-engine/model_engine_server/api/app.py | 53 + .../model_engine_server}/api/batch_jobs_v1.py | 55 +- .../model_engine_server}/api/dependencies.py | 122 +- .../api/docker_image_batch_job_bundles_v1.py | 27 +- .../model_engine_server/api/files_v1.py | 127 ++ .../model_engine_server}/api/llms_v1.py | 161 +- .../api/model_bundles_v1.py | 16 +- .../api/model_bundles_v2.py | 14 +- .../api/model_endpoints_docs_v1.py | 8 +- .../api/model_endpoints_v1.py | 16 +- .../model_engine_server}/api/tasks_v1.py | 27 +- .../model_engine_server/api/triggers_v1.py | 176 ++ .../model_engine_server}/api/worker.py | 8 +- .../model_engine_server}/common/__init__.py | 0 .../model_engine_server}/common/config.py | 27 +- .../model_engine_server/common/constants.py | 10 + .../common/datadog_utils.py | 19 + .../common/dtos/__init__.py | 0 .../common/dtos/batch_jobs.py | 11 +- .../common/dtos/docker_repository.py | 0 .../common/dtos/endpoint_builder.py | 3 +- .../model_engine_server/common/dtos/files.py | 47 + .../model_engine_server/common/dtos/llms.py | 230 +++ .../common/dtos/model_bundles.py | 2 +- .../common/dtos/model_endpoints.py | 4 +- .../common/dtos/resource_manager.py | 2 +- .../model_engine_server}/common/dtos/tasks.py | 2 +- .../common/dtos/triggers.py | 51 + .../model_engine_server/common/env_vars.py | 77 + .../model_engine_server}/common/errors.py | 0 .../model_engine_server}/common/io.py | 2 +- .../endpoint_predict_payload.py | 0 .../common/resource_limits.py | 11 +- .../common/serialization_utils.py | 0 .../common/service_requests.py | 4 +- .../model_engine_server/common/settings.py | 114 ++ .../model_engine_server}/common/types.py | 1 + .../model_engine_server}/core/__init__.py | 0 .../core/auth/__init__.py | 0 .../core/auth/authentication_repository.py | 7 + .../auth/fake_authentication_repository.py | 6 +- .../model_engine_server}/core/aws/__init__.py | 0 .../model_engine_server}/core/aws/roles.py | 23 +- .../model_engine_server}/core/aws/secrets.py | 12 +- .../core/aws/storage_client.py | 13 +- .../core/celery/__init__.py | 0 .../model_engine_server}/core/celery/app.py | 28 +- .../model_engine_server}/core/celery/s3.py | 0 .../model_engine_server}/core/config.py | 41 +- .../core/configs/default.yaml | 2 +- .../core/docker/__init__.py | 0 .../core/docker/docker_image.py | 20 +- .../model_engine_server}/core/docker/ecr.py | 19 +- .../core/docker/kaniko_template.yaml | 2 +- .../core/docker/kaniko_template_circleci.yaml | 0 .../core/docker/remote_build.py | 36 +- .../core/domain_exceptions.py | 0 .../core/fake_notification_gateway.py | 2 +- .../model_engine_server}/core/loggers.py | 23 +- .../core/notification_gateway.py | 0 .../core/utils/__init__.py | 0 .../model_engine_server}/core/utils/env.py | 0 .../model_engine_server}/core/utils/format.py | 0 .../model_engine_server}/core/utils/git.py | 0 .../core/utils/python_utils.py | 2 +- .../model_engine_server}/core/utils/timer.py | 2 +- .../model_engine_server}/core/utils/url.py | 7 +- .../model_engine_server}/db/__init__.py | 0 .../model_engine_server}/db/base.py | 49 +- .../db/endpoint_row_lock.py | 10 +- .../model_engine_server}/db/local_setup.py | 6 +- .../db/models/__init__.py | 3 +- .../db/models/common/__init__.py | 0 .../db/models/common/query.py | 0 .../db/models/common/record.py | 6 +- .../db/models/constants.py | 0 .../db/models/exceptions.py | 0 .../db/models/hosted_model_inference.py | 52 +- .../model_engine_server}/db/models/model.py | 23 +- .../db/models/utils/__init__.py | 0 .../db/models/utils/misc.py | 0 .../model_engine_server}/domain/__init__.py | 0 .../domain/authorization/__init__.py | 0 .../live_authorization_module.py | 34 +- .../domain/entities/__init__.py | 15 +- .../domain/entities/batch_job_entity.py | 10 +- .../domain/entities/common_types.py | 1 + .../docker_image_batch_job_bundle_entity.py | 4 +- .../domain/entities/file_entity.py | 15 + .../domain/entities/gpu_type.py | 3 +- .../domain/entities/llm_entity.py | 0 .../domain/entities/llm_fine_tune_entity.py | 11 +- .../domain/entities/model_bundle_entity.py | 18 +- .../domain/entities/model_endpoint_entity.py | 14 +- .../domain/entities/owned_entity.py | 0 .../domain/entities/trigger_entity.py | 20 + .../model_engine_server}/domain/exceptions.py | 28 +- .../domain/gateways/__init__.py | 6 + .../async_model_endpoint_inference_gateway.py | 4 +- .../domain/gateways/cron_job_gateway.py | 98 ++ .../docker_image_batch_job_gateway.py | 8 +- .../domain/gateways/file_storage_gateway.py | 94 + .../domain/gateways/llm_artifact_gateway.py | 15 + .../model_endpoints_schema_gateway.py | 2 +- .../gateways/model_primitive_gateway.py | 4 +- .../gateways/monitoring_metrics_gateway.py | 20 + ...eaming_model_endpoint_inference_gateway.py | 2 +- .../sync_model_endpoint_inference_gateway.py | 2 +- .../domain/gateways/task_queue_gateway.py | 2 +- .../domain/repositories/__init__.py | 4 + ...ocker_image_batch_job_bundle_repository.py | 6 +- .../domain/repositories/docker_repository.py | 2 +- .../llm_fine_tune_events_repository.py | 16 + .../repositories/model_bundle_repository.py | 4 +- .../domain/repositories/trigger_repository.py | 96 ++ .../domain/services/__init__.py | 2 + .../domain/services/batch_job_service.py | 4 +- .../services/endpoint_builder_service.py | 2 +- .../services/llm_fine_tuning_service.py | 40 + .../services/llm_model_endpoint_service.py | 4 +- .../domain/services/model_endpoint_service.py | 10 +- .../domain/use_cases/__init__.py | 0 .../use_cases/async_inference_use_cases.py | 18 +- .../domain/use_cases/batch_job_use_cases.py | 76 +- ...docker_image_batch_job_bundle_use_cases.py | 16 +- .../domain/use_cases/file_use_cases.py | 97 ++ .../use_cases/llm_fine_tuning_use_cases.py | 217 +++ .../use_cases/llm_model_endpoint_use_cases.py | 285 +++- .../use_cases/model_bundle_use_cases.py | 44 +- .../use_cases/model_endpoint_use_cases.py | 103 +- .../model_endpoints_schema_use_cases.py | 12 +- .../streaming_inference_use_cases.py | 18 +- .../use_cases/sync_inference_use_cases.py | 19 +- .../domain/use_cases/trigger_use_cases.py | 243 +++ .../entrypoints/__init__.py | 0 .../entrypoints/init_database.py | 6 +- .../entrypoints/init_spellbook_models.py | 11 +- .../entrypoints/k8s_cache.py | 54 +- .../start_batch_job_orchestration.py | 41 +- ...t_docker_image_batch_job_init_container.py | 17 +- .../entrypoints/start_fastapi_server.py | 4 +- .../inference/__init__.py | 0 .../inference/async_inference/__init__.py | 0 .../inference/async_inference/celery.py | 14 +- .../inference/async_inference/tasks.py | 28 +- .../inference/async_inference/vpa.yaml | 0 .../inference/base.Dockerfile | 6 +- .../model_engine_server}/inference/common.py | 22 +- ...-runnable-img-converted-from-artifact.yaml | 21 + .../inference/configs/service--forwarder.yaml | 19 + .../configs/service--http_forwarder.yaml | 5 +- .../inference_monitoring_metrics_gateway.py | 2 +- .../domain/gateways/usage_metrics_gateway.py | 28 + .../inference/download_and_inject_bundle.py | 2 +- .../inference/forwarding/__init__.py | 0 .../inference/forwarding/celery_forwarder.py | 170 ++ .../inference/forwarding/forwarding.py | 104 +- .../inference/forwarding/http_forwarder.py | 170 ++ .../inference/infra/__init__.py | 0 .../inference/infra/gateways/__init__.py | 0 ...og_inference_monitoring_metrics_gateway.py | 6 +- .../gateways/fake_usage_metrics_gateway.py | 10 + .../inference/inject_bundle.Dockerfile | 2 +- .../model_engine_server/inference/limits.conf | 2 + .../inference/post_inference_hooks.py | 67 +- .../inference/pytorch_or_tf.base.Dockerfile | 70 + .../inference/pytorch_or_tf.user.Dockerfile | 10 + .../inference/requirements_base.txt | 14 +- .../inference/service_requests.py | 17 +- .../inference/sync_inference/__init__.py | 0 .../inference/sync_inference/constants.py | 0 .../sync_inference/destination_rule.yaml | 0 .../sync_inference/fastapi_server.py | 16 +- .../inference/sync_inference/server.py | 96 ++ .../sync_inference/start_fastapi_server.py | 7 +- .../sync_inference/virtual_service.yaml | 0 .../inference/sync_inference/vpa.yaml | 0 .../inference/user.Dockerfile | 8 + .../model_engine_server}/infra/__init__.py | 0 .../infra/gateways/__init__.py | 8 +- .../infra/gateways/aiohttp_sse_client.py | 0 .../batch_job_orchestration_gateway.py | 2 +- .../gateways/batch_job_progress_gateway.py | 2 +- .../gateways/celery_task_queue_gateway.py | 30 +- .../datadog_monitoring_metrics_gateway.py | 38 + .../gateways/fake_model_primitive_gateway.py | 4 +- .../fake_monitoring_metrics_gateway.py | 17 +- .../infra/gateways/filesystem_gateway.py | 0 .../infra/gateways/k8s_resource_parser.py | 0 ..._async_model_endpoint_inference_gateway.py | 8 +- .../live_batch_job_orchestration_gateway.py | 18 +- .../live_batch_job_progress_gateway.py | 15 +- .../infra/gateways/live_cron_job_gateway.py | 159 ++ .../live_docker_image_batch_job_gateway.py | 104 +- .../live_model_endpoint_infra_gateway.py | 24 +- .../live_model_endpoints_schema_gateway.py | 29 +- ...eaming_model_endpoint_inference_gateway.py | 20 +- ...e_sync_model_endpoint_inference_gateway.py | 18 +- .../gateways/model_endpoint_infra_gateway.py | 6 +- .../infra/gateways/resources/__init__.py | 0 .../resources/endpoint_resource_gateway.py | 6 +- .../fake_sqs_endpoint_resource_delegate.py | 2 +- .../gateways/resources/image_cache_gateway.py | 26 +- .../k8s_endpoint_resource_delegate.py | 176 +- .../gateways/resources/k8s_resource_types.py | 377 ++-- .../live_endpoint_resource_gateway.py | 18 +- .../live_sqs_endpoint_resource_delegate.py | 10 +- .../sqs_endpoint_resource_delegate.py | 2 +- .../service_template_config_map_circleci.yaml | 1513 +++++------------ .../infra/gateways/s3_file_storage_gateway.py | 79 + .../infra/gateways/s3_filesystem_gateway.py | 11 +- .../infra/gateways/s3_llm_artifact_gateway.py | 36 + .../model_engine_server}/infra/infra_utils.py | 2 +- .../infra/repositories/__init__.py | 12 +- .../batch_job_record_repository.py | 2 +- .../db_batch_job_record_repository.py | 16 +- ...ocker_image_batch_job_bundle_repository.py | 16 +- .../db_model_bundle_repository.py | 14 +- .../db_model_endpoint_record_repository.py | 22 +- .../infra/repositories/db_repository_mixin.py | 2 +- .../repositories/db_trigger_repository.py | 134 ++ .../repositories/ecr_docker_repository.py | 16 +- .../repositories/feature_flag_repository.py | 0 .../repositories/llm_fine_tune_repository.py | 8 +- .../model_endpoint_cache_repository.py | 2 +- .../model_endpoint_record_repository.py | 4 +- .../redis_feature_flag_repository.py | 4 +- .../redis_model_endpoint_cache_repository.py | 6 +- ...s3_file_llm_fine_tune_events_repository.py | 89 + .../s3_file_llm_fine_tune_repository.py | 14 +- .../infra/services/__init__.py | 0 .../batch_job_orchestration_service.py | 2 +- ...image_batch_job_llm_fine_tuning_service.py | 73 +- .../infra/services/image_cache_service.py | 59 +- .../live_batch_job_orchestration_service.py | 55 +- .../infra/services/live_batch_job_service.py | 14 +- .../services/live_endpoint_builder_service.py | 223 ++- .../live_llm_model_endpoint_service.py | 12 +- .../services/live_model_endpoint_service.py | 41 +- .../services/model_endpoint_cache_service.py | 8 +- .../service_builder}/__init__.py | 0 .../service_builder/celery.py | 13 + .../service_builder/tasks_v1.py | 84 +- {server => model-engine}/mypy.ini | 12 +- .../requirements-test.txt | 2 + {server => model-engine}/requirements.in | 4 +- {server => model-engine}/requirements.txt | 128 +- model-engine/requirements_override.txt | 2 + .../service_config_circleci.yaml | 60 + model-engine/setup.cfg | 35 + model-engine/setup.py | 20 + model-engine/tests/README.md | 7 + .../tests}/__init__.py | 0 .../tests/integration}/__init__.py | 0 .../tests/integration/inference/conftest.py | 22 +- .../inference/test_async_inference.py | 12 +- .../tests/unit/api/conftest.py | 163 +- .../tests/unit/api/test_app.py | 0 .../tests/unit/api/test_batch_jobs.py | 136 +- .../test_docker_image_batch_job_bundles.py | 2 +- .../tests/unit/api/test_llms.py | 41 +- .../tests/unit/api/test_model_bundles.py | 17 +- .../tests/unit/api/test_model_endpoints.py | 76 +- .../unit/api/test_model_endpoints_docs.py | 2 +- .../tests/unit/api/test_tasks.py | 14 +- model-engine/tests/unit/api/test_triggers.py | 312 ++++ .../tests/unit/common}/__init__.py | 0 .../tests/unit/common/test_batch_jobs_dtos.py | 2 +- .../tests/unit/conftest.py | 483 +++++- .../tests/unit/domain/conftest.py | 57 +- .../domain/test_async_inference_use_cases.py | 10 +- ...docker_image_batch_job_bundle_use_cases.py | 28 +- .../tests/unit/domain/test_entities.py | 8 +- .../tests/unit/domain/test_llm_use_cases.py | 320 +++- .../domain/test_model_bundle_use_cases.py | 14 +- .../domain/test_model_endpoint_use_cases.py | 363 +++- .../test_streaming_inference_use_cases.py | 12 +- .../domain/test_sync_inference_use_cases.py | 10 +- .../tests/unit/inference/test_forwarding.py | 64 +- .../unit/inference/test_http_forwarder.py | 46 + .../tests/unit/infra/gateways/conftest.py | 30 +- .../test_k8s_endpoint_resource_delegate.py | 96 +- ...est_live_sqs_endpoint_resource_delegate.py | 64 +- .../gateways/test_k8s_resource_parser.py | 2 +- ...test_live_async_model_inference_gateway.py | 4 +- .../test_live_batch_job_progress_gateway.py | 4 +- ...est_live_docker_image_batch_job_gateway.py | 2 +- .../test_live_model_endpoint_infra_gateway.py | 33 +- ...est_live_model_endpoints_schema_gateway.py | 8 +- ...eaming_model_endpoint_inference_gateway.py | 18 +- ...e_sync_model_endpoint_inference_gateway.py | 18 +- .../tests/unit/infra/repositories/conftest.py | 12 +- .../test_db_batch_job_record_repository.py | 8 +- ...ocker_image_batch_job_bundle_repository.py | 16 +- .../test_db_model_bundle_repository.py | 10 +- ...est_db_model_endpoint_record_repository.py | 18 +- .../test_redis_feature_flag_repository.py | 2 +- ...t_redis_model_endpoint_cache_repository.py | 2 +- .../tests/unit/infra/services/conftest.py | 6 +- ...image_batch_job_llm_fine_tuning_service.py | 65 + .../services/test_image_cache_service.py | 58 + ...st_live_batch_job_orchestration_service.py | 16 +- .../services/test_live_batch_job_service.py | 10 +- .../test_live_endpoint_builder_service.py | 34 +- .../test_live_model_endpoint_service.py | 10 +- .../test_model_endpoint_cache_service.py | 6 +- requirements-dev.txt | 3 +- server/Dockerfile.openapi | 6 - server/Makefile | 43 - server/docker-compose.yml | 155 -- server/llm_engine_server/api/app.py | 36 - server/llm_engine_server/common/constants.py | 13 - .../llm_engine_server/common/datadog_utils.py | 10 - server/llm_engine_server/common/dtos/llms.py | 171 -- server/llm_engine_server/common/env_vars.py | 67 - server/llm_engine_server/common/settings.py | 66 - .../llm_engine_server/core/aws/sfn_client.py | 21 - server/llm_engine_server/core/kubernetes.py | 81 - .../core/testing_utilities.py | 140 -- .../db/migrations/alembic.ini | 85 - .../db/migrations/alembic/README | 1 - .../db/migrations/alembic/env.py | 91 - .../db/migrations/alembic/script.py.mako | 24 - server/llm_engine_server/db/ml_infra_pg.py | 10 - .../services/llm_fine_tuning_service.py | 30 - .../use_cases/llm_fine_tuning_use_cases.py | 82 - ...-runnable-img-converted-from-artifact.yaml | 13 - .../inference/configs/service--forwarder.yaml | 13 - .../configs/service--streaming_forwarder.yaml | 10 - .../inference/forwarding/http_forwarder.py | 167 -- .../llm_engine_server/inference/limits.conf | 2 - .../inference/pytorch_or_tf.Dockerfile | 81 - .../inference/pytorch_or_tf.base.Dockerfile | 78 - .../inference/pytorch_or_tf.user.Dockerfile | 8 - .../inference/user.Dockerfile | 8 - .../datadog_monitoring_metrics_gateway.py | 23 - .../scripts/autogenerate_client_and_docs.py | 39 - .../scripts/copy_to_public_client.sh | 11 - .../service_builder/celery.py | 13 - server/pyproject.toml | 6 - server/requirements_override.txt | 4 - server/service_configs/service_config.yaml | 66 - .../service_config_circleci.yaml | 65 - server/setup.cfg | 18 - server/setup.py | 19 - server/tests/README.md | 7 - server/tests/unit/common/__init__.py | 0 server/tests/unit/core/test_env.py | 105 -- server/tests/unit/db/common/test_query.py | 18 - .../tests/unit/db/common/test_repository.py | 69 - server/tests/unit/db/conftest.py | 467 ----- .../tests/unit/db/test_endpoint_row_lock.py | 22 - server/tests/unit/db/test_llm_engine.py | 160 -- server/tests/unit/db/test_model.py | 140 -- .../services/test_image_cache_service.py | 38 - 366 files changed, 8216 insertions(+), 6041 deletions(-) rename {server => model-engine}/Dockerfile (57%) rename {server/llm_engine_server => model-engine/model_engine_server}/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/api/__init__.py (100%) create mode 100644 model-engine/model_engine_server/api/app.py rename {server/llm_engine_server => model-engine/model_engine_server}/api/batch_jobs_v1.py (78%) rename {server/llm_engine_server => model-engine/model_engine_server}/api/dependencies.py (66%) rename {server/llm_engine_server => model-engine/model_engine_server}/api/docker_image_batch_job_bundles_v1.py (84%) create mode 100644 model-engine/model_engine_server/api/files_v1.py rename {server/llm_engine_server => model-engine/model_engine_server}/api/llms_v1.py (68%) rename {server/llm_engine_server => model-engine/model_engine_server}/api/model_bundles_v1.py (91%) rename {server/llm_engine_server => model-engine/model_engine_server}/api/model_bundles_v2.py (92%) rename {server/llm_engine_server => model-engine/model_engine_server}/api/model_endpoints_docs_v1.py (84%) rename {server/llm_engine_server => model-engine/model_engine_server}/api/model_endpoints_v1.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/api/tasks_v1.py (87%) create mode 100644 model-engine/model_engine_server/api/triggers_v1.py rename {server/llm_engine_server => model-engine/model_engine_server}/api/worker.py (70%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/config.py (60%) create mode 100644 model-engine/model_engine_server/common/constants.py create mode 100644 model-engine/model_engine_server/common/datadog_utils.py rename {server/llm_engine_server => model-engine/model_engine_server}/common/dtos/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/dtos/batch_jobs.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/dtos/docker_repository.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/dtos/endpoint_builder.py (92%) create mode 100644 model-engine/model_engine_server/common/dtos/files.py create mode 100644 model-engine/model_engine_server/common/dtos/llms.py rename {server/llm_engine_server => model-engine/model_engine_server}/common/dtos/model_bundles.py (98%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/dtos/model_endpoints.py (97%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/dtos/resource_manager.py (64%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/dtos/tasks.py (94%) create mode 100644 model-engine/model_engine_server/common/dtos/triggers.py create mode 100644 model-engine/model_engine_server/common/env_vars.py rename {server/llm_engine_server => model-engine/model_engine_server}/common/errors.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/io.py (92%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/pydantic_types/endpoint_predict_payload.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/resource_limits.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/serialization_utils.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/common/service_requests.py (93%) create mode 100644 model-engine/model_engine_server/common/settings.py rename {server/llm_engine_server => model-engine/model_engine_server}/common/types.py (98%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/auth/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/auth/authentication_repository.py (87%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/auth/fake_authentication_repository.py (85%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/aws/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/aws/roles.py (88%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/aws/secrets.py (64%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/aws/storage_client.py (86%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/celery/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/celery/app.py (96%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/celery/s3.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/config.py (64%) rename server/llm_engine_server/core/configs/circleci.yaml => model-engine/model_engine_server/core/configs/default.yaml (91%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/docker/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/docker/docker_image.py (90%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/docker/ecr.py (82%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/docker/kaniko_template.yaml (95%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/docker/kaniko_template_circleci.yaml (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/docker/remote_build.py (95%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/domain_exceptions.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/fake_notification_gateway.py (90%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/loggers.py (91%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/notification_gateway.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/utils/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/utils/env.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/utils/format.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/utils/git.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/utils/python_utils.py (95%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/utils/timer.py (98%) rename {server/llm_engine_server => model-engine/model_engine_server}/core/utils/url.py (91%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/base.py (73%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/endpoint_row_lock.py (95%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/local_setup.py (92%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/models/__init__.py (68%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/models/common/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/models/common/query.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/models/common/record.py (92%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/models/constants.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/models/exceptions.py (100%) rename server/llm_engine_server/db/models/llm_engine.py => model-engine/model_engine_server/db/models/hosted_model_inference.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/models/model.py (93%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/models/utils/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/db/models/utils/misc.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/authorization/__init__.py (100%) rename server/llm_engine_server/domain/authorization/scale_authorization_module.py => model-engine/model_engine_server/domain/authorization/live_authorization_module.py (67%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/entities/__init__.py (85%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/entities/batch_job_entity.py (78%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/entities/common_types.py (69%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/entities/docker_image_batch_job_bundle_entity.py (80%) create mode 100644 model-engine/model_engine_server/domain/entities/file_entity.py rename {server/llm_engine_server => model-engine/model_engine_server}/domain/entities/gpu_type.py (65%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/entities/llm_entity.py (100%) rename server/llm_engine_server/domain/entities/llm_fine_tune_job_entity.py => model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py (54%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/entities/model_bundle_entity.py (91%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/entities/model_endpoint_entity.py (88%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/entities/owned_entity.py (100%) create mode 100644 model-engine/model_engine_server/domain/entities/trigger_entity.py rename {server/llm_engine_server => model-engine/model_engine_server}/domain/exceptions.py (77%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/gateways/__init__.py (79%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/gateways/async_model_endpoint_inference_gateway.py (88%) create mode 100644 model-engine/model_engine_server/domain/gateways/cron_job_gateway.py rename {server/llm_engine_server => model-engine/model_engine_server}/domain/gateways/docker_image_batch_job_gateway.py (88%) create mode 100644 model-engine/model_engine_server/domain/gateways/file_storage_gateway.py create mode 100644 model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py rename {server/llm_engine_server => model-engine/model_engine_server}/domain/gateways/model_endpoints_schema_gateway.py (86%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/gateways/model_primitive_gateway.py (84%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/gateways/monitoring_metrics_gateway.py (67%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/gateways/streaming_model_endpoint_inference_gateway.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/gateways/sync_model_endpoint_inference_gateway.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/gateways/task_queue_gateway.py (91%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/repositories/__init__.py (65%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/repositories/docker_image_batch_job_bundle_repository.py (93%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/repositories/docker_repository.py (94%) create mode 100644 model-engine/model_engine_server/domain/repositories/llm_fine_tune_events_repository.py rename {server/llm_engine_server => model-engine/model_engine_server}/domain/repositories/model_bundle_repository.py (95%) create mode 100644 model-engine/model_engine_server/domain/repositories/trigger_repository.py rename {server/llm_engine_server => model-engine/model_engine_server}/domain/services/__init__.py (82%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/services/batch_job_service.py (92%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/services/endpoint_builder_service.py (89%) create mode 100644 model-engine/model_engine_server/domain/services/llm_fine_tuning_service.py rename {server/llm_engine_server => model-engine/model_engine_server}/domain/services/llm_model_endpoint_service.py (89%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/services/model_endpoint_service.py (95%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/async_inference_use_cases.py (84%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/batch_job_use_cases.py (81%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/docker_image_batch_job_bundle_use_cases.py (90%) create mode 100644 model-engine/model_engine_server/domain/use_cases/file_use_cases.py create mode 100644 model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/llm_model_endpoint_use_cases.py (75%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/model_bundle_use_cases.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/model_endpoint_use_cases.py (82%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/model_endpoints_schema_use_cases.py (64%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/streaming_inference_use_cases.py (77%) rename {server/llm_engine_server => model-engine/model_engine_server}/domain/use_cases/sync_inference_use_cases.py (76%) create mode 100644 model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py rename {server/llm_engine_server => model-engine/model_engine_server}/entrypoints/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/entrypoints/init_database.py (91%) rename server/llm_engine_server/entrypoints/init_llm_engine_models.py => model-engine/model_engine_server/entrypoints/init_spellbook_models.py (95%) rename {server/llm_engine_server => model-engine/model_engine_server}/entrypoints/k8s_cache.py (66%) rename {server/llm_engine_server => model-engine/model_engine_server}/entrypoints/start_batch_job_orchestration.py (78%) rename {server/llm_engine_server => model-engine/model_engine_server}/entrypoints/start_docker_image_batch_job_init_container.py (78%) rename {server/llm_engine_server => model-engine/model_engine_server}/entrypoints/start_fastapi_server.py (92%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/async_inference/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/async_inference/celery.py (64%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/async_inference/tasks.py (75%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/async_inference/vpa.yaml (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/base.Dockerfile (72%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/common.py (92%) create mode 100644 model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml create mode 100644 model-engine/model_engine_server/inference/configs/service--forwarder.yaml rename {server/llm_engine_server => model-engine/model_engine_server}/inference/configs/service--http_forwarder.yaml (82%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/domain/gateways/inference_monitoring_metrics_gateway.py (89%) create mode 100644 model-engine/model_engine_server/inference/domain/gateways/usage_metrics_gateway.py rename {server/llm_engine_server => model-engine/model_engine_server}/inference/download_and_inject_bundle.py (96%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/forwarding/__init__.py (100%) create mode 100644 model-engine/model_engine_server/inference/forwarding/celery_forwarder.py rename {server/llm_engine_server => model-engine/model_engine_server}/inference/forwarding/forwarding.py (82%) create mode 100644 model-engine/model_engine_server/inference/forwarding/http_forwarder.py rename {server/llm_engine_server => model-engine/model_engine_server}/inference/infra/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/infra/gateways/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py (50%) create mode 100644 model-engine/model_engine_server/inference/infra/gateways/fake_usage_metrics_gateway.py rename {server/llm_engine_server => model-engine/model_engine_server}/inference/inject_bundle.Dockerfile (79%) create mode 100644 model-engine/model_engine_server/inference/limits.conf rename {server/llm_engine_server => model-engine/model_engine_server}/inference/post_inference_hooks.py (62%) create mode 100644 model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile create mode 100644 model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile rename {server/llm_engine_server => model-engine/model_engine_server}/inference/requirements_base.txt (69%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/service_requests.py (89%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/sync_inference/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/sync_inference/constants.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/sync_inference/destination_rule.yaml (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/sync_inference/fastapi_server.py (86%) create mode 100644 model-engine/model_engine_server/inference/sync_inference/server.py rename {server/llm_engine_server => model-engine/model_engine_server}/inference/sync_inference/start_fastapi_server.py (65%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/sync_inference/virtual_service.yaml (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/inference/sync_inference/vpa.yaml (100%) create mode 100644 model-engine/model_engine_server/inference/user.Dockerfile rename {server/llm_engine_server => model-engine/model_engine_server}/infra/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/__init__.py (87%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/aiohttp_sse_client.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/batch_job_orchestration_gateway.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/batch_job_progress_gateway.py (92%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/celery_task_queue_gateway.py (71%) create mode 100644 model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/fake_model_primitive_gateway.py (82%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/fake_monitoring_metrics_gateway.py (64%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/filesystem_gateway.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/k8s_resource_parser.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/live_async_model_endpoint_inference_gateway.py (83%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/live_batch_job_orchestration_gateway.py (81%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/live_batch_job_progress_gateway.py (69%) create mode 100644 model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/live_docker_image_batch_job_gateway.py (75%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/live_model_endpoint_infra_gateway.py (90%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/live_model_endpoints_schema_gateway.py (91%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/live_streaming_model_endpoint_inference_gateway.py (90%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/live_sync_model_endpoint_inference_gateway.py (90%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/model_endpoint_infra_gateway.py (96%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/endpoint_resource_gateway.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py (93%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/image_cache_gateway.py (84%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/k8s_endpoint_resource_delegate.py (90%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/k8s_resource_types.py (79%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/live_endpoint_resource_gateway.py (84%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/sqs_endpoint_resource_delegate.py (95%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/resources/templates/service_template_config_map_circleci.yaml (69%) create mode 100644 model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py rename {server/llm_engine_server => model-engine/model_engine_server}/infra/gateways/s3_filesystem_gateway.py (77%) create mode 100644 model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py rename {server/llm_engine_server => model-engine/model_engine_server}/infra/infra_utils.py (96%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/__init__.py (75%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/batch_job_record_repository.py (97%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/db_batch_job_record_repository.py (90%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/db_docker_image_batch_job_bundle_repository.py (90%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/db_model_bundle_repository.py (96%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/db_model_endpoint_record_repository.py (94%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/db_repository_mixin.py (85%) create mode 100644 model-engine/model_engine_server/infra/repositories/db_trigger_repository.py rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/ecr_docker_repository.py (69%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/feature_flag_repository.py (100%) rename server/llm_engine_server/infra/repositories/llm_fine_tuning_job_repository.py => model-engine/model_engine_server/infra/repositories/llm_fine_tune_repository.py (69%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/model_endpoint_cache_repository.py (93%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/model_endpoint_record_repository.py (97%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/redis_feature_flag_repository.py (92%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/repositories/redis_model_endpoint_cache_repository.py (91%) create mode 100644 model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py rename server/llm_engine_server/infra/repositories/s3_file_llm_fine_tuning_job_repository.py => model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py (81%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/__init__.py (100%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/batch_job_orchestration_service.py (91%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/docker_image_batch_job_llm_fine_tuning_service.py (58%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/image_cache_service.py (62%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/live_batch_job_orchestration_service.py (88%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/live_batch_job_service.py (92%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/live_endpoint_builder_service.py (81%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/live_llm_model_endpoint_service.py (81%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/live_model_endpoint_service.py (90%) rename {server/llm_engine_server => model-engine/model_engine_server}/infra/services/model_endpoint_cache_service.py (83%) rename {server/llm_engine_server/scripts => model-engine/model_engine_server/service_builder}/__init__.py (100%) create mode 100644 model-engine/model_engine_server/service_builder/celery.py rename {server/llm_engine_server => model-engine/model_engine_server}/service_builder/tasks_v1.py (54%) rename {server => model-engine}/mypy.ini (60%) rename {server => model-engine}/requirements-test.txt (90%) rename {server => model-engine}/requirements.in (93%) rename {server => model-engine}/requirements.txt (72%) create mode 100644 model-engine/requirements_override.txt create mode 100644 model-engine/service_configs/service_config_circleci.yaml create mode 100644 model-engine/setup.cfg create mode 100644 model-engine/setup.py create mode 100644 model-engine/tests/README.md rename {server/llm_engine_server/service_builder => model-engine/tests}/__init__.py (100%) rename {server/tests => model-engine/tests/integration}/__init__.py (100%) rename {server => model-engine}/tests/integration/inference/conftest.py (86%) rename {server => model-engine}/tests/integration/inference/test_async_inference.py (94%) rename {server => model-engine}/tests/unit/api/conftest.py (86%) rename {server => model-engine}/tests/unit/api/test_app.py (100%) rename {server => model-engine}/tests/unit/api/test_batch_jobs.py (81%) rename {server => model-engine}/tests/unit/api/test_docker_image_batch_job_bundles.py (99%) rename {server => model-engine}/tests/unit/api/test_llms.py (78%) rename {server => model-engine}/tests/unit/api/test_model_bundles.py (96%) rename {server => model-engine}/tests/unit/api/test_model_endpoints.py (88%) rename {server => model-engine}/tests/unit/api/test_model_endpoints_docs.py (97%) rename {server => model-engine}/tests/unit/api/test_tasks.py (96%) create mode 100644 model-engine/tests/unit/api/test_triggers.py rename {server/tests/integration => model-engine/tests/unit/common}/__init__.py (100%) rename {server => model-engine}/tests/unit/common/test_batch_jobs_dtos.py (95%) rename {server => model-engine}/tests/unit/conftest.py (88%) rename {server => model-engine}/tests/unit/domain/conftest.py (86%) rename {server => model-engine}/tests/unit/domain/test_async_inference_use_cases.py (91%) rename {server => model-engine}/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py (92%) rename {server => model-engine}/tests/unit/domain/test_entities.py (86%) rename {server => model-engine}/tests/unit/domain/test_llm_use_cases.py (67%) rename {server => model-engine}/tests/unit/domain/test_model_bundle_use_cases.py (97%) rename {server => model-engine}/tests/unit/domain/test_model_endpoint_use_cases.py (72%) rename {server => model-engine}/tests/unit/domain/test_streaming_inference_use_cases.py (88%) rename {server => model-engine}/tests/unit/domain/test_sync_inference_use_cases.py (90%) rename {server => model-engine}/tests/unit/inference/test_forwarding.py (81%) create mode 100644 model-engine/tests/unit/inference/test_http_forwarder.py rename {server => model-engine}/tests/unit/infra/gateways/conftest.py (51%) rename {server => model-engine}/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py (87%) rename {server => model-engine}/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py (88%) rename {server => model-engine}/tests/unit/infra/gateways/test_k8s_resource_parser.py (97%) rename {server => model-engine}/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py (94%) rename {server => model-engine}/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py (91%) rename {server => model-engine}/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py (94%) rename {server => model-engine}/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py (84%) rename {server => model-engine}/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py (96%) rename {server => model-engine}/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py (86%) rename {server => model-engine}/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py (84%) rename {server => model-engine}/tests/unit/infra/repositories/conftest.py (97%) rename {server => model-engine}/tests/unit/infra/repositories/test_db_batch_job_record_repository.py (97%) rename {server => model-engine}/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py (94%) rename {server => model-engine}/tests/unit/infra/repositories/test_db_model_bundle_repository.py (97%) rename {server => model-engine}/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py (95%) rename {server => model-engine}/tests/unit/infra/repositories/test_redis_feature_flag_repository.py (88%) rename {server => model-engine}/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py (92%) rename {server => model-engine}/tests/unit/infra/services/conftest.py (92%) create mode 100644 model-engine/tests/unit/infra/services/test_docker_image_batch_job_llm_fine_tuning_service.py create mode 100644 model-engine/tests/unit/infra/services/test_image_cache_service.py rename {server => model-engine}/tests/unit/infra/services/test_live_batch_job_orchestration_service.py (94%) rename {server => model-engine}/tests/unit/infra/services/test_live_batch_job_service.py (94%) rename {server => model-engine}/tests/unit/infra/services/test_live_endpoint_builder_service.py (87%) rename {server => model-engine}/tests/unit/infra/services/test_live_model_endpoint_service.py (97%) rename {server => model-engine}/tests/unit/infra/services/test_model_endpoint_cache_service.py (92%) delete mode 100644 server/Dockerfile.openapi delete mode 100644 server/Makefile delete mode 100644 server/docker-compose.yml delete mode 100644 server/llm_engine_server/api/app.py delete mode 100644 server/llm_engine_server/common/constants.py delete mode 100644 server/llm_engine_server/common/datadog_utils.py delete mode 100644 server/llm_engine_server/common/dtos/llms.py delete mode 100644 server/llm_engine_server/common/env_vars.py delete mode 100644 server/llm_engine_server/common/settings.py delete mode 100644 server/llm_engine_server/core/aws/sfn_client.py delete mode 100644 server/llm_engine_server/core/kubernetes.py delete mode 100644 server/llm_engine_server/core/testing_utilities.py delete mode 100644 server/llm_engine_server/db/migrations/alembic.ini delete mode 100644 server/llm_engine_server/db/migrations/alembic/README delete mode 100644 server/llm_engine_server/db/migrations/alembic/env.py delete mode 100644 server/llm_engine_server/db/migrations/alembic/script.py.mako delete mode 100644 server/llm_engine_server/db/ml_infra_pg.py delete mode 100644 server/llm_engine_server/domain/services/llm_fine_tuning_service.py delete mode 100644 server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py delete mode 100644 server/llm_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml delete mode 100644 server/llm_engine_server/inference/configs/service--forwarder.yaml delete mode 100644 server/llm_engine_server/inference/configs/service--streaming_forwarder.yaml delete mode 100644 server/llm_engine_server/inference/forwarding/http_forwarder.py delete mode 100644 server/llm_engine_server/inference/limits.conf delete mode 100644 server/llm_engine_server/inference/pytorch_or_tf.Dockerfile delete mode 100644 server/llm_engine_server/inference/pytorch_or_tf.base.Dockerfile delete mode 100644 server/llm_engine_server/inference/pytorch_or_tf.user.Dockerfile delete mode 100644 server/llm_engine_server/inference/user.Dockerfile delete mode 100644 server/llm_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py delete mode 100644 server/llm_engine_server/scripts/autogenerate_client_and_docs.py delete mode 100755 server/llm_engine_server/scripts/copy_to_public_client.sh delete mode 100644 server/llm_engine_server/service_builder/celery.py delete mode 100644 server/pyproject.toml delete mode 100644 server/requirements_override.txt delete mode 100644 server/service_configs/service_config.yaml delete mode 100644 server/service_configs/service_config_circleci.yaml delete mode 100644 server/setup.cfg delete mode 100644 server/setup.py delete mode 100644 server/tests/README.md delete mode 100644 server/tests/unit/common/__init__.py delete mode 100644 server/tests/unit/core/test_env.py delete mode 100644 server/tests/unit/db/common/test_query.py delete mode 100644 server/tests/unit/db/common/test_repository.py delete mode 100644 server/tests/unit/db/conftest.py delete mode 100644 server/tests/unit/db/test_endpoint_row_lock.py delete mode 100644 server/tests/unit/db/test_llm_engine.py delete mode 100644 server/tests/unit/db/test_model.py delete mode 100644 server/tests/unit/infra/services/test_image_cache_service.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 77c2a7d1..0deee84f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -93,7 +93,7 @@ jobs: - run: name: Build Docker Image command: | - docker build . -f server/Dockerfile -t llm-engine:$CIRCLE_SHA1 + docker build . -f model-engine/Dockerfile -t llm-engine:$CIRCLE_SHA1 integration_tests: executor: ubuntu-large steps: @@ -138,19 +138,19 @@ commands: steps: - python/install-packages: pkg-manager: pip - app-dir: server + app-dir: model-engine - python/install-packages: pkg-manager: pip - app-dir: server + app-dir: model-engine pip-dependency-file: requirements-test.txt - python/install-packages: pkg-manager: pip - app-dir: server + app-dir: model-engine pip-dependency-file: requirements_override.txt - run: name: Install Server command: | - pushd server + pushd model-engine pip install -e . popd install_client: @@ -187,12 +187,12 @@ commands: - run: name: Type Check command: | - pushd server + pushd model-engine mypy . --install-types --non-interactive popd - run: name: Unit Tests command: | - pushd server - WORKSPACE=.. pytest + pushd model-engine + GIT_TAG=$(git rev-parse HEAD) WORKSPACE=.. pytest popd diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7844d3df..a560c8e3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,13 +17,13 @@ repos: - id: isort name: "python:isort" - repo: https://github.com/jazzband/pip-tools - rev: 6.6.2 + rev: 7.0.0 hooks: - id: pip-compile - files: server/requirements\.(in|txt) + files: model-engine/requirements\.(in|txt) args: [ - server/requirements.in, + model-engine/requirements.in, --allow-unsafe, --no-emit-index-url, --no-emit-trusted-host, diff --git a/charts/llm-engine/templates/cacher_deployment.yaml b/charts/llm-engine/templates/cacher_deployment.yaml index 6c833c35..1191cb40 100644 --- a/charts/llm-engine/templates/cacher_deployment.yaml +++ b/charts/llm-engine/templates/cacher_deployment.yaml @@ -49,7 +49,7 @@ spec: args: - python - -m - - server.llm_engine_server.entrypoints.k8s_cache + - model_engine_server.entrypoints.k8s_cache resources: {{- toYaml .Values.resources | nindent 12 }} {{- include "llmEngine.cacherEnv" . | indent 10 }} diff --git a/charts/llm-engine/templates/database_init_job.yaml b/charts/llm-engine/templates/database_init_job.yaml index f743d0b6..571dd1f8 100644 --- a/charts/llm-engine/templates/database_init_job.yaml +++ b/charts/llm-engine/templates/database_init_job.yaml @@ -33,7 +33,7 @@ spec: args: - python - -m - - server.llm_engine_server.entrypoints.init_database + - model_engine_server.entrypoints.init_database {{- include "llmEngine.serviceEnv" . | indent 10 }} {{- include "llmEngine.volumeMounts" . | indent 10 }} serviceAccountName: {{ include "llmEngine.fullname" . }} diff --git a/charts/llm-engine/templates/endpoint_builder_deployment.yaml b/charts/llm-engine/templates/endpoint_builder_deployment.yaml index fbd85a69..a88e07c0 100644 --- a/charts/llm-engine/templates/endpoint_builder_deployment.yaml +++ b/charts/llm-engine/templates/endpoint_builder_deployment.yaml @@ -49,7 +49,7 @@ spec: - ddtrace-run args: - celery - - --app=server.llm_engine_server.service_builder + - --app=server.model_engine_server.service_builder - worker - --loglevel=INFO - --concurrency=2 diff --git a/charts/llm-engine/templates/gateway_deployment.yaml b/charts/llm-engine/templates/gateway_deployment.yaml index e2753524..f727d2d2 100644 --- a/charts/llm-engine/templates/gateway_deployment.yaml +++ b/charts/llm-engine/templates/gateway_deployment.yaml @@ -63,7 +63,7 @@ spec: args: - python - -m - - server.llm_engine_server.entrypoints.start_fastapi_server + - model_engine_server.entrypoints.start_fastapi_server resources: {{- toYaml .Values.resources | nindent 12 }} {{- include "llmEngine.gatewayEnv" . | indent 10 }} diff --git a/charts/llm-engine/templates/llm_engine_init_job.yaml b/charts/llm-engine/templates/llm_engine_init_job.yaml index 1892d087..25d1e6c3 100644 --- a/charts/llm-engine/templates/llm_engine_init_job.yaml +++ b/charts/llm-engine/templates/llm_engine_init_job.yaml @@ -33,7 +33,7 @@ spec: args: - python - -m - - server.llm_engine_server.entrypoints.init_llm_engine_models + - model_engine_server.entrypoints.init_llm_engine_models - --gateway-url - 'http://{{- include "llmEngine.fullname" . }}.{{ .Release.Namespace }}:{{ .Values.service.port }}' {{- include "llmEngine.serviceEnv" . | indent 10 }} diff --git a/charts/llm-engine/templates/service_template_config_map.yaml b/charts/llm-engine/templates/service_template_config_map.yaml index 08ce1424..0f277c45 100644 --- a/charts/llm-engine/templates/service_template_config_map.yaml +++ b/charts/llm-engine/templates/service_template_config_map.yaml @@ -180,7 +180,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --http - production_threads - --port @@ -221,9 +221,9 @@ data: - ddtrace-run - python - -m - - server.llm_engine_server.inference.forwarding.http_forwarder + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/service--http_forwarder.yaml + - /workspace/server/model_engine_server/inference/configs/service--http_forwarder.yaml - --port - "${FORWARDER_PORT}" - --num-workers @@ -266,7 +266,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/server/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -559,7 +559,7 @@ data: args: - python - -m - - server.llm_engine_server.entrypoints.start_batch_job_orchestration + - model_engine_server.entrypoints.start_batch_job_orchestration - --job-id - ${JOB_ID} - --owner @@ -669,7 +669,7 @@ data: command: - python - -m - - server.llm_engine_server.entrypoints.start_docker_image_batch_job_init_container + - model_engine_server.entrypoints.start_docker_image_batch_job_init_container - ${INPUT_LOCATION} - --remote-file - ${S3_FILE} diff --git a/server/Dockerfile b/model-engine/Dockerfile similarity index 57% rename from server/Dockerfile rename to model-engine/Dockerfile index 59d466cb..7e186157 100644 --- a/server/Dockerfile +++ b/model-engine/Dockerfile @@ -1,6 +1,6 @@ # syntax = docker/dockerfile:experimental -FROM python:3.8.8-slim as llm-engine +FROM python:3.8.8-slim as model-engine WORKDIR /workspace RUN apt-get update && apt-get install -y \ @@ -18,7 +18,6 @@ RUN apt-get update && apt-get install -y \ python3-dev \ gcc \ build-essential \ - postgresql \ telnet \ && rm -rf /var/lib/apt/lists/* @@ -26,7 +25,7 @@ RUN curl -Lo /bin/aws-iam-authenticator https://github.com/kubernetes-sigs/aws-i RUN chmod +x /bin/aws-iam-authenticator # Install kubectl -RUN curl -LO "https://dl.k8s.io/release/v1.17.9/bin/linux/amd64/kubectl" \ +RUN curl -LO "https://dl.k8s.io/release/v1.23.13/bin/linux/amd64/kubectl" \ && chmod +x kubectl \ && mv kubectl /usr/local/bin/kubectl @@ -34,17 +33,19 @@ RUN curl -LO "https://dl.k8s.io/release/v1.17.9/bin/linux/amd64/kubectl" \ RUN pip install pip==23.0.1 RUN chmod -R 777 /workspace -## grab llm_engine_server project (w/ requirements install layer caching) -WORKDIR /workspace/server/ -COPY server/requirements-test.txt /workspace/server/requirements-test.txt -COPY server/requirements.txt /workspace/server/requirements.txt -COPY server/requirements_override.txt /workspace/server/requirements_override.txt +# Install AWS CLI +RUN pip install awscli==1.25.62 --no-cache-dir + +## grab model_engine_server project (w/ requirements install layer caching) +WORKDIR /workspace/model-engine/ +COPY model-engine/requirements-test.txt /workspace/model-engine/requirements-test.txt +COPY model-engine/requirements.txt /workspace/model-engine/requirements.txt +COPY model-engine/requirements_override.txt /workspace/model-engine/requirements_override.txt RUN pip install -r requirements-test.txt --no-cache-dir RUN pip install -r requirements.txt --no-cache-dir RUN pip install -r requirements_override.txt --no-cache-dir -COPY server/pyproject.toml /workspace/server/pyproject.toml -COPY server/setup.py /workspace/server/setup.py -COPY server/llm_engine_server /workspace/server/llm_engine_server +COPY model-engine/setup.py /workspace/model-engine/setup.py +COPY model-engine/model_engine_server /workspace/model-engine/model_engine_server RUN pip install -e . WORKDIR /workspace @@ -52,9 +53,3 @@ ENV PYTHONPATH /workspace ENV WORKSPACE /workspace EXPOSE 5000 -EXPOSE 5001 -EXPOSE 5002 -EXPOSE 5005 - -RUN useradd -m user -s /bin/bash -USER user diff --git a/server/llm_engine_server/__init__.py b/model-engine/model_engine_server/__init__.py similarity index 100% rename from server/llm_engine_server/__init__.py rename to model-engine/model_engine_server/__init__.py diff --git a/server/llm_engine_server/api/__init__.py b/model-engine/model_engine_server/api/__init__.py similarity index 100% rename from server/llm_engine_server/api/__init__.py rename to model-engine/model_engine_server/api/__init__.py diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py new file mode 100644 index 00000000..786f097a --- /dev/null +++ b/model-engine/model_engine_server/api/app.py @@ -0,0 +1,53 @@ +import os +from pathlib import Path + +from fastapi import FastAPI, Response +from fastapi.staticfiles import StaticFiles +from model_engine_server.api.batch_jobs_v1 import batch_job_router_v1 +from model_engine_server.api.dependencies import get_or_create_aioredis_pool +from model_engine_server.api.docker_image_batch_job_bundles_v1 import ( + docker_image_batch_job_bundle_router_v1, +) +from model_engine_server.api.files_v1 import file_router_v1 +from model_engine_server.api.llms_v1 import llm_router_v1 +from model_engine_server.api.model_bundles_v1 import model_bundle_router_v1 +from model_engine_server.api.model_bundles_v2 import model_bundle_router_v2 +from model_engine_server.api.model_endpoints_docs_v1 import model_endpoints_docs_router_v1 +from model_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1 +from model_engine_server.api.tasks_v1 import inference_task_router_v1 +from model_engine_server.api.triggers_v1 import trigger_router_v1 + +app = FastAPI(title="launch", version="1.0.0", redoc_url="/api") + +app.include_router(batch_job_router_v1) +app.include_router(inference_task_router_v1) +app.include_router(model_bundle_router_v1) +app.include_router(model_bundle_router_v2) +app.include_router(model_endpoint_router_v1) +app.include_router(model_endpoints_docs_router_v1) +app.include_router(docker_image_batch_job_bundle_router_v1) +app.include_router(llm_router_v1) +app.include_router(file_router_v1) +app.include_router(trigger_router_v1) + +# TODO: Remove this once we have a better way to serve internal docs +INTERNAL_DOCS_PATH = str(Path(__file__).parents[3] / "launch_internal/site") +if os.path.exists(INTERNAL_DOCS_PATH): + app.mount( + "/python-docs", + StaticFiles(directory=INTERNAL_DOCS_PATH, html=True), + name="python-docs", + ) + + +@app.on_event("startup") +def load_redis(): + get_or_create_aioredis_pool() + + +@app.get("/healthcheck") +@app.get("/healthz") +@app.get("/readyz") +def healthcheck() -> Response: + """Returns 200 if the app is healthy.""" + return Response(status_code=200) diff --git a/server/llm_engine_server/api/batch_jobs_v1.py b/model-engine/model_engine_server/api/batch_jobs_v1.py similarity index 78% rename from server/llm_engine_server/api/batch_jobs_v1.py rename to model-engine/model_engine_server/api/batch_jobs_v1.py index ac83425f..7e939d9c 100644 --- a/server/llm_engine_server/api/batch_jobs_v1.py +++ b/model-engine/model_engine_server/api/batch_jobs_v1.py @@ -1,40 +1,44 @@ -from fastapi import APIRouter, Depends, HTTPException -from llm_engine_server.api.dependencies import ( +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.datadog_utils import add_trace_resource_name +from model_engine_server.common.dtos.batch_jobs import ( CreateBatchJobV1Request, CreateBatchJobV1Response, CreateDockerImageBatchJobV1Request, CreateDockerImageBatchJobV1Response, GetBatchJobV1Response, GetDockerImageBatchJobV1Response, + ListDockerImageBatchJobsV1Response, UpdateBatchJobV1Request, UpdateBatchJobV1Response, UpdateDockerImageBatchJobV1Request, UpdateDockerImageBatchJobV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ( EndpointLabelsException, EndpointResourceInvalidRequestException, ) -from llm_engine_server.domain.use_cases.batch_job_use_cases import ( +from model_engine_server.domain.use_cases.batch_job_use_cases import ( CreateBatchJobV1UseCase, CreateDockerImageBatchJobV1UseCase, GetBatchJobV1UseCase, GetDockerImageBatchJobV1UseCase, + ListDockerImageBatchJobsV1UseCase, UpdateBatchJobV1UseCase, UpdateDockerImageBatchJobV1UseCase, ) @@ -150,16 +154,14 @@ async def create_docker_image_batch_job( ) from exc except EndpointResourceInvalidRequestException as exc: raise HTTPException( - status_code=400, - detail=f"Final endpoint resources requested is invalid: {exc}", + status_code=400, detail=f"Final endpoint resources requested is invalid: {exc}" ) from exc except EndpointLabelsException as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @batch_job_router_v1.get( - "/docker-image-batch-jobs/{batch_job_id}", - response_model=GetDockerImageBatchJobV1Response, + "/docker-image-batch-jobs/{batch_job_id}", response_model=GetDockerImageBatchJobV1Response ) async def get_docker_image_batch_job( batch_job_id: str, @@ -179,9 +181,32 @@ async def get_docker_image_batch_job( ) from exc +@batch_job_router_v1.get( + "/docker-image-batch-jobs", + response_model=ListDockerImageBatchJobsV1Response, +) +async def list_docker_image_batch_jobs( + trigger_id: Optional[str] = Query(default=None), + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> ListDockerImageBatchJobsV1Response: + """ + Lists docker image batch jobs spawned by trigger with given ID + """ + add_trace_resource_name("batch_jobs_di_get_trigger") + logger.info(f"GET /docker-image-batch-jobs?trigger_id={trigger_id}") + try: + use_case = ListDockerImageBatchJobsV1UseCase( + trigger_repository=external_interfaces.trigger_repository, + cron_job_gateway=external_interfaces.cron_job_gateway, + ) + return await use_case.execute(user=auth, trigger_id=trigger_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=f"Trigger {trigger_id} was not found.") from exc + + @batch_job_router_v1.put( - "/docker-image-batch-jobs/{batch_job_id}", - response_model=UpdateDockerImageBatchJobV1Response, + "/docker-image-batch-jobs/{batch_job_id}", response_model=UpdateDockerImageBatchJobV1Response ) async def update_docker_image_batch_job( batch_job_id: str, diff --git a/server/llm_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py similarity index 66% rename from server/llm_engine_server/api/dependencies.py rename to model-engine/model_engine_server/api/dependencies.py index f00211a8..c9e00eb9 100644 --- a/server/llm_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -6,33 +6,43 @@ import aioredis from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBasic, HTTPBasicCredentials -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.core.auth.authentication_repository import AuthenticationRepository, User -from llm_engine_server.core.auth.fake_authentication_repository import FakeAuthenticationRepository -from llm_engine_server.db.base import SessionAsync, SessionReadOnlyAsync -from llm_engine_server.domain.gateways import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.core.auth.authentication_repository import AuthenticationRepository, User +from model_engine_server.core.auth.fake_authentication_repository import ( + FakeAuthenticationRepository, +) +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.db.base import SessionAsync, SessionReadOnlyAsync +from model_engine_server.domain.gateways import ( + CronJobGateway, DockerImageBatchJobGateway, + FileStorageGateway, + LLMArtifactGateway, ModelPrimitiveGateway, TaskQueueGateway, ) -from llm_engine_server.domain.repositories import ( +from model_engine_server.domain.repositories import ( DockerImageBatchJobBundleRepository, DockerRepository, + LLMFineTuneEventsRepository, ModelBundleRepository, + TriggerRepository, ) -from llm_engine_server.domain.services import ( +from model_engine_server.domain.services import ( BatchJobService, + LLMFineTuningService, LLMModelEndpointService, ModelEndpointService, ) -from llm_engine_server.infra.gateways import ( +from model_engine_server.infra.gateways import ( CeleryTaskQueueGateway, FakeMonitoringMetricsGateway, LiveAsyncModelEndpointInferenceGateway, LiveBatchJobOrchestrationGateway, LiveBatchJobProgressGateway, + LiveCronJobGateway, LiveDockerImageBatchJobGateway, LiveModelEndpointInfraGateway, LiveModelEndpointsSchemaGateway, @@ -40,42 +50,51 @@ LiveSyncModelEndpointInferenceGateway, ModelEndpointInfraGateway, S3FilesystemGateway, + S3LLMArtifactGateway, +) +from model_engine_server.infra.gateways.fake_model_primitive_gateway import ( + FakeModelPrimitiveGateway, ) -from llm_engine_server.infra.gateways.fake_model_primitive_gateway import FakeModelPrimitiveGateway -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -from llm_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( FakeSQSEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( LiveSQSEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( SQSEndpointResourceDelegate, ) -from llm_engine_server.infra.repositories import ( +from model_engine_server.infra.gateways.s3_file_storage_gateway import S3FileStorageGateway +from model_engine_server.infra.repositories import ( DbBatchJobRecordRepository, DbDockerImageBatchJobBundleRepository, DbModelBundleRepository, DbModelEndpointRecordRepository, + DbTriggerRepository, ECRDockerRepository, RedisModelEndpointCacheRepository, - S3FileLLMFineTuningJobRepository, + S3FileLLMFineTuneEventsRepository, + S3FileLLMFineTuneRepository, ) -from llm_engine_server.infra.services import ( +from model_engine_server.infra.services import ( DockerImageBatchJobLLMFineTuningService, LiveBatchJobService, LiveModelEndpointService, ) -from llm_engine_server.infra.services.live_llm_model_endpoint_service import ( +from model_engine_server.infra.services.live_llm_model_endpoint_service import ( LiveLLMModelEndpointService, ) from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session +logger = make_logger(filename_wo_ext(__name__)) + AUTH = HTTPBasic(auto_error=False) @@ -89,10 +108,12 @@ class ExternalInterfaces: docker_repository: DockerRepository docker_image_batch_job_bundle_repository: DockerImageBatchJobBundleRepository model_bundle_repository: ModelBundleRepository + trigger_repository: TriggerRepository model_endpoint_service: ModelEndpointService batch_job_service: BatchJobService llm_model_endpoint_service: LLMModelEndpointService - llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService + llm_fine_tuning_service: LLMFineTuningService + llm_fine_tune_events_repository: LLMFineTuneEventsRepository resource_gateway: EndpointResourceGateway endpoint_creation_task_queue_gateway: TaskQueueGateway @@ -100,6 +121,10 @@ class ExternalInterfaces: model_endpoint_infra_gateway: ModelEndpointInfraGateway docker_image_batch_job_gateway: DockerImageBatchJobGateway model_primitive_gateway: ModelPrimitiveGateway + file_storage_gateway: FileStorageGateway + filesystem_gateway: FilesystemGateway + llm_artifact_gateway: LLMArtifactGateway + cron_job_gateway: CronJobGateway def _get_external_interfaces( @@ -150,6 +175,7 @@ def _get_external_interfaces( use_asyncio=(not CIRCLECI), ) filesystem_gateway = S3FilesystemGateway() + llm_artifact_gateway = S3LLMArtifactGateway() model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) @@ -171,6 +197,7 @@ def _get_external_interfaces( session=session, read_only=read_only ) batch_job_record_repository = DbBatchJobRecordRepository(session=session, read_only=read_only) + trigger_repository = DbTriggerRepository(session=session, read_only=read_only) batch_job_orchestration_gateway = LiveBatchJobOrchestrationGateway() batch_job_progress_gateway = LiveBatchJobProgressGateway(filesystem_gateway=filesystem_gateway) batch_job_service = LiveBatchJobService( @@ -180,23 +207,26 @@ def _get_external_interfaces( batch_job_progress_gateway=batch_job_progress_gateway, ) - model_primitive_gateway: ModelPrimitiveGateway model_primitive_gateway = FakeModelPrimitiveGateway() docker_image_batch_job_gateway = LiveDockerImageBatchJobGateway() + cron_job_gateway = LiveCronJobGateway() - llm_fine_tuning_job_repository = S3FileLLMFineTuningJobRepository( + llm_fine_tune_repository = S3FileLLMFineTuneRepository( file_path=os.getenv( - "S3_FILE_LLM_FINE_TUNING_JOB_REPOSITORY", - hmi_config.s3_file_llm_fine_tuning_job_repository, + "S3_FILE_LLM_FINE_TUNE_REPOSITORY", + hmi_config.s3_file_llm_fine_tune_repository, ), ) + llm_fine_tune_events_repository = S3FileLLMFineTuneEventsRepository() llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService( docker_image_batch_job_gateway=docker_image_batch_job_gateway, docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository, - llm_fine_tuning_job_repository=llm_fine_tuning_job_repository, + llm_fine_tune_repository=llm_fine_tune_repository, ) + file_storage_gateway = S3FileStorageGateway() + external_interfaces = ExternalInterfaces( docker_repository=ECRDockerRepository(), model_bundle_repository=model_bundle_repository, @@ -211,29 +241,59 @@ def _get_external_interfaces( docker_image_batch_job_bundle_repository=docker_image_batch_job_bundle_repository, docker_image_batch_job_gateway=docker_image_batch_job_gateway, llm_fine_tuning_service=llm_fine_tuning_service, + llm_fine_tune_events_repository=llm_fine_tune_events_repository, + file_storage_gateway=file_storage_gateway, + filesystem_gateway=filesystem_gateway, + llm_artifact_gateway=llm_artifact_gateway, + trigger_repository=trigger_repository, + cron_job_gateway=cron_job_gateway, ) return external_interfaces +def get_default_external_interfaces() -> ExternalInterfaces: + session = async_scoped_session(SessionAsync, scopefunc=asyncio.current_task) # type: ignore + return _get_external_interfaces(read_only=False, session=session) + + +def get_default_external_interfaces_read_only() -> ExternalInterfaces: + session = async_scoped_session( # type: ignore + SessionReadOnlyAsync, scopefunc=asyncio.current_task # type: ignore + ) + return _get_external_interfaces(read_only=True, session=session) + + async def get_external_interfaces(): try: - session = async_scoped_session(SessionAsync, scopefunc=asyncio.current_task) - yield _get_external_interfaces(read_only=False, session=session) + from plugins.dependencies import get_external_interfaces as get_custom_external_interfaces + + logger.info("Using custom external interfaces") + yield get_custom_external_interfaces() + except ModuleNotFoundError: + logger.info("Using default external interfaces") + yield get_default_external_interfaces() finally: pass async def get_external_interfaces_read_only(): try: - session = async_scoped_session(SessionReadOnlyAsync, scopefunc=asyncio.current_task) - yield _get_external_interfaces(read_only=True, session=session) + from plugins.dependencies import ( + get_external_interfaces_read_only as get_custom_external_interfaces_read_only, + ) + + logger.info("Using custom external interfaces") + yield get_custom_external_interfaces_read_only() + except ModuleNotFoundError: + logger.info("Using default external interfaces") + yield get_default_external_interfaces_read_only() finally: pass def get_auth_repository() -> Iterator[AuthenticationRepository]: """ - Dependency for an AuthenticationRepository. This implementation returns a Scale-specific repository. + Dependency for an AuthenticationRepository. This implementation returns a fake repository. """ try: yield FakeAuthenticationRepository() diff --git a/server/llm_engine_server/api/docker_image_batch_job_bundles_v1.py b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py similarity index 84% rename from server/llm_engine_server/api/docker_image_batch_job_bundles_v1.py rename to model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py index c39ab0e9..96cc3d49 100644 --- a/server/llm_engine_server/api/docker_image_batch_job_bundles_v1.py +++ b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py @@ -1,27 +1,27 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.datadog_utils import add_trace_resource_name +from model_engine_server.common.dtos.batch_jobs import ( CreateDockerImageBatchJobBundleV1Request, CreateDockerImageBatchJobBundleV1Response, DockerImageBatchJobBundleV1Response, ListDockerImageBatchJobBundleV1Response, ) -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.domain.use_cases.docker_image_batch_job_bundle_use_cases import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException +from model_engine_server.domain.use_cases.docker_image_batch_job_bundle_use_cases import ( CreateDockerImageBatchJobBundleV1UseCase, GetDockerImageBatchJobBundleByIdV1UseCase, GetLatestDockerImageBatchJobBundleByNameV1UseCase, @@ -34,8 +34,7 @@ @docker_image_batch_job_bundle_router_v1.post( - "/docker-image-batch-job-bundles", - response_model=CreateDockerImageBatchJobBundleV1Response, + "/docker-image-batch-job-bundles", response_model=CreateDockerImageBatchJobBundleV1Response ) async def create_docker_image_batch_job_bundle( request: CreateDockerImageBatchJobBundleV1Request, @@ -60,8 +59,7 @@ async def create_docker_image_batch_job_bundle( @docker_image_batch_job_bundle_router_v1.get( - "/docker-image-batch-job-bundles", - response_model=ListDockerImageBatchJobBundleV1Response, + "/docker-image-batch-job-bundles", response_model=ListDockerImageBatchJobBundleV1Response ) async def list_docker_image_batch_job_model_bundles( bundle_name: Optional[str] = Query(default=None), @@ -84,8 +82,7 @@ async def list_docker_image_batch_job_model_bundles( @docker_image_batch_job_bundle_router_v1.get( - "/docker-image-batch-job-bundles/latest", - response_model=DockerImageBatchJobBundleV1Response, + "/docker-image-batch-job-bundles/latest", response_model=DockerImageBatchJobBundleV1Response ) async def get_latest_docker_image_batch_job_bundle( bundle_name: str, diff --git a/model-engine/model_engine_server/api/files_v1.py b/model-engine/model_engine_server/api/files_v1.py new file mode 100644 index 00000000..a2d23ba3 --- /dev/null +++ b/model-engine/model_engine_server/api/files_v1.py @@ -0,0 +1,127 @@ +"""Files API routes for the hosted model inference service.""" + +from fastapi import APIRouter, Depends, HTTPException, UploadFile +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.common.datadog_utils import add_trace_resource_name +from model_engine_server.common.dtos.files import ( + DeleteFileResponse, + GetFileContentResponse, + GetFileResponse, + ListFilesResponse, + UploadFileResponse, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( + ObjectNotAuthorizedException, + ObjectNotFoundException, +) +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.use_cases.file_use_cases import ( + DeleteFileUseCase, + GetFileContentUseCase, + GetFileUseCase, + ListFilesUseCase, + UploadFileUseCase, +) + +file_router_v1 = APIRouter(prefix="/v1") +logger = make_logger(filename_wo_ext(__name__)) + + +@file_router_v1.post("/files", response_model=UploadFileResponse) +async def upload_file( + file: UploadFile, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> UploadFileResponse: + add_trace_resource_name("files_upload") + logger.info(f"POST /files with filename {file.filename} for {auth}") + use_case = UploadFileUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute( + user=auth, + filename=file.filename, + content=file.file.read(), + ) + + +@file_router_v1.get("/files/{file_id}", response_model=GetFileResponse) +async def get_file( + file_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> GetFileResponse: + add_trace_resource_name("files_get") + logger.info(f"GET /files/{file_id} for {auth}") + try: + use_case = GetFileUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute(user=auth, file_id=file_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified file could not be found.", + ) from exc + + +@file_router_v1.get("/files", response_model=ListFilesResponse) +async def list_files( + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> ListFilesResponse: + add_trace_resource_name("files_list") + logger.info(f"GET /files for {auth}") + use_case = ListFilesUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute(user=auth) + + +@file_router_v1.delete("/files/{file_id}", response_model=DeleteFileResponse) +async def delete_file( + file_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> DeleteFileResponse: + add_trace_resource_name("files_delete") + logger.info(f"DELETE /files/{file_id} for {auth}") + try: + use_case = DeleteFileUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute(user=auth, file_id=file_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified file could not be found.", + ) from exc + + +@file_router_v1.get("/files/{file_id}/content", response_model=GetFileContentResponse) +async def get_file_content( + file_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> GetFileContentResponse: + """ + Describe the LLM Model endpoint with given name. + """ + add_trace_resource_name("files_content_get") + logger.info(f"GET /files/{file_id}/content for {auth}") + try: + use_case = GetFileContentUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute(user=auth, file_id=file_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified file could not be found.", + ) from exc diff --git a/server/llm_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py similarity index 68% rename from server/llm_engine_server/api/llms_v1.py rename to model-engine/model_engine_server/api/llms_v1.py index 1d54ca5e..50bbbbe8 100644 --- a/server/llm_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -1,63 +1,69 @@ """LLM Model Endpoint routes for the hosted model inference service. """ from typing import Optional +from uuid import uuid4 from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.llms import ( - CancelFineTuneJobResponse, +from model_engine_server.common.datadog_utils import add_trace_request_id, add_trace_resource_name +from model_engine_server.common.dtos.llms import ( + CancelFineTuneResponse, CompletionStreamV1Request, CompletionStreamV1Response, CompletionSyncV1Request, CompletionSyncV1Response, - CreateFineTuneJobRequest, - CreateFineTuneJobResponse, + CreateFineTuneRequest, + CreateFineTuneResponse, CreateLLMModelEndpointV1Request, CreateLLMModelEndpointV1Response, - GetFineTuneJobResponse, + GetFineTuneEventsResponse, + GetFineTuneResponse, GetLLMModelEndpointV1Response, - ListFineTuneJobResponse, + ListFineTunesResponse, ListLLMModelEndpointsV1Response, + ModelDownloadRequest, + ModelDownloadResponse, ) -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.common.dtos.tasks import TaskStatus -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectAlreadyExistsException, ObjectHasInvalidValueException, ObjectNotApprovedException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ( EndpointLabelsException, EndpointResourceInvalidRequestException, EndpointUnsupportedInferenceTypeException, InvalidRequestException, LLMFineTuningMethodNotImplementedException, + LLMFineTuningQuotaReached, UpstreamServiceError, ) -from llm_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( - CancelFineTuneJobV1UseCase, - CreateFineTuneJobV1UseCase, - GetFineTuneJobV1UseCase, - ListFineTuneJobV1UseCase, +from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( + CancelFineTuneV1UseCase, + CreateFineTuneV1UseCase, + GetFineTuneEventsV1UseCase, + GetFineTuneV1UseCase, + ListFineTunesV1UseCase, ) -from llm_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( CompletionStreamV1UseCase, CompletionSyncV1UseCase, CreateLLMModelEndpointV1UseCase, GetLLMModelEndpointByNameV1UseCase, ListLLMModelEndpointsV1UseCase, + ModelDownloadV1UseCase, ) -from llm_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase +from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase from sse_starlette.sse import EventSourceResponse llm_router_v1 = APIRouter(prefix="/v1/llm") @@ -135,8 +141,7 @@ async def list_model_endpoints( @llm_router_v1.get( - "/model-endpoints/{model_endpoint_name}", - response_model=GetLLMModelEndpointV1Response, + "/model-endpoints/{model_endpoint_name}", response_model=GetLLMModelEndpointV1Response ) async def get_model_endpoint( model_endpoint_name: str, @@ -182,10 +187,11 @@ async def create_completion_sync_task( return await use_case.execute( user=auth, model_endpoint_name=model_endpoint_name, request=request ) - except UpstreamServiceError as exc: - return CompletionSyncV1Response( - status=TaskStatus.FAILURE, outputs=[], traceback=exc.content.decode() - ) + except UpstreamServiceError: + request_id = str(uuid4()) + add_trace_request_id(request_id) + logger.exception(f"Upstream service error for request {request_id}") + return CompletionSyncV1Response(request_id=request_id, output=None) except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( status_code=404, @@ -228,15 +234,12 @@ async def event_generator(): yield {"data": message.json()} return EventSourceResponse(event_generator()) - except UpstreamServiceError as exc: + except UpstreamServiceError: + request_id = str(uuid4()) + add_trace_request_id(request_id) + logger.exception(f"Upstream service error for request {request_id}") return EventSourceResponse( - iter( - ( - CompletionStreamV1Response( - status=TaskStatus.FAILURE, traceback=exc.content.decode() - ).json(), - ) - ) + iter((CompletionStreamV1Response(request_id=request_id).json(),)) ) except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( @@ -252,36 +255,45 @@ async def event_generator(): ) from exc -@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneJobResponse) -async def create_fine_tune_job( - request: CreateFineTuneJobRequest, +@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse) +async def create_fine_tune( + request: CreateFineTuneRequest, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), -) -> CreateFineTuneJobResponse: +) -> CreateFineTuneResponse: add_trace_resource_name("fine_tunes_create") logger.info(f"POST /fine-tunes with {request} for {auth}") try: - use_case = CreateFineTuneJobV1UseCase( + use_case = CreateFineTuneV1UseCase( llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_fine_tune_events_repository=external_interfaces.llm_fine_tune_events_repository, + file_storage_gateway=external_interfaces.file_storage_gateway, ) return await use_case.execute(user=auth, request=request) - except (LLMFineTuningMethodNotImplementedException, InvalidRequestException) as exc: + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + except ( + LLMFineTuningMethodNotImplementedException, + LLMFineTuningQuotaReached, + InvalidRequestException, + ) as exc: raise HTTPException( status_code=400, detail=str(exc), ) from exc -@llm_router_v1.get("/fine-tunes/{fine_tune_id}", response_model=GetFineTuneJobResponse) -async def get_fine_tune_job( +@llm_router_v1.get("/fine-tunes/{fine_tune_id}", response_model=GetFineTuneResponse) +async def get_fine_tune( fine_tune_id: str, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -) -> GetFineTuneJobResponse: +) -> GetFineTuneResponse: add_trace_resource_name("fine_tunes_get") logger.info(f"GET /fine-tunes/{fine_tune_id} for {auth}") try: - use_case = GetFineTuneJobV1UseCase( + use_case = GetFineTuneV1UseCase( llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, ) return await use_case.execute(user=auth, fine_tune_id=fine_tune_id) @@ -292,29 +304,29 @@ async def get_fine_tune_job( ) from exc -@llm_router_v1.get("/fine-tunes", response_model=ListFineTuneJobResponse) -async def list_fine_tune_jobs( +@llm_router_v1.get("/fine-tunes", response_model=ListFineTunesResponse) +async def list_fine_tunes( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -) -> ListFineTuneJobResponse: +) -> ListFineTunesResponse: add_trace_resource_name("fine_tunes_list") logger.info(f"GET /fine-tunes for {auth}") - use_case = ListFineTuneJobV1UseCase( + use_case = ListFineTunesV1UseCase( llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, ) return await use_case.execute(user=auth) -@llm_router_v1.put("/fine-tunes/{fine_tune_id}/cancel", response_model=CancelFineTuneJobResponse) -async def cancel_fine_tune_job( +@llm_router_v1.put("/fine-tunes/{fine_tune_id}/cancel", response_model=CancelFineTuneResponse) +async def cancel_fine_tune( fine_tune_id: str, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), -) -> CancelFineTuneJobResponse: +) -> CancelFineTuneResponse: add_trace_resource_name("fine_tunes_cancel") logger.info(f"PUT /fine-tunes/{fine_tune_id}/cancel for {auth}") try: - use_case = CancelFineTuneJobV1UseCase( + use_case = CancelFineTuneV1UseCase( llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, ) return await use_case.execute(user=auth, fine_tune_id=fine_tune_id) @@ -323,3 +335,46 @@ async def cancel_fine_tune_job( status_code=404, detail="The specified fine-tune job could not be found.", ) from exc + + +@llm_router_v1.get("/fine-tunes/{fine_tune_id}/events", response_model=GetFineTuneEventsResponse) +async def get_fine_tune_events( + fine_tune_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> GetFineTuneEventsResponse: + add_trace_resource_name("fine_tunes_events_get") + logger.info(f"GET /fine-tunes/{fine_tune_id}/events for {auth}") + try: + use_case = GetFineTuneEventsV1UseCase( + llm_fine_tune_events_repository=external_interfaces.llm_fine_tune_events_repository, + llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, + ) + return await use_case.execute(user=auth, fine_tune_id=fine_tune_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified fine-tune job's events could not be found.", + ) from exc + + +@llm_router_v1.post("/model-endpoints/download", response_model=ModelDownloadResponse) +async def download_model_endpoint( + request: ModelDownloadRequest, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> ModelDownloadResponse: + add_trace_resource_name("model_endpoints_download") + logger.info(f"POST /model-endpoints/download with {request} for {auth}") + try: + use_case = ModelDownloadV1UseCase( + filesystem_gateway=external_interfaces.filesystem_gateway, + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + ) + return await use_case.execute(user=auth, request=request) + except (ObjectNotFoundException, ObjectHasInvalidValueException) as exc: + raise HTTPException( + status_code=404, + detail="The requested fine-tuned model could not be found.", + ) from exc diff --git a/server/llm_engine_server/api/model_bundles_v1.py b/model-engine/model_engine_server/api/model_bundles_v1.py similarity index 91% rename from server/llm_engine_server/api/model_bundles_v1.py rename to model-engine/model_engine_server/api/model_bundles_v1.py index efcf43ab..de24f860 100644 --- a/server/llm_engine_server/api/model_bundles_v1.py +++ b/model-engine/model_engine_server/api/model_bundles_v1.py @@ -1,16 +1,16 @@ -"""Model Bundle v1 routes for the LLMEngine service.""" +"""Model Bundle v1 routes for the hosted model inference service.""" from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.datadog_utils import add_trace_resource_name +from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV1Request, CreateModelBundleV1Request, CreateModelBundleV1Response, @@ -18,15 +18,15 @@ ModelBundleOrderBy, ModelBundleV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.use_cases.model_bundle_use_cases import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.use_cases.model_bundle_use_cases import ( CloneModelBundleV1UseCase, CreateModelBundleV1UseCase, GetLatestModelBundleByNameV1UseCase, diff --git a/server/llm_engine_server/api/model_bundles_v2.py b/model-engine/model_engine_server/api/model_bundles_v2.py similarity index 92% rename from server/llm_engine_server/api/model_bundles_v2.py rename to model-engine/model_engine_server/api/model_bundles_v2.py index 00d4ffed..94801916 100644 --- a/server/llm_engine_server/api/model_bundles_v2.py +++ b/model-engine/model_engine_server/api/model_bundles_v2.py @@ -3,14 +3,14 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.datadog_utils import add_trace_resource_name +from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV2Request, CreateModelBundleV2Request, CreateModelBundleV2Response, @@ -18,15 +18,15 @@ ModelBundleOrderBy, ModelBundleV2Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.use_cases.model_bundle_use_cases import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.use_cases.model_bundle_use_cases import ( CloneModelBundleV2UseCase, CreateModelBundleV2UseCase, GetLatestModelBundleByNameV2UseCase, diff --git a/server/llm_engine_server/api/model_endpoints_docs_v1.py b/model-engine/model_engine_server/api/model_endpoints_docs_v1.py similarity index 84% rename from server/llm_engine_server/api/model_endpoints_docs_v1.py rename to model-engine/model_engine_server/api/model_endpoints_docs_v1.py index 5ccb8a30..9b7f1d1f 100644 --- a/server/llm_engine_server/api/model_endpoints_docs_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_docs_v1.py @@ -2,14 +2,14 @@ from fastapi.encoders import jsonable_encoder from fastapi.openapi.docs import get_redoc_html from fastapi.responses import JSONResponse -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.use_cases.model_endpoints_schema_use_cases import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.use_cases.model_endpoints_schema_use_cases import ( GetModelEndpointsSchemaV1UseCase, ) from starlette.responses import HTMLResponse diff --git a/server/llm_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py similarity index 94% rename from server/llm_engine_server/api/model_endpoints_v1.py rename to model-engine/model_engine_server/api/model_endpoints_v1.py index a1e28df3..d37f8bf6 100644 --- a/server/llm_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -6,14 +6,14 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.model_endpoints import ( +from model_engine_server.common.datadog_utils import add_trace_resource_name +from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, DeleteModelEndpointV1Response, @@ -23,22 +23,22 @@ UpdateModelEndpointV1Request, UpdateModelEndpointV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectAlreadyExistsException, ObjectHasInvalidValueException, ObjectNotApprovedException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, EndpointLabelsException, EndpointResourceInvalidRequestException, ExistingEndpointOperationInProgressException, ) -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CreateModelEndpointV1UseCase, DeleteModelEndpointByIdV1UseCase, GetModelEndpointByIdV1UseCase, diff --git a/server/llm_engine_server/api/tasks_v1.py b/model-engine/model_engine_server/api/tasks_v1.py similarity index 87% rename from server/llm_engine_server/api/tasks_v1.py rename to model-engine/model_engine_server/api/tasks_v1.py index e0318d94..74b5b634 100644 --- a/server/llm_engine_server/api/tasks_v1.py +++ b/model-engine/model_engine_server/api/tasks_v1.py @@ -1,35 +1,37 @@ +import asyncio + from fastapi import APIRouter, Depends, HTTPException -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.datadog_utils import add_trace_resource_name +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, SyncEndpointPredictV1Response, TaskStatus, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ( EndpointUnsupportedInferenceTypeException, UpstreamServiceError, ) -from llm_engine_server.domain.use_cases.async_inference_use_cases import ( +from model_engine_server.domain.use_cases.async_inference_use_cases import ( CreateAsyncInferenceTaskV1UseCase, GetAsyncInferenceTaskV1UseCase, ) -from llm_engine_server.domain.use_cases.streaming_inference_use_cases import ( +from model_engine_server.domain.use_cases.streaming_inference_use_cases import ( CreateStreamingInferenceTaskV1UseCase, ) -from llm_engine_server.domain.use_cases.sync_inference_use_cases import ( +from model_engine_server.domain.use_cases.sync_inference_use_cases import ( CreateSyncInferenceTaskV1UseCase, ) from sse_starlette.sse import EventSourceResponse @@ -125,6 +127,11 @@ async def create_sync_inference_task( status_code=400, detail=f"Unsupported inference type: {str(exc)}", ) from exc + except asyncio.exceptions.TimeoutError as exc: + raise HTTPException( + status_code=408, + detail="Request timed out.", + ) from exc @inference_task_router_v1.post("/streaming-tasks") diff --git a/model-engine/model_engine_server/api/triggers_v1.py b/model-engine/model_engine_server/api/triggers_v1.py new file mode 100644 index 00000000..cc32180e --- /dev/null +++ b/model-engine/model_engine_server/api/triggers_v1.py @@ -0,0 +1,176 @@ +from fastapi import APIRouter, Depends, HTTPException +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces, + verify_authentication, +) +from model_engine_server.common.datadog_utils import add_trace_resource_name +from model_engine_server.common.dtos.triggers import ( + CreateTriggerV1Request, + CreateTriggerV1Response, + DeleteTriggerV1Response, + GetTriggerV1Response, + ListTriggersV1Response, + UpdateTriggerV1Request, + UpdateTriggerV1Response, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( + DockerImageNotFoundException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ( + CronSyntaxException, + EndpointLabelsException, + EndpointResourceInvalidRequestException, + TriggerNameAlreadyExistsException, +) +from model_engine_server.domain.use_cases.trigger_use_cases import ( + CreateTriggerUseCase, + DeleteTriggerUseCase, + GetTriggerUseCase, + ListTriggersUseCase, + UpdateTriggerUseCase, +) + +trigger_router_v1 = APIRouter(prefix="/v1") + +logger = make_logger(filename_wo_ext(__name__)) + + +@trigger_router_v1.post("/triggers", response_model=CreateTriggerV1Response) +async def create_trigger( + request: CreateTriggerV1Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> CreateTriggerV1Response: + """ + Creates and runs a trigger + """ + add_trace_resource_name("triggers_post") + logger.info(f"POST /triggers with {request} for {auth}") + try: + use_case = CreateTriggerUseCase( + trigger_repository=external_interfaces.trigger_repository, + cron_job_gateway=external_interfaces.cron_job_gateway, + docker_image_batch_job_bundle_repository=external_interfaces.docker_image_batch_job_bundle_repository, + docker_repository=external_interfaces.docker_repository, + ) + return await use_case.execute(user=auth, request=request) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, detail="The specified batch job bundle could not be found" + ) from exc + except DockerImageNotFoundException as exc: + raise HTTPException( + status_code=404, + detail=f"The specified docker image {exc.repository}:{exc.tag} was not found", + ) + except ObjectHasInvalidValueException as exc: + raise HTTPException( + status_code=400, + detail=f"The user specified an invalid value: {exc}", + ) from exc + except EndpointResourceInvalidRequestException as exc: + raise HTTPException( + status_code=400, + detail=f"Default trigger resource request is invalid: {exc}", + ) + except EndpointLabelsException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except CronSyntaxException as exc: + raise HTTPException( + status_code=400, + detail=f"The user specified an invalid value for cron_schedule: {exc}", + ) + except TriggerNameAlreadyExistsException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + + +@trigger_router_v1.get("/triggers", response_model=ListTriggersV1Response) +async def list_triggers( + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> ListTriggersV1Response: + """ + Lists descriptions of all triggers + """ + add_trace_resource_name("triggers_get") + logger.info(f"GET /triggers for {auth}") + use_case = ListTriggersUseCase(trigger_repository=external_interfaces.trigger_repository) + return await use_case.execute(user=auth) + + +@trigger_router_v1.get("/triggers/{trigger_id}", response_model=GetTriggerV1Response) +async def get_trigger( + trigger_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> GetTriggerV1Response: + """ + Describes the trigger with the given ID + """ + add_trace_resource_name("triggers_id_get") + logger.info(f"GET /triggers/{trigger_id} for {auth}") + try: + use_case = GetTriggerUseCase(trigger_repository=external_interfaces.trigger_repository) + return await use_case.execute(user=auth, trigger_id=trigger_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=f"Trigger {trigger_id} was not found.") from exc + + +@trigger_router_v1.put("/triggers/{trigger_id}", response_model=UpdateTriggerV1Response) +async def update_trigger( + trigger_id: str, + request: UpdateTriggerV1Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> UpdateTriggerV1Response: + """ + Updates the trigger with the given ID + """ + add_trace_resource_name("triggers_id_put") + logger.info(f"PUT /triggers/{trigger_id} with {request} for {auth}") + try: + use_case = UpdateTriggerUseCase( + trigger_repository=external_interfaces.trigger_repository, + cron_job_gateway=external_interfaces.cron_job_gateway, + ) + return await use_case.execute(user=auth, trigger_id=trigger_id, request=request) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=f"Trigger {trigger_id} was not found.") from exc + except CronSyntaxException as exc: + raise HTTPException( + status_code=400, + detail=f"The user specified an invalid value for cron_schedule: {exc}", + ) + + +@trigger_router_v1.delete("/triggers/{trigger_id}", response_model=DeleteTriggerV1Response) +async def delete_trigger( + trigger_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> DeleteTriggerV1Response: + """ + Deletes the trigger with the given ID + """ + add_trace_resource_name("trigger_id_delete") + logger.info(f"DELETE /triggers/{trigger_id} for {auth}") + try: + use_case = DeleteTriggerUseCase( + trigger_repository=external_interfaces.trigger_repository, + cron_job_gateway=external_interfaces.cron_job_gateway, + ) + return await use_case.execute(user=auth, trigger_id=trigger_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=f"Trigger {trigger_id} was not found.") from exc diff --git a/server/llm_engine_server/api/worker.py b/model-engine/model_engine_server/api/worker.py similarity index 70% rename from server/llm_engine_server/api/worker.py rename to model-engine/model_engine_server/api/worker.py index 95b02c59..776b5dd6 100644 --- a/server/llm_engine_server/api/worker.py +++ b/model-engine/model_engine_server/api/worker.py @@ -5,13 +5,9 @@ CONCURRENCY_LIMIT = 32 -class LLMEngineWorker(UvicornWorker): +class LaunchWorker(UvicornWorker): """Overrides the configuration of the Uvicorn Worker.""" # uvloop and httptools are both faster than their alternatives, but they are not compatible # with Windows or PyPy. - CONFIG_KWARGS = { - "loop": "uvloop", - "http": "httptools", - "limit_concurrency": CONCURRENCY_LIMIT, - } + CONFIG_KWARGS = {"loop": "uvloop", "http": "httptools", "limit_concurrency": CONCURRENCY_LIMIT} diff --git a/server/llm_engine_server/common/__init__.py b/model-engine/model_engine_server/common/__init__.py similarity index 100% rename from server/llm_engine_server/common/__init__.py rename to model-engine/model_engine_server/common/__init__.py diff --git a/server/llm_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py similarity index 60% rename from server/llm_engine_server/common/config.py rename to model-engine/model_engine_server/common/config.py index 11d5fede..4022ceb5 100644 --- a/server/llm_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -7,7 +7,7 @@ from typing import Sequence import yaml -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import filename_wo_ext, make_logger logger = make_logger(filename_wo_ext(__file__)) @@ -20,22 +20,41 @@ DEFAULT_SERVICE_CONFIG_PATH = str( ( - Path(__file__).absolute().parent.parent.parent / "service_configs" / "service_config.yaml" + Path(__file__).absolute().parent.parent.parent + / "service_configs" + / "service_config_circleci.yaml" ).absolute() ) SERVICE_CONFIG_PATH = os.environ.get("DEPLOY_SERVICE_CONFIG_PATH", DEFAULT_SERVICE_CONFIG_PATH) +# duplicated from llm/ia3_finetune +def get_model_cache_directory_name(model_name: str): + """How huggingface maps model names to directory names in their cache for model files. + We adopt this when storing model cache files in s3. + + Args: + model_name (str): Name of the huggingface model + """ + name = "models--" + model_name.replace("/", "--") + return name + + @dataclass class HostedModelInferenceServiceConfig: endpoint_namespace: str + billing_queue_arn: str cache_redis_url: str sqs_profile: str sqs_queue_policy_template: str sqs_queue_tag_template: str - s3_file_llm_fine_tuning_job_repository: str - datadog_trace_enabled: str + model_primitive_host: str + s3_file_llm_fine_tune_repository: str + hf_user_fine_tuned_weights_prefix: str + istio_enabled: bool + datadog_trace_enabled: bool + tgi_repository: str @classmethod def from_yaml(cls, yaml_path): diff --git a/model-engine/model_engine_server/common/constants.py b/model-engine/model_engine_server/common/constants.py new file mode 100644 index 00000000..567df502 --- /dev/null +++ b/model-engine/model_engine_server/common/constants.py @@ -0,0 +1,10 @@ +from pathlib import Path + +BILLING_POST_INFERENCE_HOOK: str = "billing" +CALLBACK_POST_INFERENCE_HOOK: str = "callback" +READYZ_FPATH: str = "/tmp/readyz" +DEFAULT_CELERY_TASK_NAME: str = "hosted_model_inference.inference.async_inference.tasks.predict" +LIRA_CELERY_TASK_NAME: str = "ml_serve.celery_service.exec_func" + +PROJECT_ROOT: Path = Path(__file__).parents[2].absolute() +HOSTED_MODEL_INFERENCE_ROOT: Path = PROJECT_ROOT / "model-engine" diff --git a/model-engine/model_engine_server/common/datadog_utils.py b/model-engine/model_engine_server/common/datadog_utils.py new file mode 100644 index 00000000..c73fa2f9 --- /dev/null +++ b/model-engine/model_engine_server/common/datadog_utils.py @@ -0,0 +1,19 @@ +from ddtrace import tracer + + +def add_trace_resource_name(tag: str): + """Adds a custom tag to a given dd trace corresponding to the route + (e.g. get_model_bundles for GET /model-bundles, etc.) so that we can filter in Datadog easier + """ + current_span = tracer.current_span() + if current_span: + current_span.set_tag("launch.resource_name", tag) + + +def add_trace_request_id(request_id: str): + """Adds a custom tag to a given dd trace corresponding to the request id + so that we can filter in Datadog easier + """ + current_span = tracer.current_span() + if current_span: + current_span.set_tag("launch.request_id", request_id) diff --git a/server/llm_engine_server/common/dtos/__init__.py b/model-engine/model_engine_server/common/dtos/__init__.py similarity index 100% rename from server/llm_engine_server/common/dtos/__init__.py rename to model-engine/model_engine_server/common/dtos/__init__.py diff --git a/server/llm_engine_server/common/dtos/batch_jobs.py b/model-engine/model_engine_server/common/dtos/batch_jobs.py similarity index 94% rename from server/llm_engine_server/common/dtos/batch_jobs.py rename to model-engine/model_engine_server/common/dtos/batch_jobs.py index e1fc45fa..ce1af0c8 100644 --- a/server/llm_engine_server/common/dtos/batch_jobs.py +++ b/model-engine/model_engine_server/common/dtos/batch_jobs.py @@ -4,11 +4,12 @@ from datetime import datetime, timedelta from typing import Any, Collection, Dict, List, Optional -from llm_engine_server.common import dict_not_none -from llm_engine_server.domain.entities import ( +from model_engine_server.common import dict_not_none +from model_engine_server.domain.entities import ( BatchJobSerializationFormat, BatchJobStatus, CpuSpecificationType, + DockerImageBatchJob, GpuType, StorageSpecificationType, ) @@ -100,6 +101,8 @@ class CreateDockerImageBatchJobV1Request(BaseModel): CreateDockerImageBatchJobResourceRequests() ) + override_job_max_runtime_s: Optional[int] = None + @root_validator def exactly_one_name_or_id(cls, values): bundle_name = values.get("docker_image_batch_job_bundle_name") @@ -123,6 +126,10 @@ class GetDockerImageBatchJobV1Response(BaseModel): status: BatchJobStatus +class ListDockerImageBatchJobsV1Response(BaseModel): + jobs: List[DockerImageBatchJob] + + class UpdateDockerImageBatchJobV1Request(BaseModel): cancel: bool diff --git a/server/llm_engine_server/common/dtos/docker_repository.py b/model-engine/model_engine_server/common/dtos/docker_repository.py similarity index 100% rename from server/llm_engine_server/common/dtos/docker_repository.py rename to model-engine/model_engine_server/common/dtos/docker_repository.py diff --git a/server/llm_engine_server/common/dtos/endpoint_builder.py b/model-engine/model_engine_server/common/dtos/endpoint_builder.py similarity index 92% rename from server/llm_engine_server/common/dtos/endpoint_builder.py rename to model-engine/model_engine_server/common/dtos/endpoint_builder.py index 9817fbbc..0edbeaaf 100644 --- a/server/llm_engine_server/common/dtos/endpoint_builder.py +++ b/model-engine/model_engine_server/common/dtos/endpoint_builder.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Dict, List, Optional -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -28,6 +28,7 @@ class BuildEndpointRequest(BaseModel): child_fn_info: Optional[Dict[str, Any]] # TODO: remove this if we don't need it. post_inference_hooks: Optional[List[str]] labels: Dict[str, str] + billing_tags: Optional[Dict[str, Any]] prewarm: bool = True high_priority: Optional[bool] default_callback_url: Optional[str] diff --git a/model-engine/model_engine_server/common/dtos/files.py b/model-engine/model_engine_server/common/dtos/files.py new file mode 100644 index 00000000..94b54474 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/files.py @@ -0,0 +1,47 @@ +""" +DTOs for Files API. +""" +from typing import List + +from pydantic import BaseModel, Field + + +class UploadFileResponse(BaseModel): + """Response object for uploading a file.""" + + id: str = Field(..., description="ID of the uploaded file.") + """ID of the uploaded file.""" + + +class GetFileResponse(BaseModel): + """Response object for retrieving a file.""" + + id: str = Field(..., description="ID of the requested file.") + """ID of the requested file.""" + filename: str = Field(..., description="File name.") + """File name.""" + size: int = Field(..., description="Length of the file, in characters.") + """Length of the file, in characters.""" + + +class ListFilesResponse(BaseModel): + """Response object for listing files.""" + + files: List[GetFileResponse] = Field(..., description="List of file IDs, names, and sizes.") + """List of file IDs, names, and sizes.""" + + +class DeleteFileResponse(BaseModel): + """Response object for deleting a file.""" + + deleted: bool = Field(..., description="Whether deletion was successful.") + """Whether deletion was successful.""" + + +class GetFileContentResponse(BaseModel): + """Response object for retrieving a file's content.""" + + id: str = Field(..., description="ID of the requested file.") + """ID of the requested file.""" + content: str = Field(..., description="File content.") + """File content.""" diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py new file mode 100644 index 00000000..d62f7992 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -0,0 +1,230 @@ +""" +DTOs for LLM APIs. +""" + +from typing import Any, Dict, List, Optional + +from model_engine_server.common.dtos.model_endpoints import ( + CpuSpecificationType, + GetModelEndpointV1Response, + GpuType, + ModelEndpointType, + StorageSpecificationType, +) +from model_engine_server.domain.entities import ( + BatchJobStatus, + CallbackAuth, + FineTuneHparamValueType, + LLMFineTuneEvent, + LLMInferenceFramework, + LLMSource, + ModelEndpointStatus, + Quantization, +) +from pydantic import BaseModel, Field, HttpUrl + + +class CreateLLMModelEndpointV1Request(BaseModel): + name: str + + # LLM specific fields + model_name: str + source: LLMSource = LLMSource.HUGGING_FACE + inference_framework: LLMInferenceFramework = LLMInferenceFramework.DEEPSPEED + inference_framework_image_tag: str + num_shards: int = 1 + """ + Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models + """ + + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. Only affect behavior for text-generation-inference models + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models + """ + + # General endpoint fields + metadata: Dict[str, Any] # TODO: JSON type + post_inference_hooks: Optional[List[str]] + endpoint_type: ModelEndpointType = ModelEndpointType.SYNC + cpus: CpuSpecificationType + gpus: int + memory: StorageSpecificationType + gpu_type: GpuType + storage: Optional[StorageSpecificationType] + optimize_costs: Optional[bool] + min_workers: int + max_workers: int + per_worker: int + labels: Dict[str, str] + prewarm: Optional[bool] + high_priority: Optional[bool] + billing_tags: Optional[Dict[str, Any]] + default_callback_url: Optional[HttpUrl] + default_callback_auth: Optional[CallbackAuth] + public_inference: Optional[bool] = True # LLM endpoints are public by default. + + +class CreateLLMModelEndpointV1Response(BaseModel): + endpoint_creation_task_id: str + + +class GetLLMModelEndpointV1Response(BaseModel): + id: str + """ + The autogenerated ID of the Launch endpoint. + """ + + name: str + model_name: Optional[str] = None + source: LLMSource + status: ModelEndpointStatus + inference_framework: LLMInferenceFramework + inference_framework_image_tag: Optional[str] = None + num_shards: Optional[int] = None + quantize: Optional[Quantization] = None + spec: Optional[GetModelEndpointV1Response] = None + + +class ListLLMModelEndpointsV1Response(BaseModel): + model_endpoints: List[GetLLMModelEndpointV1Response] + + +# Delete and update use the default Launch endpoint APIs. + + +class CompletionSyncV1Request(BaseModel): + """ + Request object for a synchronous prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0, le=1) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + + +class TokenOutput(BaseModel): + token: str + log_prob: float + + +class CompletionOutput(BaseModel): + text: str + num_completion_tokens: int + tokens: Optional[List[TokenOutput]] = None + + +class CompletionSyncV1Response(BaseModel): + """ + Response object for a synchronous prompt completion task. + """ + + request_id: str + output: Optional[CompletionOutput] = None + + +class CompletionStreamV1Request(BaseModel): + """ + Request object for a stream prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0, le=1) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. Only affects behavior for text-generation-inference models + """ + + +class CompletionStreamOutput(BaseModel): + text: str + finished: bool + num_completion_tokens: Optional[int] = None + token: Optional[TokenOutput] = None + + +class CompletionStreamV1Response(BaseModel): + """ + Response object for a stream prompt completion task. + """ + + request_id: str + output: Optional[CompletionStreamOutput] = None + + +class CreateFineTuneRequest(BaseModel): + model: str + training_file: str + validation_file: Optional[str] = None + # fine_tuning_method: str # TODO enum + uncomment when we support multiple methods + hyperparameters: Dict[str, FineTuneHparamValueType] # validated somewhere else + suffix: Optional[str] = None + wandb_config: Optional[Dict[str, Any]] = None + """ + Config to pass to wandb for init. See https://docs.wandb.ai/ref/python/init + Must include `api_key` field which is the wandb API key. + """ + + +class CreateFineTuneResponse(BaseModel): + id: str + + +class GetFineTuneResponse(BaseModel): + id: str = Field(..., description="Unique ID of the fine tune") + fine_tuned_model: Optional[str] = Field( + default=None, + description="Name of the resulting fine-tuned model. This can be plugged into the " + "Completion API ones the fine-tune is complete", + ) + status: BatchJobStatus = Field(..., description="Status of the requested fine tune.") + + +class ListFineTunesResponse(BaseModel): + jobs: List[GetFineTuneResponse] + + +class CancelFineTuneResponse(BaseModel): + success: bool + + +class GetFineTuneEventsResponse(BaseModel): + # LLMFineTuneEvent is entity layer technically, but it's really simple + events: List[LLMFineTuneEvent] + + +class ModelDownloadRequest(BaseModel): + model_name: str = Field(..., description="Name of the fine tuned model") + download_format: Optional[str] = Field( + default="hugging_face", + description="Format that you want the downloaded urls to be compatible with. Currently only supports hugging_face", + ) + + +class ModelDownloadResponse(BaseModel): + urls: Dict[str, str] = Field( + ..., description="Dictionary of (file_name, url) pairs to download the model from." + ) diff --git a/server/llm_engine_server/common/dtos/model_bundles.py b/model-engine/model_engine_server/common/dtos/model_bundles.py similarity index 98% rename from server/llm_engine_server/common/dtos/model_bundles.py rename to model-engine/model_engine_server/common/dtos/model_bundles.py index 5d3ece47..778b2942 100644 --- a/server/llm_engine_server/common/dtos/model_bundles.py +++ b/model-engine/model_engine_server/common/dtos/model_bundles.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Dict, List, Optional -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( ModelBundleEnvironmentParams, ModelBundleFlavors, ModelBundlePackagingType, diff --git a/server/llm_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py similarity index 97% rename from server/llm_engine_server/common/dtos/model_endpoints.py rename to model-engine/model_engine_server/common/dtos/model_endpoints.py index 956e8ee1..301a2d45 100644 --- a/server/llm_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -10,7 +10,7 @@ from enum import Enum from typing import Any, Dict, List, Optional -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -62,6 +62,7 @@ class CreateModelEndpointV1Request(BaseModel): labels: Dict[str, str] prewarm: Optional[bool] high_priority: Optional[bool] + billing_tags: Optional[Dict[str, Any]] default_callback_url: Optional[HttpUrl] default_callback_auth: Optional[CallbackAuth] public_inference: Optional[bool] = Field(default=False) @@ -87,6 +88,7 @@ class UpdateModelEndpointV1Request(BaseModel): labels: Optional[Dict[str, str]] prewarm: Optional[bool] high_priority: Optional[bool] + billing_tags: Optional[Dict[str, Any]] default_callback_url: Optional[HttpUrl] default_callback_auth: Optional[CallbackAuth] public_inference: Optional[bool] diff --git a/server/llm_engine_server/common/dtos/resource_manager.py b/model-engine/model_engine_server/common/dtos/resource_manager.py similarity index 64% rename from server/llm_engine_server/common/dtos/resource_manager.py rename to model-engine/model_engine_server/common/dtos/resource_manager.py index cb6bea9a..e156f77e 100644 --- a/server/llm_engine_server/common/dtos/resource_manager.py +++ b/model-engine/model_engine_server/common/dtos/resource_manager.py @@ -1,4 +1,4 @@ -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest from pydantic import BaseModel diff --git a/server/llm_engine_server/common/dtos/tasks.py b/model-engine/model_engine_server/common/dtos/tasks.py similarity index 94% rename from server/llm_engine_server/common/dtos/tasks.py rename to model-engine/model_engine_server/common/dtos/tasks.py index ecd01802..5b0bf580 100644 --- a/server/llm_engine_server/common/dtos/tasks.py +++ b/model-engine/model_engine_server/common/dtos/tasks.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Optional -from llm_engine_server.domain.entities import CallbackAuth +from model_engine_server.domain.entities import CallbackAuth from pydantic import BaseModel diff --git a/model-engine/model_engine_server/common/dtos/triggers.py b/model-engine/model_engine_server/common/dtos/triggers.py new file mode 100644 index 00000000..ee4d2121 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/triggers.py @@ -0,0 +1,51 @@ +""" +Contains various input and output types relating to Triggers for the server. +""" +import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class CreateTriggerV1Request(BaseModel): + name: str + cron_schedule: str + bundle_id: str + default_job_config: Optional[Dict[str, Any]] + default_job_metadata: Optional[Dict[str, str]] + + +class CreateTriggerV1Response(BaseModel): + trigger_id: str + + +class GetTriggerV1Response(BaseModel): + id: str + name: str + owner: str + created_by: str + created_at: datetime.datetime + cron_schedule: str + docker_image_batch_job_bundle_id: str + default_job_config: Optional[Dict[str, Any]] = Field(default=None) + default_job_metadata: Optional[Dict[str, str]] = Field(default=None) + + class Config: + orm_mode = True + + +class ListTriggersV1Response(BaseModel): + triggers: List[GetTriggerV1Response] + + +class UpdateTriggerV1Request(BaseModel): + cron_schedule: Optional[str] + suspend: Optional[bool] + + +class UpdateTriggerV1Response(BaseModel): + success: bool + + +class DeleteTriggerV1Response(BaseModel): + success: bool diff --git a/model-engine/model_engine_server/common/env_vars.py b/model-engine/model_engine_server/common/env_vars.py new file mode 100644 index 00000000..a51a7698 --- /dev/null +++ b/model-engine/model_engine_server/common/env_vars.py @@ -0,0 +1,77 @@ +""" +A place for defining, setting, and referencing all environment variables used in Launch. +""" +import os +from typing import Optional, Sequence + +from model_engine_server.common.constants import PROJECT_ROOT +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger + +__all__: Sequence[str] = ( + "CIRCLECI", + "GIT_TAG", + "LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH", + "LAUNCH_SERVICE_TEMPLATE_FOLDER", + "LOCAL", + "SKIP_AUTH", + "WORKSPACE", + "get_boolean_env_var", +) + +logger = make_logger(logger_name()) + + +def get_boolean_env_var(name: str) -> bool: + """For all env vars that are either on or off. + + An env var is ON iff: + - it is defined + - its value is the literal string 'true' + + If it is present but not set to 'true', it is considered to be OFF. + """ + value = os.environ.get(name) + if value is None: + return False + value = value.strip().lower() + return "true" == value + + +CIRCLECI: bool = get_boolean_env_var("CIRCLECI") + +LOCAL: bool = get_boolean_env_var("LOCAL") +"""Indicates that Launch is running in a local development environment. Also used for local testing. +""" + +SKIP_AUTH: bool = get_boolean_env_var("SKIP_AUTH") or infra_config().identity_service_url is None +"""Indicates that Launch is running in a development environment where authentication is not +required. +""" + +WORKSPACE: str = os.environ.get("WORKSPACE", "~/models") +"""The working directory where hosted_model_inference is installed. +""" + +LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH: str = os.environ.get( + "LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH", + os.path.join( + PROJECT_ROOT, + "model_engine_server/infra/gateways/resources/templates", + "service_template_config_map_circleci.yaml", + ), +) +"""The path to the config map containing the Launch service template. +""" + +LAUNCH_SERVICE_TEMPLATE_FOLDER: Optional[str] = os.environ.get("LAUNCH_SERVICE_TEMPLATE_FOLDER") +"""The path to the folder containing the Launch service template. If set, this overrides +LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH. +""" + +if LOCAL: + logger.warning("LOCAL development & testing mode is ON") + +GIT_TAG: str = os.environ.get("GIT_TAG", "GIT_TAG_NOT_FOUND") +if GIT_TAG == "GIT_TAG_NOT_FOUND": + raise ValueError("GIT_TAG environment variable must be set") diff --git a/server/llm_engine_server/common/errors.py b/model-engine/model_engine_server/common/errors.py similarity index 100% rename from server/llm_engine_server/common/errors.py rename to model-engine/model_engine_server/common/errors.py diff --git a/server/llm_engine_server/common/io.py b/model-engine/model_engine_server/common/io.py similarity index 92% rename from server/llm_engine_server/common/io.py rename to model-engine/model_engine_server/common/io.py index 8ee049db..2247b6f4 100644 --- a/server/llm_engine_server/common/io.py +++ b/model-engine/model_engine_server/common/io.py @@ -1,4 +1,4 @@ -"""LLMEngine Input/Output utils.""" +"""Launch Input/Output utils.""" import os import boto3 diff --git a/server/llm_engine_server/common/pydantic_types/endpoint_predict_payload.py b/model-engine/model_engine_server/common/pydantic_types/endpoint_predict_payload.py similarity index 100% rename from server/llm_engine_server/common/pydantic_types/endpoint_predict_payload.py rename to model-engine/model_engine_server/common/pydantic_types/endpoint_predict_payload.py diff --git a/server/llm_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py similarity index 94% rename from server/llm_engine_server/common/resource_limits.py rename to model-engine/model_engine_server/common/resource_limits.py index cadf4001..ee19af55 100644 --- a/server/llm_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -1,18 +1,18 @@ from typing import Optional, Union, cast -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import ( CpuSpecificationType, GpuType, ModelBundle, StorageSpecificationType, TritonEnhancedRunnableImageFlavor, ) -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.infra.gateways.k8s_resource_parser import ( +from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException +from model_engine_server.infra.gateways.k8s_resource_parser import ( format_bytes, parse_cpu_request, parse_mem_request, @@ -40,6 +40,7 @@ GpuType.NVIDIA_TESLA_T4: T4_INSTANCE_LIMITS, GpuType.NVIDIA_AMPERE_A10: A10_INSTANCE_LIMITS, GpuType.NVIDIA_AMPERE_A100: A100_INSTANCE_LIMITS, + GpuType.NVIDIA_AMPERE_A100E: A100_INSTANCE_LIMITS, } FORWARDER_CPU_USAGE = 0.5 diff --git a/server/llm_engine_server/common/serialization_utils.py b/model-engine/model_engine_server/common/serialization_utils.py similarity index 100% rename from server/llm_engine_server/common/serialization_utils.py rename to model-engine/model_engine_server/common/serialization_utils.py diff --git a/server/llm_engine_server/common/service_requests.py b/model-engine/model_engine_server/common/service_requests.py similarity index 93% rename from server/llm_engine_server/common/service_requests.py rename to model-engine/model_engine_server/common/service_requests.py index 3e20fbe5..9f5327d4 100644 --- a/server/llm_engine_server/common/service_requests.py +++ b/model-engine/model_engine_server/common/service_requests.py @@ -3,8 +3,8 @@ from typing import Any, Dict, Optional import requests -from llm_engine_server.common.errors import HTTP429Exception, UpstreamHTTPSvcError -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.common.errors import HTTP429Exception, UpstreamHTTPSvcError +from model_engine_server.core.loggers import filename_wo_ext, make_logger from tenacity import ( RetryError, Retrying, diff --git a/model-engine/model_engine_server/common/settings.py b/model-engine/model_engine_server/common/settings.py new file mode 100644 index 00000000..7dc6c6bb --- /dev/null +++ b/model-engine/model_engine_server/common/settings.py @@ -0,0 +1,114 @@ +# This file contains standard settings for ML serve. +# + +import hashlib +from typing import List, Tuple + +from model_engine_server.common.config import hmi_config +from model_engine_server.core.config import infra_config + +DEPLOYMENT_PREFIX = "launch" +LEGACY_DEPLOYMENT_PREFIX = "hmi" +SERVICE_BUILDER_QUEUE_PREFIX = "model-engine" +SERVICE_BUILDER_QUEUE_SUFFIX = "service-builder" +HOSTED_INFERENCE_SERVER_NAME = "hostedinference" +LAUNCH_SERVER_NAME = "launch" +K8S_CACHER_NAME = "launch-k8s-cacher" +PYSPARK_DEFAULT_ENDPOINT_PARAMS = dict( + cpus=3, + memory="12Gi", + gpus=1, + gpu_type="nvidia-tesla-t4", + min_workers=0, + max_workers=50, + per_worker=40, +) # TODO: we could probably determine an appropriate value for max_workers based on the size of the batch +PYSPARK_DEFAULT_MAX_EXECUTORS = 50 +PYSPARK_DEFAULT_PARTITION_SIZE = 500 + +RESTRICTED_ENDPOINT_LABELS = set( + [ + "user_id", + "endpoint_name", + ] +) + +REQUIRED_ENDPOINT_LABELS = set( + [ + "team", + "product", + ] +) + +PRETRAINED_ENDPOINTS_CREATED_BY = ["nucleus-model-zoo", "bloom", "llm", "pretrained"] + + +def generate_deployment_name(user_id, endpoint_name): + return "-".join(_generate_deployment_name_parts(user_id, endpoint_name)) + + +def _generate_queue_name(user_id, endpoint_name): + return ".".join(_generate_deployment_name_parts(user_id, endpoint_name)) + + +def generate_destination(user_id: str, endpoint_name: str, endpoint_type: str) -> str: + if endpoint_type == "async": + return _generate_queue_name(user_id, endpoint_name) + elif endpoint_type in {"sync", "streaming"}: + return generate_deployment_name(user_id, endpoint_name) + else: + raise ValueError(f"Invalid endpoint_type: {endpoint_type}") + + +def _generate_deployment_name_parts(user_id: str, endpoint_name: str) -> List[str]: + user_endpoint_hash = hashlib.md5((user_id + endpoint_name).encode("utf-8")).hexdigest() + return [ + DEPLOYMENT_PREFIX, + user_id[:24], + endpoint_name[:8], + user_endpoint_hash[:8], + ] + + +def generate_batch_job_name(user_id: str, endpoint_name: str): + batch_job_partial_name = "-".join(_generate_deployment_name_parts(user_id, endpoint_name)) + return f"batch-job-{batch_job_partial_name}" + + +def get_sync_endpoint_hostname_and_url(deployment_name: str) -> Tuple[str, str]: + hostname = f"{deployment_name}.{hmi_config.endpoint_namespace}" + return hostname, f"http://{hostname}/predict" + + +def get_sync_endpoint_elb_url(deployment_name: str) -> str: + return f"http://{deployment_name}.{infra_config().dns_host_domain}/predict" + + +def get_service_builder_queue(service_identifier=None): + return ( + f"{SERVICE_BUILDER_QUEUE_PREFIX}-{service_identifier}.{SERVICE_BUILDER_QUEUE_SUFFIX}" + if service_identifier + else f"{SERVICE_BUILDER_QUEUE_PREFIX}.{SERVICE_BUILDER_QUEUE_SUFFIX}" + ) + + +def get_quart_server_name(service_identifier=None): + return ( + f"{HOSTED_INFERENCE_SERVER_NAME}-{service_identifier}" + if service_identifier + else HOSTED_INFERENCE_SERVER_NAME + ) + + +def get_gateway_server_name(service_identifier=None): + return ( + f"{LAUNCH_SERVER_NAME}-{service_identifier}" if service_identifier else LAUNCH_SERVER_NAME + ) + + +def get_service_builder_logs_location(user_id: str, endpoint_name: str): + return f"s3://{infra_config().s3_bucket}/service_builder_logs/{user_id}_{endpoint_name}" + + +def get_k8s_cacher_service_name(service_identifier=None): + return f"{K8S_CACHER_NAME}-{service_identifier}" if service_identifier else K8S_CACHER_NAME diff --git a/server/llm_engine_server/common/types.py b/model-engine/model_engine_server/common/types.py similarity index 98% rename from server/llm_engine_server/common/types.py rename to model-engine/model_engine_server/common/types.py index 22508887..93ccbed6 100644 --- a/server/llm_engine_server/common/types.py +++ b/model-engine/model_engine_server/common/types.py @@ -117,3 +117,4 @@ class EndpointBuilderParams(EndpointParams): app_config: Optional[Dict[str, Any]] = None child_fn_info: Optional[Dict[str, Any]] = None post_inference_hooks: Optional[List[str]] = None + billing_tags: Optional[Dict[str, Any]] = None diff --git a/server/llm_engine_server/core/__init__.py b/model-engine/model_engine_server/core/__init__.py similarity index 100% rename from server/llm_engine_server/core/__init__.py rename to model-engine/model_engine_server/core/__init__.py diff --git a/server/llm_engine_server/core/auth/__init__.py b/model-engine/model_engine_server/core/auth/__init__.py similarity index 100% rename from server/llm_engine_server/core/auth/__init__.py rename to model-engine/model_engine_server/core/auth/__init__.py diff --git a/server/llm_engine_server/core/auth/authentication_repository.py b/model-engine/model_engine_server/core/auth/authentication_repository.py similarity index 87% rename from server/llm_engine_server/core/auth/authentication_repository.py rename to model-engine/model_engine_server/core/auth/authentication_repository.py index ce1cf9b9..1b1d1f77 100644 --- a/server/llm_engine_server/core/auth/authentication_repository.py +++ b/model-engine/model_engine_server/core/auth/authentication_repository.py @@ -16,6 +16,13 @@ class AuthenticationRepository(ABC): With the context of the Model Primitive service, this just refers to a (user_id, team_id) pair. """ + @staticmethod + @abstractmethod + def is_allowed_team(team: str) -> bool: + """ + Returns whether the provided team is an allowed team. + """ + @abstractmethod def get_auth_from_user_id(self, user_id: str) -> Optional[User]: """ diff --git a/server/llm_engine_server/core/auth/fake_authentication_repository.py b/model-engine/model_engine_server/core/auth/fake_authentication_repository.py similarity index 85% rename from server/llm_engine_server/core/auth/fake_authentication_repository.py rename to model-engine/model_engine_server/core/auth/fake_authentication_repository.py index 5da02827..d3e5f4c1 100644 --- a/server/llm_engine_server/core/auth/fake_authentication_repository.py +++ b/model-engine/model_engine_server/core/auth/fake_authentication_repository.py @@ -1,6 +1,6 @@ from typing import Dict, Optional -from llm_engine_server.core.auth.authentication_repository import AuthenticationRepository, User +from model_engine_server.core.auth.authentication_repository import AuthenticationRepository, User class FakeAuthenticationRepository(AuthenticationRepository): @@ -9,6 +9,10 @@ def __init__(self, user_team_override: Optional[Dict[str, str]] = None): user_team_override = {} self.user_team_override = user_team_override + @staticmethod + def is_allowed_team(team: str) -> bool: + return True + def get_auth_from_user_id(self, user_id: str) -> Optional[User]: team_id = self.user_team_override.get(user_id, user_id) return User(user_id=user_id, team_id=team_id, is_privileged_user=True) diff --git a/server/llm_engine_server/core/aws/__init__.py b/model-engine/model_engine_server/core/aws/__init__.py similarity index 100% rename from server/llm_engine_server/core/aws/__init__.py rename to model-engine/model_engine_server/core/aws/__init__.py diff --git a/server/llm_engine_server/core/aws/roles.py b/model-engine/model_engine_server/core/aws/roles.py similarity index 88% rename from server/llm_engine_server/core/aws/roles.py rename to model-engine/model_engine_server/core/aws/roles.py index 65548087..d33efeca 100644 --- a/server/llm_engine_server/core/aws/roles.py +++ b/model-engine/model_engine_server/core/aws/roles.py @@ -11,7 +11,7 @@ import boto3 from boto3 import Session, client from botocore.client import BaseClient -from llm_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.loggers import logger_name, make_logger logger = make_logger(logger_name()) @@ -73,9 +73,6 @@ def client(self, client_type: str, region_name: str = "us-west-2") -> BaseClient """Creates the specified Boto3 :param:`client_type` using the AWS credentials. The :param:`client_type` parameter is any valid value for `boto3.client` (e.g. `"s3"`). - - NOTE: Use the us-west-2 region unless you are absolutely sure you require a different region. - All Scale AWS services are in US West 2. """ return boto3.client( client_type, @@ -122,15 +119,6 @@ def session(role: Optional[str], session_type: SessionT = Session) -> SessionT: :param:`session_type` defines the type of session to return. Most users will use the default boto3 type. Some users required a special type (e.g aioboto3 session). - - Includes fall-back logic to work with setups that do not use a credentials file - in the .aws folder in the user's home folder. In this setting, it is ok for :param:`role` - to be an ARN. Otherwise, the `profile_to_arn` mapping in `scaleml.config` is used to - locate the correct ARN for the given AWS profile name. - - NOTE: The fall-back is required for this to work with setups that use `aws-okta`. - - :raises: botocore.exceptions.ProfileNotFound, ValueError """ # Do not assume roles in CIRCLECI if os.getenv("CIRCLECI"): @@ -178,7 +166,6 @@ def _session_aws_okta( return sesh -# returns scale user (e.g. pranav.pillai) def get_current_user() -> str: """Uses AWS sts to obtain the profile name of the currently authenticated AWS account.""" arn = client("sts").get_caller_identity().get("Arn") @@ -196,14 +183,14 @@ def parse_arn_string(arn: str) -> ArnData: if not 2 <= len(bits) <= 3: raise ValueError( f"Invalid format for AWS ARN string: {arn} -- " - f"Expecting either 2 or 3 parts separated by '/'" + f"Expecting either 2 or 3 parts seperated by '/'" ) account_and_source: List[str] = bits[0].split("::") if len(account_and_source) != 2: raise ValueError( f"Expecting ARN string to have 2 parts in the first '/' part, " - f"separated by '::'. Instead found {account_and_source} from " + f"seperated by '::'. Instead found {account_and_source} from " f"arn={arn}" ) @@ -234,8 +221,8 @@ def parse_arn_string(arn: str) -> ArnData: except ValueError as err: raise ValueError( "ARN format invalid: expecting account ID to appear as 2nd to last " - "value separated by ':' within the first value separated by '/' and " - "second value separated by '::' -- " + "value seperated by ':' within the first value seperated by '/' and " + "second value seperated by '::' -- " f"arn={arn} and expecting {account_str} to be account ID" ) from err diff --git a/server/llm_engine_server/core/aws/secrets.py b/model-engine/model_engine_server/core/aws/secrets.py similarity index 64% rename from server/llm_engine_server/core/aws/secrets.py rename to model-engine/model_engine_server/core/aws/secrets.py index 4d8e8941..37ed25e1 100644 --- a/server/llm_engine_server/core/aws/secrets.py +++ b/model-engine/model_engine_server/core/aws/secrets.py @@ -5,8 +5,8 @@ import boto3 from botocore.exceptions import ClientError -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger logger = make_logger(filename_wo_ext(__file__)) @@ -15,13 +15,9 @@ def get_key_file(secret_name: str, aws_profile: Optional[str] = None): if aws_profile is not None: session = boto3.Session(profile_name=aws_profile) - secret_manager = session.client( - "secretsmanager", region_name=ml_infra_config().default_region - ) + secret_manager = session.client("secretsmanager", region_name=infra_config().default_region) else: - secret_manager = boto3.client( - "secretsmanager", region_name=ml_infra_config().default_region - ) + secret_manager = boto3.client("secretsmanager", region_name=infra_config().default_region) try: secret_value = json.loads( secret_manager.get_secret_value(SecretId=secret_name)["SecretString"] diff --git a/server/llm_engine_server/core/aws/storage_client.py b/model-engine/model_engine_server/core/aws/storage_client.py similarity index 86% rename from server/llm_engine_server/core/aws/storage_client.py rename to model-engine/model_engine_server/core/aws/storage_client.py index 526eda3a..c73c500d 100644 --- a/server/llm_engine_server/core/aws/storage_client.py +++ b/model-engine/model_engine_server/core/aws/storage_client.py @@ -3,9 +3,9 @@ import smart_open from botocore.client import BaseClient -from llm_engine_server.core.aws.roles import session -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.aws.roles import session +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger logger = make_logger(logger_name()) @@ -20,7 +20,7 @@ def sync_storage_client(**kwargs) -> BaseClient: - return session(ml_infra_config().profile_ml_worker).client("s3", **kwargs) + return session(infra_config().profile_ml_worker).client("s3", **kwargs) def open(uri: str, mode: str = "rt", **kwargs) -> IO: # pylint: disable=redefined-builtin @@ -30,10 +30,7 @@ def open(uri: str, mode: str = "rt", **kwargs) -> IO: # pylint: disable=redefin def sync_storage_client_keepalive( - s3_client: BaseClient, - buckets: Iterable[str], - interval: int, - is_cancelled: Callable[[], bool], + s3_client: BaseClient, buckets: Iterable[str], interval: int, is_cancelled: Callable[[], bool] ) -> None: """Keeps connection pool warmed up for access on list of S3 buckets. diff --git a/server/llm_engine_server/core/celery/__init__.py b/model-engine/model_engine_server/core/celery/__init__.py similarity index 100% rename from server/llm_engine_server/core/celery/__init__.py rename to model-engine/model_engine_server/core/celery/__init__.py diff --git a/server/llm_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py similarity index 96% rename from server/llm_engine_server/core/celery/app.py rename to model-engine/model_engine_server/core/celery/app.py index 924d268f..7e87d2f0 100644 --- a/server/llm_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -9,9 +9,9 @@ from celery.app import backends from celery.app.control import Inspect from celery.result import AsyncResult -from llm_engine_server.core.aws.roles import session -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import ( +from model_engine_server.core.aws.roles import session +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import ( CustomJSONFormatter, logger_name, make_logger, @@ -25,7 +25,7 @@ # This is because the Celery code does not actually work when you try and # override the backend with a class instead of a URL, despite the fact # that the `backend` constructor arg type is a Union[str, Type[celery.backends.base.Backend]] -backends.BACKEND_ALIASES["s3"] = "llm_engine_server.core.celery.s3:S3Backend" +backends.BACKEND_ALIASES["s3"] = "model_engine_server.core.celery.s3:S3Backend" @unique @@ -60,7 +60,7 @@ class TaskVisibility(IntEnum): 2. When making requests to such deployment, you'll have to do: ```python - from scaleml.io.celery import TaskVisibility, celery_app + from model_engine_server.core.celery.app import TaskVisibility, celery_app app = celery_app(None, task_visibility=TaskVisibility.VISIBILITY_1M) future_result = app.send_task("some.task.name", args=["some", "args"], queue="some-queue") ``` @@ -171,7 +171,7 @@ def get_redis_host_port(): port = os.getenv("REDIS_PORT") # In the case of k8s, pick the right endpoint based on the config elif os.getenv("KUBERNETES_SERVICE_HOST"): - host = ml_infra_config().redis_host + host = infra_config().redis_host port = 6379 # For debugging purposes elif os.getenv("USE_REDIS_LOCALHOST") == "1": @@ -180,8 +180,8 @@ def get_redis_host_port(): port = 6379 # In the case of local testing, pick the right endpoint based on the config elif os.getenv("KUBECONFIG"): - logger.info(f"Inferring redis host from config env: {ml_infra_config().env}") - host = f"redis-elasticache-message-broker.{ml_infra_config().dns_host_domain}" + logger.info(f"Inferring redis host from config env: {infra_config().env}") + host = f"redis-elasticache-message-broker.{infra_config().dns_host_domain}" port = 6379 logger.info(f"Using Redis host and port: {host}:{port}") @@ -292,7 +292,7 @@ def celery_app( 2. When making requests to such deployment, you'll have to do: ```python - from scaleml.io.celery import TaskVisibility, celery_app + from model_engine_server.core.celery import TaskVisibility, celery_app app = celery_app(None, task_visibility=TaskVisibility.VISIBILITY_1M) future_result = app.send_task("some.task.name", args=["some", "args"], queue="some-queue") ``` @@ -342,10 +342,10 @@ def celery_app( # FIXME: serializer. Until we figure out how to run as a non-root user, it might be better to avoid pickle. :param s3_bucket: [optional] Bucket name to store task results when using S3 as backend. The results uri will be - "s3:////...". Defaults to "scale-ml" (s3://scale-ml/tmp/celery/). + "s3:////...". :param s3_base_path: [optional] Base path for task results when using S3 as backend. The results uri will be - "s3:////...". Defaults to "tmp/celery/" (s3://scale-ml/tmp/celery/). + "s3:////...". :param backend_protocol: [optional] Backend protocol to use, currently supports "s3" and "redis". Defaults to "s3". Redis might be faster than S3 but is not persistent, so using "redis" is discouraged. @@ -356,7 +356,7 @@ def celery_app( :param broker_type: [defaults to "redis"] The broker type. We currently support "redis" and "sqs". - :param aws_role: [optional] AWS role to use. If none, will default to default for s3 backends, + :param aws_role: [optional] AWS role to use. :param extra_changes: Extra keyword arguments to Celery app. Visit https://docs.celeryproject.org/en/stable/userguide/configuration.html to see options. @@ -430,7 +430,7 @@ def celery_app( } if s3_bucket is None: - s3_bucket = ml_infra_config().s3_bucket + s3_bucket = infra_config().s3_bucket backend_url, extra_conf_changes = _get_backend_url_and_conf( backend_protocol, @@ -504,7 +504,7 @@ def _get_backend_url_and_conf( elif backend_protocol == "s3": backend_url = "s3://" if aws_role is None: - aws_session = session(ml_infra_config().profile_ml_worker) + aws_session = session(infra_config().profile_ml_worker) else: aws_session = session(aws_role) out_conf_changes.update( diff --git a/server/llm_engine_server/core/celery/s3.py b/model-engine/model_engine_server/core/celery/s3.py similarity index 100% rename from server/llm_engine_server/core/celery/s3.py rename to model-engine/model_engine_server/core/celery/s3.py diff --git a/server/llm_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py similarity index 64% rename from server/llm_engine_server/core/config.py rename to model-engine/model_engine_server/core/config.py index b3a02a85..53942fb2 100644 --- a/server/llm_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -2,7 +2,7 @@ The configuration file is loaded from the ML_INFRA_SERVICES_CONFIG_PATH environment variable. If this is not set, the default configuration file is used from -llm_engine_server.core/configs/default.yaml. +model_engine_server.core/configs/default.yaml. """ import os from contextlib import contextmanager @@ -12,7 +12,7 @@ from typing import Optional, Sequence import yaml -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import filename_wo_ext, make_logger logger = make_logger(filename_wo_ext(__file__)) @@ -21,16 +21,16 @@ "CONFIG_PATH", "config_context", "get_config_path_for_env_name", - "ml_infra_config", + "infra_config", "use_config_context", ) -DEFAULT_CONFIG_PATH = Path(__file__).parent / "configs" / "circleci.yaml" +DEFAULT_CONFIG_PATH = Path(__file__).parent / "configs" / "default.yaml" CONFIG_PATH: str = os.getenv("ML_INFRA_SERVICES_CONFIG_PATH", str(DEFAULT_CONFIG_PATH)) @dataclass -class MLInfraServicesConfig: +class InfraConfig: env: str k8s_cluster_name: str dns_host_domain: str @@ -41,45 +41,46 @@ class MLInfraServicesConfig: s3_bucket: str profile_ml_worker: str = "default" profile_ml_inference_worker: str = "default" + identity_service_url: Optional[str] = None @classmethod - def from_yaml(cls, yaml_path) -> "MLInfraServicesConfig": + def from_yaml(cls, yaml_path) -> "InfraConfig": with open(yaml_path, "r") as f: raw_data = yaml.safe_load(f) - return MLInfraServicesConfig(**raw_data) + return InfraConfig(**raw_data) def read_default_config(): logger.info(f"Using config file path: `{CONFIG_PATH}`") - return MLInfraServicesConfig.from_yaml(CONFIG_PATH) + return InfraConfig.from_yaml(CONFIG_PATH) -_ml_infra_config: Optional[MLInfraServicesConfig] = None +_infra_config: Optional[InfraConfig] = None -def ml_infra_config() -> MLInfraServicesConfig: - global _ml_infra_config - if _ml_infra_config is None: - _ml_infra_config = read_default_config() - return _ml_infra_config +def infra_config() -> InfraConfig: + global _infra_config + if _infra_config is None: + _infra_config = read_default_config() + return _infra_config @contextmanager def config_context(config_path: str): """Context manager that temporarily changes the config file path.""" - global _ml_infra_config - current_config = deepcopy(_ml_infra_config) + global _infra_config + current_config = deepcopy(_infra_config) try: - _ml_infra_config = MLInfraServicesConfig.from_yaml(config_path) + _infra_config = InfraConfig.from_yaml(config_path) yield finally: - _ml_infra_config = current_config + _infra_config = current_config def use_config_context(config_path: str): """Use the config file at the given path.""" - global _ml_infra_config - _ml_infra_config = MLInfraServicesConfig.from_yaml(config_path) + global _infra_config + _infra_config = InfraConfig.from_yaml(config_path) def get_config_path_for_env_name(env_name: str) -> Path: diff --git a/server/llm_engine_server/core/configs/circleci.yaml b/model-engine/model_engine_server/core/configs/default.yaml similarity index 91% rename from server/llm_engine_server/core/configs/circleci.yaml rename to model-engine/model_engine_server/core/configs/default.yaml index 2ef8183b..745d3fc2 100644 --- a/server/llm_engine_server/core/configs/circleci.yaml +++ b/model-engine/model_engine_server/core/configs/default.yaml @@ -5,6 +5,6 @@ default_region: "us-west-2" ml_account_id: "000000000000" docker_repo_prefix: "000000000000.dkr.ecr.us-west-2.amazonaws.com" redis_host: "redis-message-broker-master.default" -s3_bucket: "scale-ml-circleci" +s3_bucket: "test-bucket" profile_ml_worker: "default" profile_ml_inference_worker: "default" diff --git a/server/llm_engine_server/core/docker/__init__.py b/model-engine/model_engine_server/core/docker/__init__.py similarity index 100% rename from server/llm_engine_server/core/docker/__init__.py rename to model-engine/model_engine_server/core/docker/__init__.py diff --git a/server/llm_engine_server/core/docker/docker_image.py b/model-engine/model_engine_server/core/docker/docker_image.py similarity index 90% rename from server/llm_engine_server/core/docker/docker_image.py rename to model-engine/model_engine_server/core/docker/docker_image.py index b46065a0..f61294c3 100644 --- a/server/llm_engine_server/core/docker/docker_image.py +++ b/model-engine/model_engine_server/core/docker/docker_image.py @@ -16,16 +16,12 @@ import boto3 import click import docker -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import make_logger +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import make_logger from .remote_build import MODELS_ROOT, build_remote_wrapper -logger = make_logger("llm_engine_server.core.docker.docker_image", log_level=logging.INFO) - -REGISTRY_ID = ml_infra_config().ml_account_id -ECR_REGION = ml_infra_config().default_region -ECR_REPO = f"{REGISTRY_ID}.dkr.ecr.{ECR_REGION}.amazonaws.com" +logger = make_logger("ml_serve.docker_image", log_level=logging.INFO) def _get_aws_creds() -> Dict[str, str]: @@ -108,7 +104,7 @@ def build( ) # Make sure not to do this after grabbing the AWS creds, so that we don't print them out. tag = _get_image_tag(image_tag) - image = f"{ECR_REPO}/{service_name}:{tag}" + image = f"{infra_config().docker_repo_prefix}/{service_name}:{tag}" local_args["image"] = image @@ -168,7 +164,7 @@ def build( } }, environment={ - "AWS_PROFILE": ml_infra_config().profile_ml_worker, + "AWS_PROFILE": infra_config().profile_ml_worker, "AWS_CONFIG_FILE": "/root/.aws/config", }, remove=True, @@ -190,14 +186,14 @@ def push(service_name: str, image_tag: Optional[str] = None) -> None: logger.info(f"push args: {local_args}") docker_client = docker.from_env() - ecr_client = boto3.client("ecr", region_name=ECR_REGION) - token = ecr_client.get_authorization_token(registryIds=[REGISTRY_ID]) + ecr_client = boto3.client("ecr", region_name=infra_config().default_region) + token = ecr_client.get_authorization_token(registryIds=[infra_config().ml_account_id]) username, password = ( base64.b64decode(token["authorizationData"][0]["authorizationToken"]).decode().split(":") ) output = docker_client.images.push( - repository=f"{ECR_REPO}/{service_name}", + repository=f"{infra_config().docker_repo_prefix}/{service_name}", tag=_get_image_tag(image_tag), auth_config={"username": username, "password": password}, stream=True, diff --git a/server/llm_engine_server/core/docker/ecr.py b/model-engine/model_engine_server/core/docker/ecr.py similarity index 82% rename from server/llm_engine_server/core/docker/ecr.py rename to model-engine/model_engine_server/core/docker/ecr.py index 1192a4d2..aaf9ef6f 100644 --- a/server/llm_engine_server/core/docker/ecr.py +++ b/model-engine/model_engine_server/core/docker/ecr.py @@ -1,18 +1,17 @@ from typing import Dict, List, Optional import boto3 -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.utils.git import tag +from model_engine_server.core.config import infra_config +from model_engine_server.core.utils.git import tag DEFAULT_FILTER = {"tagStatus": "TAGGED"} def repository_exists(repository_name: str): - ecr = boto3.client("ecr", region_name=ml_infra_config().default_region) + ecr = boto3.client("ecr", region_name=infra_config().default_region) try: response = ecr.describe_repositories( - registryId=ml_infra_config().ml_account_id, - repositoryNames=[repository_name], + registryId=infra_config().ml_account_id, repositoryNames=[repository_name] ) if response.get("repositories"): return True @@ -23,7 +22,7 @@ def repository_exists(repository_name: str): def batch_image_exists( *, - region_name: str = ml_infra_config().default_region, + region_name: str = infra_config().default_region, repository_name: str, image_tags: Optional[List[str]] = None, image_digests: Optional[List[str]] = None, @@ -45,7 +44,7 @@ def batch_image_exists( client = session.client("ecr", region_name=region_name) try: client.describe_images( - registryId=ml_infra_config().ml_account_id, + registryId=infra_config().ml_account_id, repositoryName=repository_name, imageIds=[ *[{"imageTag": t} for t in image_tags], @@ -61,7 +60,7 @@ def batch_image_exists( def image_exists( *, - region_name: str = ml_infra_config().default_region, + region_name: str = infra_config().default_region, repository_name: str, image_name: Optional[str] = None, image_tag: Optional[str] = None, @@ -88,10 +87,10 @@ def ecr_exists_for_repo(repo_name: str, image_tag: Optional[str] = None): """Check if image exists in ECR""" if image_tag is None: image_tag = tag() - ecr = boto3.client("ecr", region_name=ml_infra_config().default_region) + ecr = boto3.client("ecr", region_name=infra_config().default_region) try: ecr.describe_images( - registryId=ml_infra_config().ml_account_id, + registryId=infra_config().ml_account_id, repositoryName=repo_name, imageIds=[{"imageTag": image_tag}], ) diff --git a/server/llm_engine_server/core/docker/kaniko_template.yaml b/model-engine/model_engine_server/core/docker/kaniko_template.yaml similarity index 95% rename from server/llm_engine_server/core/docker/kaniko_template.yaml rename to model-engine/model_engine_server/core/docker/kaniko_template.yaml index d87f89f0..dfda89e3 100644 --- a/server/llm_engine_server/core/docker/kaniko_template.yaml +++ b/model-engine/model_engine_server/core/docker/kaniko_template.yaml @@ -33,7 +33,7 @@ spec: - "--cache=$USE_CACHE" - "--cache-copy-layers=$USE_CACHE" - "--cache-run-layers=$USE_CACHE" - - "--cache-repo=000000000000.dkr.ecr.us-west-2.amazonaws.com/kaniko-cache" + - "--cache-repo=$CACHE_REPO" - "--cleanup" - "--snapshot-mode=redo" - "--use-new-run" diff --git a/server/llm_engine_server/core/docker/kaniko_template_circleci.yaml b/model-engine/model_engine_server/core/docker/kaniko_template_circleci.yaml similarity index 100% rename from server/llm_engine_server/core/docker/kaniko_template_circleci.yaml rename to model-engine/model_engine_server/core/docker/kaniko_template_circleci.yaml diff --git a/server/llm_engine_server/core/docker/remote_build.py b/model-engine/model_engine_server/core/docker/remote_build.py similarity index 95% rename from server/llm_engine_server/core/docker/remote_build.py rename to model-engine/model_engine_server/core/docker/remote_build.py index 99881cb4..8b250cfd 100644 --- a/server/llm_engine_server/core/docker/remote_build.py +++ b/model-engine/model_engine_server/core/docker/remote_build.py @@ -20,13 +20,13 @@ from kubernetes import config as kube_config from kubernetes import watch from kubernetes.config.config_exception import ConfigException -from llm_engine_server.core.aws import storage_client -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.aws import storage_client +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger logger = make_logger(logger_name()) -S3_BUCKET = os.environ.get("S3_BUCKET", ml_infra_config().s3_bucket) +S3_BUCKET = os.environ.get("S3_BUCKET", infra_config().s3_bucket) SUB_BUCKET = "tmp/docker_contexts" # Adjust if either this file or kaniko_template.yaml moves! OWN_FILE_PATH = Path(__file__).resolve() @@ -60,7 +60,7 @@ def zip_context( Takes a path to a folder, zips up the folder and sticks it into s3 :param s3_file_name: Bucket/file for context tar.gz, will upload to here - :param context: Path to context for dockerfile, relative to calling script, i.e. box_detection/ or scaleml/ if you're running from models/ + :param context: Path to context for dockerfile, relative to calling script :param folders_to_include: List of paths to subfolders needed to build docker image, relative to context :param ignore_file: File (e.g. .dockerignore) containing things to ignore when preparing docker context. Relative to context. Contents of file are parsed according to tar's --exclude-from, which differs slightly from @@ -142,8 +142,7 @@ def start_build_job( custom_tags_serialized = json.dumps(custom_tags) destination_template = Template( - f"--destination={ml_infra_config().ml_account_id}.dkr.ecr." - f"{ml_infra_config().default_region}.amazonaws.com/$REPO_AND_TAG" + f"--destination={infra_config().docker_repo_prefix}/$REPO_AND_TAG" ) job_name = f"kaniko-{str(uuid.uuid4())[:8]}" @@ -157,15 +156,11 @@ def start_build_job( aws_secret_access_key = "" if os.getenv("CIRCLECI"): aws_access_key_id_result = subprocess.run( - ["aws", "configure", "get", "aws_access_key_id"], - check=False, - stdout=PIPE, + ["aws", "configure", "get", "aws_access_key_id"], check=False, stdout=PIPE ) aws_access_key_id = aws_access_key_id_result.stdout.decode().strip() aws_secret_access_key_result = subprocess.run( - ["aws", "configure", "get", "aws_secret_access_key"], - check=False, - stdout=PIPE, + ["aws", "configure", "get", "aws_secret_access_key"], check=False, stdout=PIPE ) aws_secret_access_key = aws_secret_access_key_result.stdout.decode().strip() job = Template(template_f.read()).substitute( @@ -175,6 +170,7 @@ def start_build_job( S3_BUCKET=S3_BUCKET, S3_FILE=s3_file_name, USE_CACHE="true" if use_cache else "false", + CACHE_REPO=f"{infra_config().docker_repo_prefix}/kaniko-cache", AWS_ACCESS_KEY_ID=aws_access_key_id, AWS_SECRET_ACCESS_KEY=aws_secret_access_key, NAMESPACE=NAMESPACE, @@ -197,7 +193,7 @@ def start_build_job( if not os.path.exists("/tmp"): os.makedirs("/tmp") pip_conf_file = "/tmp/.codeartifact-pip-conf" - aws_profile = ml_infra_config().profile_ml_worker + aws_profile = infra_config().profile_ml_worker subprocess.check_output( [ f"AWS_PROFILE={aws_profile} python scripts_py3/scale_scripts/exe/maybe_refresh_codeartifact.py --export {pip_conf_file}" @@ -209,13 +205,7 @@ def start_build_job( pip_conf_base64 = b64encode(f_conf.read().encode("utf-8")).decode("utf-8") data = {"data": {"codeartifact_pip_conf": pip_conf_base64}} subprocess.check_output( - [ - "kubectl", - "patch", - "secret", - "codeartifact-pip-conf", - f"-p={json.dumps(data)}", - ] + ["kubectl", "patch", "secret", "codeartifact-pip-conf", f"-p={json.dumps(data)}"] ).decode("utf-8") print(f"Executing Kaniko build command:\n{container_spec}") @@ -262,7 +252,7 @@ def build_remote( calling_path = Path(context).resolve() if folders_to_include is None: if calling_path == MODELS_ROOT: - default_folders = {"scaleml/"} + default_folders = {} # find the models/ project folder that this Dockerfile comes from parts = dockerfile.split("/") @@ -477,7 +467,7 @@ def build_remote_block( @click.option( "--folders", required=False, - help="Comma separated list of folders (relative to context), e.g. 'scaleml/,template-project/", + help="Comma separated list of folders (relative to context", ) @click.option( "--no-cache", diff --git a/server/llm_engine_server/core/domain_exceptions.py b/model-engine/model_engine_server/core/domain_exceptions.py similarity index 100% rename from server/llm_engine_server/core/domain_exceptions.py rename to model-engine/model_engine_server/core/domain_exceptions.py diff --git a/server/llm_engine_server/core/fake_notification_gateway.py b/model-engine/model_engine_server/core/fake_notification_gateway.py similarity index 90% rename from server/llm_engine_server/core/fake_notification_gateway.py rename to model-engine/model_engine_server/core/fake_notification_gateway.py index 1b2ed19d..909c3037 100644 --- a/server/llm_engine_server/core/fake_notification_gateway.py +++ b/model-engine/model_engine_server/core/fake_notification_gateway.py @@ -1,7 +1,7 @@ from collections import defaultdict from typing import List -from llm_engine_server.core.notification_gateway import NotificationApp, NotificationGateway +from model_engine_server.core.notification_gateway import NotificationApp, NotificationGateway class FakeNotificationGateway(NotificationGateway): diff --git a/server/llm_engine_server/core/loggers.py b/model-engine/model_engine_server/core/loggers.py similarity index 91% rename from server/llm_engine_server/core/loggers.py rename to model-engine/model_engine_server/core/loggers.py index 5a0877a1..91b69758 100644 --- a/server/llm_engine_server/core/loggers.py +++ b/model-engine/model_engine_server/core/loggers.py @@ -120,19 +120,6 @@ def make_json_logger(name: str, log_level: int = logging.INFO) -> logging.Logger stream_handler = logging.StreamHandler() in_kubernetes = os.getenv("KUBERNETES_SERVICE_HOST") if in_kubernetes: - # Somewhat hacky way of determining if we're running in a Datadog environment. - # Note that if you 'kubectl logs' the pod, you'll still see the JSON logs. But you really should - # just be looking at the logs in Datadog at that point. - # - # NOTE: If you're thinking of disabling this outside of your local machine, please consider - # just piping to `jq` instead, e.g.: - # - # $ kubectl logs -lapp=celery-autoscaler-singleton | jq -r '[.time, .level, .message] | join(" - ")' - # - # this spits out: - # - # 2021-04-08T23:40:03.148308 - INFO - Missing params, skipping deployment : - # 2021-04-08T23:40:03.148440 - INFO - Missing params, skipping deployment : stream_handler.setFormatter(CustomJSONFormatter()) else: # Reading JSON logs in your terminal is kinda hard, and you can't make use of the structured data @@ -246,8 +233,8 @@ def silence_chatty_logger(*logger_names, quieter=logging.FATAL) -> None: Accepts a variable number of logger names. """ - for name in logger_names: - log = logging.getLogger(name) + for logger_name in logger_names: + log = logging.getLogger(logger_name) log.setLevel(quieter) @@ -282,19 +269,19 @@ def loggers_at_level(*loggers_or_names, new_level: int) -> None: To illustrate use, see this pseudocode example: >>>> import logging - >>>> from llm_engine_server.core.loggers import loggers_at_level, make_logger + >>>> from model_engine_server.core.loggers import loggers_at_level, make_logger >>>> >>>> your_logger = make_logger('your_logger') >>>> >>>> with loggers_at_level( >>>> your_logger, - >>>> 'llm_engine_server.core.loggers', + >>>> 'model_engine_server.core.loggers', >>>> 'document_core.utils.k8s', >>>> new_level=logging.FATAL, >>>> ): >>>> # do_something_while_those_loggers_will_only_log_FATAL_messages >>>> your_logger.info("this will not be logged") - >>>> logging.getLogger('llm_engine_server.core.loggers').warning("neither will this") + >>>> logging.getLogger('model_engine_server.core.loggers').warning("neither will this") >>>> >>>> your_logger.info("this will be logged") """ diff --git a/server/llm_engine_server/core/notification_gateway.py b/model-engine/model_engine_server/core/notification_gateway.py similarity index 100% rename from server/llm_engine_server/core/notification_gateway.py rename to model-engine/model_engine_server/core/notification_gateway.py diff --git a/server/llm_engine_server/core/utils/__init__.py b/model-engine/model_engine_server/core/utils/__init__.py similarity index 100% rename from server/llm_engine_server/core/utils/__init__.py rename to model-engine/model_engine_server/core/utils/__init__.py diff --git a/server/llm_engine_server/core/utils/env.py b/model-engine/model_engine_server/core/utils/env.py similarity index 100% rename from server/llm_engine_server/core/utils/env.py rename to model-engine/model_engine_server/core/utils/env.py diff --git a/server/llm_engine_server/core/utils/format.py b/model-engine/model_engine_server/core/utils/format.py similarity index 100% rename from server/llm_engine_server/core/utils/format.py rename to model-engine/model_engine_server/core/utils/format.py diff --git a/server/llm_engine_server/core/utils/git.py b/model-engine/model_engine_server/core/utils/git.py similarity index 100% rename from server/llm_engine_server/core/utils/git.py rename to model-engine/model_engine_server/core/utils/git.py diff --git a/server/llm_engine_server/core/utils/python_utils.py b/model-engine/model_engine_server/core/utils/python_utils.py similarity index 95% rename from server/llm_engine_server/core/utils/python_utils.py rename to model-engine/model_engine_server/core/utils/python_utils.py index 2cfcd2f8..a9297d42 100644 --- a/server/llm_engine_server/core/utils/python_utils.py +++ b/model-engine/model_engine_server/core/utils/python_utils.py @@ -3,7 +3,7 @@ from importlib import import_module from typing import Any, Optional -from llm_engine_server.core.utils.format import split_module_value, strip_non_empty +from model_engine_server.core.utils.format import split_module_value, strip_non_empty def dynamic_load(module_name: str, value_name: Optional[str], validate: bool = True) -> Any: diff --git a/server/llm_engine_server/core/utils/timer.py b/model-engine/model_engine_server/core/utils/timer.py similarity index 98% rename from server/llm_engine_server/core/utils/timer.py rename to model-engine/model_engine_server/core/utils/timer.py index 4f72b7a4..6936cfa7 100644 --- a/server/llm_engine_server/core/utils/timer.py +++ b/model-engine/model_engine_server/core/utils/timer.py @@ -26,7 +26,7 @@ class timer: # pylint: disable=invalid-name The other use case is to pass in a `name` and a `logger`. The timing will be recorded when the context block is exited: - >>> from llm_engine_server.core.loggers import make_logger + >>> from model_engine_server.core.loggers import make_logger >>> >>> log = make_logger("my-main-program") >>> diff --git a/server/llm_engine_server/core/utils/url.py b/model-engine/model_engine_server/core/utils/url.py similarity index 91% rename from server/llm_engine_server/core/utils/url.py rename to model-engine/model_engine_server/core/utils/url.py index e9d6c758..ec80747c 100644 --- a/server/llm_engine_server/core/utils/url.py +++ b/model-engine/model_engine_server/core/utils/url.py @@ -85,9 +85,6 @@ def parse_attachment_url(url: str) -> ParsedURL: if match: bucket, key = match.group(1), match.group(2) - # pattern from https://docs.google.com/document/d/1WLbQXkQL7PLo0rkjU0RsI4SPAqUvV0WV1-FWkzicduc/edit - # scale-cds://62f2a2942a57fb0024e4dc3e/dgb6etBCrUHtOMQ#s3/scale-cds-private-us-west-2 - # scale-cds://57743957186fd0060017f1a1/json/0e09cdfc-adbb-4d88-acf7-d75a478328e3 match = re.search("scale-cds://(\\w+)/([\\-\\w\\/]+)", url) if match: bucket, key = match.group(1), match.group(2) @@ -98,8 +95,8 @@ def parse_attachment_url(url: str) -> ParsedURL: "Invalid attachment URL: no bucket or key specified: \n" f"'{url}'" ) - def clean(val): - return val and val.strip("/") + def clean(v): + return v and v.strip("/") return ParsedURL( protocol=clean(protocol), diff --git a/server/llm_engine_server/db/__init__.py b/model-engine/model_engine_server/db/__init__.py similarity index 100% rename from server/llm_engine_server/db/__init__.py rename to model-engine/model_engine_server/db/__init__.py diff --git a/server/llm_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py similarity index 73% rename from server/llm_engine_server/db/base.py rename to model-engine/model_engine_server/db/base.py index 68f4a7bf..37496e19 100644 --- a/server/llm_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -4,7 +4,9 @@ from typing import Iterator, Optional import sqlalchemy -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.aws.secrets import get_key_file +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import async_scoped_session, async_sessionmaker, create_async_engine from sqlalchemy.ext.declarative import declarative_base @@ -14,13 +16,16 @@ logger = make_logger(filename_wo_ext(__file__)) +def get_key_file_name(environment: str) -> str: + return f"{environment}/ml_infra_pg".replace("training", "prod").replace("-new", "") + + def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool = True) -> str: """Gets the URL of the Postgresql engine depending on the environment.""" if os.getenv("ML_INFRA_DATABASE_URL"): # In CircleCI environment, we set up a test in another container and specify the URL. engine_url = os.getenv("ML_INFRA_DATABASE_URL") - else: - assert "pytest" in sys.modules, "Must specify ML_INFRA_DATABASE_URL or be in a testing env." + elif "pytest" in sys.modules: # If we are in a local testing environment, we can set up a test psql instance. # pylint: disable=import-outside-toplevel import testing.postgresql @@ -30,6 +35,24 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool ) postgresql = Postgresql() engine_url = postgresql.url() + else: + key_file = os.environ.get("DB_SECRET_NAME") + if env is None: + env = infra_config().env + if key_file is None: + key_file = get_key_file_name(env) # type: ignore + logger.info(f"Using key file {key_file}") + db_secret_aws_profile = os.environ.get("DB_SECRET_AWS_PROFILE") + creds = get_key_file(key_file, db_secret_aws_profile) + + user = creds.get("username") + password = creds.get("password") + host = creds.get("clusterHostRo") if read_only else creds.get("clusterHost") + port = str(creds.get("port")) + dbname = creds.get("dbname") + logger.info(f"Connecting to db {host}:{port}, name {dbname}") + + engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}" assert engine_url @@ -51,32 +74,32 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool # but hopefully should completely eliminate # any of the postgres connection errors we've been seeing. -ml_infra_pg_engine = create_engine( +pg_engine = create_engine( get_engine_url(read_only=False, sync=True), echo=False, future=True, pool_pre_ping=True, ) -ml_infra_pg_engine_read_only = create_engine( +pg_engine_read_only = create_engine( get_engine_url(read_only=True, sync=True), echo=False, future=True, pool_pre_ping=True, ) -ml_infra_pg_engine_async = create_async_engine( +pg_engine_async = create_async_engine( get_engine_url(read_only=False, sync=False), echo=False, future=True, pool_pre_ping=True, ) -ml_infra_pg_engine_read_only_async = create_async_engine( +pg_engine_read_only_async = create_async_engine( get_engine_url(read_only=True, sync=False), echo=False, future=True, pool_pre_ping=True, max_overflow=5, ) -ml_infra_pg_engine_async_null_pool = create_async_engine( +pg_engine_async_null_pool = create_async_engine( get_engine_url(read_only=False, sync=False), echo=False, future=True, @@ -89,13 +112,13 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool # if you're running a synchronous program where concurrency of database connections is not # super important (e.g. Celery workers that use long-standing connections, and Celery is currently # synchronous). Use SessionAsync and SessionReadOnlyAsync in ASGI applications. -Session = sessionmaker(autocommit=False, autoflush=False, bind=ml_infra_pg_engine) -SessionReadOnly = sessionmaker(autocommit=False, autoflush=False, bind=ml_infra_pg_engine_read_only) +Session = sessionmaker(autocommit=False, autoflush=False, bind=pg_engine) +SessionReadOnly = sessionmaker(autocommit=False, autoflush=False, bind=pg_engine_read_only) SessionAsync = async_scoped_session( session_factory=async_sessionmaker( autocommit=False, autoflush=False, - bind=ml_infra_pg_engine_async, + bind=pg_engine_async, expire_on_commit=False, ), scopefunc=asyncio.current_task, @@ -104,7 +127,7 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool session_factory=async_sessionmaker( autocommit=False, autoflush=False, - bind=ml_infra_pg_engine_async_null_pool, + bind=pg_engine_async_null_pool, expire_on_commit=False, ), scopefunc=asyncio.current_task, @@ -113,7 +136,7 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool async_sessionmaker( autocommit=False, autoflush=False, - bind=ml_infra_pg_engine_read_only_async, + bind=pg_engine_read_only_async, expire_on_commit=False, ), scopefunc=asyncio.current_task, diff --git a/server/llm_engine_server/db/endpoint_row_lock.py b/model-engine/model_engine_server/db/endpoint_row_lock.py similarity index 95% rename from server/llm_engine_server/db/endpoint_row_lock.py rename to model-engine/model_engine_server/db/endpoint_row_lock.py index ed5786ac..676546f6 100644 --- a/server/llm_engine_server/db/endpoint_row_lock.py +++ b/model-engine/model_engine_server/db/endpoint_row_lock.py @@ -4,7 +4,7 @@ import time from contextlib import AbstractContextManager -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import filename_wo_ext, make_logger from sqlalchemy import BIGINT, cast, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.session import Session @@ -17,14 +17,10 @@ def get_lock_key(user_id: str, endpoint_name: str) -> int: uid_hash = int.from_bytes( - hashlib.sha256(bytes(user_id, "utf-8")).digest()[:4], - byteorder="little", - signed=False, + hashlib.sha256(bytes(user_id, "utf-8")).digest()[:4], byteorder="little", signed=False ) endpoint_name_hash = int.from_bytes( - hashlib.sha256(bytes(endpoint_name, "utf-8")).digest()[:4], - byteorder="little", - signed=False, + hashlib.sha256(bytes(endpoint_name, "utf-8")).digest()[:4], byteorder="little", signed=False ) return 2**32 * uid_hash + endpoint_name_hash - 2**63 diff --git a/server/llm_engine_server/db/local_setup.py b/model-engine/model_engine_server/db/local_setup.py similarity index 92% rename from server/llm_engine_server/db/local_setup.py rename to model-engine/model_engine_server/db/local_setup.py index cb446d5e..4db34463 100644 --- a/server/llm_engine_server/db/local_setup.py +++ b/model-engine/model_engine_server/db/local_setup.py @@ -2,13 +2,13 @@ import os import psycopg2 -from llm_engine_server.db.base import Base -from llm_engine_server.db.models import * +from model_engine_server.db.base import Base +from model_engine_server.db.models import * from sqlalchemy import create_engine from sqlalchemy.engine import Engine from tenacity import Retrying, stop_after_attempt, wait_exponential -SCHEMAS = ["llm_engine", "model"] +SCHEMAS = ["hosted_model_inference", "model"] def init_database(database_url: str, psycopg_connection): diff --git a/server/llm_engine_server/db/models/__init__.py b/model-engine/model_engine_server/db/models/__init__.py similarity index 68% rename from server/llm_engine_server/db/models/__init__.py rename to model-engine/model_engine_server/db/models/__init__.py index bd6a9788..e7a62852 100644 --- a/server/llm_engine_server/db/models/__init__.py +++ b/model-engine/model_engine_server/db/models/__init__.py @@ -1,6 +1,6 @@ from typing import Sequence -from .llm_engine import BatchJob, Bundle, DockerImageBatchJobBundle, Endpoint +from .hosted_model_inference import BatchJob, Bundle, DockerImageBatchJobBundle, Endpoint, Trigger from .model import Model, ModelArtifact, ModelVersion __all__: Sequence[str] = [ @@ -11,4 +11,5 @@ "Model", "ModelArtifact", "ModelVersion", + "Trigger", ] diff --git a/server/llm_engine_server/db/models/common/__init__.py b/model-engine/model_engine_server/db/models/common/__init__.py similarity index 100% rename from server/llm_engine_server/db/models/common/__init__.py rename to model-engine/model_engine_server/db/models/common/__init__.py diff --git a/server/llm_engine_server/db/models/common/query.py b/model-engine/model_engine_server/db/models/common/query.py similarity index 100% rename from server/llm_engine_server/db/models/common/query.py rename to model-engine/model_engine_server/db/models/common/query.py diff --git a/server/llm_engine_server/db/models/common/record.py b/model-engine/model_engine_server/db/models/common/record.py similarity index 92% rename from server/llm_engine_server/db/models/common/record.py rename to model-engine/model_engine_server/db/models/common/record.py index ae9602df..d2ecd2ce 100644 --- a/server/llm_engine_server/db/models/common/record.py +++ b/model-engine/model_engine_server/db/models/common/record.py @@ -2,9 +2,9 @@ from typing import Generic, Optional, Sequence, TypeVar -from llm_engine_server.db.base import Base -from llm_engine_server.db.models.common.query import Query -from llm_engine_server.db.models.exceptions import EntityNotFoundError +from model_engine_server.db.base import Base +from model_engine_server.db.models.common.query import Query +from model_engine_server.db.models.exceptions import EntityNotFoundError from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/server/llm_engine_server/db/models/constants.py b/model-engine/model_engine_server/db/models/constants.py similarity index 100% rename from server/llm_engine_server/db/models/constants.py rename to model-engine/model_engine_server/db/models/constants.py diff --git a/server/llm_engine_server/db/models/exceptions.py b/model-engine/model_engine_server/db/models/exceptions.py similarity index 100% rename from server/llm_engine_server/db/models/exceptions.py rename to model-engine/model_engine_server/db/models/exceptions.py diff --git a/server/llm_engine_server/db/models/llm_engine.py b/model-engine/model_engine_server/db/models/hosted_model_inference.py similarity index 94% rename from server/llm_engine_server/db/models/llm_engine.py rename to model-engine/model_engine_server/db/models/hosted_model_inference.py index e508c188..8e028a2c 100644 --- a/server/llm_engine_server/db/models/llm_engine.py +++ b/model-engine/model_engine_server/db/models/hosted_model_inference.py @@ -18,7 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import relationship, selectinload from sqlalchemy.sql import func, text -from sqlalchemy.sql.expression import update +from sqlalchemy.sql.expression import delete, update from sqlalchemy.sql.schema import CheckConstraint, Index, UniqueConstraint from xid import XID @@ -105,7 +105,7 @@ class Bundle(Base): CheckConstraint( "(flavor = 'triton_enhanced_runnable_image') = (triton_enhanced_runnable_image_readiness_initial_delay_seconds IS NOT NULL)" ), - {"schema": "llm_engine"}, + {"schema": "hosted_model_inference"}, ) id = Column(Text, primary_key=True) @@ -433,7 +433,7 @@ class Endpoint(Base): unique=True, postgresql_where=text("endpoint_metadata ? '_llm'"), ), - {"schema": "llm_engine"}, + {"schema": "hosted_model_inference"}, ) id = Column(Text, primary_key=True) @@ -441,7 +441,7 @@ class Endpoint(Base): created_by = Column(String(SHORT_STRING), index=True) created_at = Column(DateTime(timezone=True), server_default=func.now()) last_updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=time_now) - current_bundle_id = Column(Text, ForeignKey("llm_engine.bundles.id")) + current_bundle_id = Column(Text, ForeignKey("hosted_model_inference.bundles.id")) endpoint_metadata = Column(JSONB, default={}) creation_task_id = Column(Text) endpoint_type = Column(Text, default="async") @@ -623,7 +623,7 @@ async def delete(cls, session: AsyncSession, endpoint: "Endpoint") -> None: class BatchJob(Base): __tablename__ = "batch_jobs" - __table_args__ = ({"schema": "llm_engine"},) + __table_args__ = ({"schema": "hosted_model_inference"},) id = Column(Text, primary_key=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) @@ -632,9 +632,11 @@ class BatchJob(Base): created_by = Column(String(SHORT_STRING), index=True, nullable=False) owner = Column(String(SHORT_STRING), index=True, nullable=False) model_bundle_id = Column( - Text, ForeignKey("llm_engine.bundles.id", ondelete="SET NULL"), nullable=False + Text, ForeignKey("hosted_model_inference.bundles.id", ondelete="SET NULL"), nullable=False + ) + model_endpoint_id = Column( + Text, ForeignKey("hosted_model_inference.endpoints.id"), nullable=True ) - model_endpoint_id = Column(Text, ForeignKey("llm_engine.endpoints.id"), nullable=True) task_ids_location = Column(Text, nullable=True) result_location = Column(Text, nullable=True) @@ -704,7 +706,7 @@ async def update_by_id( class DockerImageBatchJobBundle(Base): __tablename__ = "docker_image_batch_job_bundles" - __table_args__ = ({"schema": "llm_engine"},) + __table_args__ = ({"schema": "hosted_model_inference"},) id = Column("id", Text, primary_key=True) name = Column("name", Text, nullable=False) @@ -806,7 +808,7 @@ class Trigger(Base): __tablename__ = "triggers" __table_args__ = ( UniqueConstraint("name", "owner", name="uq_triggers_name_owner"), - {"schema": "llm_engine"}, + {"schema": "hosted_model_inference"}, ) id = Column("id", String, nullable=False, primary_key=True) @@ -820,7 +822,7 @@ class Trigger(Base): docker_image_batch_job_bundle_id = Column( "docker_image_batch_job_bundle_id", String, - ForeignKey("llm_engine.docker_image_batch_job_bundles.id"), + ForeignKey("hosted_model_inference.docker_image_batch_job_bundles.id"), nullable=False, ) default_job_config = Column("default_job_config", JSONB, nullable=True) @@ -845,3 +847,33 @@ def __init__( self.docker_image_batch_job_bundle_id = docker_image_batch_job_bundle_id self.default_job_config = default_job_config self.default_job_metadata = default_job_metadata + + @classmethod + async def create(cls, session: AsyncSession, trigger: "Trigger") -> None: + session.add(trigger) + await session.commit() + + @classmethod + async def select_all_by_owner(cls, session: AsyncSession, owner: str) -> List["Trigger"]: + triggers = await session.execute(select(Trigger).filter_by(owner=owner)) + return triggers.scalars().all() + + @classmethod + async def select_by_id(cls, session: AsyncSession, trigger_id: str) -> Optional["Trigger"]: + trigger = await session.execute(select(Trigger).filter_by(id=trigger_id)) + return trigger.scalar_one_or_none() + + @classmethod + async def update_by_id( + cls, session: AsyncSession, trigger_id: str, kwargs: Dict[str, Any] + ) -> None: + update_kwargs = kwargs.copy() + stmt = update(Trigger).where(Trigger.id == trigger_id).values(**update_kwargs) + await session.execute(stmt) + await session.commit() + + @classmethod + async def delete_by_id(cls, session: AsyncSession, trigger_id: str) -> None: + stmt = delete(Trigger).where(Trigger.id == trigger_id) + await session.execute(stmt) + await session.commit() diff --git a/server/llm_engine_server/db/models/model.py b/model-engine/model_engine_server/db/models/model.py similarity index 93% rename from server/llm_engine_server/db/models/model.py rename to model-engine/model_engine_server/db/models/model.py index 043e170b..d5c6fef9 100644 --- a/server/llm_engine_server/db/models/model.py +++ b/model-engine/model_engine_server/db/models/model.py @@ -106,9 +106,9 @@ class ModelVersion(Base): Column("model_id", Text, ForeignKey("model.models.id"), index=True, nullable=False), Column("version_number", Integer, index=True, nullable=False), Column( - "llm_engine_model_bundle_id", + "launch_model_bundle_id", Text, - # ForeignKey("llm_engine.bundles.id"), # This is currently breaking tests. + # ForeignKey("hosted_model_inference.bundles.id"), # This is currently breaking tests. index=True, nullable=True, ), @@ -116,14 +116,9 @@ class ModelVersion(Base): Column("tags", ARRAY(Text), index=True, nullable=False), Column("metadata", JSON, index=False, server_default="{}"), Column("created_by", String(SHORT_STRING), index=True, nullable=False), - Column( - "created_at", - DateTime(timezone=True), - server_default=func.now(), - nullable=False, - ), + Column("created_at", DateTime(timezone=True), server_default=func.now(), nullable=False), UniqueConstraint("model_id", "version_number", name="model_id_version_number_uc"), - UniqueConstraint("llm_engine_model_bundle_id", name="llm_engine_model_bundle_id_uc"), + UniqueConstraint("launch_model_bundle_id", name="launch_model_bundle_id_uc"), UniqueConstraint("nucleus_model_id", name="nucleus_model_id_uc"), schema="model", ) @@ -132,7 +127,7 @@ def __init__( self, model_id: Optional[str] = None, version_number: Optional[int] = None, - llm_engine_model_bundle_id: Optional[str] = None, + launch_model_bundle_id: Optional[str] = None, nucleus_model_id: Optional[str] = None, tags: Optional[List[str]] = None, metadata: Optional[Any] = None, @@ -142,7 +137,7 @@ def __init__( self.id = f"mov_{get_xid()}" self.model_id = model_id self.version_number = version_number - self.llm_engine_model_bundle_id = llm_engine_model_bundle_id + self.launch_model_bundle_id = launch_model_bundle_id self.nucleus_model_id = nucleus_model_id self.tags = tags or [] self.metadata = metadata @@ -175,11 +170,11 @@ def select( return models @staticmethod - def select_by_llm_engine_model_bundle_id( - session: Session, llm_engine_model_bundle_id: str + def select_by_launch_model_bundle_id( + session: Session, launch_model_bundle_id: str ) -> Optional["ModelVersion"]: model_version = session.execute( - select(ModelVersion).filter_by(llm_engine_model_bundle_id=llm_engine_model_bundle_id) + select(ModelVersion).filter_by(launch_model_bundle_id=launch_model_bundle_id) ).scalar_one_or_none() return model_version diff --git a/server/llm_engine_server/db/models/utils/__init__.py b/model-engine/model_engine_server/db/models/utils/__init__.py similarity index 100% rename from server/llm_engine_server/db/models/utils/__init__.py rename to model-engine/model_engine_server/db/models/utils/__init__.py diff --git a/server/llm_engine_server/db/models/utils/misc.py b/model-engine/model_engine_server/db/models/utils/misc.py similarity index 100% rename from server/llm_engine_server/db/models/utils/misc.py rename to model-engine/model_engine_server/db/models/utils/misc.py diff --git a/server/llm_engine_server/domain/__init__.py b/model-engine/model_engine_server/domain/__init__.py similarity index 100% rename from server/llm_engine_server/domain/__init__.py rename to model-engine/model_engine_server/domain/__init__.py diff --git a/server/llm_engine_server/domain/authorization/__init__.py b/model-engine/model_engine_server/domain/authorization/__init__.py similarity index 100% rename from server/llm_engine_server/domain/authorization/__init__.py rename to model-engine/model_engine_server/domain/authorization/__init__.py diff --git a/server/llm_engine_server/domain/authorization/scale_authorization_module.py b/model-engine/model_engine_server/domain/authorization/live_authorization_module.py similarity index 67% rename from server/llm_engine_server/domain/authorization/scale_authorization_module.py rename to model-engine/model_engine_server/domain/authorization/live_authorization_module.py index 19eaab5e..f895cefe 100644 --- a/server/llm_engine_server/domain/authorization/scale_authorization_module.py +++ b/model-engine/model_engine_server/domain/authorization/live_authorization_module.py @@ -1,17 +1,19 @@ -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.dtos.model_bundles import ( CreateModelBundleV1Request, CreateModelBundleV2Request, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.entities import CustomFramework, ModelBundleFrameworkType, OwnedEntity -from llm_engine_server.domain.entities.model_bundle_entity import RunnableImageLike -from llm_engine_server.domain.entities.model_endpoint_entity import ModelEndpointRecord - -LLM_ENGINE_INTEGRATION_TEST_USER: str = "62bc820451dbea002b1c5421" +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities import ( + CustomFramework, + ModelBundleFrameworkType, + ModelEndpointRecord, + OwnedEntity, + RunnableImageLike, +) -class ScaleAuthorizationModule: +class LiveAuthorizationModule: """ This class contains authorization utilities. All methods expect User objects given from authn. """ @@ -29,13 +31,9 @@ def check_access_create_bundle_v1(user: User, request: CreateModelBundleV1Reques def check_access_create_bundle_v2(user: User, request: CreateModelBundleV2Request) -> bool: """Checks whether the provided user is authorized to create the requested model bundle.""" # External customers cannot use custom images. - return ( - user.is_privileged_user - or user.user_id == LLM_ENGINE_INTEGRATION_TEST_USER - or ( - not isinstance(request.flavor, RunnableImageLike) - and not isinstance(request.flavor.framework, CustomFramework) - ) + return user.is_privileged_user or ( + not isinstance(request.flavor, RunnableImageLike) + and not isinstance(request.flavor.framework, CustomFramework) ) @staticmethod @@ -52,12 +50,12 @@ def check_access_write_owned_entity(user: User, owned_entity: OwnedEntity) -> bo @staticmethod def get_aws_role_for_user(user: User) -> str: """Returns the AWS role that should be assumed with the user's resources.""" - return ml_infra_config().profile_ml_inference_worker + return infra_config().profile_ml_inference_worker @staticmethod def get_s3_bucket_for_user(user: User) -> str: """Returns the AWS role that should be assumed with the user's resources.""" - return ml_infra_config().s3_bucket + return infra_config().s3_bucket @staticmethod def check_endpoint_public_inference_for_user( diff --git a/server/llm_engine_server/domain/entities/__init__.py b/model-engine/model_engine_server/domain/entities/__init__.py similarity index 85% rename from server/llm_engine_server/domain/entities/__init__.py rename to model-engine/model_engine_server/domain/entities/__init__.py index a906eb83..a3ed7393 100644 --- a/server/llm_engine_server/domain/entities/__init__.py +++ b/model-engine/model_engine_server/domain/entities/__init__.py @@ -6,10 +6,13 @@ BatchJobRecord, BatchJobSerializationFormat, BatchJobStatus, + DockerImageBatchJob, ) -from .common_types import CpuSpecificationType, StorageSpecificationType +from .common_types import CpuSpecificationType, FineTuneHparamValueType, StorageSpecificationType +from .file_entity import FileMetadata from .gpu_type import GpuType from .llm_entity import LLMInferenceFramework, LLMMetadata, LLMSource, Quantization +from .llm_fine_tune_entity import LLMFineTuneEvent from .model_bundle_entity import ( ArtifactLike, CloudpickleArtifactFlavor, @@ -44,8 +47,9 @@ ModelEndpointUserConfigState, ) from .owned_entity import OwnedEntity +from .trigger_entity import Trigger -__all__: Sequence[str] = ( +__all__: Sequence[str] = [ "ArtifactLike", "BatchJob", "BatchJobProgress", @@ -58,7 +62,11 @@ "CloudpickleArtifactFlavor", "CpuSpecificationType", "CustomFramework", + "DockerImageBatchJob", + "FileMetadata", "GpuType", + "FineTuneHparamValueType", + "LLMFineTuneEvent", "LLMInferenceFramework", "LLMMetadata", "LLMSource", @@ -86,6 +94,7 @@ "StorageSpecificationType", "StreamingEnhancedRunnableImageFlavor", "TensorflowFramework", + "Trigger", "TritonEnhancedRunnableImageFlavor", "ZipArtifactFlavor", -) +] diff --git a/server/llm_engine_server/domain/entities/batch_job_entity.py b/model-engine/model_engine_server/domain/entities/batch_job_entity.py similarity index 78% rename from server/llm_engine_server/domain/entities/batch_job_entity.py rename to model-engine/model_engine_server/domain/entities/batch_job_entity.py index fe16398e..e80f9fd4 100644 --- a/server/llm_engine_server/domain/entities/batch_job_entity.py +++ b/model-engine/model_engine_server/domain/entities/batch_job_entity.py @@ -1,10 +1,10 @@ from datetime import datetime from enum import Enum -from typing import Optional +from typing import Dict, Optional -from llm_engine_server.domain.entities.model_bundle_entity import ModelBundle -from llm_engine_server.domain.entities.model_endpoint_entity import ModelEndpoint -from llm_engine_server.domain.entities.owned_entity import OwnedEntity +from model_engine_server.domain.entities.model_bundle_entity import ModelBundle +from model_engine_server.domain.entities.model_endpoint_entity import ModelEndpoint +from model_engine_server.domain.entities.owned_entity import OwnedEntity from pydantic import BaseModel @@ -59,3 +59,5 @@ class DockerImageBatchJob(BaseModel): created_at: datetime completed_at: Optional[datetime] status: BatchJobStatus # the status map relatively nicely onto BatchJobStatus + annotations: Optional[Dict[str, str]] = None + override_job_max_runtime_s: Optional[int] = None diff --git a/server/llm_engine_server/domain/entities/common_types.py b/model-engine/model_engine_server/domain/entities/common_types.py similarity index 69% rename from server/llm_engine_server/domain/entities/common_types.py rename to model-engine/model_engine_server/domain/entities/common_types.py index 899a2973..3556723c 100644 --- a/server/llm_engine_server/domain/entities/common_types.py +++ b/model-engine/model_engine_server/domain/entities/common_types.py @@ -2,3 +2,4 @@ CpuSpecificationType = Union[str, int, float] StorageSpecificationType = Union[str, int, float] # TODO(phil): we can make this more specific. +FineTuneHparamValueType = Union[str, int, float] # should suffice for now diff --git a/server/llm_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py b/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py similarity index 80% rename from server/llm_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py rename to model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py index 02a14990..1ed2838d 100644 --- a/server/llm_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py @@ -1,8 +1,8 @@ import datetime from typing import Dict, List, Optional -from llm_engine_server.domain.entities import GpuType -from llm_engine_server.domain.entities.owned_entity import OwnedEntity +from model_engine_server.domain.entities import GpuType +from model_engine_server.domain.entities.owned_entity import OwnedEntity class DockerImageBatchJobBundle(OwnedEntity): diff --git a/model-engine/model_engine_server/domain/entities/file_entity.py b/model-engine/model_engine_server/domain/entities/file_entity.py new file mode 100644 index 00000000..f21314eb --- /dev/null +++ b/model-engine/model_engine_server/domain/entities/file_entity.py @@ -0,0 +1,15 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class FileMetadata(BaseModel): + """ + This is the entity-layer class for a File from the Files API. + """ + + id: str + filename: str + size: int + owner: str + updated_at: datetime diff --git a/server/llm_engine_server/domain/entities/gpu_type.py b/model-engine/model_engine_server/domain/entities/gpu_type.py similarity index 65% rename from server/llm_engine_server/domain/entities/gpu_type.py rename to model-engine/model_engine_server/domain/entities/gpu_type.py index a8c4ade4..5dc2c459 100644 --- a/server/llm_engine_server/domain/entities/gpu_type.py +++ b/model-engine/model_engine_server/domain/entities/gpu_type.py @@ -2,8 +2,9 @@ class GpuType(str, Enum): - """Lists allowed GPU types for LLMEngine.""" + """Lists allowed GPU types for Launch.""" NVIDIA_TESLA_T4 = "nvidia-tesla-t4" NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" + NVIDIA_AMPERE_A100E = "nvidia-ampere-a100e" diff --git a/server/llm_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py similarity index 100% rename from server/llm_engine_server/domain/entities/llm_entity.py rename to model-engine/model_engine_server/domain/entities/llm_entity.py diff --git a/server/llm_engine_server/domain/entities/llm_fine_tune_job_entity.py b/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py similarity index 54% rename from server/llm_engine_server/domain/entities/llm_fine_tune_job_entity.py rename to model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py index 487483ae..13188c06 100644 --- a/server/llm_engine_server/domain/entities/llm_fine_tune_job_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py @@ -1,14 +1,19 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel -class LLMFineTuneJobTemplate(BaseModel): +class LLMFineTuneTemplate(BaseModel): docker_image_batch_job_bundle_id: str - launch_bundle_config: Dict[str, Any] launch_endpoint_config: Dict[str, Any] default_hparams: Dict[str, Any] required_params: List[str] class Config: orm_mode = True + + +class LLMFineTuneEvent(BaseModel): + timestamp: Optional[float] = None + message: str + level: str diff --git a/server/llm_engine_server/domain/entities/model_bundle_entity.py b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py similarity index 91% rename from server/llm_engine_server/domain/entities/model_bundle_entity.py rename to model-engine/model_engine_server/domain/entities/model_bundle_entity.py index 70e03494..247539d0 100644 --- a/server/llm_engine_server/domain/entities/model_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py @@ -3,8 +3,8 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME -from llm_engine_server.domain.entities.owned_entity import OwnedEntity +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME +from model_engine_server.domain.entities.owned_entity import OwnedEntity from pydantic import BaseModel, Field, root_validator from typing_extensions import Literal @@ -18,6 +18,7 @@ class ModelBundlePackagingType(str, Enum): CLOUDPICKLE = "cloudpickle" ZIP = "zip" + LIRA = "lira" class ModelBundleFrameworkType(str, Enum): @@ -176,7 +177,7 @@ class TritonEnhancedRunnableImageFlavor(RunnableImageLike): flavor: Literal[ModelBundleFlavorType.TRITON_ENHANCED_RUNNABLE_IMAGE] triton_model_repository: str - triton_model_replicas: Optional[Dict[str, int]] + triton_model_replicas: Optional[Dict[str, str]] triton_num_cpu: float triton_commit_tag: str triton_storage: Optional[str] @@ -235,8 +236,15 @@ class Config: orm_mode = True def is_runnable(self) -> bool: - """True iff the model bundle calls for it.""" - return isinstance(self.flavor, RunnableImageLike) + """True iff the model bundle calls for it. + + If it is set to 'true', then this function will only return true if the :param:`model_bundle`'s + packaging_type is `ModelBundlePackagingType.LIRA` or if the :param:`model_bundle`'s flavor is + an instance of `RunnableImageLike`. Otherwise, it will return false. + """ + return self.packaging_type == ModelBundlePackagingType.LIRA or isinstance( + self.flavor, RunnableImageLike + ) def celery_task_name(self): return LIRA_CELERY_TASK_NAME if self.is_runnable() else DEFAULT_CELERY_TASK_NAME diff --git a/server/llm_engine_server/domain/entities/model_endpoint_entity.py b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py similarity index 88% rename from server/llm_engine_server/domain/entities/model_endpoint_entity.py rename to model-engine/model_engine_server/domain/entities/model_endpoint_entity.py index ec16e2b9..809035ba 100644 --- a/server/llm_engine_server/domain/entities/model_endpoint_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py @@ -3,15 +3,15 @@ from typing import Any, Dict, List, Optional, Union from fastapi.openapi.models import OpenAPI -from llm_engine_server.common import dict_not_none -from llm_engine_server.common.serialization_utils import b64_to_python_json, python_json_to_b64 -from llm_engine_server.domain.entities.common_types import ( +from model_engine_server.common import dict_not_none +from model_engine_server.common.serialization_utils import b64_to_python_json, python_json_to_b64 +from model_engine_server.domain.entities.common_types import ( CpuSpecificationType, StorageSpecificationType, ) -from llm_engine_server.domain.entities.gpu_type import GpuType -from llm_engine_server.domain.entities.model_bundle_entity import ModelBundle -from llm_engine_server.domain.entities.owned_entity import OwnedEntity +from model_engine_server.domain.entities.gpu_type import GpuType +from model_engine_server.domain.entities.model_bundle_entity import ModelBundle +from model_engine_server.domain.entities.owned_entity import OwnedEntity from pydantic import BaseModel, Field from typing_extensions import Literal @@ -84,6 +84,8 @@ class ModelEndpointConfig(BaseModel): bundle_name: str post_inference_hooks: Optional[List[str]] user_id: Optional[str] = None + billing_queue: Optional[str] = None + billing_tags: Optional[Dict[str, Any]] = None default_callback_url: Optional[str] = None default_callback_auth: Optional[CallbackAuth] diff --git a/server/llm_engine_server/domain/entities/owned_entity.py b/model-engine/model_engine_server/domain/entities/owned_entity.py similarity index 100% rename from server/llm_engine_server/domain/entities/owned_entity.py rename to model-engine/model_engine_server/domain/entities/owned_entity.py diff --git a/model-engine/model_engine_server/domain/entities/trigger_entity.py b/model-engine/model_engine_server/domain/entities/trigger_entity.py new file mode 100644 index 00000000..ac515865 --- /dev/null +++ b/model-engine/model_engine_server/domain/entities/trigger_entity.py @@ -0,0 +1,20 @@ +import datetime +from typing import Any, Dict, Optional + +from model_engine_server.domain.entities.owned_entity import OwnedEntity + + +class Trigger(OwnedEntity): + id: str + name: str + owner: str + created_by: str + created_at: datetime.datetime + + cron_schedule: str + docker_image_batch_job_bundle_id: str + default_job_config: Optional[Dict[str, Any]] + default_job_metadata: Optional[Dict[str, str]] + + class Config: + orm_mode = True diff --git a/server/llm_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py similarity index 77% rename from server/llm_engine_server/domain/exceptions.py rename to model-engine/model_engine_server/domain/exceptions.py index 4763690a..66b6f708 100644 --- a/server/llm_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -1,4 +1,4 @@ -from llm_engine_server.core.domain_exceptions import DomainException +from model_engine_server.core.domain_exceptions import DomainException class ExistingEndpointOperationInProgressException(DomainException): @@ -37,7 +37,7 @@ class EndpointInfraStateNotFound(DomainException): class EndpointResourceInfraException(DomainException): """ Thrown if the endpoint resource request passes validation, but failed for unhandled reasons. - This corresponds to a 503 error and requires investigation by the LLMEngine team. + This corresponds to a 503 error and requires investigation by the Launch team. """ @@ -47,6 +47,12 @@ class EndpointLabelsException(DomainException): """ +class EndpointBillingTagsMalformedException(DomainException): + """ + Thrown if endpoint billing tags are malformed (i.e. wrong type, wrong keys, etc.) + """ + + class TooManyRequestsException(DomainException): """ Thrown if an endpoint returns a 429 exception for too many requests. @@ -77,7 +83,25 @@ class LLMFineTuningMethodNotImplementedException(DomainException): """ +class LLMFineTuningQuotaReached(DomainException): + """ + Thrown if the user has run too many fine-tunes. + """ + + class InvalidRequestException(DomainException): """ Thrown if the user request is invalid. """ + + +class CronSyntaxException(DomainException): + """ + Thrown if the requested cron schedule has invalid syntax. + """ + + +class TriggerNameAlreadyExistsException(DomainException): + """ + Thrown if the requested name already exists in the trigger repository + """ diff --git a/server/llm_engine_server/domain/gateways/__init__.py b/model-engine/model_engine_server/domain/gateways/__init__.py similarity index 79% rename from server/llm_engine_server/domain/gateways/__init__.py rename to model-engine/model_engine_server/domain/gateways/__init__.py index caa45f72..fea6d2b5 100644 --- a/server/llm_engine_server/domain/gateways/__init__.py +++ b/model-engine/model_engine_server/domain/gateways/__init__.py @@ -1,5 +1,8 @@ from .async_model_endpoint_inference_gateway import AsyncModelEndpointInferenceGateway +from .cron_job_gateway import CronJobGateway from .docker_image_batch_job_gateway import DockerImageBatchJobGateway +from .file_storage_gateway import FileStorageGateway +from .llm_artifact_gateway import LLMArtifactGateway from .model_endpoints_schema_gateway import ModelEndpointsSchemaGateway from .model_primitive_gateway import ModelPrimitiveGateway from .monitoring_metrics_gateway import MonitoringMetricsGateway @@ -9,7 +12,10 @@ __all__ = ( "AsyncModelEndpointInferenceGateway", + "CronJobGateway", "DockerImageBatchJobGateway", + "FileStorageGateway", + "LLMArtifactGateway", "ModelEndpointsSchemaGateway", "ModelPrimitiveGateway", "MonitoringMetricsGateway", diff --git a/server/llm_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py similarity index 88% rename from server/llm_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py index 7aebae40..bff654c2 100644 --- a/server/llm_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, diff --git a/model-engine/model_engine_server/domain/gateways/cron_job_gateway.py b/model-engine/model_engine_server/domain/gateways/cron_job_gateway.py new file mode 100644 index 00000000..c4bb289b --- /dev/null +++ b/model-engine/model_engine_server/domain/gateways/cron_job_gateway.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob + + +class CronJobGateway(ABC): + """ + Base class for K8s CronJob Gateway + """ + + @abstractmethod + async def create_cronjob( + self, + *, + request_host: str, + trigger_id: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Dict[str, str], + ) -> None: + """ + Create a cron job from a bundle and trigger. + + Args: + request_host: URL to forward the batch job creation request + trigger_id: The ID of the trigger + created_by: The user who created the trigger + owner: The user who owns the trigger + cron_schedule: Cron-formatted string representing the cron job's invocation schedule + docker_image_batch_job_bundle_id: The ID of the docker image batch job bundle + default_job_config: The user-specified input to the batch job. Exposed as a file mounted at mount_location to the batch job + job_config: K8s team/product labels + resource_requests: The resource requests for the batch job + + Returns: + None + """ + pass + + @abstractmethod + async def list_jobs( + self, + *, + owner: str, + trigger_id: Optional[str], + ) -> List[DockerImageBatchJob]: + """ + Lists all docker image batch jobs spawned by the trigger with the given ID, otherwise by owner if trigger_id is None + + Args: + trigger_id: the ID of the trigger pointing to the cron job + + Returns: + List of docker image batch jobs spawned by the trigger with the given ID, otherwise by owner if trigger_id is None + """ + pass + + @abstractmethod + async def update_cronjob( + self, + *, + trigger_id: str, + cron_schedule: Optional[str], + suspend: Optional[bool], + ) -> None: + """ + Partially updates the schedule field and/or the suspend field of the specified cron job. + + Args: + trigger_id: the ID of the trigger pointing to the cron job + cron_schedule: New cron schedule parameter representing the cron job's invocation schedule + suspend: The active status of the trigger, False means paused and True means unpaused + + Returns: + None + """ + pass + + @abstractmethod + async def delete_cronjob( + self, + *, + trigger_id: str, + ) -> None: + """ + Deletes the specified cron job. + + Args: + trigger_id: the ID of the trigger pointing to the cron job + + Returns: + None + """ + pass diff --git a/server/llm_engine_server/domain/gateways/docker_image_batch_job_gateway.py b/model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py similarity index 88% rename from server/llm_engine_server/domain/gateways/docker_image_batch_job_gateway.py rename to model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py index eafd355c..43af4e04 100644 --- a/server/llm_engine_server/domain/gateways/docker_image_batch_job_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py @@ -1,14 +1,13 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob class DockerImageBatchJobGateway(ABC): """ Base class for docker image batch job gateway - """ @abstractmethod @@ -25,6 +24,8 @@ async def create_docker_image_batch_job( resource_requests: CreateDockerImageBatchJobResourceRequests, labels: Dict[str, str], mount_location: Optional[str], + annotations: Optional[Dict[str, str]] = None, + override_job_max_runtime_s: Optional[int] = None, ) -> str: """ Create a docker image batch job @@ -38,6 +39,7 @@ async def create_docker_image_batch_job( repo: The ECR repo where the docker image running the batch job lies tag: The tag of the docker image labels: K8s team/product labels + annotations: K8s annotations resource_requests: The resource requests for the batch job. mount_location: Location on filesystem where runtime-provided file contents get mounted diff --git a/model-engine/model_engine_server/domain/gateways/file_storage_gateway.py b/model-engine/model_engine_server/domain/gateways/file_storage_gateway.py new file mode 100644 index 00000000..b76fdd77 --- /dev/null +++ b/model-engine/model_engine_server/domain/gateways/file_storage_gateway.py @@ -0,0 +1,94 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from model_engine_server.domain.entities import FileMetadata + + +class FileStorageGateway(ABC): + """ + Base class for file storage gateway + """ + + @abstractmethod + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + """ + Get file URL from file ID + + Args: + owner: The user who owns the file. + file_id: The ID of the file. + + Returns: + The URL of the file, or None if it does not exist. + """ + pass + + @abstractmethod + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + """ + Upload a file + + Args: + owner: The user who owns the file. + filename: The name of the file. + content: The content of the file. + + Returns: + The ID of the file. + """ + pass + + @abstractmethod + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + """ + Get metadata about a file. + + Args: + owner: The user who owns the file. + file_id: The ID of the file. + + Returns: + Information about the file, or None if it does not exist. + """ + pass + + @abstractmethod + async def list_files(self, owner: str) -> List[FileMetadata]: + """ + List all files for a given owner. + + Args: + owner: The owner whose files to list. + + Returns: + The list of files. + """ + pass + + @abstractmethod + async def delete_file(self, owner: str, file_id: str) -> bool: + """ + Delete a file. + + Args: + owner: The user who owns the files. + file_id: The ID of the file. + + Returns: + Whether the file was deleted successfully. + """ + pass + + @abstractmethod + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + """ + Get a file's content. + + Args: + owner: The user who owns the file. + file_id: The ID of the file. + + Returns: + The content of the file, or None if it does not exist. + """ + pass diff --git a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py new file mode 100644 index 00000000..dba41676 --- /dev/null +++ b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod +from typing import List + + +class LLMArtifactGateway(ABC): + """ + Abstract Base Class for interacting with llm artifacts. + """ + + @abstractmethod + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + """ + Gets a list of URLs for all files associated with a given model. + """ + pass diff --git a/server/llm_engine_server/domain/gateways/model_endpoints_schema_gateway.py b/model-engine/model_engine_server/domain/gateways/model_endpoints_schema_gateway.py similarity index 86% rename from server/llm_engine_server/domain/gateways/model_endpoints_schema_gateway.py rename to model-engine/model_engine_server/domain/gateways/model_endpoints_schema_gateway.py index dd2347d1..fea71580 100644 --- a/server/llm_engine_server/domain/gateways/model_endpoints_schema_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/model_endpoints_schema_gateway.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Sequence -from llm_engine_server.domain.entities import ModelEndpointRecord, ModelEndpointsSchema +from model_engine_server.domain.entities import ModelEndpointRecord, ModelEndpointsSchema class ModelEndpointsSchemaGateway: diff --git a/server/llm_engine_server/domain/gateways/model_primitive_gateway.py b/model-engine/model_engine_server/domain/gateways/model_primitive_gateway.py similarity index 84% rename from server/llm_engine_server/domain/gateways/model_primitive_gateway.py rename to model-engine/model_engine_server/domain/gateways/model_primitive_gateway.py index 365aadf5..0b58e17d 100644 --- a/server/llm_engine_server/domain/gateways/model_primitive_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/model_primitive_gateway.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod from typing import Optional -from llm_engine_server.domain.entities.model_bundle_entity import ModelBundleFrameworkType +from model_engine_server.domain.entities.model_bundle_entity import ModelBundleFrameworkType class ModelPrimitiveGateway(ABC): """ - Base class for interactions with Scale Model Primitive. + Base class for interactions with Model Primitive. """ @abstractmethod diff --git a/server/llm_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py similarity index 67% rename from server/llm_engine_server/domain/gateways/monitoring_metrics_gateway.py rename to model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py index 33ff03c2..28a561cf 100644 --- a/server/llm_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -24,6 +24,26 @@ def emit_successful_build_metric(self): """ + @abstractmethod + def emit_build_time_metric(self, duration_seconds: float): + """ + Service builder build time metric + """ + + @abstractmethod + def emit_image_build_cache_hit_metric(self, image_type: str): + """ + Service builder image build cache hit metric + + """ + + @abstractmethod + def emit_image_build_cache_miss_metric(self, image_type: str): + """ + Service builder image build cache miss metric + + """ + @abstractmethod def emit_docker_failed_build_metric(self): """ diff --git a/server/llm_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py similarity index 94% rename from server/llm_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py index cd4ded50..565cbe45 100644 --- a/server/llm_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import AsyncIterable -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, SyncEndpointPredictV1Response, ) diff --git a/server/llm_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py similarity index 94% rename from server/llm_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py index 0ad6921c..99ec36fa 100644 --- a/server/llm_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, SyncEndpointPredictV1Response, ) diff --git a/server/llm_engine_server/domain/gateways/task_queue_gateway.py b/model-engine/model_engine_server/domain/gateways/task_queue_gateway.py similarity index 91% rename from server/llm_engine_server/domain/gateways/task_queue_gateway.py rename to model-engine/model_engine_server/domain/gateways/task_queue_gateway.py index a667b813..bf41892f 100644 --- a/server/llm_engine_server/domain/gateways/task_queue_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/task_queue_gateway.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.tasks import CreateAsyncTaskV1Response, GetAsyncTaskV1Response +from model_engine_server.common.dtos.tasks import CreateAsyncTaskV1Response, GetAsyncTaskV1Response class TaskQueueGateway(ABC): diff --git a/server/llm_engine_server/domain/repositories/__init__.py b/model-engine/model_engine_server/domain/repositories/__init__.py similarity index 65% rename from server/llm_engine_server/domain/repositories/__init__.py rename to model-engine/model_engine_server/domain/repositories/__init__.py index 96236895..00718521 100644 --- a/server/llm_engine_server/domain/repositories/__init__.py +++ b/model-engine/model_engine_server/domain/repositories/__init__.py @@ -2,10 +2,14 @@ from .docker_image_batch_job_bundle_repository import DockerImageBatchJobBundleRepository from .docker_repository import DockerRepository +from .llm_fine_tune_events_repository import LLMFineTuneEventsRepository from .model_bundle_repository import ModelBundleRepository +from .trigger_repository import TriggerRepository __all__: Sequence[str] = [ "DockerRepository", "DockerImageBatchJobBundleRepository", + "LLMFineTuneEventsRepository", "ModelBundleRepository", + "TriggerRepository", ] diff --git a/server/llm_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py b/model-engine/model_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py similarity index 93% rename from server/llm_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py rename to model-engine/model_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py index 019f4150..eb3c7318 100644 --- a/server/llm_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py +++ b/model-engine/model_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Sequence -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.domain.entities import GpuType -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.domain.entities import GpuType +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) diff --git a/server/llm_engine_server/domain/repositories/docker_repository.py b/model-engine/model_engine_server/domain/repositories/docker_repository.py similarity index 94% rename from server/llm_engine_server/domain/repositories/docker_repository.py rename to model-engine/model_engine_server/domain/repositories/docker_repository.py index 184fd9da..b2d410a1 100644 --- a/server/llm_engine_server/domain/repositories/docker_repository.py +++ b/model-engine/model_engine_server/domain/repositories/docker_repository.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Optional -from llm_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse class DockerRepository(ABC): diff --git a/model-engine/model_engine_server/domain/repositories/llm_fine_tune_events_repository.py b/model-engine/model_engine_server/domain/repositories/llm_fine_tune_events_repository.py new file mode 100644 index 00000000..004739ab --- /dev/null +++ b/model-engine/model_engine_server/domain/repositories/llm_fine_tune_events_repository.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from typing import List + +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent + + +class LLMFineTuneEventsRepository(ABC): + @abstractmethod + async def get_fine_tune_events( + self, user_id: str, model_endpoint_name: str + ) -> List[LLMFineTuneEvent]: + pass + + @abstractmethod + async def initialize_events(self, user_id: str, model_endpoint_name: str) -> None: + pass diff --git a/server/llm_engine_server/domain/repositories/model_bundle_repository.py b/model-engine/model_engine_server/domain/repositories/model_bundle_repository.py similarity index 95% rename from server/llm_engine_server/domain/repositories/model_bundle_repository.py rename to model-engine/model_engine_server/domain/repositories/model_bundle_repository.py index 2dc74c9d..067df488 100644 --- a/server/llm_engine_server/domain/repositories/model_bundle_repository.py +++ b/model-engine/model_engine_server/domain/repositories/model_bundle_repository.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Sequence -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.domain.entities import ( ModelBundle, ModelBundleFlavors, ModelBundlePackagingType, diff --git a/model-engine/model_engine_server/domain/repositories/trigger_repository.py b/model-engine/model_engine_server/domain/repositories/trigger_repository.py new file mode 100644 index 00000000..a8fd096f --- /dev/null +++ b/model-engine/model_engine_server/domain/repositories/trigger_repository.py @@ -0,0 +1,96 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Sequence + +from model_engine_server.domain.entities.trigger_entity import Trigger + + +class TriggerRepository(ABC): + @abstractmethod + async def create_trigger( + self, + *, + name: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Optional[Dict[str, str]], + ) -> Trigger: + """ + Creates a trigger. + Args: + name: User-set name of trigger + created_by: User creating trigger + owner: Team owning trigger + cron_schedule: Schedule of k8s CronJob + docker_image_batch_job_bundle_id: ID of docker image batch job bundle used by trigger + default_job_config: Optional config to specify parameters injected at runtime + default_job_metadata: Optional metdata tags for k8s jobs spawned by trigger + + Returns: + A trigger entity + """ + pass + + @abstractmethod + async def list_triggers( + self, + owner: str, + ) -> Sequence[Trigger]: + """ + Lists all triggers with a given owner + Args: + owner: Owner of trigger(s) + + Returns: + Sequence of trigger entities + """ + pass + + @abstractmethod + async def get_trigger( + self, + trigger_id: str, + ) -> Optional[Trigger]: + """ + Retrieves a single trigger by ID + Args: + trigger_id: ID of trigger we want + + Returns: + Associated trigger entity or None if we couldn't find it + """ + pass + + @abstractmethod + async def update_trigger( + self, + trigger_id: str, + cron_schedule: str, + ) -> bool: + """ + Updates the specified trigger's cron schedule + Args: + trigger_id: ID of trigger we want + cron_schedule: new cron schedule to replace the original + + Returns: + True or False, whether the update of the trigger was successful or not + """ + pass + + @abstractmethod + async def delete_trigger( + self, + trigger_id: str, + ) -> bool: + """ + Deletes the specified trigger + Args: + trigger_id: ID of trigger we want to delete + + Returns: + True or False, whether the deletion of the trigger was successful or not + """ + pass diff --git a/server/llm_engine_server/domain/services/__init__.py b/model-engine/model_engine_server/domain/services/__init__.py similarity index 82% rename from server/llm_engine_server/domain/services/__init__.py rename to model-engine/model_engine_server/domain/services/__init__.py index 723f62db..508a68e1 100644 --- a/server/llm_engine_server/domain/services/__init__.py +++ b/model-engine/model_engine_server/domain/services/__init__.py @@ -2,6 +2,7 @@ from .batch_job_service import BatchJobService from .endpoint_builder_service import EndpointBuilderService +from .llm_fine_tuning_service import LLMFineTuningService from .llm_model_endpoint_service import LLMModelEndpointService from .model_endpoint_service import ModelEndpointService @@ -9,5 +10,6 @@ "BatchJobService", "EndpointBuilderService", "LLMModelEndpointService", + "LLMFineTuningService", "ModelEndpointService", ] diff --git a/server/llm_engine_server/domain/services/batch_job_service.py b/model-engine/model_engine_server/domain/services/batch_job_service.py similarity index 92% rename from server/llm_engine_server/domain/services/batch_job_service.py rename to model-engine/model_engine_server/domain/services/batch_job_service.py index 9e92843d..4bac6e63 100644 --- a/server/llm_engine_server/domain/services/batch_job_service.py +++ b/model-engine/model_engine_server/domain/services/batch_job_service.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Dict, Optional -from llm_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests -from llm_engine_server.domain.entities import BatchJob, BatchJobSerializationFormat +from model_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests +from model_engine_server.domain.entities import BatchJob, BatchJobSerializationFormat class BatchJobService(ABC): diff --git a/server/llm_engine_server/domain/services/endpoint_builder_service.py b/model-engine/model_engine_server/domain/services/endpoint_builder_service.py similarity index 89% rename from server/llm_engine_server/domain/services/endpoint_builder_service.py rename to model-engine/model_engine_server/domain/services/endpoint_builder_service.py index ff521817..4b9079e0 100644 --- a/server/llm_engine_server/domain/services/endpoint_builder_service.py +++ b/model-engine/model_engine_server/domain/services/endpoint_builder_service.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from llm_engine_server.common.dtos.endpoint_builder import ( +from model_engine_server.common.dtos.endpoint_builder import ( BuildEndpointRequest, BuildEndpointResponse, ) diff --git a/model-engine/model_engine_server/domain/services/llm_fine_tuning_service.py b/model-engine/model_engine_server/domain/services/llm_fine_tuning_service.py new file mode 100644 index 00000000..e55d0527 --- /dev/null +++ b/model-engine/model_engine_server/domain/services/llm_fine_tuning_service.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from model_engine_server.domain.entities import FineTuneHparamValueType +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob + + +class LLMFineTuningService(ABC): + @abstractmethod + async def create_fine_tune( + self, + created_by: str, + owner: str, + model: str, + training_file: str, + validation_file: Optional[str], + fine_tuning_method: str, + hyperparameters: Dict[str, FineTuneHparamValueType], + fine_tuned_model: str, + wandb_config: Optional[Dict[str, Any]], + ) -> str: + pass + + @abstractmethod + async def get_fine_tune(self, owner: str, fine_tune_id: str) -> Optional[DockerImageBatchJob]: + pass + + @abstractmethod + async def list_fine_tunes(self, owner: str) -> List[DockerImageBatchJob]: + pass + + @abstractmethod + async def cancel_fine_tune(self, owner: str, fine_tune_id: str) -> bool: + pass + + @abstractmethod + async def get_fine_tune_model_name_from_id( + self, owner: str, fine_tune_id: str + ) -> Optional[str]: + pass diff --git a/server/llm_engine_server/domain/services/llm_model_endpoint_service.py b/model-engine/model_engine_server/domain/services/llm_model_endpoint_service.py similarity index 89% rename from server/llm_engine_server/domain/services/llm_model_endpoint_service.py rename to model-engine/model_engine_server/domain/services/llm_model_endpoint_service.py index f06279d5..07a7f9f6 100644 --- a/server/llm_engine_server/domain/services/llm_model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/llm_model_endpoint_service.py @@ -2,8 +2,8 @@ from abc import ABC, abstractmethod from typing import List, Optional -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.domain.entities import ModelEndpoint +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.domain.entities import ModelEndpoint class LLMModelEndpointService(ABC): diff --git a/server/llm_engine_server/domain/services/model_endpoint_service.py b/model-engine/model_engine_server/domain/services/model_endpoint_service.py similarity index 95% rename from server/llm_engine_server/domain/services/model_endpoint_service.py rename to model-engine/model_engine_server/domain/services/model_endpoint_service.py index b53b5ace..90b50983 100644 --- a/server/llm_engine_server/domain/services/model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/model_endpoint_service.py @@ -2,8 +2,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -13,7 +13,7 @@ ModelEndpointType, StorageSpecificationType, ) -from llm_engine_server.domain.gateways import ( +from model_engine_server.domain.gateways import ( AsyncModelEndpointInferenceGateway, StreamingModelEndpointInferenceGateway, SyncModelEndpointInferenceGateway, @@ -74,6 +74,7 @@ async def create_model_endpoint( results_s3_bucket: str, prewarm: bool, high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, owner: str, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], @@ -106,6 +107,7 @@ async def create_model_endpoint( to False high_priority: Makes all pods for this endpoint higher priority to enable faster pod spinup time. Higher priority pods will displace the lower priority dummy pods from shared pool. + billing_tags: Tags that get passed to scale's billing infra owner: The team ID of the creator of the model endpoint. default_callback_url: The default callback URL to use for the model endpoint. default_callback_auth: The default callback auth to use for the model endpoint. @@ -202,6 +204,7 @@ async def update_model_endpoint( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, public_inference: Optional[bool] = None, @@ -229,6 +232,7 @@ async def update_model_endpoint( to False high_priority: Makes all pods for this endpoint higher priority to enable faster pod spinup time. Higher priority pods will displace the lower priority dummy pods from shared pool. + billing_tags: Tags that get passed to scale's billing infra default_callback_url: The default callback URL to use for the model endpoint. default_callback_auth: The default callback auth to use for the model endpoint. public_inference: Whether to allow public inference. diff --git a/server/llm_engine_server/domain/use_cases/__init__.py b/model-engine/model_engine_server/domain/use_cases/__init__.py similarity index 100% rename from server/llm_engine_server/domain/use_cases/__init__.py rename to model-engine/model_engine_server/domain/use_cases/__init__.py diff --git a/server/llm_engine_server/domain/use_cases/async_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py similarity index 84% rename from server/llm_engine_server/domain/use_cases/async_inference_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py index a0a3ec9d..3b8a5ddf 100644 --- a/server/llm_engine_server/domain/use_cases/async_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py @@ -1,19 +1,19 @@ -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.entities import ModelEndpointType -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.services.model_endpoint_service import ModelEndpointService +from model_engine_server.domain.entities import ModelEndpointType +from model_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException +from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService DEFAULT_TASK_TIMEOUT_SECONDS = 86400 @@ -25,7 +25,7 @@ class CreateAsyncInferenceTaskV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, model_endpoint_id: str, request: EndpointPredictV1Request diff --git a/server/llm_engine_server/domain/use_cases/batch_job_use_cases.py b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py similarity index 81% rename from server/llm_engine_server/domain/use_cases/batch_job_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py index 7b5f7520..0a1bb1f5 100644 --- a/server/llm_engine_server/domain/use_cases/batch_job_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py @@ -1,6 +1,7 @@ from datetime import datetime +from typing import Optional -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.dtos.batch_jobs import ( CreateBatchJobV1Request, CreateBatchJobV1Response, CreateDockerImageBatchJobResourceRequests, @@ -8,35 +9,36 @@ CreateDockerImageBatchJobV1Response, GetBatchJobV1Response, GetDockerImageBatchJobV1Response, + ListDockerImageBatchJobsV1Response, UpdateBatchJobV1Request, UpdateBatchJobV1Response, UpdateDockerImageBatchJobV1Request, UpdateDockerImageBatchJobV1Response, ) -from llm_engine_server.common.resource_limits import validate_resource_requests -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.resource_limits import validate_resource_requests +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.entities import ModelEndpointType -from llm_engine_server.domain.gateways.docker_image_batch_job_gateway import ( - DockerImageBatchJobGateway, -) -from llm_engine_server.domain.repositories import ( +from model_engine_server.domain.entities import ModelEndpointType +from model_engine_server.domain.gateways import CronJobGateway, DockerImageBatchJobGateway +from model_engine_server.domain.repositories import ( DockerImageBatchJobBundleRepository, DockerRepository, ModelBundleRepository, + TriggerRepository, ) -from llm_engine_server.domain.services import BatchJobService, ModelEndpointService -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.services import BatchJobService, ModelEndpointService +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( validate_deployment_resources, + validate_labels, ) logger = make_logger(filename_wo_ext(__file__)) @@ -56,11 +58,12 @@ def __init__( self.batch_job_service = batch_job_service self.model_bundle_repository = model_bundle_repository self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, request: CreateBatchJobV1Request ) -> CreateBatchJobV1Response: + validate_labels(request.labels) validate_deployment_resources( min_workers=0, max_workers=request.resource_requests.max_workers, @@ -109,7 +112,7 @@ class GetBatchJobV1UseCase: """ def __init__(self, batch_job_service: BatchJobService): - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() self.batch_job_service = batch_job_service async def execute(self, user: User, batch_job_id: str) -> GetBatchJobV1Response: @@ -143,7 +146,7 @@ class UpdateBatchJobV1UseCase: def __init__(self, batch_job_service: BatchJobService): self.batch_job_service = batch_job_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, batch_job_id: str, request: UpdateBatchJobV1Request @@ -170,7 +173,7 @@ def __init__( self.docker_image_batch_job_gateway = docker_image_batch_job_gateway self.docker_image_batch_job_bundle_repository = docker_image_batch_job_bundle_repository self.docker_repository = docker_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, request: CreateDockerImageBatchJobV1Request @@ -238,6 +241,16 @@ async def execute( gpu_type=final_requests.gpu_type, ) + validate_labels(request.labels) + + if ( + request.override_job_max_runtime_s is not None + and request.override_job_max_runtime_s < 1 + ): + raise ObjectHasInvalidValueException( + "Please supply a positive integer value for batch job's maximum runtime (`override_job_max_runtime_s`)" + ) + job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( created_by=user.user_id, owner=user.team_id, @@ -249,6 +262,7 @@ async def execute( resource_requests=final_requests, labels=request.labels, mount_location=batch_bundle.mount_location, + override_job_max_runtime_s=request.override_job_max_runtime_s, ) return CreateDockerImageBatchJobV1Response(job_id=job_id) @@ -277,6 +291,32 @@ async def execute(self, user: User, batch_job_id: str) -> GetDockerImageBatchJob return GetDockerImageBatchJobV1Response(status=job.status) +class ListDockerImageBatchJobsV1UseCase: + def __init__( + self, + trigger_repository: TriggerRepository, + cron_job_gateway: CronJobGateway, + ): + self.trigger_repository = trigger_repository + self.cron_job_gateway = cron_job_gateway + self.authz_module = LiveAuthorizationModule() + + async def execute( + self, user: User, trigger_id: Optional[str] + ) -> ListDockerImageBatchJobsV1Response: + if trigger_id: + trigger = await self.trigger_repository.get_trigger(trigger_id=trigger_id) + if trigger is None: + raise ObjectNotFoundException + if not self.authz_module.check_access_read_owned_entity(user, trigger): + raise ObjectNotAuthorizedException( + f"User {user} is not authorized for trigger {trigger_id}" + ) + + jobs = await self.cron_job_gateway.list_jobs(owner=user.team_id, trigger_id=trigger_id) + return ListDockerImageBatchJobsV1Response(jobs=jobs) + + class UpdateDockerImageBatchJobV1UseCase: """ Use case for cancelling a batch job. diff --git a/server/llm_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py b/model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py similarity index 90% rename from server/llm_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py index 77c97329..29d40fe8 100644 --- a/server/llm_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py @@ -1,21 +1,21 @@ from typing import Optional -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.dtos.batch_jobs import ( CreateDockerImageBatchJobBundleV1Request, CreateDockerImageBatchJobBundleV1Response, DockerImageBatchJobBundleV1Response, ListDockerImageBatchJobBundleV1Response, ) -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.repositories import DockerImageBatchJobBundleRepository +from model_engine_server.domain.repositories import DockerImageBatchJobBundleRepository class CreateDockerImageBatchJobBundleV1UseCase: @@ -88,7 +88,7 @@ async def execute( class GetDockerImageBatchJobBundleByIdV1UseCase: def __init__(self, docker_image_batch_job_bundle_repo: DockerImageBatchJobBundleRepository): self.docker_image_batch_job_bundle_repo = docker_image_batch_job_bundle_repo - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, docker_image_batch_job_bundle_id: str diff --git a/model-engine/model_engine_server/domain/use_cases/file_use_cases.py b/model-engine/model_engine_server/domain/use_cases/file_use_cases.py new file mode 100644 index 00000000..e646e8a0 --- /dev/null +++ b/model-engine/model_engine_server/domain/use_cases/file_use_cases.py @@ -0,0 +1,97 @@ +from model_engine_server.common.dtos.files import ( + DeleteFileResponse, + GetFileContentResponse, + GetFileResponse, + ListFilesResponse, + UploadFileResponse, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ObjectNotFoundException +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.gateways import FileStorageGateway + +logger = make_logger(filename_wo_ext(__file__)) + + +class UploadFileUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, filename: str, content: bytes) -> UploadFileResponse: + file_id = await self.file_storage_gateway.upload_file( + owner=user.team_id, + filename=filename, + content=content, + ) + return UploadFileResponse( + id=file_id, + ) + + +class GetFileUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, file_id: str) -> GetFileResponse: + file_metadata = await self.file_storage_gateway.get_file( + owner=user.team_id, + file_id=file_id, + ) + if file_metadata is None: + raise ObjectNotFoundException + return GetFileResponse( + id=file_metadata.id, + filename=file_metadata.filename, + size=file_metadata.size, + ) + + +class ListFilesUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User) -> ListFilesResponse: + files = await self.file_storage_gateway.list_files( + owner=user.team_id, + ) + return ListFilesResponse( + files=[ + GetFileResponse( + id=file_metadata.id, + filename=file_metadata.filename, + size=file_metadata.size, + ) + for file_metadata in files + ] + ) + + +class DeleteFileUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, file_id: str) -> DeleteFileResponse: + deleted = await self.file_storage_gateway.delete_file( + owner=user.team_id, + file_id=file_id, + ) + return DeleteFileResponse( + deleted=deleted, + ) + + +class GetFileContentUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, file_id: str) -> GetFileContentResponse: + content = await self.file_storage_gateway.get_file_content( + owner=user.team_id, + file_id=file_id, + ) + if content is None: + raise ObjectNotFoundException + return GetFileContentResponse( + id=file_id, + content=content, + ) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py new file mode 100644 index 00000000..331b0e48 --- /dev/null +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -0,0 +1,217 @@ +import datetime +import re + +from model_engine_server.common.dtos.llms import ( + CancelFineTuneResponse, + CreateFineTuneRequest, + CreateFineTuneResponse, + GetFineTuneEventsResponse, + GetFineTuneResponse, + ListFineTunesResponse, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ObjectNotFoundException +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import BatchJobStatus +from model_engine_server.domain.exceptions import InvalidRequestException, LLMFineTuningQuotaReached +from model_engine_server.domain.gateways import FileStorageGateway +from model_engine_server.domain.repositories import LLMFineTuneEventsRepository +from model_engine_server.domain.services import LLMFineTuningService, ModelEndpointService + +DEFAULT_FINE_TUNING_METHOD = "lora" + +MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER = 5 +MAX_LLM_ENDPOINTS_PER_INTERNAL_USER = 15 + +MAX_SUFFIX_LENGTH = 28 +# k8s labels need to be <= 62 characters, timestamp takes 13 characters, 2 characters for periods, +# model name is currently 17 long, but want to add a bit of buffer. + +logger = make_logger(filename_wo_ext(__file__)) + + +def is_model_name_suffix_valid(model_name: str): + pattern = "^[A-Za-z0-9-]+$" # TODO can we do spaces and underscores + return bool(re.match(pattern, model_name)) and len(model_name) <= MAX_SUFFIX_LENGTH + + +def ensure_model_name_is_valid_k8s_label(model_name: str): + """ + Ensure the model name is usable as a k8s label, + since we will end up creating a deployment with the model name as a label. + """ + return re.sub("[^-A-Za-z0-9_.]", "-", model_name).lstrip("-_.")[:62].rstrip("-_.") + + +class CreateFineTuneV1UseCase: + def __init__( + self, + llm_fine_tuning_service: LLMFineTuningService, + model_endpoint_service: ModelEndpointService, + llm_fine_tune_events_repository: LLMFineTuneEventsRepository, + file_storage_gateway: FileStorageGateway, + ): + self.llm_fine_tuning_service = llm_fine_tuning_service + self.model_endpoint_service = model_endpoint_service + self.llm_fine_tune_events_repository = llm_fine_tune_events_repository + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, request: CreateFineTuneRequest) -> CreateFineTuneResponse: + di_batch_jobs = await self.llm_fine_tuning_service.list_fine_tunes( + owner=user.team_id, + ) + in_progress_jobs = [ + job + for job in di_batch_jobs + if job.status in [BatchJobStatus.PENDING, BatchJobStatus.RUNNING] + ] + model_endpoints = await self.model_endpoint_service.list_model_endpoints( + owner=user.team_id, name=None, order_by=None + ) + + current_jobs_and_endpoints = len(in_progress_jobs) + len(model_endpoints) + + max_llm_endpoints_per_user = ( + MAX_LLM_ENDPOINTS_PER_INTERNAL_USER + if user.is_privileged_user + else MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER + ) + + if current_jobs_and_endpoints >= max_llm_endpoints_per_user: + raise LLMFineTuningQuotaReached( + f"Limit {max_llm_endpoints_per_user} fine-tunes/fine-tuned endpoints per user. " + f"Cancel/delete a total of " + f"{current_jobs_and_endpoints - max_llm_endpoints_per_user + 1} pending or " + f"running fine-tune(s) or fine-tuned endpoints to run another fine-tune." + ) + + if request.suffix is not None and not is_model_name_suffix_valid(request.suffix): + raise InvalidRequestException( + f"User-provided suffix is invalid, must only contain alphanumeric characters and dashes and be at most {MAX_SUFFIX_LENGTH} characters" + ) + time_now = datetime.datetime.utcnow().strftime("%y%m%d-%H%M%S") + # Colons breaks our download command. Keep delimiters as `.` + fine_tuned_model = ( + f"{request.model}.{request.suffix}.{time_now}" + if request.suffix is not None + else f"{request.model}.{time_now}" + ) + + # We need to ensure fine_tuned_model conforms to the k8s label spec + # This is unfortunately a leaky abstraction. This likely goes away if we redo how we implement fine-tuned + # models though + fine_tuned_model = ensure_model_name_is_valid_k8s_label(fine_tuned_model) + + if request.training_file.startswith("file-"): + training_file = await self.file_storage_gateway.get_url_from_id( + user.team_id, request.training_file + ) + if training_file is None: + raise ObjectNotFoundException("Training file does not exist") + else: + training_file = request.training_file + + if request.validation_file is not None and request.validation_file.startswith("file-"): + validation_file = await self.file_storage_gateway.get_url_from_id( + user.team_id, request.validation_file + ) + if validation_file is None: + raise ObjectNotFoundException("Validation file does not exist") + else: + validation_file = request.validation_file + + await self.llm_fine_tune_events_repository.initialize_events(user.team_id, fine_tuned_model) + fine_tune_id = await self.llm_fine_tuning_service.create_fine_tune( + created_by=user.user_id, + owner=user.team_id, + model=request.model, + training_file=training_file, + validation_file=validation_file, + fine_tuning_method=DEFAULT_FINE_TUNING_METHOD, + hyperparameters=request.hyperparameters, + fine_tuned_model=fine_tuned_model, + wandb_config=request.wandb_config, + ) + return CreateFineTuneResponse( + id=fine_tune_id, + ) + + +class GetFineTuneV1UseCase: + def __init__(self, llm_fine_tuning_service: LLMFineTuningService): + self.llm_fine_tuning_service = llm_fine_tuning_service + + async def execute(self, user: User, fine_tune_id: str) -> GetFineTuneResponse: + di_batch_job = await self.llm_fine_tuning_service.get_fine_tune( + owner=user.team_id, + fine_tune_id=fine_tune_id, + ) + if di_batch_job is None: + raise ObjectNotFoundException + if di_batch_job.annotations: + fine_tuned_model = di_batch_job.annotations.get("fine_tuned_model") + else: + fine_tuned_model = None + logger.warning(f"Fine-tune {di_batch_job.id} has no annotations. This is unexpected.") + return GetFineTuneResponse( + id=di_batch_job.id, + fine_tuned_model=fine_tuned_model, + status=di_batch_job.status, + ) + + +class ListFineTunesV1UseCase: + def __init__(self, llm_fine_tuning_service: LLMFineTuningService): + self.llm_fine_tuning_service = llm_fine_tuning_service + + async def execute(self, user: User) -> ListFineTunesResponse: + di_batch_jobs = await self.llm_fine_tuning_service.list_fine_tunes( + owner=user.team_id, + ) + return ListFineTunesResponse( + jobs=[ + GetFineTuneResponse( + id=job.id, + status=job.status, + fine_tuned_model=job.annotations.get("fine_tuned_model") + if job.annotations + else None, + ) + for job in di_batch_jobs + ] + ) + + +class CancelFineTuneV1UseCase: + def __init__(self, llm_fine_tuning_service: LLMFineTuningService): + self.llm_fine_tuning_service = llm_fine_tuning_service + + async def execute(self, user: User, fine_tune_id: str) -> CancelFineTuneResponse: + success = await self.llm_fine_tuning_service.cancel_fine_tune( + owner=user.team_id, + fine_tune_id=fine_tune_id, + ) + return CancelFineTuneResponse( + success=success, + ) + + +class GetFineTuneEventsV1UseCase: + def __init__( + self, + llm_fine_tune_events_repository: LLMFineTuneEventsRepository, + llm_fine_tuning_service: LLMFineTuningService, + ): + self.llm_fine_tune_events_repository = llm_fine_tune_events_repository + self.llm_fine_tuning_service = llm_fine_tuning_service + + async def execute(self, user: User, fine_tune_id: str) -> GetFineTuneEventsResponse: + model_endpoint_name = await self.llm_fine_tuning_service.get_fine_tune_model_name_from_id( + user.team_id, fine_tune_id + ) + if model_endpoint_name is None: + raise ObjectNotFoundException(f"Fine-tune with id {fine_tune_id} not found") + events = await self.llm_fine_tune_events_repository.get_fine_tune_events( + user_id=user.team_id, model_endpoint_name=model_endpoint_name + ) + return GetFineTuneEventsResponse(events=events) diff --git a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py similarity index 75% rename from server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 08b25354..fa4eef0f 100644 --- a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1,9 +1,18 @@ +""" +TODO figure out how to do: (or if we want to do it) +List model endpoint history: GET model-endpoints//history +Read model endpoint creation logs: GET model-endpoints//creation-logs +""" + import json +import math import os from dataclasses import asdict -from typing import Any, AsyncIterable, Dict, Optional +from typing import Any, AsyncIterable, Dict, List, Optional +from uuid import uuid4 -from llm_engine_server.common.dtos.llms import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.llms import ( CompletionOutput, CompletionStreamOutput, CompletionStreamV1Request, @@ -14,22 +23,22 @@ CreateLLMModelEndpointV1Response, GetLLMModelEndpointV1Response, ListLLMModelEndpointsV1Response, + ModelDownloadRequest, + ModelDownloadResponse, + TokenOutput, ) -from llm_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request, TaskStatus -from llm_engine_server.common.resource_limits import validate_resource_requests -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request, TaskStatus +from model_engine_server.common.resource_limits import validate_resource_requests +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, -) -from llm_engine_server.domain.entities import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import ( LLMInferenceFramework, LLMMetadata, LLMSource, @@ -41,18 +50,24 @@ RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor, ) -from llm_engine_server.domain.exceptions import ( +from model_engine_server.domain.exceptions import ( EndpointLabelsException, EndpointUnsupportedInferenceTypeException, ) -from llm_engine_server.domain.repositories import ModelBundleRepository -from llm_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService +from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway +from model_engine_server.domain.repositories import ModelBundleRepository +from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from ...common.datadog_utils import add_trace_request_id +from ..authorization.live_authorization_module import LiveAuthorizationModule from .model_bundle_use_cases import CreateModelBundleV2UseCase from .model_endpoint_use_cases import ( _handle_post_inference_hooks, model_endpoint_entity_to_get_model_endpoint_response, + validate_billing_tags, validate_deployment_resources, + validate_labels, validate_post_inference_hooks, ) @@ -76,6 +91,12 @@ "mpt-7b-instruct": "mosaicml/mpt-7b-instruct", "flan-t5-xxl": "google/flan-t5-xxl", "llama-7b": "decapoda-research/llama-7b-hf", + "llama-2-7b": "meta-llama/Llama-2-7b-hf", + "llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", + "llama-2-13b": "meta-llama/Llama-2-13b-hf", + "llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", + "llama-2-70b": "meta-llama/Llama-2-70b-hf", + "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", "falcon-7b": "tiiuae/falcon-7b", "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", "falcon-40b": "tiiuae/falcon-40b", @@ -97,6 +118,7 @@ def _model_endpoint_entity_to_get_llm_model_endpoint_response( name=model_endpoint.record.name, model_name=llm_metadata["model_name"], source=llm_metadata["source"], + status=model_endpoint.record.status, inference_framework=llm_metadata["inference_framework"], inference_framework_image_tag=llm_metadata["inference_framework_image_tag"], num_shards=llm_metadata["num_shards"], @@ -132,7 +154,7 @@ def __init__( model_bundle_repository: ModelBundleRepository, model_endpoint_service: ModelEndpointService, ): - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() self.create_model_bundle_use_case = create_model_bundle_use_case self.model_bundle_repository = model_bundle_repository self.model_endpoint_service = model_endpoint_service @@ -192,6 +214,14 @@ async def create_text_generation_inference_bundle( checkpoint_path: Optional[str], ): command = [] + + # TGI requires max_input_length < max_total_tokens + max_input_length = 2047 + max_total_tokens = 2048 + if "llama-2" in model_name: + max_input_length = 4095 + max_total_tokens = 4096 + if checkpoint_path is not None: if checkpoint_path.startswith("s3://"): base_path = checkpoint_path.split("/")[-1] @@ -199,9 +229,14 @@ async def create_text_generation_inference_bundle( subcommands = [] s5cmd = "s5cmd" - subcommands.append( - f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}" - ) + # This is a hack for now to skip installing s5cmd for text-generation-inference:0.9.3-launch_s3, + # which has s5cmd binary already baked in. Otherwise, install s5cmd if it's not already available + if framework_image_tag != "0.9.3-launch_s3": + subcommands.append( + f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}" + ) + else: + s5cmd = "./s5cmd" if base_path.endswith(".tar"): # If the checkpoint file is a tar file, extract it into final_weights_folder @@ -218,7 +253,7 @@ async def create_text_generation_inference_bundle( ) subcommands.append( - f"text-generation-launcher --hostname :: --model-id ./{final_weights_folder} --num-shard {num_shards} --port 5005" + f"text-generation-launcher --hostname :: --model-id ./{final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}" ) if quantize: @@ -247,9 +282,13 @@ async def create_text_generation_inference_bundle( "5005", "--hostname", "::", + "--max-input-length", + str(max_input_length), + "--max-total-tokens", + str(max_total_tokens), ] if quantize: - command = command + ["--quantize", str(quantize)] + command = command + [f"--quantize {quantize}"] return ( await self.create_model_bundle_use_case.execute( @@ -259,12 +298,12 @@ async def create_text_generation_inference_bundle( schema_location="TBA", flavor=StreamingEnhancedRunnableImageFlavor( flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, - repository="ghcr.io/huggingface/text-generation-inference", # TODO: let user choose repo + repository=hmi_config.tgi_repository, tag=framework_image_tag, command=command, streaming_command=command, protocol="http", - readiness_initial_delay_seconds=60, + readiness_initial_delay_seconds=10, healthcheck_route="/health", predict_route="/generate", streaming_predict_route="/generate_stream", @@ -272,6 +311,10 @@ async def create_text_generation_inference_bundle( ), metadata={}, ), + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. ) ).model_bundle_id @@ -316,6 +359,7 @@ async def create_deepspeed_bundle( ), metadata={}, ), + do_auth_check=False, ) ).model_bundle_id else: @@ -349,6 +393,7 @@ async def create_deepspeed_bundle( ), metadata={}, ), + do_auth_check=False, ) ).model_bundle_id @@ -362,6 +407,8 @@ async def execute( ) if request.labels is None: raise EndpointLabelsException("Endpoint labels cannot be None!") + validate_labels(request.labels) + validate_billing_tags(request.billing_tags) validate_post_inference_hooks(user, request.post_inference_hooks) validate_model_name(request.model_name, request.inference_framework) validate_num_shards(request.num_shards, request.inference_framework, request.gpus) @@ -494,7 +541,7 @@ class GetLLMModelEndpointByNameV1UseCase: def __init__(self, llm_model_endpoint_service: LLMModelEndpointService): self.llm_model_endpoint_service = llm_model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndpointV1Response: """ @@ -529,6 +576,18 @@ class DeleteLLMModelEndpointByIdV1UseCase: pass +def deepspeed_result_to_tokens(result: Dict[str, Any]) -> List[TokenOutput]: + tokens = [] + for i in range(len(result["token_probs"]["token_probs"])): + tokens.append( + TokenOutput( + token=result["token_probs"]["tokens"][i], + log_prob=math.log(result["token_probs"]["token_probs"][i]), + ) + ) + return tokens + + class CompletionSyncV1UseCase: """ Use case for running a prompt completion on an LLM endpoint. @@ -541,27 +600,43 @@ def __init__( ): self.model_endpoint_service = model_endpoint_service self.llm_model_endpoint_service = llm_model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() def model_output_to_completion_output( self, model_output: Dict[str, Any], model_endpoint: ModelEndpoint, + with_token_probs: Optional[bool], ) -> CompletionOutput: model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: completion_token_count = len(model_output["token_probs"]["tokens"]) + tokens = None + if with_token_probs: + tokens = deepspeed_result_to_tokens(model_output) return CompletionOutput( text=model_output["text"], num_completion_tokens=completion_token_count, + tokens=tokens, ) elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: - return CompletionOutput( - text=model_output["generated_text"], - # len(model_output["details"]["prefill"]) does not return the correct value reliably - num_completion_tokens=model_output["details"]["generated_tokens"], - ) + try: + tokens = None + if with_token_probs: + tokens = [ + TokenOutput(token=t["text"], log_prob=t["logprob"]) + for t in model_output["details"]["tokens"] + ] + return CompletionOutput( + text=model_output["generated_text"], + # len(model_output["details"]["prefill"]) does not return the correct value reliably + num_completion_tokens=model_output["details"]["generated_tokens"], + tokens=tokens, + ) + except Exception as e: + logger.exception(f"Error parsing text-generation-inference output {model_output}") + raise e else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" @@ -586,6 +661,9 @@ async def execute( ObjectNotAuthorizedException: If the owner does not own the model endpoint. """ + request_id = str(uuid4()) + add_trace_request_id(request_id) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( owner=user.team_id, name=model_endpoint_name, order_by=None ) @@ -619,7 +697,7 @@ async def execute( endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) if endpoint_content.inference_framework == LLMInferenceFramework.DEEPSPEED: args: Any = { - "prompts": request.prompts, + "prompts": [request.prompt], "token_probs": True, "generate_kwargs": { "do_sample": True, @@ -628,62 +706,63 @@ async def execute( }, "serialize_results_as_string": False, } + if request.stop_sequences is not None: + # Deepspeed models only accepts one stop sequence + args["stop_sequence"] = request.stop_sequences[0] inference_request = EndpointPredictV1Request(args=args) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, - predict_request=inference_request, + topic=model_endpoint.record.destination, predict_request=inference_request ) if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: return CompletionSyncV1Response( - status=predict_result.status, - outputs=[ - self.model_output_to_completion_output(result, model_endpoint) - for result in predict_result.result["result"] - ], + request_id=request_id, + output=self.model_output_to_completion_output( + predict_result.result["result"][0], + model_endpoint, + request.return_token_log_probs, + ), ) else: return CompletionSyncV1Response( - status=predict_result.status, - outputs=[], - traceback=predict_result.traceback, + request_id=request_id, + output=None, ) elif ( endpoint_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE ): - outputs = [] - - for prompt in request.prompts: - tgi_args: Any = { - "inputs": prompt, - "parameters": { - "max_new_tokens": request.max_new_tokens, - "temperature": request.temperature, - "decoder_input_details": True, - }, - } - inference_request = EndpointPredictV1Request(args=tgi_args) - predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, - predict_request=inference_request, - ) + tgi_args: Any = { + "inputs": request.prompt, + "parameters": { + "max_new_tokens": request.max_new_tokens, + "decoder_input_details": True, + }, + } + if request.stop_sequences is not None: + tgi_args["parameters"]["stop"] = request.stop_sequences + if request.temperature > 0: + tgi_args["parameters"]["temperature"] = request.temperature + tgi_args["parameters"]["do_sample"] = True - if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: - return CompletionSyncV1Response( - status=predict_result.status, - outputs=[], - traceback=predict_result.traceback, - ) + inference_request = EndpointPredictV1Request(args=tgi_args) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, predict_request=inference_request + ) - outputs.append(json.loads(predict_result.result["result"])) + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + return CompletionSyncV1Response( + request_id=request_id, + output=None, + ) + + output = json.loads(predict_result.result["result"]) return CompletionSyncV1Response( - status=predict_result.status, - outputs=[ - self.model_output_to_completion_output(output, model_endpoint) - for output in outputs - ], + request_id=request_id, + output=self.model_output_to_completion_output( + output, model_endpoint, request.return_token_log_probs + ), ) else: raise EndpointUnsupportedInferenceTypeException( @@ -703,7 +782,7 @@ def __init__( ): self.model_endpoint_service = model_endpoint_service self.llm_model_endpoint_service = llm_model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, model_endpoint_name: str, request: CompletionStreamV1Request @@ -724,6 +803,8 @@ async def execute( ObjectNotAuthorizedException: If the owner does not own the model endpoint. """ + request_id = str(uuid4()) + add_trace_request_id(request_id) model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( owner=user.team_id, name=model_endpoint_name, order_by=None ) @@ -768,14 +849,22 @@ async def execute( }, "serialize_results_as_string": False, } + if request.stop_sequences is not None: + # Deepspeed models only accepts one stop sequence + args["stop_sequence"] = request.stop_sequences[0] elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: args = { "inputs": request.prompt, "parameters": { "max_new_tokens": request.max_new_tokens, - "temperature": request.temperature, }, } + if request.stop_sequences is not None: + args["parameters"]["stop"] = request.stop_sequences + if request.temperature > 0: + args["parameters"]["temperature"] = request.temperature + args["parameters"]["do_sample"] = True + inference_request = EndpointPredictV1Request(args=args) predict_result = inference_gateway.streaming_predict( @@ -789,7 +878,7 @@ async def execute( if res.status == TaskStatus.SUCCESS and result is not None: if "token" in result["result"]: yield CompletionStreamV1Response( - status=res.status, + request_id=request_id, output=CompletionStreamOutput( text=result["result"]["token"], finished=False, @@ -801,7 +890,7 @@ async def execute( result["result"]["response"][0]["token_probs"]["tokens"] ) yield CompletionStreamV1Response( - status=res.status, + request_id=request_id, output=CompletionStreamOutput( text=result["result"]["response"][0]["text"], finished=True, @@ -810,9 +899,8 @@ async def execute( ) else: yield CompletionStreamV1Response( - status=res.status, + request_id=request_id, output=None, - traceback=res.traceback, ) elif ( model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE @@ -825,21 +913,60 @@ async def execute( num_completion_tokens += 1 + token = None + if request.return_token_log_probs: + token = TokenOutput( + token=result["result"]["token"]["text"], + log_prob=result["result"]["token"]["logprob"], + ) yield CompletionStreamV1Response( - status=res.status, + request_id=request_id, output=CompletionStreamOutput( text=result["result"]["token"]["text"], finished=finished, num_completion_tokens=num_completion_tokens, + token=token, ), ) else: yield CompletionStreamV1Response( - status=res.status, + request_id=request_id, output=None, - traceback=res.traceback, ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" ) + + +class ModelDownloadV1UseCase: + def __init__( + self, + filesystem_gateway: FilesystemGateway, + model_endpoint_service: ModelEndpointService, + llm_artifact_gateway: LLMArtifactGateway, + ): + self.filesystem_gateway = filesystem_gateway + self.model_endpoint_service = model_endpoint_service + self.llm_artifact_gateway = llm_artifact_gateway + + async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownloadResponse: + model_endpoints = await self.model_endpoint_service.list_model_endpoints( + owner=user.team_id, name=request.model_name, order_by=None + ) + if len(model_endpoints) == 0: + raise ObjectNotFoundException + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {request.model_name}, got {len(model_endpoints)}" + ) + model_files = self.llm_artifact_gateway.get_model_weights_urls( + user.team_id, request.model_name + ) + urls = {} + for model_file in model_files: + # don't want to make s3 bucket full keys public, so trim to just keep file name + public_file_name = model_file.rsplit("/", 1)[-1] + urls[public_file_name] = self.filesystem_gateway.generate_signed_url(model_file) + return ModelDownloadResponse(urls=urls) diff --git a/server/llm_engine_server/domain/use_cases/model_bundle_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py similarity index 94% rename from server/llm_engine_server/domain/use_cases/model_bundle_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py index aa69bbde..be75e695 100644 --- a/server/llm_engine_server/domain/use_cases/model_bundle_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py @@ -1,7 +1,7 @@ from typing import Optional, Union from uuid import uuid4 -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV1Request, CloneModelBundleV2Request, CreateModelBundleV1Request, @@ -15,16 +15,16 @@ ModelBundleV1Response, ModelBundleV2Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( DockerImageNotFoundException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( ArtifactLike, CloudpickleArtifactFlavor, CustomFramework, @@ -37,8 +37,8 @@ TensorflowFramework, ZipArtifactFlavor, ) -from llm_engine_server.domain.gateways import ModelPrimitiveGateway -from llm_engine_server.domain.repositories import DockerRepository, ModelBundleRepository +from model_engine_server.domain.gateways import ModelPrimitiveGateway +from model_engine_server.domain.repositories import DockerRepository, ModelBundleRepository class CreateModelBundleV1UseCase: @@ -52,7 +52,7 @@ def __init__( docker_repository: DockerRepository, model_primitive_gateway: ModelPrimitiveGateway, ): - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() self.model_bundle_repository = model_bundle_repository self.docker_repository = docker_repository self.model_primitive_gateway = model_primitive_gateway @@ -145,7 +145,7 @@ async def execute( load_predict_fn_module_path=metadata.get("load_predict_fn_module_path", ""), load_model_fn_module_path=metadata.get("load_model_fn_module_path", ""), ) - else: # request.packaging_type == ModelBundlePackagingType.CLOUDPICKLE: + else: # request.packaging_type == ModelBundlePackagingType.LIRA: flavor = RunnableImageFlavor( flavor=ModelBundleFlavorType.RUNNABLE_IMAGE, repository="", # stub value, not used @@ -182,7 +182,7 @@ class CloneModelBundleV1UseCase: def __init__(self, model_bundle_repository: ModelBundleRepository): self.model_bundle_repository = model_bundle_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, @@ -280,7 +280,7 @@ class GetModelBundleByIdV1UseCase: def __init__(self, model_bundle_repository: ModelBundleRepository): self.model_bundle_repository = model_bundle_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User, model_bundle_id: str) -> ModelBundleV1Response: """ @@ -346,13 +346,16 @@ def __init__( docker_repository: DockerRepository, model_primitive_gateway: ModelPrimitiveGateway, ): - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() self.model_bundle_repository = model_bundle_repository self.docker_repository = docker_repository self.model_primitive_gateway = model_primitive_gateway async def execute( - self, user: User, request: CreateModelBundleV2Request + self, + user: User, + request: CreateModelBundleV2Request, + do_auth_check: bool = True, ) -> CreateModelBundleV2Response: """ Runs the use case to create a Model Bundle. @@ -360,6 +363,9 @@ async def execute( Args: user: The user who is creating the Model Bundle. request: A request object that contains the creation fields. + do_auth_check: Whether we should run the auth check. We're skipping the check + inside of the llm endpoint creation use case. This is fine as long as that use case + isn't directly exposed to the outside. Returns: A response object that contains the creation response fields. @@ -396,7 +402,7 @@ async def execute( tag=request.flavor.tag, ) - if not self.authz_module.check_access_create_bundle_v2(user, request): + if do_auth_check and not self.authz_module.check_access_create_bundle_v2(user, request): raise ObjectNotAuthorizedException created_by = user.user_id @@ -428,14 +434,14 @@ async def execute( ) app_config = request.flavor.app_config else: - location = "unused" # Nonempty to support legacy LLMEngine + location = "unused" # Nonempty to support legacy Launch requirements = [] env_params = { "framework_type": ModelBundleFrameworkType.CUSTOM, "ecr_repo": request.flavor.repository, "image_tag": request.flavor.tag, } - packaging_type = ModelBundlePackagingType.CLOUDPICKLE + packaging_type = ModelBundlePackagingType.LIRA app_config = None model_bundle = await self.model_bundle_repository.create_model_bundle( @@ -464,7 +470,7 @@ class CloneModelBundleV2UseCase: def __init__(self, model_bundle_repository: ModelBundleRepository): self.model_bundle_repository = model_bundle_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, @@ -561,7 +567,7 @@ class GetModelBundleByIdV2UseCase: def __init__(self, model_bundle_repository: ModelBundleRepository): self.model_bundle_repository = model_bundle_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User, model_bundle_id: str) -> ModelBundleV2Response: """ diff --git a/server/llm_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py similarity index 82% rename from server/llm_engine_server/domain/use_cases/model_endpoint_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index 4cf04c16..04e595d4 100644 --- a/server/llm_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -4,10 +4,14 @@ Read model endpoint creation logs: GET model-endpoints//creation-logs """ -from typing import List, Optional +import re +from typing import Any, Dict, List, Optional -from llm_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK -from llm_engine_server.common.dtos.model_endpoints import ( +from model_engine_server.common.constants import ( + BILLING_POST_INFERENCE_HOOK, + CALLBACK_POST_INFERENCE_HOOK, +) +from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, DeleteModelEndpointV1Response, @@ -17,31 +21,34 @@ UpdateModelEndpointV1Request, UpdateModelEndpointV1Response, ) -from llm_engine_server.common.resource_limits import MAX_ENDPOINT_SIZE, validate_resource_requests -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.resource_limits import MAX_ENDPOINT_SIZE, validate_resource_requests +from model_engine_server.common.settings import REQUIRED_ENDPOINT_LABELS, RESTRICTED_ENDPOINT_LABELS +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( ModelEndpoint, ModelEndpointType, StreamingEnhancedRunnableImageFlavor, ) -from llm_engine_server.domain.exceptions import ( +from model_engine_server.domain.exceptions import ( + EndpointBillingTagsMalformedException, EndpointInfraStateNotFound, EndpointLabelsException, EndpointResourceInvalidRequestException, ) -from llm_engine_server.domain.repositories import ModelBundleRepository -from llm_engine_server.domain.services import ModelEndpointService +from model_engine_server.domain.repositories import ModelBundleRepository +from model_engine_server.domain.services import ModelEndpointService CONVERTED_FROM_ARTIFACT_LIKE_KEY = "_CONVERTED_FROM_ARTIFACT_LIKE" +MODEL_BUNDLE_CHANGED_KEY = "_MODEL_BUNDLE_CHANGED" logger = make_logger(filename_wo_ext(__name__)) @@ -110,6 +117,62 @@ def validate_deployment_resources( ) +def validate_labels(labels: Dict[str, str]) -> None: + for required_label in REQUIRED_ENDPOINT_LABELS: + if required_label not in labels: + raise EndpointLabelsException( + f"Missing label '{required_label}' in labels. These are all required: {REQUIRED_ENDPOINT_LABELS}", + ) + + for restricted_label in RESTRICTED_ENDPOINT_LABELS: + if restricted_label in labels: + raise EndpointLabelsException(f"Cannot specify '{restricted_label}' in labels") + + try: + from plugins.known_users import ALLOWED_TEAMS + + # Make sure that the team is one of the values from a canonical set. + if labels["team"] not in ALLOWED_TEAMS: + raise EndpointLabelsException(f"Invalid team label, must be one of: {ALLOWED_TEAMS}") + except ModuleNotFoundError: + pass + + # Check k8s will accept the label values + regex_pattern = "(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?" # k8s label regex + for label_value in labels.values(): + if re.fullmatch(regex_pattern, label_value) is None: + raise EndpointLabelsException( + f"Invalid label value {label_value}, must match regex {regex_pattern}" + ) + + +def validate_billing_tags(billing_tags: Optional[Dict[str, Any]]) -> None: + if billing_tags is None: + return + + if type(billing_tags) != dict: + raise EndpointBillingTagsMalformedException("Billing tags must be a json dictionary") + + required_keys = { + "idempotencyKeyPrefix", + "product", + "type", + "subType", + "payee", + "payor", + "reference", + } + + missing_keys = required_keys - set(billing_tags) + if len(missing_keys) > 0: + raise EndpointBillingTagsMalformedException(f"Missing billing tag keys {missing_keys}") + for k, v in billing_tags.items(): + if type(k) != str or type(v) not in [str, dict]: + raise EndpointBillingTagsMalformedException( + "Billing tags must have string keys and string/dict values" + ) + + def validate_post_inference_hooks(user: User, post_inference_hooks: Optional[List[str]]) -> None: # We're going to ask for user-specified auth for callbacks instead of providing default auth # from Launch. Otherwise, we'd want to prevent non-privileged users from using the @@ -119,6 +182,7 @@ def validate_post_inference_hooks(user: User, post_inference_hooks: Optional[Lis for hook in post_inference_hooks: if hook not in [ + BILLING_POST_INFERENCE_HOOK, CALLBACK_POST_INFERENCE_HOOK, ]: raise ValueError(f"Unsupported post-inference hook {hook}") @@ -132,7 +196,7 @@ def __init__( ): self.model_bundle_repository = model_bundle_repository self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, request: CreateModelEndpointV1Request @@ -144,6 +208,8 @@ async def execute( ) if request.labels is None: raise EndpointLabelsException("Endpoint labels cannot be None!") + validate_labels(request.labels) + validate_billing_tags(request.billing_tags) validate_post_inference_hooks(user, request.post_inference_hooks) bundle = await self.model_bundle_repository.get_model_bundle( model_bundle_id=request.model_bundle_id @@ -240,11 +306,14 @@ def __init__( ): self.model_bundle_repository = model_bundle_repository self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, model_endpoint_id: str, request: UpdateModelEndpointV1Request ) -> UpdateModelEndpointV1Response: + if request.labels is not None: + validate_labels(request.labels) + validate_billing_tags(request.billing_tags) validate_post_inference_hooks(user, request.post_inference_hooks) endpoint = await self.model_endpoint_service.get_model_endpoint( @@ -381,7 +450,7 @@ class GetModelEndpointByIdV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User, model_endpoint_id: str) -> GetModelEndpointV1Response: """ @@ -409,7 +478,7 @@ async def execute(self, user: User, model_endpoint_id: str) -> GetModelEndpointV class DeleteModelEndpointByIdV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User, model_endpoint_id: str) -> DeleteModelEndpointV1Response: model_endpoint = await self.model_endpoint_service.get_model_endpoint_record( diff --git a/server/llm_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py similarity index 64% rename from server/llm_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py index 5ebe873c..c35ce456 100644 --- a/server/llm_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py @@ -1,9 +1,9 @@ -from llm_engine_server.common.dtos.model_endpoints import GetModelEndpointsSchemaV1Response -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, +from model_engine_server.common.dtos.model_endpoints import GetModelEndpointsSchemaV1Response +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.services import ModelEndpointService +from model_engine_server.domain.services import ModelEndpointService class GetModelEndpointsSchemaV1UseCase: @@ -13,7 +13,7 @@ class GetModelEndpointsSchemaV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User) -> GetModelEndpointsSchemaV1Response: """Execute the use case. diff --git a/server/llm_engine_server/domain/use_cases/streaming_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py similarity index 77% rename from server/llm_engine_server/domain/use_cases/streaming_inference_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py index 1eddf85a..f4dfce40 100644 --- a/server/llm_engine_server/domain/use_cases/streaming_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py @@ -1,20 +1,20 @@ from typing import AsyncIterable -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, SyncEndpointPredictV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.entities import ModelEndpointType -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.services.model_endpoint_service import ModelEndpointService +from model_engine_server.domain.entities import ModelEndpointType +from model_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException +from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService class CreateStreamingInferenceTaskV1UseCase: @@ -24,7 +24,7 @@ class CreateStreamingInferenceTaskV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, model_endpoint_id: str, request: EndpointPredictV1Request diff --git a/server/llm_engine_server/domain/use_cases/sync_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py similarity index 76% rename from server/llm_engine_server/domain/use_cases/sync_inference_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py index 44fc9c11..7ef1f8bd 100644 --- a/server/llm_engine_server/domain/use_cases/sync_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py @@ -1,18 +1,18 @@ -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, SyncEndpointPredictV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.entities import ModelEndpointType -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.services.model_endpoint_service import ModelEndpointService +from model_engine_server.domain.entities import ModelEndpointType +from model_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException +from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService class CreateSyncInferenceTaskV1UseCase: @@ -22,7 +22,7 @@ class CreateSyncInferenceTaskV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, model_endpoint_id: str, request: EndpointPredictV1Request @@ -41,6 +41,7 @@ async def execute( Raises: ObjectNotFoundException: If a model endpoint with the given ID could not be found. ObjectNotAuthorizedException: If the owner does not own the model endpoint. + asyncio.exceptions.TimeoutError: If the task times out. """ model_endpoint = await self.model_endpoint_service.get_model_endpoint( model_endpoint_id=model_endpoint_id diff --git a/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py b/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py new file mode 100644 index 00000000..b616c299 --- /dev/null +++ b/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py @@ -0,0 +1,243 @@ +import os + +from croniter import croniter +from model_engine_server.common.dtos.triggers import ( + CreateTriggerV1Request, + CreateTriggerV1Response, + DeleteTriggerV1Response, + GetTriggerV1Response, + ListTriggersV1Response, + UpdateTriggerV1Request, + UpdateTriggerV1Response, +) +from model_engine_server.common.resource_limits import validate_resource_requests +from model_engine_server.common.settings import REQUIRED_ENDPOINT_LABELS +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.config import infra_config +from model_engine_server.core.domain_exceptions import ( + DockerImageNotFoundException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, +) +from model_engine_server.domain.exceptions import CronSyntaxException, EndpointLabelsException +from model_engine_server.domain.gateways.cron_job_gateway import CronJobGateway +from model_engine_server.domain.repositories import ( + DockerImageBatchJobBundleRepository, + DockerRepository, + TriggerRepository, +) +from model_engine_server.domain.use_cases.model_endpoint_use_cases import validate_labels + +DEFAULT_HOST = f"https://model-engine.{infra_config().dns_host_domain}" + +ALLOWED_CRON_MACROS = set( + [ + "@yearly", + "@annually", + "@monthly", + "@weekly", + "@daily", + "@midnight", + "@hourly", + ] +) + + +def validate_cron(cron: str) -> None: + if len(cron) == 0: + raise CronSyntaxException("Cron expression cannot be empty.") + + if cron not in ALLOWED_CRON_MACROS: + # case on presence of macro identifier + if cron[0] == "@": + raise CronSyntaxException( + f"Unsupported macro supplied: '{cron}'. Please select from the following list, {ALLOWED_CRON_MACROS}." + ) + elif not croniter.is_valid(cron): + raise CronSyntaxException( + f"Invalid Cron syntax: '{cron}'. Please see https://crontab.guru." + ) + + +class CreateTriggerUseCase: + """Use case for creating a Trigger""" + + def __init__( + self, + trigger_repository: TriggerRepository, + cron_job_gateway: CronJobGateway, + docker_image_batch_job_bundle_repository: DockerImageBatchJobBundleRepository, + docker_repository: DockerRepository, + ): + self.trigger_repository = trigger_repository + self.cron_job_gateway = cron_job_gateway + self.docker_image_batch_job_bundle_repository = docker_image_batch_job_bundle_repository + self.docker_repository = docker_repository + self.authz_module = LiveAuthorizationModule() + + async def execute( + self, + user: User, + request: CreateTriggerV1Request, + ) -> CreateTriggerV1Response: + batch_bundle = ( + await self.docker_image_batch_job_bundle_repository.get_docker_image_batch_job_bundle( + request.bundle_id + ) + ) + + if batch_bundle is None: + raise ObjectNotFoundException("The specified batch job bundle could not be found") + if not self.authz_module.check_access_read_owned_entity(user, batch_bundle): + raise ObjectNotAuthorizedException( + f"User {user} does not have permission for the specified batch job bundle" + ) + + if not self.docker_repository.image_exists( + image_tag=batch_bundle.image_tag, repository_name=batch_bundle.image_repository + ): + raise DockerImageNotFoundException( + repository=batch_bundle.image_repository, + tag=batch_bundle.image_tag, + ) # Error if docker image could not be found either + + # check if required resources exist + if None in [batch_bundle.cpus, batch_bundle.memory]: + raise ObjectHasInvalidValueException("Bundle must specify value for cpus and memory") + # validate resource request in cluster also + validate_resource_requests( + bundle=batch_bundle, + cpus=batch_bundle.cpus, + memory=batch_bundle.memory, + storage=batch_bundle.storage, + gpus=batch_bundle.gpus, + gpu_type=batch_bundle.gpu_type, + ) + + if request.default_job_metadata is None: + raise EndpointLabelsException( + f"Missing labels in default_job_metadata. These are all required: {REQUIRED_ENDPOINT_LABELS}" + ) + + validate_labels(request.default_job_metadata) + validate_cron(request.cron_schedule) + + trigger = await self.trigger_repository.create_trigger( + name=request.name, + created_by=user.user_id, + owner=user.team_id, + cron_schedule=request.cron_schedule, + docker_image_batch_job_bundle_id=request.bundle_id, + default_job_config=request.default_job_config, + default_job_metadata=request.default_job_metadata, + ) + + request.default_job_metadata["trigger_id"] = trigger.id + await self.cron_job_gateway.create_cronjob( + request_host=os.getenv("GATEWAY_URL") or DEFAULT_HOST, + trigger_id=trigger.id, + created_by=user.user_id, + owner=user.team_id, + cron_schedule=request.cron_schedule, + docker_image_batch_job_bundle_id=request.bundle_id, + default_job_config=request.default_job_config, + default_job_metadata=request.default_job_metadata, + ) + + return CreateTriggerV1Response(trigger_id=trigger.id) + + +class ListTriggersUseCase: + def __init__(self, trigger_repository: TriggerRepository): + self.trigger_repository = trigger_repository + + async def execute(self, user: User) -> ListTriggersV1Response: + triggers = await self.trigger_repository.list_triggers(owner=user.team_id) + return ListTriggersV1Response( + triggers=[GetTriggerV1Response.from_orm(trigger) for trigger in triggers] + ) + + +class GetTriggerUseCase: + def __init__(self, trigger_repository: TriggerRepository): + self.trigger_repository = trigger_repository + self.authz_module = LiveAuthorizationModule() + + async def execute(self, user: User, trigger_id: str) -> GetTriggerV1Response: + trigger = await self.trigger_repository.get_trigger(trigger_id=trigger_id) + if trigger is None: + raise ObjectNotFoundException + if not self.authz_module.check_access_read_owned_entity(user, trigger): + raise ObjectNotAuthorizedException( + f"User {user} is not authorized for trigger {trigger_id}" + ) + + return GetTriggerV1Response.from_orm(trigger) + + +class UpdateTriggerUseCase: + def __init__( + self, + trigger_repository: TriggerRepository, + cron_job_gateway: CronJobGateway, + ): + self.trigger_repository = trigger_repository + self.cron_job_gateway = cron_job_gateway + self.authz_module = LiveAuthorizationModule() + + async def execute( + self, user: User, trigger_id: str, request: UpdateTriggerV1Request + ) -> UpdateTriggerV1Response: + trigger = await self.trigger_repository.get_trigger(trigger_id=trigger_id) + if trigger is None: + raise ObjectNotFoundException + if not self.authz_module.check_access_read_owned_entity(user, trigger): + raise ObjectNotAuthorizedException( + f"User {user} is not authorized for trigger {trigger_id}" + ) + + success = True + if request.cron_schedule is not None: + validate_cron(request.cron_schedule) + success = await self.trigger_repository.update_trigger( + trigger_id=trigger_id, cron_schedule=request.cron_schedule + ) + + if success: + await self.cron_job_gateway.update_cronjob( + trigger_id=trigger.id, + cron_schedule=request.cron_schedule, + suspend=request.suspend, + ) + + return UpdateTriggerV1Response(success=success) + + +class DeleteTriggerUseCase: + def __init__( + self, + trigger_repository: TriggerRepository, + cron_job_gateway: CronJobGateway, + ): + self.trigger_repository = trigger_repository + self.cron_job_gateway = cron_job_gateway + self.authz_module = LiveAuthorizationModule() + + async def execute(self, user: User, trigger_id: str) -> DeleteTriggerV1Response: + trigger = await self.trigger_repository.get_trigger(trigger_id=trigger_id) + if trigger is None: + raise ObjectNotFoundException + if not self.authz_module.check_access_read_owned_entity(user, trigger): + raise ObjectNotAuthorizedException( + f"User {user} is not authorized for trigger {trigger_id}" + ) + + success = await self.trigger_repository.delete_trigger(trigger_id=trigger_id) + if success: + await self.cron_job_gateway.delete_cronjob(trigger_id=trigger_id) + + return DeleteTriggerV1Response(success=success) diff --git a/server/llm_engine_server/entrypoints/__init__.py b/model-engine/model_engine_server/entrypoints/__init__.py similarity index 100% rename from server/llm_engine_server/entrypoints/__init__.py rename to model-engine/model_engine_server/entrypoints/__init__.py diff --git a/server/llm_engine_server/entrypoints/init_database.py b/model-engine/model_engine_server/entrypoints/init_database.py similarity index 91% rename from server/llm_engine_server/entrypoints/init_database.py rename to model-engine/model_engine_server/entrypoints/init_database.py index cea6330f..30ca1a1c 100644 --- a/server/llm_engine_server/entrypoints/init_database.py +++ b/model-engine/model_engine_server/entrypoints/init_database.py @@ -2,13 +2,13 @@ import os import psycopg2 -from llm_engine_server.db.base import Base -from llm_engine_server.db.models import * +from model_engine_server.db.base import Base +from model_engine_server.db.models import * from sqlalchemy import create_engine from sqlalchemy.engine import Engine from tenacity import Retrying, stop_after_attempt, wait_exponential -SCHEMAS = ["llm_engine", "model"] +SCHEMAS = ["hosted_model_inference", "model"] def init_database(database_url: str, psycopg_connection): diff --git a/server/llm_engine_server/entrypoints/init_llm_engine_models.py b/model-engine/model_engine_server/entrypoints/init_spellbook_models.py similarity index 95% rename from server/llm_engine_server/entrypoints/init_llm_engine_models.py rename to model-engine/model_engine_server/entrypoints/init_spellbook_models.py index 80ffc275..46c6a53a 100644 --- a/server/llm_engine_server/entrypoints/init_llm_engine_models.py +++ b/model-engine/model_engine_server/entrypoints/init_spellbook_models.py @@ -2,7 +2,7 @@ from typing import Any, Dict import requests -from llm_engine_server.domain.entities import ModelEndpointType +from launch.api_client.model.model_endpoint_type import ModelEndpointType from tenacity import retry, stop_after_attempt, wait_fixed DEFAULT_NETWORK_TIMEOUT_SEC = 10 @@ -120,7 +120,7 @@ def spellbook_bundle_payload( "flavor": { "flavor": "runnable_image", "repository": "instant-llm", - "tag": f"llm_engine_llm_cuda_image_{git_commit}", + "tag": f"launch_llm_cuda_image_{git_commit}", "command": [ "dumb-init", "--", @@ -147,7 +147,7 @@ def spellbook_endpoint_payload( *, endpoint_name: str, bundle_name: str, - endpoint_type: ModelEndpointType = ModelEndpointType.SYNC, + endpoint_type: ModelEndpointType = "async", min_workers: int = 0, max_workers: int = 1, memory: str = "185Gi", @@ -228,7 +228,7 @@ def create_model_endpoint( return response.json() -def create_llm_engine_deployments(gateway_url: str): +def create_spellbook_deployments(gateway_url: str): for model_name, service_config in SERVICE_CONFIGS.items(): bundle_payload = spellbook_bundle_payload( model_name=model_name, @@ -252,5 +252,4 @@ def create_llm_engine_deployments(gateway_url: str): args = parser.parse_args() ensure_gateway_ready(args.gateway_url) - # TODO: Renable this when we're ready to pre-init models - # create_llm_engine_deployments(args.gateway_url) + create_spellbook_deployments(args.gateway_url) diff --git a/server/llm_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py similarity index 66% rename from server/llm_engine_server/entrypoints/k8s_cache.py rename to model-engine/model_engine_server/entrypoints/k8s_cache.py index 593418c5..3802129b 100644 --- a/server/llm_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -10,44 +10,48 @@ from kubernetes import config as kube_config from kubernetes.config.config_exception import ConfigException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.constants import READYZ_FPATH -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.db.base import SessionAsyncNullPool -from llm_engine_server.domain.repositories import DockerRepository -from llm_engine_server.infra.gateways import FakeMonitoringMetricsGateway -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.constants import READYZ_FPATH +from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.db.base import SessionAsyncNullPool +from model_engine_server.domain.gateways import MonitoringMetricsGateway +from model_engine_server.domain.repositories import DockerRepository +from model_engine_server.infra.gateways import ( + DatadogMonitoringMetricsGateway, + FakeMonitoringMetricsGateway, +) +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -from llm_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( FakeSQSEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.image_cache_gateway import ImageCacheGateway -from llm_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.resources.image_cache_gateway import ImageCacheGateway +from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( LiveSQSEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( SQSEndpointResourceDelegate, ) -from llm_engine_server.infra.repositories import ECRDockerRepository -from llm_engine_server.infra.repositories.db_model_endpoint_record_repository import ( +from model_engine_server.infra.repositories import ECRDockerRepository +from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( DbModelEndpointRecordRepository, ) -from llm_engine_server.infra.repositories.model_endpoint_cache_repository import ( +from model_engine_server.infra.repositories.model_endpoint_cache_repository import ( ModelEndpointCacheRepository, ) -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) -from llm_engine_server.infra.repositories.redis_model_endpoint_cache_repository import ( +from model_engine_server.infra.repositories.redis_model_endpoint_cache_repository import ( RedisModelEndpointCacheRepository, ) -from llm_engine_server.infra.services.image_cache_service import ImageCacheService -from llm_engine_server.infra.services.model_endpoint_cache_service import ( +from model_engine_server.infra.services.image_cache_service import ImageCacheService +from model_engine_server.infra.services.model_endpoint_cache_service import ( ModelEndpointCacheWriteService, ) @@ -91,7 +95,11 @@ async def main(args: Any): logger.info(f"Using cache redis url {redis_url}") cache_repo = RedisModelEndpointCacheRepository(redis_info=redis_url) - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + monitoring_metrics_gateway: MonitoringMetricsGateway + if SKIP_AUTH: + monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + else: + monitoring_metrics_gateway = DatadogMonitoringMetricsGateway() endpoint_record_repo = DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, session=SessionAsyncNullPool, @@ -105,7 +113,9 @@ async def main(args: Any): sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) - k8s_resource_manager = LiveEndpointResourceGateway(sqs_delegate=sqs_delegate) + k8s_resource_manager = LiveEndpointResourceGateway( + sqs_delegate=sqs_delegate, + ) image_cache_gateway = ImageCacheGateway() docker_repo = ECRDockerRepository() while True: diff --git a/server/llm_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py similarity index 78% rename from server/llm_engine_server/entrypoints/start_batch_job_orchestration.py rename to model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index 33863048..01a03445 100644 --- a/server/llm_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -4,13 +4,15 @@ from datetime import timedelta import aioredis -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.db.base import SessionAsyncNullPool -from llm_engine_server.domain.entities import BatchJobSerializationFormat -from llm_engine_server.infra.gateways import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH +from model_engine_server.db.base import SessionAsyncNullPool +from model_engine_server.domain.entities import BatchJobSerializationFormat +from model_engine_server.domain.gateways import MonitoringMetricsGateway +from model_engine_server.infra.gateways import ( CeleryTaskQueueGateway, + DatadogMonitoringMetricsGateway, FakeMonitoringMetricsGateway, LiveAsyncModelEndpointInferenceGateway, LiveBatchJobProgressGateway, @@ -20,24 +22,24 @@ LiveSyncModelEndpointInferenceGateway, S3FilesystemGateway, ) -from llm_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( FakeSQSEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( LiveSQSEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( SQSEndpointResourceDelegate, ) -from llm_engine_server.infra.repositories import ( +from model_engine_server.infra.repositories import ( DbBatchJobRecordRepository, DbModelEndpointRecordRepository, RedisModelEndpointCacheRepository, ) -from llm_engine_server.infra.services import ( +from model_engine_server.infra.services import ( LiveBatchJobOrchestrationService, LiveModelEndpointService, ) @@ -55,11 +57,13 @@ async def run_batch_job( redis = aioredis.Redis(connection_pool=pool) redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS) sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + monitoring_metrics_gateway: MonitoringMetricsGateway + if SKIP_AUTH: + monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + else: + monitoring_metrics_gateway = DatadogMonitoringMetricsGateway() model_endpoint_record_repo = DbModelEndpointRecordRepository( - monitoring_metrics_gateway=monitoring_metrics_gateway, - session=session, - read_only=False, + monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False ) sqs_delegate: SQSEndpointResourceDelegate @@ -124,10 +128,7 @@ def entrypoint(): parser = argparse.ArgumentParser() parser.add_argument("--job-id", "-j", required=True, help="The ID of the batch job to run.") parser.add_argument( - "--owner", - "-o", - required=True, - help="The ID of the user who owns the batch job.", + "--owner", "-o", required=True, help="The ID of the user who owns the batch job." ) parser.add_argument("--input-path", "-i", required=True, help="The path to the input data.") parser.add_argument( diff --git a/server/llm_engine_server/entrypoints/start_docker_image_batch_job_init_container.py b/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py similarity index 78% rename from server/llm_engine_server/entrypoints/start_docker_image_batch_job_init_container.py rename to model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py index c552b09a..f26662c3 100644 --- a/server/llm_engine_server/entrypoints/start_docker_image_batch_job_init_container.py +++ b/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py @@ -1,11 +1,11 @@ import argparse import shutil -import llm_engine_server.core.aws.storage_client as storage_client -from llm_engine_server.common.serialization_utils import b64_to_str -from llm_engine_server.core.aws.storage_client import s3_fileobj_exists -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.core.utils.url import parse_attachment_url +import model_engine_server.core.aws.storage_client as storage_client +from model_engine_server.common.serialization_utils import b64_to_str +from model_engine_server.core.aws.storage_client import s3_fileobj_exists +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.utils.url import parse_attachment_url logger = make_logger(filename_wo_ext(__file__)) @@ -41,9 +41,4 @@ def main(input_local: str, local_file: str, remote_file: str, file_contents_b64e parser.add_argument("--remote-file", type=str) parser.add_argument("--file-contents-b64encoded", type=str) args = parser.parse_args() - main( - args.input_local, - args.local_file, - args.remote_file, - args.file_contents_b64encoded, - ) + main(args.input_local, args.local_file, args.remote_file, args.file_contents_b64encoded) diff --git a/server/llm_engine_server/entrypoints/start_fastapi_server.py b/model-engine/model_engine_server/entrypoints/start_fastapi_server.py similarity index 92% rename from server/llm_engine_server/entrypoints/start_fastapi_server.py rename to model-engine/model_engine_server/entrypoints/start_fastapi_server.py index d120a31b..119935ff 100644 --- a/server/llm_engine_server/entrypoints/start_fastapi_server.py +++ b/model-engine/model_engine_server/entrypoints/start_fastapi_server.py @@ -22,11 +22,11 @@ def start_gunicorn_server(port: int, num_workers: int, debug: bool) -> None: "--keep-alive", "2", "--worker-class", - "llm_engine_server.api.worker.LLMEngineWorker", + "model_engine_server.api.worker.LaunchWorker", "--workers", f"{num_workers}", *additional_args, - "llm_engine_server.api.app:app", + "model_engine_server.api.app:app", ] subprocess.run(command, check=True) diff --git a/server/llm_engine_server/inference/__init__.py b/model-engine/model_engine_server/inference/__init__.py similarity index 100% rename from server/llm_engine_server/inference/__init__.py rename to model-engine/model_engine_server/inference/__init__.py diff --git a/server/llm_engine_server/inference/async_inference/__init__.py b/model-engine/model_engine_server/inference/async_inference/__init__.py similarity index 100% rename from server/llm_engine_server/inference/async_inference/__init__.py rename to model-engine/model_engine_server/inference/async_inference/__init__.py diff --git a/server/llm_engine_server/inference/async_inference/celery.py b/model-engine/model_engine_server/inference/async_inference/celery.py similarity index 64% rename from server/llm_engine_server/inference/async_inference/celery.py rename to model-engine/model_engine_server/inference/async_inference/celery.py index 80ba64a0..3ea5db6d 100644 --- a/server/llm_engine_server/inference/async_inference/celery.py +++ b/model-engine/model_engine_server/inference/async_inference/celery.py @@ -1,23 +1,24 @@ import os -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.core.celery import TaskVisibility, celery_app -from llm_engine_server.inference.common import unset_sensitive_envvars +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.core.celery import TaskVisibility, celery_app +from model_engine_server.inference.common import unset_sensitive_envvars unset_sensitive_envvars() broker_type_str = os.getenv("BROKER_TYPE") broker_type = BrokerType(broker_type_str) s3_bucket: str = os.environ.get("CELERY_S3_BUCKET") # type: ignore celery_kwargs = dict( - name="llm_engine_server.inference.async_inference", - modules=["llm_engine_server.inference.async_inference.tasks"], + name="model_engine_server.inference.async_inference", + modules=["model_engine_server.inference.async_inference.tasks"], aws_role=os.environ["AWS_PROFILE"], s3_bucket=s3_bucket, # s3_base_path = TODO get from env var/config task_reject_on_worker_lost=False, worker_proc_alive_timeout=1500, broker_type=broker_type_str, - task_visibility=TaskVisibility.VISIBILITY_24H, # We're using SQS so this only changes task_time_limit + task_visibility=TaskVisibility.VISIBILITY_24H, + # We're using SQS so this only changes task_time_limit ) if broker_type == BrokerType.SQS: queue_name = os.getenv("SQS_QUEUE_NAME") @@ -26,7 +27,6 @@ dict(broker_transport_options={"predefined_queues": {queue_name: {"url": queue_url}}}) ) - async_inference_service = celery_app(**celery_kwargs) # type: ignore if __name__ == "__main__": diff --git a/server/llm_engine_server/inference/async_inference/tasks.py b/model-engine/model_engine_server/inference/async_inference/tasks.py similarity index 75% rename from server/llm_engine_server/inference/async_inference/tasks.py rename to model-engine/model_engine_server/inference/async_inference/tasks.py index 999cc270..074e12ef 100644 --- a/server/llm_engine_server/inference/async_inference/tasks.py +++ b/model-engine/model_engine_server/inference/async_inference/tasks.py @@ -3,22 +3,22 @@ from celery import Task from celery.signals import worker_process_init -from llm_engine_server.common.constants import READYZ_FPATH -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.common.serialization_utils import str_to_bool -from llm_engine_server.core.loggers import make_logger -from llm_engine_server.core.utils.timer import timer -from llm_engine_server.domain.entities import ModelEndpointConfig -from llm_engine_server.inference.async_inference.celery import async_inference_service -from llm_engine_server.inference.common import ( +from model_engine_server.common.constants import READYZ_FPATH +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.common.serialization_utils import str_to_bool +from model_engine_server.core.loggers import make_logger +from model_engine_server.core.utils.timer import timer +from model_engine_server.domain.entities import ModelEndpointConfig +from model_engine_server.inference.async_inference.celery import async_inference_service +from model_engine_server.inference.common import ( get_endpoint_config, load_predict_fn_or_cls, run_predict, ) -from llm_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) -from llm_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler logger = make_logger(__name__) @@ -41,6 +41,8 @@ def init_worker_global(): bundle_name=endpoint_config.bundle_name, post_inference_hooks=endpoint_config.post_inference_hooks, user_id=endpoint_config.user_id, + billing_queue=endpoint_config.billing_queue, + billing_tags=endpoint_config.billing_tags, default_callback_url=endpoint_config.default_callback_url, default_callback_auth=endpoint_config.default_callback_auth, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), @@ -92,6 +94,10 @@ def on_success(self, retval, task_id, args, kwargs): hooks.handle(request_params_pydantic, retval, task_id) # type: ignore -@async_inference_service.task(base=InferenceTask) +@async_inference_service.task( + base=InferenceTask, + # For legacy reasons, we need to use the old name. + name="hosted_model_inference.inference.async_inference.tasks.predict", +) def predict(request_params: Dict[str, Any], return_pickled=True): return predict.predict(request_params, return_pickled) diff --git a/server/llm_engine_server/inference/async_inference/vpa.yaml b/model-engine/model_engine_server/inference/async_inference/vpa.yaml similarity index 100% rename from server/llm_engine_server/inference/async_inference/vpa.yaml rename to model-engine/model_engine_server/inference/async_inference/vpa.yaml diff --git a/server/llm_engine_server/inference/base.Dockerfile b/model-engine/model_engine_server/inference/base.Dockerfile similarity index 72% rename from server/llm_engine_server/inference/base.Dockerfile rename to model-engine/model_engine_server/inference/base.Dockerfile index ab0f9310..34f09ab5 100644 --- a/server/llm_engine_server/inference/base.Dockerfile +++ b/model-engine/model_engine_server/inference/base.Dockerfile @@ -22,9 +22,9 @@ RUN apt-get update && apt-get install -y \ build-essential \ && rm -rf /var/lib/apt/lists/* -COPY --chown=root llm_engine /app/llm_engine -WORKDIR /app/llm_engine +COPY --chown=root model-engine /app/model-engine +WORKDIR /app/model-engine RUN pip install -e . WORKDIR /app -RUN pip install -r /app/llm_engine/llm_engine/inference/requirements_base.txt +RUN pip install -r /app/model-engine/model_engine_server/inference/requirements_base.txt diff --git a/server/llm_engine_server/inference/common.py b/model-engine/model_engine_server/inference/common.py similarity index 92% rename from server/llm_engine_server/inference/common.py rename to model-engine/model_engine_server/inference/common.py index e6191ca6..a1242371 100644 --- a/server/llm_engine_server/inference/common.py +++ b/model-engine/model_engine_server/inference/common.py @@ -9,13 +9,13 @@ import boto3 import cloudpickle -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request, RequestSchema -from llm_engine_server.common.io import open_wrapper -from llm_engine_server.common.serialization_utils import b64_to_python_json -from llm_engine_server.core.loggers import make_logger -from llm_engine_server.core.utils.timer import timer -from llm_engine_server.domain.entities import ModelEndpointConfig -from llm_engine_server.inference.service_requests import make_request +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request, RequestSchema +from model_engine_server.common.io import open_wrapper +from model_engine_server.common.serialization_utils import b64_to_python_json +from model_engine_server.core.loggers import make_logger +from model_engine_server.core.utils.timer import timer +from model_engine_server.domain.entities import ModelEndpointConfig +from model_engine_server.inference.service_requests import make_request logger = make_logger(__name__) @@ -117,7 +117,6 @@ def load_predict_fn_or_cls(): return predict_fn_inner else: logger.info("Loading bundle from serialized object") - # e.g. s3://scale-ml/hosted-model-inference/predict_fns/abc123 with timer(logger=logger, name="download_and_deserialize_cloudpickle_bundle"): with open_wrapper(bundle_url, "rb") as f: @@ -131,7 +130,6 @@ def load_predict_fn_or_cls(): if "model" in bundle: model = bundle["model"] elif "load_model_fn" in bundle: - # e.g. s3://scale-ml/hosted-model-inference/tf-saved-models/tf-cpu-efficientdet-abc123.tar.gz with timer(logger=logger, name="load_model_fn"): if deserialized_config is None: model = bundle["load_model_fn"]() @@ -268,14 +266,14 @@ def get_endpoint_config(): def is_sensitive_envvar(var): - return var.startswith("LLM_ENGINE_") or var.startswith("HMI_") + return var.startswith("LAUNCH_") or var.startswith("HMI_") def unset_sensitive_envvars(): # Since all the pods are in the same namespace as of now, there are env vars e.g. - # `LLM_ENGINE__...` that store the IPs of various services (and also leak that the services exist) + # `LAUNCH__...` that store the IPs of various services (and also leak that the services exist) # Let's unset them here - # The names seem to be the name of the deployment, which always starts with `LLM_ENGINE_` or `HMI_`. + # The names seem to be the name of the deployment, which always starts with `LAUNCH_` or `HMI_`. logger.info("Unsetting environment variables...") sensitive_envvars = [var for var in os.environ if is_sensitive_envvar(var)] for var in sensitive_envvars: diff --git a/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml b/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml new file mode 100644 index 00000000..6fd6f920 --- /dev/null +++ b/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml @@ -0,0 +1,21 @@ +forwarder: + sync: + user_port: 5005 + user_hostname: "localhost" + use_grpc: false + predict_route: "/predict" + healthcheck_route: "/readyz" + batch_route: null + model_engine_unwrap: false + serialize_results_as_string: false + wrap_response: false + async: + user_port: 5005 + user_hostname: "localhost" + use_grpc: false + predict_route: "/predict" + healthcheck_route: "/readyz" + batch_route: null + model_engine_unwrap: false + serialize_results_as_string: false + wrap_response: false diff --git a/model-engine/model_engine_server/inference/configs/service--forwarder.yaml b/model-engine/model_engine_server/inference/configs/service--forwarder.yaml new file mode 100644 index 00000000..9e284230 --- /dev/null +++ b/model-engine/model_engine_server/inference/configs/service--forwarder.yaml @@ -0,0 +1,19 @@ +forwarder: + sync: + user_port: 5005 + user_hostname: "localhost" + use_grpc: false + predict_route: "/predict" + healthcheck_route: "/readyz" + batch_route: null + model_engine_unwrap: true + serialize_results_as_string: true + async: + user_port: 5005 + user_hostname: "localhost" + use_grpc: false + predict_route: "/predict" + healthcheck_route: "/readyz" + batch_route: null + model_engine_unwrap: true + serialize_results_as_string: true diff --git a/server/llm_engine_server/inference/configs/service--http_forwarder.yaml b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml similarity index 82% rename from server/llm_engine_server/inference/configs/service--http_forwarder.yaml rename to model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml index f37694f8..f7f046d5 100644 --- a/server/llm_engine_server/inference/configs/service--http_forwarder.yaml +++ b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml @@ -6,7 +6,7 @@ forwarder: predict_route: "/predict" healthcheck_route: "/readyz" batch_route: null - llm_engine_unwrap: true + model_engine_unwrap: true serialize_results_as_string: true stream: user_port: 5005 @@ -14,5 +14,6 @@ forwarder: predict_route: "/stream" healthcheck_route: "/readyz" batch_route: null - llm_engine_unwrap: true + model_engine_unwrap: true serialize_results_as_string: false + max_concurrency: 20 diff --git a/server/llm_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py b/model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py similarity index 89% rename from server/llm_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py rename to model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py index e23c2c74..13586992 100644 --- a/server/llm_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py @@ -5,7 +5,7 @@ Used to calculate proportion of successful/unsuccessful requests, differentiates between docker build vs other errors. -(Copy of llm_engine/domain/gateways/monitoring_metrics_gateway.py but used purely for +(Copy of model_engine_server/domain/gateways/monitoring_metrics_gateway.py but used purely for inference to avoid importing stuff in user code that we don't need.) """ diff --git a/model-engine/model_engine_server/inference/domain/gateways/usage_metrics_gateway.py b/model-engine/model_engine_server/inference/domain/gateways/usage_metrics_gateway.py new file mode 100644 index 00000000..64161b3c --- /dev/null +++ b/model-engine/model_engine_server/inference/domain/gateways/usage_metrics_gateway.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod +from typing import Dict + + +class UsageMetricsGateway(ABC): + """ + Base class for gateway that emits usage metrics to some store of metrics, e.g. Datadog or + Platform Money Infra. + + Inside inference/ because otherwise we import tons of stuff (in particular hmi_config) that + isn't safe to import inside of the inference code (since it contains sensitive data) + + TODO this code (at least in its current form) should be considered temporary, it's to enable + instantml billing + """ + + @abstractmethod + def emit_task_call_metric(self, idempotency_token: str, tags: Dict[str, str]): + """ + Emits the billing event to the billing queue + Args: + idempotency_token: Some per-request token + tags: User-defined tags to get passed to billing. Should be for internal only. + Right now `tags` is pretty strictly formatted, + and reflects the scale FinancialEvent schema (see EventbridgeUsageMetricsGateway) + + """ + pass diff --git a/server/llm_engine_server/inference/download_and_inject_bundle.py b/model-engine/model_engine_server/inference/download_and_inject_bundle.py similarity index 96% rename from server/llm_engine_server/inference/download_and_inject_bundle.py rename to model-engine/model_engine_server/inference/download_and_inject_bundle.py index 8637be59..74fb3b15 100644 --- a/server/llm_engine_server/inference/download_and_inject_bundle.py +++ b/model-engine/model_engine_server/inference/download_and_inject_bundle.py @@ -2,7 +2,7 @@ import os import shutil -from llm_engine_server.core.loggers import make_logger +from model_engine_server.core.loggers import make_logger logger = make_logger(__name__) diff --git a/server/llm_engine_server/inference/forwarding/__init__.py b/model-engine/model_engine_server/inference/forwarding/__init__.py similarity index 100% rename from server/llm_engine_server/inference/forwarding/__init__.py rename to model-engine/model_engine_server/inference/forwarding/__init__.py diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py new file mode 100644 index 00000000..16e7fc34 --- /dev/null +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -0,0 +1,170 @@ +import argparse +import json +from typing import Any, Dict, Optional, TypedDict, Union + +from celery import Celery, Task, states +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.core.celery import TaskVisibility, celery_app +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.format import format_stacktrace +from model_engine_server.inference.forwarding.forwarding import ( + Forwarder, + LoadForwarder, + load_named_config, +) + +logger = make_logger(logger_name()) + + +class ErrorResponse(TypedDict): + """The response payload for any inference request that encountered an error.""" + + error: str + error_metadata: str + + +class ErrorHandlingTask(Task): + """Sets a 'custom' field with error in the Task response for FAILURE. + + Used when services are ran via the Celery backend. + """ + + def after_return( + self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo + ) -> None: + """Handler that ensures custom error response information is available whenever a Task fails. + + Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value + :param:`retval` is an `Exception`, this handler extracts information from the `Exception` + and constructs a custom error response JSON value (see :func:`error_response` for details). + + This handler then re-propagates the Celery-required exception information (`"exc_type"` and + `"exc_message"`) while adding this new error response information under the `"custom"` key. + """ + if status == states.FAILURE and isinstance(retval, Exception): + logger.warning(f"Setting custom error response for failed task {task_id}") + + info: dict = raw_celery_response(self.backend, task_id) + result: dict = info["result"] + err: Exception = retval + + error_payload = error_response("Internal failure", err) + + # Inspired by pattern from: + # https://www.distributedpython.com/2018/09/28/celery-task-states/ + self.update_state( + state=states.FAILURE, + meta={ + "exc_type": result["exc_type"], + "exc_message": result["exc_message"], + "custom": json.dumps(error_payload, indent=False), + }, + ) + + +def raw_celery_response(backend, task_id: str) -> Dict[str, Any]: + key_info: str = backend.get_key_for_task(task_id) + info_as_str: str = backend.get(key_info) + info: dict = json.loads(info_as_str) + return info + + +def error_response(msg: str, e_unhandled: Exception) -> ErrorResponse: + stacktrace = format_stacktrace(e_unhandled) + return { + "error": str(e_unhandled), + "error_metadata": f"{msg}\n{stacktrace}", + } + + +def create_celery_service( + forwarder: Forwarder, + task_visibility: TaskVisibility, + queue_name: Optional[str] = None, + sqs_url: Optional[str] = None, +) -> Celery: + """ + Creates a celery application. + Returns: + app (celery.app.base.Celery): Celery app. + exec_func (celery.local.PromiseProxy): Callable task function. + """ + + app: Celery = celery_app( + name=None, + s3_bucket=infra_config().s3_bucket, + task_visibility=task_visibility, + broker_type=str(BrokerType.SQS.value if sqs_url else BrokerType.REDIS.value), + broker_transport_options={"predefined_queues": {queue_name: {"url": sqs_url}}} + if sqs_url + else None, + ) + + # See documentation for options: + # https://docs.celeryproject.org/en/stable/userguide/tasks.html#list-of-options + @app.task(base=ErrorHandlingTask, name=LIRA_CELERY_TASK_NAME, track_started=True) + def exec_func(payload, *ignored_args, **ignored_kwargs): + if len(ignored_args) > 0: + logger.warning(f"Ignoring {len(ignored_args)} positional arguments: {ignored_args=}") + if len(ignored_kwargs) > 0: + logger.warning(f"Ignoring {len(ignored_kwargs)} keyword arguments: {ignored_kwargs=}") + try: + return forwarder(payload) + except Exception: + logger.exception("Celery service failed to respond to request.") + raise + + # Have celery service also accept pre-LIRA celery task name to ensure no downtime + # when transitioning from pre-LIRA single container architecture to LIRA + # multi-container-architecture. + @app.task( + base=ErrorHandlingTask, + name=DEFAULT_CELERY_TASK_NAME, + track_started=True, + ) + def exec_func_pre_lira(payload, *ignored_args, **ignored_kwargs): + return exec_func(payload, *ignored_args, **ignored_kwargs) + + return app + + +def start_celery_service( + app: Celery, + queue: str, + concurrency: int, +) -> None: + worker = app.Worker( + queues=[queue], + concurrency=concurrency, + loglevel="INFO", + optimization="fair", + # pool="solo" argument fixes the known issues of celery and some of the libraries. + # Particularly asyncio and torchvision transformers. + pool="solo", + ) + worker.start() + + +def entrypoint(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--set", type=str, action="append") + parser.add_argument("--task-visibility", type=str, required=True) + parser.add_argument("--num-workers", type=int, required=True) + parser.add_argument("--queue", type=str, required=True) + parser.add_argument("--sqs-url", type=str, default=None) + + args = parser.parse_args() + + forwarder_config = load_named_config(args.config, args.set) + forwarder_loader = LoadForwarder(**forwarder_config["async"]) + forwader = forwarder_loader.load(None, None) + + app = create_celery_service(forwader, TaskVisibility.VISIBILITY_24H, args.queue, args.sqs_url) + start_celery_service(app, args.queue, args.num_workers) + + +if __name__ == "__main__": + entrypoint() diff --git a/server/llm_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py similarity index 82% rename from server/llm_engine_server/inference/forwarding/forwarding.py rename to model-engine/model_engine_server/inference/forwarding/forwarding.py index 0517bc65..196942d5 100644 --- a/server/llm_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -3,17 +3,18 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Any, Iterator, Optional, Sequence, Tuple +from typing import Any, Iterator, List, Optional, Sequence, Tuple import requests import sseclient -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.loggers import logger_name, make_logger -from llm_engine_server.inference.common import get_endpoint_config -from llm_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( +import yaml +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.inference.common import get_endpoint_config +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) -from llm_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler __all__: Sequence[str] = ( "Forwarder", @@ -31,10 +32,10 @@ DEFAULT_PORT: int = 5005 -class LLMEngineSerializationMixin: - """Mixin class for optionally wrapping LLMEngine requests.""" +class ModelEngineSerializationMixin: + """Mixin class for optionally wrapping Model Engine requests.""" - llm_engine_unwrap: bool + model_engine_unwrap: bool serialize_results_as_string: bool def _get_serialize_results_as_string_value( @@ -46,7 +47,7 @@ def _get_serialize_results_as_string_value( return serialize_results_as_string elif KEY_SERIALIZE_RESULTS_AS_STRING in json_payload: - serialize_results_as_string = json_payload[KEY_SERIALIZE_RESULTS_AS_STRING] + serialize_results_as_string = bool(json_payload[KEY_SERIALIZE_RESULTS_AS_STRING]) logger.warning( f"Found '{KEY_SERIALIZE_RESULTS_AS_STRING}' in payload! " f"Overriding {self.serialize_results_as_string=} with " @@ -68,15 +69,17 @@ def _get_serialize_results_as_string_value( def unwrap_json_payload(self, json_payload: Any) -> Tuple[Any, bool]: # TODO: eventually delete serialize_results_as_string: Optional[bool] = None + # IF we get a feature update in model_engine where it's able to allow a user to + # request this from the API, then we can determine that here. + # (NOTE: This is _potential_ future behavior) serialize_results_as_string = self._get_serialize_results_as_string_value( serialize_results_as_string, json_payload, # type: ignore ) - if self.llm_engine_unwrap: + if self.model_engine_unwrap: logger.info(f"Unwrapping {json_payload.keys()=}") - json_payload = json_payload["args"] - # TODO: eventually delete + json_payload = json_payload.get("args", json_payload) serialize_results_as_string = self._get_serialize_results_as_string_value( serialize_results_as_string, json_payload, # type: ignore @@ -91,7 +94,7 @@ def unwrap_json_payload(self, json_payload: Any) -> Tuple[Any, bool]: @staticmethod def get_response_payload(using_serialize_results_as_string: bool, response: Any): - # LLMEngine expects a JSON object with a "result" key. + # Model Engine expects a JSON object with a "result" key. if using_serialize_results_as_string: response_as_string: str = json.dumps(response) return {"result": response_as_string} @@ -100,7 +103,7 @@ def get_response_payload(using_serialize_results_as_string: bool, response: Any) @dataclass -class Forwarder(LLMEngineSerializationMixin): +class Forwarder(ModelEngineSerializationMixin): """Forwards inference requests to another service via HTTP POST. Expects this user-defined inference service to be running on localhost. However, @@ -115,7 +118,7 @@ class Forwarder(LLMEngineSerializationMixin): """ predict_endpoint: str - llm_engine_unwrap: bool + model_engine_unwrap: bool serialize_results_as_string: bool post_inference_hooks_handler: PostInferenceHooksHandler wrap_response: bool @@ -174,8 +177,7 @@ class LoadForwarder: predict_route: str = "/predict" healthcheck_route: str = "/readyz" batch_route: Optional[str] = None - llm_engine_unwrap: bool = True - # TODO: this is a workaround + model_engine_unwrap: bool = True serialize_results_as_string: bool = True wrap_response: bool = True @@ -236,7 +238,7 @@ def endpoint(route: str) -> str: logger.info(f"Waiting for user-defined service to be ready at {hc}...") time.sleep(1) - logger.info(f"Unwrapping spellbook payload formatting?: {self.llm_engine_unwrap}") + logger.info(f"Unwrapping model engine payload formatting?: {self.model_engine_unwrap}") logger.info(f"Serializing result as string?: {self.serialize_results_as_string}") if ENV_SERIALIZE_RESULTS_AS_STRING in os.environ: @@ -263,6 +265,8 @@ def endpoint(route: str) -> str: bundle_name=endpoint_config.bundle_name, post_inference_hooks=endpoint_config.post_inference_hooks, user_id=endpoint_config.user_id, + billing_queue=endpoint_config.billing_queue, + billing_tags=endpoint_config.billing_tags, default_callback_url=endpoint_config.default_callback_url, default_callback_auth=endpoint_config.default_callback_auth, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), @@ -270,7 +274,7 @@ def endpoint(route: str) -> str: return Forwarder( predict_endpoint=pred, - llm_engine_unwrap=self.llm_engine_unwrap, + model_engine_unwrap=self.model_engine_unwrap, serialize_results_as_string=serialize_results_as_string, post_inference_hooks_handler=handler, wrap_response=self.wrap_response, @@ -278,7 +282,7 @@ def endpoint(route: str) -> str: @dataclass -class StreamingForwarder(LLMEngineSerializationMixin): +class StreamingForwarder(ModelEngineSerializationMixin): """Forwards inference requests to another service via HTTP POST. Expects this user-defined inference service to be running on localhost. However, @@ -294,7 +298,7 @@ class StreamingForwarder(LLMEngineSerializationMixin): """ predict_endpoint: str - llm_engine_unwrap: bool + model_engine_unwrap: bool serialize_results_as_string: bool post_inference_hooks_handler: PostInferenceHooksHandler # unused for now @@ -345,7 +349,7 @@ class LoadStreamingForwarder: predict_route: str = "/predict" healthcheck_route: str = "/readyz" batch_route: Optional[str] = None - llm_engine_unwrap: bool = True + model_engine_unwrap: bool = True serialize_results_as_string: bool = False def load(self, resources: Path, cache: Any) -> StreamingForwarder: @@ -405,7 +409,7 @@ def endpoint(route: str) -> str: logger.info(f"Waiting for user-defined service to be ready at {hc}...") time.sleep(1) - logger.info(f"Unwrapping spellbook payload formatting?: {self.llm_engine_unwrap}") + logger.info(f"Unwrapping model engine payload formatting?: {self.model_engine_unwrap}") logger.info(f"Serializing result as string?: {self.serialize_results_as_string}") if ENV_SERIALIZE_RESULTS_AS_STRING in os.environ: @@ -432,6 +436,8 @@ def endpoint(route: str) -> str: bundle_name=endpoint_config.bundle_name, post_inference_hooks=endpoint_config.post_inference_hooks, user_id=endpoint_config.user_id, + billing_queue=endpoint_config.billing_queue, + billing_tags=endpoint_config.billing_tags, default_callback_url=endpoint_config.default_callback_url, default_callback_auth=endpoint_config.default_callback_auth, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), @@ -439,7 +445,55 @@ def endpoint(route: str) -> str: return StreamingForwarder( predict_endpoint=pred, - llm_engine_unwrap=self.llm_engine_unwrap, + model_engine_unwrap=self.model_engine_unwrap, serialize_results_as_string=serialize_results_as_string, post_inference_hooks_handler=handler, ) + + +def load_named_config(config_uri, config_overrides=None): + with open(config_uri, "rt") as rt: + if config_uri.endswith(".json"): + return json.load(rt) + else: + c = yaml.safe_load(rt) + if config_overrides: + _substitute_config_overrides(c, config_overrides) + if len(c) == 1: + name = list(c.keys())[0] + c = c[name] + if "name" not in c: + c["name"] = name + return c + + +def _substitute_config_overrides(config: dict, config_overrides: List[str]) -> None: + """ + Modifies config based on config_overrides. + + config_overrides should be a list of strings of the form `key=value`, + where `key` can be of the form `key1.key2` to denote a substitution for config[key1][key2] + (nesting can be arbitrarily deep). + """ + for override in config_overrides: + split = override.split("=") + if len(split) != 2: + raise ValueError(f"Config override {override} must contain exactly one =") + key_path, value = split + try: + _set_value(config, key_path.split("."), value) + except Exception as e: + raise ValueError(f"Error setting {key_path} to {value} in {config}") from e + + +def _set_value(config: dict, key_path: List[str], value: Any) -> None: + """ + Modifies config by setting the value at config[key_path[0]][key_path[1]]... to be `value`. + """ + key = key_path[0] + if len(key_path) == 1: + config[key] = value + else: + if key not in config: + config[key] = dict() + _set_value(config[key], key_path[1:], value) diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py new file mode 100644 index 00000000..85de6ded --- /dev/null +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -0,0 +1,170 @@ +import argparse +import json +import os +import subprocess +from functools import lru_cache +from multiprocessing import BoundedSemaphore +from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType +from typing import Optional + +from fastapi import Depends, FastAPI, HTTPException +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.inference.forwarding.forwarding import ( + LoadForwarder, + LoadStreamingForwarder, + load_named_config, +) +from sse_starlette.sse import EventSourceResponse + +logger = make_logger(logger_name()) +app = FastAPI() + + +class MultiprocessingConcurrencyLimiter: + def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): + if concurrency is not None: + if concurrency < 1: + raise ValueError("Concurrency should be at least 1") + self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency) + self.blocking = ( + not fail_on_concurrency_limit + ) # we want to block if we want to queue up requests + else: + self.semaphore = None + self.blocking = False # Unused + + def __enter__(self): + logger.debug("Entering concurrency limiter semaphore") + if self.semaphore and not self.semaphore.acquire(block=self.blocking): + logger.warning("Too many requests, returning 429") + raise HTTPException(status_code=429, detail="Too many requests") + # Just raises an HTTPException. + # __exit__ should not run; otherwise the release() doesn't have an acquire() + + def __exit__(self, type, value, traceback): + logger.debug("Exiting concurrency limiter semaphore") + if self.semaphore: + self.semaphore.release() + + +@app.get("/healthz") +@app.get("/readyz") +def healthcheck(): + return "OK" + + +def get_config(): + overrides = os.getenv("CONFIG_OVERRIDES") + config_overrides = None + if overrides is not None: + config_overrides = overrides.split(";") + return load_named_config( + os.getenv("CONFIG_FILE"), + config_overrides, + ) + + +def get_forwarder_loader(): + config = get_config() + forwarder_loader = LoadForwarder(**config["sync"]) + return forwarder_loader + + +def get_streaming_forwarder_loader(): + config = get_config() + streaming_forwarder_loader = LoadStreamingForwarder(**config["stream"]) + return streaming_forwarder_loader + + +@lru_cache() +def get_concurrency_limiter(): + config = get_config() + concurrency = int(config.get("max_concurrency", 5)) + return MultiprocessingConcurrencyLimiter( + concurrency=concurrency, fail_on_concurrency_limit=True + ) + + +@lru_cache() +def load_forwarder(): + return get_forwarder_loader().load(None, None) + + +@lru_cache() +def load_streaming_forwarder(): + return get_streaming_forwarder_loader().load(None, None) + + +@app.post("/predict") +def predict( + request: EndpointPredictV1Request, + forwarder=Depends(load_forwarder), + limiter=Depends(get_concurrency_limiter), +): + with limiter: + return forwarder(request.dict()) + + +@app.post("/stream") +async def stream( + request: EndpointPredictV1Request, + forwarder=Depends(load_streaming_forwarder), + limiter=Depends(get_concurrency_limiter), +): + with limiter: + try: + payload = request.dict() + except Exception: + logger.error(f"Failed to decode payload from: {request}") + raise + else: + logger.debug(f"Received request: {payload}") + + # has internal error logging for each processing stage + responses = forwarder(payload) + + async def event_generator(): + for response in responses: + yield {"data": json.dumps(response)} + + return EventSourceResponse(event_generator()) + + +def entrypoint(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--num-workers", type=int, required=True) + parser.add_argument("--host", type=str, default="[::]") + parser.add_argument("--port", type=int, default=5000) + parser.add_argument("--set", type=str, action="append") + + args = parser.parse_args() + + values = [f"CONFIG_FILE={args.config}"] + if args.set is not None: + values.append(f"CONFIG_OVERRIDES={';'.join(args.set)}") + envs = [] + for v in values: + envs.extend(["--env", v]) + + command = [ + "gunicorn", + "--bind", + f"{args.host}:{args.port}", + "--timeout", + "1200", + "--keep-alive", + "2", + "--worker-class", + "uvicorn.workers.UvicornWorker", + "--workers", + str(args.num_workers), + *envs, + "model_engine_server.inference.forwarding.http_forwarder:app", + ] + subprocess.run(command) + + +if __name__ == "__main__": + entrypoint() diff --git a/server/llm_engine_server/inference/infra/__init__.py b/model-engine/model_engine_server/inference/infra/__init__.py similarity index 100% rename from server/llm_engine_server/inference/infra/__init__.py rename to model-engine/model_engine_server/inference/infra/__init__.py diff --git a/server/llm_engine_server/inference/infra/gateways/__init__.py b/model-engine/model_engine_server/inference/infra/gateways/__init__.py similarity index 100% rename from server/llm_engine_server/inference/infra/gateways/__init__.py rename to model-engine/model_engine_server/inference/infra/gateways/__init__.py diff --git a/server/llm_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py similarity index 50% rename from server/llm_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py rename to model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py index 8e7d3aa9..a8999723 100644 --- a/server/llm_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py @@ -1,12 +1,12 @@ from datadog import statsd -from llm_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( +from model_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( InferenceMonitoringMetricsGateway, ) class DatadogInferenceMonitoringMetricsGateway(InferenceMonitoringMetricsGateway): def emit_attempted_post_inference_hook(self, hook: str): - statsd.increment(f"scale_llm_engine_server.post_inference_hook.{hook}.attempt") + statsd.increment(f"scale_launch.post_inference_hook.{hook}.attempt") def emit_successful_post_inference_hook(self, hook: str): - statsd.increment(f"scale_llm_engine_server.post_inference_hook.{hook}.success") + statsd.increment(f"scale_launch.post_inference_hook.{hook}.success") diff --git a/model-engine/model_engine_server/inference/infra/gateways/fake_usage_metrics_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/fake_usage_metrics_gateway.py new file mode 100644 index 00000000..d3e76fdf --- /dev/null +++ b/model-engine/model_engine_server/inference/infra/gateways/fake_usage_metrics_gateway.py @@ -0,0 +1,10 @@ +from typing import Dict + +from model_engine_server.inference.domain.gateways.usage_metrics_gateway import UsageMetricsGateway + + +class FakeUsageMetricsGateway(UsageMetricsGateway): + """No-op usage metrics emitter""" + + def emit_task_call_metric(self, idempotency_token: str, tags: Dict[str, str]): + pass diff --git a/server/llm_engine_server/inference/inject_bundle.Dockerfile b/model-engine/model_engine_server/inference/inject_bundle.Dockerfile similarity index 79% rename from server/llm_engine_server/inference/inject_bundle.Dockerfile rename to model-engine/model_engine_server/inference/inject_bundle.Dockerfile index 94432467..84de0bbc 100644 --- a/server/llm_engine_server/inference/inject_bundle.Dockerfile +++ b/model-engine/model_engine_server/inference/inject_bundle.Dockerfile @@ -13,6 +13,6 @@ WORKDIR /app COPY ${LOCAL_BUNDLE_PATH} ${LOCAL_BUNDLE_PATH} -RUN python /app/llm_engine/llm_engine/inference/download_and_inject_bundle.py +RUN python /app/model-engine/model_engine_server/inference/download_and_inject_bundle.py ENV PYTHONPATH /app \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/limits.conf b/model-engine/model_engine_server/inference/limits.conf new file mode 100644 index 00000000..080aa4a3 --- /dev/null +++ b/model-engine/model_engine_server/inference/limits.conf @@ -0,0 +1,2 @@ +modelengine hard nproc 2000 +modelengine soft nproc 1000 diff --git a/server/llm_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py similarity index 62% rename from server/llm_engine_server/inference/post_inference_hooks.py rename to model-engine/model_engine_server/inference/post_inference_hooks.py index 626bd1b5..cd460a27 100644 --- a/server/llm_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -1,15 +1,23 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional +from uuid import uuid4 import requests -from llm_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth -from llm_engine_server.inference.common import _write_to_s3 -from llm_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( +from model_engine_server.common.constants import ( + BILLING_POST_INFERENCE_HOOK, + CALLBACK_POST_INFERENCE_HOOK, +) +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth +from model_engine_server.inference.common import _write_to_s3 +from model_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( InferenceMonitoringMetricsGateway, ) +from model_engine_server.inference.domain.gateways.usage_metrics_gateway import UsageMetricsGateway +from model_engine_server.inference.infra.gateways.fake_usage_metrics_gateway import ( + FakeUsageMetricsGateway, +) from tenacity import Retrying, stop_after_attempt, wait_exponential logger = make_logger(filename_wo_ext(__file__)) @@ -40,6 +48,41 @@ def handle( pass +class BillingHook(PostInferenceHook): + def __init__( + self, + endpoint_name: str, + bundle_name: str, + user_id: str, + billing_queue: Optional[str], + billing_tags: Optional[Dict[str, Any]], + ): + super().__init__(endpoint_name, bundle_name, user_id) + self._billing_queue = billing_queue + self._billing_tags = billing_tags or {} + + def handle( + self, + request_payload: EndpointPredictV1Request, + response: Dict[str, Any], + task_id: Optional[str], + ): + if not self._user_id or not self._billing_queue: + logger.error("Usage inputs could not be found for billing hook, aborting") + return + if not task_id: + task_id = str(uuid4()) + + events_queue: UsageMetricsGateway + try: + from plugins.eventbridge_usage_metrics_gateway import EventbridgeUsageMetricsGateway + + events_queue = EventbridgeUsageMetricsGateway(self._billing_queue) + except ModuleNotFoundError: + events_queue = FakeUsageMetricsGateway() + events_queue.emit_task_call_metric(idempotency_token=task_id, tags=self._billing_tags) + + class CallbackHook(PostInferenceHook): def __init__( self, @@ -85,6 +128,8 @@ def __init__( endpoint_name: str, bundle_name: str, user_id: str, + billing_queue: str, + billing_tags: Dict[str, Any], default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], post_inference_hooks: Optional[List[str]], @@ -97,7 +142,15 @@ def __init__( # TODO: Ensure that this process gracefully handles errors in # initializing each post-inference hook. hook_lower = hook.lower() - if hook_lower == CALLBACK_POST_INFERENCE_HOOK: + if hook_lower == BILLING_POST_INFERENCE_HOOK: + self._hooks[hook_lower] = BillingHook( + endpoint_name, + bundle_name, + user_id, + billing_queue, + billing_tags, + ) + elif hook_lower == CALLBACK_POST_INFERENCE_HOOK: self._hooks[hook_lower] = CallbackHook( endpoint_name, bundle_name, diff --git a/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile new file mode 100644 index 00000000..edac54c9 --- /dev/null +++ b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile @@ -0,0 +1,70 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +WORKDIR /app + +# Install basic packages. +# TODO: ffmpeg, libsm6, and lixext6 are essentially hardcoded from lidar. +# It's probably more correct to add support for arbitrary user-specified base images, +# otherwise this base image gets bloated over time. +RUN apt-get update && apt-get install -y \ + apt-utils \ + dumb-init \ + git \ + ssh \ + emacs-nox \ + htop \ + iftop \ + vim \ + ffmpeg \ + libsm6 \ + libxext6 \ + libcurl4-openssl-dev \ + libssl-dev \ + python3-dev \ + gcc \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Apparently wget has a vulnerability so we remove it here +RUN apt-get remove wget -y + +# Create a virtualenv for python so we install our packages in the right place +# Not sure how useful the existing contents of the pytorch image are anymore :/ Maybe it's used for cuda/cudnn installs +RUN python3 -m venv /venv +ENV PATH=/venv/bin:$PATH + +# Run everything as not-root user +RUN useradd -m modelengine -s /bin/bash +RUN chown -R modelengine /venv +RUN chown -R modelengine /app +# Limits for nproc and consequently number of files open +ADD model-engine/model_engine_server/inference/limits.conf /etc/security/limits.conf +USER modelengine + +# Not good for layer caching oh well +# The inference code should only need these few files/directories to function (hopefully) +# Don't copy the entire folder for security reasons + +RUN mkdir -p /app/model-engine +RUN mkdir -p /app/model-engine/model_engine_server + +RUN chown -R modelengine /app/model-engine + +COPY --chown=modelengine \ + model-engine/model_engine_server/inference/requirements_base.txt \ + /app/model-engine/model_engine_server/inference/requirements_base.txt +RUN pip install -r /app/model-engine/model_engine_server/inference/requirements_base.txt + +COPY --chown=modelengine model-engine/setup.py /app/model-engine/setup.py +COPY --chown=modelengine model-engine/model_engine_server.egg-info /app/model-engine/model_engine_server.egg-info +COPY --chown=modelengine model-engine/model_engine_server/__init__.py /app/model-engine/model_engine_server/__init__.py +COPY --chown=modelengine model-engine/model_engine_server/common /app/model-engine/model_engine_server/common +COPY --chown=modelengine model-engine/model_engine_server/core /app/model-engine/model_engine_server/core +COPY --chown=modelengine model-engine/model_engine_server/domain /app/model-engine/model_engine_server/domain +COPY --chown=modelengine model-engine/model_engine_server/infra /app/model-engine/model_engine_server/infra +COPY --chown=modelengine model-engine/model_engine_server/inference /app/model-engine/model_engine_server/inference +WORKDIR /app/model-engine +RUN pip install -e . + +WORKDIR /app diff --git a/model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile b/model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile new file mode 100644 index 00000000..29f5cd81 --- /dev/null +++ b/model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile @@ -0,0 +1,10 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ARG REQUIREMENTS_FILE +COPY --chown=modelengine ${REQUIREMENTS_FILE} /app/model-engine/model_engine_server/inference/requirements.txt +RUN --mount=type=secret,id=codeartifact-pip-conf,target=/etc/pip.conf,mode=0444 \ + PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf \ + pip install -r /app/model-engine/model_engine_server/inference/requirements.txt + +ENV PYTHONPATH /app diff --git a/server/llm_engine_server/inference/requirements_base.txt b/model-engine/model_engine_server/inference/requirements_base.txt similarity index 69% rename from server/llm_engine_server/inference/requirements_base.txt rename to model-engine/model_engine_server/inference/requirements_base.txt index 1d543b2d..aa3acad0 100644 --- a/server/llm_engine_server/inference/requirements_base.txt +++ b/model-engine/model_engine_server/inference/requirements_base.txt @@ -1,9 +1,13 @@ +aioredis==2.0.1 +celery[redis,sqs,tblib]==5.3.1 fastapi==0.78.0 -uvicorn==0.17.6 -waitress==2.1.2 +gunicorn==20.1.0 +# Incompatibility between celery 5 and python 3.7 because of importlib-metadata 5, so we pin it +importlib-metadata<5.0;python_version<"3.8" +json-log-formatter==0.5.2 smart_open==5.1.0 +tqdm==4.65.0 # Pin typing-extensions so aioitertools doesn't break typing-extensions>=4.1.1 -scale-launch>=0.1.0 -# Incompatibility between celery 5 and python 3.7 because of importlib-metadata 5, so we pin it -importlib-metadata<5.0;python_version<"3.8" +uvicorn==0.17.6 +waitress==2.0.0 diff --git a/server/llm_engine_server/inference/service_requests.py b/model-engine/model_engine_server/inference/service_requests.py similarity index 89% rename from server/llm_engine_server/inference/service_requests.py rename to model-engine/model_engine_server/inference/service_requests.py index 3dd4485e..ad94c7f5 100644 --- a/server/llm_engine_server/inference/service_requests.py +++ b/model-engine/model_engine_server/inference/service_requests.py @@ -8,12 +8,12 @@ import boto3 import cloudpickle from celery.result import allow_join_result -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from llm_engine_server.common.errors import UpstreamHTTPSvcError -from llm_engine_server.common.io import open_wrapper -from llm_engine_server.common.service_requests import make_sync_request_with_retries -from llm_engine_server.core.celery import TaskVisibility, celery_app -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME +from model_engine_server.common.errors import UpstreamHTTPSvcError +from model_engine_server.common.io import open_wrapper +from model_engine_server.common.service_requests import make_sync_request_with_retries +from model_engine_server.core.celery import TaskVisibility, celery_app +from model_engine_server.core.loggers import filename_wo_ext, make_logger logger = make_logger(filename_wo_ext(__file__)) @@ -48,8 +48,7 @@ def get_s3_client(): def _read_function_to_network_endpoint_info(): # Dictionary format: {servable_id: {remote: true/false, endpoint_type: "sync"/"async", destination: },...} - # destination is either a celery queue name, i.e. llm_engine_server., or the full url for an http request, - # i.e. http://.ml-internal.scale.com/predict. + # destination is either a celery queue name, i.e. launch., or the full url for an http request. details_json = os.getenv("CHILD_FN_INFO") if details_json is None: return None @@ -62,7 +61,7 @@ def _read_function_to_network_endpoint_info(): def make_request(servable_id: str, local_fn: Callable, args: List[Any], kwargs: Dict[str, Any]): # This is the external-facing entrypoint. Reads in details and decides to make a network request or not - # This function gets imported and called by the LLMEngine client. + # This function gets imported and called by the Launch client. current_fn_info = child_fn_info[servable_id] use_remote = current_fn_info["remote"] if use_remote: diff --git a/server/llm_engine_server/inference/sync_inference/__init__.py b/model-engine/model_engine_server/inference/sync_inference/__init__.py similarity index 100% rename from server/llm_engine_server/inference/sync_inference/__init__.py rename to model-engine/model_engine_server/inference/sync_inference/__init__.py diff --git a/server/llm_engine_server/inference/sync_inference/constants.py b/model-engine/model_engine_server/inference/sync_inference/constants.py similarity index 100% rename from server/llm_engine_server/inference/sync_inference/constants.py rename to model-engine/model_engine_server/inference/sync_inference/constants.py diff --git a/server/llm_engine_server/inference/sync_inference/destination_rule.yaml b/model-engine/model_engine_server/inference/sync_inference/destination_rule.yaml similarity index 100% rename from server/llm_engine_server/inference/sync_inference/destination_rule.yaml rename to model-engine/model_engine_server/inference/sync_inference/destination_rule.yaml diff --git a/server/llm_engine_server/inference/sync_inference/fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py similarity index 86% rename from server/llm_engine_server/inference/sync_inference/fastapi_server.py rename to model-engine/model_engine_server/inference/sync_inference/fastapi_server.py index 78ec06d5..f25bece2 100644 --- a/server/llm_engine_server/inference/sync_inference/fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py @@ -5,18 +5,18 @@ from typing import Optional from fastapi import BackgroundTasks, FastAPI, HTTPException, Response, status -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.inference.common import ( +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.inference.common import ( get_endpoint_config, load_predict_fn_or_cls, run_predict, ) -from llm_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) -from llm_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler -from llm_engine_server.inference.sync_inference.constants import ( +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from model_engine_server.inference.sync_inference.constants import ( CONCURRENCY, FAIL_ON_CONCURRENCY_LIMIT, NAME, @@ -26,7 +26,6 @@ class MultiprocessingConcurrencyLimiter: - # Shamelessly copied from std-ml-srv def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): if concurrency is not None: if concurrency < 1: @@ -54,7 +53,6 @@ def __exit__(self, type, value, traceback): def with_concurrency_limit(concurrency_limiter: MultiprocessingConcurrencyLimiter): - # Shamelessly copied from std-ml-srv def _inner(flask_func): @wraps(flask_func) def _inner_2(*args, **kwargs): @@ -78,6 +76,8 @@ def _inner_2(*args, **kwargs): bundle_name=endpoint_config.bundle_name, post_inference_hooks=endpoint_config.post_inference_hooks, user_id=endpoint_config.user_id, + billing_queue=endpoint_config.billing_queue, + billing_tags=endpoint_config.billing_tags, default_callback_url=endpoint_config.default_callback_url, default_callback_auth=endpoint_config.default_callback_auth, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), diff --git a/model-engine/model_engine_server/inference/sync_inference/server.py b/model-engine/model_engine_server/inference/sync_inference/server.py new file mode 100644 index 00000000..1713a394 --- /dev/null +++ b/model-engine/model_engine_server/inference/sync_inference/server.py @@ -0,0 +1,96 @@ +import os +from functools import wraps +from threading import BoundedSemaphore +from typing import Optional + +import waitress +from flask import Flask, Response, abort, request +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.inference.common import load_predict_fn_or_cls, run_predict + +logger = make_logger(filename_wo_ext(__file__)) + +NAME = "hosted-inference-sync-service" +CONCURRENCY = 2 # TODO read from env var?? what's our api +NUM_THREADS = CONCURRENCY + 1 # Extra thread for rejecting above-concurrency requests +FAIL_ON_CONCURRENCY_LIMIT = True # TODO read from env var?? +PORT = os.environ["PORT"] + + +class FlaskConcurrencyLimiter: + def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): + if concurrency is not None: + if concurrency < 1: + raise ValueError("Concurrency should be at least 1") + self.semaphore: Optional[BoundedSemaphore] = BoundedSemaphore(value=concurrency) + self.blocking = ( + not fail_on_concurrency_limit + ) # we want to block if we want to queue up requests + else: + self.semaphore = None + self.blocking = False # Unused + + def __enter__(self): + logger.debug("Entering concurrency limiter semaphore") + if self.semaphore and not self.semaphore.acquire(blocking=self.blocking): + logger.warning("Too many requests, returning 429") + abort(429) + # Just raises an HTTPException. + # __exit__ should not run; otherwise the release() doesn't have an acquire() + + def __exit__(self, type, value, traceback): + logger.debug("Exiting concurrency limiter semaphore") + if self.semaphore: + self.semaphore.release() + + +def with_concurrency_limit(concurrency_limiter: FlaskConcurrencyLimiter): + def _inner(flask_func): + @wraps(flask_func) + def _inner_2(*args, **kwargs): + with concurrency_limiter: + return flask_func(*args, **kwargs) + + return _inner_2 + + return _inner + + +app = Flask(NAME) +concurrency_limiter = FlaskConcurrencyLimiter(CONCURRENCY, FAIL_ON_CONCURRENCY_LIMIT) + +# How does this interact with threads? +# Analogous to init_worker() inside async_inference +predict_fn = load_predict_fn_or_cls() + + +@app.route("/healthcheck", methods=["GET"]) +@app.route("/healthz", methods=["GET"]) +@app.route("/readyz", methods=["GET"]) +def healthcheck(): + return Response(status=200, headers={}) + + +@app.route("/predict", methods=["POST"]) +@with_concurrency_limit(concurrency_limiter) +def predict(): + """ + Assumption: payload is a JSON with format {"url": , "args": , "returned_pickled": boolean} + Returns: Results of running the predict function on the request url. See `run_predict`. + + """ + try: + payload = request.get_json() + payload_pydantic = EndpointPredictV1Request.parse_obj(payload) + except Exception: + logger.error(f"Failed to decode payload from: {request}") + raise + else: + logger.debug(f"Received request: {payload}") + + return run_predict(predict_fn, payload_pydantic) + + +if __name__ == "__main__": + waitress.serve(app, port=PORT, url_scheme="https", threads=NUM_THREADS) diff --git a/server/llm_engine_server/inference/sync_inference/start_fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py similarity index 65% rename from server/llm_engine_server/inference/sync_inference/start_fastapi_server.py rename to model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py index c54b9e7b..97aea0ed 100644 --- a/server/llm_engine_server/inference/sync_inference/start_fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py @@ -1,15 +1,14 @@ import os import subprocess -from llm_engine_server.inference.common import unset_sensitive_envvars -from llm_engine_server.inference.sync_inference.constants import NUM_PROCESSES +from model_engine_server.inference.common import unset_sensitive_envvars +from model_engine_server.inference.sync_inference.constants import NUM_PROCESSES PORT = os.environ["PORT"] def start_server(): # TODO: HTTPS - # Copied from std-ml-srv command = [ "gunicorn", "--bind", @@ -22,7 +21,7 @@ def start_server(): "uvicorn.workers.UvicornWorker", "--workers", str(NUM_PROCESSES), - "llm_engine_server.inference.sync_inference.fastapi_server:app", + "model_engine_server.inference.sync_inference.fastapi_server:app", ] unset_sensitive_envvars() subprocess.run(command) diff --git a/server/llm_engine_server/inference/sync_inference/virtual_service.yaml b/model-engine/model_engine_server/inference/sync_inference/virtual_service.yaml similarity index 100% rename from server/llm_engine_server/inference/sync_inference/virtual_service.yaml rename to model-engine/model_engine_server/inference/sync_inference/virtual_service.yaml diff --git a/server/llm_engine_server/inference/sync_inference/vpa.yaml b/model-engine/model_engine_server/inference/sync_inference/vpa.yaml similarity index 100% rename from server/llm_engine_server/inference/sync_inference/vpa.yaml rename to model-engine/model_engine_server/inference/sync_inference/vpa.yaml diff --git a/model-engine/model_engine_server/inference/user.Dockerfile b/model-engine/model_engine_server/inference/user.Dockerfile new file mode 100644 index 00000000..6ed69146 --- /dev/null +++ b/model-engine/model_engine_server/inference/user.Dockerfile @@ -0,0 +1,8 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ARG REQUIREMENTS_FILE +COPY --chown=root ${REQUIREMENTS_FILE} /app/model-engine/model_engine_server/inference/requirements.txt +RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r /app/model-engine/model_engine_server/inference/requirements.txt + +ENV PYTHONPATH /app diff --git a/server/llm_engine_server/infra/__init__.py b/model-engine/model_engine_server/infra/__init__.py similarity index 100% rename from server/llm_engine_server/infra/__init__.py rename to model-engine/model_engine_server/infra/__init__.py diff --git a/server/llm_engine_server/infra/gateways/__init__.py b/model-engine/model_engine_server/infra/gateways/__init__.py similarity index 87% rename from server/llm_engine_server/infra/gateways/__init__.py rename to model-engine/model_engine_server/infra/gateways/__init__.py index c7c5a2af..0417527d 100644 --- a/server/llm_engine_server/infra/gateways/__init__.py +++ b/model-engine/model_engine_server/infra/gateways/__init__.py @@ -4,11 +4,12 @@ from .batch_job_progress_gateway import BatchJobProgressGateway from .celery_task_queue_gateway import CeleryTaskQueueGateway from .datadog_monitoring_metrics_gateway import DatadogMonitoringMetricsGateway +from .fake_model_primitive_gateway import FakeModelPrimitiveGateway from .fake_monitoring_metrics_gateway import FakeMonitoringMetricsGateway -from .filesystem_gateway import FilesystemGateway from .live_async_model_endpoint_inference_gateway import LiveAsyncModelEndpointInferenceGateway from .live_batch_job_orchestration_gateway import LiveBatchJobOrchestrationGateway from .live_batch_job_progress_gateway import LiveBatchJobProgressGateway +from .live_cron_job_gateway import LiveCronJobGateway from .live_docker_image_batch_job_gateway import LiveDockerImageBatchJobGateway from .live_model_endpoint_infra_gateway import LiveModelEndpointInfraGateway from .live_model_endpoints_schema_gateway import LiveModelEndpointsSchemaGateway @@ -18,17 +19,19 @@ from .live_sync_model_endpoint_inference_gateway import LiveSyncModelEndpointInferenceGateway from .model_endpoint_infra_gateway import ModelEndpointInfraGateway from .s3_filesystem_gateway import S3FilesystemGateway +from .s3_llm_artifact_gateway import S3LLMArtifactGateway __all__: Sequence[str] = [ "BatchJobOrchestrationGateway", "BatchJobProgressGateway", "CeleryTaskQueueGateway", "DatadogMonitoringMetricsGateway", + "FakeModelPrimitiveGateway", "FakeMonitoringMetricsGateway", - "FilesystemGateway", "LiveAsyncModelEndpointInferenceGateway", "LiveBatchJobOrchestrationGateway", "LiveBatchJobProgressGateway", + "LiveCronJobGateway", "LiveDockerImageBatchJobGateway", "LiveModelEndpointInfraGateway", "LiveModelEndpointsSchemaGateway", @@ -36,4 +39,5 @@ "LiveSyncModelEndpointInferenceGateway", "ModelEndpointInfraGateway", "S3FilesystemGateway", + "S3LLMArtifactGateway", ] diff --git a/server/llm_engine_server/infra/gateways/aiohttp_sse_client.py b/model-engine/model_engine_server/infra/gateways/aiohttp_sse_client.py similarity index 100% rename from server/llm_engine_server/infra/gateways/aiohttp_sse_client.py rename to model-engine/model_engine_server/infra/gateways/aiohttp_sse_client.py diff --git a/server/llm_engine_server/infra/gateways/batch_job_orchestration_gateway.py b/model-engine/model_engine_server/infra/gateways/batch_job_orchestration_gateway.py similarity index 94% rename from server/llm_engine_server/infra/gateways/batch_job_orchestration_gateway.py rename to model-engine/model_engine_server/infra/gateways/batch_job_orchestration_gateway.py index 57c40394..5ce0bb1e 100644 --- a/server/llm_engine_server/infra/gateways/batch_job_orchestration_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/batch_job_orchestration_gateway.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict -from llm_engine_server.domain.entities import BatchJobSerializationFormat +from model_engine_server.domain.entities import BatchJobSerializationFormat class BatchJobOrchestrationGateway(ABC): diff --git a/server/llm_engine_server/infra/gateways/batch_job_progress_gateway.py b/model-engine/model_engine_server/infra/gateways/batch_job_progress_gateway.py similarity index 92% rename from server/llm_engine_server/infra/gateways/batch_job_progress_gateway.py rename to model-engine/model_engine_server/infra/gateways/batch_job_progress_gateway.py index ab20bd34..e1da816d 100644 --- a/server/llm_engine_server/infra/gateways/batch_job_progress_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/batch_job_progress_gateway.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from llm_engine_server.domain.entities import BatchJobProgress +from model_engine_server.domain.entities import BatchJobProgress class BatchJobProgressGateway(ABC): diff --git a/server/llm_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py similarity index 71% rename from server/llm_engine_server/infra/gateways/celery_task_queue_gateway.py rename to model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index 7017e6ef..66f39f83 100644 --- a/server/llm_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -1,42 +1,35 @@ from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, GetAsyncTaskV1Response, TaskStatus, ) -from llm_engine_server.core.celery import TaskVisibility, celery_app -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway - -logger = make_logger(filename_wo_ext(__file__)) +from model_engine_server.core.celery import TaskVisibility, celery_app +from model_engine_server.core.config import infra_config +from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway celery_redis = celery_app( None, - s3_bucket=ml_infra_config().s3_bucket, + s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.REDIS.value), ) celery_redis_24h = celery_app( None, - s3_bucket=ml_infra_config().s3_bucket, + s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.REDIS.value), task_visibility=TaskVisibility.VISIBILITY_24H, ) celery_sqs = celery_app( - None, s3_bucket=ml_infra_config().s3_bucket, broker_type=str(BrokerType.SQS.value) + None, s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.SQS.value) ) class CeleryTaskQueueGateway(TaskQueueGateway): def __init__(self, broker_type: BrokerType): self.broker_type = broker_type - assert self.broker_type in [ - BrokerType.SQS, - BrokerType.REDIS, - BrokerType.REDIS_24H, - ] + assert self.broker_type in [BrokerType.SQS, BrokerType.REDIS, BrokerType.REDIS_24H] def _get_celery_dest(self): if self.broker_type == BrokerType.SQS: @@ -55,16 +48,13 @@ def send_task( expires: Optional[int] = None, ) -> CreateAsyncTaskV1Response: celery_dest = self._get_celery_dest() - logger.info( - f"Sending task {task_name} with args {args} kwargs {kwargs} to queue {queue_name}" - ) + res = celery_dest.send_task( name=task_name, args=args, kwargs=kwargs, queue=queue_name, ) - logger.info(f"Response from sending task {task_name}: {res}") return CreateAsyncTaskV1Response(task_id=res.id) def get_task(self, task_id: str) -> GetAsyncTaskV1Response: diff --git a/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py new file mode 100644 index 00000000..4dc73f69 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py @@ -0,0 +1,38 @@ +from datadog import statsd +from model_engine_server.core.config import infra_config +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway + + +class DatadogMonitoringMetricsGateway(MonitoringMetricsGateway): + def __init__(self): + self.tags = [f"env:{infra_config().env}"] + + def emit_attempted_build_metric(self): + statsd.increment("scale_launch.service_builder.attempt", tags=self.tags) + + def emit_successful_build_metric(self): + statsd.increment("scale_launch.service_builder.success", tags=self.tags) + + def emit_build_time_metric(self, duration_seconds: float): + statsd.distribution( + "scale_launch.service_builder.endpoint_build_time", duration_seconds, tags=self.tags + ) + + def emit_image_build_cache_hit_metric(self, image_type: str): + statsd.increment( + f"scale_launch.service_builder.{image_type}_image_cache_hit", tags=self.tags + ) + + def emit_image_build_cache_miss_metric(self, image_type: str): + statsd.increment( + f"scale_launch.service_builder.{image_type}_image_cache_miss", tags=self.tags + ) + + def emit_docker_failed_build_metric(self): + statsd.increment("scale_launch.service_builder.docker_failed", tags=self.tags) + + def emit_database_cache_hit_metric(self): + statsd.increment("scale_launch.database_cache.hit", tags=self.tags) + + def emit_database_cache_miss_metric(self): + statsd.increment("scale_launch.database_cache.miss", tags=self.tags) diff --git a/server/llm_engine_server/infra/gateways/fake_model_primitive_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_model_primitive_gateway.py similarity index 82% rename from server/llm_engine_server/infra/gateways/fake_model_primitive_gateway.py rename to model-engine/model_engine_server/infra/gateways/fake_model_primitive_gateway.py index 18bebf0b..e095fa12 100644 --- a/server/llm_engine_server/infra/gateways/fake_model_primitive_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_model_primitive_gateway.py @@ -1,7 +1,7 @@ from typing import Optional -from llm_engine_server.domain.entities import ModelBundleFrameworkType -from llm_engine_server.domain.gateways import ModelPrimitiveGateway +from model_engine_server.domain.entities import ModelBundleFrameworkType +from model_engine_server.domain.gateways import ModelPrimitiveGateway class FakeModelPrimitiveGateway(ModelPrimitiveGateway): diff --git a/server/llm_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py similarity index 64% rename from server/llm_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py rename to model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py index a41ee417..65b6cd7e 100644 --- a/server/llm_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -1,12 +1,15 @@ from collections import defaultdict -from llm_engine_server.domain.gateways import MonitoringMetricsGateway +from model_engine_server.domain.gateways import MonitoringMetricsGateway class FakeMonitoringMetricsGateway(MonitoringMetricsGateway): def __init__(self): self.attempted_build = 0 self.successful_build = 0 + self.build_time_seconds = 0 + self.image_build_cache_hit = defaultdict(int) + self.image_build_cache_miss = defaultdict(int) self.docker_failed_build = 0 self.attempted_hook = defaultdict(int) self.successful_hook = defaultdict(int) @@ -16,6 +19,9 @@ def __init__(self): def reset(self): self.attempted_build = 0 self.successful_build = 0 + self.build_time_seconds = 0 + self.image_build_cache_hit = defaultdict(int) + self.image_build_cache_miss = defaultdict(int) self.docker_failed_build = 0 self.attempted_hook = defaultdict(int) self.successful_hook = defaultdict(int) @@ -28,6 +34,15 @@ def emit_attempted_build_metric(self): def emit_successful_build_metric(self): self.successful_build += 1 + def emit_build_time_metric(self, duration_seconds: float): + self.build_time_seconds += duration_seconds + + def emit_image_build_cache_hit_metric(self, image_type: str): + self.image_build_cache_hit[image_type] += 1 + + def emit_image_build_cache_miss_metric(self, image_type: str): + self.image_build_cache_miss[image_type] += 1 + def emit_docker_failed_build_metric(self): self.docker_failed_build += 1 diff --git a/server/llm_engine_server/infra/gateways/filesystem_gateway.py b/model-engine/model_engine_server/infra/gateways/filesystem_gateway.py similarity index 100% rename from server/llm_engine_server/infra/gateways/filesystem_gateway.py rename to model-engine/model_engine_server/infra/gateways/filesystem_gateway.py diff --git a/server/llm_engine_server/infra/gateways/k8s_resource_parser.py b/model-engine/model_engine_server/infra/gateways/k8s_resource_parser.py similarity index 100% rename from server/llm_engine_server/infra/gateways/k8s_resource_parser.py rename to model-engine/model_engine_server/infra/gateways/k8s_resource_parser.py diff --git a/server/llm_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py similarity index 83% rename from server/llm_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py index be5f5537..7976e24f 100644 --- a/server/llm_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py @@ -1,15 +1,15 @@ import json -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, ) -from llm_engine_server.domain.gateways.async_model_endpoint_inference_gateway import ( +from model_engine_server.domain.gateways.async_model_endpoint_inference_gateway import ( AsyncModelEndpointInferenceGateway, ) -from llm_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway +from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway class LiveAsyncModelEndpointInferenceGateway(AsyncModelEndpointInferenceGateway): diff --git a/server/llm_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py b/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py similarity index 81% rename from server/llm_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py index 09eade35..40cc9f9d 100644 --- a/server/llm_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py @@ -1,20 +1,21 @@ from typing import Dict from kubernetes_asyncio.client.rest import ApiException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import BatchJobSerializationFormat -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways import BatchJobOrchestrationGateway -from llm_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.env_vars import GIT_TAG +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import BatchJobSerializationFormat +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways import BatchJobOrchestrationGateway +from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( BATCH_JOB_TTL_SECONDS_AFTER_FINISHED, ) -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( get_kubernetes_batch_client, load_k8s_yaml, maybe_load_kube_config, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ( +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( BatchJobOrchestrationJobArguments, ) @@ -53,6 +54,7 @@ async def create_batch_job_orchestrator( BATCH_JOB_TIMEOUT=timeout_seconds, BATCH_JOB_MAX_RUNTIME=int(timeout_seconds + SHUTDOWN_GRACE_PERIOD), BATCH_JOB_TTL_SECONDS_AFTER_FINISHED=BATCH_JOB_TTL_SECONDS_AFTER_FINISHED, + GIT_TAG=GIT_TAG, ) resource_key = "batch-job-orchestration-job.yaml" deployment_spec = load_k8s_yaml(resource_key, substitution_kwargs) diff --git a/server/llm_engine_server/infra/gateways/live_batch_job_progress_gateway.py b/model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py similarity index 69% rename from server/llm_engine_server/infra/gateways/live_batch_job_progress_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py index 1db10e0d..7de8f8aa 100644 --- a/server/llm_engine_server/infra/gateways/live_batch_job_progress_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py @@ -1,13 +1,14 @@ -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import BatchJobProgress -from llm_engine_server.infra.gateways import BatchJobProgressGateway, FilesystemGateway +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import BatchJobProgress +from model_engine_server.infra.gateways import BatchJobProgressGateway +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway logger = make_logger(filename_wo_ext(__file__)) def get_batch_job_progress_location(user_id: str, batch_job_id: str): - return f"s3://{ml_infra_config().s3_bucket}/batch_job_progress/{user_id}/{batch_job_id}" + return f"s3://{infra_config().s3_bucket}/batch_job_progress/{user_id}/{batch_job_id}" class LiveBatchJobProgressGateway(BatchJobProgressGateway): @@ -21,7 +22,7 @@ def get_progress(self, owner: str, batch_job_id: str) -> BatchJobProgress: ) try: with self.filesystem_gateway.open( - progress_location, aws_profile=ml_infra_config().profile_ml_worker + progress_location, aws_profile=infra_config().profile_ml_worker ) as f: progress = BatchJobProgress.parse_raw(f.read()) except Exception: @@ -39,6 +40,6 @@ def update_progress(self, owner: str, batch_job_id: str, progress: BatchJobProgr user_id=owner, batch_job_id=batch_job_id ) with self.filesystem_gateway.open( - progress_location, "w", aws_profile=ml_infra_config().profile_ml_worker + progress_location, "w", aws_profile=infra_config().profile_ml_worker ) as f: f.write(progress.json()) diff --git a/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py new file mode 100644 index 00000000..b8316b25 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py @@ -0,0 +1,159 @@ +from typing import Any, Dict, List, Optional + +from kubernetes_asyncio.client.rest import ApiException +from model_engine_server.common import dict_not_none +from model_engine_server.common.config import hmi_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways.cron_job_gateway import CronJobGateway +from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( + LAUNCH_JOB_ID_LABEL_SELECTOR, + _parse_job_status_from_k8s_obj, +) +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( + get_kubernetes_batch_client, + load_k8s_yaml, + maybe_load_kube_config, +) +from model_engine_server.infra.gateways.resources.k8s_resource_types import CronTriggerArguments + +BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS = 10 + +logger = make_logger(filename_wo_ext(__file__)) + + +def _k8s_cron_job_name_from_id(trigger_id: str): + trigger_id_suffix = trigger_id[5:] # suffix following "trig_" contains xid + return f"launch-trigger-{trigger_id_suffix}" + + +class LiveCronJobGateway(CronJobGateway): + def __init__(self): + pass + + async def create_cronjob( + self, + *, + request_host: str, + trigger_id: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Dict[str, str], + ) -> None: + await maybe_load_kube_config() + + batch_client = get_kubernetes_batch_client() + + cron_job_name = _k8s_cron_job_name_from_id(trigger_id) + + cron_trigger_key = "cron-trigger.yaml" + substitution_kwargs = CronTriggerArguments( + HOST=request_host, + NAME=cron_job_name, + CREATED_BY=created_by, + OWNER=owner, + TEAM=default_job_metadata["team"], + PRODUCT=default_job_metadata["product"], + TRIGGER_ID=trigger_id, + CRON_SCHEDULE=cron_schedule, + DOCKER_IMAGE_BATCH_JOB_BUNDLE_ID=docker_image_batch_job_bundle_id, + JOB_CONFIG=self._format_dict_template_args(default_job_config or {}), + JOB_METADATA=self._format_dict_template_args(default_job_metadata), + BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS=BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS, + ) + cron_job_body = load_k8s_yaml(cron_trigger_key, substitution_kwargs) + + try: + await batch_client.create_namespaced_cron_job( + namespace=hmi_config.endpoint_namespace, body=cron_job_body + ) + except ApiException as exc: + logger.exception( + f"Exception encountered when creating batch cron job for docker image batch job bundle id '{docker_image_batch_job_bundle_id}' for {owner}" + ) + raise EndpointResourceInfraException from exc + + async def list_jobs( + self, + *, + owner: str, + trigger_id: Optional[str], + ) -> List[DockerImageBatchJob]: + await maybe_load_kube_config() + + batch_client = get_kubernetes_batch_client() + + try: + label_selector = f"trigger_id={trigger_id}" if trigger_id else f"owner={owner}" + jobs = await batch_client.list_namespaced_job( + namespace=hmi_config.endpoint_namespace, + label_selector=label_selector, + ) + except ApiException as exc: + logger.exception("Got an exception when trying to list the Jobs") + raise EndpointResourceInfraException from exc + + return [ + DockerImageBatchJob( + id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR), + created_by=job.metadata.labels.get("created_by"), + owner=job.metadata.labels.get("owner"), + created_at=job.metadata.creation_timestamp, + completed_at=job.status.completion_time, + status=_parse_job_status_from_k8s_obj(job), + ) + for job in jobs.items + ] + + async def update_cronjob( + self, + *, + trigger_id: str, + cron_schedule: Optional[str], + suspend: Optional[bool], + ) -> None: + await maybe_load_kube_config() + + batch_client = get_kubernetes_batch_client() + + cron_job_name = _k8s_cron_job_name_from_id(trigger_id) + partial_body = dict(spec=dict_not_none(schedule=cron_schedule, suspend=suspend)) + + try: + await batch_client.patch_namespaced_cron_job( + name=cron_job_name, + namespace=hmi_config.endpoint_namespace, + body=partial_body, + ) + except ApiException: + logger.exception( + f"Exception encountered when patching batch cron job for trigger id '{trigger_id}', requested object likely does not exist" + ) + + async def delete_cronjob( + self, + *, + trigger_id: str, + ) -> None: + await maybe_load_kube_config() + + batch_client = get_kubernetes_batch_client() + + cron_job_name = _k8s_cron_job_name_from_id(trigger_id) + + try: + await batch_client.delete_namespaced_cron_job( + name=cron_job_name, namespace=hmi_config.endpoint_namespace + ) + except ApiException: + logger.exception( + f"Exception encountered when deleting batch cron job for trigger id '{trigger_id}', requested object likely does not exist" + ) + + @staticmethod + def _format_dict_template_args(obj: Dict[str, Any]) -> str: + return f"{obj}".replace("'", '"') diff --git a/server/llm_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py similarity index 75% rename from server/llm_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py index 2c2f7a88..bc1d6a9b 100644 --- a/server/llm_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py @@ -5,34 +5,33 @@ from kubernetes_asyncio.client.models.v1_job import V1Job from kubernetes_asyncio.client.rest import ApiException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests -from llm_engine_server.common.serialization_utils import python_json_to_b64 -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities.batch_job_entity import BatchJobStatus, DockerImageBatchJob -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.domain.gateways.docker_image_batch_job_gateway import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.serialization_utils import python_json_to_b64 +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities.batch_job_entity import BatchJobStatus, DockerImageBatchJob +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways.docker_image_batch_job_gateway import ( DockerImageBatchJobGateway, ) -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( get_kubernetes_batch_client, load_k8s_yaml, maybe_load_kube_config, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ( +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( DictStrStr, DockerImageBatchJobCpuArguments, DockerImageBatchJobGpuArguments, ) from xid import XID -DEFAULT_MOUNT_LOCATION = "/restricted_llm_engine/batch_payload.json" +DEFAULT_MOUNT_LOCATION = "/restricted_launch/batch_payload.json" # Must match resources/docker...{cpu,gpu}.yaml's label selector -LLM_ENGINE_JOB_ID_LABEL_SELECTOR = "llm_engine_job_id" +LAUNCH_JOB_ID_LABEL_SELECTOR = "launch_job_id" OWNER_LABEL_SELECTOR = "owner" - ENV: str = os.environ.get("DD_ENV") # type: ignore GIT_TAG: str = os.environ.get("GIT_TAG") # type: ignore SERVICE_CONFIG_PATH: str = os.environ.get("DEPLOY_SERVICE_CONFIG_PATH") # type: ignore @@ -44,7 +43,7 @@ Path(__file__).parent.absolute() / "resources/docker_image_batch_job_gpu.yaml" ) -BATCH_JOB_MAX_RUNTIME_SECONDS = 43200 # 12 hours +BATCH_JOB_MAX_RUNTIME_SECONDS = 86400 * 7 # 7 days BATCH_JOB_TTL_SECONDS_AFTER_FINISHED = 86400 * 3 # 3 days logger = make_logger(filename_wo_ext(__file__)) @@ -56,7 +55,7 @@ class K8sEnvDict(TypedDict): def _get_job_id(): - return f"job-{XID().string()}" + return f"ft-{XID().string()}" def _check_batch_job_id_valid(job_id: str): @@ -82,7 +81,21 @@ def _add_list_values( def _k8s_job_name_from_id(job_id: str): # "di" stands for "docker image" btw - return f"llm-engine-di-batch-job-{job_id}" + return f"launch-di-batch-job-{job_id}" + + +def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus: + status = job.status + # these counts are the number of pods in some given status + if status.failed is not None and status.failed > 0: + return BatchJobStatus.FAILURE + if status.succeeded is not None and status.succeeded > 0: + return BatchJobStatus.SUCCESS + if status.ready is not None and status.ready > 0: + return BatchJobStatus.RUNNING # empirically this doesn't happen + if status.active is not None and status.active > 0: + return BatchJobStatus.RUNNING # TODO this might be a mix of pending and running + return BatchJobStatus.PENDING class LiveDockerImageBatchJobGateway(DockerImageBatchJobGateway): @@ -102,6 +115,8 @@ async def create_docker_image_batch_job( resource_requests: CreateDockerImageBatchJobResourceRequests, labels: Dict[str, str], mount_location: Optional[str], + annotations: Optional[Dict[str, str]] = None, + override_job_max_runtime_s: Optional[int] = None, ) -> str: await maybe_load_kube_config() @@ -116,6 +131,8 @@ async def create_docker_image_batch_job( created_by=created_by, owner=owner, labels=labels, + annotations=annotations, + override_job_max_runtime_s=override_job_max_runtime_s, ) batch_client = get_kubernetes_batch_client() @@ -144,10 +161,13 @@ def _generate_job_spec( created_by: str, owner: str, labels: Dict[str, str], + annotations: Optional[Dict[str, str]] = None, + override_job_max_runtime_s: Optional[int] = None, ) -> Tuple[str, Dict[str, Any]]: job_id = _get_job_id() job_name = _k8s_job_name_from_id(job_id) # why do we even have job_name and id job_config_b64encoded = python_json_to_b64(job_config) + job_runtime_limit = override_job_max_runtime_s or BATCH_JOB_MAX_RUNTIME_SECONDS storage = resource_requests.storage storage_dict = DictStrStr("") if storage is not None: @@ -169,20 +189,22 @@ def _generate_job_spec( CREATED_BY=created_by, OWNER=owner, JOB_ID=job_id, + GIT_TAG=GIT_TAG, # Batch Job Arguments - BATCH_JOB_MAX_RUNTIME=BATCH_JOB_MAX_RUNTIME_SECONDS, + BATCH_JOB_MAX_RUNTIME=job_runtime_limit, BATCH_JOB_TTL_SECONDS_AFTER_FINISHED=BATCH_JOB_TTL_SECONDS_AFTER_FINISHED, - IMAGE=f"{ml_infra_config().docker_repo_prefix}/{repo}:{tag}", + IMAGE=f"{infra_config().docker_repo_prefix}/{repo}:{tag}", COMMAND=command, CPUS=str(resource_requests.cpus), MEMORY=str(resource_requests.memory), STORAGE_DICT=storage_dict, MOUNT_PATH=mount_path, - INPUT_LOCATION="--input-local", # TODO when we enable mounting remote s3files should be "--input-remote" + INPUT_LOCATION="--input-local", + # TODO when we enable mounting remote s3files should be "--input-remote" S3_FILE="unused", LOCAL_FILE_NAME=mount_location, FILE_CONTENTS_B64ENCODED=job_config_b64encoded, - AWS_ROLE=ml_infra_config().profile_ml_inference_worker, + AWS_ROLE=infra_config().profile_ml_inference_worker, # GPU Arguments GPU_TYPE=resource_requests.gpu_type.value, GPUS=resource_requests.gpus or 1, @@ -198,20 +220,22 @@ def _generate_job_spec( CREATED_BY=created_by, OWNER=owner, JOB_ID=job_id, + GIT_TAG=GIT_TAG, # Batch Job Arguments - BATCH_JOB_MAX_RUNTIME=BATCH_JOB_MAX_RUNTIME_SECONDS, + BATCH_JOB_MAX_RUNTIME=job_runtime_limit, BATCH_JOB_TTL_SECONDS_AFTER_FINISHED=BATCH_JOB_TTL_SECONDS_AFTER_FINISHED, - IMAGE=f"{ml_infra_config().docker_repo_prefix}/{repo}:{tag}", + IMAGE=f"{infra_config().docker_repo_prefix}/{repo}:{tag}", COMMAND=command, CPUS=str(resource_requests.cpus), MEMORY=str(resource_requests.memory), STORAGE_DICT=storage_dict, MOUNT_PATH=mount_path, - INPUT_LOCATION="--input-local", # TODO when we enable mounting remote s3files should be "--input-remote" + INPUT_LOCATION="--input-local", + # TODO when we enable mounting remote s3files should be "--input-remote" S3_FILE="unused", LOCAL_FILE_NAME=mount_location, FILE_CONTENTS_B64ENCODED=job_config_b64encoded, - AWS_ROLE=ml_infra_config().profile_ml_inference_worker, + AWS_ROLE=infra_config().profile_ml_inference_worker, ) resource_spec = load_k8s_yaml(resource_key, substitution_kwargs) @@ -227,6 +251,13 @@ def _generate_job_spec( resource_spec["spec"]["template"]["spec"]["containers"][0]["env"] = _add_list_values( container_env_list, override_envs ) + if "annotations" in resource_spec["metadata"]: + resource_spec["metadata"]["annotations"].update(annotations) + else: + resource_spec["metadata"]["annotations"] = annotations + # add trigger_id label if job was spawned by trigger + if "trigger_id" in labels: + resource_spec["metadata"]["labels"]["trigger_id"] = labels["trigger_id"] return job_id, resource_spec async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[DockerImageBatchJob]: @@ -239,7 +270,7 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker try: jobs = await batch_client.list_namespaced_job( namespace=hmi_config.endpoint_namespace, - label_selector=f"{LLM_ENGINE_JOB_ID_LABEL_SELECTOR}={batch_job_id}", + label_selector=f"{LAUNCH_JOB_ID_LABEL_SELECTOR}={batch_job_id}", ) if len(jobs.items) == 0: logger.info(f"Job id {batch_job_id} not found") @@ -252,8 +283,9 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker raise EndpointResourceInfraException from exc job_labels = job.metadata.labels + annotations = job.metadata.annotations - status = self._parse_job_status_from_k8s_obj(job) + status = _parse_job_status_from_k8s_obj(job) return DockerImageBatchJob( id=batch_job_id, @@ -262,6 +294,7 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker created_at=job.metadata.creation_timestamp, completed_at=job.status.completion_time, status=status, + annotations=annotations, ) async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatchJob]: @@ -278,12 +311,13 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc return [ DockerImageBatchJob( - id=job.metadata.labels.get(LLM_ENGINE_JOB_ID_LABEL_SELECTOR), + id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR), created_by=job.metadata.labels.get("created_by"), owner=owner, created_at=job.metadata.creation_timestamp, completed_at=job.status.completion_time, - status=self._parse_job_status_from_k8s_obj(job), + annotations=job.metadata.annotations, + status=_parse_job_status_from_k8s_obj(job), ) for job in jobs.items ] @@ -321,17 +355,3 @@ async def _delete_docker_image_batch_job(self, batch_job_id: str) -> bool: ) raise EndpointResourceInfraException from exc return True - - @staticmethod - def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus: - status = job.status - # these counts are the number of pods in some given status - if status.failed is not None and status.failed > 0: - return BatchJobStatus.FAILURE - if status.succeeded is not None and status.succeeded > 0: - return BatchJobStatus.SUCCESS - if status.ready is not None and status.ready > 0: - return BatchJobStatus.RUNNING # empirically this doesn't happen - if status.active is not None and status.active > 0: - return BatchJobStatus.RUNNING # TODO this might be a mix of pending and running - return BatchJobStatus.PENDING diff --git a/server/llm_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py similarity index 90% rename from server/llm_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py index d4e282ad..4b73a386 100644 --- a/server/llm_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py @@ -1,13 +1,13 @@ import os from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from llm_engine_server.common.settings import ( +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.common.settings import ( RESTRICTED_ENDPOINT_LABELS, generate_deployment_name, get_service_builder_queue, ) -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -15,14 +15,16 @@ ModelEndpointRecord, StorageSpecificationType, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.domain.gateways import TaskQueueGateway -from llm_engine_server.infra.gateways.model_endpoint_infra_gateway import ModelEndpointInfraGateway -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways import TaskQueueGateway +from model_engine_server.infra.gateways.model_endpoint_infra_gateway import ( + ModelEndpointInfraGateway, +) +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -BUILD_TASK_NAME = "llm_engine_server.service_builder.tasks_v1.build_endpoint" +BUILD_TASK_NAME = "model_engine_server.service_builder.tasks_v1.build_endpoint" SERVICE_IDENTIFIER = os.getenv("SERVICE_IDENTIFIER") @@ -68,6 +70,7 @@ def create_model_endpoint_infra( labels: Dict[str, str], prewarm: bool, high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], ) -> str: @@ -93,6 +96,7 @@ def create_model_endpoint_infra( labels=labels, prewarm=prewarm, high_priority=high_priority, + billing_tags=billing_tags, default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, ) @@ -122,6 +126,7 @@ async def update_model_endpoint_infra( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, ) -> str: @@ -159,6 +164,8 @@ async def update_model_endpoint_infra( infra_state.labels.update(labels) labels = infra_state.labels assert labels is not None + if billing_tags is None and endpoint_config is not None: + billing_tags = endpoint_config.billing_tags redact_restricted_labels(labels) if prewarm is None: if infra_state.prewarm is None: @@ -200,6 +207,7 @@ async def update_model_endpoint_infra( labels=labels, prewarm=prewarm, high_priority=high_priority, + billing_tags=billing_tags, default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, ) diff --git a/server/llm_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py similarity index 91% rename from server/llm_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py index 6ebff349..1f6dd7b0 100644 --- a/server/llm_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py @@ -5,7 +5,7 @@ from fastapi import routing from fastapi.openapi.utils import get_openapi_path from fastapi.utils import get_model_definitions -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, GetAsyncTaskV1Response, RequestSchema, @@ -13,8 +13,8 @@ SyncEndpointPredictV1Response, TaskStatus, ) -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.entities import ( +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities import ( CallbackAuth, CallbackBasicAuth, CallbackmTLSAuth, @@ -22,16 +22,15 @@ ModelEndpointsSchema, ModelEndpointType, ) -from llm_engine_server.domain.gateways import ModelEndpointsSchemaGateway +from model_engine_server.domain.gateways import ModelEndpointsSchemaGateway +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway from pydantic import BaseModel from starlette.routing import BaseRoute -from . import FilesystemGateway - # Caches the default model definition so we don't need to recompute every time _default_model_definitions = None -API_REFERENCE_TITLE = "LLMEngine Endpoints API Reference" +API_REFERENCE_TITLE = "Launch Endpoints API Reference" API_REFERENCE_VERSION = "1.0.0" @@ -39,9 +38,7 @@ def predict_stub_async(payload: EndpointPredictV1Request) -> GetAsyncTaskV1Respo raise NotImplementedError -def predict_stub_sync( - payload: EndpointPredictV1Request, -) -> SyncEndpointPredictV1Response: +def predict_stub_sync(payload: EndpointPredictV1Request) -> SyncEndpointPredictV1Response: raise NotImplementedError @@ -123,9 +120,7 @@ def get_openapi( prefix = model_endpoint_name model_name_map = LiveModelEndpointsSchemaGateway.get_model_name_map(prefix) result = get_openapi_path( - route=route, - model_name_map=model_name_map, - operation_ids=operation_ids, + route=route, model_name_map=model_name_map, operation_ids=operation_ids ) if result: path, security_schemes, path_definitions = result @@ -191,9 +186,7 @@ def update_schema_refs_with_prefix(schema: Dict[str, Any], prefix: str) -> None: LiveModelEndpointsSchemaGateway.update_schema_refs_with_prefix(item, prefix) @staticmethod - def get_model_name_map( - prefix: str, - ) -> Dict[Union[Type[BaseModel], Type[Enum]], str]: + def get_model_name_map(prefix: str) -> Dict[Union[Type[BaseModel], Type[Enum]], str]: return { CallbackAuth: "CallbackAuth", CallbackBasicAuth: "CallbackBasicAuth", @@ -223,9 +216,7 @@ def get_schemas_from_model_endpoint_record( try: if schema_location is not None: with self.filesystem_gateway.open( - schema_location, - "rb", - aws_profile=ml_infra_config().profile_ml_worker, + schema_location, "rb", aws_profile=infra_config().profile_ml_worker ) as f: schema = json.load(f) finally: diff --git a/server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py similarity index 90% rename from server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index 1bbcbc11..9103b3e9 100644 --- a/server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -4,21 +4,21 @@ import orjson import requests import sseclient -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, SyncEndpointPredictV1Response, TaskStatus, ) -from llm_engine_server.common.env_vars import CIRCLECI, LOCAL -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import TooManyRequestsException, UpstreamServiceError -from llm_engine_server.domain.gateways.streaming_model_endpoint_inference_gateway import ( +from model_engine_server.common.env_vars import CIRCLECI, LOCAL +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import TooManyRequestsException, UpstreamServiceError +from model_engine_server.domain.gateways.streaming_model_endpoint_inference_gateway import ( StreamingModelEndpointInferenceGateway, ) -from llm_engine_server.infra.gateways.aiohttp_sse_client import EventSource -from llm_engine_server.infra.gateways.k8s_resource_parser import get_node_port +from model_engine_server.infra.gateways.aiohttp_sse_client import EventSource +from model_engine_server.infra.gateways.k8s_resource_parser import get_node_port from orjson import JSONDecodeError from tenacity import ( AsyncRetrying, @@ -43,7 +43,7 @@ def _get_streaming_endpoint_url(deployment_name: str) -> str: elif LOCAL: # local development: the svc.cluster.local address is only available w/in the k8s cluster protocol = "https" - hostname = f"{deployment_name}.{ml_infra_config().dns_host_domain}" + hostname = f"{deployment_name}.{infra_config().dns_host_domain}" else: protocol = "http" # no need to hit external DNS resolution if we're w/in the k8s cluster diff --git a/server/llm_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py similarity index 90% rename from server/llm_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index 5022aeed..c29fcf53 100644 --- a/server/llm_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -3,20 +3,20 @@ import aiohttp import orjson import requests -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, SyncEndpointPredictV1Response, TaskStatus, ) -from llm_engine_server.common.env_vars import CIRCLECI, LOCAL -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import TooManyRequestsException, UpstreamServiceError -from llm_engine_server.domain.gateways.sync_model_endpoint_inference_gateway import ( +from model_engine_server.common.env_vars import CIRCLECI, LOCAL +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import TooManyRequestsException, UpstreamServiceError +from model_engine_server.domain.gateways.sync_model_endpoint_inference_gateway import ( SyncModelEndpointInferenceGateway, ) -from llm_engine_server.infra.gateways.k8s_resource_parser import get_node_port +from model_engine_server.infra.gateways.k8s_resource_parser import get_node_port from orjson import JSONDecodeError from tenacity import ( AsyncRetrying, @@ -41,7 +41,7 @@ def _get_sync_endpoint_url(deployment_name: str) -> str: elif LOCAL: # local development: the svc.cluster.local address is only available w/in the k8s cluster protocol = "https" - hostname = f"{deployment_name}.{ml_infra_config().dns_host_domain}" + hostname = f"{deployment_name}.{infra_config().dns_host_domain}" else: protocol = "http" # no need to hit external DNS resolution if we're w/in the k8s cluster diff --git a/server/llm_engine_server/infra/gateways/model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py similarity index 96% rename from server/llm_engine_server/infra/gateways/model_endpoint_infra_gateway.py rename to model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py index 51ee3c13..7d349657 100644 --- a/server/llm_engine_server/infra/gateways/model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -38,6 +38,7 @@ def create_model_endpoint_infra( labels: Dict[str, str], prewarm: bool, high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], ) -> str: @@ -65,6 +66,7 @@ def create_model_endpoint_infra( to False high_priority: Makes all pods for this endpoint higher priority to enable faster pod spinup time. Higher priority pods will displace the lower priority dummy pods from shared pool. + billing_tags: Arbitrary tags passed to billing default_callback_url: The default callback URL to use for the model endpoint. default_callback_auth: The default callback auth to use for the model endpoint. @@ -91,6 +93,7 @@ async def update_model_endpoint_infra( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth], ) -> str: @@ -98,6 +101,7 @@ async def update_model_endpoint_infra( Updates the underlying infrastructure for a Model Endpoint. Args: + billing_tags: Arbitrary tags passed to billing model_endpoint_record: The associated record of a model endpoint. min_workers: The minimum number of workers for the model endpoint. max_workers: The maximum number of workers for the model endpoint. diff --git a/server/llm_engine_server/infra/gateways/resources/__init__.py b/model-engine/model_engine_server/infra/gateways/resources/__init__.py similarity index 100% rename from server/llm_engine_server/infra/gateways/resources/__init__.py rename to model-engine/model_engine_server/infra/gateways/resources/__init__.py diff --git a/server/llm_engine_server/infra/gateways/resources/endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py similarity index 94% rename from server/llm_engine_server/infra/gateways/resources/endpoint_resource_gateway.py rename to model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py index c09f1247..145f675b 100644 --- a/server/llm_engine_server/infra/gateways/resources/endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from typing import Dict, Generic, Sequence, Tuple, TypeVar -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.domain.entities import ( ModelEndpointInfraState, ModelEndpointRecord, ModelEndpointType, diff --git a/server/llm_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py similarity index 93% rename from server/llm_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py rename to model-engine/model_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py index 1c9ad4a5..e8cfa497 100644 --- a/server/llm_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Sequence -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( SQSEndpointResourceDelegate, SQSQueueInfo, ) diff --git a/server/llm_engine_server/infra/gateways/resources/image_cache_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py similarity index 84% rename from server/llm_engine_server/infra/gateways/resources/image_cache_gateway.py rename to model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py index fcb3be26..bdd15e27 100644 --- a/server/llm_engine_server/infra/gateways/resources/image_cache_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py @@ -1,14 +1,17 @@ -import hashlib import os from typing import Any, Dict, List, TypedDict, cast from kubernetes_asyncio.client.rest import ApiException -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( get_kubernetes_apps_client, load_k8s_yaml, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ImageCacheArguments +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( + ImageCacheArguments, + compute_image_hash, +) logger = make_logger(filename_wo_ext(__file__)) @@ -21,7 +24,6 @@ class CachedImages(TypedDict): KUBERNETES_MAX_LENGTH = 64 -LLM_ENGINE_DEFAULT_NAMESPACE = "llm-engine" class ImageCacheGateway: @@ -35,7 +37,7 @@ async def create_or_update_image_cache(self, cached_images: CachedImages) -> Non base_path = os.getenv("WORKSPACE") if base_path is None: raise EnvironmentError("WORKSPACE env variable not found") - base_name = "llm-engine-image-cache" + base_name = "launch-image-cache" for compute_type, images in cached_images.items(): # Required for mypy TypedDict @@ -45,7 +47,7 @@ async def create_or_update_image_cache(self, cached_images: CachedImages) -> Non name = f"{base_name}-{compute_type}" substitution_kwargs = ImageCacheArguments( RESOURCE_NAME=name, - NAMESPACE=LLM_ENGINE_DEFAULT_NAMESPACE, + NAMESPACE=hmi_config.endpoint_namespace, ) resource_key = f"image-cache-{compute_type}.yaml" image_cache = load_k8s_yaml(resource_key, substitution_kwargs) @@ -53,9 +55,7 @@ async def create_or_update_image_cache(self, cached_images: CachedImages) -> Non labels = image_cache["spec"]["template"]["metadata"]["labels"] containers = image_cache["spec"]["template"]["spec"]["containers"] for image in images: - image_hash = str(hashlib.md5(str(image).encode()).hexdigest())[ - :KUBERNETES_MAX_LENGTH - ] + image_hash = compute_image_hash(image) labels[image_hash] = "True" base_container_dict = { @@ -92,7 +92,7 @@ async def _create_image_cache( try: await apps_api.create_namespaced_daemon_set( - namespace=LLM_ENGINE_DEFAULT_NAMESPACE, + namespace=hmi_config.endpoint_namespace, body=image_cache, ) logger.info(f"Created image cache daemonset {name}") @@ -100,7 +100,7 @@ async def _create_image_cache( if exc.status == 409: # Do not update existing daemonset if the cache is unchanged existing_daemonsets = await apps_api.list_namespaced_daemon_set( - namespace=LLM_ENGINE_DEFAULT_NAMESPACE + namespace=hmi_config.endpoint_namespace ) for daemonset in existing_daemonsets.items: if daemonset.metadata.name == name: @@ -116,7 +116,7 @@ async def _create_image_cache( f"Image cache daemonset {name} already exists, replacing with new values" ) await apps_api.replace_namespaced_daemon_set( - name=name, namespace=LLM_ENGINE_DEFAULT_NAMESPACE, body=image_cache + name=name, namespace=hmi_config.endpoint_namespace, body=image_cache ) elif exc.status == 404: logger.exception("ImageCache API not found. Is the ImageCache CRD installed?") diff --git a/server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py similarity index 90% rename from server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py rename to model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index b6c4d2d2..7006ca1f 100644 --- a/server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -15,34 +15,34 @@ ) from kubernetes_asyncio.client.rest import ApiException from kubernetes_asyncio.config import ConfigException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.common.env_vars import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.env_vars import ( CIRCLECI, - LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH, - LLM_ENGINE_SERVICE_TEMPLATE_FOLDER, + LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH, + LAUNCH_SERVICE_TEMPLATE_FOLDER, ) -from llm_engine_server.common.serialization_utils import b64_to_python_json, str_to_bool -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.common.serialization_utils import b64_to_python_json, str_to_bool +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import ( ModelEndpointConfig, ModelEndpointDeploymentState, ModelEndpointInfraState, + ModelEndpointRecord, ModelEndpointResourceState, ModelEndpointType, ModelEndpointUserConfigState, - RunnableImageFlavor, RunnableImageLike, - StreamingEnhancedRunnableImageFlavor, TritonEnhancedRunnableImageFlavor, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.k8s_resource_parser import ( +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.use_cases.model_endpoint_use_cases import MODEL_BUNDLE_CHANGED_KEY +from model_engine_server.infra.gateways.k8s_resource_parser import ( get_per_worker_value_from_target_concurrency, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ( - LLM_ENGINE_HIGH_PRIORITY_CLASS, +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( + LAUNCH_HIGH_PRIORITY_CLASS, CommonEndpointParams, HorizontalAutoscalingEndpointParams, ResourceArguments, @@ -60,13 +60,7 @@ # and where the user actually owns the files BASE_PATH_IN_ENDPOINT = "/app" -DATADOG_ENV_VAR = { - "DATADOG_TRACE_ENABLED", - "DD_SERVICE", - "DD_ENV", - "DD_VERSION", - "DD_AGENT_HOST", -} +DATADOG_ENV_VAR = {"DATADOG_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"} _lazy_load_kubernetes_clients = True _kubernetes_apps_api = None @@ -132,13 +126,14 @@ def get_kubernetes_autoscaling_client(): # pragma: no cover else: _kubernetes_autoscaling_api = None if not _kubernetes_autoscaling_api: - cluster_version = get_kubernetes_cluster_version() if not CIRCLECI else "1.26" + cluster_version = get_kubernetes_cluster_version() # For k8s cluster versions 1.23 - 1.25 we need to use the v2beta2 api # For 1.26+ v2beta2 has been deperecated and merged into v2 if version.parse(cluster_version) >= version.parse("1.26"): _kubernetes_autoscaling_api = kubernetes_asyncio.client.AutoscalingV2Api() else: _kubernetes_autoscaling_api = kubernetes_asyncio.client.AutoscalingV2beta2Api() + _kubernetes_autoscaling_api = kubernetes_asyncio.client.AutoscalingV2beta2Api() return _kubernetes_autoscaling_api @@ -163,11 +158,11 @@ def get_kubernetes_custom_objects_client(): # pragma: no cover def _endpoint_id_to_k8s_resource_group_name(endpoint_id: str) -> str: - return f"llm-engine-endpoint-id-{endpoint_id}".replace("_", "-") + return f"launch-endpoint-id-{endpoint_id}".replace("_", "-") def _k8s_resource_group_name_to_endpoint_id(k8s_resource_group_name: str) -> str: - return k8s_resource_group_name.replace("llm-engine-endpoint-id-", "").replace("-", "_") + return k8s_resource_group_name.replace("launch-endpoint-id-", "").replace("-", "_") _kube_config_loaded = False @@ -190,11 +185,11 @@ async def maybe_load_kube_config(): def load_k8s_yaml(key: str, substitution_kwargs: ResourceArguments) -> Dict[str, Any]: - if LLM_ENGINE_SERVICE_TEMPLATE_FOLDER is not None: - with open(os.path.join(LLM_ENGINE_SERVICE_TEMPLATE_FOLDER, key), "r") as f: + if LAUNCH_SERVICE_TEMPLATE_FOLDER is not None: + with open(os.path.join(LAUNCH_SERVICE_TEMPLATE_FOLDER, key), "r") as f: template_str = f.read() else: - with open(LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH, "r") as f: + with open(LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH, "r") as f: config_map_str = yaml.safe_load(f.read()) template_str = config_map_str["data"][key] @@ -334,7 +329,7 @@ def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> Common Dictionary with detected values """ main_container = self._get_main_container(deployment_config) - llm_engine_container = self._get_llm_engine_container(deployment_config) + launch_container = self._get_launch_container(deployment_config) resources = main_container.resources image = main_container.image @@ -343,7 +338,7 @@ def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> Common gpus = int((resources.limits or dict()).get("nvidia.com/gpu", 0)) storage = resources.requests.get("ephemeral-storage") - envlist = llm_engine_container.env + envlist = launch_container.env # Hack: for LIRA since the bundle_url isn't really a real env var # we use the `image` for now. This may change if we allow for unpickling # in LIRA. @@ -354,9 +349,9 @@ def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> Common # Temporary fix: new LIRA endpoints created should have these env vars # but old ones don't, so we can fetch them from the config. if aws_role is None: - aws_role = ml_infra_config().profile_ml_inference_worker + aws_role = infra_config().profile_ml_inference_worker if results_s3_bucket is None: - results_s3_bucket = ml_infra_config().s3_bucket + results_s3_bucket = infra_config().s3_bucket if bundle_url is None or aws_role is None or results_s3_bucket is None: raise ValueError("Failed to fetch common endpoint values.") @@ -395,7 +390,7 @@ def _get_main_container(deployment_config: V1Deployment) -> V1Container: return name_to_container["main"] @staticmethod - def _get_llm_engine_container(deployment_config: V1Deployment) -> V1Container: + def _get_launch_container(deployment_config: V1Deployment) -> V1Container: pod_containers = deployment_config.spec.template.spec.containers name_to_container = {container.name: container for container in pod_containers} @@ -415,7 +410,9 @@ def _get_llm_engine_container(deployment_config: V1Deployment) -> V1Container: # --- Private low level fns that interact with k8s @staticmethod - async def _create_deployment(deployment: Dict[str, Any], name: str) -> None: + async def _create_deployment( + model_endpoint_record: ModelEndpointRecord, deployment: Dict[str, Any], name: str + ) -> None: """ Lower-level function to create/patch a k8s deployment Args: @@ -436,32 +433,49 @@ async def _create_deployment(deployment: Dict[str, Any], name: str) -> None: ) except ApiException as exc: if exc.status == 409: - logger.info(f"Deployment {name} already exists, patching") - - if "replicas" in deployment["spec"]: - # Don't pass in replicas if we're doing an update, because we want to just - # let the autoscaler do its thing. - del deployment["spec"]["replicas"] - - logger.info(f"Deployment {name} contents: {deployment}") - - try: - await apps_client.patch_namespaced_deployment( + if ( + model_endpoint_record.metadata is not None + and MODEL_BUNDLE_CHANGED_KEY in model_endpoint_record.metadata + ): + logger.info( + f"Deployment {name} already exists, replacing since model bundle has changed" + ) + logger.info(f"Deployment {name} contents: {deployment}") + await apps_client.replace_namespaced_deployment( name=name, namespace=hmi_config.endpoint_namespace, body=deployment, ) - except ApiException as exc2: - if exc2.status in [409, 422]: - logger.info(f"Deployment {name} failed to patch, falling back to replacing") - await apps_client.replace_namespaced_deployment( + else: + logger.info(f"Deployment {name} already exists, patching") + + if "replicas" in deployment["spec"]: + # Don't pass in replicas if we're doing an update, because we want to just + # let the autoscaler do its thing. + del deployment["spec"]["replicas"] + logger.info(f"Deployment {name} contents: {deployment}") + + try: + await apps_client.patch_namespaced_deployment( name=name, namespace=hmi_config.endpoint_namespace, body=deployment, ) - else: - logger.exception("Got an exception when trying to patch the Deployment") - raise + except ApiException as exc2: + if exc2.status in [409, 422]: + logger.info( + f"Deployment {name} failed to patch, falling back to replacing" + ) + await apps_client.replace_namespaced_deployment( + name=name, + namespace=hmi_config.endpoint_namespace, + body=deployment, + ) + else: + logger.exception( + "Got an exception when trying to replace the Deployment" + ) + raise else: logger.exception("Got an exception when trying to apply the Deployment") raise @@ -989,12 +1003,10 @@ def _get_deployment_resource_name(request: CreateOrUpdateResourcesRequest) -> st model_endpoint_record = build_endpoint_request.model_endpoint_record flavor = model_endpoint_record.current_model_bundle.flavor - if isinstance(flavor, (RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor)): - flavor_class = "runnable-image" - elif isinstance(flavor, TritonEnhancedRunnableImageFlavor): + if isinstance(flavor, TritonEnhancedRunnableImageFlavor): flavor_class = "triton-enhanced-runnable-image" else: - flavor_class = "artifact" + flavor_class = "runnable-image" mode = model_endpoint_record.endpoint_type.value device = "gpu" if build_endpoint_request.gpus > 0 else "cpu" @@ -1033,6 +1045,7 @@ async def _create_or_update_resources( ): add_datadog_env_to_main_container(deployment_template) await self._create_deployment( + model_endpoint_record=request.build_endpoint_request.model_endpoint_record, deployment=deployment_template, name=k8s_resource_group_name, ) @@ -1081,7 +1094,7 @@ async def _create_or_update_resources( ModelEndpointType.SYNC, ModelEndpointType.STREAMING, }: - cluster_version = get_kubernetes_cluster_version() if not CIRCLECI else "1.26" + cluster_version = get_kubernetes_cluster_version() # For k8s cluster versions 1.23 - 1.25 we need to use the v2beta2 api # For 1.26+ v2beta2 has been deperecated and merged into v2 if version.parse(cluster_version) >= version.parse("1.26"): @@ -1116,6 +1129,38 @@ async def _create_or_update_resources( name=k8s_resource_group_name, ) + # TODO wsong: add flag to use istio and use these arguments + if hmi_config.istio_enabled: + virtual_service_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="virtual-service", + ) + virtual_service_template = load_k8s_yaml( + "virtual-service.yaml", virtual_service_arguments + ) + await self._create_virtual_service( + virtual_service=virtual_service_template, + name=k8s_resource_group_name, + ) + + destination_rule_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="destination-rule", + ) + destination_rule_template = load_k8s_yaml( + "destination-rule.yaml", destination_rule_arguments + ) + await self._create_destination_rule( + destination_rule=destination_rule_template, + name=k8s_resource_group_name, + ) + @staticmethod def _get_vertical_autoscaling_params( vpa_config, @@ -1212,15 +1257,14 @@ async def _get_resources( config_maps = await self._get_config_maps( endpoint_id=endpoint_id, deployment_name=k8s_resource_group_name ) - llm_engine_container = self._get_llm_engine_container(deployment_config) - envlist = llm_engine_container.env + launch_container = self._get_launch_container(deployment_config) + envlist = launch_container.env # Note: the env var PREWARM is either "true" or "false" string (or doesn't exist for legacy) # Convert this as early as possible to Optional[bool] to avoid bugs prewarm = str_to_bool(self._get_env_value_from_envlist(envlist, "PREWARM")) high_priority = ( - deployment_config.spec.template.spec.priority_class_name - == LLM_ENGINE_HIGH_PRIORITY_CLASS + deployment_config.spec.template.spec.priority_class_name == LAUNCH_HIGH_PRIORITY_CLASS ) infra_state = ModelEndpointInfraState( @@ -1287,7 +1331,7 @@ async def _get_all_resources( hpas_by_name = {hpa.metadata.name: hpa for hpa in hpas} vpas_by_name = {vpa["metadata"]["name"]: vpa for vpa in vpas} all_config_maps = await self._get_all_config_maps() - # can safely assume hpa with same name as deployment corresponds to the same LLMEngine Endpoint + # can safely assume hpa with same name as deployment corresponds to the same Launch Endpoint logger.info(f"Orphaned hpas: {set(hpas_by_name).difference(set(deployments_by_name))}") logger.info(f"Orphaned vpas: {set(vpas_by_name).difference(set(deployments_by_name))}") infra_states = {} @@ -1297,15 +1341,15 @@ async def _get_all_resources( hpa_config = hpas_by_name.get(name, None) vpa_config = vpas_by_name.get(name, None) common_params = self._get_common_endpoint_params(deployment_config) - llm_engine_container = self._get_llm_engine_container(deployment_config) + launch_container = self._get_launch_container(deployment_config) - envlist = llm_engine_container.env + envlist = launch_container.env # Convert as early as possible to Optional[bool] to avoid bugs prewarm = str_to_bool(self._get_env_value_from_envlist(envlist, "PREWARM")) high_priority = ( deployment_config.spec.template.spec.priority_class_name - == LLM_ENGINE_HIGH_PRIORITY_CLASS + == LAUNCH_HIGH_PRIORITY_CLASS ) if hpa_config: @@ -1350,7 +1394,7 @@ async def _get_all_resources( image=common_params["image"], num_queued_items=None, ) - if name.startswith("llm-engine-endpoint-id-"): + if name.startswith("launch-endpoint-id-"): key = _k8s_resource_group_name_to_endpoint_id(name) is_key_an_endpoint_id = True else: diff --git a/server/llm_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py similarity index 79% rename from server/llm_engine_server/infra/gateways/resources/k8s_resource_types.py rename to model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index f140fd5d..632ec7bf 100644 --- a/server/llm_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -1,31 +1,28 @@ import hashlib -import json from datetime import datetime from typing import Any, Dict, List, Optional, Sequence, TypedDict, Union -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.model_endpoints import BrokerName, BrokerType -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.common.resource_limits import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.model_endpoints import BrokerName, BrokerType +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.env_vars import CIRCLECI, GIT_TAG +from model_engine_server.common.resource_limits import ( FORWARDER_CPU_USAGE, FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_USAGE, ) -from llm_engine_server.common.serialization_utils import bool_to_str, python_json_to_b64 -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.entities import ( - ArtifactLike, +from model_engine_server.common.serialization_utils import python_json_to_b64 +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities import ( ModelEndpointConfig, RunnableImageLike, StreamingEnhancedRunnableImageFlavor, TritonEnhancedRunnableImageFlavor, - ZipArtifactFlavor, ) -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CONVERTED_FROM_ARTIFACT_LIKE_KEY, ) -from llm_engine_server.infra.gateways.k8s_resource_parser import ( +from model_engine_server.infra.gateways.k8s_resource_parser import ( get_node_port, get_target_concurrency_from_per_worker_value, ) @@ -33,10 +30,6 @@ __all__: Sequence[str] = ( "BatchJobOrchestrationJobArguments", "CommonEndpointParams", - "DeploymentArtifactAsyncCpuArguments", - "DeploymentArtifactAsyncGpuArguments", - "DeploymentArtifactSyncCpuArguments", - "DeploymentArtifactSyncGpuArguments", "DeploymentRunnableImageAsyncCpuArguments", "DeploymentRunnableImageAsyncGpuArguments", "DeploymentRunnableImageStreamingCpuArguments", @@ -47,6 +40,7 @@ "DeploymentTritonEnhancedRunnableImageAsyncGpuArguments", "DeploymentTritonEnhancedRunnableImageSyncCpuArguments", "DeploymentTritonEnhancedRunnableImageSyncGpuArguments", + "DestinationRuleArguments", "DictStrInt", "DictStrStr", "DockerImageBatchJobCpuArguments", @@ -56,25 +50,26 @@ "HorizontalAutoscalingEndpointParams", "HorizontalPodAutoscalerArguments", "ImageCacheArguments", - "LLM_ENGINE_DEFAULT_PRIORITY_CLASS", - "LLM_ENGINE_HIGH_PRIORITY_CLASS", + "CronTriggerArguments", + "LAUNCH_DEFAULT_PRIORITY_CLASS", + "LAUNCH_HIGH_PRIORITY_CLASS", "ResourceArguments", "ServiceArguments", "UserConfigArguments", "VerticalAutoscalingEndpointParams", "VerticalPodAutoscalerArguments", + "VirtualServiceArguments", "get_endpoint_resource_arguments_from_request", ) -# Constants for LLMEngine priority classes -LLM_ENGINE_HIGH_PRIORITY_CLASS = "llm-engine-high-priority" -LLM_ENGINE_DEFAULT_PRIORITY_CLASS = "llm-engine-default-priority" +# Constants for Launch priority classes +LAUNCH_HIGH_PRIORITY_CLASS = "model-engine-high-priority" +LAUNCH_DEFAULT_PRIORITY_CLASS = "model-engine-default-priority" KUBERNETES_MAX_LENGTH = 64 FORWARDER_PORT = 5000 USER_CONTAINER_PORT = 5005 ARTIFACT_LIKE_CONTAINER_PORT = FORWARDER_PORT -FORWARDER_IMAGE_TAG = "54f8f73bfb1cce62a2b42326ccf9f49b5b145126" class _BaseResourceArguments(TypedDict): @@ -86,6 +81,7 @@ class _BaseResourceArguments(TypedDict): PRODUCT: str CREATED_BY: str OWNER: str + GIT_TAG: str class _BaseEndpointArguments(_BaseResourceArguments): @@ -149,17 +145,6 @@ class _StreamingDeploymentArguments(TypedDict): STREAMING_PREDICT_ROUTE: str -class _ArtifactDeploymentArguments(_BaseDeploymentArguments): - """Keyword-arguments for substituting into artifact deployment templates.""" - - BUNDLE_URL: str - BASE_PATH: str - LOAD_PREDICT_FN_MODULE_PATH: str - LOAD_MODEL_FN_MODULE_PATH: str - CHILD_FN_INFO: str - PREWARM: str - - class _RunnableImageDeploymentArguments(_BaseDeploymentArguments): """Keyword-arguments for substituting into runnable image deployment templates.""" @@ -169,7 +154,6 @@ class _RunnableImageDeploymentArguments(_BaseDeploymentArguments): HEALTHCHECK_ROUTE: str READINESS_INITIAL_DELAY: int INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH: str - FORWARDER_IMAGE_TAG: str FORWARDER_CONFIG_FILE_NAME: str FORWARDER_CPUS_LIMIT: float FORWARDER_MEMORY_LIMIT: str @@ -220,28 +204,6 @@ class _TritonArguments(TypedDict): TRITON_COMMIT_TAG: str -class DeploymentArtifactAsyncCpuArguments(_ArtifactDeploymentArguments, _AsyncDeploymentArguments): - """Keyword-arguments for substituting into CPU async deployment templates with artifacts.""" - - -class DeploymentArtifactAsyncGpuArguments( - _ArtifactDeploymentArguments, _AsyncDeploymentArguments, _GpuArguments -): - """Keyword-arguments for substituting into GPU async deployment templates with artifacts.""" - - -class DeploymentArtifactSyncCpuArguments( - _ArtifactDeploymentArguments, _SyncArtifactDeploymentArguments -): - """Keyword-arguments for substituting into CPU sync deployment templates with artifacts.""" - - -class DeploymentArtifactSyncGpuArguments( - _ArtifactDeploymentArguments, _SyncArtifactDeploymentArguments, _GpuArguments -): - """Keyword-arguments for substituting into GPU sync deployment templates with artifacts.""" - - class DeploymentRunnableImageSyncCpuArguments( _RunnableImageDeploymentArguments, _SyncRunnableImageDeploymentArguments ): @@ -249,9 +211,7 @@ class DeploymentRunnableImageSyncCpuArguments( class DeploymentRunnableImageSyncGpuArguments( - _RunnableImageDeploymentArguments, - _SyncRunnableImageDeploymentArguments, - _GpuArguments, + _RunnableImageDeploymentArguments, _SyncRunnableImageDeploymentArguments, _GpuArguments ): """Keyword-arguments for substituting into GPU sync deployment templates for runnable images.""" @@ -281,9 +241,7 @@ class DeploymentRunnableImageAsyncGpuArguments( class DeploymentTritonEnhancedRunnableImageSyncCpuArguments( - _RunnableImageDeploymentArguments, - _SyncRunnableImageDeploymentArguments, - _TritonArguments, + _RunnableImageDeploymentArguments, _SyncRunnableImageDeploymentArguments, _TritonArguments ): """Keyword-arguments for substituting into CPU sync deployment templates for triton-enhanced runnable images. @@ -310,10 +268,7 @@ class DeploymentTritonEnhancedRunnableImageAsyncCpuArguments( class DeploymentTritonEnhancedRunnableImageAsyncGpuArguments( - _RunnableImageDeploymentArguments, - _AsyncDeploymentArguments, - _GpuArguments, - _TritonArguments, + _RunnableImageDeploymentArguments, _AsyncDeploymentArguments, _GpuArguments, _TritonArguments ): """Keyword-arguments for substituting GPU async deployment templates for triton-enhanced runnable images. @@ -355,6 +310,10 @@ class ServiceArguments(_BaseEndpointArguments): NODE_PORT_DICT: DictStrInt +class DestinationRuleArguments(_BaseEndpointArguments): + """Keyword-arguments for substituting into destination-rule templates.""" + + class VerticalPodAutoscalerArguments(_BaseEndpointArguments): """Keyword-arguments for substituting into vertical pod autoscaler templates.""" @@ -362,6 +321,12 @@ class VerticalPodAutoscalerArguments(_BaseEndpointArguments): MEMORY: str +class VirtualServiceArguments(_BaseEndpointArguments): + """Keyword-arguments for substituting into virtual-service templates.""" + + DNS_HOST_DOMAIN: str + + class BatchJobOrchestrationJobArguments(_JobArguments): """Keyword-arguments for substituting into batch-job-orchestration-job templates.""" @@ -385,6 +350,23 @@ class ImageCacheArguments(TypedDict): NAMESPACE: str +class CronTriggerArguments(TypedDict): + """Keyword-arguments for substituting into cronjob trigger templates.""" + + HOST: str + NAME: str + CREATED_BY: str + OWNER: str + TEAM: str + PRODUCT: str + TRIGGER_ID: str + CRON_SCHEDULE: str + DOCKER_IMAGE_BATCH_JOB_BUNDLE_ID: str + JOB_CONFIG: str + JOB_METADATA: str + BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS: int + + class CommonEndpointParams(TypedDict): cpus: str memory: str @@ -412,10 +394,6 @@ class VerticalAutoscalingEndpointParams(TypedDict): EndpointResourceArguments = Union[ - DeploymentArtifactAsyncCpuArguments, - DeploymentArtifactAsyncGpuArguments, - DeploymentArtifactSyncCpuArguments, - DeploymentArtifactSyncGpuArguments, DeploymentRunnableImageAsyncCpuArguments, DeploymentRunnableImageAsyncGpuArguments, DeploymentRunnableImageStreamingCpuArguments, @@ -426,11 +404,13 @@ class VerticalAutoscalingEndpointParams(TypedDict): DeploymentTritonEnhancedRunnableImageAsyncGpuArguments, DeploymentTritonEnhancedRunnableImageSyncCpuArguments, DeploymentTritonEnhancedRunnableImageSyncGpuArguments, + DestinationRuleArguments, EndpointConfigArguments, HorizontalPodAutoscalerArguments, ServiceArguments, UserConfigArguments, VerticalPodAutoscalerArguments, + VirtualServiceArguments, ] ResourceArguments = Union[ @@ -439,16 +419,20 @@ class VerticalAutoscalingEndpointParams(TypedDict): DockerImageBatchJobCpuArguments, DockerImageBatchJobGpuArguments, ImageCacheArguments, + CronTriggerArguments, ] +def compute_image_hash(image: str) -> str: + return str(hashlib.md5(str(image).encode()).hexdigest())[:KUBERNETES_MAX_LENGTH] + + def container_start_triton_cmd( triton_model_repository: str, - triton_model_replicas: Dict[str, int], + triton_model_replicas: Union[Dict[str, int], Dict[str, str]], ipv6_healthcheck: bool = False, ) -> List[str]: - # NOTE: this path is set in the Trtion-specific Dockerfile: - # std-ml-srv/ml_serve/triton/Dockerfile + # NOTE: this path is set in the Triton-specific Dockerfile: triton_start_command: List[str] = [ "python", "/install/tritonserver.py", @@ -485,15 +469,8 @@ def get_endpoint_resource_arguments_from_request( team = k8s_labels.get("team", "") product = k8s_labels.get("product", "") storage = build_endpoint_request.storage - prewarm = bool_to_str(build_endpoint_request.prewarm) or "false" - sqs_profile = "default" # TODO: Make this configurable - s3_bucket = ml_infra_config().s3_bucket - - load_predict_fn_module_path = "" - load_model_fn_module_path = "" - if isinstance(flavor, ZipArtifactFlavor): - load_predict_fn_module_path = flavor.load_predict_fn_module_path - load_model_fn_module_path = flavor.load_model_fn_module_path + sqs_profile = f"eks-{infra_config().profile_ml_worker}" # TODO: Make this configurable + s3_bucket = infra_config().s3_bucket storage_dict = DictStrStr("") if storage is not None: @@ -506,11 +483,11 @@ def get_endpoint_resource_arguments_from_request( f"endpoint ID: {model_endpoint_record.id}" ) - priority = LLM_ENGINE_DEFAULT_PRIORITY_CLASS + priority = LAUNCH_DEFAULT_PRIORITY_CLASS if build_endpoint_request.high_priority: - priority = LLM_ENGINE_HIGH_PRIORITY_CLASS + priority = LAUNCH_HIGH_PRIORITY_CLASS - image_hash = str(hashlib.md5(str(request.image).encode()).hexdigest())[:KUBERNETES_MAX_LENGTH] + image_hash = compute_image_hash(request.image) # In Circle CI, we use Redis on localhost instead of SQS broker_name = BrokerName.SQS.value if not CIRCLECI else BrokerName.REDIS.value @@ -538,7 +515,9 @@ def get_endpoint_resource_arguments_from_request( raise ValueError( "flavor.env['BASE_PATH'] is required for runnable image converted from artifact like bundle" ) - infra_service_config_volume_mount_path = f"{flavor.env['BASE_PATH']}/ml_infra_core/llm_engine_server.core/llm_engine_server.core/configs" + infra_service_config_volume_mount_path = ( + f"{flavor.env['BASE_PATH']}/model-engine/model_engine_server/core/configs" + ) forwarder_config_file_name = "service--forwarder-runnable-img-converted-from-artifact.yaml" triton_command = "" @@ -568,6 +547,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -588,7 +568,6 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -616,6 +595,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -636,7 +616,6 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -666,6 +645,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -687,7 +667,6 @@ def get_endpoint_resource_arguments_from_request( STREAMING_PREDICT_ROUTE=flavor.streaming_predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -710,6 +689,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -731,7 +711,6 @@ def get_endpoint_resource_arguments_from_request( STREAMING_PREDICT_ROUTE=flavor.streaming_predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -756,6 +735,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -776,7 +756,6 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -799,6 +778,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -819,7 +799,6 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -844,6 +823,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -864,7 +844,6 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -900,6 +879,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -920,7 +900,6 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -958,6 +937,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -978,7 +958,6 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -1009,6 +988,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, @@ -1029,7 +1009,6 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, @@ -1050,176 +1029,6 @@ def get_endpoint_resource_arguments_from_request( TRITON_COMMAND=triton_command, TRITON_COMMIT_TAG=flavor.triton_commit_tag, ) - elif endpoint_resource_name == "deployment-artifact-async-cpu": - assert isinstance(flavor, ArtifactLike) - return DeploymentArtifactAsyncCpuArguments( - # Base resource arguments - RESOURCE_NAME=k8s_resource_group_name, - NAMESPACE=hmi_config.endpoint_namespace, - ENDPOINT_ID=model_endpoint_record.id, - ENDPOINT_NAME=model_endpoint_record.name, - TEAM=team, - PRODUCT=product, - CREATED_BY=created_by, - OWNER=owner, - # Base deployment arguments - CHANGE_CAUSE_MESSAGE=change_cause_message, - AWS_ROLE=build_endpoint_request.aws_role, - PRIORITY=priority, - IMAGE=request.image, - IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, - CPUS=str(build_endpoint_request.cpus), - MEMORY=str(build_endpoint_request.memory), - STORAGE_DICT=storage_dict, - BASE_PATH="/app", - PER_WORKER=build_endpoint_request.per_worker, - MIN_WORKERS=build_endpoint_request.min_workers, - MAX_WORKERS=build_endpoint_request.max_workers, - RESULTS_S3_BUCKET=s3_bucket, - # Artifact Arguments - BUNDLE_URL=flavor.location, - LOAD_PREDICT_FN_MODULE_PATH=load_predict_fn_module_path, - LOAD_MODEL_FN_MODULE_PATH=load_model_fn_module_path, - CHILD_FN_INFO=json.dumps( - build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info else {} - ), - PREWARM=prewarm, - # Async Deployment Arguments - CELERY_S3_BUCKET=s3_bucket, - QUEUE=sqs_queue_name, - BROKER_NAME=broker_name, - BROKER_TYPE=broker_type, - SQS_QUEUE_URL=sqs_queue_url, - SQS_PROFILE=sqs_profile, - ) - elif endpoint_resource_name == "deployment-artifact-async-gpu": - assert isinstance(flavor, ArtifactLike) - assert build_endpoint_request.gpu_type is not None - return DeploymentArtifactAsyncGpuArguments( - # Base resource arguments - RESOURCE_NAME=k8s_resource_group_name, - NAMESPACE=hmi_config.endpoint_namespace, - ENDPOINT_ID=model_endpoint_record.id, - ENDPOINT_NAME=model_endpoint_record.name, - TEAM=team, - PRODUCT=product, - CREATED_BY=created_by, - OWNER=owner, - # Base deployment arguments - CHANGE_CAUSE_MESSAGE=change_cause_message, - AWS_ROLE=build_endpoint_request.aws_role, - PRIORITY=priority, - IMAGE=request.image, - IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, - CPUS=str(build_endpoint_request.cpus), - MEMORY=str(build_endpoint_request.memory), - STORAGE_DICT=storage_dict, - BASE_PATH="/app", - PER_WORKER=build_endpoint_request.per_worker, - MIN_WORKERS=build_endpoint_request.min_workers, - MAX_WORKERS=build_endpoint_request.max_workers, - RESULTS_S3_BUCKET=s3_bucket, - # Artifact Arguments - BUNDLE_URL=flavor.location, - LOAD_PREDICT_FN_MODULE_PATH=load_predict_fn_module_path, - LOAD_MODEL_FN_MODULE_PATH=load_model_fn_module_path, - CHILD_FN_INFO=json.dumps( - build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info else {} - ), - PREWARM=prewarm, - # Async Deployment Arguments - CELERY_S3_BUCKET=s3_bucket, - QUEUE=sqs_queue_name, - BROKER_NAME=broker_name, - BROKER_TYPE=broker_type, - SQS_QUEUE_URL=sqs_queue_url, - SQS_PROFILE=sqs_profile, - # GPU Deployment Arguments - GPU_TYPE=build_endpoint_request.gpu_type.value, - GPUS=build_endpoint_request.gpus, - ) - elif endpoint_resource_name == "deployment-artifact-sync-cpu": - assert isinstance(flavor, ArtifactLike) - return DeploymentArtifactSyncCpuArguments( - # Base resource arguments - RESOURCE_NAME=k8s_resource_group_name, - NAMESPACE=hmi_config.endpoint_namespace, - ENDPOINT_ID=model_endpoint_record.id, - ENDPOINT_NAME=model_endpoint_record.name, - TEAM=team, - PRODUCT=product, - CREATED_BY=created_by, - OWNER=owner, - # Base deployment arguments - CHANGE_CAUSE_MESSAGE=change_cause_message, - AWS_ROLE=build_endpoint_request.aws_role, - PRIORITY=priority, - IMAGE=request.image, - IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, - CPUS=str(build_endpoint_request.cpus), - MEMORY=str(build_endpoint_request.memory), - STORAGE_DICT=storage_dict, - BASE_PATH="/app", - PER_WORKER=build_endpoint_request.per_worker, - MIN_WORKERS=build_endpoint_request.min_workers, - MAX_WORKERS=build_endpoint_request.max_workers, - RESULTS_S3_BUCKET=s3_bucket, - # Artifact Arguments - BUNDLE_URL=flavor.location, - LOAD_PREDICT_FN_MODULE_PATH=load_predict_fn_module_path, - LOAD_MODEL_FN_MODULE_PATH=load_model_fn_module_path, - CHILD_FN_INFO=json.dumps( - build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info else {} - ), - PREWARM=prewarm, - # Sync Artifact DeploymentArguments Arguments - ARTIFACT_LIKE_CONTAINER_PORT=ARTIFACT_LIKE_CONTAINER_PORT, - ) - elif endpoint_resource_name == "deployment-artifact-sync-gpu": - assert isinstance(flavor, ArtifactLike) - assert build_endpoint_request.gpu_type is not None - return DeploymentArtifactSyncGpuArguments( - # Base resource arguments - RESOURCE_NAME=k8s_resource_group_name, - NAMESPACE=hmi_config.endpoint_namespace, - ENDPOINT_ID=model_endpoint_record.id, - ENDPOINT_NAME=model_endpoint_record.name, - TEAM=team, - PRODUCT=product, - CREATED_BY=created_by, - OWNER=owner, - # Base deployment arguments - CHANGE_CAUSE_MESSAGE=change_cause_message, - AWS_ROLE=build_endpoint_request.aws_role, - PRIORITY=priority, - IMAGE=request.image, - IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, - CPUS=str(build_endpoint_request.cpus), - MEMORY=str(build_endpoint_request.memory), - STORAGE_DICT=storage_dict, - BASE_PATH="/app", - PER_WORKER=build_endpoint_request.per_worker, - MIN_WORKERS=build_endpoint_request.min_workers, - MAX_WORKERS=build_endpoint_request.max_workers, - RESULTS_S3_BUCKET=s3_bucket, - # Artifact Arguments - BUNDLE_URL=flavor.location, - LOAD_PREDICT_FN_MODULE_PATH=load_predict_fn_module_path, - LOAD_MODEL_FN_MODULE_PATH=load_model_fn_module_path, - CHILD_FN_INFO=json.dumps( - build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info else {} - ), - PREWARM=prewarm, - # Sync Artifact DeploymentArguments Arguments - ARTIFACT_LIKE_CONTAINER_PORT=ARTIFACT_LIKE_CONTAINER_PORT, - # GPU Deployment Arguments - GPU_TYPE=build_endpoint_request.gpu_type.value, - GPUS=build_endpoint_request.gpus, - ) elif endpoint_resource_name == "user-config": app_config_serialized = python_json_to_b64(model_bundle.app_config) return UserConfigArguments( @@ -1232,6 +1041,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, CONFIG_DATA_SERIALIZED=app_config_serialized, ) elif endpoint_resource_name == "endpoint-config": @@ -1240,6 +1050,8 @@ def get_endpoint_resource_arguments_from_request( bundle_name=model_bundle.name, post_inference_hooks=build_endpoint_request.post_inference_hooks, user_id=user_id, + billing_queue=hmi_config.billing_queue_arn, + billing_tags=build_endpoint_request.billing_tags, default_callback_url=build_endpoint_request.default_callback_url, default_callback_auth=build_endpoint_request.default_callback_auth, ).serialize() @@ -1253,6 +1065,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, ENDPOINT_CONFIG_SERIALIZED=endpoint_config_serialized, ) elif endpoint_resource_name == "horizontal-pod-autoscaler": @@ -1269,6 +1082,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, API_VERSION=api_version, # Autoscaler arguments CONCURRENCY=concurrency, @@ -1294,11 +1108,39 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Service arguments NODE_PORT_DICT=node_port_dict, SERVICE_TYPE=service_type, SERVICE_TARGET_PORT=FORWARDER_PORT, ) + elif endpoint_resource_name == "virtual-service": + return VirtualServiceArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + DNS_HOST_DOMAIN=infra_config().dns_host_domain, + ) + elif endpoint_resource_name == "destination-rule": + return DestinationRuleArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + ) elif endpoint_resource_name == "vertical-pod-autoscaler": return VerticalPodAutoscalerArguments( # Base resource arguments @@ -1310,6 +1152,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), ) diff --git a/server/llm_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py similarity index 84% rename from server/llm_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py rename to model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index 33880fb7..838cc592 100644 --- a/server/llm_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -1,23 +1,23 @@ from typing import Dict, Optional, Tuple -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import ( ModelEndpointInfraState, ModelEndpointRecord, ModelEndpointType, ) -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, EndpointResourceGatewayCreateOrUpdateResourcesResponse, QueueInfo, ) -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( K8SEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( SQSEndpointResourceDelegate, ) @@ -69,7 +69,7 @@ async def create_or_update_resources( queue_url: Optional[str] = q.queue_url destination: str = q.queue_name else: - destination = f"llm-engine-endpoint-id-{endpoint_record.id.replace('_', '-')}" + destination = f"launch-endpoint-id-{endpoint_record.id.replace('_', '-')}" queue_name = None queue_url = None diff --git a/server/llm_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py similarity index 94% rename from server/llm_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py rename to model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py index 87dac35a..fae21d5e 100644 --- a/server/llm_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py @@ -5,11 +5,11 @@ import botocore.exceptions from aioboto3 import Session as AioSession from aiobotocore.client import AioBaseClient -from llm_engine_server.common.config import hmi_config -from llm_engine_server.core.aws.roles import session -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.core.aws.roles import session +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( SQSEndpointResourceDelegate, SQSQueueInfo, ) diff --git a/server/llm_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py similarity index 95% rename from server/llm_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py rename to model-engine/model_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py index 85aaa945..de3a59e3 100644 --- a/server/llm_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py @@ -45,4 +45,4 @@ async def get_queue_attributes(self, endpoint_id: str) -> GetQueueAttributesResu @staticmethod def endpoint_id_to_queue_name(endpoint_id: str) -> str: - return f"llm-engine-endpoint-id-{endpoint_id}" + return f"launch-endpoint-id-{endpoint_id}" diff --git a/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml similarity index 69% rename from server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml rename to model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 3f2e519f..1f712fdb 100644 --- a/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -1,17 +1,17 @@ --- -# Source: llm-engine/templates/service_template_config_map.yaml +# Source: model-engine/templates/service_template_config_map.yaml # THIS FILE IS AUTOGENERATED USING `just autogen-templates`. PLEASE EDIT THE GOTEMPLATE FILE IN THE HELM CHART!!! apiVersion: v1 kind: ConfigMap metadata: - name: llm-engine-service-template-config + name: model-engine-service-template-config labels: team: infra - product: llm-engine - helm.sh/chart: llm-engine-0.1.0 + product: launch + helm.sh/chart: model-engine-0.1.0 app.kubernetes.io/managed-by: Helm - app.kubernetes.io/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + app.kubernetes.io/version: 7034db9f84a3a6009d2ef738e5497b300f24f6cd + tags.datadoghq.com/version: 7034db9f84a3a6009d2ef738e5497b300f24f6cd tags.datadoghq.com/env: circleci annotations: "helm.sh/hook": pre-install,pre-upgrade @@ -30,10 +30,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -65,10 +65,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -106,15 +106,17 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -123,7 +125,7 @@ data: - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" - --set - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --concurrency + - --num-workers - "${PER_WORKER}" env: - name: DATADOG_TRACE_ENABLED @@ -133,7 +135,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -145,7 +147,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: CELERY_QUEUE value: "${QUEUE}" - name: CELERY_TASK_VISIBILITY @@ -174,7 +176,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs - name: tritonserver image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton imagePullPolicy: IfNotPresent @@ -253,7 +255,7 @@ data: mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -282,7 +284,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -299,10 +301,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -334,10 +336,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -375,15 +377,17 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -392,7 +396,7 @@ data: - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" - --set - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --concurrency + - --num-workers - "${PER_WORKER}" env: - name: DATADOG_TRACE_ENABLED @@ -402,7 +406,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -414,7 +418,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: CELERY_QUEUE value: "${QUEUE}" - name: CELERY_TASK_VISIBILITY @@ -443,7 +447,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs - name: main securityContext: capabilities: @@ -478,7 +482,7 @@ data: mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -507,212 +511,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config - items: - - key: infra_service_config - path: config.yaml - deployment-artifact-async-cpu.yaml: |- - apiVersion: apps/v1 - kind: Deployment - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - annotations: - celery.scaleml.autoscaler/queue: ${QUEUE} - celery.scaleml.autoscaler/broker: ${BROKER_NAME} - celery.scaleml.autoscaler/taskVisibility: "VISIBILITY_24H" - celery.scaleml.autoscaler/perWorker: "${PER_WORKER}" - celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}" - celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}" - spec: - strategy: - type: RollingUpdate - rollingUpdate: - maxSurge: 1 - maxUnavailable: 0 - replicas: ${MIN_WORKERS} - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - template: - metadata: - labels: - app: ${RESOURCE_NAME} - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - sidecar.istio.io/inject: "false" # TODO: switch to scuttle - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' - kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" - spec: - affinity: - podAffinity: - preferredDuringSchedulingIgnoredDuringExecution: - - weight: 1 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: app - operator: In - values: - - ${RESOURCE_NAME} - topologyKey: kubernetes.io/hostname - - weight: 100 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: ${IMAGE_HASH} - operator: In - values: - - "True" - topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 - serviceAccount: default - nodeSelector: - node-lifecycle: normal - priorityClassName: ${PRIORITY} - containers: - - image: ${IMAGE} - imagePullPolicy: IfNotPresent - name: main - securityContext: - capabilities: - drop: - - all - env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: OMP_NUM_THREADS - value: "1" - - name: BASE_PATH - value: "${BASE_PATH}" - - name: BUNDLE_URL - value: "${BUNDLE_URL}" - - name: LOAD_PREDICT_FN_MODULE_PATH - value: "${LOAD_PREDICT_FN_MODULE_PATH}" - - name: LOAD_MODEL_FN_MODULE_PATH - value: "${LOAD_MODEL_FN_MODULE_PATH}" - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: CHILD_FN_INFO - value: "${CHILD_FN_INFO}" - - name: PREWARM - value: "${PREWARM}" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: CELERY_S3_BUCKET - value: "${CELERY_S3_BUCKET}" - - name: BROKER_TYPE - value: "${BROKER_TYPE}" - - name: SQS_PROFILE - value: "${SQS_PROFILE}" - - name: SQS_QUEUE_NAME - value: "${QUEUE}" - - name: SQS_QUEUE_URL - value: "${SQS_QUEUE_URL}" - readinessProbe: - exec: - command: - - cat - - /tmp/readyz - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - # Not including --pool=solo means there's a worker process and a separate supervisor process - # meaning if the worker crashes (because of OOM or something) the supervisor process can mark the task as - # failed, which should get rid of infinite task retries - args: - - celery - - --app=llm_engine.inference.async_inference - - worker - - --loglevel=INFO - - --concurrency=1 - - --queues=${QUEUE} - - -O - - fair - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: ${BASE_PATH}/user_config - subPath: raw_data - - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config - subPath: raw_data - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs - # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 - securityContext: - fsGroup: 65534 - volumes: - - name: config-volume - configMap: - name: default-config - - name: user-config - configMap: - name: ${RESOURCE_NAME} - - name: endpoint-config - configMap: - name: ${RESOURCE_NAME}-endpoint-config - - name: dshm - emptyDir: - medium: Memory - - name: infra-service-config-volume - configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -729,10 +528,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -757,10 +556,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -797,25 +596,27 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --port - "${FORWARDER_PORT}" - - --concurrency + - --num-workers - "${PER_WORKER}" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - name: DATADOG_TRACE_ENABLED value: "${DATADOG_TRACE_ENABLED}" @@ -824,7 +625,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -836,7 +637,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -867,7 +668,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http @@ -949,7 +750,7 @@ data: mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -978,7 +779,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -995,10 +796,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1023,10 +824,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1063,25 +864,27 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --port - "${FORWARDER_PORT}" - - --concurrency + - --num-workers - "${PER_WORKER}" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - name: DATADOG_TRACE_ENABLED value: "${DATADOG_TRACE_ENABLED}" @@ -1090,7 +893,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -1102,7 +905,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -1133,7 +936,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http @@ -1171,7 +974,7 @@ data: mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -1200,187 +1003,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config - items: - - key: infra_service_config - path: config.yaml - deployment-artifact-sync-cpu.yaml: |- - apiVersion: apps/v1 - kind: Deployment - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - spec: - strategy: - type: RollingUpdate - rollingUpdate: - maxSurge: 1 - maxUnavailable: 0 - replicas: ${MIN_WORKERS} - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - template: - metadata: - labels: - app: ${RESOURCE_NAME} - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' - kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" - spec: - affinity: - podAffinity: - preferredDuringSchedulingIgnoredDuringExecution: - - weight: 1 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: app - operator: In - values: - - ${RESOURCE_NAME} - topologyKey: kubernetes.io/hostname - - weight: 100 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: ${IMAGE_HASH} - operator: In - values: - - "True" - topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 - serviceAccount: default - nodeSelector: - node-lifecycle: normal - priorityClassName: ${PRIORITY} - containers: - - image: ${IMAGE} - imagePullPolicy: IfNotPresent - name: main - securityContext: - capabilities: - drop: - - all - env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: OMP_NUM_THREADS - value: "1" - - name: BASE_PATH - value: "${BASE_PATH}" - - name: BUNDLE_URL - value: "${BUNDLE_URL}" - - name: LOAD_PREDICT_FN_MODULE_PATH - value: "${LOAD_PREDICT_FN_MODULE_PATH}" - - name: LOAD_MODEL_FN_MODULE_PATH - value: "${LOAD_MODEL_FN_MODULE_PATH}" - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: CHILD_FN_INFO - value: "${CHILD_FN_INFO}" - - name: PREWARM - value: "${PREWARM}" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: PORT - value: "${ARTIFACT_LIKE_CONTAINER_PORT}" - readinessProbe: - httpGet: - path: /readyz - port: ${ARTIFACT_LIKE_CONTAINER_PORT} - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - args: - - python - - -m - - llm_engine.inference.sync_inference.start_fastapi_server - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: ${BASE_PATH}/user_config - subPath: raw_data - - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config - subPath: raw_data - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs - # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 - securityContext: - fsGroup: 65534 - volumes: - - name: config-volume - configMap: - name: default-config - - name: user-config - configMap: - name: ${RESOURCE_NAME} - - name: endpoint-config - configMap: - name: ${RESOURCE_NAME}-endpoint-config - - name: dshm - emptyDir: - medium: Memory - - name: infra-service-config-volume - configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -1397,10 +1020,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1425,10 +1048,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1465,7 +1088,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:7034db9f84a3a6009d2ef738e5497b300f24f6cd imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -1473,9 +1096,9 @@ data: - ddtrace-run - python - -m - - server.llm_engine_server.inference.forwarding.http_forwarder + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/service--http_forwarder.yaml + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml - --port - "${FORWARDER_PORT}" - --num-workers @@ -1496,7 +1119,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -1508,7 +1131,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -1539,7 +1162,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http @@ -1577,7 +1200,7 @@ data: mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -1606,7 +1229,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -1623,10 +1246,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1658,10 +1281,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1704,15 +1327,17 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -1721,7 +1346,7 @@ data: - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" - --set - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --concurrency + - --num-workers - "${PER_WORKER}" env: - name: DATADOG_TRACE_ENABLED @@ -1731,7 +1356,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -1743,7 +1368,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: CELERY_QUEUE value: "${QUEUE}" - name: CELERY_TASK_VISIBILITY @@ -1772,7 +1397,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs - name: tritonserver image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton imagePullPolicy: IfNotPresent @@ -1852,7 +1477,7 @@ data: mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -1881,7 +1506,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -1898,10 +1523,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1933,10 +1558,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1979,15 +1604,17 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -1996,7 +1623,7 @@ data: - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" - --set - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --concurrency + - --num-workers - "${PER_WORKER}" env: - name: DATADOG_TRACE_ENABLED @@ -2006,7 +1633,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -2018,7 +1645,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: CELERY_QUEUE value: "${QUEUE}" - name: CELERY_TASK_VISIBILITY @@ -2045,241 +1672,24 @@ data: subPath: raw_data - name: endpoint-config mountPath: /workspace/endpoint_config - subPath: raw_data - - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs - - name: main - securityContext: - capabilities: - drop: - - all - image: ${IMAGE} - imagePullPolicy: IfNotPresent - command: ${COMMAND} - env: ${MAIN_ENV} - readinessProbe: - httpGet: - path: ${HEALTHCHECK_ROUTE} - port: ${USER_CONTAINER_PORT} - initialDelaySeconds: ${READINESS_INITIAL_DELAY} - periodSeconds: 5 - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - nvidia.com/gpu: ${GPUS} - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - mountPath: /dev/shm - name: dshm - - name: infra-service-config-volume - mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: /app/user_config - subPath: raw_data - - name: endpoint-config - mountPath: /app/endpoint_config - subPath: raw_data - ports: - - containerPort: ${USER_CONTAINER_PORT} - name: http - # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 - securityContext: - fsGroup: 65534 - volumes: - - name: config-volume - configMap: - name: default-config - - name: user-config - configMap: - name: ${RESOURCE_NAME} - - name: endpoint-config - configMap: - name: ${RESOURCE_NAME}-endpoint-config - - name: dshm - emptyDir: - medium: Memory - - name: infra-service-config-volume - configMap: - name: llm-engine-service-config - items: - - key: infra_service_config - path: config.yaml - deployment-artifact-async-gpu.yaml: |- - apiVersion: apps/v1 - kind: Deployment - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - annotations: - celery.scaleml.autoscaler/queue: ${QUEUE} - celery.scaleml.autoscaler/broker: ${BROKER_NAME} - celery.scaleml.autoscaler/taskVisibility: "VISIBILITY_24H" - celery.scaleml.autoscaler/perWorker: "${PER_WORKER}" - celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}" - celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}" - spec: - strategy: - type: RollingUpdate - rollingUpdate: - maxSurge: 1 - maxUnavailable: 0 - replicas: ${MIN_WORKERS} - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - template: - metadata: - labels: - app: ${RESOURCE_NAME} - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - sidecar.istio.io/inject: "false" # TODO: switch to scuttle - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' - kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" - spec: - affinity: - podAffinity: - preferredDuringSchedulingIgnoredDuringExecution: - - weight: 1 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: app - operator: In - values: - - ${RESOURCE_NAME} - topologyKey: kubernetes.io/hostname - - weight: 100 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: ${IMAGE_HASH} - operator: In - values: - - "True" - topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 - serviceAccount: default - nodeSelector: - node-lifecycle: normal - k8s.amazonaws.com/accelerator: ${GPU_TYPE} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - priorityClassName: ${PRIORITY} - containers: - - image: ${IMAGE} - imagePullPolicy: IfNotPresent - name: main - securityContext: - capabilities: - drop: - - all - env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: OMP_NUM_THREADS - value: "1" - - name: BASE_PATH - value: "${BASE_PATH}" - - name: BUNDLE_URL - value: "${BUNDLE_URL}" - - name: LOAD_PREDICT_FN_MODULE_PATH - value: "${LOAD_PREDICT_FN_MODULE_PATH}" - - name: LOAD_MODEL_FN_MODULE_PATH - value: "${LOAD_MODEL_FN_MODULE_PATH}" - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: CHILD_FN_INFO - value: "${CHILD_FN_INFO}" - - name: PREWARM - value: "${PREWARM}" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: CELERY_S3_BUCKET - value: "${CELERY_S3_BUCKET}" - - name: BROKER_TYPE - value: "${BROKER_TYPE}" - - name: SQS_PROFILE - value: "${SQS_PROFILE}" - - name: SQS_QUEUE_NAME - value: "${QUEUE}" - - name: SQS_QUEUE_URL - value: "${SQS_QUEUE_URL}" + subPath: raw_data + - name: infra-service-config-volume + mountPath: /workspace/model-engine/model_engine_server/core/configs + - name: main + securityContext: + capabilities: + drop: + - all + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + env: ${MAIN_ENV} readinessProbe: - exec: - command: - - cat - - /tmp/readyz - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - # Not including --pool=solo means there's a worker process and a separate supervisor process - # meaning if the worker crashes (because of OOM or something) the supervisor process can mark the task as - # failed, which should get rid of infinite task retries - args: - - celery - - --app=llm_engine.inference.async_inference - - worker - - --loglevel=INFO - - --concurrency=1 - - --queues=${QUEUE} - - -O - - fair + httpGet: + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 resources: requests: cpu: ${CPUS} @@ -2294,17 +1704,23 @@ data: - name: config-volume mountPath: /root/.aws/config subPath: config + - mountPath: /dev/shm + name: dshm + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config - mountPath: ${BASE_PATH}/user_config + mountPath: /app/user_config subPath: raw_data - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config + mountPath: /app/endpoint_config subPath: raw_data - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 securityContext: fsGroup: 65534 @@ -2323,7 +1739,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -2340,10 +1756,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2368,10 +1784,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2413,25 +1829,27 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --port - "${FORWARDER_PORT}" - - --concurrency + - --num-workers - "${PER_WORKER}" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - name: DATADOG_TRACE_ENABLED value: "${DATADOG_TRACE_ENABLED}" @@ -2440,7 +1858,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -2452,7 +1870,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -2483,7 +1901,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http @@ -2566,7 +1984,7 @@ data: mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -2595,7 +2013,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -2612,10 +2030,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2640,10 +2058,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2685,25 +2103,27 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --port - "${FORWARDER_PORT}" - - --concurrency + - --num-workers - "${PER_WORKER}" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - name: DATADOG_TRACE_ENABLED value: "${DATADOG_TRACE_ENABLED}" @@ -2712,7 +2132,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -2724,7 +2144,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -2755,7 +2175,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http @@ -2794,7 +2214,7 @@ data: mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -2823,193 +2243,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config - items: - - key: infra_service_config - path: config.yaml - deployment-artifact-sync-gpu.yaml: |- - apiVersion: apps/v1 - kind: Deployment - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - spec: - strategy: - type: RollingUpdate - rollingUpdate: - maxSurge: 1 - maxUnavailable: 0 - replicas: ${MIN_WORKERS} - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - template: - metadata: - labels: - app: ${RESOURCE_NAME} - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' - kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" - spec: - affinity: - podAffinity: - preferredDuringSchedulingIgnoredDuringExecution: - - weight: 1 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: app - operator: In - values: - - ${RESOURCE_NAME} - topologyKey: kubernetes.io/hostname - - weight: 100 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: ${IMAGE_HASH} - operator: In - values: - - "True" - topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 - serviceAccount: default - nodeSelector: - node-lifecycle: normal - k8s.amazonaws.com/accelerator: ${GPU_TYPE} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - priorityClassName: ${PRIORITY} - containers: - - image: ${IMAGE} - imagePullPolicy: IfNotPresent - name: main - securityContext: - capabilities: - drop: - - all - env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: OMP_NUM_THREADS - value: "1" - - name: BASE_PATH - value: "${BASE_PATH}" - - name: BUNDLE_URL - value: "${BUNDLE_URL}" - - name: LOAD_PREDICT_FN_MODULE_PATH - value: "${LOAD_PREDICT_FN_MODULE_PATH}" - - name: LOAD_MODEL_FN_MODULE_PATH - value: "${LOAD_MODEL_FN_MODULE_PATH}" - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: CHILD_FN_INFO - value: "${CHILD_FN_INFO}" - - name: PREWARM - value: "${PREWARM}" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: PORT - value: "${ARTIFACT_LIKE_CONTAINER_PORT}" - readinessProbe: - httpGet: - path: /readyz - port: ${ARTIFACT_LIKE_CONTAINER_PORT} - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - args: - - python - - -m - - llm_engine.inference.sync_inference.start_fastapi_server - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - nvidia.com/gpu: ${GPUS} - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: ${BASE_PATH}/user_config - subPath: raw_data - - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config - subPath: raw_data - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs - # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 - securityContext: - fsGroup: 65534 - volumes: - - name: config-volume - configMap: - name: default-config - - name: user-config - configMap: - name: ${RESOURCE_NAME} - - name: endpoint-config - configMap: - name: ${RESOURCE_NAME}-endpoint-config - - name: dshm - emptyDir: - medium: Memory - - name: infra-service-config-volume - configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -3026,10 +2260,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3054,10 +2288,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3099,7 +2333,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:7034db9f84a3a6009d2ef738e5497b300f24f6cd imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -3107,9 +2341,9 @@ data: - ddtrace-run - python - -m - - server.llm_engine_server.inference.forwarding.http_forwarder + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/server/llm_engine_server/inference/configs/service--http_forwarder.yaml + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml - --port - "${FORWARDER_PORT}" - --num-workers @@ -3130,7 +2364,7 @@ data: - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -3142,7 +2376,7 @@ data: - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -3173,7 +2407,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http @@ -3212,7 +2446,7 @@ data: mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -3241,7 +2475,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -3258,10 +2492,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3280,10 +2514,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3302,10 +2536,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3337,10 +2571,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3354,6 +2588,62 @@ data: protocol: TCP name: http ${NODE_PORT_DICT} + virtual-service.yaml: |- + apiVersion: networking.istio.io/v1alpha3 + kind: VirtualService + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + hosts: + - ${RESOURCE_NAME}.${DNS_HOST_DOMAIN} + gateways: + - default/internal-gateway + http: + - route: + - destination: + host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" + port: + number: 80 + destination-rule.yaml: |- + apiVersion: networking.istio.io/v1beta1 + kind: DestinationRule + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" + trafficPolicy: + loadBalancer: + simple: LEAST_REQUEST vertical-pod-autoscaler.yaml: |- apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler @@ -3366,10 +2656,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3404,11 +2694,11 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} tags.datadoghq.com/service: ${JOB_ID} spec: backoffLimit: 0 @@ -3423,62 +2713,63 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} tags.datadoghq.com/service: ${JOB_ID} sidecar.istio.io/inject: "false" version: v1 annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "llm_engine_job_id:${JOB_ID}"]}]' + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "launch_job_id:${JOB_ID}"]}]' cluster-autoscaler.kubernetes.io/safe-to-evict: "false" spec: restartPolicy: Never nodeSelector: node-lifecycle: normal - serviceAccountName: llm-engine + serviceAccountName: model-engine volumes: - name: config-volume configMap: name: default-config containers: - name: main - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} env: - name: DD_SERVICE value: ${RESOURCE_NAME} - name: DATADOG_TRACE_ENABLED - value: "true" + value: "false" - name: DD_ENV value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - - name: GIT_TAG - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: SERVICE_IDENTIFIER + - name: GATEWAY_URL + value: http://model-engine.default:80 - name: AWS_PROFILE value: default - name: ECR_READ_AWS_PROFILE value: default - - name: ML_INFRA_DATABASE_URL - valueFrom: - secretKeyRef: - key: database_url - name: ml-infra-pg + - name: S3_WRITE_AWS_PROFILE + value: default + - name: DB_SECRET_NAME + value: prod/ml_infra_pg - name: DEPLOY_SERVICE_CONFIG_PATH - value: /workspace/llm_engine/service_configs/service_config.yaml + value: /workspace/model-engine/service_configs/service_config.yaml - name: ML_INFRA_SERVICES_CONFIG_PATH - value: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml + value: /workspace/model-engine/model_engine_server/core/configs/config.yaml - name: CELERY_ELASTICACHE_ENABLED value: "true" - - name: LLM_ENGINE_SERVICE_TEMPLATE_FOLDER - value: /workspace/llm_engine/llm_engine/infra/gateways/resources/templates + - name: LAUNCH_SERVICE_TEMPLATE_FOLDER + value: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + - name: DD_VERSION + value: ${GIT_TAG} + - name: GIT_TAG + value: ${GIT_TAG} imagePullPolicy: Always command: - dumb-init @@ -3487,7 +2778,7 @@ data: args: - python - -m - - server.llm_engine_server.entrypoints.start_batch_job_orchestration + - model_engine_server.entrypoints.start_batch_job_orchestration - --job-id - ${JOB_ID} - --owner @@ -3522,11 +2813,11 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} tags.datadoghq.com/service: ${JOB_ID} spec: backoffLimit: 0 @@ -3541,16 +2832,16 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} tags.datadoghq.com/service: ${JOB_ID} sidecar.istio.io/inject: "false" version: v1 annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "llm_engine_job_id:${JOB_ID}"]}]' + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "launch_job_id:${JOB_ID}"]}]' spec: restartPolicy: Never nodeSelector: @@ -3572,35 +2863,36 @@ data: - name: DD_SERVICE value: ${RESOURCE_NAME} - name: DATADOG_TRACE_ENABLED - value: "true" + value: "false" - name: DD_ENV value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - - name: GIT_TAG - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: SERVICE_IDENTIFIER + - name: GATEWAY_URL + value: http://model-engine.default:80 - name: AWS_PROFILE value: default - name: ECR_READ_AWS_PROFILE value: default - - name: ML_INFRA_DATABASE_URL - valueFrom: - secretKeyRef: - key: database_url - name: ml-infra-pg + - name: S3_WRITE_AWS_PROFILE + value: default + - name: DB_SECRET_NAME + value: prod/ml_infra_pg - name: DEPLOY_SERVICE_CONFIG_PATH - value: /workspace/llm_engine/service_configs/service_config.yaml + value: /workspace/model-engine/service_configs/service_config.yaml - name: ML_INFRA_SERVICES_CONFIG_PATH - value: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml + value: /workspace/model-engine/model_engine_server/core/configs/config.yaml - name: CELERY_ELASTICACHE_ENABLED value: "true" - - name: LLM_ENGINE_SERVICE_TEMPLATE_FOLDER - value: /workspace/llm_engine/llm_engine/infra/gateways/resources/templates + - name: LAUNCH_SERVICE_TEMPLATE_FOLDER + value: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + - name: DD_VERSION + value: ${GIT_TAG} + - name: GIT_TAG + value: ${GIT_TAG} imagePullPolicy: Always command: ${COMMAND} resources: @@ -3623,11 +2915,11 @@ data: name: dshm initContainers: - name: input-downloader - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} command: - python - -m - - server.llm_engine_server.entrypoints.start_docker_image_batch_job_init_container + - model_engine_server.entrypoints.start_docker_image_batch_job_init_container - ${INPUT_LOCATION} - --remote-file - ${S3_FILE} @@ -3660,11 +2952,11 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} tags.datadoghq.com/service: ${JOB_ID} spec: backoffLimit: 0 @@ -3679,16 +2971,16 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} tags.datadoghq.com/service: ${JOB_ID} sidecar.istio.io/inject: "false" version: v1 annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "llm_engine_job_id:${JOB_ID}"]}]' + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "launch_job_id:${JOB_ID}"]}]' spec: restartPolicy: Never nodeSelector: @@ -3715,35 +3007,36 @@ data: - name: DD_SERVICE value: ${RESOURCE_NAME} - name: DATADOG_TRACE_ENABLED - value: "true" + value: "false" - name: DD_ENV value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - - name: GIT_TAG - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: SERVICE_IDENTIFIER + - name: GATEWAY_URL + value: http://model-engine.default:80 - name: AWS_PROFILE value: default - name: ECR_READ_AWS_PROFILE value: default - - name: ML_INFRA_DATABASE_URL - valueFrom: - secretKeyRef: - key: database_url - name: ml-infra-pg + - name: S3_WRITE_AWS_PROFILE + value: default + - name: DB_SECRET_NAME + value: prod/ml_infra_pg - name: DEPLOY_SERVICE_CONFIG_PATH - value: /workspace/llm_engine/service_configs/service_config.yaml + value: /workspace/model-engine/service_configs/service_config.yaml - name: ML_INFRA_SERVICES_CONFIG_PATH - value: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml + value: /workspace/model-engine/model_engine_server/core/configs/config.yaml - name: CELERY_ELASTICACHE_ENABLED value: "true" - - name: LLM_ENGINE_SERVICE_TEMPLATE_FOLDER - value: /workspace/llm_engine/llm_engine/infra/gateways/resources/templates + - name: LAUNCH_SERVICE_TEMPLATE_FOLDER + value: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + - name: DD_VERSION + value: ${GIT_TAG} + - name: GIT_TAG + value: ${GIT_TAG} imagePullPolicy: Always command: ${COMMAND} resources: @@ -3767,11 +3060,11 @@ data: name: dshm initContainers: - name: input-downloader - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} command: - python - -m - - server.llm_engine_server.entrypoints.start_docker_image_batch_job_init_container + - model_engine_server.entrypoints.start_docker_image_batch_job_init_container - ${INPUT_LOCATION} - --remote-file - ${S3_FILE} @@ -3800,8 +3093,8 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: selector: @@ -3815,8 +3108,8 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 sidecar.istio.io/inject: "false" @@ -3837,8 +3130,8 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: selector: @@ -3852,8 +3145,8 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 sidecar.istio.io/inject: "false" @@ -3878,8 +3171,8 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: selector: @@ -3893,8 +3186,8 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 sidecar.istio.io/inject: "false" @@ -3919,8 +3212,8 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: selector: @@ -3934,8 +3227,8 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 sidecar.istio.io/inject: "false" @@ -3952,3 +3245,53 @@ data: name: busybox command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] terminationGracePeriodSeconds: 0 + cron-trigger.yaml: |- + apiVersion: batch/v1 + kind: CronJob + metadata: + name: ${NAME} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + launch_trigger_id: ${TRIGGER_ID} + tags.datadoghq.com/service: ${TRIGGER_ID} + spec: + schedule: "${CRON_SCHEDULE}" + successfulJobsHistoryLimit: 0 + failedJobsHistoryLimit: 0 + jobTemplate: + spec: + backoffLimit: 0 + activeDeadlineSeconds: ${BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS} + template: + metadata: + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + launch_trigger_id: ${TRIGGER_ID} + tags.datadoghq.com/service: ${TRIGGER_ID} + spec: + containers: + - name: ${NAME} + image: curlimages/curl:7.72.0 + imagePullPolicy: IfNotPresent + command: + - curl + - -X + - 'POST' + - '${HOST}/v1/docker-image-batch-jobs' + - -H + - 'accept: application/json' + - -H + - 'Content-Type: application/json' + - -d + - '{ "docker_image_batch_job_bundle_id": "${DOCKER_IMAGE_BATCH_JOB_BUNDLE_ID}", "job_config": ${JOB_CONFIG}, "labels": ${JOB_METADATA} }' + - -u + - '${OWNER}:' + restartPolicy: Never diff --git a/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py new file mode 100644 index 00000000..7f297f61 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py @@ -0,0 +1,79 @@ +import os +from typing import List, Optional + +from model_engine_server.core.config import infra_config +from model_engine_server.domain.gateways.file_storage_gateway import ( + FileMetadata, + FileStorageGateway, +) +from model_engine_server.infra.gateways import S3FilesystemGateway + + +def get_s3_key(owner: str, file_id: str): + return os.path.join(owner, file_id) + + +def get_s3_url(owner: str, file_id: str): + return f"s3://{infra_config().s3_bucket}/{get_s3_key(owner, file_id)}" + + +class S3FileStorageGateway(FileStorageGateway): + """ + Concrete implementation of a file storage gateway backed by S3. + """ + + def __init__(self): + self.filesystem_gateway = S3FilesystemGateway() + + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + return self.filesystem_gateway.generate_signed_url(get_s3_url(owner, file_id)) + + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + try: + obj = self.filesystem_gateway.get_s3_client({}).head_object( + Bucket=infra_config().s3_bucket, + Key=get_s3_key(owner, file_id), + ) + return FileMetadata( + id=file_id, + filename=get_s3_url(owner, file_id), + size=obj.get("ContentLength"), + owner=owner, + updated_at=obj.get("LastModified"), + ) + except: # noqa: E722 + return None + + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + try: + with self.filesystem_gateway.open( + get_s3_url(owner, file_id), aws_profile=infra_config().profile_ml_worker + ) as f: + return f.read() + except: # noqa: E722 + return None + + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + with self.filesystem_gateway.open( + get_s3_url(owner, filename), mode="w", aws_profile=infra_config().profile_ml_worker + ) as f: + f.write(content) + return filename + + async def delete_file(self, owner: str, file_id: str) -> bool: + try: + self.filesystem_gateway.get_s3_client({}).delete_object( + Bucket=infra_config().s3_bucket, + Key=get_s3_key(owner, file_id), + ) + return True + except: # noqa: E722 + return False + + async def list_files(self, owner: str) -> List[FileMetadata]: + objects = self.filesystem_gateway.get_s3_client({}).list_objects_v2( + Bucket=infra_config().s3_bucket, + Prefix=owner, + ) + files = [await self.get_file(owner, obj["Name"]) for obj in objects] + return [f for f in files if f is not None] diff --git a/server/llm_engine_server/infra/gateways/s3_filesystem_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py similarity index 77% rename from server/llm_engine_server/infra/gateways/s3_filesystem_gateway.py rename to model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py index 4dab06ba..b0bf9e84 100644 --- a/server/llm_engine_server/infra/gateways/s3_filesystem_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py @@ -4,16 +4,15 @@ import boto3 import smart_open - -from . import FilesystemGateway +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway class S3FilesystemGateway(FilesystemGateway): """ - Concrete implemention for interacting with a filesystem backed by S3. + Concrete implementation for interacting with a filesystem backed by S3. """ - def _get_s3_client(self, kwargs): + def get_s3_client(self, kwargs): profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) session = boto3.Session(profile_name=profile_name) client = session.client("s3") @@ -21,12 +20,12 @@ def _get_s3_client(self, kwargs): def open(self, uri: str, mode: str = "rt", **kwargs) -> IO: # This follows the 5.1.0 smart_open API - client = self._get_s3_client(kwargs) + client = self.get_s3_client(kwargs) transport_params = {"client": client} return smart_open.open(uri, mode, transport_params=transport_params) def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str: - client = self._get_s3_client(kwargs) + client = self.get_s3_client(kwargs) match = re.search("^s3://([^/]+)/(.*?)$", uri) assert match diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py new file mode 100644 index 00000000..9ebb84e6 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -0,0 +1,36 @@ +import os +from typing import List + +import boto3 +from model_engine_server.common.config import get_model_cache_directory_name, hmi_config +from model_engine_server.domain.gateways import LLMArtifactGateway + + +class S3LLMArtifactGateway(LLMArtifactGateway): + """ + Concrete implemention for interacting with a filesystem backed by S3. + """ + + def _get_s3_resource(self, kwargs): + profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) + session = boto3.Session(profile_name=profile_name) + resource = session.resource("s3") + return resource + + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + s3 = self._get_s3_resource(kwargs) + # parsing prefix to get S3 bucket name + bucket_name = hmi_config.hf_user_fine_tuned_weights_prefix.replace("s3://", "").split("/")[ + 0 + ] + bucket = s3.Bucket(bucket_name) + model_files: List[str] = [] + model_cache_name = get_model_cache_directory_name(model_name) + # parsing prefix to get /hosted-model-inference/fine_tuned_weights + fine_tuned_weights_prefix = "/".join( + hmi_config.hf_user_fine_tuned_weights_prefix.split("/")[-2:] + ) + prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}" + for obj in bucket.objects.filter(Prefix=prefix): + model_files.append(f"s3://{hmi_config.s3_bucket_name}/{obj.key}") + return model_files diff --git a/server/llm_engine_server/infra/infra_utils.py b/model-engine/model_engine_server/infra/infra_utils.py similarity index 96% rename from server/llm_engine_server/infra/infra_utils.py rename to model-engine/model_engine_server/infra/infra_utils.py index db8e7182..b38083c9 100644 --- a/server/llm_engine_server/infra/infra_utils.py +++ b/model-engine/model_engine_server/infra/infra_utils.py @@ -2,7 +2,7 @@ from logging import LoggerAdapter from typing import Callable, Sequence -from llm_engine_server.common.env_vars import LOCAL +from model_engine_server.common.env_vars import LOCAL __all__: Sequence[str] = "make_exception_log" diff --git a/server/llm_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py similarity index 75% rename from server/llm_engine_server/infra/repositories/__init__.py rename to model-engine/model_engine_server/infra/repositories/__init__.py index 061baa94..bf109926 100644 --- a/server/llm_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -5,14 +5,16 @@ from .db_docker_image_batch_job_bundle_repository import DbDockerImageBatchJobBundleRepository from .db_model_bundle_repository import DbModelBundleRepository from .db_model_endpoint_record_repository import DbModelEndpointRecordRepository +from .db_trigger_repository import DbTriggerRepository from .ecr_docker_repository import ECRDockerRepository from .feature_flag_repository import FeatureFlagRepository -from .llm_fine_tuning_job_repository import LLMFineTuningJobRepository +from .llm_fine_tune_repository import LLMFineTuneRepository from .model_endpoint_cache_repository import ModelEndpointCacheRepository from .model_endpoint_record_repository import ModelEndpointRecordRepository from .redis_feature_flag_repository import RedisFeatureFlagRepository from .redis_model_endpoint_cache_repository import RedisModelEndpointCacheRepository -from .s3_file_llm_fine_tuning_job_repository import S3FileLLMFineTuningJobRepository +from .s3_file_llm_fine_tune_events_repository import S3FileLLMFineTuneEventsRepository +from .s3_file_llm_fine_tune_repository import S3FileLLMFineTuneRepository __all__: Sequence[str] = [ "BatchJobRecordRepository", @@ -20,12 +22,14 @@ "DbDockerImageBatchJobBundleRepository", "DbModelBundleRepository", "DbModelEndpointRecordRepository", + "DbTriggerRepository", "ECRDockerRepository", "FeatureFlagRepository", - "LLMFineTuningJobRepository", + "LLMFineTuneRepository", "ModelEndpointRecordRepository", "ModelEndpointCacheRepository", "RedisFeatureFlagRepository", "RedisModelEndpointCacheRepository", - "S3FileLLMFineTuningJobRepository", + "S3FileLLMFineTuneRepository", + "S3FileLLMFineTuneEventsRepository", ] diff --git a/server/llm_engine_server/infra/repositories/batch_job_record_repository.py b/model-engine/model_engine_server/infra/repositories/batch_job_record_repository.py similarity index 97% rename from server/llm_engine_server/infra/repositories/batch_job_record_repository.py rename to model-engine/model_engine_server/infra/repositories/batch_job_record_repository.py index 6b33ec29..982aaa5f 100644 --- a/server/llm_engine_server/infra/repositories/batch_job_record_repository.py +++ b/model-engine/model_engine_server/infra/repositories/batch_job_record_repository.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import List, Optional -from llm_engine_server.domain.entities import BatchJobRecord, BatchJobStatus +from model_engine_server.domain.entities import BatchJobRecord, BatchJobStatus class BatchJobRecordRepository(ABC): diff --git a/server/llm_engine_server/infra/repositories/db_batch_job_record_repository.py b/model-engine/model_engine_server/infra/repositories/db_batch_job_record_repository.py similarity index 90% rename from server/llm_engine_server/infra/repositories/db_batch_job_record_repository.py rename to model-engine/model_engine_server/infra/repositories/db_batch_job_record_repository.py index 3ded1566..6aa9feb0 100644 --- a/server/llm_engine_server/infra/repositories/db_batch_job_record_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_batch_job_record_repository.py @@ -1,24 +1,22 @@ from datetime import datetime from typing import Any, Dict, List, Optional -from llm_engine_server.common import dict_not_none -from llm_engine_server.db.models import BatchJob as OrmBatchJob -from llm_engine_server.domain.entities import BatchJobRecord, BatchJobStatus -from llm_engine_server.infra.repositories.batch_job_record_repository import ( +from model_engine_server.common import dict_not_none +from model_engine_server.db.models import BatchJob as OrmBatchJob +from model_engine_server.domain.entities import BatchJobRecord, BatchJobStatus +from model_engine_server.infra.repositories.batch_job_record_repository import ( BatchJobRecordRepository, ) -from llm_engine_server.infra.repositories.db_model_bundle_repository import ( +from model_engine_server.infra.repositories.db_model_bundle_repository import ( translate_model_bundle_orm_to_model_bundle, ) -from llm_engine_server.infra.repositories.db_repository_mixin import ( +from model_engine_server.infra.repositories.db_repository_mixin import ( DbRepositoryMixin, raise_if_read_only, ) -def translate_batch_job_orm_to_batch_job_record( - batch_job_orm: OrmBatchJob, -) -> BatchJobRecord: +def translate_batch_job_orm_to_batch_job_record(batch_job_orm: OrmBatchJob) -> BatchJobRecord: return BatchJobRecord( id=batch_job_orm.id, created_at=batch_job_orm.created_at, diff --git a/server/llm_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py b/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py similarity index 90% rename from server/llm_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py rename to model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py index d1774419..4fa1948c 100644 --- a/server/llm_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py @@ -1,17 +1,17 @@ from typing import Dict, List, Optional, Sequence -from llm_engine_server.common import dict_not_none -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle -from llm_engine_server.domain.entities import GpuType -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.common import dict_not_none +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle +from model_engine_server.domain.entities import GpuType +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) -from llm_engine_server.domain.exceptions import CorruptRecordInfraStateException -from llm_engine_server.domain.repositories.docker_image_batch_job_bundle_repository import ( +from model_engine_server.domain.exceptions import CorruptRecordInfraStateException +from model_engine_server.domain.repositories.docker_image_batch_job_bundle_repository import ( DockerImageBatchJobBundleRepository, ) -from llm_engine_server.infra.repositories.db_repository_mixin import ( +from model_engine_server.infra.repositories.db_repository_mixin import ( DbRepositoryMixin, raise_if_read_only, ) diff --git a/server/llm_engine_server/infra/repositories/db_model_bundle_repository.py b/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py similarity index 96% rename from server/llm_engine_server/infra/repositories/db_model_bundle_repository.py rename to model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py index 73700b9a..9408d59b 100644 --- a/server/llm_engine_server/infra/repositories/db_model_bundle_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py @@ -1,15 +1,15 @@ from typing import Any, Dict, List, Optional, Sequence -from llm_engine_server.common import dict_not_none -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.db.models import Bundle as OrmModelBundle -from llm_engine_server.domain.entities import ( +from model_engine_server.common import dict_not_none +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.db.models import Bundle as OrmModelBundle +from model_engine_server.domain.entities import ( ModelBundle, ModelBundleFlavors, ModelBundlePackagingType, ) -from llm_engine_server.domain.repositories import ModelBundleRepository -from llm_engine_server.infra.repositories.db_repository_mixin import ( +from model_engine_server.domain.repositories import ModelBundleRepository +from model_engine_server.infra.repositories.db_repository_mixin import ( DbRepositoryMixin, raise_if_read_only, ) @@ -57,7 +57,7 @@ async def create_model_bundle( ) async with self.session() as session: await OrmModelBundle.create(session, model_bundle_record) - model_bundle_record = await OrmModelBundle.select_by_id( # type: ignore + model_bundle_record = await OrmModelBundle.select_by_id( session=session, bundle_id=model_bundle_record.id ) return translate_model_bundle_orm_to_model_bundle(model_bundle_record) diff --git a/server/llm_engine_server/infra/repositories/db_model_endpoint_record_repository.py b/model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py similarity index 94% rename from server/llm_engine_server/infra/repositories/db_model_endpoint_record_repository.py rename to model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py index b7803b3a..69fef3de 100644 --- a/server/llm_engine_server/infra/repositories/db_model_endpoint_record_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py @@ -3,21 +3,21 @@ from typing import Any, Callable, Dict, List, Optional from cachetools import TTLCache -from llm_engine_server.common import dict_not_none -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.db.endpoint_row_lock import AdvisoryLockContextManager, get_lock_key -from llm_engine_server.db.models import Endpoint as OrmModelEndpoint -from llm_engine_server.domain.entities import ModelEndpointRecord -from llm_engine_server.domain.gateways import MonitoringMetricsGateway -from llm_engine_server.infra.repositories.db_model_bundle_repository import ( +from model_engine_server.common import dict_not_none +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.db.endpoint_row_lock import AdvisoryLockContextManager, get_lock_key +from model_engine_server.db.models import Endpoint as OrmModelEndpoint +from model_engine_server.domain.entities import ModelEndpointRecord +from model_engine_server.domain.gateways import MonitoringMetricsGateway +from model_engine_server.infra.repositories.db_model_bundle_repository import ( translate_model_bundle_orm_to_model_bundle, ) -from llm_engine_server.infra.repositories.db_repository_mixin import ( +from model_engine_server.infra.repositories.db_repository_mixin import ( DbRepositoryMixin, raise_if_read_only, ) -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) from sqlalchemy import or_, text @@ -202,7 +202,7 @@ async def list_llm_model_endpoint_records( if owner: ownership_filters.append(OrmModelEndpoint.owner == owner) filters.append( - or_(*ownership_filters, OrmModelEndpoint.public_inference == True) # noqa + or_(*ownership_filters, OrmModelEndpoint.public_inference == True) # noqa: E712 ) async with self.session() as session: diff --git a/server/llm_engine_server/infra/repositories/db_repository_mixin.py b/model-engine/model_engine_server/infra/repositories/db_repository_mixin.py similarity index 85% rename from server/llm_engine_server/infra/repositories/db_repository_mixin.py rename to model-engine/model_engine_server/infra/repositories/db_repository_mixin.py index f1d26a81..cd8bc402 100644 --- a/server/llm_engine_server/infra/repositories/db_repository_mixin.py +++ b/model-engine/model_engine_server/infra/repositories/db_repository_mixin.py @@ -2,7 +2,7 @@ from functools import wraps from typing import Callable -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException +from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException from sqlalchemy.ext.asyncio import AsyncSession diff --git a/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py b/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py new file mode 100644 index 00000000..bb9cb5a3 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py @@ -0,0 +1,134 @@ +from typing import Any, Dict, Optional, Sequence + +from model_engine_server.common import dict_not_none +from model_engine_server.db.models import Trigger as OrmTrigger +from model_engine_server.domain.entities.trigger_entity import Trigger +from model_engine_server.domain.exceptions import ( + CorruptRecordInfraStateException, + TriggerNameAlreadyExistsException, +) +from model_engine_server.domain.repositories.trigger_repository import TriggerRepository +from model_engine_server.infra.repositories.db_repository_mixin import ( + DbRepositoryMixin, + raise_if_read_only, +) +from pydantic.error_wrappers import ValidationError +from sqlalchemy.exc import IntegrityError + + +class DbTriggerRepository(TriggerRepository, DbRepositoryMixin): + @raise_if_read_only + async def create_trigger( + self, + *, + name: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Optional[Dict[str, str]], + ) -> Trigger: + trigger_record = translate_kwargs_to_trigger_orm( + name=name, + created_by=created_by, + owner=owner, + cron_schedule=cron_schedule, + docker_image_batch_job_bundle_id=docker_image_batch_job_bundle_id, + default_job_config=default_job_config, + default_job_metadata=default_job_metadata, + ) + try: + async with self.session() as session: + await OrmTrigger.create(session, trigger_record) + trigger_record = await OrmTrigger.select_by_id( + session=session, trigger_id=trigger_record.id + ) + except IntegrityError: + raise TriggerNameAlreadyExistsException( + f"Trigger with name {name} already exists for {owner}" + ) + return translate_trigger_orm_to_entity(trigger_record) + + async def list_triggers(self, owner: str) -> Sequence[Trigger]: + async with self.session() as session: + trigger_records = await OrmTrigger.select_all_by_owner(session=session, owner=owner) + triggers = [translate_trigger_orm_to_entity(tr) for tr in trigger_records] + return triggers + + async def get_trigger(self, trigger_id: str) -> Optional[Trigger]: + async with self.session() as session: + trigger_record = await OrmTrigger.select_by_id(session=session, trigger_id=trigger_id) + if not trigger_record: + return None + + return translate_trigger_orm_to_entity(trigger_record) + + @raise_if_read_only + async def update_trigger( + self, + trigger_id: str, + cron_schedule: str, + ) -> bool: + async with self.session() as session: + trigger = await OrmTrigger.select_by_id(session=session, trigger_id=trigger_id) + if trigger is None: + return False + + await OrmTrigger.update_by_id( + session=session, trigger_id=trigger_id, kwargs=dict(cron_schedule=cron_schedule) + ) + return True + + @raise_if_read_only + async def delete_trigger( + self, + trigger_id: str, + ) -> bool: + async with self.session() as session: + trigger = await OrmTrigger.select_by_id(session=session, trigger_id=trigger_id) + if trigger is None: + return False + + await OrmTrigger.delete_by_id(session=session, trigger_id=trigger_id) + return True + + +def translate_trigger_orm_to_entity( + trigger_orm: OrmTrigger, +) -> Trigger: + kwargs = dict_not_none( + id=trigger_orm.id, + name=trigger_orm.name, + owner=trigger_orm.owner, + created_at=trigger_orm.created_at, + created_by=trigger_orm.created_by, + cron_schedule=trigger_orm.cron_schedule, + docker_image_batch_job_bundle_id=trigger_orm.docker_image_batch_job_bundle_id, + default_job_config=trigger_orm.default_job_config, + default_job_metadata=trigger_orm.default_job_metadata, + ) + try: + return Trigger.parse_obj(kwargs) + except ValidationError as exc: + raise CorruptRecordInfraStateException() from exc + + +def translate_kwargs_to_trigger_orm( + name: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Optional[Dict[str, str]], +) -> OrmTrigger: + return OrmTrigger( + name=name, + owner=owner, + created_by=created_by, + cron_schedule=cron_schedule, + docker_image_batch_job_bundle_id=docker_image_batch_job_bundle_id, + default_job_config=default_job_config, + default_job_metadata=default_job_metadata, + ) diff --git a/server/llm_engine_server/infra/repositories/ecr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py similarity index 69% rename from server/llm_engine_server/infra/repositories/ecr_docker_repository.py rename to model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py index eb61287d..ca5f7469 100644 --- a/server/llm_engine_server/infra/repositories/ecr_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py @@ -1,10 +1,10 @@ from typing import Optional -from llm_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.docker.ecr import image_exists as ecr_image_exists -from llm_engine_server.core.docker.remote_build import build_remote_block -from llm_engine_server.domain.repositories import DockerRepository +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.core.config import infra_config +from model_engine_server.core.docker.ecr import image_exists as ecr_image_exists +from model_engine_server.core.docker.remote_build import build_remote_block +from model_engine_server.domain.repositories import DockerRepository class ECRDockerRepository(DockerRepository): @@ -18,12 +18,10 @@ def image_exists( ) def get_image_url(self, image_tag: str, repository_name: str) -> str: - return f"{ml_infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" + return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: - folders_to_include = [ - "llm_engine", - ] + folders_to_include = ["model-engine"] if image_params.requirements_folder: folders_to_include.append(image_params.requirements_folder) diff --git a/server/llm_engine_server/infra/repositories/feature_flag_repository.py b/model-engine/model_engine_server/infra/repositories/feature_flag_repository.py similarity index 100% rename from server/llm_engine_server/infra/repositories/feature_flag_repository.py rename to model-engine/model_engine_server/infra/repositories/feature_flag_repository.py diff --git a/server/llm_engine_server/infra/repositories/llm_fine_tuning_job_repository.py b/model-engine/model_engine_server/infra/repositories/llm_fine_tune_repository.py similarity index 69% rename from server/llm_engine_server/infra/repositories/llm_fine_tuning_job_repository.py rename to model-engine/model_engine_server/infra/repositories/llm_fine_tune_repository.py index 4a7acc1e..b33d74de 100644 --- a/server/llm_engine_server/infra/repositories/llm_fine_tuning_job_repository.py +++ b/model-engine/model_engine_server/infra/repositories/llm_fine_tune_repository.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from typing import Optional -from llm_engine_server.domain.entities.llm_fine_tune_job_entity import LLMFineTuneJobTemplate +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate -class LLMFineTuningJobRepository(ABC): +class LLMFineTuneRepository(ABC): """ Basically a store of model name + fine tuning method -> docker image batch job bundle ids @@ -13,11 +13,11 @@ class LLMFineTuningJobRepository(ABC): @abstractmethod async def get_job_template_for_model( self, model_name: str, fine_tuning_method: str - ) -> Optional[LLMFineTuneJobTemplate]: + ) -> Optional[LLMFineTuneTemplate]: pass @abstractmethod async def write_job_template_for_model( - self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneJobTemplate + self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate ): pass diff --git a/server/llm_engine_server/infra/repositories/model_endpoint_cache_repository.py b/model-engine/model_engine_server/infra/repositories/model_endpoint_cache_repository.py similarity index 93% rename from server/llm_engine_server/infra/repositories/model_endpoint_cache_repository.py rename to model-engine/model_engine_server/infra/repositories/model_endpoint_cache_repository.py index 2d8a22a9..3c26cbfd 100644 --- a/server/llm_engine_server/infra/repositories/model_endpoint_cache_repository.py +++ b/model-engine/model_engine_server/infra/repositories/model_endpoint_cache_repository.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional -from llm_engine_server.domain.entities import ModelEndpointInfraState +from model_engine_server.domain.entities import ModelEndpointInfraState class ModelEndpointCacheRepository(ABC): diff --git a/server/llm_engine_server/infra/repositories/model_endpoint_record_repository.py b/model-engine/model_engine_server/infra/repositories/model_endpoint_record_repository.py similarity index 97% rename from server/llm_engine_server/infra/repositories/model_endpoint_record_repository.py rename to model-engine/model_engine_server/infra/repositories/model_endpoint_record_repository.py index 48a222b3..3abaee21 100644 --- a/server/llm_engine_server/infra/repositories/model_endpoint_record_repository.py +++ b/model-engine/model_engine_server/infra/repositories/model_endpoint_record_repository.py @@ -2,8 +2,8 @@ from contextlib import AbstractAsyncContextManager from typing import Any, Dict, List, Optional, Sequence -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.domain.entities import ModelEndpointRecord +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.domain.entities import ModelEndpointRecord __all__: Sequence[str] = ("ModelEndpointRecordRepository",) diff --git a/server/llm_engine_server/infra/repositories/redis_feature_flag_repository.py b/model-engine/model_engine_server/infra/repositories/redis_feature_flag_repository.py similarity index 92% rename from server/llm_engine_server/infra/repositories/redis_feature_flag_repository.py rename to model-engine/model_engine_server/infra/repositories/redis_feature_flag_repository.py index b40c1c30..8283ab23 100644 --- a/server/llm_engine_server/infra/repositories/redis_feature_flag_repository.py +++ b/model-engine/model_engine_server/infra/repositories/redis_feature_flag_repository.py @@ -1,7 +1,7 @@ from typing import Optional import aioredis -from llm_engine_server.infra.repositories.feature_flag_repository import FeatureFlagRepository +from model_engine_server.infra.repositories.feature_flag_repository import FeatureFlagRepository class RedisFeatureFlagRepository(FeatureFlagRepository): @@ -27,7 +27,7 @@ def __init__( @staticmethod def _to_redis_key(key: str): - return f"llm-engine-feature-flag:{key}" + return f"launch-feature-flag:{key}" async def write_feature_flag_bool(self, key: str, value: bool): if not isinstance(value, bool): diff --git a/server/llm_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py b/model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py similarity index 91% rename from server/llm_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py rename to model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py index 7f2dde7a..fb1cf630 100644 --- a/server/llm_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py +++ b/model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py @@ -2,8 +2,8 @@ from typing import Optional import aioredis -from llm_engine_server.domain.entities import ModelEndpointInfraState -from llm_engine_server.infra.repositories.model_endpoint_cache_repository import ( +from model_engine_server.domain.entities import ModelEndpointInfraState +from model_engine_server.infra.repositories.model_endpoint_cache_repository import ( ModelEndpointCacheRepository, ) @@ -32,7 +32,7 @@ def __init__( @staticmethod def _find_redis_key(key: str): - return f"llm-engine-k8s-cache:{key}" + return f"launch-k8s-cache:{key}" async def write_endpoint_info( self, diff --git a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py new file mode 100644 index 00000000..90f179c9 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py @@ -0,0 +1,89 @@ +import json +import os +from json.decoder import JSONDecodeError +from typing import IO, List + +import boto3 +import smart_open +from model_engine_server.core.config import infra_config +from model_engine_server.core.domain_exceptions import ObjectNotFoundException +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent +from model_engine_server.domain.repositories.llm_fine_tune_events_repository import ( + LLMFineTuneEventsRepository, +) + +# Echoes llm/ia3_finetune/docker_image_fine_tuning_entrypoint.py +S3_HF_USER_FINE_TUNED_WEIGHTS_PREFIX = ( + f"s3://{infra_config().s3_bucket}/hosted-model-inference/fine_tuned_weights" +) + + +class S3FileLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): + def __init__(self): + pass + + # _get_s3_client + _open copypasted from s3_file_llm_fine_tune_repo, in turn from s3_filesystem_gateway + # sorry + def _get_s3_client(self, kwargs): + profile_name = kwargs.get("aws_profile", os.getenv("S3_WRITE_AWS_PROFILE")) + session = boto3.Session(profile_name=profile_name) + client = session.client("s3") + return client + + def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + # This follows the 5.1.0 smart_open API + client = self._get_s3_client(kwargs) + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + # echoes llm/ia3_finetune/docker_image_fine_tuning_entrypoint.py + def _get_model_cache_directory_name(self, model_name: str): + """How huggingface maps model names to directory names in their cache for model files. + We adopt this when storing model cache files in s3. + + Args: + model_name (str): Name of the huggingface model + """ + name = "models--" + model_name.replace("/", "--") + return name + + def _get_file_location(self, user_id: str, model_endpoint_name: str): + model_cache_name = self._get_model_cache_directory_name(model_endpoint_name) + s3_file_location = ( + f"{S3_HF_USER_FINE_TUNED_WEIGHTS_PREFIX}/{user_id}/{model_cache_name}.jsonl" + ) + return s3_file_location + + async def get_fine_tune_events( + self, user_id: str, model_endpoint_name: str + ) -> List[LLMFineTuneEvent]: + s3_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + try: + with self._open(s3_file_location, "r") as f: + lines = f.readlines() + final_events = [] + for line in lines: + try: + event_dict = json.loads(line) + event = LLMFineTuneEvent( + timestamp=event_dict["timestamp"], + message=str(event_dict["message"]), + level=event_dict.get("level", "info"), + ) + except JSONDecodeError: + event = LLMFineTuneEvent( + message=line, + level="info", + ) + final_events.append(event) + return final_events + except Exception as exc: # TODO better exception + raise ObjectNotFoundException from exc + + async def initialize_events(self, user_id: str, model_endpoint_name: str) -> None: + s3_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + self._open(s3_file_location, "w") diff --git a/server/llm_engine_server/infra/repositories/s3_file_llm_fine_tuning_job_repository.py b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py similarity index 81% rename from server/llm_engine_server/infra/repositories/s3_file_llm_fine_tuning_job_repository.py rename to model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py index e651dcd2..6b3ea8aa 100644 --- a/server/llm_engine_server/infra/repositories/s3_file_llm_fine_tuning_job_repository.py +++ b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py @@ -4,13 +4,11 @@ import boto3 import smart_open -from llm_engine_server.domain.entities.llm_fine_tune_job_entity import LLMFineTuneJobTemplate -from llm_engine_server.infra.repositories.llm_fine_tuning_job_repository import ( - LLMFineTuningJobRepository, -) +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository -class S3FileLLMFineTuningJobRepository(LLMFineTuningJobRepository): +class S3FileLLMFineTuneRepository(LLMFineTuneRepository): def __init__(self, file_path: str): self.file_path = file_path @@ -32,7 +30,7 @@ def _get_key(model_name, fine_tuning_method): async def get_job_template_for_model( self, model_name: str, fine_tuning_method: str - ) -> Optional[LLMFineTuneJobTemplate]: + ) -> Optional[LLMFineTuneTemplate]: # can hot reload the file lol with self._open(self.file_path, "r") as f: data = json.load(f) @@ -40,10 +38,10 @@ async def get_job_template_for_model( job_template_dict = data.get(key, None) if job_template_dict is None: return None - return LLMFineTuneJobTemplate.parse_obj(job_template_dict) + return LLMFineTuneTemplate.parse_obj(job_template_dict) async def write_job_template_for_model( - self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneJobTemplate + self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate ): # Use locally in script with self._open(self.file_path, "r") as f: diff --git a/server/llm_engine_server/infra/services/__init__.py b/model-engine/model_engine_server/infra/services/__init__.py similarity index 100% rename from server/llm_engine_server/infra/services/__init__.py rename to model-engine/model_engine_server/infra/services/__init__.py diff --git a/server/llm_engine_server/infra/services/batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/batch_job_orchestration_service.py similarity index 91% rename from server/llm_engine_server/infra/services/batch_job_orchestration_service.py rename to model-engine/model_engine_server/infra/services/batch_job_orchestration_service.py index bba6d661..bbfa54af 100644 --- a/server/llm_engine_server/infra/services/batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/batch_job_orchestration_service.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from datetime import timedelta -from llm_engine_server.domain.entities import BatchJobSerializationFormat +from model_engine_server.domain.entities import BatchJobSerializationFormat class BatchJobOrchestrationService(ABC): diff --git a/server/llm_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py similarity index 58% rename from server/llm_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py rename to model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py index 7773182f..d5edd4aa 100644 --- a/server/llm_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py +++ b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py @@ -1,22 +1,21 @@ import os -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob -from llm_engine_server.domain.exceptions import ( +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.domain.entities import FineTuneHparamValueType +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.domain.exceptions import ( InvalidRequestException, LLMFineTuningMethodNotImplementedException, ) -from llm_engine_server.domain.gateways.docker_image_batch_job_gateway import ( +from model_engine_server.domain.gateways.docker_image_batch_job_gateway import ( DockerImageBatchJobGateway, ) -from llm_engine_server.domain.repositories.docker_image_batch_job_bundle_repository import ( +from model_engine_server.domain.repositories.docker_image_batch_job_bundle_repository import ( DockerImageBatchJobBundleRepository, ) -from llm_engine_server.domain.services.llm_fine_tuning_service import LLMFineTuningService -from llm_engine_server.infra.repositories.llm_fine_tuning_job_repository import ( - LLMFineTuningJobRepository, -) +from model_engine_server.domain.services import LLMFineTuningService +from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository class DockerImageBatchJobLLMFineTuningService(LLMFineTuningService): @@ -24,30 +23,33 @@ def __init__( self, docker_image_batch_job_gateway: DockerImageBatchJobGateway, docker_image_batch_job_bundle_repo: DockerImageBatchJobBundleRepository, - llm_fine_tuning_job_repository: LLMFineTuningJobRepository, + llm_fine_tune_repository: LLMFineTuneRepository, ): self.docker_image_batch_job_gateway = docker_image_batch_job_gateway self.docker_image_batch_job_bundle_repo = docker_image_batch_job_bundle_repo - self.llm_fine_tuning_job_repository = llm_fine_tuning_job_repository + self.llm_fine_tune_repository = llm_fine_tune_repository - async def create_fine_tune_job( + async def create_fine_tune( self, created_by: str, owner: str, + model: str, training_file: str, - validation_file: str, - model_name: str, - base_model: str, + validation_file: Optional[str], fine_tuning_method: str, - hyperparameters: Dict[str, str], + hyperparameters: Dict[str, FineTuneHparamValueType], + fine_tuned_model: str, + wandb_config: Optional[Dict[str, Any]], ) -> str: - batch_job_template = await self.llm_fine_tuning_job_repository.get_job_template_for_model( - model_name=base_model, fine_tuning_method=fine_tuning_method + # fine_tuned_model must be a valid k8s label. Leaky implementation detail unfortunately. + batch_job_template = await self.llm_fine_tune_repository.get_job_template_for_model( + model_name=model, fine_tuning_method=fine_tuning_method ) if batch_job_template is None: raise LLMFineTuningMethodNotImplementedException( - f"Fine-tuning not implemented for the (base model, fine-tuning method) pairing of ({base_model}, {fine_tuning_method})" - ) + f"Fine-tuning not implemented for model type {model}" + # f"Fine-tuning not implemented for the (base model, fine-tuning method) pairing of ({base_model}, {fine_tuning_method})" + ) # TODO uncomment out error when we support multiple fine tuning methods for param in batch_job_template.required_params: if param not in hyperparameters: @@ -66,10 +68,10 @@ async def create_fine_tune_job( ) if di_batch_job_bundle is None: - raise LLMFineTuningMethodNotImplementedException("Fine-tuning job doesn't exist") + raise LLMFineTuningMethodNotImplementedException("Fine-tuning method doesn't exist") if not di_batch_job_bundle.public and di_batch_job_bundle.owner != owner: - raise LLMFineTuningMethodNotImplementedException("Fine-tuning job not accessible") + raise LLMFineTuningMethodNotImplementedException("Fine-tuning method not accessible") batch_job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( created_by=created_by, @@ -79,10 +81,10 @@ async def create_fine_tune_job( user_id=owner, training_file=training_file, validation_file=validation_file, - model_name=model_name, - launch_bundle_config=batch_job_template.launch_bundle_config, + model_name=fine_tuned_model, launch_endpoint_config=batch_job_template.launch_endpoint_config, hyperparameters=combined_hyperparameters, + wandb_config=wandb_config, ), env=di_batch_job_bundle.env, command=di_batch_job_bundle.command, @@ -95,15 +97,14 @@ async def create_fine_tune_job( gpu_type=di_batch_job_bundle.gpu_type, storage=di_batch_job_bundle.storage, ), - labels=dict(team="infra", product="llm-fine-tuning"), + labels=dict(team="infra", product="llm-fine-tune"), + annotations=dict(fine_tuned_model=fine_tuned_model), mount_location=di_batch_job_bundle.mount_location, ) return batch_job_id - async def get_fine_tune_job( - self, owner: str, fine_tune_id: str - ) -> Optional[DockerImageBatchJob]: + async def get_fine_tune(self, owner: str, fine_tune_id: str) -> Optional[DockerImageBatchJob]: di_batch_job = await self.docker_image_batch_job_gateway.get_docker_image_batch_job( batch_job_id=fine_tune_id ) @@ -111,17 +112,25 @@ async def get_fine_tune_job( return None return di_batch_job - async def list_fine_tune_jobs(self, owner: str) -> List[DockerImageBatchJob]: + async def list_fine_tunes(self, owner: str) -> List[DockerImageBatchJob]: di_batch_jobs = await self.docker_image_batch_job_gateway.list_docker_image_batch_jobs( owner=owner ) return di_batch_jobs - async def cancel_fine_tune_job(self, owner: str, fine_tune_id: str) -> bool: - di_batch_job = self.get_fine_tune_job(owner, fine_tune_id) + async def cancel_fine_tune(self, owner: str, fine_tune_id: str) -> bool: + di_batch_job = await self.get_fine_tune(owner, fine_tune_id) if di_batch_job is None: return False cancel = await self.docker_image_batch_job_gateway.update_docker_image_batch_job( batch_job_id=fine_tune_id, cancel=True ) return cancel + + async def get_fine_tune_model_name_from_id( + self, owner: str, fine_tune_id: str + ) -> Optional[str]: + di_batch_job = await self.get_fine_tune(owner, fine_tune_id) + if di_batch_job is None or di_batch_job.annotations is None: + return None + return di_batch_job.annotations["fine_tuned_model"] diff --git a/server/llm_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py similarity index 62% rename from server/llm_engine_server/infra/services/image_cache_service.py rename to model-engine/model_engine_server/infra/services/image_cache_service.py index a48f8c95..db395a54 100644 --- a/server/llm_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -1,14 +1,17 @@ from datetime import datetime from typing import Dict, NamedTuple, Tuple -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import GpuType, ModelEndpointInfraState -from llm_engine_server.domain.repositories import DockerRepository -from llm_engine_server.infra.gateways.resources.image_cache_gateway import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.env_vars import GIT_TAG +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import GpuType, ModelEndpointInfraState +from model_engine_server.domain.repositories import DockerRepository +from model_engine_server.infra.gateways.resources.image_cache_gateway import ( CachedImages, ImageCacheGateway, ) -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) @@ -25,6 +28,14 @@ ), ) +DockerImage = NamedTuple( + "DockerImage", + ( + ("repo", str), + ("tag", str), + ), +) + class ImageCacheService: """ @@ -41,6 +52,41 @@ def __init__( self.image_cache_gateway = image_cache_gateway self.docker_repository = docker_repository + def _cache_finetune_llm_images( + self, images_to_cache_priority: Dict[str, Dict[str, CachePriority]] + ): + """ + Cache images used by fine tune LLM endpoints to reduce cold start time. + """ + # a cache priority to ensure llm endpoint images are always prioritized + llm_image_cache_priority = CachePriority( + is_high_priority=1, # make it a high priority + has_no_available_workers=1, + # assuming it has no available workers so that it will be at top after reverse sorting + last_updated_at=datetime.max, + # setting it to max to ensure it will be at top after reverse sorting + ) + + istio_image = DockerImage("gcr.io/istio-release/proxyv2", "1.15.0") + tgi_image = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.3-launch_s3" + ) + forwarder_image = DockerImage( + f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG + ) + + for llm_image in [istio_image, tgi_image, forwarder_image]: + if self.docker_repository.is_repo_name( + llm_image.repo + ) and not self.docker_repository.image_exists(llm_image.tag, llm_image.repo): + logger.warning( + f"Image {llm_image.repo}:{llm_image.tag} does not exist. Skipping caching ..." + ) + continue + image = f"{llm_image.repo}:{llm_image.tag}" + for key in ["a10", "a100"]: + images_to_cache_priority[key][image] = llm_image_cache_priority + async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpointInfraState]]): images_to_cache_priority: Dict[str, Dict[str, CachePriority]] = { "cpu": {}, @@ -48,6 +94,9 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi "a100": {}, "t4": {}, } + + self._cache_finetune_llm_images(images_to_cache_priority) + for endpoint_id, (_, state) in endpoint_infra_states.items(): record = await self.model_endpoint_record_repository.get_model_endpoint_record( endpoint_id diff --git a/server/llm_engine_server/infra/services/live_batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py similarity index 88% rename from server/llm_engine_server/infra/services/live_batch_job_orchestration_service.py rename to model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py index c847044a..258a4429 100644 --- a/server/llm_engine_server/infra/services/live_batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py @@ -10,31 +10,32 @@ from datetime import datetime, timedelta from typing import List, Optional, Union -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, GetAsyncTaskV1Response, TaskStatus, ) -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.domain_exceptions import ObjectNotFoundException -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.core.config import infra_config +from model_engine_server.core.domain_exceptions import ObjectNotFoundException +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import ( BatchJobProgress, BatchJobRecord, BatchJobSerializationFormat, BatchJobStatus, ModelEndpointStatus, ) -from llm_engine_server.domain.gateways import AsyncModelEndpointInferenceGateway -from llm_engine_server.domain.services import ModelEndpointService -from llm_engine_server.domain.use_cases.async_inference_use_cases import ( +from model_engine_server.domain.gateways import AsyncModelEndpointInferenceGateway +from model_engine_server.domain.services import ModelEndpointService +from model_engine_server.domain.use_cases.async_inference_use_cases import ( DEFAULT_TASK_TIMEOUT_SECONDS, ) -from llm_engine_server.infra.gateways import BatchJobProgressGateway, FilesystemGateway -from llm_engine_server.infra.repositories.batch_job_record_repository import ( +from model_engine_server.infra.gateways import BatchJobProgressGateway +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.repositories.batch_job_record_repository import ( BatchJobRecordRepository, ) -from llm_engine_server.infra.services.batch_job_orchestration_service import ( +from model_engine_server.infra.services.batch_job_orchestration_service import ( BatchJobOrchestrationService, ) @@ -165,10 +166,7 @@ async def _run_batch_job( ) results = self._poll_tasks( - owner=owner, - job_id=job_id, - task_ids=task_ids, - timeout_timestamp=timeout_timestamp, + owner=owner, job_id=job_id, task_ids=task_ids, timeout_timestamp=timeout_timestamp ) result_location = batch_job_record.result_location @@ -204,10 +202,7 @@ async def _wait_for_endpoint_to_be_ready( model_endpoint = await self.model_endpoint_service.get_model_endpoint_record( model_endpoint_id=model_endpoint_id, ) - updating = { - ModelEndpointStatus.UPDATE_PENDING, - ModelEndpointStatus.UPDATE_IN_PROGRESS, - } + updating = {ModelEndpointStatus.UPDATE_PENDING, ModelEndpointStatus.UPDATE_IN_PROGRESS} assert model_endpoint while model_endpoint.status in updating: @@ -245,9 +240,7 @@ async def _read_or_submit_tasks( pending_task_ids_location = batch_job_record.task_ids_location if pending_task_ids_location is not None: with self.filesystem_gateway.open( - pending_task_ids_location, - "r", - aws_profile=ml_infra_config().profile_ml_worker, + pending_task_ids_location, "r", aws_profile=infra_config().profile_ml_worker ) as f: task_ids_serialized = f.read().splitlines() task_ids = [ @@ -261,9 +254,7 @@ async def _read_or_submit_tasks( task_ids = await self._submit_tasks(queue_name, input_path, task_name) pending_task_ids_location = self._get_pending_task_ids_location(job_id) with self.filesystem_gateway.open( - pending_task_ids_location, - "w", - aws_profile=ml_infra_config().profile_ml_worker, + pending_task_ids_location, "w", aws_profile=infra_config().profile_ml_worker ) as f: f.write("\n".join([tid.serialize() for tid in task_ids])) await self.batch_job_record_repository.update_batch_job_record( @@ -291,7 +282,7 @@ def _create_task( inputs: List[BatchEndpointInferencePrediction] = [] with self.filesystem_gateway.open( - input_path, "r", aws_profile=ml_infra_config().profile_ml_worker + input_path, "r", aws_profile=infra_config().profile_ml_worker ) as f: # Increase the CSV reader's field limit size from the default (131072) csv.field_size_limit(sys.maxsize) @@ -337,8 +328,7 @@ def _poll_tasks( self.batch_job_progress_gateway.update_progress(owner, job_id, progress) while pending_task_ids_set: new_results = executor.map( - self.async_model_endpoint_inference_gateway.get_task, - pending_task_ids_set, + self.async_model_endpoint_inference_gateway.get_task, pending_task_ids_set ) has_new_ready_tasks = False curr_timestamp = datetime.utcnow() @@ -362,8 +352,7 @@ def _poll_tasks( results = [ BatchEndpointInferencePredictionResponse( - response=task_id_to_result[task_id], - reference_id=task_id_to_ref_id_map[task_id], + response=task_id_to_result[task_id], reference_id=task_id_to_ref_id_map[task_id] ) for task_id in task_ids_only ] @@ -383,14 +372,14 @@ def _serialize_and_write_results( results_serialized = pickle.dumps(results) with self.filesystem_gateway.open( - result_location, "wb", aws_profile=ml_infra_config().profile_ml_worker + result_location, "wb", aws_profile=infra_config().profile_ml_worker ) as f: f.write(results_serialized) @staticmethod def _get_pending_task_ids_location(job_id: str) -> str: - return f"s3://{ml_infra_config().s3_bucket}/llm-engine/batch-jobs/{job_id}/pending_task_ids.txt" + return f"s3://{infra_config().s3_bucket}/launch/batch-jobs/{job_id}/pending_task_ids.txt" @staticmethod def _get_job_result_location(job_id: str) -> str: - return f"s3://{ml_infra_config().s3_bucket}/llm-engine/batch-jobs/{job_id}/result.json" + return f"s3://{infra_config().s3_bucket}/launch/batch-jobs/{job_id}/result.json" diff --git a/server/llm_engine_server/infra/services/live_batch_job_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_service.py similarity index 92% rename from server/llm_engine_server/infra/services/live_batch_job_service.py rename to model-engine/model_engine_server/infra/services/live_batch_job_service.py index 7e49c699..9036de50 100644 --- a/server/llm_engine_server/infra/services/live_batch_job_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_service.py @@ -1,8 +1,8 @@ from typing import Dict, Optional -from llm_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import ( BatchJob, BatchJobProgress, BatchJobSerializationFormat, @@ -10,10 +10,10 @@ GpuType, ModelEndpointType, ) -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.domain.services import BatchJobService, ModelEndpointService -from llm_engine_server.infra.gateways import BatchJobOrchestrationGateway, BatchJobProgressGateway -from llm_engine_server.infra.repositories.batch_job_record_repository import ( +from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException +from model_engine_server.domain.services import BatchJobService, ModelEndpointService +from model_engine_server.infra.gateways import BatchJobOrchestrationGateway, BatchJobProgressGateway +from model_engine_server.infra.repositories.batch_job_record_repository import ( BatchJobRecordRepository, ) diff --git a/server/llm_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py similarity index 81% rename from server/llm_engine_server/infra/services/live_endpoint_builder_service.py rename to model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 35b93e28..61b381e0 100644 --- a/server/llm_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -2,33 +2,31 @@ import json import os import tempfile +import time from contextlib import AsyncExitStack from logging import LoggerAdapter -from typing import List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, Set from datadog import statsd -from llm_engine_server.common.constants import ( - FEATURE_FLAG_USE_MULTI_CONTAINER_ARCHITECTURE_FOR_ARTIFACTLIKE_BUNDLE, -) -from llm_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse -from llm_engine_server.common.dtos.endpoint_builder import ( +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.common.dtos.endpoint_builder import ( BuildEndpointRequest, BuildEndpointResponse, BuildEndpointStatus, ) -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.common.env_vars import LOCAL -from llm_engine_server.common.io import open_wrapper -from llm_engine_server.common.serialization_utils import bool_to_str -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.domain_exceptions import DockerBuildFailedException -from llm_engine_server.core.loggers import make_logger -from llm_engine_server.core.notification_gateway import NotificationApp, NotificationGateway -from llm_engine_server.core.utils.env import environment -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.env_vars import LOCAL +from model_engine_server.common.io import open_wrapper +from model_engine_server.common.serialization_utils import bool_to_str +from model_engine_server.core.config import infra_config +from model_engine_server.core.domain_exceptions import DockerBuildFailedException +from model_engine_server.core.loggers import make_logger +from model_engine_server.core.notification_gateway import NotificationApp, NotificationGateway +from model_engine_server.core.utils.env import environment +from model_engine_server.domain.entities import ( ArtifactLike, - CloudpickleArtifactFlavor, CustomFramework, + ModelBundle, ModelBundleFlavorType, ModelEndpointConfig, ModelEndpointDeploymentState, @@ -42,28 +40,31 @@ TensorflowFramework, ZipArtifactFlavor, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.domain.gateways import MonitoringMetricsGateway -from llm_engine_server.domain.repositories import DockerRepository -from llm_engine_server.domain.services import EndpointBuilderService -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways import MonitoringMetricsGateway +from model_engine_server.domain.repositories import DockerRepository +from model_engine_server.domain.services import EndpointBuilderService +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CONVERTED_FROM_ARTIFACT_LIKE_KEY, ) -from llm_engine_server.infra.gateways import FilesystemGateway -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -from llm_engine_server.infra.infra_utils import make_exception_log -from llm_engine_server.infra.repositories import FeatureFlagRepository, ModelEndpointCacheRepository -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.infra.infra_utils import make_exception_log +from model_engine_server.infra.repositories import ( + FeatureFlagRepository, + ModelEndpointCacheRepository, +) +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) if LOCAL: with environment(KUBERNETES_SERVICE_HOST=None): - logger = make_logger("llm_engine_server.service_builder") + logger = make_logger("model_engine_server.service_builder") else: - logger = make_logger("llm_engine_server.service_builder") + logger = make_logger("model_engine_server.service_builder") __all__: Sequence[str] = ( "INITIAL_K8S_CACHE_TTL_SECONDS", @@ -78,6 +79,32 @@ INITIAL_K8S_CACHE_TTL_SECONDS: int = 60 MAX_IMAGE_TAG_LEN = 128 +RESTRICTED_ENV_VARS_KEYS = { + "BASE": [ + "DATADOG_TRACE_ENABLED", + "DD_AGENT_HOST", + "DD_ENV", + "DD_SERVICE", + "DD_VERSION", + "OMP_THREAD_LIMIT", + ], + "TRITON": [ + "AWS_PROFILE", + ], + "CELERY": [ + "CELERY_ELASTICACHE_ENABLED", + "CELERY_QUEUE", + "CELERY_TASK_VISIBILITY", + "S3_BUCKET", + ], + "TEMPORAL": [ + "TEMPORAL_TASK_QUEUE", + ], + "HTTP": [ + "HTTP_PORT", + ], +} + class LiveEndpointBuilderService(EndpointBuilderService): def __init__( @@ -103,6 +130,7 @@ def __init__( async def build_endpoint( self, build_endpoint_request: BuildEndpointRequest ) -> BuildEndpointResponse: + time_build_endpoint_start = time.time() self.monitoring_metrics_gateway.emit_attempted_build_metric() logger_extra = build_endpoint_request.dict() @@ -118,12 +146,6 @@ async def build_endpoint( self._validate_build_endpoint_request(build_endpoint_request) - use_multi_container_architecture_for_artifactlike_bundle = ( - await self.feature_flag_repo.read_feature_flag_bool( - FEATURE_FLAG_USE_MULTI_CONTAINER_ARCHITECTURE_FOR_ARTIFACTLIKE_BUNDLE - ) - ) - async with AsyncExitStack() as stack: lock_ctx = self.model_endpoint_record_repository.get_lock_context(model_endpoint_record) lock = await stack.enter_async_context(lock_ctx) @@ -139,29 +161,31 @@ async def build_endpoint( try: # First, build the image if the model bundle does not have a docker image if not model_bundle.is_runnable(): - if use_multi_container_architecture_for_artifactlike_bundle: - assert isinstance( - model_bundle.flavor, CloudpickleArtifactFlavor - ) or isinstance(model_bundle.flavor, ZipArtifactFlavor) - logger_adapter.info( - f"Create a new runnable image model bundle for artifact flavor model bundle {model_bundle.id=} ..." - ) + logger_adapter.info( + f"Create a new runnable image model bundle for artifact flavor model bundle {model_bundle.id=} ..." + ) logger_adapter.info("Building base & user image...") # Build service image in two steps for better caching. # First we build a base image, which is expected to be shared between # many different bundles. try: - base_image_params = self._get_base_image_params( + base_image_params = self.get_base_image_params( build_endpoint_request, logger_adapter ) base_image = await self._build_image( - base_image_params, build_endpoint_request, logger_adapter + base_image_params, + build_endpoint_request, + logger_adapter, + "base", ) user_image_params = self._get_user_image_params( base_image, build_endpoint_request, logger_adapter ) image = await self._build_image( - user_image_params, build_endpoint_request, logger_adapter + user_image_params, + build_endpoint_request, + logger_adapter, + "user", ) image_repo = user_image_params.repo @@ -189,10 +213,12 @@ async def build_endpoint( inject_bundle_image_params, build_endpoint_request, logger_adapter, + "inject_bundle", ) # Now that it's no longer needed, clean up serialized bundle file to save storage - model_bundle_path = inject_bundle_image_params.substitution_args[ # type: ignore + model_bundle_path = inject_bundle_image_params.substitution_args[ + # type: ignore "LOCAL_BUNDLE_PATH" ] if os.path.exists(model_bundle_path): @@ -205,27 +231,20 @@ async def build_endpoint( self.monitoring_metrics_gateway.emit_docker_failed_build_metric() raise - if use_multi_container_architecture_for_artifactlike_bundle: - self.convert_artifact_like_bundle_to_runnable_image( - build_endpoint_request, image_repo, image_tag - ) + self.convert_artifact_like_bundle_to_runnable_image( + build_endpoint_request, image_repo, image_tag + ) - # CONVERTED_FROM_ARTIFACT_LIKE_KEY will be checked by `get_endpoint_resource_arguments_from_request()` in k8s_resource_types.py - if not model_endpoint_record.metadata: - model_endpoint_record.metadata = {} - model_endpoint_record.metadata.update( - {CONVERTED_FROM_ARTIFACT_LIKE_KEY: True} - ) - await self.model_endpoint_record_repository.update_model_endpoint_record( - model_endpoint_id=endpoint_id, - metadata=model_endpoint_record.metadata, - ) + # CONVERTED_FROM_ARTIFACT_LIKE_KEY will be checked by `get_endpoint_resource_arguments_from_request()` in k8s_resource_types.py + if not model_endpoint_record.metadata: + model_endpoint_record.metadata = {} + model_endpoint_record.metadata.update({CONVERTED_FROM_ARTIFACT_LIKE_KEY: True}) else: flavor = model_bundle.flavor assert isinstance(flavor, RunnableImageLike) repository = ( - f"{ml_infra_config().docker_repo_prefix}/{flavor.repository}" + f"{infra_config().docker_repo_prefix}/{flavor.repository}" if self.docker_repository.is_repo_name(flavor.repository) else flavor.repository ) @@ -251,6 +270,13 @@ async def build_endpoint( except EndpointResourceInfraException: log_error("K8s resource update failed") raise + finally: + # Clean up CONVERTED_FROM_ARTIFACT_LIKE_KEY as it is for internal use only + if ( + model_endpoint_record.metadata is not None + and CONVERTED_FROM_ARTIFACT_LIKE_KEY in model_endpoint_record.metadata + ): + del model_endpoint_record.metadata[CONVERTED_FROM_ARTIFACT_LIKE_KEY] endpoint_info = ModelEndpointInfraState( deployment_name=build_endpoint_request.deployment_name, @@ -334,6 +360,13 @@ async def build_endpoint( except Exception: # noqa log_error(f"[Continuing] Failed to emit successful build metric for {endpoint_id=}") + try: + self.monitoring_metrics_gateway.emit_build_time_metric( + time.time() - time_build_endpoint_start + ) + except Exception: # noqa + log_error(f"[Continuing] Failed to emit endpoint build time metric for {endpoint_id=}") + return BuildEndpointResponse(status=BuildEndpointStatus.OK) def convert_artifact_like_bundle_to_runnable_image( @@ -343,7 +376,7 @@ def convert_artifact_like_bundle_to_runnable_image( image_tag: str, ) -> None: """ - With LLMEngine Inference Re-Architecture, we want to deploy endpoints with ArtifactLike bundle using + With Launch Inference Re-Architecture, we want to deploy endpoints with ArtifactLike bundle using multi-container architecture, which RunnableImageFlavor has already adopted. This function mutates the build_endpoint_request by converting the ArtifactLike bundle flavor into @@ -357,29 +390,30 @@ def convert_artifact_like_bundle_to_runnable_image( assert isinstance(model_bundle.flavor, ArtifactLike) new_model_bundle = model_bundle.copy() - if ml_infra_config().env == "circleci": - ml_infra_service_config_file = "config.yaml" + if infra_config().env == "circleci": + infra_config_file = "config.yaml" else: - ml_infra_service_config_file = ml_infra_config().env + ".yaml" + infra_config_file = infra_config().env + ".yaml" new_flavor = RunnableImageFlavor( flavor=ModelBundleFlavorType.RUNNABLE_IMAGE, repository=image_repo, tag=image_tag, + readiness_initial_delay_seconds=30, command=[ "dumb-init", "--", "ddtrace-run", "python", "-m", - "llm_engine_server.inference.sync_inference.start_fastapi_server", + "model_engine_server.inference.sync_inference.start_fastapi_server", ], env={ - "OMP_NUM_THREADS": '"1"', + "OMP_NUM_THREADS": "1", "BASE_PATH": "/app", "BUNDLE_URL": model_bundle.flavor.location, "AWS_PROFILE": build_endpoint_request.aws_role, - "RESULTS_S3_BUCKET": ml_infra_config().s3_bucket, + "RESULTS_S3_BUCKET": infra_config().s3_bucket, "CHILD_FN_INFO": json.dumps( build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info @@ -387,7 +421,7 @@ def convert_artifact_like_bundle_to_runnable_image( ), "PREWARM": bool_to_str(build_endpoint_request.prewarm) or "false", "PORT": "5005", - "ML_INFRA_SERVICES_CONFIG_PATH": f"/app/ml_infra_core/llm_engine_server.core/llm_engine_server.core/configs/{ml_infra_service_config_file}", + "ML_INFRA_SERVICES_CONFIG_PATH": f"/app/model-engine/model_engine_server/core/configs/{infra_config_file}", }, protocol="http", ) @@ -407,7 +441,7 @@ def convert_artifact_like_bundle_to_runnable_image( build_endpoint_request.model_endpoint_record.current_model_bundle = new_model_bundle - def _get_base_image_params( + def get_base_image_params( self, build_endpoint_request: BuildEndpointRequest, logger_adapter: LoggerAdapter, @@ -453,11 +487,11 @@ def _get_base_image_params( raise ValueError(f"Unsupported framework_type: {env_params.framework_type}") # The context should be whatever WORKDIR is in the container running the build app itself. - inference_folder = "llm_engine/llm_engine/inference" + inference_folder = "model-engine/model_engine_server/inference" base_path: str = os.getenv("WORKSPACE") # type: ignore return BuildImageRequest( - repo="llm-engine", + repo="launch/inference", image_tag=resulting_image_tag[:MAX_IMAGE_TAG_LEN], aws_profile=ECR_AWS_PROFILE, # type: ignore base_path=base_path, @@ -514,7 +548,7 @@ def _get_user_image_params( raise ValueError(f"Unsupported framework_type: {env_params.framework_type}") # The context should be whatever WORKDIR is in the container running the build app itself. - inference_folder = "llm_engine/llm_engine/inference" + inference_folder = "model-engine/model_engine_server/inference" base_path: str = os.getenv("WORKSPACE") # type: ignore requirements_folder = os.path.join(base_path, f"requirements_{requirements_hash}") @@ -551,6 +585,7 @@ def _get_inject_bundle_image_params( ) -> BuildImageRequest: model_endpoint_record = build_endpoint_request.model_endpoint_record model_bundle = model_endpoint_record.current_model_bundle + assert isinstance(model_bundle.flavor, ZipArtifactFlavor) bundle_id = model_bundle.id service_image_str = "-".join([base_image_params.image_tag, GIT_TAG, bundle_id]) @@ -564,7 +599,7 @@ def _get_inject_bundle_image_params( # The context should be whatever WORKDIR is in the container running the build app itself. dockerfile = "inject_bundle.Dockerfile" - inference_folder = "llm_engine/llm_engine/inference" + inference_folder = "model-engine/model_engine_server/inference" base_path: str = os.getenv("WORKSPACE") # type: ignore bundle_folder = os.path.join(base_path, f"bundle_{service_image_hash}") @@ -584,7 +619,7 @@ def _get_inject_bundle_image_params( substitution_args = { "LOCAL_BUNDLE_PATH": model_bundle_path, "LOAD_MODEL_MODULE_PATH": model_bundle.flavor.load_model_fn_module_path, # type: ignore - "LOAD_PREDICT_MODULE_PATH": model_bundle.flavor.load_predict_fn_module_path, # type: ignore + "LOAD_PREDICT_MODULE_PATH": model_bundle.flavor.load_predict_fn_module_path, } return BuildImageRequest( @@ -603,6 +638,7 @@ async def _build_image( image_params: BuildImageRequest, build_endpoint_request: BuildEndpointRequest, logger_adapter: LoggerAdapter, + image_type: str, ) -> str: """ Builds the service image and updates the endpoint status if the image building fails. @@ -625,11 +661,12 @@ async def _build_image( image_tag=image_params.image_tag, aws_profile=ECR_AWS_PROFILE, ): + self.monitoring_metrics_gateway.emit_image_build_cache_miss_metric(image_type) tags = [ f"kube_deployment:{build_endpoint_request.deployment_name}", f"user_id:{user_id}", ] - with statsd.timed("kaniko.build_time", tags=tags): + with statsd.timed(f"kaniko.{image_type}_build_time", tags=tags): try: build_result: BuildImageResponse = self.docker_repository.build_image( image_params, @@ -640,7 +677,7 @@ async def _build_image( build_result_status = False s3_logs_location: Optional[str] = None log_error( - "Unknown error encountered on image build" + "Unknown error encountered on image build. " f"No logs to write for {model_endpoint_name}, since docker build threw exception" ) else: @@ -653,7 +690,7 @@ async def _build_image( with self.filesystem_gateway.open( s3_logs_location, "w", - aws_profile=ml_infra_config().profile_ml_worker, + aws_profile=infra_config().profile_ml_worker, ) as file_out: file_out.write(build_result_logs) except Exception: # noqa @@ -682,11 +719,11 @@ async def _build_image( help_url = self.filesystem_gateway.generate_signed_url( s3_logs_location, expiration=43200, # 12 hours - aws_profile=ml_infra_config().profile_ml_worker, + aws_profile=infra_config().profile_ml_worker, ) else: help_url = ( - "https://app.datadoghq.com/logs?query=service%3Allm-engine-" + "https://app.datadoghq.com/logs?query=service%3Alaunch-" f"endpoint-builder%20env%3A{ENV}&cols=host%2Cservice&" "index=%2A&messageDisplay=inline&stream_sort=time%2C" "desc&viz=stream&live=true" @@ -702,7 +739,7 @@ async def _build_image( ) self.notification_gateway.send_notification( - title="LLMEngine Endpoint Build Failed", + title="Launch Endpoint Build Failed", description=message, help_url=help_url, notification_apps=[ @@ -715,6 +752,7 @@ async def _build_image( raise DockerBuildFailedException(f"Image build failed ({endpoint_id=})") else: + self.monitoring_metrics_gateway.emit_image_build_cache_hit_metric(image_type) logger_adapter.info( f"Image {image_params.repo}:{image_params.image_tag} already exists, " f"skipping build for {endpoint_id=}" @@ -728,8 +766,8 @@ def _validate_build_endpoint_request( ) -> None: """Raises ValueError if the request's AWS role isn't allowed.""" allowed_aws_roles = { - ml_infra_config().profile_ml_worker, - ml_infra_config().profile_ml_inference_worker, + infra_config().profile_ml_worker, + infra_config().profile_ml_inference_worker, } if build_endpoint_request.aws_role not in allowed_aws_roles: @@ -738,6 +776,23 @@ def _validate_build_endpoint_request( f"{allowed_aws_roles}." ) + model_bundle: ModelBundle = ( + build_endpoint_request.model_endpoint_record.current_model_bundle + ) + if isinstance(model_bundle.flavor, RunnableImageLike) and model_bundle.flavor.env: + restriced_env_vars = LiveEndpointBuilderService._get_restricted_env_vars( + model_bundle.flavor.env + ) + if len(restriced_env_vars) > 0: + raise ValueError( + f"Runnable image endpoints cannot set the following env vars: {restriced_env_vars}" + ) + + @staticmethod + def _get_restricted_env_vars(env_vars: Dict[str, str]) -> Set[str]: + restricted_env_vars = set(key for keys in RESTRICTED_ENV_VARS_KEYS.values() for key in keys) + return set(env_vars.keys()) & restricted_env_vars + @staticmethod def _get_requirements_hash(requirements: List[str]) -> str: """Identifying hash for endpoint's Python requirements.""" @@ -757,4 +812,4 @@ def _get_service_builder_logs_location(user_id: str, endpoint_name: str) -> str: This function uses creates a key from the endpoint's name and owning user's ID. It uses an S3 bucket that is accessible by the Gateway & Service Builder. """ - return f"s3://{ml_infra_config().s3_bucket}/service_builder_logs/{user_id}_{endpoint_name}" + return f"s3://{infra_config().s3_bucket}/service_builder_logs/{user_id}_{endpoint_name}" diff --git a/server/llm_engine_server/infra/services/live_llm_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py similarity index 81% rename from server/llm_engine_server/infra/services/live_llm_model_endpoint_service.py rename to model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py index 0eb4e3e4..41b6e5a9 100644 --- a/server/llm_engine_server/infra/services/live_llm_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py @@ -1,13 +1,13 @@ from typing import List, Optional -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.domain.services import LLMModelEndpointService -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.services import LLMModelEndpointService +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) -from llm_engine_server.infra.services import LiveModelEndpointService +from model_engine_server.infra.services import LiveModelEndpointService logger = make_logger(filename_wo_ext(__file__)) diff --git a/server/llm_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py similarity index 90% rename from server/llm_engine_server/infra/services/live_model_endpoint_service.py rename to model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index 5f671676..ced2a6e1 100644 --- a/server/llm_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -1,14 +1,14 @@ from typing import Any, Dict, List, Optional from datadog import statsd -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.common.settings import generate_deployment_name -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.common.settings import generate_deployment_name +from model_engine_server.core.domain_exceptions import ( ObjectAlreadyExistsException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -20,24 +20,25 @@ ModelEndpointType, StorageSpecificationType, ) -from llm_engine_server.domain.exceptions import EndpointDeleteFailedException -from llm_engine_server.domain.gateways import ( +from model_engine_server.domain.exceptions import EndpointDeleteFailedException +from model_engine_server.domain.gateways import ( AsyncModelEndpointInferenceGateway, ModelEndpointsSchemaGateway, StreamingModelEndpointInferenceGateway, SyncModelEndpointInferenceGateway, ) -from llm_engine_server.domain.services import ModelEndpointService -from llm_engine_server.infra.gateways import ModelEndpointInfraGateway -from llm_engine_server.infra.repositories import ModelEndpointCacheRepository -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.domain.services import ModelEndpointService +from model_engine_server.domain.use_cases.model_endpoint_use_cases import MODEL_BUNDLE_CHANGED_KEY +from model_engine_server.infra.gateways import ModelEndpointInfraGateway +from model_engine_server.infra.repositories import ModelEndpointCacheRepository +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) logger = make_logger(filename_wo_ext(__file__)) -STATSD_CACHE_HIT_NAME = "llm_engine_server.get_infra_state.cache_hit" -STATSD_CACHE_MISS_NAME = "llm_engine_server.get_infra_state.cache_miss" +STATSD_CACHE_HIT_NAME = "launch.get_infra_state.cache_hit" +STATSD_CACHE_MISS_NAME = "launch.get_infra_state.cache_miss" class LiveModelEndpointService(ModelEndpointService): @@ -144,6 +145,7 @@ async def create_model_endpoint( results_s3_bucket: str, prewarm: bool, high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, owner: str, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth], @@ -267,6 +269,7 @@ async def update_model_endpoint( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, public_inference: Optional[bool] = None, @@ -297,6 +300,12 @@ async def update_model_endpoint( # f"Resource update on endpoint {name} in progress, try again later" # ) + if record.current_model_bundle.id != model_bundle_id: + if metadata is None: + metadata = {} + # MODEL_BUNDLE_CHANGED_KEY will be checked during _create_deployment in K8SEndpointResourceDelegate + metadata[MODEL_BUNDLE_CHANGED_KEY] = True + record = await self.model_endpoint_record_repository.update_model_endpoint_record( model_endpoint_id=model_endpoint_id, model_bundle_id=model_bundle_id, @@ -324,9 +333,15 @@ async def update_model_endpoint( default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, ) + + # Clean up MODEL_BUNDLE_CHANGED_KEY as it is only for internal use + if metadata is not None and MODEL_BUNDLE_CHANGED_KEY in metadata: + del metadata[MODEL_BUNDLE_CHANGED_KEY] + await self.model_endpoint_record_repository.update_model_endpoint_record( model_endpoint_id=model_endpoint_id, creation_task_id=creation_task_id, + metadata=metadata, ) record = await self.model_endpoint_record_repository.get_model_endpoint_record( diff --git a/server/llm_engine_server/infra/services/model_endpoint_cache_service.py b/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py similarity index 83% rename from server/llm_engine_server/infra/services/model_endpoint_cache_service.py rename to model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py index 824370f5..9169d883 100644 --- a/server/llm_engine_server/infra/services/model_endpoint_cache_service.py +++ b/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py @@ -1,13 +1,13 @@ from typing import Dict, Tuple -from llm_engine_server.domain.entities import ModelEndpointInfraState -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.domain.entities import ModelEndpointInfraState +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -from llm_engine_server.infra.repositories.model_endpoint_cache_repository import ( +from model_engine_server.infra.repositories.model_endpoint_cache_repository import ( ModelEndpointCacheRepository, ) -from llm_engine_server.infra.services.image_cache_service import ImageCacheService +from model_engine_server.infra.services.image_cache_service import ImageCacheService class ModelEndpointCacheWriteService: diff --git a/server/llm_engine_server/scripts/__init__.py b/model-engine/model_engine_server/service_builder/__init__.py similarity index 100% rename from server/llm_engine_server/scripts/__init__.py rename to model-engine/model_engine_server/service_builder/__init__.py diff --git a/model-engine/model_engine_server/service_builder/celery.py b/model-engine/model_engine_server/service_builder/celery.py new file mode 100644 index 00000000..a5ac93e6 --- /dev/null +++ b/model-engine/model_engine_server/service_builder/celery.py @@ -0,0 +1,13 @@ +from model_engine_server.core.celery import celery_app +from model_engine_server.core.config import infra_config + +service_builder_service = celery_app( + name="model_engine_server.service_builder", + modules=[ + "model_engine_server.service_builder.tasks_v1", + ], + s3_bucket=infra_config().s3_bucket, +) + +if __name__ == "__main__": + service_builder_service.start() diff --git a/server/llm_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py similarity index 54% rename from server/llm_engine_server/service_builder/tasks_v1.py rename to model-engine/model_engine_server/service_builder/tasks_v1.py index 8548f149..539b6803 100644 --- a/server/llm_engine_server/service_builder/tasks_v1.py +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -4,79 +4,103 @@ import aioredis from celery.signals import worker_process_init -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.constants import READYZ_FPATH -from llm_engine_server.common.dtos.endpoint_builder import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.constants import READYZ_FPATH +from model_engine_server.common.dtos.endpoint_builder import ( BuildEndpointRequest, BuildEndpointResponse, ) -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.core.fake_notification_gateway import FakeNotificationGateway -from llm_engine_server.core.notification_gateway import NotificationGateway -from llm_engine_server.db.base import SessionAsyncNullPool -from llm_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway -from llm_engine_server.infra.gateways import FakeMonitoringMetricsGateway, S3FilesystemGateway -from llm_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( +from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH +from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway +from model_engine_server.db.base import SessionAsyncNullPool +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway +from model_engine_server.infra.gateways import ( + DatadogMonitoringMetricsGateway, + FakeMonitoringMetricsGateway, + S3FilesystemGateway, +) +from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( FakeSQSEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( set_lazy_load_kubernetes_clients, ) -from llm_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( LiveSQSEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( SQSEndpointResourceDelegate, ) -from llm_engine_server.infra.repositories import ( +from model_engine_server.infra.repositories import ( DbModelEndpointRecordRepository, ECRDockerRepository, RedisFeatureFlagRepository, RedisModelEndpointCacheRepository, ) -from llm_engine_server.infra.services import LiveEndpointBuilderService -from llm_engine_server.service_builder.celery import service_builder_service +from model_engine_server.infra.services import LiveEndpointBuilderService +from model_engine_server.service_builder.celery import service_builder_service # Need to disable lazy loading of k8s clients because each event loop should contain its own k8s # client, which constructs the aiohttp.ClientSession in the event loop. set_lazy_load_kubernetes_clients(False) -async def _build_endpoint( - build_endpoint_request: BuildEndpointRequest, -) -> BuildEndpointResponse: - session = SessionAsyncNullPool - pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) - redis = aioredis.Redis(connection_pool=pool) +def get_live_endpoint_builder_service( + session: Any, + redis: aioredis.Redis, +): sqs_delegate: SQSEndpointResourceDelegate - notification_gateway: NotificationGateway if CIRCLECI: sqs_delegate = FakeSQSEndpointResourceDelegate() else: sqs_delegate = LiveSQSEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) - monitoring_metrics_gateway: MonitoringMetricsGateway - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() notification_gateway = FakeNotificationGateway() + monitoring_metrics_gateway: MonitoringMetricsGateway + if SKIP_AUTH: + monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + else: + monitoring_metrics_gateway = DatadogMonitoringMetricsGateway() service = LiveEndpointBuilderService( docker_repository=ECRDockerRepository(), - resource_gateway=LiveEndpointResourceGateway(sqs_delegate=sqs_delegate), + resource_gateway=LiveEndpointResourceGateway( + sqs_delegate=sqs_delegate, + ), monitoring_metrics_gateway=monitoring_metrics_gateway, model_endpoint_record_repository=DbModelEndpointRecordRepository( - monitoring_metrics_gateway=monitoring_metrics_gateway, - session=session, - read_only=False, + monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False ), model_endpoint_cache_repository=RedisModelEndpointCacheRepository(redis_client=redis), filesystem_gateway=S3FilesystemGateway(), notification_gateway=notification_gateway, feature_flag_repo=RedisFeatureFlagRepository(redis_client=redis), ) + + return service + + +async def _build_endpoint( + build_endpoint_request: BuildEndpointRequest, +) -> BuildEndpointResponse: + session = SessionAsyncNullPool + pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) + redis = aioredis.Redis(connection_pool=pool) + + service: LiveEndpointBuilderService + try: + from plugins.dependencies import ( + get_live_endpoint_builder_service as get_custom_live_endpoint_builder_service, + ) + + service = get_custom_live_endpoint_builder_service(session, redis) + except ModuleNotFoundError: + service = get_live_endpoint_builder_service(session, redis) + response = await service.build_endpoint(build_endpoint_request) await redis.close() await pool.disconnect() diff --git a/server/mypy.ini b/model-engine/mypy.ini similarity index 60% rename from server/mypy.ini rename to model-engine/mypy.ini index 316c36ef..9abfbeaa 100644 --- a/server/mypy.ini +++ b/model-engine/mypy.ini @@ -8,13 +8,19 @@ strict_optional = True plugins = pydantic.mypy exclude = clients -[mypy-llm_engine_server.core.*] +[mypy-model_engine_server.cli.*] ignore_errors = True -[mypy-llm_engine_server.db.*] +[mypy-model_engine_server.core.*] ignore_errors = True -[mypy-llm_engine_server.infra.repositories.*] +[mypy-model_engine_server.db.*] +ignore_errors = True + +[mypy-model_engine_server.infra.repositories.*] +ignore_errors = True + +[mypy-clients.*] ignore_errors = True [mypy-tests.*] diff --git a/server/requirements-test.txt b/model-engine/requirements-test.txt similarity index 90% rename from server/requirements-test.txt rename to model-engine/requirements-test.txt index 719527a1..f93718b3 100644 --- a/server/requirements-test.txt +++ b/model-engine/requirements-test.txt @@ -1,5 +1,6 @@ multiprocess==0.70.14 pytest==7.2.0 +pytest-asyncio==0.20.1 pytest-cov==2.10.0 moto==3.1.12 coverage==5.5 @@ -9,6 +10,7 @@ pytest-mypy-plugins==1.10.1 pytest-asyncio==0.20.1 pytest-pylint==0.18.0 types-cachetools==5.3.0.5 +types-croniter==1.4.0.0 types-PyYAML==6.0.7 types-redis==4.3.21.3 types-requests==2.27.26 diff --git a/server/requirements.in b/model-engine/requirements.in similarity index 93% rename from server/requirements.in rename to model-engine/requirements.in index 8d405e0f..5caed45c 100644 --- a/server/requirements.in +++ b/model-engine/requirements.in @@ -11,6 +11,7 @@ build==0.8.0 celery[redis,sqs,tblib]~=5.2 click~=8.1 cloudpickle==2.1.0 +croniter==1.4.1 dataclasses-json>=0.5.7 datadog-api-client==2.11.0 datadog~=0.46.0 @@ -30,7 +31,8 @@ protobuf~=3.20 psycopg2-binary==2.9.3 py-xid==0.3.0 pycurl~=7.44 # For celery[sqs] -pydantic~=1.10 +pydantic~=1.10.11 +python-multipart~=0.0.6 quart==0.18.3 requests-auth-aws-sigv4~=0.7 requests~=2.25 diff --git a/server/requirements.txt b/model-engine/requirements.txt similarity index 72% rename from server/requirements.txt rename to model-engine/requirements.txt index 69424330..2cf51929 100644 --- a/server/requirements.txt +++ b/model-engine/requirements.txt @@ -1,21 +1,21 @@ # -# This file is autogenerated by pip-compile with python 3.8 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: # -# pip-compile --allow-unsafe --no-emit-index-url --no-emit-trusted-host --output-file=requirements.txt requirements.in +# pip-compile --allow-unsafe --index-url=https://pypi.org/simple --no-emit-index-url --no-emit-trusted-host model-engine/requirements.in # aiofiles==23.1.0 # via quart aiohttp==3.8.4 # via - # -r requirements.in + # -r model-engine/requirements.in # kubernetes-asyncio aioredis==2.0.1 - # via -r requirements.in + # via -r model-engine/requirements.in aiosignal==1.3.1 # via aiohttp alembic==1.8.1 - # via -r requirements.in + # via -r model-engine/requirements.in amqp==5.1.1 # via kombu anyio==3.7.1 @@ -30,11 +30,11 @@ async-timeout==4.0.2 # aioredis # redis asyncpg==0.27.0 - # via -r requirements.in + # via -r model-engine/requirements.in attrs==23.1.0 # via # aiohttp - # ddtrace-scale + # ddtrace backports-zoneinfo[tzdata]==0.2.1 # via # celery @@ -47,23 +47,24 @@ blinker==1.6.2 # via quart boto3==1.28.1 # via - # -r requirements.in + # -r model-engine/requirements.in # celery + # kombu boto3-stubs[essential]==1.26.67 - # via -r requirements.in + # via -r model-engine/requirements.in botocore==1.31.1 # via - # -r requirements.in + # -r model-engine/requirements.in # boto3 # s3transfer botocore-stubs==1.29.165 # via boto3-stubs build==0.8.0 - # via -r requirements.in + # via -r model-engine/requirements.in cachetools==5.3.1 # via google-auth celery[redis,sqs,tblib]==5.3.1 - # via -r requirements.in + # via -r model-engine/requirements.in certifi==2023.5.7 # via # datadog-api-client @@ -76,7 +77,7 @@ charset-normalizer==3.2.0 # requests click==8.1.4 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # click-didyoumean # click-plugins @@ -90,29 +91,31 @@ click-plugins==1.1.1 click-repl==0.3.0 # via celery cloudpickle==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in colorama==0.4.6 # via twine commonmark==0.9.1 # via rich +croniter==1.4.1 + # via -r model-engine/requirements.in dataclasses-json==0.5.9 - # via -r requirements.in -datadog-api-client==2.11.0 - # via -r requirements.in + # via -r model-engine/requirements.in datadog==0.46.0 - # via -r requirements.in + # via -r model-engine/requirements.in +datadog-api-client==2.11.0 + # via -r model-engine/requirements.in ddtrace==0.49.2 - # via -r requirements.in + # via -r model-engine/requirements.in deprecation==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in docker==5.0.3 - # via -r requirements.in + # via -r model-engine/requirements.in docutils==0.20.1 # via readme-renderer exceptiongroup==1.1.2 # via anyio fastapi==0.78.0 - # via -r requirements.in + # via -r model-engine/requirements.in frozenlist==1.3.3 # via # aiohttp @@ -120,15 +123,15 @@ frozenlist==1.3.3 gitdb==4.0.10 # via gitpython gitdb2==2.0.6 - # via -r requirements.in + # via -r model-engine/requirements.in gitpython==3.1.32 - # via -r requirements.in + # via -r model-engine/requirements.in google-auth==2.21.0 # via kubernetes greenlet==2.0.2 # via sqlalchemy gunicorn==20.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in h11==0.14.0 # via # hypercorn @@ -139,7 +142,7 @@ h2==4.1.0 hpack==4.0.0 # via h2 httptools==0.5.0 - # via -r requirements.in + # via -r model-engine/requirements.in hypercorn==0.14.4 # via quart hyperframe==6.0.1 @@ -165,24 +168,24 @@ jaraco-classes==3.3.0 # via keyring jinja2==3.0.3 # via - # -r requirements.in + # -r model-engine/requirements.in # quart jmespath==1.0.1 # via # boto3 # botocore json-log-formatter==0.5.2 - # via -r requirements.in + # via -r model-engine/requirements.in keyring==24.2.0 # via twine -kombu==5.3.1 +kombu[sqs]==5.3.1 # via celery kubeconfig==1.1.1 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes==25.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes-asyncio==24.2.2 - # via -r requirements.in + # via -r model-engine/requirements.in mako==1.2.4 # via alembic markupsafe==2.1.3 @@ -222,11 +225,11 @@ mypy-extensions==1.0.0 oauthlib==3.2.2 # via requests-oauthlib orjson==3.8.6 - # via -r requirements.in + # via -r model-engine/requirements.in packaging==23.1 # via # build - # ddtrace-scale + # ddtrace # deprecation # marshmallow pep517==0.13.0 @@ -241,12 +244,12 @@ prompt-toolkit==3.0.39 # via click-repl protobuf==3.20.3 # via - # -r requirements.in - # ddtrace-scale + # -r model-engine/requirements.in + # ddtrace psycopg2-binary==2.9.3 - # via -r requirements.in + # via -r model-engine/requirements.in py-xid==0.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in pyasn1==0.5.0 # via # pyasn1-modules @@ -255,11 +258,12 @@ pyasn1-modules==0.3.0 # via google-auth pycurl==7.45.2 # via - # -r requirements.in + # -r model-engine/requirements.in # celery + # kombu pydantic==1.10.11 # via - # -r requirements.in + # -r model-engine/requirements.in # fastapi pygments==2.15.1 # via @@ -269,25 +273,28 @@ python-dateutil==2.8.2 # via # botocore # celery + # croniter # datadog-api-client # kubernetes # kubernetes-asyncio # pg8000 +python-multipart==0.0.6 + # via -r model-engine/requirements.in pyyaml==6.0 # via # kubeconfig # kubernetes # kubernetes-asyncio quart==0.18.3 - # via -r requirements.in + # via -r model-engine/requirements.in readme-renderer==40.0 # via twine redis==4.6.0 # via celery requests==2.31.0 # via - # -r requirements.in - # datadog-scale + # -r model-engine/requirements.in + # datadog # docker # kubernetes # requests-auth-aws-sigv4 @@ -295,7 +302,7 @@ requests==2.31.0 # requests-toolbelt # twine requests-auth-aws-sigv4==0.7 - # via -r requirements.in + # via -r model-engine/requirements.in requests-oauthlib==1.3.1 # via kubernetes requests-toolbelt==1.0.0 @@ -303,7 +310,7 @@ requests-toolbelt==1.0.0 rfc3986==2.0.0 # via twine rich==12.6.0 - # via -r requirements.in + # via -r model-engine/requirements.in rsa==4.9 # via google-auth s3transfer==0.6.1 @@ -311,18 +318,18 @@ s3transfer==0.6.1 scramp==1.4.4 # via pg8000 sh==1.14.3 - # via -r requirements.in + # via -r model-engine/requirements.in six==1.16.0 # via # bleach - # ddtrace-scale + # ddtrace # google-auth # kubernetes # kubernetes-asyncio # python-dateutil # tenacity smart-open==5.2.1 - # via -r requirements.in + # via -r model-engine/requirements.in smmap==5.0.0 # via # gitdb @@ -333,12 +340,12 @@ sniffio==1.3.0 # via anyio sqlalchemy[asyncio]==2.0.4 # via - # -r requirements.in + # -r model-engine/requirements.in # alembic sse-starlette==1.6.1 - # via -r requirements.in + # via -r model-engine/requirements.in sseclient-py==1.7.2 - # via -r requirements.in + # via -r model-engine/requirements.in starlette==0.19.1 # via # fastapi @@ -347,12 +354,12 @@ tblib==2.0.0 # via celery tenacity==6.2.0 # via - # -r requirements.in - # ddtrace-scale + # -r model-engine/requirements.in + # ddtrace testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in tomli==2.0.1 # via # build @@ -360,10 +367,10 @@ tomli==2.0.1 # pep517 tqdm==4.65.0 # via - # -r requirements.in + # -r model-engine/requirements.in # twine twine==3.7.1 - # via -r requirements.in + # via -r model-engine/requirements.in types-awscrt==0.16.23 # via # botocore-stubs @@ -402,13 +409,14 @@ urllib3==1.26.16 # celery # datadog-api-client # google-auth + # kombu # kubernetes # kubernetes-asyncio # requests uvicorn==0.17.6 - # via -r requirements.in + # via -r model-engine/requirements.in uvloop==0.17.0 - # via -r requirements.in + # via -r model-engine/requirements.in vine==5.0.0 # via # amqp @@ -428,7 +436,7 @@ wsproto==1.2.0 # via hypercorn yarl==1.9.2 # via - # -r requirements.in + # -r model-engine/requirements.in # aiohttp zipp==3.16.0 # via diff --git a/model-engine/requirements_override.txt b/model-engine/requirements_override.txt new file mode 100644 index 00000000..0520f838 --- /dev/null +++ b/model-engine/requirements_override.txt @@ -0,0 +1,2 @@ +# Consists of packages that need to be explicitly different from those in requirements.txt +aioboto3==10.4.0 diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml new file mode 100644 index 00000000..0d3ae024 --- /dev/null +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -0,0 +1,60 @@ +# Config for Model Engine running in CircleCI +model_primitive_host: "none" + +# Endpoint config +# K8s namespace the endpoints will be created in +endpoint_namespace: model-engine + +# Asynchronous endpoints +# TODO: Try out localstack once e2e tests have been updated to use sqs as a broker_type +sqs_profile: nonexistent_sqs_profile +sqs_queue_policy_template: > + { + "Version": "2012-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Sid": "__owner_statement", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::000000000000:root" + }, + "Action": "sqs:*", + "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" + }, + { + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::000000000000:role/default" + }, + "Action": "sqs:*", + "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" + } + ] + } + +sqs_queue_tag_template: > + { + "infra.scale.com/product": "MLInfraLaunchSQS", + "infra.scale.com/team": "${team}", + "infra.scale.com/contact": "yi.xu@scale.com", + "infra.scale.com/customer": "AllCustomers", + "infra.scale.com/financialOwner": "yi.xu@scale.com", + "Launch-Endpoint-Id": "${endpoint_id}", + "Launch-Endpoint-Name": "${endpoint_name}", + "Launch-Endpoint-Created-By": "${endpoint_created_by}" + } + +# Billing +billing_queue_arn: none +# There's a separate piece of infra that caches k8s state onto redis, so we need a url to it +cache_redis_url: redis://127.0.0.1:6379/15 + +s3_file_llm_fine_tune_repository: "s3://test-bucket" + +datadog_trace_enabled: false +istio_enabled: true +tgi_repository: "text-generation-inference" + +# S3 access +hf_user_fine_tuned_weights_prefix: "s3://test-bucket" diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg new file mode 100644 index 00000000..c47c17ed --- /dev/null +++ b/model-engine/setup.cfg @@ -0,0 +1,35 @@ +[aliases] +test=pytest + +[coverage:run] +omit = + hosted_model_inference/start_server.py, + hosted_model_inference/start_service_builder.py + +# TODO: Fix pylint errors +# [pylint] +# ignore-paths = test/* +# disable = +# I0011, +# R0801, R0902, R0903, R0913, +# W0703, W1202, W1203, W1514, +# C0114, C0411, +# E0611, +# W0511, +# W0622, +# output-format = colorized +# max-line-length = 120 + + +[tool:pytest] +addopts = + --verbose + --durations=0 + --cache-clear + --cov=hosted_model_inference + --cov-report=term-missing + --mypy + --mypy-ini-file=mypy.ini + --ignore=clients +# --pylint +# --pylint-rcfile=setup.cfg diff --git a/model-engine/setup.py b/model-engine/setup.py new file mode 100644 index 00000000..190be7c7 --- /dev/null +++ b/model-engine/setup.py @@ -0,0 +1,20 @@ +# To get circleci to work +from setuptools import find_packages, setup + +setup( + name="model_engine_server", + version="1.0.0", + packages=[p for p in find_packages() if "tests" not in p], + install_requires=[], + entry_points={ + "console_scripts": [ + "start-service-builder=model_engine_server.start_service_builder:entrypoint", + "start-server=model_engine_server.start_server:entrypoint", + "start-fastapi-server=model_engine_server.entrypoints.start_fastapi_server:entrypoint", + "start-batch-job-orchestration=model_engine_server.entrypoints.start_batch_job_orchestration:entrypoint", + "hosted-inference-server=model_engine_server.entrypoints.hosted_inference_server:entrypoint", + "autogen=model_engine_server.scripts.autogenerate_client_and_docs:entrypoint", + "launch-admin=model_engine_server.cli.bin:entrypoint", + ], + }, +) diff --git a/model-engine/tests/README.md b/model-engine/tests/README.md new file mode 100644 index 00000000..ed230099 --- /dev/null +++ b/model-engine/tests/README.md @@ -0,0 +1,7 @@ +## To Run Tests: + +```shell +pushd ../ +PYTHONPATH=hosted_model_inference WORKSPACE=. python3 -m pytest hosted_model_inference/tests --cov=hosted_model_inference +popd +``` diff --git a/server/llm_engine_server/service_builder/__init__.py b/model-engine/tests/__init__.py similarity index 100% rename from server/llm_engine_server/service_builder/__init__.py rename to model-engine/tests/__init__.py diff --git a/server/tests/__init__.py b/model-engine/tests/integration/__init__.py similarity index 100% rename from server/tests/__init__.py rename to model-engine/tests/integration/__init__.py diff --git a/server/tests/integration/inference/conftest.py b/model-engine/tests/integration/inference/conftest.py similarity index 86% rename from server/tests/integration/inference/conftest.py rename to model-engine/tests/integration/inference/conftest.py index ec6eee64..fcd63dfc 100644 --- a/server/tests/integration/inference/conftest.py +++ b/model-engine/tests/integration/inference/conftest.py @@ -13,10 +13,10 @@ from fastapi import Depends, FastAPI, Request from fastapi.responses import JSONResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials -from llm_engine_server.common.constants import READYZ_FPATH -from llm_engine_server.common.serialization_utils import python_json_to_b64 -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth, ModelEndpointConfig +from model_engine_server.common.constants import READYZ_FPATH +from model_engine_server.common.serialization_utils import python_json_to_b64 +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth, ModelEndpointConfig from tenacity import Retrying, retry_if_exception_type, stop_after_attempt, wait_fixed MODULE_PATH = Path(__file__).resolve() @@ -67,6 +67,8 @@ def endpoint_config_location(callback_port: int, test_user_id: str) -> Iterator[ post_inference_hooks=["callback"], default_callback_url=f"http://localhost:{callback_port}/v0/callback", user_id=test_user_id, + billing_queue=None, + billing_tags=None, ).serialize() with NamedTemporaryFile(mode="w+") as f: f.write(endpoint_config_serialized) @@ -75,26 +77,26 @@ def endpoint_config_location(callback_port: int, test_user_id: str) -> Iterator[ @pytest.fixture(scope="session") -def llm_engine_celery_app( +def launch_celery_app( queue: str, user_config_location: str, endpoint_config_location: str ) -> Iterator[subprocess.Popen]: env = dict( - AWS_PROFILE="default", + AWS_PROFILE="default" if os.getenv("CIRCLECI") else infra_config().profile_ml_worker, BROKER_TYPE="redis", USE_REDIS_LOCALHOST=1, - CELERY_S3_BUCKET=ml_infra_config().s3_bucket, - RESULTS_S3_BUCKET=ml_infra_config().s3_bucket, + CELERY_S3_BUCKET=infra_config().s3_bucket, + RESULTS_S3_BUCKET=infra_config().s3_bucket, CHILD_FN_INFO="{}", BASE_PATH=str(BASE_PATH), PREWARM=True, - BUNDLE_URL=f"s3://{ml_infra_config().s3_bucket}/model_bundles/61a67d767bce560024c7eb96/f0142411-51e1-4357-a405-ee5fef87d977", + BUNDLE_URL=f"s3://{infra_config().s3_bucket}/model_bundles/61a67d767bce560024c7eb96/f0142411-51e1-4357-a405-ee5fef87d977", USER_CONFIG_LOCATION=user_config_location, ENDPOINT_CONFIG_LOCATION=endpoint_config_location, ) env_str = " ".join(f"{k}={v}" for k, v in env.items()) command = ( - f"{env_str} exec celery --app=llm_engine_server.inference.async_inference worker " + f"{env_str} exec celery --app=model_engine_server.inference.async_inference worker " f"--loglevel=INFO --concurrency=1 --queues={queue}" ) # Wait up to 10 seconds for process to start and be ready. diff --git a/server/tests/integration/inference/test_async_inference.py b/model-engine/tests/integration/inference/test_async_inference.py similarity index 94% rename from server/tests/integration/inference/test_async_inference.py rename to model-engine/tests/integration/inference/test_async_inference.py index 91010221..e96164d7 100644 --- a/server/tests/integration/inference/test_async_inference.py +++ b/model-engine/tests/integration/inference/test_async_inference.py @@ -9,15 +9,15 @@ import redis import requests from fastapi import FastAPI -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.dtos.tasks import ( CallbackAuth, EndpointPredictV1Request, ResponseSchema, TaskStatus, ) -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.infra.gateways import ( +from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.infra.gateways import ( CeleryTaskQueueGateway, LiveAsyncModelEndpointInferenceGateway, ) @@ -45,7 +45,7 @@ def redis_available() -> bool: ) def test_submit_and_get_tasks( queue: str, - llm_engine_celery_app: subprocess.Popen, + launch_celery_app: subprocess.Popen, callback_app: FastAPI, task_args: List[Any], cloudpickle: bool, @@ -94,7 +94,7 @@ def test_async_callbacks( queue: str, callback_port: int, test_user_id: str, - llm_engine_celery_app: subprocess.Popen, + launch_celery_app: subprocess.Popen, callback_app: FastAPI, callback_version: Optional[str], expected_callback_payload: Any, diff --git a/server/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py similarity index 86% rename from server/tests/unit/api/conftest.py rename to model-engine/tests/unit/api/conftest.py index f223c722..fa6f08fa 100644 --- a/server/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -5,16 +5,18 @@ from fastapi import Depends, HTTPException from fastapi.security import HTTPBasicCredentials from fastapi.testclient import TestClient -from llm_engine_server.api.app import app -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.app import app +from model_engine_server.api.dependencies import ( AUTH, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.core.auth.authentication_repository import AuthenticationRepository, User -from llm_engine_server.core.auth.fake_authentication_repository import FakeAuthenticationRepository -from llm_engine_server.domain.entities import ( +from model_engine_server.core.auth.authentication_repository import AuthenticationRepository, User +from model_engine_server.core.auth.fake_authentication_repository import ( + FakeAuthenticationRepository, +) +from model_engine_server.domain.entities import ( BatchJob, BatchJobProgress, BatchJobRecord, @@ -39,10 +41,11 @@ PytorchFramework, StreamingEnhancedRunnableImageFlavor, TensorflowFramework, + Trigger, ZipArtifactFlavor, ) -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) @@ -99,6 +102,10 @@ def get_test_client( fake_docker_image_batch_job_bundle_repository_contents=None, fake_docker_image_batch_job_gateway_contents=None, fake_llm_fine_tuning_service_contents=None, + fake_file_storage_gateway_contents=None, + fake_file_system_gateway_contents=None, + fake_trigger_repository_contents=None, + fake_cron_job_gateway_contents=None, ) -> TestClient: if fake_docker_image_batch_job_gateway_contents is None: fake_docker_image_batch_job_gateway_contents = {} @@ -116,6 +123,14 @@ def get_test_client( fake_model_bundle_repository_contents = {} if fake_llm_fine_tuning_service_contents is None: fake_llm_fine_tuning_service_contents = {} + if fake_file_storage_gateway_contents is None: + fake_file_storage_gateway_contents = {} + if fake_file_system_gateway_contents is None: + fake_file_system_gateway_contents = {} + if fake_trigger_repository_contents is None: + fake_trigger_repository_contents = {} + if fake_cron_job_gateway_contents is None: + fake_cron_job_gateway_contents = {} app.dependency_overrides[get_external_interfaces] = get_repositories_generator_wrapper( fake_docker_repository_image_always_exists=fake_docker_repository_image_always_exists, fake_model_bundle_repository_contents=fake_model_bundle_repository_contents, @@ -126,6 +141,10 @@ def get_test_client( fake_docker_image_batch_job_bundle_repository_contents=fake_docker_image_batch_job_bundle_repository_contents, fake_docker_image_batch_job_gateway_contents=fake_docker_image_batch_job_gateway_contents, fake_llm_fine_tuning_service_contents=fake_llm_fine_tuning_service_contents, + fake_file_storage_gateway_contents=fake_file_storage_gateway_contents, + fake_file_system_gateway_contents=fake_file_system_gateway_contents, + fake_trigger_repository_contents=fake_trigger_repository_contents, + fake_cron_job_gateway_contents=fake_cron_job_gateway_contents, ) app.dependency_overrides[get_external_interfaces_read_only] = app.dependency_overrides[ get_external_interfaces @@ -147,6 +166,7 @@ def simple_client(get_test_client_wrapper) -> TestClient: fake_batch_job_record_repository_contents={}, fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, + fake_trigger_repository_contents={}, ) return client @@ -790,7 +810,18 @@ def model_endpoint_2( post_inference_hooks=None, default_callback_url=None, default_callback_auth=None, + billing_tags={ + "idempotencyKeyPrefix": "value1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, user_id=test_api_key, + billing_queue="some:arn:for:something", ), ), image="test_image_2", @@ -1050,6 +1081,45 @@ def docker_image_batch_job_bundle_2_v1(test_api_key) -> Tuple[DockerImageBatchJo return batch_bundle, batch_bundle_json +@pytest.fixture +def docker_image_batch_job_bundle_3_v1(test_api_key) -> Tuple[DockerImageBatchJobBundle, Any]: + batch_bundle = DockerImageBatchJobBundle( + id="test_docker_image_batch_job_bundle_id_31", + created_at=datetime.datetime(2022, 1, 2), + name="test_docker_image_batch_job_bundle_3", + created_by=test_api_key, + owner=test_api_key, + image_repository="image_repository", + image_tag="image_tag_git_sha", + command=["python", "script3.py", "--arg1"], + env=dict(ENV1="VAL1", ENV2="VAL2"), + mount_location="/mount2/location/to/config", + cpus="3", + memory="5G", + storage="5G", + gpus=None, + gpu_type=None, + public=None, + ) + batch_bundle_json = { + "id": "test_docker_image_batch_job_bundle_id_31", + "name": "test_docker_image_batch_job_bundle_3", + "created_at": "2022-01-02T00:00:00", + "image_repository": "image_repository", + "image_tag": "image_tag_git_sha", + "command": ["python", "script3.py", "--arg1"], + "env": {"ENV1": "VAL1", "ENV2": "VAL2"}, + "mount_location": "/mount2/location/to/config", + "cpus": "3", + "memory": "5G", + "storage": "5G", + "gpus": None, + "gpu_type": None, + "public": None, + } + return batch_bundle, batch_bundle_json + + @pytest.fixture def create_docker_image_batch_job_request() -> Dict[str, Any]: return dict( @@ -1114,14 +1184,83 @@ def create_llm_model_endpoint_request_sync() -> Dict[str, Any]: @pytest.fixture def completion_sync_request() -> Dict[str, Any]: - return {"prompts": ["what is 1+1?"], "max_new_tokens": 10, "temperature": 0.1} + return { + "prompt": "what is 1+1?", + "max_new_tokens": 10, + "temperature": 0.1, + } @pytest.fixture -def completion_sync_request_temperature_zero() -> Dict[str, Any]: - return {"prompts": ["what is 1+1?"], "max_new_tokens": 10, "temperature": 0} +def completion_stream_request() -> Dict[str, Any]: + return {"prompt": "what is 1+1?", "max_new_tokens": 10, "temperature": 0.1} @pytest.fixture -def completion_stream_request() -> Dict[str, Any]: - return {"prompt": "what is 1+1?", "max_new_tokens": 10, "temperature": 0.1} +def create_trigger_request() -> Dict[str, Any]: + return dict( + name="test_trigger_1", + cron_schedule="* * * * *", + bundle_id="test_docker_image_batch_job_bundle_id_31", + default_job_config={}, + default_job_metadata=dict(team="infra", product="my_product"), + ) + + +@pytest.fixture +def update_trigger_request() -> Dict[str, Any]: + return dict(cron_schedule="0 * * * *", suspend=True) + + +@pytest.fixture +def trigger_1(test_api_key) -> Tuple[Trigger, Any]: + trigger = Trigger( + id="test_trigger_id_1", + name="test_trigger_1", + owner=test_api_key, + created_by=test_api_key, + created_at=datetime.datetime(2022, 1, 2), + cron_schedule="* * * * *", + docker_image_batch_job_bundle_id="test_docker_image_batch_job_bundle_id_11", + default_job_config={}, + default_job_metadata=dict(team="infra", product="my_product_one"), + ) + trigger_json = { + "id": "test_trigger_id_1", + "name": "test_trigger_1", + "owner": "test_user_id", + "created_by": "test_user_id", + "created_at": "2022-01-02T00:00:00", + "cron_schedule": "* * * * *", + "docker_image_batch_job_bundle_id": "test_docker_image_batch_job_bundle_id_11", + "default_job_config": {}, + "default_job_metadata": {"team": "infra", "product": "my_product_one"}, + } + return trigger, trigger_json + + +@pytest.fixture +def trigger_2(test_api_key) -> Tuple[Trigger, Any]: + trigger = Trigger( + id="test_trigger_id_2", + name="test_trigger_2", + owner=test_api_key, + created_by=test_api_key, + created_at=datetime.datetime(2022, 2, 2), + cron_schedule="0 * * * *", + docker_image_batch_job_bundle_id="test_docker_image_batch_job_bundle_id_12", + default_job_config={}, + default_job_metadata=dict(team="infra", product="my_product_two"), + ) + trigger_json = { + "id": "test_trigger_id_2", + "name": "test_trigger_2", + "owner": "test_user_id", + "created_by": "test_user_id", + "created_at": "2022-02-02T00:00:00", + "cron_schedule": "0 * * * *", + "docker_image_batch_job_bundle_id": "test_docker_image_batch_job_bundle_id_12", + "default_job_config": {}, + "default_job_metadata": {"team": "infra", "product": "my_product_two"}, + } + return trigger, trigger_json diff --git a/server/tests/unit/api/test_app.py b/model-engine/tests/unit/api/test_app.py similarity index 100% rename from server/tests/unit/api/test_app.py rename to model-engine/tests/unit/api/test_app.py diff --git a/server/tests/unit/api/test_batch_jobs.py b/model-engine/tests/unit/api/test_batch_jobs.py similarity index 81% rename from server/tests/unit/api/test_batch_jobs.py rename to model-engine/tests/unit/api/test_batch_jobs.py index 87216d9a..5b638d1f 100644 --- a/server/tests/unit/api/test_batch_jobs.py +++ b/model-engine/tests/unit/api/test_batch_jobs.py @@ -2,9 +2,15 @@ import pytest from fastapi.testclient import TestClient -from llm_engine_server.domain.entities import BatchJob, GpuType, ModelBundle, ModelEndpoint -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities import ( + BatchJob, + DockerImageBatchJob, + GpuType, + ModelBundle, + ModelEndpoint, + Trigger, +) +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) @@ -35,6 +41,33 @@ def test_create_batch_job_success( assert "job_id" in response.json() +@pytest.mark.skip(reason="TODO: team validation is currently disabled") +def test_create_batch_job_invalid_team_returns_400( + model_bundle_1_v1: Tuple[ModelBundle, Any], + create_batch_job_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + create_batch_job_request["labels"]["team"] = "invalid_team" + response = client.post( + "/v1/batch-jobs", + auth=(test_api_key, ""), + json=create_batch_job_request, + ) + assert response.status_code == 400 + + def test_create_batch_job_bundle_not_found_returns_404( create_batch_job_request: Dict[str, Any], test_api_key: str, @@ -393,6 +426,26 @@ def test_create_docker_image_batch_job_no_image( assert response.status_code == 404 +def test_create_docker_image_batch_job_invalid_time_limit( + test_api_key: str, + get_test_client_wrapper, + create_docker_image_batch_job_request: Dict[str, Any], + docker_image_batch_job_bundle_1_v1: Tuple[DockerImageBatchJobBundle, Any], +): + client = get_test_client_wrapper( + fake_docker_image_batch_job_bundle_repository_contents={ + docker_image_batch_job_bundle_1_v1[0].id: docker_image_batch_job_bundle_1_v1[0] + } + ) + create_docker_image_batch_job_request["override_job_max_runtime_s"] = -1 + response = client.post( + "/v1/docker-image-batch-jobs", + auth=(test_api_key, ""), + json=create_docker_image_batch_job_request, + ) + assert response.status_code == 400 + + def test_get_docker_image_batch_job_success( test_api_key: str, get_test_client_wrapper, @@ -440,6 +493,83 @@ def test_get_docker_image_batch_job_not_exist( assert response.status_code == 404 +def test_list_jobs_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + "/v1/docker-image-batch-jobs", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert "jobs" in response.json() + + +def test_list_jobs_by_trigger_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + f"/v1/docker-image-batch-jobs?trigger_id={trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert "jobs" in response.json() + + +def test_list_jobs_by_trigger_not_found_returns_404( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + "/v1/docker-image-batch-jobs?trigger_id=some_trigger_id", + auth=(test_api_key, ""), + ) + assert response.status_code == 404 + + +def test_list_jobs_by_trigger_unauthorized_returns_404( + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + f"/v1/docker-image-batch-jobs?trigger_id={trigger_1[0].id}", + auth=("some_invalid_id", ""), + ) + assert response.status_code == 404 + + def test_update_docker_image_batch_job_noop( test_api_key: str, get_test_client_wrapper, diff --git a/server/tests/unit/api/test_docker_image_batch_job_bundles.py b/model-engine/tests/unit/api/test_docker_image_batch_job_bundles.py similarity index 99% rename from server/tests/unit/api/test_docker_image_batch_job_bundles.py rename to model-engine/tests/unit/api/test_docker_image_batch_job_bundles.py index 49e4d09a..2aa12a30 100644 --- a/server/tests/unit/api/test_docker_image_batch_job_bundles.py +++ b/model-engine/tests/unit/api/test_docker_image_batch_job_bundles.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Tuple from fastapi.testclient import TestClient -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) diff --git a/server/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py similarity index 78% rename from server/tests/unit/api/test_llms.py rename to model-engine/tests/unit/api/test_llms.py index 9b4065f6..2e909aeb 100644 --- a/server/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -1,9 +1,10 @@ import json +import re from typing import Any, Dict, Tuple import pytest -from llm_engine_server.common.dtos.llms import GetLLMModelEndpointV1Response -from llm_engine_server.domain.entities import ModelEndpoint +from model_engine_server.common.dtos.llms import GetLLMModelEndpointV1Response +from model_engine_server.domain.entities import ModelEndpoint def test_create_llm_model_endpoint_success( @@ -108,35 +109,8 @@ def test_completion_sync_success( json=completion_sync_request, ) assert response_1.status_code == 200 - assert response_1.json() == {"outputs": [], "status": "SUCCESS", "traceback": None} - - -def test_completion_sync_raises_temperature_zero( - llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], - completion_sync_request_temperature_zero: Dict[str, Any], - get_test_client_wrapper, -): - client = get_test_client_wrapper( - fake_docker_repository_image_always_exists=True, - fake_model_bundle_repository_contents={}, - fake_model_endpoint_record_repository_contents={ - llm_model_endpoint_sync[0].record.id: llm_model_endpoint_sync[0].record, - }, - fake_model_endpoint_infra_gateway_contents={ - llm_model_endpoint_sync[0] - .infra_state.deployment_name: llm_model_endpoint_sync[0] - .infra_state, - }, - fake_batch_job_record_repository_contents={}, - fake_batch_job_progress_gateway_contents={}, - fake_docker_image_batch_job_bundle_repository_contents={}, - ) - response_1 = client.post( - f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", - auth=("no_user", ""), - json=completion_sync_request_temperature_zero, - ) - assert response_1.status_code == 422 + assert response_1.json()["output"] is None + assert response_1.json().keys() == {"output", "request_id"} @pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") @@ -166,6 +140,9 @@ def test_completion_stream_success( assert response_1.status_code == 200 count = 0 for message in response_1: - assert message == b'data: {"status": "SUCCESS", "output": null, "traceback": null}\r\n\r\n' + assert re.fullmatch( + 'data: {"request_id"}: ".*", "output": null}\r\n\r\n', + message.decode("utf-8"), + ) count += 1 assert count == 1 diff --git a/server/tests/unit/api/test_model_bundles.py b/model-engine/tests/unit/api/test_model_bundles.py similarity index 96% rename from server/tests/unit/api/test_model_bundles.py rename to model-engine/tests/unit/api/test_model_bundles.py index 66346f5f..c54e110c 100644 --- a/server/tests/unit/api/test_model_bundles.py +++ b/model-engine/tests/unit/api/test_model_bundles.py @@ -2,7 +2,7 @@ import pytest from fastapi.testclient import TestClient -from llm_engine_server.domain.entities import ModelBundle +from model_engine_server.domain.entities import ModelBundle @pytest.mark.parametrize("version", ["v1", "v2"]) @@ -84,10 +84,7 @@ def test_clone_model_bundle_success( response = client.post( f"/{version}/model-bundles/clone-with-changes", auth=(test_api_key, ""), - json={ - "original_model_bundle_id": model_bundle_1_v1[0].id, - "app_config": {"foo": "bar"}, - }, + json={"original_model_bundle_id": model_bundle_1_v1[0].id, "app_config": {"foo": "bar"}}, ) assert response.status_code == 200 response_json = response.json() @@ -116,10 +113,7 @@ def test_clone_model_bundle_unauthorized_returns_404( response = client.post( f"/{version}/model-bundles/clone-with-changes", auth=(test_api_key_2, ""), # Not the owner, should be unauthorized - json={ - "original_model_bundle_id": model_bundle_1_v1[0].id, - "app_config": {"foo": "bar"}, - }, + json={"original_model_bundle_id": model_bundle_1_v1[0].id, "app_config": {"foo": "bar"}}, ) assert response.status_code == 404 @@ -146,10 +140,7 @@ def test_clone_model_bundle_not_found_returns_404( response = client.post( f"/{version}/model-bundles/clone-with-changes", auth=(test_api_key, ""), - json={ - "original_model_bundle_id": "unknown model bundle id", - "app_config": {"foo": "bar"}, - }, + json={"original_model_bundle_id": "unknown model bundle id", "app_config": {"foo": "bar"}}, ) assert response.status_code == 404 diff --git a/server/tests/unit/api/test_model_endpoints.py b/model-engine/tests/unit/api/test_model_endpoints.py similarity index 88% rename from server/tests/unit/api/test_model_endpoints.py rename to model-engine/tests/unit/api/test_model_endpoints.py index 8961bbdc..614e5907 100644 --- a/server/tests/unit/api/test_model_endpoints.py +++ b/model-engine/tests/unit/api/test_model_endpoints.py @@ -3,8 +3,8 @@ import pytest from fastapi.testclient import TestClient -from llm_engine_server.common.dtos.model_endpoints import GetModelEndpointV1Response -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint, ModelEndpointStatus +from model_engine_server.common.dtos.model_endpoints import GetModelEndpointV1Response +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint, ModelEndpointStatus def test_create_model_endpoint_success( @@ -40,6 +40,42 @@ def test_create_model_endpoint_success( assert response_2.status_code == 200 +@pytest.mark.skip(reason="TODO: team validation is currently disabled") +def test_create_model_endpoint_invalid_team_returns_400( + model_bundle_1_v1: Tuple[ModelBundle, Any], + create_model_endpoint_request_sync: Dict[str, Any], + create_model_endpoint_request_async: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + create_model_endpoint_request_sync["labels"]["team"] = "some_invalid_team" + response_1 = client.post( + "/v1/model-endpoints", + auth=(test_api_key, ""), + json=create_model_endpoint_request_sync, + ) + assert response_1.status_code == 400 + + create_model_endpoint_request_async["labels"]["team"] = "some_invalid_team" + response_2 = client.post( + "/v1/model-endpoints", + auth=(test_api_key, ""), + json=create_model_endpoint_request_async, + ) + assert response_2.status_code == 400 + + def test_create_model_endpoint_invalid_streaming_bundle_returns_400( model_bundle_1_v1: Tuple[ModelBundle, Any], create_model_endpoint_request_streaming_invalid_bundle: Dict[str, Any], @@ -358,6 +394,42 @@ def test_update_model_endpoint_by_id_success( assert response.json()["endpoint_creation_task_id"] +@pytest.mark.skip(reason="TODO: team validation is currently disabled") +def test_update_model_endpoint_by_id_invalid_team_returns_400( + model_bundle_1_v1: Tuple[ModelBundle, Any], + model_endpoint_1: Tuple[ModelEndpoint, Any], + update_model_endpoint_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + assert model_endpoint_1[0].infra_state is not None + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={ + model_endpoint_1[0].record.id: model_endpoint_1[0].record, + }, + fake_model_endpoint_infra_gateway_contents={ + model_endpoint_1[0].infra_state.deployment_name: model_endpoint_1[0].infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + update_model_endpoint_request["labels"] = { + "team": "some_invalid_team", + "product": "my_product", + } + response = client.put( + "/v1/model-endpoints/test_model_endpoint_id_1", + auth=(test_api_key, ""), + json=update_model_endpoint_request, + ) + assert response.status_code == 400 + + def test_update_model_endpoint_by_id_endpoint_not_authorized_returns_404( model_bundle_1_v1: Tuple[ModelBundle, Any], model_endpoint_1: Tuple[ModelEndpoint, Any], diff --git a/server/tests/unit/api/test_model_endpoints_docs.py b/model-engine/tests/unit/api/test_model_endpoints_docs.py similarity index 97% rename from server/tests/unit/api/test_model_endpoints_docs.py rename to model-engine/tests/unit/api/test_model_endpoints_docs.py index 5ee1451b..04828d05 100644 --- a/server/tests/unit/api/test_model_endpoints_docs.py +++ b/model-engine/tests/unit/api/test_model_endpoints_docs.py @@ -1,6 +1,6 @@ from typing import Any, Tuple -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint def test_model_endpoints_schema_success( diff --git a/server/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py similarity index 96% rename from server/tests/unit/api/test_tasks.py rename to model-engine/tests/unit/api/test_tasks.py index 360658e3..db65a80f 100644 --- a/server/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -1,13 +1,13 @@ from typing import Any, Dict, Tuple from unittest.mock import AsyncMock, MagicMock, patch -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint -from llm_engine_server.domain.exceptions import UpstreamServiceError +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint +from model_engine_server.domain.exceptions import UpstreamServiceError def test_create_async_task_success( @@ -158,7 +158,7 @@ def test_get_async_task_raises_404_object_not_found( mock_use_case = MagicMock() mock_use_case.return_value.execute = MagicMock(side_effect=ObjectNotFoundException) with patch( - "llm_engine_server.api.tasks_v1.GetAsyncInferenceTaskV1UseCase", + "model_engine_server.api.tasks_v1.GetAsyncInferenceTaskV1UseCase", mock_use_case, ): response = client.get( @@ -193,7 +193,7 @@ def test_get_async_task_raises_404_object_not_authorized( mock_use_case = MagicMock() mock_use_case.return_value.execute = MagicMock(side_effect=ObjectNotAuthorizedException) with patch( - "llm_engine_server.api.tasks_v1.GetAsyncInferenceTaskV1UseCase", + "model_engine_server.api.tasks_v1.GetAsyncInferenceTaskV1UseCase", mock_use_case, ): response = client.get( @@ -325,7 +325,7 @@ def test_create_sync_task_returns_failure( side_effect=UpstreamServiceError(400, b"test_content") ) with patch( - "llm_engine_server.api.tasks_v1.CreateSyncInferenceTaskV1UseCase", + "model_engine_server.api.tasks_v1.CreateSyncInferenceTaskV1UseCase", mock_use_case, ): response = client.post( diff --git a/model-engine/tests/unit/api/test_triggers.py b/model-engine/tests/unit/api/test_triggers.py new file mode 100644 index 00000000..ea9170ba --- /dev/null +++ b/model-engine/tests/unit/api/test_triggers.py @@ -0,0 +1,312 @@ +from typing import Any, Dict, Tuple + +from fastapi.testclient import TestClient +from model_engine_server.domain.entities import Trigger +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( + DockerImageBatchJobBundle, +) + + +def test_create_trigger_success( + create_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + docker_image_batch_job_bundle_3_v1: Tuple[DockerImageBatchJobBundle, Any], +): + # populate docker image batch bundle repo + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={ + docker_image_batch_job_bundle_3_v1[0].id: docker_image_batch_job_bundle_3_v1[0], + }, + ) + + response_1 = client.post( + "/v1/triggers", + auth=(test_api_key, ""), + json=create_trigger_request, + ) + assert response_1.status_code == 200 + assert "trigger_id" in response_1.json() + + +def test_create_trigger_batch_bundle_not_found_returns_404( + create_trigger_request: Dict[str, Any], + test_api_key: str, + simple_client: TestClient, +): + response_1 = simple_client.post( + "/v1/triggers", + auth=(test_api_key, ""), + json=create_trigger_request, + ) + assert response_1.status_code == 404 + + +def test_create_trigger_batch_bundle_unauthorized_returns_400( + create_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + docker_image_batch_job_bundle_3_v1: Tuple[DockerImageBatchJobBundle, Any], +): + # populate docker image batch bundle repo + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={ + docker_image_batch_job_bundle_3_v1[0].id: docker_image_batch_job_bundle_3_v1[0], + }, + ) + + response_1 = client.post( + "/v1/triggers", + auth=("some_invalid_id", ""), + json=create_trigger_request, + ) + assert response_1.status_code == 404 + + +def test_create_trigger_bad_cron_returns_400( + create_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + docker_image_batch_job_bundle_3_v1: Tuple[DockerImageBatchJobBundle, Any], +): + # populate docker image batch bundle repo + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={ + docker_image_batch_job_bundle_3_v1[0].id: docker_image_batch_job_bundle_3_v1[0], + }, + ) + + create_trigger_request["cron_schedule"] = "field is wrong" + response_1 = client.post( + "/v1/triggers", + auth=(test_api_key, ""), + json=create_trigger_request, + ) + assert response_1.status_code == 400 + + +def test_list_triggers_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + "/v1/triggers", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert response.json() == { + "triggers": [trigger_1[1], trigger_2[1]], + } + + +def test_get_trigger_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert response.json() == trigger_1[1] + + +def test_get_trigger_not_found_returns_404( + test_api_key: str, + simple_client: TestClient, +): + response = simple_client.get( + "/v1/triggers/some_trigger_id", + auth=(test_api_key, ""), + ) + assert response.status_code == 404 + + +def test_get_trigger_unauthorized_returns_404( + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + f"/v1/triggers/{trigger_1[0].id}", + auth=("some_invalid_id", ""), + ) + assert response.status_code == 404 + + +def test_update_trigger_success( + update_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.put( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + json=update_trigger_request, + ) + assert response.json().get("success") + + response = client.get( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert response.json().get("cron_schedule") == "0 * * * *" + + +def test_update_trigger_not_found_returns_404( + update_trigger_request: Dict[str, Any], + test_api_key: str, + simple_client: TestClient, +): + response = simple_client.put( + "/v1/triggers/some_trigger_id", + auth=(test_api_key, ""), + json=update_trigger_request, + ) + assert response.status_code == 404 + + +def test_update_trigger_unauthorized_returns_404( + update_trigger_request: Dict[str, Any], + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.put( + f"/v1/triggers/{trigger_1[0].id}", + auth=("some_invalid_id", ""), + json=update_trigger_request, + ) + assert response.status_code == 404 + + +def test_update_trigger_bad_cron_returns_400( + update_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + + update_trigger_request["cron_schedule"] = "field is wrong" + response = client.put( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + json=update_trigger_request, + ) + assert response.status_code == 400 + + +def test_delete_trigger_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.delete( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.json().get("success") + + response = client.get( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.status_code == 404 + + +def test_delete_trigger_not_found_returns_404( + test_api_key: str, + simple_client: TestClient, +): + response = simple_client.delete( + "/v1/triggers/some_trigger_id", + auth=(test_api_key, ""), + ) + assert response.status_code == 404 + + +def test_delete_trigger_unauthorized_returns_404( + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.delete( + f"/v1/triggers/{trigger_1[0].id}", + auth=("some_invalid_id", ""), + ) + assert response.status_code == 404 diff --git a/server/tests/integration/__init__.py b/model-engine/tests/unit/common/__init__.py similarity index 100% rename from server/tests/integration/__init__.py rename to model-engine/tests/unit/common/__init__.py diff --git a/server/tests/unit/common/test_batch_jobs_dtos.py b/model-engine/tests/unit/common/test_batch_jobs_dtos.py similarity index 95% rename from server/tests/unit/common/test_batch_jobs_dtos.py rename to model-engine/tests/unit/common/test_batch_jobs_dtos.py index f6eb384e..2ba5499d 100644 --- a/server/tests/unit/common/test_batch_jobs_dtos.py +++ b/model-engine/tests/unit/common/test_batch_jobs_dtos.py @@ -1,4 +1,4 @@ -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests def test_create_docker_image_batch_job_resource_requests_merge_requests(): diff --git a/server/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py similarity index 88% rename from server/tests/unit/conftest.py rename to model-engine/tests/unit/conftest.py index 705cf468..df1b5e2e 100644 --- a/server/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -11,40 +11,41 @@ List, Optional, Sequence, + Set, Tuple, ) from unittest.mock import mock_open from uuid import uuid4 import pytest -from llm_engine_server.api.dependencies import ExternalInterfaces -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests -from llm_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.common.dtos.model_endpoints import ( +from model_engine_server.api.dependencies import ExternalInterfaces +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.common.dtos.model_endpoints import ( BrokerType, CpuSpecificationType, GpuType, ModelEndpointOrderBy, StorageSpecificationType, ) -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, SyncEndpointPredictV1Response, TaskStatus, ) -from llm_engine_server.common.settings import generate_destination -from llm_engine_server.core.domain_exceptions import ObjectNotFoundException -from llm_engine_server.core.fake_notification_gateway import FakeNotificationGateway -from llm_engine_server.db.endpoint_row_lock import get_lock_key -from llm_engine_server.db.models import BatchJob as OrmBatchJob -from llm_engine_server.db.models import Endpoint as OrmModelEndpoint -from llm_engine_server.domain.entities import ( +from model_engine_server.common.settings import generate_destination +from model_engine_server.core.domain_exceptions import ObjectNotFoundException +from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway +from model_engine_server.db.endpoint_row_lock import get_lock_key +from model_engine_server.db.models import BatchJob as OrmBatchJob +from model_engine_server.db.models import Endpoint as OrmModelEndpoint +from model_engine_server.domain.entities import ( BatchJob, BatchJobProgress, BatchJobRecord, @@ -54,6 +55,9 @@ CallbackBasicAuth, CloudpickleArtifactFlavor, CustomFramework, + FileMetadata, + FineTuneHparamValueType, + LLMFineTuneEvent, ModelBundle, ModelBundleEnvironmentParams, ModelBundleFlavors, @@ -73,61 +77,74 @@ RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor, TensorflowFramework, + Trigger, TritonEnhancedRunnableImageFlavor, ZipArtifactFlavor, ) -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.domain.gateways import ( +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways import ( AsyncModelEndpointInferenceGateway, + CronJobGateway, DockerImageBatchJobGateway, + FileStorageGateway, + LLMArtifactGateway, StreamingModelEndpointInferenceGateway, SyncModelEndpointInferenceGateway, TaskQueueGateway, ) -from llm_engine_server.domain.repositories import ( +from model_engine_server.domain.repositories import ( DockerImageBatchJobBundleRepository, DockerRepository, + LLMFineTuneEventsRepository, ModelBundleRepository, + TriggerRepository, ) -from llm_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService -from llm_engine_server.domain.services.llm_fine_tuning_service import LLMFineTuningService -from llm_engine_server.infra.gateways import ( +from model_engine_server.domain.services import ( + LLMFineTuningService, + LLMModelEndpointService, + ModelEndpointService, +) +from model_engine_server.infra.gateways import ( BatchJobOrchestrationGateway, - FilesystemGateway, LiveBatchJobProgressGateway, LiveModelEndpointsSchemaGateway, ModelEndpointInfraGateway, ) -from llm_engine_server.infra.gateways.fake_model_primitive_gateway import FakeModelPrimitiveGateway -from llm_engine_server.infra.gateways.fake_monitoring_metrics_gateway import ( +from model_engine_server.infra.gateways.fake_model_primitive_gateway import ( + FakeModelPrimitiveGateway, +) +from model_engine_server.infra.gateways.fake_monitoring_metrics_gateway import ( FakeMonitoringMetricsGateway, ) -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, EndpointResourceGatewayCreateOrUpdateResourcesResponse, QueueInfo, ) -from llm_engine_server.infra.gateways.resources.image_cache_gateway import ( +from model_engine_server.infra.gateways.resources.image_cache_gateway import ( CachedImages, ImageCacheGateway, ) -from llm_engine_server.infra.repositories import ( +from model_engine_server.infra.repositories import ( BatchJobRecordRepository, FeatureFlagRepository, + LLMFineTuneRepository, ModelEndpointCacheRepository, ModelEndpointRecordRepository, ) -from llm_engine_server.infra.repositories.db_model_bundle_repository import ( +from model_engine_server.infra.repositories.db_model_bundle_repository import ( translate_kwargs_to_model_bundle_orm, translate_model_bundle_orm_to_model_bundle, ) -from llm_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService -from llm_engine_server.infra.services.image_cache_service import ImageCacheService -from llm_engine_server.infra.services.live_llm_model_endpoint_service import ( +from model_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService +from model_engine_server.infra.services.image_cache_service import ImageCacheService +from model_engine_server.infra.services.live_llm_model_endpoint_service import ( LiveLLMModelEndpointService, ) @@ -693,6 +710,124 @@ async def read_feature_flag_bool( return self.db.get(key, None) +class FakeLLMFineTuneRepository(LLMFineTuneRepository): + def __init__(self, db: Optional[Dict[Tuple[str, str], LLMFineTuneTemplate]] = None): + self.db = db + if self.db is None: + self.db = {} + + async def get_job_template_for_model( + self, model_name: str, fine_tuning_method: str + ) -> Optional[LLMFineTuneTemplate]: + return self.db.get((model_name, fine_tuning_method), None) + + async def write_job_template_for_model( + self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate + ): + self.db[(model_name, fine_tuning_method)] = job_template + + +class FakeLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): + def __init__(self): + self.initialized_events = [] + self.all_events_list = [LLMFineTuneEvent(timestamp=1, message="message", level="info")] + + async def get_fine_tune_events(self, user_id: str, model_endpoint_name: str): + if (user_id, model_endpoint_name) in self.initialized_events: + return self.all_events_list + raise ObjectNotFoundException + + async def initialize_events(self, user_id: str, model_endpoint_name: str): + self.initialized_events.append((user_id, model_endpoint_name)) + + +class FakeLLMArtifactGateway(LLMArtifactGateway): + def __init__(self): + self.existing_models = [] + self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} + + def _add_model(self, owner: str, model_name: str): + self.existing_models.append((owner, model_name)) + + def get_model_weights_urls(self, owner: str, model_name: str): + if (owner, model_name) in self.existing_models: + return self.urls + raise ObjectNotFoundException + + +class FakeTriggerRepository(TriggerRepository): + def __init__(self, contents: Optional[Dict[str, Trigger]] = None): + self.db = {} if contents is None else contents + self.next_id = 0 + + def _get_new_id(self): + new_id = f"trig_{self.next_id}" + self.next_id += 1 + return new_id + + async def create_trigger( + self, + *, + name: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Optional[Dict[str, str]], + ) -> Trigger: + trigger_id = self._get_new_id() + trigger = Trigger( + id=trigger_id, + name=name, + owner=owner, + created_by=created_by, + created_at=datetime.now(), + cron_schedule=cron_schedule, + docker_image_batch_job_bundle_id=docker_image_batch_job_bundle_id, + default_job_config=default_job_config, + default_job_metadata=default_job_metadata, + ) + self.db[trigger_id] = trigger + return trigger + + async def list_triggers( + self, + owner: str, + ) -> Sequence[Trigger]: + def filter_fn(trig: Trigger) -> bool: + return trig.owner == owner + + return list(filter(filter_fn, self.db.values())) + + async def get_trigger( + self, + trigger_id: str, + ) -> Optional[Trigger]: + return self.db.get(trigger_id) + + async def update_trigger( + self, + trigger_id: str, + cron_schedule: str, + ) -> bool: + if trigger_id not in self.db: + return False + + self.db[trigger_id].cron_schedule = cron_schedule + return True + + async def delete_trigger( + self, + trigger_id: str, + ) -> bool: + if trigger_id not in self.db: + return False + + del self.db[trigger_id] + return True + + class FakeImageCacheGateway(ImageCacheGateway): def __init__(self): self.cached_images = CachedImages(cpu=[], a10=[], a100=[], t4=[]) @@ -839,6 +974,7 @@ def create_model_endpoint_infra( labels: Dict[str, str], prewarm: Optional[bool], high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], ) -> str: @@ -937,6 +1073,7 @@ async def update_model_endpoint_infra( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, ) -> str: @@ -1087,8 +1224,10 @@ async def create_docker_image_batch_job( resource_requests: CreateDockerImageBatchJobResourceRequests, labels: Dict[str, str], mount_location: Optional[str], + annotations: Optional[Dict[str, str]] = None, + override_job_max_runtime_s: Optional[int] = None, ) -> str: - job_id = f"job-{self.id}" + job_id = f"ft-{self.id}" self.id += 1 self.db[job_id] = DockerImageBatchJob( @@ -1098,6 +1237,8 @@ async def create_docker_image_batch_job( created_at=datetime.now(), completed_at=None, status=BatchJobStatus.RUNNING, + annotations=annotations, + override_job_max_runtime_s=override_job_max_runtime_s, ) return job_id @@ -1118,54 +1259,136 @@ async def update_docker_image_batch_job(self, batch_job_id: str, cancel: bool) - return cancel +class FakeCronJobGateway(CronJobGateway): + def __init__(self, contents=None): + self.db = contents or {} + self.suspended_cronjobs: Set[str] = set() + self.id = 0 + + async def create_cronjob( + self, + *, + request_host: str, + trigger_id: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Dict[str, str], + ) -> None: + cron_job_id = f"cronjob-{trigger_id}" + self.id += 1 + + self.db[cron_job_id] = Trigger( + id=cron_job_id, + name=cron_job_id, + owner=owner, + created_by=created_by, + created_at=datetime.now(), + cron_schedule=cron_schedule, + docker_image_batch_job_bundle_id=docker_image_batch_job_bundle_id, + default_job_config=default_job_config, + default_job_metadata=default_job_metadata, + ) + + async def list_jobs( + self, + *, + owner: str, + trigger_id: Optional[str], + ) -> List[DockerImageBatchJob]: + return [] + + async def update_cronjob( + self, + *, + trigger_id: str, + cron_schedule: Optional[str], + suspend: Optional[bool], + ) -> None: + cron_job_id = f"cronjob-{trigger_id}" + if cron_job_id not in self.db: + return + + if cron_schedule is not None: + self.db[cron_job_id].cron_schedule = cron_schedule + if suspend is not None: + if suspend: + self.suspended_cronjobs.add(cron_job_id) + else: + self.suspended_cronjobs.discard(cron_job_id) + + async def delete_cronjob( + self, + *, + trigger_id: str, + ) -> None: + cron_job_id = f"cronjob-{trigger_id}" + self.db.pop(cron_job_id, None) + self.suspended_cronjobs.discard(cron_job_id) + + class FakeLLMFineTuningService(LLMFineTuningService): def __init__(self, contents=None): self.db: Dict[str, DockerImageBatchJob] = {} if contents is None else contents self.id = 0 - async def create_fine_tune_job( + async def create_fine_tune( self, created_by: str, owner: str, + model: str, training_file: str, - validation_file: str, - model_name: str, - base_model: str, + validation_file: Optional[str], fine_tuning_method: str, - hyperparameters: Dict[str, str], + hyperparameters: Dict[str, FineTuneHparamValueType], + fine_tuned_model: str, + wandb_config: Optional[Dict[str, Any]], ) -> str: - job_id = f"job-{self.id}" + job_id = f"ft-{self.id}" self.id += 1 + now = datetime.now() + self.db[job_id] = DockerImageBatchJob( id=job_id, created_by=created_by, owner=owner, - created_at=datetime.now(), + created_at=now, completed_at=None, status=BatchJobStatus.RUNNING, + annotations={ + "fine_tuned_model": fine_tuned_model, + }, ) return job_id - async def get_fine_tune_job( - self, owner: str, fine_tune_id: str - ) -> Optional[DockerImageBatchJob]: + async def get_fine_tune(self, owner: str, fine_tune_id: str) -> Optional[DockerImageBatchJob]: di_batch_job = self.db.get(fine_tune_id) - if di_batch_job is None or di_batch_job["owner"] != owner: + if di_batch_job is None or di_batch_job.owner != owner: return None return di_batch_job - async def list_fine_tune_jobs(self, owner: str) -> List[DockerImageBatchJob]: - return [job for job in self.db.values() if job["owner"] == owner] + async def list_fine_tunes(self, owner: str) -> List[DockerImageBatchJob]: + return [job for job in self.db.values() if job.owner == owner] - async def cancel_fine_tune_job(self, owner: str, fine_tune_id: str) -> bool: - if fine_tune_id not in self.db or self.db.get(fine_tune_id)["owner"] != owner: + async def cancel_fine_tune(self, owner: str, fine_tune_id: str) -> bool: + if fine_tune_id not in self.db or self.db.get(fine_tune_id).owner != owner: return False del self.db[fine_tune_id] return True + async def get_fine_tune_model_name_from_id( + self, owner: str, fine_tune_id: str + ) -> Optional[str]: + fine_tune = self.db.get(fine_tune_id, None) + if fine_tune is not None and fine_tune.owner == owner: + return fine_tune.annotations["fine_tuned_model"] + return None + class FakeStreamingModelEndpointInferenceGateway(StreamingModelEndpointInferenceGateway): def __init__(self): @@ -1204,6 +1427,52 @@ async def predict( return self.response +class FakeFileStorageGateway(FileStorageGateway): + def __init__(self, contents=None): + self.db: Dict[str, FileMetadata] = {} if contents is None else contents + self.id = 0 + self.content = "Test content" + + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + return "dummy URL" + + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + file_id = f"file-{self.id}" + self.id += 1 + + self.db[file_id] = FileMetadata( + id=file_id, + filename=f"{file_id}_name", + size=len(self.content), + owner=owner, + updated_at=datetime.now(), + ) + + return file_id + + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + file = self.db.get(file_id) + if file is None or file.owner != owner: + return None + return file + + async def list_files(self, owner: str) -> List[FileMetadata]: + return [file for file in self.db.values() if file.owner == owner] + + async def delete_file(self, owner: str, file_id: str) -> bool: + if file_id not in self.db or self.db.get(file_id).owner != owner: + return False + + del self.db[file_id] + return True + + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + file = self.db.get(file_id) + if file is None or file.owner != owner: + return None + return self.content + + @dataclass class FakeAsyncTask: topic: str @@ -1337,6 +1606,7 @@ async def create_model_endpoint( results_s3_bucket: str, prewarm: Optional[bool], high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, owner: str, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, @@ -1393,7 +1663,9 @@ async def create_model_endpoint( bundle_name=current_model_bundle.name, endpoint_name=name, post_inference_hooks=post_inference_hooks, + billing_tags=billing_tags, user_id=created_by, + billing_queue="some:arn:for:something", default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, ), @@ -1424,6 +1696,7 @@ async def update_model_endpoint( results_s3_bucket: Optional[str] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, public_inference: Optional[bool] = None, @@ -1621,6 +1894,23 @@ def fake_docker_image_batch_job_bundle_repository() -> FakeDockerImageBatchJobBu return repo +@pytest.fixture +def fake_llm_fine_tune_repository() -> FakeLLMFineTuneRepository: + repo = FakeLLMFineTuneRepository() + return repo + + +@pytest.fixture +def fake_llm_fine_tuning_events_repository() -> FakeLLMFineTuneEventsRepository: + repo = FakeLLMFineTuneEventsRepository() + return repo + + +def fake_trigger_repository() -> FakeTriggerRepository: + repo = FakeTriggerRepository() + return repo + + @pytest.fixture def fake_image_cache_gateway() -> FakeImageCacheGateway: gateway = FakeImageCacheGateway() @@ -1639,6 +1929,12 @@ def fake_batch_job_orchestration_gateway() -> FakeBatchJobOrchestrationGateway: return gateway +@pytest.fixture +def fake_docker_image_batch_job_gateway() -> FakeDockerImageBatchJobGateway: + gateway = FakeDockerImageBatchJobGateway() + return gateway + + @pytest.fixture def fake_monitoring_metrics_gateway() -> FakeMonitoringMetricsGateway: gateway = FakeMonitoringMetricsGateway() @@ -1693,6 +1989,23 @@ def fake_sync_model_endpoint_inference_gateway() -> FakeSyncModelEndpointInferen return gateway +@pytest.fixture +def fake_file_storage_gateway() -> FakeFileStorageGateway: + gateway = FakeFileStorageGateway() + return gateway + + +@pytest.fixture +def fake_llm_artifact_gateway() -> FakeLLMArtifactGateway: + gateway = FakeLLMArtifactGateway() + return gateway + + +def fake_cron_job_gateway() -> FakeCronJobGateway: + gateway = FakeCronJobGateway() + return gateway + + @pytest.fixture def fake_model_endpoint_service() -> FakeModelEndpointService: service = FakeModelEndpointService() @@ -1705,6 +2018,12 @@ def fake_llm_model_endpoint_service() -> FakeLLMModelEndpointService: return service +@pytest.fixture +def fake_llm_fine_tuning_service() -> FakeLLMFineTuningService: + service = FakeLLMFineTuningService() + return service + + @pytest.fixture def fake_image_cache_service( fake_image_cache_gateway, @@ -1727,11 +2046,16 @@ def get_repositories_generator( fake_model_endpoint_infra_gateway_contents, fake_batch_job_record_repository_contents, fake_batch_job_progress_gateway_contents, + fake_cron_job_gateway_contents, fake_docker_image_batch_job_bundle_repository_contents, fake_docker_image_batch_job_gateway_contents, fake_llm_fine_tuning_service_contents, + fake_file_storage_gateway_contents, + fake_trigger_repository_contents, + fake_file_system_gateway_contents, ): def get_test_repositories() -> Iterator[ExternalInterfaces]: + fake_file_system_gateway = FakeFilesystemGateway() fake_model_bundle_repository = FakeModelBundleRepository( contents=fake_model_bundle_repository_contents ) @@ -1777,9 +2101,14 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: fake_docker_image_batch_job_bundle_repository = FakeDockerImageBatchJobBundleRepository( contents=fake_docker_image_batch_job_bundle_repository_contents ) + fake_trigger_repository = FakeTriggerRepository( + contents=fake_trigger_repository_contents + ) fake_docker_image_batch_job_gateway = FakeDockerImageBatchJobGateway( fake_docker_image_batch_job_gateway_contents ) + fake_llm_artifact_gateway = FakeLLMArtifactGateway() + fake_cron_job_gateway = FakeCronJobGateway(fake_cron_job_gateway_contents) fake_llm_model_endpoint_service = LiveLLMModelEndpointService( model_endpoint_record_repository=fake_model_endpoint_record_repository, model_endpoint_service=fake_model_endpoint_service, @@ -1787,6 +2116,8 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: fake_llm_fine_tuning_service = FakeLLMFineTuningService( fake_llm_fine_tuning_service_contents ) + fake_llm_fine_tuning_events_repository = FakeLLMFineTuneEventsRepository() + fake_file_storage_gateway = FakeFileStorageGateway(fake_file_storage_gateway_contents) repositories = ExternalInterfaces( docker_repository=FakeDockerRepository( fake_docker_repository_image_always_exists, False @@ -1803,6 +2134,12 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: docker_image_batch_job_bundle_repository=fake_docker_image_batch_job_bundle_repository, docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway, llm_fine_tuning_service=fake_llm_fine_tuning_service, + llm_fine_tune_events_repository=fake_llm_fine_tuning_events_repository, + file_storage_gateway=fake_file_storage_gateway, + trigger_repository=fake_trigger_repository, + cron_job_gateway=fake_cron_job_gateway, + filesystem_gateway=fake_file_system_gateway, + llm_artifact_gateway=fake_llm_artifact_gateway, ) try: yield repositories @@ -1985,7 +2322,7 @@ def model_bundle_4(test_api_key: str) -> ModelBundle: ecr_repo="test_repo", image_tag="test_tag", ), - packaging_type=ModelBundlePackagingType.CLOUDPICKLE, + packaging_type=ModelBundlePackagingType.LIRA, app_config=None, ) return model_bundle @@ -2022,7 +2359,7 @@ def model_bundle_5(test_api_key: str) -> ModelBundle: ecr_repo="test_repo", image_tag="test_tag", ), - packaging_type=ModelBundlePackagingType.CLOUDPICKLE, + packaging_type=ModelBundlePackagingType.LIRA, app_config=None, ) return model_bundle @@ -2063,7 +2400,7 @@ def model_bundle_6(test_api_key: str) -> ModelBundle: ecr_repo="test_repo", image_tag="test_tag", ), - packaging_type=ModelBundlePackagingType.CLOUDPICKLE, + packaging_type=ModelBundlePackagingType.LIRA, app_config=None, ) return model_bundle @@ -2108,7 +2445,7 @@ def model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage( ecr_repo="test_repo", image_tag="test_tag", ), - packaging_type=ModelBundlePackagingType.CLOUDPICKLE, + packaging_type=ModelBundlePackagingType.LIRA, app_config=None, ) return model_bundle @@ -2160,6 +2497,16 @@ def model_endpoint_1(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd bundle_name=model_bundle_1.name, endpoint_name="test_model_endpoint_name_1", post_inference_hooks=None, + billing_tags={ + "idempotencyKeyPrefix": "value1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, ), ), image="000000000000.dkr.ecr.us-west-2.amazonaws.com/non-existent-repo:fake-tag", @@ -2217,7 +2564,7 @@ def model_endpoint_2(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd post_inference_hooks=None, ), ), - image="000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:40d3b5fb06d1a8c3d14903390a3b23ae388bdb19", + image="000000000000.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg222", ), ) return model_endpoint @@ -2272,7 +2619,7 @@ def model_endpoint_3(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd post_inference_hooks=None, ), ), - image="000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:e4ea48ddccfb9ca3ef6d846ae9b2d146d7e30b0f", + image="000000000000.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg111111111", ), ) return model_endpoint @@ -2327,7 +2674,7 @@ def model_endpoint_4(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd post_inference_hooks=None, ), ), - image="000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:9a319cd9b897f02291f3242b1395f2b669993cdf-fd", + image="000000000000.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg00000", ), ) return model_endpoint @@ -2380,6 +2727,16 @@ def model_endpoint_public(test_api_key: str, model_bundle_1: ModelBundle) -> Mod bundle_name=model_bundle_1.name, endpoint_name="test_model_endpoint_name_1", post_inference_hooks=None, + billing_tags={ + "idempotencyKeyPrefix": "value1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, ), ), image="000000000000.dkr.ecr.us-west-2.amazonaws.com/non-existent-repo:fake-tag", @@ -2435,6 +2792,16 @@ def model_endpoint_public_sync(test_api_key: str, model_bundle_1: ModelBundle) - bundle_name=model_bundle_1.name, endpoint_name="test_model_endpoint_name_1", post_inference_hooks=None, + billing_tags={ + "idempotencyKeyPrefix": "value1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, ), ), image="000000000000.dkr.ecr.us-west-2.amazonaws.com/non-existent-repo:fake-tag", @@ -3069,6 +3436,7 @@ def llm_model_endpoint_async( "name": "test_llm_model_endpoint_name_1", "model_name": "llama-7b", "source": "hugging_face", + "status": "READY", "inference_framework": "deepspeed", "inference_framework_image_tag": "123", "num_shards": 4, @@ -3200,6 +3568,7 @@ def llm_model_endpoint_sync( "name": "test_llm_model_endpoint_name_1", "model_name": "llama-7b", "source": "hugging_face", + "status": "READY", "inference_framework": "deepspeed", "inference_framework_image_tag": "123", "num_shards": 4, diff --git a/server/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py similarity index 86% rename from server/tests/unit/domain/conftest.py rename to model-engine/tests/unit/domain/conftest.py index 5c993d0b..85e57ea4 100644 --- a/server/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -1,22 +1,22 @@ import pytest -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.dtos.batch_jobs import ( CreateDockerImageBatchJobBundleV1Request, CreateDockerImageBatchJobResourceRequests, ) -from llm_engine_server.common.dtos.llms import ( +from model_engine_server.common.dtos.llms import ( CompletionStreamV1Request, CompletionSyncV1Request, CreateLLMModelEndpointV1Request, ) -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.dtos.model_bundles import ( CreateModelBundleV1Request, CreateModelBundleV2Request, ) -from llm_engine_server.common.dtos.model_endpoints import ( +from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, UpdateModelEndpointV1Request, ) -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( GpuType, ModelBundle, ModelBundleEnvironmentParams, @@ -75,7 +75,7 @@ def create_model_endpoint_request_sync( model_bundle_id=model_bundle_1.id, endpoint_type=ModelEndpointType.SYNC, metadata={}, - post_inference_hooks=[], + post_inference_hooks=["billing"], cpus=1, gpus=1, memory="8G", @@ -99,7 +99,7 @@ def create_model_endpoint_request_streaming( model_bundle_id=model_bundle_5.id, endpoint_type=ModelEndpointType.STREAMING, metadata={}, - post_inference_hooks=[], + post_inference_hooks=["billing"], cpus=1, gpus=1, memory="8G", @@ -123,7 +123,7 @@ def create_model_endpoint_request_async( model_bundle_id=model_bundle_1.id, endpoint_type=ModelEndpointType.ASYNC, metadata={}, - post_inference_hooks=[], + post_inference_hooks=["billing"], cpus=1, gpus=1, memory="8G", @@ -179,7 +179,7 @@ def create_llm_model_endpoint_request_sync() -> CreateLLMModelEndpointV1Request: num_shards=2, endpoint_type=ModelEndpointType.SYNC, metadata={}, - post_inference_hooks=[], + post_inference_hooks=["billing"], cpus=1, gpus=2, memory="8G", @@ -205,7 +205,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request num_shards=2, endpoint_type=ModelEndpointType.ASYNC, metadata={}, - post_inference_hooks=[], + post_inference_hooks=["billing"], cpus=1, gpus=2, memory="8G", @@ -231,7 +231,33 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req num_shards=2, endpoint_type=ModelEndpointType.STREAMING, metadata={}, - post_inference_hooks=[], + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage=None, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_2", + model_name="llama-2-7b", + source="hugging_face", + inference_framework="text_generation_inference", + inference_framework_image_tag="test_tag", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], cpus=1, gpus=2, memory="8G", @@ -260,7 +286,7 @@ def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( quantize=Quantization.BITSANDBYTES, endpoint_type=ModelEndpointType.STREAMING, metadata={}, - post_inference_hooks=[], + post_inference_hooks=["billing"], cpus=1, gpus=2, memory="8G", @@ -289,7 +315,7 @@ def create_llm_model_endpoint_text_generation_inference_request_async() -> ( quantize=Quantization.BITSANDBYTES, endpoint_type=ModelEndpointType.ASYNC, metadata={}, - post_inference_hooks=[], + post_inference_hooks=["billing"], cpus=1, gpus=2, memory="8G", @@ -315,7 +341,7 @@ def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndp num_shards=2, endpoint_type=ModelEndpointType.SYNC, metadata={}, - post_inference_hooks=[], + post_inference_hooks=["billing"], cpus=1, gpus=2, memory="8G", @@ -333,9 +359,10 @@ def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndp @pytest.fixture def completion_sync_request() -> CompletionSyncV1Request: return CompletionSyncV1Request( - prompts=["test_prompt_1", "test_prompt_2"], + prompt="test_prompt_1", max_new_tokens=10, temperature=0.5, + return_token_log_probs=True, ) diff --git a/server/tests/unit/domain/test_async_inference_use_cases.py b/model-engine/tests/unit/domain/test_async_inference_use_cases.py similarity index 91% rename from server/tests/unit/domain/test_async_inference_use_cases.py rename to model-engine/tests/unit/domain/test_async_inference_use_cases.py index 9b027907..7a122b3b 100644 --- a/server/tests/unit/domain/test_async_inference_use_cases.py +++ b/model-engine/tests/unit/domain/test_async_inference_use_cases.py @@ -1,14 +1,14 @@ from typing import Any, Dict, Tuple import pytest -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request, TaskStatus -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request, TaskStatus +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.domain.use_cases.async_inference_use_cases import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.use_cases.async_inference_use_cases import ( CreateAsyncInferenceTaskV1UseCase, GetAsyncInferenceTaskV1UseCase, ) diff --git a/server/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py b/model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py similarity index 92% rename from server/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py rename to model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py index 51852f30..9522c9d5 100644 --- a/server/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py +++ b/model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py @@ -1,16 +1,16 @@ import pytest -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.dtos.batch_jobs import ( CreateDockerImageBatchJobBundleV1Request, CreateDockerImageBatchJobBundleV1Response, ) -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.repositories import DockerImageBatchJobBundleRepository -from llm_engine_server.domain.use_cases.docker_image_batch_job_bundle_use_cases import ( +from model_engine_server.domain.repositories import DockerImageBatchJobBundleRepository +from model_engine_server.domain.use_cases.docker_image_batch_job_bundle_use_cases import ( CreateDockerImageBatchJobBundleV1UseCase, GetDockerImageBatchJobBundleByIdV1UseCase, GetLatestDockerImageBatchJobBundleByNameV1UseCase, @@ -56,9 +56,7 @@ async def test_create_list_docker_image_batch_job_bundle_use_case( user=user, request=create_docker_image_batch_job_bundle_request ) response_2 = await use_case_list.execute( - user=user, - bundle_name=create_docker_image_batch_job_bundle_request.name, - order_by=None, + user=user, bundle_name=create_docker_image_batch_job_bundle_request.name, order_by=None ) assert len(response_2.docker_image_batch_job_bundles) == 1 assert ( @@ -103,9 +101,7 @@ async def test_create_list_docker_image_batch_job_bundle_team_use_case( ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) user_other_team_1 = User( - user_id=test_api_key_user_on_other_team, - team_id=test_api_key_team, - is_privileged_user=True, + user_id=test_api_key_user_on_other_team, team_id=test_api_key_team, is_privileged_user=True ) user_other_team_2 = User( user_id=test_api_key_user_on_other_team_2, @@ -121,9 +117,7 @@ async def test_create_list_docker_image_batch_job_bundle_team_use_case( ) await use_case_create.execute(user=user, request=create_docker_image_batch_job_bundle_request) response_2 = await use_case_list.execute( - user=user, - bundle_name=create_docker_image_batch_job_bundle_request.name, - order_by=None, + user=user, bundle_name=create_docker_image_batch_job_bundle_request.name, order_by=None ) assert len(response_2.docker_image_batch_job_bundles) == 1 response_3 = await use_case_list.execute( @@ -209,9 +203,7 @@ async def test_create_get_docker_image_batch_job_bundle_by_id_unauthorized_use_c ): user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) user_other_team_1 = User( - user_id=test_api_key_user_on_other_team, - team_id=test_api_key_team, - is_privileged_user=True, + user_id=test_api_key_user_on_other_team, team_id=test_api_key_team, is_privileged_user=True ) use_case_create = CreateDockerImageBatchJobBundleV1UseCase( docker_image_batch_job_bundle_repo=fake_docker_image_batch_job_bundle_repository diff --git a/server/tests/unit/domain/test_entities.py b/model-engine/tests/unit/domain/test_entities.py similarity index 86% rename from server/tests/unit/domain/test_entities.py rename to model-engine/tests/unit/domain/test_entities.py index dc9a8e56..41533afc 100644 --- a/server/tests/unit/domain/test_entities.py +++ b/model-engine/tests/unit/domain/test_entities.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( CallbackAuth, CallbackBasicAuth, ModelBundle, @@ -15,6 +15,7 @@ bundle_name="test_bundle", post_inference_hooks=["test_hook"], user_id="test_user", + billing_queue="test_queue", default_callback_url="test_url", ), ModelEndpointConfig( @@ -22,6 +23,7 @@ bundle_name="test_bundle", post_inference_hooks=["test_hook"], user_id="test_user", + billing_queue="test_queue", default_callback_auth=CallbackAuth( __root__=CallbackBasicAuth( kind="basic", username="test_user", password="test_password" @@ -30,9 +32,7 @@ ), ], ) -def test_model_endpoint_config_serialization( - model_endpoint_config: ModelEndpointConfig, -): +def test_model_endpoint_config_serialization(model_endpoint_config: ModelEndpointConfig): serialized_config = model_endpoint_config.serialize() deserialized_config = ModelEndpointConfig.deserialize(serialized_config) assert model_endpoint_config == deserialized_config diff --git a/server/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py similarity index 67% rename from server/tests/unit/domain/test_llm_use_cases.py rename to model-engine/tests/unit/domain/test_llm_use_cases.py index d33c4fb9..4e30c41f 100644 --- a/server/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1,29 +1,43 @@ from typing import Any, Tuple import pytest -from llm_engine_server.common.dtos.llms import ( +from model_engine_server.common.dtos.llms import ( CompletionOutput, CompletionStreamV1Request, CompletionSyncV1Request, + CreateFineTuneRequest, CreateLLMModelEndpointV1Request, CreateLLMModelEndpointV1Response, + ModelDownloadRequest, + TokenOutput, ) -from llm_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ModelEndpoint, ModelEndpointType -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( +from model_engine_server.domain.entities import ModelEndpoint, ModelEndpointType +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, + InvalidRequestException, + LLMFineTuningQuotaReached, +) +from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( + MAX_LLM_ENDPOINTS_PER_INTERNAL_USER, + CreateFineTuneV1UseCase, + GetFineTuneEventsV1UseCase, + is_model_name_suffix_valid, +) +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( CompletionStreamV1UseCase, CompletionSyncV1UseCase, CreateLLMModelEndpointV1UseCase, GetLLMModelEndpointByNameV1UseCase, + ModelDownloadV1UseCase, ) -from llm_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase +from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @pytest.mark.asyncio @@ -36,6 +50,7 @@ async def test_create_model_endpoint_use_case_success( create_llm_model_endpoint_request_async: CreateLLMModelEndpointV1Request, create_llm_model_endpoint_request_sync: CreateLLMModelEndpointV1Request, create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_request_llama_2: CreateLLMModelEndpointV1Request, ): fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository bundle_use_case = CreateModelBundleV2UseCase( @@ -117,6 +132,16 @@ async def test_create_model_endpoint_use_case_success( } } + response_4 = await use_case.execute( + user=user, request=create_llm_model_endpoint_request_llama_2 + ) + assert response_4.endpoint_creation_task_id + assert isinstance(response_4, CreateLLMModelEndpointV1Response) + bundle = await fake_model_bundle_repository.get_latest_model_bundle_by_name( + owner=user.team_id, name=create_llm_model_endpoint_request_llama_2.name + ) + assert "--max-total-tokens" in bundle.flavor.command and "4096" in bundle.flavor.command + @pytest.mark.asyncio async def test_create_model_endpoint_text_generation_inference_use_case_success( @@ -260,7 +285,20 @@ async def test_completion_sync_use_case_success( " of", " programming", ".", - ] + ], + "token_probs": [ + 0.1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], }, "tokens_consumed": 25, } @@ -279,13 +317,23 @@ async def test_completion_sync_use_case_success( model_endpoint_name=llm_model_endpoint_sync[0].record.name, request=completion_sync_request, ) - assert response_1.status == TaskStatus.SUCCESS - assert response_1.outputs == [ - CompletionOutput( - text="I am a newbie to the world of programming.", - num_completion_tokens=11, - ) - ] + assert response_1.output == CompletionOutput( + text="I am a newbie to the world of programming.", + num_completion_tokens=11, + tokens=[ + TokenOutput(token="I", log_prob=-2.3025850929940455), + TokenOutput(token=" am", log_prob=0), + TokenOutput(token=" a", log_prob=0), + TokenOutput(token=" new", log_prob=0), + TokenOutput(token="bie", log_prob=0), + TokenOutput(token=" to", log_prob=0), + TokenOutput(token=" the", log_prob=0), + TokenOutput(token=" world", log_prob=0), + TokenOutput(token=" of", log_prob=0), + TokenOutput(token=" programming", log_prob=0), + TokenOutput(token=".", log_prob=0), + ], + ) @pytest.mark.asyncio @@ -330,31 +378,40 @@ async def test_completion_sync_text_generation_inference_use_case_success( ], "tokens": [ { - "text": " Deep" + "text": " Deep", + "logprob": 0 }, { - "text": " Learning" + "text": " Learning", + "logprob": -1 }, { - "text": " is" + "text": " is", + "logprob": 0 }, { - "text": " a" + "text": " a", + "logprob": 0 }, { - "text": " new" + "text": " new", + "logprob": 0 }, { - "text": " type" + "text": " type", + "logprob": 0 }, { - "text": " of" + "text": " of", + "logprob": 0 }, { - "text": " machine" + "text": " machine", + "logprob": 0 }, { - "text": " learning" + "text": " learning", + "logprob": 0 } ] } @@ -373,18 +430,21 @@ async def test_completion_sync_text_generation_inference_use_case_success( model_endpoint_name=llm_model_endpoint_text_generation_inference.record.name, request=completion_sync_request, ) - assert response_1.status == TaskStatus.SUCCESS - print(response_1.outputs) - assert response_1.outputs == [ - CompletionOutput( - text=" Deep Learning is a new type of machine learning", - num_completion_tokens=9, - ), - CompletionOutput( - text=" Deep Learning is a new type of machine learning", - num_completion_tokens=9, - ), - ] + assert response_1.output == CompletionOutput( + text=" Deep Learning is a new type of machine learning", + num_completion_tokens=9, + tokens=[ + TokenOutput(token=" Deep", log_prob=0.0), + TokenOutput(token=" Learning", log_prob=-1.0), + TokenOutput(token=" is", log_prob=0.0), + TokenOutput(token=" a", log_prob=0.0), + TokenOutput(token=" new", log_prob=0.0), + TokenOutput(token=" type", log_prob=0.0), + TokenOutput(token=" of", log_prob=0.0), + TokenOutput(token=" machine", log_prob=0.0), + TokenOutput(token=" learning", log_prob=0.0), + ], + ) @pytest.mark.asyncio @@ -413,9 +473,7 @@ async def test_completion_sync_use_case_predict_failed( model_endpoint_name=llm_model_endpoint_sync[0].record.name, request=completion_sync_request, ) - assert response_1.status == TaskStatus.FAILURE - assert len(response_1.outputs) == 0 - assert response_1.traceback == "failed to predict" + assert response_1.output is None @pytest.mark.asyncio @@ -519,7 +577,7 @@ async def test_completion_stream_use_case_success( output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] i = 0 async for message in response_1: - assert message.dict()["status"] == "SUCCESS" + assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] if i == 6: assert message.dict()["output"]["num_completion_tokens"] == 6 @@ -580,8 +638,186 @@ async def test_completion_stream_text_generation_inference_use_case_success( output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] i = 0 async for message in response_1: - assert message.dict()["status"] == "SUCCESS" + assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] if i == 5: assert message.dict()["output"]["num_completion_tokens"] == 6 i += 1 + + +@pytest.mark.asyncio +async def test_create_llm_fine_tune_model_name_valid(): + assert is_model_name_suffix_valid("model-name") + assert not is_model_name_suffix_valid("Hi There! This is an invalid model name.") + + +@pytest.mark.asyncio +async def test_create_fine_tune_success( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + test_api_key: str, +): + use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix=None, + ) + response = await use_case.execute(user=user, request=request) + assert response.id + + # This erroring code is part of the service anyways + # request.suffix = "Invalid model suffix *&^&%^$^%&^*" + # with pytest.raises(InvalidRequestException): + # await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +async def test_create_fine_tune_limit( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + test_api_key: str, +): + use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix=None, + ) + for i in range(MAX_LLM_ENDPOINTS_PER_INTERNAL_USER): + if i == MAX_LLM_ENDPOINTS_PER_INTERNAL_USER: + with pytest.raises(LLMFineTuningQuotaReached): + await use_case.execute(user=user, request=request) + else: + response = await use_case.execute(user=user, request=request) + assert response.id + + +@pytest.mark.asyncio +async def test_create_fine_tune_long_suffix( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + test_api_key: str, +): + use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix="a" * 100, + ) + with pytest.raises(InvalidRequestException): + await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +async def test_get_fine_tune_events_success( + fake_llm_fine_tuning_service, + fake_llm_fine_tuning_events_repository, + fake_model_endpoint_service, + fake_file_storage_gateway, + test_api_key: str, +): + populate_use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix=None, + ) + response = await populate_use_case.execute(user=user, request=request) + + use_case = GetFineTuneEventsV1UseCase( + llm_fine_tune_events_repository=fake_llm_fine_tuning_events_repository, + llm_fine_tuning_service=fake_llm_fine_tuning_service, + ) + response_2 = await use_case.execute(user=user, fine_tune_id=response.id) + assert len(response_2.events) == len(fake_llm_fine_tuning_events_repository.all_events_list) + + +@pytest.mark.asyncio +async def test_download_model_success( + fake_model_endpoint_service, + fake_filesystem_gateway, + fake_llm_artifact_gateway, + model_endpoint_1: ModelEndpoint, + test_api_key: str, +): + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + model_endpoint_1.record.owner = test_api_key + model_endpoint_1.record.name = "base_model" + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_llm_artifact_gateway._add_model(user.team_id, model_endpoint_1.record.name) + use_case = ModelDownloadV1UseCase( + fake_filesystem_gateway, fake_model_endpoint_service, fake_llm_artifact_gateway + ) + request = ModelDownloadRequest( + model_name=model_endpoint_1.record.name, + download_format="huggingface", + ) + response = await use_case.execute(user=user, request=request) + assert response.urls != {} + + +@pytest.mark.asyncio +async def test_download_nonexistent_model_raises_not_found( + fake_model_endpoint_service, + fake_filesystem_gateway, + fake_llm_artifact_gateway, + model_endpoint_1: ModelEndpoint, + test_api_key: str, +): + model_endpoint_1.record.owner = test_api_key + model_endpoint_1.record.name = "base_model" + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_llm_artifact_gateway._add_model(test_api_key, "base_model") + use_case = ModelDownloadV1UseCase( + fake_filesystem_gateway, fake_model_endpoint_service, fake_llm_artifact_gateway + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = ModelDownloadRequest( + model_name="nonexistent_model", + download_format="huggingface", + ) + with pytest.raises(ObjectNotFoundException): + await use_case.execute(user=user, request=request) diff --git a/server/tests/unit/domain/test_model_bundle_use_cases.py b/model-engine/tests/unit/domain/test_model_bundle_use_cases.py similarity index 97% rename from server/tests/unit/domain/test_model_bundle_use_cases.py rename to model-engine/tests/unit/domain/test_model_bundle_use_cases.py index 820ffb93..d9b4bc25 100644 --- a/server/tests/unit/domain/test_model_bundle_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_bundle_use_cases.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV1Request, CreateModelBundleV1Request, CreateModelBundleV1Response, @@ -9,15 +9,15 @@ ModelBundleOrderBy, ModelBundleV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( DockerImageNotFoundException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.gateways import ModelPrimitiveGateway -from llm_engine_server.domain.repositories import DockerRepository, ModelBundleRepository -from llm_engine_server.domain.use_cases.model_bundle_use_cases import ( +from model_engine_server.domain.gateways import ModelPrimitiveGateway +from model_engine_server.domain.repositories import DockerRepository, ModelBundleRepository +from model_engine_server.domain.use_cases.model_bundle_use_cases import ( CloneModelBundleV1UseCase, CreateModelBundleV1UseCase, CreateModelBundleV2UseCase, @@ -446,7 +446,7 @@ async def test_create_model_bundle_v2_full_url_use_case_success( model_primitive_gateway=fake_model_primitive_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - # will a full uri specification, image existence is not checked + # will a full uri specification, image existance is not checked create_model_bundle_v2_request.flavor.repository = "registry.hub.docker.com/library/busybox" response = await use_case.execute(user=user, request=create_model_bundle_v2_request) assert response.model_bundle_id diff --git a/server/tests/unit/domain/test_model_endpoint_use_cases.py b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py similarity index 72% rename from server/tests/unit/domain/test_model_endpoint_use_cases.py rename to model-engine/tests/unit/domain/test_model_endpoint_use_cases.py index 20a64885..1875d7d0 100644 --- a/server/tests/unit/domain/test_model_endpoint_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.common.dtos.model_endpoints import ( +from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, DeleteModelEndpointV1Response, @@ -9,29 +9,33 @@ UpdateModelEndpointV1Request, UpdateModelEndpointV1Response, ) -from llm_engine_server.common.resource_limits import ( +from model_engine_server.common.resource_limits import ( FORWARDER_CPU_USAGE, FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_USAGE, REQUESTS_BY_GPU_TYPE, STORAGE_LIMIT, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint +from model_engine_server.domain.exceptions import ( + EndpointBillingTagsMalformedException, + EndpointLabelsException, + EndpointResourceInvalidRequestException, +) +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CreateModelEndpointV1UseCase, DeleteModelEndpointByIdV1UseCase, GetModelEndpointByIdV1UseCase, ListModelEndpointsV1UseCase, UpdateModelEndpointByIdV1UseCase, ) -from llm_engine_server.infra.gateways.k8s_resource_parser import parse_mem_request +from model_engine_server.infra.gateways.k8s_resource_parser import parse_mem_request @pytest.mark.asyncio @@ -279,6 +283,154 @@ async def test_create_model_endpoint_use_case_raises_resource_request_exception( await use_case.execute(user=user, request=request) +@pytest.mark.asyncio +async def test_create_model_endpoint_use_case_raises_endpoint_labels_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + create_model_endpoint_request_async: CreateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_1.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = create_model_endpoint_request_async.copy() + request.labels = None # type: ignore + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = {} + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = {"team": "infra"} + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = {"product": "my_product"} + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = { + "team": "infra", + "product": "my_product", + "user_id": "test_labels_user", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = { + "team": "infra", + "product": "my_product", + "endpoint_name": "test_labels_endpoint_name", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + +@pytest.mark.skip(reason="TODO: team validation is currently disabled") +@pytest.mark.asyncio +async def test_create_model_endpoint_use_case_invalid_team_raises_endpoint_labels_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + create_model_endpoint_request_async: CreateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_1.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = create_model_endpoint_request_async.copy() + request.labels = { + "team": "unknown_team", + "product": "my_product", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + # for team in ALLOWED_TEAMS: + # # Conversely, make sure that all the ALLOWED_TEAMS are, well, allowed. + # request = create_model_endpoint_request_async.copy() + # request.labels = { + # "team": team, + # "product": "my_product", + # } + # await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +async def test_create_model_endpoint_use_case_raises_billing_tags_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + create_model_endpoint_request_async: CreateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_1.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = None + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = { + "idempotencyKeyPrefix": "val1", + "product": "val2", + "type": "val3", + "subType": "val4", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "val5", + "payor": "val6", + "reference": {"referenceType": "val7", "referenceId": "val8"}, + } + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = {"incomplete_labels": "hi"} + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = { + "idempotencyKeyPrefix": ["wrong", "type"], + "product": "val2", + "type": "val3", + "subType": "val4", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "val5", + "payor": "val6", + "reference": {"referenceType": "val7", "referenceId": "val8"}, + } + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = "not_a_dict" # type: ignore + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute(user=user, request=request) + + @pytest.mark.asyncio async def test_create_model_endpoint_use_case_validates_post_inference_hooks( fake_model_bundle_repository, @@ -768,6 +920,201 @@ async def test_update_model_endpoint_raises_not_authorized( ) +@pytest.mark.asyncio +async def test_update_model_endpoint_raises_endpoint_labels_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + model_bundle_2: ModelBundle, + model_endpoint_1: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + + request = update_model_endpoint_request.copy() + request.labels = {"team": "infra"} + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.labels = {"product": "my_product"} + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.labels = { + "team": "infra", + "product": "my_product", + "user_id": "test_labels_user", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.labels = { + "team": "infra", + "product": "my_product", + "endpoint_name": "test_labels_endpoint_name", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + +@pytest.mark.skip(reason="TODO: team validation is currently disabled") +@pytest.mark.asyncio +async def test_update_model_endpoint_invalid_team_raises_endpoint_labels_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + model_bundle_2: ModelBundle, + model_endpoint_1: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + + request = update_model_endpoint_request.copy() + request.labels = { + "team": "invalid_team", + "product": "some_product", + } + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + # TODO: renable this part of the test once we figure out how to import this + # properly + + # for team in ALLOWED_TEAMS: + # # Conversely, make sure that all the ALLOWED_TEAMS are, well, allowed. + # request = update_model_endpoint_request.copy() + # request.labels = { + # "team": team, + # "product": "my_product", + # } + # await use_case.execute( + # user=user, model_endpoint_id=model_endpoint_1.record.id, request=request + # ) + + +@pytest.mark.asyncio +async def test_update_model_endpoint_raises_billing_tags_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + model_bundle_2: ModelBundle, + model_endpoint_1: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = None + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = { + "idempotencyKeyPrefix": "val1", + "product": "val2", + "type": "val3", + "subType": "val4", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "val5", + "payor": "val6", + "reference": {"referenceType": "val7", "referenceId": "val8"}, + } + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = {"incomplete_labels": "hi"} + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = { + "idempotencyKeyPrefix": ["wrong", "type"], + "product": "val2", + "type": "val3", + "subType": "val4", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "val5", + "payor": "val6", + "reference": {"referenceType": "val7", "referenceId": "val8"}, + } + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = "not_a_dict" # type: ignore + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + @pytest.mark.asyncio async def test_delete_model_endpoint_success( fake_model_endpoint_service, diff --git a/server/tests/unit/domain/test_streaming_inference_use_cases.py b/model-engine/tests/unit/domain/test_streaming_inference_use_cases.py similarity index 88% rename from server/tests/unit/domain/test_streaming_inference_use_cases.py rename to model-engine/tests/unit/domain/test_streaming_inference_use_cases.py index 5043b6a4..191fa0f4 100644 --- a/server/tests/unit/domain/test_streaming_inference_use_cases.py +++ b/model-engine/tests/unit/domain/test_streaming_inference_use_cases.py @@ -1,15 +1,15 @@ from typing import Any, Dict, Tuple import pytest -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.use_cases.streaming_inference_use_cases import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException +from model_engine_server.domain.use_cases.streaming_inference_use_cases import ( CreateStreamingInferenceTaskV1UseCase, ) diff --git a/server/tests/unit/domain/test_sync_inference_use_cases.py b/model-engine/tests/unit/domain/test_sync_inference_use_cases.py similarity index 90% rename from server/tests/unit/domain/test_sync_inference_use_cases.py rename to model-engine/tests/unit/domain/test_sync_inference_use_cases.py index ffb3637e..879d5345 100644 --- a/server/tests/unit/domain/test_sync_inference_use_cases.py +++ b/model-engine/tests/unit/domain/test_sync_inference_use_cases.py @@ -1,14 +1,14 @@ from typing import Any, Dict, Tuple import pytest -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.domain_exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.domain.use_cases.sync_inference_use_cases import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.use_cases.sync_inference_use_cases import ( CreateSyncInferenceTaskV1UseCase, ) diff --git a/server/tests/unit/inference/test_forwarding.py b/model-engine/tests/unit/inference/test_forwarding.py similarity index 81% rename from server/tests/unit/inference/test_forwarding.py rename to model-engine/tests/unit/inference/test_forwarding.py index 5f343fcd..283af031 100644 --- a/server/tests/unit/inference/test_forwarding.py +++ b/model-engine/tests/unit/inference/test_forwarding.py @@ -4,9 +4,9 @@ from unittest import mock import pytest -from llm_engine_server.core.utils.env import environment -from llm_engine_server.domain.entities import ModelEndpointConfig -from llm_engine_server.inference.forwarding.forwarding import ( +from model_engine_server.core.utils.env import environment +from model_engine_server.domain.entities import ModelEndpointConfig +from model_engine_server.inference.forwarding.forwarding import ( ENV_SERIALIZE_RESULTS_AS_STRING, KEY_SERIALIZE_RESULTS_AS_STRING, Forwarder, @@ -14,12 +14,12 @@ LoadStreamingForwarder, StreamingForwarder, ) -from llm_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) -from llm_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler -PAYLOAD: Mapping[str, Mapping[str, str]] = {"hello": "world"} +PAYLOAD: Mapping[str, str] = {"hello": "world"} def mocked_get(*args, **kwargs): # noqa @@ -67,6 +67,8 @@ def post_inference_hooks_handler(): bundle_name="test_bundle_name", post_inference_hooks=[], user_id="test_user_id", + billing_queue="billing_queue", + billing_tags=[], default_callback_url=None, default_callback_auth=None, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), @@ -79,12 +81,12 @@ def post_inference_hooks_handler(): def test_forwarders(post_inference_hooks_handler): fwd = Forwarder( "ignored", - llm_engine_unwrap=True, + model_engine_unwrap=True, serialize_results_as_string=False, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, ) - json_response = fwd({"args": {"ignore": "me"}}) + json_response = fwd({"ignore": "me"}) _check(json_response) @@ -115,12 +117,12 @@ def _check_streaming_serialized(streaming_response) -> None: def test_forwarders_serialize_results_as_string(post_inference_hooks_handler): fwd = Forwarder( "ignored", - llm_engine_unwrap=True, + model_engine_unwrap=True, serialize_results_as_string=True, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, ) - json_response = fwd({"args": {"ignore": "me"}}) + json_response = fwd({"ignore": "me"}) _check_serialized(json_response) @@ -135,23 +137,23 @@ def _check_serialized(json_response) -> None: def test_forwarders_override_serialize_results(post_inference_hooks_handler): fwd = Forwarder( "ignored", - llm_engine_unwrap=True, + model_engine_unwrap=True, serialize_results_as_string=True, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, ) - json_response = fwd({"args": {"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: False}}) + json_response = fwd({"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: False}) _check(json_response) assert json_response == {"result": PAYLOAD} fwd = Forwarder( "ignored", - llm_engine_unwrap=True, + model_engine_unwrap=True, serialize_results_as_string=False, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, ) - json_response = fwd({"args": {"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: True}}) + json_response = fwd({"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: True}) _check_serialized(json_response) @@ -160,60 +162,60 @@ def test_forwarders_override_serialize_results(post_inference_hooks_handler): def test_forwarder_does_not_wrap_response(post_inference_hooks_handler): fwd = Forwarder( "ignored", - llm_engine_unwrap=True, + model_engine_unwrap=True, serialize_results_as_string=False, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=False, ) - json_response = fwd({"args": {"ignore": "me"}}) + json_response = fwd({"ignore": "me"}) _check_responses_not_wrapped(json_response) @mock.patch("requests.post", mocked_post) @mock.patch("requests.get", mocked_get) @mock.patch( - "llm_engine_server.inference.forwarding.forwarding.get_endpoint_config", + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", mocked_get_endpoint_config, ) def test_forwarder_loader(): fwd = LoadForwarder(serialize_results_as_string=True).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) + json_response = fwd({"ignore": "me"}) _check_serialized(json_response) fwd = LoadForwarder(serialize_results_as_string=False).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) + json_response = fwd({"ignore": "me"}) _check(json_response) fwd = LoadForwarder(wrap_response=False).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) + json_response = fwd({"ignore": "me"}) _check_responses_not_wrapped(json_response) @mock.patch("requests.post", mocked_post) @mock.patch("requests.get", mocked_get) @mock.patch( - "llm_engine_server.inference.forwarding.forwarding.get_endpoint_config", + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", mocked_get_endpoint_config, ) def test_forwarder_loader_env_serialize_behavior(post_inference_hooks_handler): with environment(**{ENV_SERIALIZE_RESULTS_AS_STRING: "false"}): fwd = LoadForwarder(serialize_results_as_string=True).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) + json_response = fwd({"ignore": "me"}) _check(json_response) with environment(**{ENV_SERIALIZE_RESULTS_AS_STRING: "true"}): fwd = LoadForwarder(serialize_results_as_string=False).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) + json_response = fwd({"ignore": "me"}) _check_serialized(json_response) @mock.patch("requests.post", mocked_post) @mock.patch("requests.get", mocked_get) def test_forwarder_serialize_within_args(post_inference_hooks_handler): - # standard Spellbook-Serve-created forwarder + # standard Launch-created forwarder fwd = Forwarder( "ignored", - llm_engine_unwrap=True, + model_engine_unwrap=True, serialize_results_as_string=True, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, @@ -231,7 +233,7 @@ def test_forwarder_serialize_within_args(post_inference_hooks_handler): # w/o unwrapping it won't "find" the `"serialize_results_as_string": False` directive fwd = Forwarder( "ignored", - llm_engine_unwrap=False, + model_engine_unwrap=False, serialize_results_as_string=True, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, @@ -246,11 +248,11 @@ def test_forwarder_serialize_within_args(post_inference_hooks_handler): def test_streaming_forwarders(post_inference_hooks_handler): fwd = StreamingForwarder( "ignored", - llm_engine_unwrap=True, + model_engine_unwrap=True, serialize_results_as_string=False, post_inference_hooks_handler=post_inference_hooks_handler, ) - response = fwd({"args": {"ignore": "me"}}) + response = fwd({"ignore": "me"}) _check_streaming(response) @@ -258,14 +260,14 @@ def test_streaming_forwarders(post_inference_hooks_handler): @mock.patch("requests.get", mocked_get) @mock.patch("sseclient.SSEClient", mocked_sse_client) @mock.patch( - "llm_engine_server.inference.forwarding.forwarding.get_endpoint_config", + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", mocked_get_endpoint_config, ) def test_streaming_forwarder_loader(): fwd = LoadStreamingForwarder(serialize_results_as_string=True).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) + json_response = fwd({"ignore": "me"}) _check_streaming_serialized(json_response) fwd = LoadStreamingForwarder(serialize_results_as_string=False).load(None, None) # type: ignore - response = fwd({"args": {"ignore": "me"}}) + response = fwd({"ignore": "me"}) _check_streaming(response) diff --git a/model-engine/tests/unit/inference/test_http_forwarder.py b/model-engine/tests/unit/inference/test_http_forwarder.py new file mode 100644 index 00000000..43fbdfbd --- /dev/null +++ b/model-engine/tests/unit/inference/test_http_forwarder.py @@ -0,0 +1,46 @@ +import threading +import time + +import pytest +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.inference.forwarding.http_forwarder import ( + MultiprocessingConcurrencyLimiter, + predict, +) + + +class ExceptionCapturedThread(threading.Thread): + def __init__(self, target, args): + super().__init__(target=target, args=args) + self.ex = None + + def run(self): + try: + self._target(*self._args) + except Exception as e: + self.ex = e + + def join(self): + super().join() + if self.ex is not None: + raise self.ex + + +def mock_forwarder(dict): + time.sleep(1) + return dict + + +def test_http_service_429(): + limiter = MultiprocessingConcurrencyLimiter(1, True) + t1 = ExceptionCapturedThread( + target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter) + ) + t2 = ExceptionCapturedThread( + target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter) + ) + t1.start() + t2.start() + t1.join() + with pytest.raises(Exception): # 429 thrown + t2.join() diff --git a/server/tests/unit/infra/gateways/conftest.py b/model-engine/tests/unit/infra/gateways/conftest.py similarity index 51% rename from server/tests/unit/infra/gateways/conftest.py rename to model-engine/tests/unit/infra/gateways/conftest.py index fe82601e..4ca93044 100644 --- a/server/tests/unit/infra/gateways/conftest.py +++ b/model-engine/tests/unit/infra/gateways/conftest.py @@ -1,6 +1,6 @@ import pytest -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest @pytest.fixture @@ -15,33 +15,22 @@ def create_resources_request_sync_pytorch( @pytest.fixture -def create_resources_request_async_tensorflow( - test_api_key: str, build_endpoint_request_async_tensorflow: BuildEndpointRequest +def create_resources_request_async_runnable_image( + test_api_key: str, build_endpoint_request_async_runnable_image: BuildEndpointRequest ) -> CreateOrUpdateResourcesRequest: create_resources_request = CreateOrUpdateResourcesRequest( - build_endpoint_request=build_endpoint_request_async_tensorflow, + build_endpoint_request=build_endpoint_request_async_runnable_image, image="test_image", ) return create_resources_request @pytest.fixture -def create_resources_request_async_custom( - test_api_key: str, build_endpoint_request_async_custom: BuildEndpointRequest +def create_resources_request_sync_runnable_image( + test_api_key: str, build_endpoint_request_sync_runnable_image: BuildEndpointRequest ) -> CreateOrUpdateResourcesRequest: create_resources_request = CreateOrUpdateResourcesRequest( - build_endpoint_request=build_endpoint_request_async_custom, - image="test_image", - ) - return create_resources_request - - -@pytest.fixture -def create_resources_request_sync_custom( - test_api_key: str, build_endpoint_request_sync_custom: BuildEndpointRequest -) -> CreateOrUpdateResourcesRequest: - create_resources_request = CreateOrUpdateResourcesRequest( - build_endpoint_request=build_endpoint_request_sync_custom, + build_endpoint_request=build_endpoint_request_sync_runnable_image, image="test_image", ) return create_resources_request @@ -49,8 +38,7 @@ def create_resources_request_sync_custom( @pytest.fixture def create_resources_request_streaming_runnable_image( - test_api_key: str, - build_endpoint_request_streaming_runnable_image: BuildEndpointRequest, + test_api_key: str, build_endpoint_request_streaming_runnable_image: BuildEndpointRequest ) -> CreateOrUpdateResourcesRequest: create_resources_request = CreateOrUpdateResourcesRequest( build_endpoint_request=build_endpoint_request_streaming_runnable_image, diff --git a/server/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py similarity index 87% rename from server/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py rename to model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py index 6dbefbe1..40fe6c14 100644 --- a/server/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py @@ -4,28 +4,29 @@ import pytest from kubernetes_asyncio.client.rest import ApiException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.domain.entities import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.env_vars import GIT_TAG +from model_engine_server.domain.entities import ( ModelEndpointConfig, ModelEndpointType, ModelEndpointUserConfigState, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( DATADOG_ENV_VAR, K8SEndpointResourceDelegate, add_datadog_env_to_main_container, get_main_container_from_deployment_template, load_k8s_yaml, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ( +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( DictStrInt, DictStrStr, ResourceArguments, ) -MODULE_PATH = "llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate" +MODULE_PATH = "model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate" @dataclass @@ -39,6 +40,16 @@ class FakeK8sDeploymentContainer: env: List[FakeK8sEnvVar] +@pytest.fixture +def mock_get_kubernetes_cluster_version(): + mock_version = "1.26" + with patch( + f"{MODULE_PATH}.get_kubernetes_cluster_version", + return_value=mock_version, + ): + yield mock_version + + @pytest.fixture def mock_apps_client(): mock_client = AsyncMock() @@ -119,9 +130,7 @@ def k8s_endpoint_resource_delegate( @pytest.mark.parametrize("resource_arguments_type", ResourceArguments.__args__) -def test_resource_arguments_type_and_add_datadog_env_to_main_container( - resource_arguments_type, -): +def test_resource_arguments_type_and_add_datadog_env_to_main_container(resource_arguments_type): # Convert the name of the type to a kebab case string # e.g. "BatchJobOrchestrationJobArguments" -> "batch-job-orchestration-job-arguments" resource_arguments_type_name = resource_arguments_type.__name__ @@ -178,9 +187,8 @@ def _verify_deployment_labels( labels = build_endpoint_request.labels endpoint_name = model_endpoint_record.name env = "circleci" - git_tag = "54f8f73bfb1cce62a2b42326ccf9f49b5b145126" - k8s_resource_group_name = f"llm-engine-endpoint-id-{model_endpoint_record.id.replace('_', '-')}" + k8s_resource_group_name = f"launch-endpoint-id-{model_endpoint_record.id.replace('_', '-')}" assert body["metadata"]["name"] == k8s_resource_group_name assert body["metadata"]["namespace"] == hmi_config.endpoint_namespace @@ -191,15 +199,15 @@ def _verify_deployment_labels( "user_id": user_id, "endpoint_id": model_endpoint_record.id, "endpoint_name": endpoint_name, - "managed-by": "llm-engine", + "managed-by": "model-engine", "owner": user_id, "team": labels["team"], "product": labels["product"], "env": env, "tags.datadoghq.com/env": env, "tags.datadoghq.com/service": endpoint_name, - "tags.datadoghq.com/version": git_tag, - "use_scale_llm_engine_endpoint_network_policy": "true", + "tags.datadoghq.com/version": GIT_TAG, + "use_scale_launch_endpoint_network_policy": "true", } assert body["metadata"]["labels"] == expected_labels @@ -209,7 +217,7 @@ def _verify_deployment_labels( "user_id": user_id, "endpoint_id": model_endpoint_record.id, "endpoint_name": endpoint_name, - "managed-by": "llm-engine", + "managed-by": "model-engine", "owner": user_id, "team": labels["team"], "product": labels["product"], @@ -217,8 +225,8 @@ def _verify_deployment_labels( "version": "v1", "tags.datadoghq.com/env": env, "tags.datadoghq.com/service": endpoint_name, - "tags.datadoghq.com/version": git_tag, - "use_scale_llm_engine_endpoint_network_policy": "true", + "tags.datadoghq.com/version": GIT_TAG, + "use_scale_launch_endpoint_network_policy": "true", } if model_endpoint_record.endpoint_type == ModelEndpointType.ASYNC: @@ -237,9 +245,8 @@ def _verify_non_deployment_labels( labels = build_endpoint_request.labels endpoint_name = model_endpoint_record.name env = "circleci" - git_tag = "54f8f73bfb1cce62a2b42326ccf9f49b5b145126" - k8s_resource_group_name = f"llm-engine-endpoint-id-{model_endpoint_record.id.replace('_', '-')}" + k8s_resource_group_name = f"launch-endpoint-id-{model_endpoint_record.id.replace('_', '-')}" assert k8s_resource_group_name in body["metadata"]["name"] assert body["metadata"]["namespace"] == hmi_config.endpoint_namespace @@ -247,7 +254,7 @@ def _verify_non_deployment_labels( expected_labels = { "created_by": user_id, - "managed-by": "llm-engine", + "managed-by": "model-engine", "owner": user_id, "user_id": user_id, "endpoint_id": model_endpoint_record.id, @@ -257,8 +264,8 @@ def _verify_non_deployment_labels( "env": env, "tags.datadoghq.com/env": env, "tags.datadoghq.com/service": endpoint_name, - "tags.datadoghq.com/version": git_tag, - "use_scale_llm_engine_endpoint_network_policy": "true", + "tags.datadoghq.com/version": GIT_TAG, + "use_scale_launch_endpoint_network_policy": "true", } assert body["metadata"]["labels"] == expected_labels @@ -281,12 +288,11 @@ async def test_create_async_endpoint_has_correct_labels( mock_core_client, mock_autoscaling_client, mock_custom_objects_client, - create_resources_request_async_custom: CreateOrUpdateResourcesRequest, - create_resources_request_async_tensorflow: CreateOrUpdateResourcesRequest, + mock_get_kubernetes_cluster_version, + create_resources_request_async_runnable_image: CreateOrUpdateResourcesRequest, ): for request in [ - create_resources_request_async_custom, - create_resources_request_async_tensorflow, + create_resources_request_async_runnable_image, ]: await k8s_endpoint_resource_delegate.create_or_update_resources( request, sqs_queue_name="my_queue", sqs_queue_url="https://my_queue" @@ -345,6 +351,7 @@ async def test_create_streaming_endpoint_has_correct_labels( mock_core_client, mock_autoscaling_client, mock_custom_objects_client, + mock_get_kubernetes_cluster_version, create_resources_request_streaming_runnable_image: CreateOrUpdateResourcesRequest, ): request = create_resources_request_streaming_runnable_image @@ -385,7 +392,12 @@ async def test_create_streaming_endpoint_has_correct_labels( if optimize_costs: _verify_custom_object_plurals( call_args_list=create_custom_object_call_args_list, - expected_plurals=["verticalpodautoscalers"], + expected_plurals=["verticalpodautoscalers", "virtualservices", "destinationrules"], + ) + if build_endpoint_request.model_endpoint_record.endpoint_type == ModelEndpointType.SYNC: + _verify_custom_object_plurals( + call_args_list=create_custom_object_call_args_list, + expected_plurals=["virtualservices", "destinationrules"], ) mock_custom_objects_client.reset_mock() @@ -406,12 +418,11 @@ async def test_create_sync_endpoint_has_correct_labels( mock_core_client, mock_autoscaling_client, mock_custom_objects_client, - create_resources_request_sync_pytorch: CreateOrUpdateResourcesRequest, - create_resources_request_sync_custom: CreateOrUpdateResourcesRequest, + mock_get_kubernetes_cluster_version, + create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest, ): for request in [ - create_resources_request_sync_pytorch, - create_resources_request_sync_custom, + create_resources_request_sync_runnable_image, ]: await k8s_endpoint_resource_delegate.create_or_update_resources( request, @@ -450,13 +461,20 @@ async def test_create_sync_endpoint_has_correct_labels( if optimize_costs: _verify_custom_object_plurals( call_args_list=create_custom_object_call_args_list, - expected_plurals=["verticalpodautoscalers"], + expected_plurals=["verticalpodautoscalers", "virtualservices", "destinationrules"], + ) + if build_endpoint_request.model_endpoint_record.endpoint_type == ModelEndpointType.SYNC: + _verify_custom_object_plurals( + call_args_list=create_custom_object_call_args_list, + expected_plurals=["virtualservices", "destinationrules"], ) mock_custom_objects_client.reset_mock() # Make sure that an VPA is created if optimize_costs is True. - optimize_costs = create_resources_request_sync_pytorch.build_endpoint_request.optimize_costs + optimize_costs = ( + create_resources_request_sync_runnable_image.build_endpoint_request.optimize_costs + ) create_vpa_call_args = mock_custom_objects_client.create_namespaced_custom_objects.call_args if optimize_costs: assert create_vpa_call_args is not None @@ -471,10 +489,11 @@ async def test_create_sync_endpoint_has_correct_k8s_service_type( mock_core_client, mock_autoscaling_client, mock_custom_objects_client, - create_resources_request_sync_pytorch: CreateOrUpdateResourcesRequest, + mock_get_kubernetes_cluster_version, + create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest, ): await k8s_endpoint_resource_delegate.create_or_update_resources( - create_resources_request_sync_pytorch, + create_resources_request_sync_runnable_image, sqs_queue_name="my_queue", sqs_queue_url="https://my_queue", ) @@ -550,7 +569,7 @@ async def test_get_resources_async_success( Mock(return_value=FakeK8sDeploymentContainer(env=[])), ) k8s_endpoint_resource_delegate.__setattr__( - "_get_llm_engine_container", + "_get_launch_container", Mock( return_value=FakeK8sDeploymentContainer( env=[FakeK8sEnvVar(name="PREWARM", value="true")] @@ -608,8 +627,7 @@ async def test_get_resources_sync_success( "_get_main_container", Mock(return_value=FakeK8sDeploymentContainer(env=[])) ) k8s_endpoint_resource_delegate.__setattr__( - "_get_llm_engine_container", - Mock(return_value=FakeK8sDeploymentContainer(env=[])), + "_get_launch_container", Mock(return_value=FakeK8sDeploymentContainer(env=[])) ) k8s_endpoint_resource_delegate.__setattr__( "_translate_k8s_config_maps_to_user_config_data", diff --git a/server/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py similarity index 88% rename from server/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py rename to model-engine/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py index 952712a5..1ab2143a 100644 --- a/server/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py @@ -4,14 +4,14 @@ import botocore.exceptions import pytest -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from llm_engine_server.domain.entities import ModelEndpointRecord -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.domain.entities import ModelEndpointRecord +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( LiveSQSEndpointResourceDelegate, ) -MODULE_PATH = "llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate" +MODULE_PATH = "model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate" EXPECTED_QUEUE_POLICY = """ { @@ -25,7 +25,7 @@ "AWS": "arn:aws:iam::000000000000:root" }, "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:llm-engine-endpoint-id-test_model_endpoint_id_3" + "Resource": "arn:aws:sqs:us-west-2:000000000000:launch-endpoint-id-test_model_endpoint_id_3" }, { "Effect": "Allow", @@ -33,29 +33,21 @@ "AWS": "arn:aws:iam::000000000000:role/default" }, "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:llm-engine-endpoint-id-test_model_endpoint_id_3" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/ml_llm_engine" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:llm-engine-endpoint-id-test_model_endpoint_id_3" + "Resource": "arn:aws:sqs:us-west-2:000000000000:launch-endpoint-id-test_model_endpoint_id_3" } ] } """ EXPECTED_QUEUE_TAGS = { - "infra.scale.com/product": "MLInfraLLMEngineSQS", + "infra.scale.com/product": "MLInfraLaunchSQS", "infra.scale.com/team": "test_team", "infra.scale.com/contact": "yi.xu@scale.com", "infra.scale.com/customer": "AllCustomers", "infra.scale.com/financialOwner": "yi.xu@scale.com", - "Spellbook-Serve-Endpoint-Id": "test_model_endpoint_id_3", - "Spellbook-Serve-Endpoint-Name": "test_model_endpoint_name_3", - "Spellbook-Serve-Endpoint-Created-By": "test_user_id", + "Launch-Endpoint-Id": "test_model_endpoint_id_3", + "Launch-Endpoint-Name": "test_model_endpoint_name_3", + "Launch-Endpoint-Created-By": "test_user_id", } @@ -75,7 +67,7 @@ def _get_fake_botocore_exception(): @pytest.fixture def mock_create_async_sqs_client_create_queue(): create_queue_response = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-test_model_endpoint_id_3", + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-test_model_endpoint_id_3", "ResponseMetadata": { "RequestId": "9c05b1cc-d806-5cbd-bd4a-ea339c90e25f", "HTTPStatusCode": 200, @@ -108,7 +100,7 @@ def mock_create_async_sqs_client_create_queue(): @pytest.fixture def mock_create_async_sqs_client_get_queue_url(): get_queue_response = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-test_model_endpoint_id_3", + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-test_model_endpoint_id_3", } mock_sqs_client_session_val = AsyncMock() @@ -179,7 +171,7 @@ def mock_create_async_sqs_client_delete_queue(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } delete_response = { @@ -213,7 +205,7 @@ def mock_create_async_sqs_client_delete_queue_returns_non_200(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } delete_response = { @@ -247,7 +239,7 @@ def mock_create_async_sqs_client_delete_queue_throws_exception(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } mock_sqs_client_session_val.delete_queue = AsyncMock(side_effect=_get_fake_botocore_exception()) @@ -268,12 +260,12 @@ def mock_create_async_sqs_client_get_queue_attributes(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } get_queue_attributes_response = { "Attributes": { - "QueueArn": "arn:aws:sqs:us-west-2:000000000000:llm-engine-endpoint-id-model_endpoint_id_1", + "QueueArn": "arn:aws:sqs:us-west-2:000000000000:launch-endpoint-id-model_endpoint_id_1", "ApproximateNumberOfMessages": "0", "ApproximateNumberOfMessagesNotVisible": "0", "ApproximateNumberOfMessagesDelayed": "0", @@ -326,7 +318,7 @@ def mock_create_async_sqs_client_get_queue_attributes_queue_throws_exception(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } mock_sqs_client_session_val.get_queue_attributes = AsyncMock( @@ -360,7 +352,7 @@ async def test_sqs_create_or_update_resources_endpoint_exists( mock_create_async_sqs_client_get_queue_url.__aenter__.assert_called_once() expected_get_queue_url_args: Dict[str, Any] = { - "QueueName": "llm-engine-endpoint-id-test_model_endpoint_id_3", + "QueueName": "launch-endpoint-id-test_model_endpoint_id_3", } actual_get_queue_kwargs = ( mock_create_async_sqs_client_get_queue_url.__aenter__.return_value.get_queue_url.call_args.kwargs @@ -388,7 +380,7 @@ async def test_sqs_create_or_update_resources( mock_create_async_sqs_client_create_queue.__aenter__.assert_called_once() expected_create_queue_args: Dict[str, Any] = { - "QueueName": "llm-engine-endpoint-id-test_model_endpoint_id_3", + "QueueName": "launch-endpoint-id-test_model_endpoint_id_3", "Attributes": { "VisibilityTimeout": "3600", "Policy": EXPECTED_QUEUE_POLICY, @@ -450,13 +442,13 @@ async def test_sqs_delete_resources(mock_create_async_sqs_client_delete_queue): mock_create_async_sqs_client_delete_queue.__aenter__.assert_called_once() mock_create_async_sqs_client_delete_queue.__aenter__.return_value.get_queue_url.assert_called_once_with( - QueueName="llm-engine-endpoint-id-model_endpoint_id_1" + QueueName="launch-endpoint-id-model_endpoint_id_1" ) delete_call_kwargs = ( mock_create_async_sqs_client_delete_queue.__aenter__.return_value.delete_queue.call_args.kwargs ) - assert delete_call_kwargs["QueueUrl"].endswith("llm-engine-endpoint-id-model_endpoint_id_1") + assert delete_call_kwargs["QueueUrl"].endswith("launch-endpoint-id-model_endpoint_id_1") @pytest.mark.asyncio @@ -478,25 +470,23 @@ async def test_sqs_delete_resources_non_200( @pytest.mark.asyncio -async def test_sqs_get_queue_attributes( - mock_create_async_sqs_client_get_queue_attributes, -): +async def test_sqs_get_queue_attributes(mock_create_async_sqs_client_get_queue_attributes): delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") response = await delegate.get_queue_attributes(endpoint_id="model_endpoint_id_1") mock_create_async_sqs_client_get_queue_attributes.__aenter__.assert_called_once() mock_create_async_sqs_client_get_queue_attributes.__aenter__.return_value.get_queue_url.assert_called_once_with( - QueueName="llm-engine-endpoint-id-model_endpoint_id_1" + QueueName="launch-endpoint-id-model_endpoint_id_1" ) get_queue_attributes_call_kwargs = ( mock_create_async_sqs_client_get_queue_attributes.__aenter__.return_value.get_queue_attributes.call_args.kwargs ) assert get_queue_attributes_call_kwargs["QueueUrl"].endswith( - "llm-engine-endpoint-id-model_endpoint_id_1" + "launch-endpoint-id-model_endpoint_id_1" ) - assert response["Attributes"]["QueueArn"].endswith("llm-engine-endpoint-id-model_endpoint_id_1") + assert response["Attributes"]["QueueArn"].endswith("launch-endpoint-id-model_endpoint_id_1") @pytest.mark.asyncio diff --git a/server/tests/unit/infra/gateways/test_k8s_resource_parser.py b/model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py similarity index 97% rename from server/tests/unit/infra/gateways/test_k8s_resource_parser.py rename to model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py index 91741b7a..7f59350a 100644 --- a/server/tests/unit/infra/gateways/test_k8s_resource_parser.py +++ b/model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.infra.gateways.k8s_resource_parser import ( +from model_engine_server.infra.gateways.k8s_resource_parser import ( get_per_worker_value_from_target_concurrency, get_target_concurrency_from_per_worker_value, parse_cpu_request, diff --git a/server/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py similarity index 94% rename from server/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py index ce2309ef..2b4939df 100644 --- a/server/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py @@ -2,8 +2,8 @@ from typing import Any import pytest -from llm_engine_server.common.dtos.tasks import GetAsyncTaskV1Response, TaskStatus -from llm_engine_server.infra.gateways import LiveAsyncModelEndpointInferenceGateway +from model_engine_server.common.dtos.tasks import GetAsyncTaskV1Response, TaskStatus +from model_engine_server.infra.gateways import LiveAsyncModelEndpointInferenceGateway @pytest.fixture diff --git a/server/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py similarity index 91% rename from server/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py index 53716d53..2a3fe197 100644 --- a/server/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py @@ -1,5 +1,5 @@ -from llm_engine_server.domain.entities import BatchJobProgress -from llm_engine_server.infra.gateways import LiveBatchJobProgressGateway +from model_engine_server.domain.entities import BatchJobProgress +from model_engine_server.infra.gateways import LiveBatchJobProgressGateway def test_get_progress_empty(test_api_key: str, fake_filesystem_gateway): diff --git a/server/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py similarity index 94% rename from server/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py index e1652ceb..2f4c5c2a 100644 --- a/server/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py @@ -1,4 +1,4 @@ -from llm_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( +from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( K8sEnvDict, _add_list_values, _check_batch_job_id_valid, diff --git a/server/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py similarity index 84% rename from server/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py index b409463e..e03a9840 100644 --- a/server/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py @@ -2,8 +2,8 @@ from unittest.mock import Mock import pytest -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.infra.gateways import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.infra.gateways import ( LiveModelEndpointInfraGateway, live_model_endpoint_infra_gateway, ) @@ -91,6 +91,15 @@ async def test_update_model_endpoint_infra( ), ) assert creation_task_id_1 + # Test existing billing tags don't get lost + endpoint_config = model_endpoint_1.infra_state.user_config_state.endpoint_config # type: ignore + billing_tags = endpoint_config.billing_tags # type: ignore + assert ( + fake_task_queue_gateway.get_task_args(creation_task_id_1)["kwargs"][ + "build_endpoint_request_json" + ].get("billing_tags") + == billing_tags + ) creation_task_id_2 = await model_endpoint_infra_gateway.update_model_endpoint_infra( model_endpoint_record=model_endpoint_1.record, @@ -100,8 +109,28 @@ async def test_update_model_endpoint_infra( gpu_type=model_endpoint_2.infra_state.resource_state.gpu_type, child_fn_info=model_endpoint_2.infra_state.child_fn_info, labels=model_endpoint_2.infra_state.labels, + billing_tags={ + "idempotencyKeyPrefix": "new_value_1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, ) assert creation_task_id_2 + # Inspect the value of billing_tags across the wire to make sure it's set correctly + # Test new billing tags overwrite existing ones + assert ( + fake_task_queue_gateway.get_task_args(creation_task_id_2)["kwargs"][ + "build_endpoint_request_json" + ] + .get("billing_tags") + .get("idempotencyKeyPrefix") + == "new_value_1" + ) @pytest.mark.asyncio diff --git a/server/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py similarity index 96% rename from server/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py index e60bbcff..9b3f2ad7 100644 --- a/server/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py @@ -1,14 +1,12 @@ import pytest -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.infra.gateways.live_model_endpoints_schema_gateway import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.infra.gateways.live_model_endpoints_schema_gateway import ( LiveModelEndpointsSchemaGateway, ) @pytest.fixture -def live_model_endpoints_schema_gateway( - fake_filesystem_gateway, -) -> LiveModelEndpointsSchemaGateway: +def live_model_endpoints_schema_gateway(fake_filesystem_gateway) -> LiveModelEndpointsSchemaGateway: return LiveModelEndpointsSchemaGateway(filesystem_gateway=fake_filesystem_gateway) diff --git a/server/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py similarity index 86% rename from server/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py index 58a735ee..35256fce 100644 --- a/server/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py @@ -4,12 +4,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, SyncEndpointPredictV1Response, ) -from llm_engine_server.domain.exceptions import UpstreamServiceError -from llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway import ( +from model_engine_server.domain.exceptions import UpstreamServiceError +from model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway import ( LiveStreamingModelEndpointInferenceGateway, ) @@ -61,7 +61,7 @@ async def test_make_request_with_retries_success(): mock_client_session = _get_mock_client_session(fake_response) with patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): response = gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) @@ -80,7 +80,7 @@ async def test_make_request_with_retries_failed_429(): mock_client_session = _get_mock_client_session(fake_response) with pytest.raises(UpstreamServiceError), patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): async for response in gateway.make_request_with_retries("test_request_url", {}, 0.05, 2): @@ -95,7 +95,7 @@ async def test_make_request_with_retries_failed_traceback(): mock_client_session = _get_mock_client_session(fake_response) with pytest.raises(UpstreamServiceError), patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): async for response in gateway.make_request_with_retries("test_request_url", {}, 0.05, 2): @@ -111,7 +111,7 @@ async def test_streaming_predict_success( fake_response = FakeResponse(status=200) mock_client_session = _get_mock_client_session(fake_response) with patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): response = gateway.streaming_predict( @@ -139,7 +139,7 @@ async def test_predict_raises_traceback_json( fake_response = FakeResponse(status=500, message_content=content) mock_client_session = _get_mock_client_session(fake_response) with patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): response = gateway.streaming_predict( @@ -167,7 +167,7 @@ async def test_predict_raises_traceback_not_json( fake_response = FakeResponse(status=500, message_content=content) mock_client_session = _get_mock_client_session(fake_response) with patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): response = gateway.streaming_predict( diff --git a/server/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py similarity index 84% rename from server/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py index d89ae7b7..758e6ce0 100644 --- a/server/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py @@ -4,12 +4,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, SyncEndpointPredictV1Response, ) -from llm_engine_server.domain.exceptions import UpstreamServiceError -from llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway import ( +from model_engine_server.domain.exceptions import UpstreamServiceError +from model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway import ( LiveSyncModelEndpointInferenceGateway, ) @@ -45,7 +45,7 @@ async def test_make_request_with_retries_success(): mock_client_session = _get_mock_client_session(fake_response) with patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): response = await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) @@ -60,7 +60,7 @@ async def test_make_request_with_retries_failed_429(): mock_client_session = _get_mock_client_session(fake_response) with pytest.raises(UpstreamServiceError), patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) @@ -74,7 +74,7 @@ async def test_make_request_with_retries_failed_traceback(): mock_client_session = _get_mock_client_session(fake_response) with pytest.raises(UpstreamServiceError), patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) @@ -89,7 +89,7 @@ async def test_predict_success( fake_response = FakeResponse(status=200, body={"test_key": "test_value"}) mock_client_session = _get_mock_client_session(fake_response) with patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): response = await gateway.predict( @@ -113,7 +113,7 @@ async def test_predict_raises_traceback_json( fake_response = FakeResponse(status=500, content=content) mock_client_session = _get_mock_client_session(fake_response) with patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): response = await gateway.predict( @@ -137,7 +137,7 @@ async def test_predict_raises_traceback_not_json( fake_response = FakeResponse(status=500, content=content) mock_client_session = _get_mock_client_session(fake_response) with patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): response = await gateway.predict( diff --git a/server/tests/unit/infra/repositories/conftest.py b/model-engine/tests/unit/infra/repositories/conftest.py similarity index 97% rename from server/tests/unit/infra/repositories/conftest.py rename to model-engine/tests/unit/infra/repositories/conftest.py index dcd0260a..12c550b6 100644 --- a/server/tests/unit/infra/repositories/conftest.py +++ b/model-engine/tests/unit/infra/repositories/conftest.py @@ -2,10 +2,10 @@ from typing import Callable, Optional, Union import pytest -from llm_engine_server.db.models import BatchJob, Bundle -from llm_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle -from llm_engine_server.db.models import Endpoint -from llm_engine_server.domain.entities import ( +from model_engine_server.db.models import BatchJob, Bundle +from model_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle +from model_engine_server.db.models import Endpoint +from model_engine_server.domain.entities import ( BatchJobRecord, GpuType, ModelBundle, @@ -170,7 +170,7 @@ def orm_model_bundle_4(test_api_key: str) -> Bundle: "ecr_repo": "test_repo", "image_tag": "test_tag", }, - packaging_type="cloudpickle", + packaging_type="lira", app_config=None, ) model_bundle.id = "test_model_bundle_id_4" @@ -205,7 +205,7 @@ def orm_model_bundle_5(test_api_key: str) -> Bundle: "ecr_repo": "test_repo", "image_tag": "test_tag", }, - packaging_type="cloudpickle", + packaging_type="lira", app_config=None, ) model_bundle.id = "test_model_bundle_id_5" diff --git a/server/tests/unit/infra/repositories/test_db_batch_job_record_repository.py b/model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py similarity index 97% rename from server/tests/unit/infra/repositories/test_db_batch_job_record_repository.py rename to model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py index 95859158..d52d327b 100644 --- a/server/tests/unit/infra/repositories/test_db_batch_job_record_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py @@ -3,10 +3,10 @@ from unittest.mock import AsyncMock import pytest -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException -from llm_engine_server.db.models import BatchJob, Bundle -from llm_engine_server.domain.entities import BatchJobRecord -from llm_engine_server.infra.repositories.db_batch_job_record_repository import ( +from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException +from model_engine_server.db.models import BatchJob, Bundle +from model_engine_server.domain.entities import BatchJobRecord +from model_engine_server.infra.repositories.db_batch_job_record_repository import ( DbBatchJobRecordRepository, OrmBatchJob, ) diff --git a/server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py b/model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py similarity index 94% rename from server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py rename to model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py index 579d13d0..b28bf81f 100644 --- a/server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py @@ -3,16 +3,16 @@ from unittest.mock import AsyncMock import pytest -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException -from llm_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle -from llm_engine_server.domain.entities import GpuType -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException +from model_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle +from model_engine_server.domain.entities import GpuType +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) -from llm_engine_server.domain.exceptions import CorruptRecordInfraStateException -from llm_engine_server.infra.repositories import DbDockerImageBatchJobBundleRepository -from llm_engine_server.infra.repositories.db_docker_image_batch_job_bundle_repository import ( +from model_engine_server.domain.exceptions import CorruptRecordInfraStateException +from model_engine_server.infra.repositories import DbDockerImageBatchJobBundleRepository +from model_engine_server.infra.repositories.db_docker_image_batch_job_bundle_repository import ( translate_docker_image_batch_job_bundle_orm_to_entity, ) from sqlalchemy.ext.asyncio import AsyncSession diff --git a/server/tests/unit/infra/repositories/test_db_model_bundle_repository.py b/model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py similarity index 97% rename from server/tests/unit/infra/repositories/test_db_model_bundle_repository.py rename to model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py index 0e45d4a4..dd73b221 100644 --- a/server/tests/unit/infra/repositories/test_db_model_bundle_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py @@ -3,16 +3,16 @@ from unittest.mock import AsyncMock import pytest -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException -from llm_engine_server.db.models import Bundle -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException +from model_engine_server.db.models import Bundle +from model_engine_server.domain.entities import ( CloudpickleArtifactFlavor, ModelBundle, ModelBundlePackagingType, PytorchFramework, ) -from llm_engine_server.infra.repositories.db_model_bundle_repository import ( +from model_engine_server.infra.repositories.db_model_bundle_repository import ( DbModelBundleRepository, OrmModelBundle, ) diff --git a/server/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py b/model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py similarity index 95% rename from server/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py rename to model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py index 38103bc2..8d751272 100644 --- a/server/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py @@ -3,13 +3,13 @@ from unittest.mock import AsyncMock, Mock import pytest -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException -from llm_engine_server.db.models import Bundle, Endpoint -from llm_engine_server.domain.entities import ModelEndpointRecord -from llm_engine_server.infra.gateways import FakeMonitoringMetricsGateway -from llm_engine_server.infra.repositories import db_model_endpoint_record_repository -from llm_engine_server.infra.repositories.db_model_endpoint_record_repository import ( +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException +from model_engine_server.db.models import Bundle, Endpoint +from model_engine_server.domain.entities import ModelEndpointRecord +from model_engine_server.infra.gateways import FakeMonitoringMetricsGateway +from model_engine_server.infra.repositories import db_model_endpoint_record_repository +from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( DbModelEndpointRecordRepository, OrmModelEndpoint, ) @@ -140,7 +140,7 @@ async def test_list_llm_model_endpoint_records( orm_model_bundle: Bundle, fake_monitoring_metrics_gateway: FakeMonitoringMetricsGateway, ): - filter_content = "endpoint_metadata ? '_llm' AND llm_engine.endpoints.name = :name_1 AND (llm_engine.endpoints.owner = :owner_1 OR llm_engine.endpoints.public_inference = true)" + filter_content = "endpoint_metadata ? '_llm' AND hosted_model_inference.endpoints.name = :name_1 AND (hosted_model_inference.endpoints.owner = :owner_1 OR hosted_model_inference.endpoints.public_inference = true)" def mock_llm_model_endpoint_select_all_by_filters( session: AsyncSession, filters: Any @@ -169,7 +169,7 @@ def mock_llm_model_endpoint_select_all_by_filters( order_by=ModelEndpointOrderBy.NEWEST, ) - filter_content = "endpoint_metadata ? '_llm' AND (llm_engine.endpoints.owner = :owner_1 OR llm_engine.endpoints.public_inference = true)" + filter_content = "endpoint_metadata ? '_llm' AND (hosted_model_inference.endpoints.owner = :owner_1 OR hosted_model_inference.endpoints.public_inference = true)" await repo.list_llm_model_endpoint_records( owner="test_user_id", name=None, diff --git a/server/tests/unit/infra/repositories/test_redis_feature_flag_repository.py b/model-engine/tests/unit/infra/repositories/test_redis_feature_flag_repository.py similarity index 88% rename from server/tests/unit/infra/repositories/test_redis_feature_flag_repository.py rename to model-engine/tests/unit/infra/repositories/test_redis_feature_flag_repository.py index 50871f6e..5bf3a0e5 100644 --- a/server/tests/unit/infra/repositories/test_redis_feature_flag_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_redis_feature_flag_repository.py @@ -2,7 +2,7 @@ import aioredis import pytest -from llm_engine_server.infra.repositories.redis_feature_flag_repository import ( +from model_engine_server.infra.repositories.redis_feature_flag_repository import ( RedisFeatureFlagRepository, ) diff --git a/server/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py b/model-engine/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py similarity index 92% rename from server/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py rename to model-engine/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py index f7cdb743..eb1133fb 100644 --- a/server/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py @@ -2,7 +2,7 @@ import aioredis import pytest -from llm_engine_server.infra.repositories.redis_model_endpoint_cache_repository import ( +from model_engine_server.infra.repositories.redis_model_endpoint_cache_repository import ( RedisModelEndpointCacheRepository, ) diff --git a/server/tests/unit/infra/services/conftest.py b/model-engine/tests/unit/infra/services/conftest.py similarity index 92% rename from server/tests/unit/infra/services/conftest.py rename to model-engine/tests/unit/infra/services/conftest.py index a8c2edb6..9efc271f 100644 --- a/server/tests/unit/infra/services/conftest.py +++ b/model-engine/tests/unit/infra/services/conftest.py @@ -1,10 +1,10 @@ import pytest -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint -from llm_engine_server.infra.gateways import ( +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint +from model_engine_server.infra.gateways import ( LiveBatchJobProgressGateway, LiveModelEndpointsSchemaGateway, ) -from llm_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService +from model_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService @pytest.fixture diff --git a/model-engine/tests/unit/infra/services/test_docker_image_batch_job_llm_fine_tuning_service.py b/model-engine/tests/unit/infra/services/test_docker_image_batch_job_llm_fine_tuning_service.py new file mode 100644 index 00000000..598b5c1b --- /dev/null +++ b/model-engine/tests/unit/infra/services/test_docker_image_batch_job_llm_fine_tuning_service.py @@ -0,0 +1,65 @@ +import pytest +import pytest_asyncio +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.services import DockerImageBatchJobLLMFineTuningService + + +@pytest_asyncio.fixture +async def live_docker_image_batch_job_llm_fine_tuning_service( + fake_docker_image_batch_job_gateway, + fake_docker_image_batch_job_bundle_repository, + fake_llm_fine_tune_repository, +): + fake_bundle = ( + await fake_docker_image_batch_job_bundle_repository.create_docker_image_batch_job_bundle( + name="fake_fine_tune_bundle", + created_by="fake_egp_admin", + owner="fake_egp_admin", + image_repository="fake_image_repo", + image_tag="fake_image_tag", + command=["fake_command"], + env={"fake_env": "fake_env"}, + mount_location="/fake_mount_location", + cpus="1", + memory="0.1Gi", + storage="1Gi", + gpus=0, + gpu_type=None, + public=True, + ) + ) + await fake_llm_fine_tune_repository.write_job_template_for_model( + model_name="fake_model_name", + fine_tuning_method="fake_fine_tuning_method", + job_template=LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=fake_bundle.id, + launch_endpoint_config={}, + default_hparams={}, + required_params=[], + ), + ) + return DockerImageBatchJobLLMFineTuningService( + docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway, + docker_image_batch_job_bundle_repo=fake_docker_image_batch_job_bundle_repository, + llm_fine_tune_repository=fake_llm_fine_tune_repository, + ) + + +@pytest.mark.asyncio +async def test_create_fine_tune_success( + live_docker_image_batch_job_llm_fine_tuning_service, + fake_docker_image_batch_job_gateway, +): + batch_job_id = await live_docker_image_batch_job_llm_fine_tuning_service.create_fine_tune( + created_by="fake_user", + owner="fake_user", + model="fake_model_name", + training_file="fake_training_file_path", + validation_file="fake_validation_file_path", + fine_tuning_method="fake_fine_tuning_method", + hyperparameters={}, + fine_tuned_model="fake_fine_tuned_model_name", + wandb_config=None, + ) + assert batch_job_id is not None + assert fake_docker_image_batch_job_gateway.get_docker_image_batch_job(batch_job_id) is not None diff --git a/model-engine/tests/unit/infra/services/test_image_cache_service.py b/model-engine/tests/unit/infra/services/test_image_cache_service.py new file mode 100644 index 00000000..f405ea21 --- /dev/null +++ b/model-engine/tests/unit/infra/services/test_image_cache_service.py @@ -0,0 +1,58 @@ +from typing import Any + +import pytest +from model_engine_server.common.config import hmi_config +from model_engine_server.common.env_vars import GIT_TAG +from model_engine_server.core.config import infra_config +from model_engine_server.infra.services.image_cache_service import DockerImage, ImageCacheService + + +@pytest.mark.asyncio +async def test_image_cache_success( + fake_image_cache_service: ImageCacheService, + model_endpoint_1, + model_endpoint_2, + model_endpoint_3, + model_endpoint_4, +): + infra_states = { + model_endpoint_1.record.id: (bool, model_endpoint_1.infra_state), + model_endpoint_2.record.id: (bool, model_endpoint_2.infra_state), + model_endpoint_3.record.id: (bool, model_endpoint_3.infra_state), + model_endpoint_4.record.id: (bool, model_endpoint_4.infra_state), + } + repo: Any = fake_image_cache_service.model_endpoint_record_repository + repo.add_model_endpoint_record(model_endpoint_1.record) + repo.add_model_endpoint_record(model_endpoint_2.record) + repo.add_model_endpoint_record(model_endpoint_3.record) + repo.add_model_endpoint_record(model_endpoint_4.record) + + await fake_image_cache_service.execute(infra_states) # type: ignore + gateway: Any = fake_image_cache_service.image_cache_gateway + + assert f"{infra_config().docker_repo_prefix}/my-repo:abcdefg222" in gateway.cached_images["t4"] + assert ( + f"{infra_config().docker_repo_prefix}/my-repo:abcdefg111111111" + in gateway.cached_images["t4"] + ) + assert ( + f"{infra_config().docker_repo_prefix}/my-repo:abcdefg00000" in gateway.cached_images["t4"] + ) + + +@pytest.mark.asyncio +async def test_caching_finetune_llm_images( + fake_image_cache_service: ImageCacheService, +): + await fake_image_cache_service.execute({}) + gateway: Any = fake_image_cache_service.image_cache_gateway + + istio_image = DockerImage("gcr.io/istio-release/proxyv2", "1.15.0") + tgi_image = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.3-launch_s3" + ) + forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG) + + for key in ["a10", "a100"]: + for llm_image in [istio_image, tgi_image, forwarder_image]: + assert f"{llm_image.repo}:{llm_image.tag}" in gateway.cached_images[key] diff --git a/server/tests/unit/infra/services/test_live_batch_job_orchestration_service.py b/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py similarity index 94% rename from server/tests/unit/infra/services/test_live_batch_job_orchestration_service.py rename to model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py index 9b3ff377..8d894622 100644 --- a/server/tests/unit/infra/services/test_live_batch_job_orchestration_service.py +++ b/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py @@ -4,10 +4,10 @@ from unittest.mock import patch import pytest -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME -from llm_engine_server.common.dtos.tasks import GetAsyncTaskV1Response, ResponseSchema, TaskStatus -from llm_engine_server.core.domain_exceptions import ObjectNotFoundException -from llm_engine_server.domain.entities import ( +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME +from model_engine_server.common.dtos.tasks import GetAsyncTaskV1Response, ResponseSchema, TaskStatus +from model_engine_server.core.domain_exceptions import ObjectNotFoundException +from model_engine_server.domain.entities import ( BatchJob, BatchJobSerializationFormat, BatchJobStatus, @@ -15,12 +15,12 @@ ModelEndpoint, ModelEndpointStatus, ) -from llm_engine_server.infra.gateways import LiveBatchJobProgressGateway -from llm_engine_server.infra.services import ( +from model_engine_server.infra.gateways import LiveBatchJobProgressGateway +from model_engine_server.infra.services import ( LiveBatchJobOrchestrationService, LiveModelEndpointService, ) -from llm_engine_server.infra.services.live_batch_job_orchestration_service import ( +from model_engine_server.infra.services.live_batch_job_orchestration_service import ( BatchEndpointInferencePredictionResponse, BatchEndpointInProgressTask, ) @@ -235,7 +235,7 @@ async def test_run_batch_job_wait_for_endpoint( ): model_endpoint_1.record.status = ModelEndpointStatus.UPDATE_PENDING with patch( - "llm_engine_server.infra.services.live_batch_job_orchestration_service.asyncio.sleep" + "model_engine_server.infra.services.live_batch_job_orchestration_service.asyncio.sleep" ) as mock_sleep: def set_record_ready(*args, **kwargs): diff --git a/server/tests/unit/infra/services/test_live_batch_job_service.py b/model-engine/tests/unit/infra/services/test_live_batch_job_service.py similarity index 94% rename from server/tests/unit/infra/services/test_live_batch_job_service.py rename to model-engine/tests/unit/infra/services/test_live_batch_job_service.py index b1b60a91..8d440a2a 100644 --- a/server/tests/unit/infra/services/test_live_batch_job_service.py +++ b/model-engine/tests/unit/infra/services/test_live_batch_job_service.py @@ -1,9 +1,9 @@ import pytest -from llm_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests -from llm_engine_server.domain.entities import BatchJobSerializationFormat, GpuType, ModelBundle -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.infra.services import LiveBatchJobService -from llm_engine_server.infra.services.live_batch_job_service import ( +from model_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests +from model_engine_server.domain.entities import BatchJobSerializationFormat, GpuType, ModelBundle +from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException +from model_engine_server.infra.services import LiveBatchJobService +from model_engine_server.infra.services.live_batch_job_service import ( DEFAULT_ENDPOINT_CPUS_BATCH_JOB, DEFAULT_ENDPOINT_GPU_TYPE_BATCH_JOB, DEFAULT_ENDPOINT_GPUS_BATCH_JOB, diff --git a/server/tests/unit/infra/services/test_live_endpoint_builder_service.py b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py similarity index 87% rename from server/tests/unit/infra/services/test_live_endpoint_builder_service.py rename to model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py index d87be073..6d5724fb 100644 --- a/server/tests/unit/infra/services/test_live_endpoint_builder_service.py +++ b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py @@ -2,22 +2,25 @@ from unittest.mock import Mock, mock_open import pytest -from llm_engine_server.common.dtos.docker_repository import BuildImageResponse -from llm_engine_server.common.dtos.endpoint_builder import ( +from model_engine_server.common.dtos.docker_repository import BuildImageResponse +from model_engine_server.common.dtos.endpoint_builder import ( BuildEndpointRequest, BuildEndpointResponse, BuildEndpointStatus, ) -from llm_engine_server.core.domain_exceptions import DockerBuildFailedException -from llm_engine_server.core.fake_notification_gateway import FakeNotificationGateway -from llm_engine_server.core.notification_gateway import NotificationApp -from llm_engine_server.domain.entities.model_bundle_entity import RunnableImageFlavor -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.fake_monitoring_metrics_gateway import ( +from model_engine_server.core.domain_exceptions import DockerBuildFailedException +from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway +from model_engine_server.core.notification_gateway import NotificationApp +from model_engine_server.domain.entities.model_bundle_entity import ( + ArtifactLike, + RunnableImageFlavor, +) +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.fake_monitoring_metrics_gateway import ( FakeMonitoringMetricsGateway, ) -from llm_engine_server.infra.repositories import ModelEndpointCacheRepository -from llm_engine_server.infra.services import ( +from model_engine_server.infra.repositories import ModelEndpointCacheRepository +from model_engine_server.infra.services import ( LiveEndpointBuilderService, live_endpoint_builder_service, ) @@ -129,7 +132,8 @@ async def test_build_endpoint( ]: fake_monitoring_metrics_gateway.reset() repo.add_model_endpoint_record(request.model_endpoint_record) - response = await service.build_endpoint(request) + # Pass in a deep copy of request since LiveEndpointBuilderService.convert_artifact_like_bundle_to_runnable_image mutate the request + response = await service.build_endpoint(request.copy(deep=True)) assert response == BuildEndpointResponse(status=BuildEndpointStatus.OK) assert fake_model_endpoint_cache_repository.read_endpoint_info( endpoint_id=request.model_endpoint_record.id, @@ -138,6 +142,14 @@ async def test_build_endpoint( assert fake_monitoring_metrics_gateway.attempted_build == 1 assert fake_monitoring_metrics_gateway.docker_failed_build == 0 assert fake_monitoring_metrics_gateway.successful_build == 1 + assert fake_monitoring_metrics_gateway.build_time_seconds > 0 + if isinstance(request.model_endpoint_record.current_model_bundle.flavor, ArtifactLike): + if service == endpoint_builder_service_empty_docker_built: + assert sum(fake_monitoring_metrics_gateway.image_build_cache_hit.values()) > 0 + assert sum(fake_monitoring_metrics_gateway.image_build_cache_miss.values()) == 0 + else: + assert sum(fake_monitoring_metrics_gateway.image_build_cache_hit.values()) == 0 + assert sum(fake_monitoring_metrics_gateway.image_build_cache_miss.values()) > 0 @pytest.mark.asyncio diff --git a/server/tests/unit/infra/services/test_live_model_endpoint_service.py b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py similarity index 97% rename from server/tests/unit/infra/services/test_live_model_endpoint_service.py rename to model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py index 8b37b5a9..87cbab0f 100644 --- a/server/tests/unit/infra/services/test_live_model_endpoint_service.py +++ b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py @@ -2,21 +2,21 @@ from unittest.mock import AsyncMock import pytest -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.domain_exceptions import ( ObjectAlreadyExistsException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( ModelBundle, ModelEndpoint, ModelEndpointRecord, ModelEndpointStatus, ) -from llm_engine_server.domain.exceptions import ( +from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, ExistingEndpointOperationInProgressException, ) -from llm_engine_server.infra.services import LiveModelEndpointService +from model_engine_server.infra.services import LiveModelEndpointService async def _create_model_endpoint_helper( @@ -63,6 +63,7 @@ async def _create_model_endpoint_helper( results_s3_bucket=infra_state.results_s3_bucket, prewarm=prewarm, high_priority=high_priority, + billing_tags=infra_state.user_config_state.endpoint_config.billing_tags, owner=model_endpoint.record.owner, ) return model_endpoint_record @@ -112,6 +113,7 @@ async def test_create_get_model_endpoint_success( model_endpoint.record.created_at = model_endpoint_1.record.created_at model_endpoint.record.last_updated_at = model_endpoint_1.record.last_updated_at model_endpoint.record.id = model_endpoint_1.record.id + model_endpoint.infra_state.user_config_state.endpoint_config.billing_tags = model_endpoint_1.infra_state.user_config_state.endpoint_config.billing_tags # type: ignore # Use dict comparison because errors are more readable. assert model_endpoint.dict() == model_endpoint_1.dict() diff --git a/server/tests/unit/infra/services/test_model_endpoint_cache_service.py b/model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py similarity index 92% rename from server/tests/unit/infra/services/test_model_endpoint_cache_service.py rename to model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py index 64f50eaa..fc3661b2 100644 --- a/server/tests/unit/infra/services/test_model_endpoint_cache_service.py +++ b/model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.infra.services.model_endpoint_cache_service import ( +from model_engine_server.infra.services.model_endpoint_cache_service import ( ModelEndpointCacheWriteService, ) @@ -18,9 +18,7 @@ async def test_model_endpoint_write_success( ) cache_write_service = ModelEndpointCacheWriteService( - fake_model_endpoint_cache_repository, - fake_resource_gateway, - fake_image_cache_service, + fake_model_endpoint_cache_repository, fake_resource_gateway, fake_image_cache_service ) await cache_write_service.execute(42) infra_state = await fake_model_endpoint_cache_repository.read_endpoint_info( diff --git a/requirements-dev.txt b/requirements-dev.txt index f293cb22..e3edc67e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ # Make sure to update .pre-commit-config.yaml to match versions! -black==22.12.0 +black[jupyter]==22.12.0 ruff==0.0.278 +ipython==8.12.0 # 8.12.0 is the last version to support Python 3.8 isort==5.12.0 mypy==1.3.0 pip-tools==7.0.0 diff --git a/server/Dockerfile.openapi b/server/Dockerfile.openapi deleted file mode 100644 index f0c892de..00000000 --- a/server/Dockerfile.openapi +++ /dev/null @@ -1,6 +0,0 @@ -FROM openapitools/openapi-generator-cli:v6.4.0 as openapi -RUN apt-get update && apt-get install -y npm && rm -rf /var/lib/apt/lists/* -RUN npm install @openapitools/openapi-generator-cli -g -RUN openapi-generator-cli version-manager set 6.4.0 -WORKDIR /local -ENTRYPOINT ["openapi-generator-cli"] diff --git a/server/Makefile b/server/Makefile deleted file mode 100644 index c673b1ad..00000000 --- a/server/Makefile +++ /dev/null @@ -1,43 +0,0 @@ -install: - pip install -r requirements.txt - pip install -r requirements_override.txt - pip install -e . - -install-test: - pip install -r requirements-test.txt - -install-dev: - pip install -r ../requirements-dev.txt - -install-docs: - pip install -r ../requirements-docs.txt - pip install -e ../clients/python/ - -requirements: install-dev - pip-compile --allow-unsafe --no-emit-index-url --no-emit-trusted-host --output-file=requirements.txt requirements.in - -install-all: install install-test install-dev install-docs - -test: - WORKSPACE=.. pytest - -autogen-templates: - pushd charts && \ - helm template llm-engine llm-engine -f llm-engine/values_circleci.yaml \ - -s templates/service_template_config_map.yaml \ - --set message='# THIS FILE IS AUTOGENERATED USING `just autogen-templates`. PLEASE EDIT THE GOTEMPLATE FILE IN THE HELM CHART!!!' \ - > ../llm_engine/infra/gateways/resources/templates/service_template_config_map_circleci.yaml \ - && popd - -build: - docker compose build llm-engine - -dev: - # TODO: add env variables to make this work. - docker compose up llm-engine-gateway-dev llm-engine-service-builder-dev - -build-docs: - mkdocs build - -dev-docs: - mkdocs serve diff --git a/server/docker-compose.yml b/server/docker-compose.yml deleted file mode 100644 index ca172d2a..00000000 --- a/server/docker-compose.yml +++ /dev/null @@ -1,155 +0,0 @@ -version: "3.8" - -services: - llm-engine: - build: - context: .. - dockerfile: server/Dockerfile - target: llm-engine - image: "${ECR_HOST:-local}/llm-engine:${GIT_SHA:-latest}" - llm-engine-gateway-dev: - build: - context: .. - dockerfile: server/Dockerfile - target: llm-engine - command: - - python - - -m - - llm_engine_server.entrypoints.start_fastapi_server - - --port=5001 - - --debug - - --num-workers=1 - environment: - - AWS_PROFILE - - SQS_PROFILE - - KUBECONFIG=/workspace/.kube/kubeconfig:/workspace/.kube/config - - SERVICE_IDENTIFIER - - AWS_CONFIG_FILE=/creds/.aws/config - - AWS_SHARED_CREDENTIALS_FILE=/creds/.aws/credentials - - CELERY_ELASTICACHE_ENABLED=true - - "GIT_TAG=${GIT_SHA}" - - DD_ENV=training - - "DEPLOY_SERVICE_CONFIG_PATH=/workspace/server/service_configs/service_config_${ENV}.yaml" - - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH=/workspace/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_${ENV}.yaml" - - "ML_INFRA_SERVICES_CONFIG_PATH=/workspace/server/llm_engine_server/core/configs/${ENV}.yaml" - - "DB_SECRET_NAME=${DB_SECRET_NAME:-}" - - "ML_INFRA_DATABASE_URL=${ML_INFRA_DATABASE_URL:-}" - - "USE_REDIS_LOCALHOST=${USE_REDIS_LOCALHOST:-}" - - "SKIP_AUTH=${SKIP_AUTH:-}" - - "CIRCLECI=${CIRCLECI:-}" - - "LOCAL=${LOCAL:-false}" - network_mode: host - ports: - - 5001:5001 - stdin_open: true - tty: true - volumes: - - "${HOME}/.kube:/workspace/.kube" - - "${HOME}/.minikube:/workspace/.minikube" - - "${HOME}/.minikube:/home/circleci/.minikube" - - "${HOME}/.aws-mountable:/creds/.aws" - - "../llm_engine:/workspace/llm_engine" - llm-engine-service-builder-dev: - build: - context: .. - dockerfile: server/Dockerfile - target: llm-engine - command: - - celery - - --app=llm_engine_server.service_builder - - worker - - --loglevel=INFO - - --concurrency=4 - - "--queues=${QUEUE}" - environment: - - AWS_PROFILE - - SQS_PROFILE - - ECR_READ_AWS_PROFILE - - DB_SECRET_AWS_PROFILE - - "S3_BUCKET=${S3_BUCKET:-scale-ml}" - - "DEPLOY_SERVICE_CONFIG_PATH=/workspace/server/service_configs/service_config_${ENV}.yaml" - - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH=/workspace/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_${ENV}.yaml" - - "ML_INFRA_SERVICES_CONFIG_PATH=/workspace/ml_infra_core/ml_infra_services/ml_infra_services/configs/${ENV}.yaml" - - "GIT_TAG=${GIT_SHA}" - - DD_ENV=training - - SERVICE_IDENTIFIER - - KUBECONFIG=/workspace/.kube/kubeconfig:/workspace/.kube/config - - AWS_CONFIG_FILE=/creds/.aws/config - - AWS_SHARED_CREDENTIALS_FILE=/creds/.aws/credentials - - CELERY_ELASTICACHE_ENABLED=true - - "KANIKO_TEMPLATE=${KANIKO_TEMPLATE:-kaniko_template.yaml}" - - "DB_SECRET_NAME=${DB_SECRET_NAME:-}" - - "ML_INFRA_DATABASE_URL=${ML_INFRA_DATABASE_URL:-}" - - "USE_REDIS_LOCALHOST=${USE_REDIS_LOCALHOST:-}" - - "SKIP_AUTH=${SKIP_AUTH:-}" - - "CIRCLECI=${CIRCLECI:-}" - - "LOCAL=${LOCAL:-false}" - network_mode: host - stdin_open: true - tty: true - volumes: - - "${HOME}/.kube:/workspace/.kube" - - "${HOME}/.minikube:/workspace/.minikube" - - "${HOME}/.minikube:/home/circleci/.minikube" - - "${HOME}/.aws-mountable:/creds/.aws" - - "../llm_engine:/workspace/llm_engine" - llm-engine-bash: - build: - context: .. - dockerfile: server/Dockerfile - target: llm-engine - command: - - /bin/bash - - -c - - "'${BASH_COMMAND:-/bin/bash}'" - environment: - - AWS_PROFILE - - SQS_PROFILE - - ECR_READ_AWS_PROFILE - - DB_SECRET_AWS_PROFILE - - "DEPLOY_SERVICE_CONFIG_PATH=/workspace/server/service_configs/service_config_${ENV}.yaml" - - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH=/workspace/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_${ENV}.yaml" - - "ML_INFRA_SERVICES_CONFIG_PATH=/workspace/ml_infra_core/ml_infra_services/ml_infra_services/configs/${ENV}.yaml" - - "GIT_TAG=${GIT_SHA}" - - DD_ENV=training - - SERVICE_IDENTIFIER - - KUBECONFIG=/workspace/.kube/kubeconfig:/workspace/.kube/config - - AWS_CONFIG_FILE=/creds/.aws/config - - AWS_SHARED_CREDENTIALS_FILE=/creds/.aws/credentials - - CELERY_ELASTICACHE_ENABLED=true - - "DB_SECRET_NAME=${DB_SECRET_NAME:-}" - - "ML_INFRA_DATABASE_URL=${ML_INFRA_DATABASE_URL:-}" - - "USE_REDIS_LOCALHOST=${USE_REDIS_LOCALHOST:-}" - - "CIRCLECI=${CIRCLECI:-}" - - "LOCAL=${LOCAL:-false}" - network_mode: host - ports: - - 5002:5000 - volumes: - - "${HOME}/.kube:/workspace/.kube" - - "${HOME}/.minikube:/workspace/.minikube" - - "${HOME}/.minikube:/home/circleci/.minikube" - - "${HOME}/.aws-mountable:/creds/.aws" - - "../llm_engine:/workspace/llm_engine" - db: - image: "cimg/postgres:12.8-postgis" - ports: - - 5432:5432 - environment: - - POSTGRES_USER=ml_infra_test - - POSTGRES_DB=ml_infra_test - - POSTGRES_PASSWORD=ml_infra_test - redis: - image: redis - ports: - - 6379:6379 - openapi-generator-cli: - image: "${ECR_HOST:-local}/ml_infra_core/openapi:${GIT_SHA:-latest}" - build: - context: .. - dockerfile: ml_infra_core/Dockerfile.openapi - target: base - volumes: - - "../llm_engine/clients:/local" - command: - - generate diff --git a/server/llm_engine_server/api/app.py b/server/llm_engine_server/api/app.py deleted file mode 100644 index 72703909..00000000 --- a/server/llm_engine_server/api/app.py +++ /dev/null @@ -1,36 +0,0 @@ -from fastapi import FastAPI, Response -from llm_engine_server.api.batch_jobs_v1 import batch_job_router_v1 -from llm_engine_server.api.dependencies import get_or_create_aioredis_pool -from llm_engine_server.api.docker_image_batch_job_bundles_v1 import ( - docker_image_batch_job_bundle_router_v1, -) -from llm_engine_server.api.llms_v1 import llm_router_v1 -from llm_engine_server.api.model_bundles_v1 import model_bundle_router_v1 -from llm_engine_server.api.model_bundles_v2 import model_bundle_router_v2 -from llm_engine_server.api.model_endpoints_docs_v1 import model_endpoints_docs_router_v1 -from llm_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1 -from llm_engine_server.api.tasks_v1 import inference_task_router_v1 - -app = FastAPI(title="llm_engine", version="1.0.0", redoc_url="/api") - -app.include_router(batch_job_router_v1) -app.include_router(inference_task_router_v1) -app.include_router(model_bundle_router_v1) -app.include_router(model_bundle_router_v2) -app.include_router(model_endpoint_router_v1) -app.include_router(model_endpoints_docs_router_v1) -app.include_router(docker_image_batch_job_bundle_router_v1) -app.include_router(llm_router_v1) - - -@app.on_event("startup") -def load_redis(): - get_or_create_aioredis_pool() - - -@app.get("/healthcheck") -@app.get("/healthz") -@app.get("/readyz") -def healthcheck() -> Response: - """Returns 200 if the app is healthy.""" - return Response(status_code=200) diff --git a/server/llm_engine_server/common/constants.py b/server/llm_engine_server/common/constants.py deleted file mode 100644 index 87048c20..00000000 --- a/server/llm_engine_server/common/constants.py +++ /dev/null @@ -1,13 +0,0 @@ -from pathlib import Path - -CALLBACK_POST_INFERENCE_HOOK: str = "callback" -READYZ_FPATH: str = "/tmp/readyz" -DEFAULT_CELERY_TASK_NAME: str = "llm_engine_server.inference.async_inference.tasks.predict" -LIRA_CELERY_TASK_NAME: str = "llm_engine_server.inference.celery_service.exec_func" # TODO: FIXME - -PROJECT_ROOT: Path = Path(__file__).parents[2].absolute() -HOSTED_MODEL_INFERENCE_ROOT: Path = PROJECT_ROOT / "llm_engine" - -FEATURE_FLAG_USE_MULTI_CONTAINER_ARCHITECTURE_FOR_ARTIFACTLIKE_BUNDLE: str = ( - "USE_MULTI_CONTAINER_ARCHITECTURE_FOR_ARTIFACTLIKE_BUNDLE" -) diff --git a/server/llm_engine_server/common/datadog_utils.py b/server/llm_engine_server/common/datadog_utils.py deleted file mode 100644 index a7ee6a4e..00000000 --- a/server/llm_engine_server/common/datadog_utils.py +++ /dev/null @@ -1,10 +0,0 @@ -from ddtrace import tracer - - -def add_trace_resource_name(tag: str): - """Adds a custom tag to a given dd trace corresponding to the route - (e.g. get_model_bundles for GET /model-bundles, etc.) so that we can filter in Datadog easier - """ - current_span = tracer.current_span() - if current_span: - current_span.set_tag("llm_engine_server.resource_name", tag) diff --git a/server/llm_engine_server/common/dtos/llms.py b/server/llm_engine_server/common/dtos/llms.py deleted file mode 100644 index 2739dc1f..00000000 --- a/server/llm_engine_server/common/dtos/llms.py +++ /dev/null @@ -1,171 +0,0 @@ -""" -DTOs for LLM APIs. -""" - -from typing import Any, Dict, List, Optional - -from llm_engine_server.common.dtos.model_endpoints import ( - CpuSpecificationType, - GetModelEndpointV1Response, - GpuType, - ModelEndpointType, - StorageSpecificationType, -) -from llm_engine_server.domain.entities import ( - BatchJobStatus, - CallbackAuth, - LLMInferenceFramework, - LLMSource, - Quantization, -) -from pydantic import BaseModel, Field, HttpUrl - -from .tasks import TaskStatus - - -class CreateLLMModelEndpointV1Request(BaseModel): - name: str - - # LLM specific fields - model_name: str - source: LLMSource = LLMSource.HUGGING_FACE - inference_framework: LLMInferenceFramework = LLMInferenceFramework.DEEPSPEED - inference_framework_image_tag: str - num_shards: int = 1 - """ - Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models - """ - - quantize: Optional[Quantization] = None - """ - Whether to quantize the model. Only affect behavior for text-generation-inference models - """ - - checkpoint_path: Optional[str] = None - """ - Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models - """ - - # General endpoint fields - metadata: Dict[str, Any] # TODO: JSON type - post_inference_hooks: Optional[List[str]] - endpoint_type: ModelEndpointType = ModelEndpointType.SYNC - cpus: CpuSpecificationType - gpus: int - memory: StorageSpecificationType - gpu_type: GpuType - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] - min_workers: int - max_workers: int - per_worker: int - labels: Dict[str, str] - prewarm: Optional[bool] - high_priority: Optional[bool] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] - public_inference: Optional[bool] = True # LLM endpoints are public by default. - - -class CreateLLMModelEndpointV1Response(BaseModel): - endpoint_creation_task_id: str - - -class GetLLMModelEndpointV1Response(BaseModel): - id: str - """ - The autogenerated ID of the LLMEngine endpoint. - """ - - name: str - model_name: str - source: LLMSource - inference_framework: LLMInferenceFramework - inference_framework_image_tag: str - num_shards: int - quantize: Optional[Quantization] = None - spec: GetModelEndpointV1Response - - -class ListLLMModelEndpointsV1Response(BaseModel): - model_endpoints: List[GetLLMModelEndpointV1Response] - - -# Delete and update use the default LLMEngine endpoint APIs. - - -class CompletionSyncV1Request(BaseModel): - """ - Request object for a synchronous prompt completion task. - """ - - prompts: List[str] - max_new_tokens: int - temperature: float = Field(gt=0, le=100) - - -class CompletionOutput(BaseModel): - text: str - num_completion_tokens: int - - -class CompletionSyncV1Response(BaseModel): - """ - Response object for a synchronous prompt completion task. - """ - - status: TaskStatus - outputs: List[CompletionOutput] - traceback: Optional[str] = None - - -class CompletionStreamV1Request(BaseModel): - """ - Request object for a stream prompt completion task. - """ - - prompt: str - max_new_tokens: int - temperature: float = Field(gt=0, le=100) - - -class CompletionStreamOutput(BaseModel): - text: str - finished: bool - num_completion_tokens: Optional[int] = None - - -class CompletionStreamV1Response(BaseModel): - """ - Response object for a stream prompt completion task. - """ - - status: TaskStatus - output: Optional[CompletionStreamOutput] = None - traceback: Optional[str] = None - - -class CreateFineTuneJobRequest(BaseModel): - training_file: str - validation_file: str - model_name: str - base_model: str # TODO enum - fine_tuning_method: str # TODO enum - hyperparameters: Dict[str, str] # TODO validated somewhere else - - -class CreateFineTuneJobResponse(BaseModel): - id: str - - -class GetFineTuneJobResponse(BaseModel): - id: str - status: BatchJobStatus - - -class ListFineTuneJobResponse(BaseModel): - jobs: List[GetFineTuneJobResponse] - - -class CancelFineTuneJobResponse(BaseModel): - success: bool diff --git a/server/llm_engine_server/common/env_vars.py b/server/llm_engine_server/common/env_vars.py deleted file mode 100644 index fac2325c..00000000 --- a/server/llm_engine_server/common/env_vars.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -A place for defining, setting, and referencing all environment variables used in LLMEngine. -""" -import os -from typing import Optional, Sequence - -from llm_engine_server.common.constants import PROJECT_ROOT -from llm_engine_server.core.loggers import logger_name, make_logger - -__all__: Sequence[str] = ( - "CIRCLECI", - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH", - "LLM_ENGINE_SERVICE_TEMPLATE_FOLDER", - "LOCAL", - "WORKSPACE", - "get_boolean_env_var", -) - -logger = make_logger(logger_name()) - - -def get_boolean_env_var(name: str) -> bool: - """For all env vars that are either on or off. - - An env var is ON iff: - - it is defined - - its value is the literal string 'true' - - If it is present but not set to 'true', it is considered to be OFF. - """ - value = os.environ.get(name) - if value is None: - return False - value = value.strip().lower() - return "true" == value - - -CIRCLECI: bool = get_boolean_env_var("CIRCLECI") - -LOCAL: bool = get_boolean_env_var("LOCAL") -"""Indicates that LLMEngine is running in a local development environment. Also used for local testing. -""" - -WORKSPACE: str = os.environ.get("WORKSPACE", "~/models") -"""The working directory where llm_engine is installed. -""" - -LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH: str = os.environ.get( - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH", - os.path.join( - PROJECT_ROOT, - "llm_engine_server/infra/gateways/resources/templates", - "service_template_config_map_circleci.yaml", - ), -) -"""The path to the config map containing the LLMEngine service template. -""" - -LLM_ENGINE_SERVICE_TEMPLATE_FOLDER: Optional[str] = os.environ.get( - "LLM_ENGINE_SERVICE_TEMPLATE_FOLDER" -) -"""The path to the folder containing the LLMEngine service template. If set, this overrides -LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH. -""" - -if LOCAL: - logger.warning("LOCAL development & testing mode is ON") diff --git a/server/llm_engine_server/common/settings.py b/server/llm_engine_server/common/settings.py deleted file mode 100644 index 4d0af5e4..00000000 --- a/server/llm_engine_server/common/settings.py +++ /dev/null @@ -1,66 +0,0 @@ -# This file contains standard settings for ML serve. -# - -import hashlib -from typing import List - -from llm_engine_server.core.config import ml_infra_config - -DEPLOYMENT_PREFIX = "llm-engine" -SERVICE_BUILDER_QUEUE_PREFIX = "llm-engine" -SERVICE_BUILDER_QUEUE_SUFFIX = "service-builder" - -RESTRICTED_ENDPOINT_LABELS = set( - [ - "user_id", - "endpoint_name", - ] -) - -REQUIRED_ENDPOINT_LABELS = set( - [ - "team", - "product", - ] -) - -PRETRAINED_ENDPOINTS_CREATED_BY = ["nucleus-model-zoo", "bloom", "llm", "pretrained"] - - -def generate_deployment_name(user_id, endpoint_name): - return "-".join(_generate_deployment_name_parts(user_id, endpoint_name)) - - -def _generate_queue_name(user_id, endpoint_name): - return ".".join(_generate_deployment_name_parts(user_id, endpoint_name)) - - -def generate_destination(user_id: str, endpoint_name: str, endpoint_type: str) -> str: - if endpoint_type == "async": - return _generate_queue_name(user_id, endpoint_name) - elif endpoint_type in {"sync", "streaming"}: - return generate_deployment_name(user_id, endpoint_name) - else: - raise ValueError(f"Invalid endpoint_type: {endpoint_type}") - - -def _generate_deployment_name_parts(user_id: str, endpoint_name: str) -> List[str]: - user_endpoint_hash = hashlib.md5((user_id + endpoint_name).encode("utf-8")).hexdigest() - return [ - DEPLOYMENT_PREFIX, - user_id[:24], - endpoint_name[:8], - user_endpoint_hash[:8], - ] - - -def get_service_builder_queue(service_identifier=None): - return ( - f"{SERVICE_BUILDER_QUEUE_PREFIX}-{service_identifier}.{SERVICE_BUILDER_QUEUE_SUFFIX}" - if service_identifier - else f"{SERVICE_BUILDER_QUEUE_PREFIX}.{SERVICE_BUILDER_QUEUE_SUFFIX}" - ) - - -def get_service_builder_logs_location(user_id: str, endpoint_name: str): - return f"s3://{ml_infra_config().s3_bucket}/service_builder_logs/{user_id}_{endpoint_name}" diff --git a/server/llm_engine_server/core/aws/sfn_client.py b/server/llm_engine_server/core/aws/sfn_client.py deleted file mode 100644 index 30b6593f..00000000 --- a/server/llm_engine_server/core/aws/sfn_client.py +++ /dev/null @@ -1,21 +0,0 @@ -"""This module provides a client for the AWS Step Functions service.""" -import os -from typing import Optional - -from botocore.client import BaseClient -from llm_engine_server.core.aws.roles import session -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import logger_name, make_logger - -logger = make_logger(logger_name()) - - -def sync_sfn_client(**kwargs) -> Optional[BaseClient]: - is_testing_mode = os.environ.get("TESTING_DISABLE_SFN", "").lower() == "true" - if is_testing_mode: - logger.error( - "Not creating step function client as we are in testing mode." - "THIS SHOULD NOT HAPPEN IN PRODUCTION!" - ) - return None - return session(ml_infra_config().profile_ml_worker).client("stepfunctions", **kwargs) diff --git a/server/llm_engine_server/core/kubernetes.py b/server/llm_engine_server/core/kubernetes.py deleted file mode 100644 index 59589d7c..00000000 --- a/server/llm_engine_server/core/kubernetes.py +++ /dev/null @@ -1,81 +0,0 @@ -import logging -from enum import Enum -from pathlib import Path -from string import Template -from typing import Iterator, Union - -import yaml -from kubeconfig import KubeConfig - -from .loggers import make_logger - -logger = make_logger(__file__, log_level=logging.DEBUG) -_config = KubeConfig() - -_K8S_CONFIGS = {} - - -class LifecycleSelector(str, Enum): - NORMAL = "normal" - SPOT = "spot" - - -def k8s_config() -> str: - """Returns the name of the current kubernetes context""" - return _config.view()["current-context"].strip() - - -def check_k8s_config(env_name: str) -> bool: - """ - Checks whether the current k8s context (i.e. which cluster you're on) - is the one given by the config. - """ - assert env_name in _K8S_CONFIGS - cur_config = k8s_config() - return cur_config.strip() == _K8S_CONFIGS[env_name].strip() - - -def substitute_yaml(fp: Union[str, Path], **kwargs) -> dict: - """Read a file from disk, substitute options, return yaml - - The yaml file must have the variables to substitute written as $VAR or ${VAR}. See documentation - for string.Template for more details. - - Args: - fp: path to a yaml file - **kwargs: all the keyword arguments needed to substitute flags in the yaml file - - Returns: - Returns a dict of parsed yaml - - Raises: - FileNotFoundError: If no file exists at the path - KeyError: If a keyword argument is specified for a key that doesn't exist, or a key is - specified and no corresponding argument is passed in. - """ - with open(fp, "r") as template_f: - config = yaml.safe_load(Template(template_f.read()).substitute(**kwargs)) - return config - - -def substitute_yamls(fp: Union[str, Path], **kwargs) -> Iterator: - """Read a file from disk, substitute options, return yaml - - The yaml file must have the variables to substitute written as $VAR or ${VAR}. See documentation - for string.Template for more details. - - Args: - fp: path to a yaml file - **kwargs: all the keyword arguments needed to substitute flags in the yaml file - - Returns: - Returns a list of dicts of parsed yaml - - Raises: - FileNotFoundError: If no file exists at the path - KeyError: If a keyword argument is specified for a key that doesn't exist, or a key is - specified and no corresponding argument is passed in. - """ - with open(fp, "r") as template_f: - config = yaml.safe_load_all(Template(template_f.read()).substitute(**kwargs)) - return config diff --git a/server/llm_engine_server/core/testing_utilities.py b/server/llm_engine_server/core/testing_utilities.py deleted file mode 100644 index 5b80dd9d..00000000 --- a/server/llm_engine_server/core/testing_utilities.py +++ /dev/null @@ -1,140 +0,0 @@ -"""Utility functions for Python programs. Should not be used by other modules in this package.""" -import os -import platform -from functools import lru_cache -from tempfile import NamedTemporaryFile -from typing import Callable, Iterable, Optional, Sequence, Tuple, TypeVar - -from llm_engine_server.core.aws.storage_client import sync_storage_client -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.utils.url import parse_attachment_url - -In = TypeVar("In") -"""Type variable representing the function under test's input type. -""" - -Out = TypeVar("Out") -"""Type variable representing the function under test's output type. -""" - -__all__: Sequence[str] = ( - "table_tester", - "no_aws_r_creds", - "no_aws_rw_creds", - "env_var_is_true", -) - - -def table_tester( - fn: Callable[[In], Out], - i_o_pairs: Iterable[Tuple[In, Out]], - equality: Callable[[Out, Out], bool] = lambda a, b: a == b, -) -> None: - """Convenience function to apply a function against a series of input & expected output pairs. - This function `assert`s that the function applied to each input results in the associated output - value, where equality is checked by the :param:`equality` function, which defaults to Python's `==`. - """ - for i, (inp, expected) in enumerate(i_o_pairs): - msg_part = f"Failed on test pair # {i + 1}:\nINPUT: {inp}\nEXPECTED: {expected}\n" - try: - actual = fn(inp) - except Exception: # pylint: disable=broad-except - print(msg_part) - raise - assert equality(actual, expected), msg_part + f"ACTUAL: {actual}" - - -@lru_cache(1) -def no_aws_r_creds() -> bool: - """True if we don't have the read AWS access credentials to run tests. False means we do. - - Useful in a `@pytest.mark.skipif(condition=no_aws_r_creds(), reason="No AWS read credentials")` - marker on a `test_` unittest function. - """ - return _no_aws_creds(write_check=False) - - -@lru_cache(1) -def no_aws_rw_creds() -> bool: - """True if we don't have the read+write AWS access credentials to run tests. False means we do. - - Useful in a `@pytest.mark.skipif(condition=no_aws_rw_creds(), reason="No AWS read/write credentials")` - marker on a `test_` unittest function. - """ - return _no_aws_creds(write_check=True) - - -def _no_aws_creds(*, write_check: bool) -> bool: - try: - p = parse_attachment_url(f"s3://{ml_infra_config().s3_bucket}/testing/_keep_do_not_delete") - s3_client = sync_storage_client() - if not _exists(s3_client, p): - return True - - with NamedTemporaryFile() as f: - f.close() - # test read - with open(f.name, "wb") as wb: - s3_client.download_fileobj( - Bucket=p.bucket, - Key=p.key, - Fileobj=wb, - ) - if write_check: - # test write - with open(f.name, "rb") as rb: - s3_client.upload_fileobj( - Fileobj=rb, - Bucket=p.bucket, - Key=p.key, - ) - except Exception: # pylint: disable=broad-except - return True - else: - return False - - -def _exists(s3_client, p): - try: - # https://stackoverflow.com/questions/33842944/check-if-a-key-exists-in-a-bucket-in-s3-using-boto3 - s3_client.head_object(Bucket=p.bucket, Key=p.key) - except Exception as e: # type: ignore - try: - # pylint: disable=no-member - error_code = e.response["Error"]["Code"].strip() # type: ignore - if error_code in ("404", "NoSuchKey"): - return False - except (NameError, KeyError): - pass - raise e - else: - return True - - -def env_var_is_true(env_var_name: str) -> bool: - """Return true if the environment variable is currently set to a known truth value. - - True if the :param:`env_var_name` environment variable is present and contains a truth value. - The **only** accepted truth values are, case-insensitive: - - 'y' - - 'yes' - - 'true' - - - All other values are considered false. - Additionally, an unset environment variable will result in this function evaluating to false. - """ - if len(env_var_name) == 0: - raise ValueError("Need non-empty environment variable name!") - - try: - x: Optional[str] = os.environ.get(env_var_name, None) - if x is None: - return False - x = x.lower().strip() - return x in ("y", "true", "yes") - except Exception: # pylint: disable=broad-except - return False - - -def is_linux() -> bool: - return "Linux" in platform.platform() diff --git a/server/llm_engine_server/db/migrations/alembic.ini b/server/llm_engine_server/db/migrations/alembic.ini deleted file mode 100644 index 574eb9cc..00000000 --- a/server/llm_engine_server/db/migrations/alembic.ini +++ /dev/null @@ -1,85 +0,0 @@ -# A generic, single database configuration. - -[alembic] -# path to migration scripts -script_location = alembic - -# template used to generate migration files -# file_template = %%(rev)s_%%(slug)s - -# timezone to use when rendering the date -# within the migration file as well as the filename. -# string value is passed to dateutil.tz.gettz() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the -# "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; this defaults -# to alembic/versions. When using multiple version -# directories, initial revisions must be specified with --version-path -# version_locations = %(here)s/bar %(here)s/bat alembic/versions - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = utf-8 - -sqlalchemy.url = driver://user:pass@localhost/dbname - - -[post_write_hooks] -# post_write_hooks defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples - -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks=black -# black.type=console_scripts -# black.entrypoint=black -# black.options=-l 79 - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console -qualname = - -[logger_sqlalchemy] -level = DEBUG -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/server/llm_engine_server/db/migrations/alembic/README b/server/llm_engine_server/db/migrations/alembic/README deleted file mode 100644 index 98e4f9c4..00000000 --- a/server/llm_engine_server/db/migrations/alembic/README +++ /dev/null @@ -1 +0,0 @@ -Generic single-database configuration. \ No newline at end of file diff --git a/server/llm_engine_server/db/migrations/alembic/env.py b/server/llm_engine_server/db/migrations/alembic/env.py deleted file mode 100644 index c85751f9..00000000 --- a/server/llm_engine_server/db/migrations/alembic/env.py +++ /dev/null @@ -1,91 +0,0 @@ -import logging -import os -from logging.config import fileConfig - -from alembic import context -from llm_engine_server.db.base import get_engine_url -from sqlalchemy import engine_from_config, pool - -env = os.environ.get("ENV") -assert env is not None, "Expected ENV to be a nonempty environment variable." - -config = context.config - -config.set_main_option("sqlalchemy.url", get_engine_url(env, read_only=False)) - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = None - - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - - -def run_migrations_offline(): - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - include_schemas=True, - ) - - try: - with context.begin_transaction(): - context.run_migrations() - except Exception as e: - logging.exception("Error during migration: %s", str(e)) - raise e - - -def run_migrations_online(): - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - connectable = engine_from_config( - config.get_section(config.config_ini_section), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) - - try: - with context.begin_transaction(): - context.run_migrations() - except Exception as e: - logging.exception("Error during migration: %s", str(e)) - raise e - finally: - connection.close() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/server/llm_engine_server/db/migrations/alembic/script.py.mako b/server/llm_engine_server/db/migrations/alembic/script.py.mako deleted file mode 100644 index 2c015630..00000000 --- a/server/llm_engine_server/db/migrations/alembic/script.py.mako +++ /dev/null @@ -1,24 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision = ${repr(up_revision)} -down_revision = ${repr(down_revision)} -branch_labels = ${repr(branch_labels)} -depends_on = ${repr(depends_on)} - - -def upgrade(): - ${upgrades if upgrades else "pass"} - - -def downgrade(): - ${downgrades if downgrades else "pass"} diff --git a/server/llm_engine_server/db/ml_infra_pg.py b/server/llm_engine_server/db/ml_infra_pg.py deleted file mode 100644 index 0a2d4852..00000000 --- a/server/llm_engine_server/db/ml_infra_pg.py +++ /dev/null @@ -1,10 +0,0 @@ -from .base import Base, ml_infra_pg_engine - -# we need to import the following for sqlalchemy -# pylint: disable=unused-import -from .models.llm_engine import Bundle, Endpoint # noqa -from .models.model import Model, ModelArtifact, ModelVersion # noqa -from .models.train import Execution, Experiment, Job, Snapshot # noqa - -# run this file to create the db models imported -Base.metadata.create_all(ml_infra_pg_engine) diff --git a/server/llm_engine_server/domain/services/llm_fine_tuning_service.py b/server/llm_engine_server/domain/services/llm_fine_tuning_service.py deleted file mode 100644 index 0f71592e..00000000 --- a/server/llm_engine_server/domain/services/llm_fine_tuning_service.py +++ /dev/null @@ -1,30 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict - - -class LLMFineTuningService(ABC): - @abstractmethod - async def create_fine_tune_job( - self, - created_by: str, - owner: str, - training_file: str, - validation_file: str, - model_name: str, - base_model: str, - fine_tuning_method: str, - hyperparameters: Dict[str, str], - ): - pass - - @abstractmethod - async def get_fine_tune_job(self, owner: str, fine_tune_id: str): - pass - - @abstractmethod - async def list_fine_tune_jobs(self, owner: str): - pass - - @abstractmethod - async def cancel_fine_tune_job(self, owner: str, fine_tune_id: str): - pass diff --git a/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py deleted file mode 100644 index 78ca4bf1..00000000 --- a/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ /dev/null @@ -1,82 +0,0 @@ -from llm_engine_server.common.dtos.llms import ( - CancelFineTuneJobResponse, - CreateFineTuneJobRequest, - CreateFineTuneJobResponse, - GetFineTuneJobResponse, - ListFineTuneJobResponse, -) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ObjectNotFoundException -from llm_engine_server.infra.services import DockerImageBatchJobLLMFineTuningService - - -class CreateFineTuneJobV1UseCase: - def __init__(self, llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService): - self.llm_fine_tuning_service = llm_fine_tuning_service - - async def execute( - self, user: User, request: CreateFineTuneJobRequest - ) -> CreateFineTuneJobResponse: - fine_tune_id = await self.llm_fine_tuning_service.create_fine_tune_job( - created_by=user.user_id, - owner=user.team_id, - training_file=request.training_file, - validation_file=request.validation_file, - model_name=request.model_name, - base_model=request.base_model, - fine_tuning_method=request.fine_tuning_method, - hyperparameters=request.hyperparameters, - ) - return CreateFineTuneJobResponse( - id=fine_tune_id, - ) - - -class GetFineTuneJobV1UseCase: - def __init__(self, llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService): - self.llm_fine_tuning_service = llm_fine_tuning_service - - async def execute(self, user: User, fine_tune_id: str) -> GetFineTuneJobResponse: - di_batch_job = await self.llm_fine_tuning_service.get_fine_tune_job( - owner=user.team_id, - fine_tune_id=fine_tune_id, - ) - if di_batch_job is None: - raise ObjectNotFoundException - return GetFineTuneJobResponse( - id=fine_tune_id, - status=di_batch_job.status, - ) - - -class ListFineTuneJobV1UseCase: - def __init__(self, llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService): - self.llm_fine_tuning_service = llm_fine_tuning_service - - async def execute(self, user: User) -> ListFineTuneJobResponse: - di_batch_jobs = await self.llm_fine_tuning_service.list_fine_tune_jobs( - owner=user.team_id, - ) - return ListFineTuneJobResponse( - jobs=[ - GetFineTuneJobResponse( - id=job.id, - status=job.status, - ) - for job in di_batch_jobs - ] - ) - - -class CancelFineTuneJobV1UseCase: - def __init__(self, llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService): - self.llm_fine_tuning_service = llm_fine_tuning_service - - async def execute(self, user: User, fine_tune_id: str) -> CancelFineTuneJobResponse: - success = await self.llm_fine_tuning_service.cancel_fine_tune_job( - owner=user.team_id, - fine_tune_id=fine_tune_id, - ) - return CancelFineTuneJobResponse( - success=success, - ) diff --git a/server/llm_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml b/server/llm_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml deleted file mode 100644 index 84cc7e9c..00000000 --- a/server/llm_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml +++ /dev/null @@ -1,13 +0,0 @@ -forwarder: - model: - class_name: llm_engine.inference.forwarding.forwarding.LoadForwarder - args: - user_port: 5005 - user_hostname: "localhost" - use_grpc: false - predict_route: "/predict" - healthcheck_route: "/readyz" - batch_route: null - llm_engine_unwrap: false - serialize_results_as_string: false - wrap_response: false \ No newline at end of file diff --git a/server/llm_engine_server/inference/configs/service--forwarder.yaml b/server/llm_engine_server/inference/configs/service--forwarder.yaml deleted file mode 100644 index 3b7ff30e..00000000 --- a/server/llm_engine_server/inference/configs/service--forwarder.yaml +++ /dev/null @@ -1,13 +0,0 @@ -forwarder: - model: - class_name: llm_engine.inference.forwarding.forwarding.LoadForwarder - args: - user_port: 5005 - user_hostname: "localhost" - use_grpc: false - predict_route: "/predict" - healthcheck_route: "/readyz" - batch_route: null - llm_engine_unwrap: true - serialize_results_as_string: true - diff --git a/server/llm_engine_server/inference/configs/service--streaming_forwarder.yaml b/server/llm_engine_server/inference/configs/service--streaming_forwarder.yaml deleted file mode 100644 index 23f844e4..00000000 --- a/server/llm_engine_server/inference/configs/service--streaming_forwarder.yaml +++ /dev/null @@ -1,10 +0,0 @@ -forwarder: - class_name: llm_engine.inference.forwarding.forwarding.LoadStreamingForwarder - args: - user_port: 5005 - user_hostname: "localhost" - predict_route: "/stream" - healthcheck_route: "/readyz" - batch_route: null - llm_engine_unwrap: true - serialize_results_as_string: false diff --git a/server/llm_engine_server/inference/forwarding/http_forwarder.py b/server/llm_engine_server/inference/forwarding/http_forwarder.py deleted file mode 100644 index efac1752..00000000 --- a/server/llm_engine_server/inference/forwarding/http_forwarder.py +++ /dev/null @@ -1,167 +0,0 @@ -import argparse -import json -import os -import subprocess -from functools import lru_cache -from typing import Any, List - -import yaml -from fastapi import Depends, FastAPI -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.loggers import logger_name, make_logger -from llm_engine_server.inference.forwarding.forwarding import LoadForwarder, LoadStreamingForwarder -from sse_starlette.sse import EventSourceResponse - -logger = make_logger(logger_name()) -app = FastAPI() - - -def _set_value(config: dict, key_path: List[str], value: Any) -> None: - """ - Modifies config by setting the value at config[key_path[0]][key_path[1]]... to be `value`. - """ - key = key_path[0] - if len(key_path) == 1: - config[key] = value - else: - if key not in config: - config[key] = dict() - _set_value(config[key], key_path[1:], value) - - -def _substitute_config_overrides(config: dict, config_overrides: List[str]) -> None: - """ - Modifies config based on config_overrides. - - config_overrides should be a list of strings of the form `key=value`, - where `key` can be of the form `key1.key2` to denote a substitution for config[key1][key2] - (nesting can be arbitrarily deep). - """ - for override in config_overrides: - split = override.split("=") - if len(split) != 2: - raise ValueError(f"Config override {override} must contain exactly one =") - key_path, value = split - try: - _set_value(config, key_path.split("."), value) - except Exception as e: - raise ValueError(f"Error setting {key_path} to {value} in {config}") from e - - -def _load_named_config(config_uri, config_overrides=None): - with open(config_uri, "rt") as rt: - if config_uri.endswith(".json"): - return json.load(rt) - else: - c = yaml.safe_load(rt) - if config_overrides: - _substitute_config_overrides(c, config_overrides) - if len(c) == 1: - name = list(c.keys())[0] - c = c[name] - if "name" not in c: - c["name"] = name - return c - - -@app.get("/healthz") -@app.get("/readyz") -def healthcheck(): - return "OK" - - -def get_config(): - overrides = os.getenv("CONFIG_OVERRIDES") - config_overrides = None - if overrides is not None: - config_overrides = overrides.split(";") - return _load_named_config( - os.getenv("CONFIG_FILE"), - config_overrides, - ) - - -def get_forwarder_loader(): - config = get_config() - forwarder_loader = LoadForwarder(**config["sync"]) - return forwarder_loader - - -def get_streaming_forwarder_loader(): - config = get_config() - streaming_forwarder_loader = LoadStreamingForwarder(**config["stream"]) - return streaming_forwarder_loader - - -@lru_cache() -def load_forwarder(): - return get_forwarder_loader().load(None, None) - - -@lru_cache() -def load_streaming_forwarder(): - return get_streaming_forwarder_loader().load(None, None) - - -@app.post("/predict") -def predict(request: EndpointPredictV1Request, forwarder=Depends(load_forwarder)): - return forwarder(request.dict()) - - -@app.post("/stream") -async def stream(request: EndpointPredictV1Request, forwarder=Depends(load_streaming_forwarder)): - try: - payload = request.dict() - except Exception: - logger.error(f"Failed to decode payload from: {request}") - raise - else: - logger.debug(f"Received request: {payload}") - - # has internal error logging for each processing stage - responses = forwarder(payload) - - async def event_generator(): - for response in responses: - yield {"data": json.dumps(response)} - - return EventSourceResponse(event_generator()) - - -def entrypoint(): - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, required=True) - parser.add_argument("--num-workers", type=int, required=True) - parser.add_argument("--host", type=str, default="[::]") - parser.add_argument("--port", type=int, default=5000) - parser.add_argument("--set", type=str, action="append") - - args = parser.parse_args() - - values = [f"CONFIG_FILE={args.config}"] - if args.set is not None: - values.append(f"CONFIG_OVERRIDES={';'.join(args.set)}") - envs = [] - for v in values: - envs.extend(["--env", v]) - - command = [ - "gunicorn", - "--bind", - f"{args.host}:{args.port}", - "--timeout", - "1200", - "--keep-alive", - "2", - "--worker-class", - "uvicorn.workers.UvicornWorker", - "--workers", - str(args.num_workers), - *envs, - "llm_engine_server.inference.forwarding.http_forwarder:app", - ] - subprocess.run(command) - - -if __name__ == "__main__": - entrypoint() diff --git a/server/llm_engine_server/inference/limits.conf b/server/llm_engine_server/inference/limits.conf deleted file mode 100644 index a22a6bc1..00000000 --- a/server/llm_engine_server/inference/limits.conf +++ /dev/null @@ -1,2 +0,0 @@ -llmengine hard nproc 2000 -llmengine soft nproc 1000 diff --git a/server/llm_engine_server/inference/pytorch_or_tf.Dockerfile b/server/llm_engine_server/inference/pytorch_or_tf.Dockerfile deleted file mode 100644 index 999a0564..00000000 --- a/server/llm_engine_server/inference/pytorch_or_tf.Dockerfile +++ /dev/null @@ -1,81 +0,0 @@ -### THIS FILE IS DEPRECATED IN V1. INSTEAD, USE pytorch_or_tf.base.Dockerfile -### and pytorch_or_tf.user.Dockerfile -ARG BASE_IMAGE -FROM ${BASE_IMAGE} - -WORKDIR /app - -# Install basic packages. -# TODO: ffmpeg, libsm6, and lixext6 are essentially hardcoded from lidar. -# It's probably more correct to add support for arbitrary user-specified base images, -# otherwise this base image gets bloated over time. -RUN apt-get update && apt-get install -y \ - apt-utils \ - dumb-init \ - git \ - ssh \ - emacs-nox \ - htop \ - iftop \ - vim \ - ffmpeg \ - libsm6 \ - libxext6 \ - libcurl4-openssl-dev \ - libssl-dev \ - python3-dev \ - gcc \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -# Apparently wget has a vulnerability so we remove it here -RUN apt-get remove wget -y - -# Create a virtualenv for python so we install our packages in the right place -# Not sure how useful the existing contents of the pytorch image are anymore :/ Maybe it's used for cuda/cudnn installs -RUN python3 -m venv /venv -ENV PATH=/venv/bin:$PATH - -# Run everything as not-root user -RUN useradd -m llmengine -s /bin/bash -RUN chown -R llmengine /venv -RUN chown -R llmengine /app -# Limits for nproc and consequently number of files open -ADD llm_engine/llm_engine/inference/limits.conf /etc/security/limits.conf -USER llmengine - -RUN mkdir -p /app/ml_infra_core/llm_engine.core -RUN chown -R llmengine /app/ml_infra_core - -COPY --chown=llmengine ml_infra_core/llm_engine.core/requirements.txt ml_infra_core/llm_engine.core/requirements.txt -RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r ml_infra_core/llm_engine.core/requirements.txt --no-cache-dir -COPY --chown=llmengine ml_infra_core/llm_engine.core ml_infra_core/llm_engine.core -RUN pip install -e ml_infra_core/llm_engine.core - -# Not good for layer caching oh well -# The inference code should only need these few files/directories to function (hopefully) -# Don't copy the entire folder for security reasons - -RUN mkdir -p /app/llm_engine -RUN mkdir -p /app/llm_engine/llm_engine - -RUN chown -R llmengine /app/llm_engine - -COPY --chown=llmengine llm_engine/setup.py /app/llm_engine/setup.py -COPY --chown=llmengine llm_engine/llm_engine.egg-info /app/llm_engine/llm_engine.egg-info -COPY --chown=llmengine llm_engine/llm_engine/__init__.py /app/llm_engine/llm_engine/__init__.py -COPY --chown=llmengine llm_engine/llm_engine/common /app/llm_engine/llm_engine/common -COPY --chown=llmengine llm_engine/llm_engine/domain /app/llm_engine/llm_engine/domain -COPY --chown=llmengine llm_engine/llm_engine/infra /app/llm_engine/llm_engine/infra -COPY --chown=llmengine llm_engine/llm_engine/inference /app/llm_engine/llm_engine/inference -WORKDIR /app/llm_engine -RUN pip install -e . -WORKDIR /app - -RUN pip install -r /app/llm_engine/llm_engine/inference/requirements_base.txt -ARG REQUIREMENTS_FILE -COPY --chown=llmengine ${REQUIREMENTS_FILE} /app/llm_engine/llm_engine/inference/requirements.txt -RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r /app/llm_engine/llm_engine/inference/requirements.txt - - -ENV PYTHONPATH /app diff --git a/server/llm_engine_server/inference/pytorch_or_tf.base.Dockerfile b/server/llm_engine_server/inference/pytorch_or_tf.base.Dockerfile deleted file mode 100644 index 72b711ad..00000000 --- a/server/llm_engine_server/inference/pytorch_or_tf.base.Dockerfile +++ /dev/null @@ -1,78 +0,0 @@ -ARG BASE_IMAGE -FROM ${BASE_IMAGE} - -WORKDIR /app - -# Install basic packages. -# TODO: ffmpeg, libsm6, and lixext6 are essentially hardcoded from lidar. -# It's probably more correct to add support for arbitrary user-specified base images, -# otherwise this base image gets bloated over time. -RUN apt-get update && apt-get install -y \ - apt-utils \ - dumb-init \ - git \ - ssh \ - emacs-nox \ - htop \ - iftop \ - vim \ - ffmpeg \ - libsm6 \ - libxext6 \ - libcurl4-openssl-dev \ - libssl-dev \ - python3-dev \ - gcc \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -# Apparently wget has a vulnerability so we remove it here -RUN apt-get remove wget -y - -# Create a virtualenv for python so we install our packages in the right place -# Not sure how useful the existing contents of the pytorch image are anymore :/ Maybe it's used for cuda/cudnn installs -RUN python3 -m venv /venv -ENV PATH=/venv/bin:$PATH - -# Run everything as not-root user -RUN useradd -m llmengine -s /bin/bash -RUN chown -R llmengine /venv -RUN chown -R llmengine /app -# Limits for nproc and consequently number of files open -ADD llm_engine/llm_engine/inference/limits.conf /etc/security/limits.conf -USER llmengine - -RUN mkdir -p /app/ml_infra_core/llm_engine.core -RUN chown -R llmengine /app/ml_infra_core - -COPY --chown=llmengine ml_infra_core/llm_engine.core/requirements.txt ml_infra_core/llm_engine.core/requirements.txt -RUN --mount=type=secret,id=codeartifact-pip-conf,target=/etc/pip.conf,mode=0444 \ - PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf \ - pip install -r ml_infra_core/llm_engine.core/requirements.txt --no-cache-dir -COPY --chown=llmengine ml_infra_core/llm_engine.core ml_infra_core/llm_engine.core -RUN pip install -e ml_infra_core/llm_engine.core - -# Not good for layer caching oh well -# The inference code should only need these few files/directories to function (hopefully) -# Don't copy the entire folder for security reasons - -RUN mkdir -p /app/llm_engine -RUN mkdir -p /app/llm_engine/llm_engine - -RUN chown -R llmengine /app/llm_engine - -COPY --chown=llmengine \ - llm_engine/llm_engine/inference/requirements_base.txt \ - /app/llm_engine/llm_engine/inference/requirements_base.txt -RUN pip install -r /app/llm_engine/llm_engine/inference/requirements_base.txt - -COPY --chown=llmengine llm_engine/setup.py /app/llm_engine/setup.py -COPY --chown=llmengine llm_engine/llm_engine.egg-info /app/llm_engine/llm_engine.egg-info -COPY --chown=llmengine llm_engine/llm_engine/__init__.py /app/llm_engine/llm_engine/__init__.py -COPY --chown=llmengine llm_engine/llm_engine/common /app/llm_engine/llm_engine/common -COPY --chown=llmengine llm_engine/llm_engine/domain /app/llm_engine/llm_engine/domain -COPY --chown=llmengine llm_engine/llm_engine/infra /app/llm_engine/llm_engine/infra -COPY --chown=llmengine llm_engine/llm_engine/inference /app/llm_engine/llm_engine/inference -WORKDIR /app/llm_engine -RUN pip install -e . -WORKDIR /app diff --git a/server/llm_engine_server/inference/pytorch_or_tf.user.Dockerfile b/server/llm_engine_server/inference/pytorch_or_tf.user.Dockerfile deleted file mode 100644 index eb3c35df..00000000 --- a/server/llm_engine_server/inference/pytorch_or_tf.user.Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -ARG BASE_IMAGE -FROM ${BASE_IMAGE} - -ARG REQUIREMENTS_FILE -COPY --chown=llmengine ${REQUIREMENTS_FILE} /app/llm_engine/llm_engine/inference/requirements.txt -RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r /app/llm_engine/llm_engine/inference/requirements.txt - -ENV PYTHONPATH /app diff --git a/server/llm_engine_server/inference/user.Dockerfile b/server/llm_engine_server/inference/user.Dockerfile deleted file mode 100644 index 595097ed..00000000 --- a/server/llm_engine_server/inference/user.Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -ARG BASE_IMAGE -FROM ${BASE_IMAGE} - -ARG REQUIREMENTS_FILE -COPY --chown=root ${REQUIREMENTS_FILE} /app/llm_engine/llm_engine/inference/requirements.txt -RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r /app/llm_engine/llm_engine/inference/requirements.txt - -ENV PYTHONPATH /app diff --git a/server/llm_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py b/server/llm_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py deleted file mode 100644 index 9c3860cc..00000000 --- a/server/llm_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py +++ /dev/null @@ -1,23 +0,0 @@ -from datadog import statsd -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway - - -class DatadogMonitoringMetricsGateway(MonitoringMetricsGateway): - def __init__(self): - self.tags = [f"env:{ml_infra_config().env}"] - - def emit_attempted_build_metric(self): - statsd.increment("scale_llm_engine_server.service_builder.attempt", tags=self.tags) - - def emit_successful_build_metric(self): - statsd.increment("scale_llm_engine_server.service_builder.success", tags=self.tags) - - def emit_docker_failed_build_metric(self): - statsd.increment("scale_llm_engine_server.service_builder.docker_failed", tags=self.tags) - - def emit_database_cache_hit_metric(self): - statsd.increment("scale_llm_engine_server.database_cache.hit", tags=self.tags) - - def emit_database_cache_miss_metric(self): - statsd.increment("scale_llm_engine_server.database_cache.miss", tags=self.tags) diff --git a/server/llm_engine_server/scripts/autogenerate_client_and_docs.py b/server/llm_engine_server/scripts/autogenerate_client_and_docs.py deleted file mode 100644 index 973e7008..00000000 --- a/server/llm_engine_server/scripts/autogenerate_client_and_docs.py +++ /dev/null @@ -1,39 +0,0 @@ -import json -import subprocess -from pathlib import Path - -from llm_engine_server.api.app import app - -MODULE_PATH = Path(__file__).resolve() -LLM_ENGINE_SERVICE_BASE = MODULE_PATH.parents[2].resolve() -OPENAPI_PATH = (LLM_ENGINE_SERVICE_BASE / "clients/openapi.json").resolve() -LANGUAGE_TO_GENERATOR_NAME = dict(python="python", typescript="typescript-axios") - - -def dump_openapi(openapi_path: str): - """Writes the OpenAPI schema to the specified path.""" - with open(openapi_path, "w") as file: - schema = app.openapi() - file.write(json.dumps(schema, indent=4, sort_keys=True)) - - -def run_openapi_generator(): - """Launches a subprocess with the OpenAPI generator.""" - print("🏭 Generating client") - command = ["docker-compose run openapi-generator-cli"] - subprocess.run( - command, - cwd=str((LLM_ENGINE_SERVICE_BASE / "../ml_infra_core").resolve()), - check=True, - shell=True, - ) - - -def entrypoint(): - """Entrypoint for autogenerating client and documentation.""" - dump_openapi(str(OPENAPI_PATH)) - run_openapi_generator() - - -if __name__ == "__main__": - entrypoint() diff --git a/server/llm_engine_server/scripts/copy_to_public_client.sh b/server/llm_engine_server/scripts/copy_to_public_client.sh deleted file mode 100755 index a40afddf..00000000 --- a/server/llm_engine_server/scripts/copy_to_public_client.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -# Usage: bash build_and_publish_to_codeartifact.sh $PATH_TO_PRIVATE_CLIENT $PATH_TO_PUBLIC_CLIENT - -set -e -PRIVATE_CLIENT_ROOT=$1 -PUBLIC_CLIENT_ROOT=$2 - -rm -rf $PUBLIC_CLIENT_ROOT/launch/api_client/* -cp -r $PRIVATE_CLIENT_ROOT/launch_client/* $PUBLIC_CLIENT_ROOT/launch/api_client/ - -sed -i '' 's/launch_client/launch.api_client/g' $(find $PUBLIC_CLIENT_ROOT/launch/api_client -type f -name '*\.py') diff --git a/server/llm_engine_server/service_builder/celery.py b/server/llm_engine_server/service_builder/celery.py deleted file mode 100644 index 57c9f623..00000000 --- a/server/llm_engine_server/service_builder/celery.py +++ /dev/null @@ -1,13 +0,0 @@ -from llm_engine_server.core.celery import celery_app -from llm_engine_server.core.config import ml_infra_config - -service_builder_service = celery_app( - name="llm_engine_server.service_builder", - modules=[ - "llm_engine_server.service_builder.tasks_v1", - ], - s3_bucket=ml_infra_config().s3_bucket, -) - -if __name__ == "__main__": - service_builder_service.start() diff --git a/server/pyproject.toml b/server/pyproject.toml deleted file mode 100644 index 0a7ba88b..00000000 --- a/server/pyproject.toml +++ /dev/null @@ -1,6 +0,0 @@ -[build-system] -requires = [ - "setuptools", - "wheel", -] -build-backend = 'setuptools.build_meta' diff --git a/server/requirements_override.txt b/server/requirements_override.txt deleted file mode 100644 index 0c282e6e..00000000 --- a/server/requirements_override.txt +++ /dev/null @@ -1,4 +0,0 @@ -# Consists of packages that are incompatible with requirements.txt -aioboto3==10.0.0 -aiobotocore[boto3]~=2.3.4 -urllib3==1.26.11 diff --git a/server/service_configs/service_config.yaml b/server/service_configs/service_config.yaml deleted file mode 100644 index 1f9c4ef2..00000000 --- a/server/service_configs/service_config.yaml +++ /dev/null @@ -1,66 +0,0 @@ -# Default Configs - -# Endpoint config -# K8s namespace the endpoints will be created in -endpoint_namespace: llm-engine - -# Asynchronous endpoints -sqs_profile: default -sqs_queue_policy_template: > - { - "Version": "2012-10-17", - "Id": "__default_policy_ID", - "Statement": [ - { - "Sid": "__owner_statement", - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:root" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/default" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/ml_llm_engine" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - } - ] - } - -sqs_queue_tag_template: > - { - "infra.scale.com/product": "MLInfraLLMEngineSQS", - "infra.scale.com/team": "${team}", - "infra.scale.com/contact": "yi.xu@scale.com", - "infra.scale.com/customer": "AllCustomers", - "infra.scale.com/financialOwner": "yi.xu@scale.com", - "Spellbook-Serve-Endpoint-Id": "${endpoint_id}", - "Spellbook-Serve-Endpoint-Name": "${endpoint_name}", - "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" - } - -# resultsS3Bucket (i.e. where HMI will store model inference results) is currently determined on endpoint creation -# via a request - -# modelBundleS3Bucket (i.e. where model bundles are stored) is not determined by any HMI code, but instead -# by some scaleapi routing layer code for scale-hosted HMI, and by request parameters in general. - -# Currently, the celery redis used is defaulted to scale's celery redis, and is hardcoded inside scaleml's celery impl. -# We'll need to bundle this celery implementation along for open-source hosting. - -# There's a separate piece of infra that caches k8s state onto redis, so we need a url to it -cache_redis_url: redis://redis-elasticache-message-broker.ml-internal.scale.com:6379/15 -s3_file_llm_fine_tuning_job_repository: "s3://scale-ml/hosted-model-inference/llm-ft-job-repository/circleci" -datadog_trace_enabled: false diff --git a/server/service_configs/service_config_circleci.yaml b/server/service_configs/service_config_circleci.yaml deleted file mode 100644 index ca755dea..00000000 --- a/server/service_configs/service_config_circleci.yaml +++ /dev/null @@ -1,65 +0,0 @@ -# Config for CircleCI - -# Endpoint config -# K8s namespace the endpoints will be created in -endpoint_namespace: llm-engine - -# Asynchronous endpoints -sqs_profile: nonexistent_sqs_profile -sqs_queue_policy_template: > - { - "Version": "2012-10-17", - "Id": "__default_policy_ID", - "Statement": [ - { - "Sid": "__owner_statement", - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:root" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/default" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/ml_llm_engine" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - } - ] - } - -sqs_queue_tag_template: > - { - "infra.scale.com/product": "MLInfraLLMEngineSQS", - "infra.scale.com/team": "${team}", - "infra.scale.com/contact": "yi.xu@scale.com", - "infra.scale.com/customer": "AllCustomers", - "infra.scale.com/financialOwner": "yi.xu@scale.com", - "Spellbook-Serve-Endpoint-Id": "${endpoint_id}", - "Spellbook-Serve-Endpoint-Name": "${endpoint_name}", - "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" - } - -# resultsS3Bucket (i.e. where HMI will store model inference results) is currently determined on endpoint creation -# via a request - -# modelBundleS3Bucket (i.e. where model bundles are stored) is not determined by any HMI code, but instead -# by some scaleapi routing layer code for scale-hosted HMI, and by request parameters in general. - -# Currently, the celery redis used is defaulted to scale's celery redis, and is hardcoded inside scaleml's celery impl. -# We'll need to bundle this celery implementation along for open-source hosting. - -# There's a separate piece of infra that caches k8s state onto redis, so we need a url to it -cache_redis_url: redis://127.0.0.1:6379/15 -s3_file_llm_fine_tuning_job_repository: "s3://scale-ml-circleci/hosted-model-inference/llm-ft-job-repository/circleci" diff --git a/server/setup.cfg b/server/setup.cfg deleted file mode 100644 index 69610053..00000000 --- a/server/setup.cfg +++ /dev/null @@ -1,18 +0,0 @@ -[aliases] -test=pytest - -[coverage:run] -omit = - llm_engine/start_server.py, - llm_engine/start_service_builder.py - -[tool:pytest] -addopts = - --verbose - --durations=0 - --cache-clear - --cov=llm_engine - --cov-report=term-missing - --mypy - --mypy-ini-file=mypy.ini - --ignore=clients diff --git a/server/setup.py b/server/setup.py deleted file mode 100644 index 5377b054..00000000 --- a/server/setup.py +++ /dev/null @@ -1,19 +0,0 @@ -# To get circleci to work -from setuptools import find_packages, setup - -setup( - name="scale-llm-engine-server", - version="1.0.0", - packages=[p for p in find_packages() if "tests" not in p], - install_requires=[], - entry_points={ - "console_scripts": [ - "start-service-builder=llm_engine_server.start_service_builder:entrypoint", - "start-server=llm_engine_server.start_server:entrypoint", - "start-fastapi-server=llm_engine_server.entrypoints.start_fastapi_server:entrypoint", - "start-batch-job-orchestration=llm_engine_server.entrypoints.start_batch_job_orchestration:entrypoint", - "hosted-inference-server=llm_engine_server.entrypoints.hosted_inference_server:entrypoint", - "autogen=llm_engine_server.scripts.autogenerate_client_and_docs:entrypoint", - ], - }, -) diff --git a/server/tests/README.md b/server/tests/README.md deleted file mode 100644 index c9ddba03..00000000 --- a/server/tests/README.md +++ /dev/null @@ -1,7 +0,0 @@ -## To Run Unit Tests: - -Inside `server/` folder, run - -```shell -PYTHONPATH=llm_engine_server WORKSPACE=. python3 -m pytest tests --cov=llm_engine_server -``` \ No newline at end of file diff --git a/server/tests/unit/common/__init__.py b/server/tests/unit/common/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/server/tests/unit/core/test_env.py b/server/tests/unit/core/test_env.py deleted file mode 100644 index 8765d4ed..00000000 --- a/server/tests/unit/core/test_env.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -from typing import Any, Callable, Dict, Optional, Sequence -from uuid import uuid4 - -import pytest -from llm_engine_server.core.utils.env import environment - -# DO NOT EXPORT ANYTHING -__all__: Sequence[str] = () - - -def expect_not_present(e: str) -> None: - assert ( - e not in os.environ - ), f"Not expecting env var {e} to be present, instead found {os.environ[e]}" - - -def expect_present(e: str, value: Any) -> None: - assert e in os.environ, f"Expecting env var {e} to be present with {value}" - assert ( - os.environ[e] == value - ), f"Expected env var {e} to have value {value} but instead found {os.environ[e]}" - - -def prepare(e: str, existing: Optional[str]) -> Callable[[], None]: - if existing is not None: - os.environ[e] = existing - return lambda: expect_present(e, existing) - else: - return lambda: expect_not_present(e) - - -def test_environment_kwarg(): - e = "ENV_VAR_TEST" - expect_not_present(e) - # NOTE: This is to test keyword argument use. - # Make sure this **literal value** is the same as `e`'s contents. - with environment(ENV_VAR_TEST="x"): - expect_present(e, "x") - expect_not_present(e) - - -@pytest.mark.parametrize("existing", ["env var has prior value", None]) -def test_environment_normal_cases(existing): - e = f"___{uuid4()}-test_env_var" - check = prepare(e, existing) - - check() - new = f"{uuid4()}--hello_world" - with environment(**{e: new}): - expect_present(e, new) - check() - - -@pytest.mark.parametrize("existing", ["env var has prior value", None]) -def test_environment_with_exception(existing): - e = f"___{uuid4()}-test_env_var" - check = prepare(e, existing) - - check() - new = f"{uuid4()}--hello_world" - with pytest.raises(ValueError): - with environment(**{e: new}): - expect_present(e, new) - raise ValueError("Uh oh! Something went wrong in our context!") - check() - - -def test_environment_multi(): - env_vars: Dict[str, str] = {f"___{uuid4()}-test_env_var--{i}": f"value_{i}" for i in range(25)} - - def ok(): - for e in env_vars.keys(): - expect_not_present(e) - - ok() - with environment(**env_vars): - for e, v in env_vars.items(): - expect_present(e, v) - ok() - - -def test_environment_invalid_states(): - with pytest.raises(ValueError): - environment(**{"": "2"}) - - -def test_environment_unset(): - k = f"___{uuid4()}___--test_unset_env_var--" - v = "hello world! :)" - # when there is a previous value - try: - os.environ[k] = v - with environment(**{k: None}): - assert k not in os.environ - assert k in os.environ - assert os.environ[k] == v - finally: - del os.environ[k] - - # when there is not a previous value - assert k not in os.environ - with environment(**{k: None}): - assert k not in os.environ - assert k not in os.environ diff --git a/server/tests/unit/db/common/test_query.py b/server/tests/unit/db/common/test_query.py deleted file mode 100644 index 0d9b173a..00000000 --- a/server/tests/unit/db/common/test_query.py +++ /dev/null @@ -1,18 +0,0 @@ -from dataclasses import dataclass - -from llm_engine_server.db.models.common.query import Query - - -@dataclass -class ExampleQuery(Query): - """ - Example query - """ - - id: str - name: str - - -def test_query(): - query = ExampleQuery(id="123", name="test") - assert query.to_sqlalchemy_query() == {"id": "123", "name": "test"} diff --git a/server/tests/unit/db/common/test_repository.py b/server/tests/unit/db/common/test_repository.py deleted file mode 100644 index c3dfacf5..00000000 --- a/server/tests/unit/db/common/test_repository.py +++ /dev/null @@ -1,69 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from llm_engine_server.db.models.common.record import Record - - -@pytest.fixture -def mock_session(): - return MagicMock() - - -@pytest.fixture -def mock_query(): - return MagicMock() - - -class TestRecord: - """ - Test the Record class. - """ - - def test_create(self, mock_session): - item = MagicMock() - Record.create(session=mock_session, record=item) - mock_session.add.assert_called_once_with(item) - mock_session.commit.assert_called_once_with() - - @patch("llm_engine_server.db.models.common.record.select") - def test_select_all(self, mock_select, mock_session, mock_query): - mock_query.to_sqlalchemy_query.return_value = {"id": "123", "name": "test"} - mock_select_obj = MagicMock() - mock_select.return_value = mock_select_obj - Record.select_all(session=mock_session, query=mock_query) - mock_select.assert_called_once_with(Record) - mock_select_obj.filter_by.assert_called_once_with(id="123", name="test") - mock_session.execute.assert_called_once_with(mock_select_obj.filter_by.return_value) - mock_session.execute.return_value.scalars.assert_called_once_with() - mock_session.execute.return_value.scalars.return_value.all.assert_called_once_with() - - @patch("llm_engine_server.db.models.common.record.select") - def test_select_by_id(self, mock_select, mock_session): - mock_select_obj = MagicMock() - mock_select.return_value = mock_select_obj - Record.select_by_id(session=mock_session, record_id="123") - mock_select.assert_called_once_with(Record) - mock_select_obj.filter_by.assert_called_once_with(id="123") - mock_session.execute.assert_called_once_with(mock_select_obj.filter_by.return_value) - mock_session.execute.return_value.scalar_one_or_none.assert_called_once_with() - - @patch("llm_engine_server.db.models.common.record.select") - def test_update(self, mock_select, mock_session, mock_query): - mock_select_obj = MagicMock() - mock_select.return_value = mock_select_obj - mock_query.to_sqlalchemy_query.return_value = {"name": "test"} - item = MagicMock() - mock_session.execute.return_value.scalar_one_or_none.return_value = item - Record.update(session=mock_session, record_id="123", query=mock_query) - mock_select.assert_called_once_with(Record) - mock_select_obj.filter_by.assert_called_once_with(id="123") - mock_session.execute.assert_called_once_with(mock_select_obj.filter_by.return_value) - mock_session.execute.return_value.scalar_one_or_none.assert_called_once_with() - item.name = "test" - mock_session.commit.assert_called_once_with() - - def test_delete(self, mock_session): - item = MagicMock() - Record.delete(session=mock_session, record=item) - mock_session.delete.assert_called_once_with(item) - mock_session.commit.assert_called_once_with() diff --git a/server/tests/unit/db/conftest.py b/server/tests/unit/db/conftest.py deleted file mode 100644 index 1a01c532..00000000 --- a/server/tests/unit/db/conftest.py +++ /dev/null @@ -1,467 +0,0 @@ -import datetime -import os -from typing import List - -import psycopg2 -import pytest -import pytest_asyncio -import testing.postgresql -from llm_engine_server.db.base import Session, SessionAsync -from llm_engine_server.db.local_setup import init_database, init_database_and_engine -from llm_engine_server.db.models import ( - BatchJob, - Bundle, - DockerImageBatchJobBundle, - Endpoint, - Model, - ModelArtifact, - ModelVersion, -) -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine -from sqlalchemy.ext.asyncio import create_async_engine - - -def init_testing_postgresql(postgresql: testing.postgresql.Postgresql) -> None: - """Initializes local postgresql server.""" - conn = psycopg2.connect(**postgresql.dsn()) - init_database(postgresql.url(), conn) # type: ignore - - -@pytest.fixture(scope="session") -def engine() -> Engine: - if os.getenv("ML_INFRA_DATABASE_URL"): - url = os.getenv("ML_INFRA_DATABASE_URL") - db_engine = init_database_and_engine(url) - yield db_engine - else: - Postgresql = testing.postgresql.PostgresqlFactory( - cache_initialized_db=True, - on_initialized=init_testing_postgresql, - ) - postgresql = Postgresql().__enter__() - yield create_engine(postgresql.url(), echo=False, future=True) - - -@pytest.fixture(scope="function") -def dbsession(engine: Engine) -> Session: - """Returns a sqlalchemy session, and after the test tears down everything properly.""" - connection = engine.connect() - transaction = connection.begin() - session = Session(bind=connection) - - yield session - - session.close() - transaction.rollback() - connection.close() - - -@pytest_asyncio.fixture(scope="function") -async def dbsession_async(engine: Engine) -> SessionAsync: - """Returns a sqlalchemy session, and after the test tears down everything properly.""" - url = str(engine.url).replace("postgresql://", "postgresql+asyncpg://") - engine = create_async_engine(url) - async with engine.connect() as connection: - async with connection.begin() as transaction: - session = SessionAsync(bind=connection) - yield session - await session.close() - await transaction.rollback() - await connection.close() - - -@pytest_asyncio.fixture(scope="function") -async def bundles(dbsession_async: SessionAsync) -> List[Bundle]: - bundle1 = Bundle( - name="test_bundle_1", - created_by="test_user_1", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="cloudpickle_artifact", - # Artifact fields - artifact_requirements=["test_requirement_1"], - artifact_location="test_location_1", - artifact_app_config=None, - artifact_framework_type="pytorch", - artifact_pytorch_image_tag="test_tag_1", - # Cloudpickle artifact fields - cloudpickle_artifact_load_predict_fn="test_load_predict_fn", - cloudpickle_artifact_load_model_fn="test_load_model_fn", - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundle2 = Bundle( - name="test_bundle_2", - created_by="test_user_1", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="zip_artifact", - # Artifact fields - artifact_requirements=["test_requirement_1"], - artifact_location="test_location_2", - artifact_app_config={"test_key": "test_value"}, - artifact_framework_type="custom_base_image", - artifact_image_repository="test_repo_1", - artifact_image_tag="test_tag_1", - # Zip artifact fields - zip_artifact_load_predict_fn_module_path="test_path_1", - zip_artifact_load_model_fn_module_path="test_path_2", - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundle3 = Bundle( - name="test_bundle_3", - created_by="test_user_2", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="runnable_image", - # Runnable Image fields - runnable_image_repository="test_repository_1", - runnable_image_tag="test_tag_1", - runnable_image_command=["test_command_1"], - runnable_image_predict_route="/test_predict_route", - runnable_image_healthcheck_route="/test_healthcheck_route", - runnable_image_env={"test_key": "test_value"}, - runnable_image_protocol="http", - runnable_image_readiness_initial_delay_seconds=300, - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundle4 = Bundle( - name="test_bundle_4", - created_by="test_user_2", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="triton_enhanced_runnable_image", - # Runnable Image fields - runnable_image_repository="test_repository_1", - runnable_image_tag="test_tag_1", - runnable_image_command=["test_command_1"], - runnable_image_predict_route="/test_predict_route", - runnable_image_healthcheck_route="/test_healthcheck_route", - runnable_image_env={"test_key": "test_value"}, - runnable_image_protocol="http", - runnable_image_readiness_initial_delay_seconds=300, - # Triton enhanced runnable image fields - triton_enhanced_runnable_image_model_repository="test_model_repository_1", - triton_enhanced_runnable_image_model_replicas={"test_model_1": "test_val"}, - triton_enhanced_runnable_image_num_cpu=3.5, - triton_enhanced_runnable_image_commit_tag="test_commit_tag_1", - triton_enhanced_runnable_image_storage="test_storage_1", - triton_enhanced_runnable_image_readiness_initial_delay_seconds=350, - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundle5 = Bundle( - name="test_bundle_5", - created_by="test_user_2", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="streaming_enhanced_runnable_image", - # Runnable Image fields - runnable_image_repository="test_repository_1", - runnable_image_tag="test_tag_1", - runnable_image_command=["test_command_1"], - runnable_image_predict_route="/test_predict_route", - runnable_image_healthcheck_route="/test_healthcheck_route", - runnable_image_env={"test_key": "test_value"}, - runnable_image_protocol="http", - runnable_image_readiness_initial_delay_seconds=300, - # Streaming Enhanced Runnable Image fields - streaming_enhanced_runnable_image_streaming_command=["test_streaming_command_1"], - streaming_enhanced_runnable_image_streaming_predict_route="/test_streaming_predict_route", - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundles = [bundle1, bundle2, bundle3, bundle4, bundle5] - for bundle in bundles: - await Bundle.create(dbsession_async, bundle) - return bundles - - -@pytest_asyncio.fixture(scope="function") -async def endpoints(dbsession_async: SessionAsync, bundles: List[Bundle]) -> List[Endpoint]: - endpoint1 = Endpoint( - name="test_endpoint_1", - created_by="test_user_1", - current_bundle_id=bundles[0].id, - endpoint_metadata=None, - creation_task_id="test_creation_task_id_1", - endpoint_type="async", - destination="test_destination_1", - endpoint_status="READY", - owner="test_user_1", - ) - endpoint2 = Endpoint( - name="test_endpoint_2", - created_by="test_user_1", - current_bundle_id=bundles[0].id, - endpoint_metadata=None, - creation_task_id="test_creation_task_id_1", - endpoint_type="async", - destination="test_destination_1", - endpoint_status="READY", - owner="test_user_1", - ) - endpoint3 = Endpoint( - name="test_endpoint_3", - created_by="test_user_1", - current_bundle_id=bundles[1].id, - endpoint_metadata=None, - creation_task_id="test_creation_task_id_1", - endpoint_type="async", - destination="test_destination_1", - endpoint_status="READY", - owner="test_user_1", - ) - endpoints = [endpoint1, endpoint2, endpoint3] - for endpoint in endpoints: - await Endpoint.create(dbsession_async, endpoint) - return endpoints - - -@pytest_asyncio.fixture(scope="function") -async def batch_jobs( - dbsession_async: SessionAsync, bundles: List[Bundle], endpoints: List[Endpoint] -) -> List[BatchJob]: - batch_job1 = BatchJob( - batch_job_status="READY", - created_by="test_user_1", - owner="test_user_1", - model_bundle_id=bundles[0].id, - model_endpoint_id=endpoints[0].id, - task_ids_location=None, - ) - batch_job2 = BatchJob( - batch_job_status="READY", - created_by="test_user_1", - owner="test_user_1", - model_bundle_id=bundles[0].id, - model_endpoint_id=endpoints[0].id, - task_ids_location=None, - ) - batch_job3 = BatchJob( - batch_job_status="READY", - created_by="test_user_2", - owner="test_user_2", - model_bundle_id=bundles[1].id, - model_endpoint_id=endpoints[2].id, - task_ids_location=None, - ) - jobs = [batch_job1, batch_job2, batch_job3] - for batch_job in jobs: - await BatchJob.create(dbsession_async, batch_job) - return jobs - - -@pytest_asyncio.fixture(scope="function") -async def docker_image_batch_job_bundles( - dbsession_async: SessionAsync, -) -> List[DockerImageBatchJobBundle]: - batch_bundle_1 = DockerImageBatchJobBundle( - name="test_docker_image_batch_job_bundle_1", - created_by="test_user_1", - owner="test_user_1", - image_repository="image_repository", - image_tag="image_tag_git_sha", - command=["python", "script.py", "--arg1"], - env=dict(ENV1="VAL1", ENV2="VAL2"), - mount_location="/mount/location/to/config", - cpus="1", - memory=None, - storage=None, - gpus=None, - gpu_type=None, - public=None, - ) - batch_bundle_2 = DockerImageBatchJobBundle( - name="test_docker_image_batch_job_bundle_1", - created_by="test_user_1", - owner="test_user_1", - image_repository="image_repository", - image_tag="image_tag_git_sha", - command=["python", "script.py", "--arg2"], - env=dict(ENV1="VAL3", ENV2="VAL4"), - mount_location="/mount/location/to/config2", - cpus="2", - memory=None, - storage=None, - gpus=None, - gpu_type=None, - public=None, - ) - batch_bundle_3 = DockerImageBatchJobBundle( - name="test_docker_image_batch_job_bundle_2", - created_by="test_user_2", - owner="test_user_2", - image_repository="image_repository", - image_tag="image_tag_git_sha", - command=["python", "script2.py", "--arg1"], - env=dict(ENV1="VAL1", ENV2="VAL2"), - mount_location="/mount2/location/to/config", - cpus="3", - memory=None, - storage=None, - gpus=None, - gpu_type=None, - public=None, - ) - batch_bundle_1.created_at = datetime.datetime(2022, 1, 1) - batch_bundle_2.created_at = datetime.datetime(2022, 1, 3) - batch_bundle_3.created_at = datetime.datetime(2022, 1, 2) - batch_bundles = [batch_bundle_1, batch_bundle_2, batch_bundle_3] - for batch_bundle in batch_bundles: - await DockerImageBatchJobBundle.create(dbsession_async, batch_bundle) - return batch_bundles - - -@pytest.fixture(scope="function") -def models(dbsession: Session) -> List[Model]: - model1 = Model( - name="test_model_1", - description="test_description_1", - task_types=["test_task_type_1", "test_task_type_2"], - created_by="test_user_id_1", - owner="test_user_id_1", - ) - model2 = Model( - name="test_model_2", - description="test_description_2", - task_types=["test_task_type_1", "test_task_type_3"], - created_by="test_user_id_1", - owner="test_user_id_1", - ) - model3 = Model( - name="test_model_1", - description="test_description_1", - task_types=["test_task_type_2", "test_task_type_3"], - created_by="test_user_id_2", - owner="test_user_id_2", - ) - models = [model1, model2, model3] - for model in models: - Model.create(dbsession, model) - return models - - -@pytest_asyncio.fixture(scope="function") -async def model_versions( - dbsession: Session, models: List[Model], bundles: List[Bundle] -) -> List[ModelVersion]: - model_version1 = ModelVersion( - model_id=models[0].id, - version_number=0, - tags=["test_tag_1", "test_tag_2"], - metadata={"key1": "value1"}, - created_by="test_user_id_1", - ) - model_version2 = ModelVersion( - model_id=models[0].id, - version_number=1, - llm_engine_model_bundle_id=bundles[0].id, - tags=["test_tag_1", "test_tag_3"], - metadata={"key1": "value2"}, - created_by="test_user_id_1", - ) - model_version3 = ModelVersion( - model_id=models[2].id, - version_number=0, - llm_engine_model_bundle_id=bundles[1].id, - nucleus_model_id="test_nucleus_model_id_1", - tags=["test_tag_1", "test_tag_2"], - metadata={"key2": "value3"}, - created_by="test_user_id_1", - ) - model_versions = [model_version1, model_version2, model_version3] - for model_version in model_versions: - ModelVersion.create(dbsession, model_version) - return model_versions - - -@pytest.fixture(scope="function") -def model_artifacts(dbsession: Session) -> List[ModelArtifact]: - model_artifact1 = ModelArtifact( - name="test_model_artifact_1", - description="test_description_1", - is_public=True, - created_by="test_user_id_1", - owner="test_user_id_1", - input_schema={"test_schema_key": "test_schema_value"}, - output_schema={"test_schema_key": "test_schema_value"}, - config={"test_config_key": "test_config_value"}, - location="test_location", - format="pytorch", - format_metadata={"test_format_key": "test_format_value"}, - source="huggingface", - source_metadata={"test_source_key": "test_source_value"}, - ) - model_artifact2 = ModelArtifact( - name="test_model_artifact_2", - description="test_description_2", - is_public=False, - created_by="test_user_id_1", - owner="test_user_id_1", - input_schema={"test_schema_key": "test_schema_value"}, - output_schema={"test_schema_key": "test_schema_value"}, - config={"test_config_key": "test_config_value"}, - location="test_location", - format="pytorch", - format_metadata={"test_format_key": "test_format_value"}, - source="huggingface", - source_metadata={"test_source_key": "test_source_value"}, - ) - model_artifact3 = ModelArtifact( - name="test_model_artifact_3", - description="test_description_3", - is_public=True, - created_by="test_user_id_2", - owner="test_user_id_2", - input_schema={"test_schema_key": "test_schema_value"}, - output_schema={"test_schema_key": "test_schema_value"}, - config={"test_config_key": "test_config_value"}, - location="test_location", - format="tensorflow", - format_metadata={"test_format_key": "test_format_value"}, - source="mlflow", - source_metadata={"test_source_key": "test_source_value"}, - ) - model_artifacts = [model_artifact1, model_artifact2, model_artifact3] - for model_artifact in model_artifacts: - ModelArtifact.create(dbsession, model_artifact) - return model_artifacts diff --git a/server/tests/unit/db/test_endpoint_row_lock.py b/server/tests/unit/db/test_endpoint_row_lock.py deleted file mode 100644 index ed7879e0..00000000 --- a/server/tests/unit/db/test_endpoint_row_lock.py +++ /dev/null @@ -1,22 +0,0 @@ -# Since the bulk of the file involves actually connecting to postgres, we're only gonna test that the -# `get_lock_key` function doesn't error and returns nonnegative ints from 0 to 2**64-1 - -from llm_engine_server.db.base import Session -from llm_engine_server.db.endpoint_row_lock import AdvisoryLockContextManager, get_lock_key - - -def test_get_lock_key(): - pairs = [ - ("userid1", "endpointname1"), - ("userid2", "endpointname2"), - ("userid", "1endpointname1"), - ("endpointname1", "userid1"), - ] + [(str(i), str(i)) for i in range(10000)] - keys = [get_lock_key(uid, name) for uid, name in pairs] - assert len(keys) == len(set(keys)), "Key collision found" - assert all([-(2**63) <= key < 2**63 for key in keys]), "Key falls outside of range" - - -def test_lock_context_manager(dbsession: Session): - with AdvisoryLockContextManager(session=dbsession, lock_id=10) as lock: - assert lock.lock_acquired() diff --git a/server/tests/unit/db/test_llm_engine.py b/server/tests/unit/db/test_llm_engine.py deleted file mode 100644 index 51811464..00000000 --- a/server/tests/unit/db/test_llm_engine.py +++ /dev/null @@ -1,160 +0,0 @@ -from datetime import datetime -from typing import List - -import pytest -from llm_engine_server.db.base import SessionAsync -from llm_engine_server.db.models import BatchJob, Bundle, DockerImageBatchJobBundle, Endpoint - - -@pytest.mark.asyncio -async def test_bundle_select(dbsession_async: SessionAsync, bundles: List[Bundle]): - bundle_by_name_created_by = await Bundle.select_by_name_created_by( - dbsession_async, name="test_bundle_1", created_by="test_user_1" - ) - assert bundle_by_name_created_by is not None - - bundle_by_name_owner = await Bundle.select_by_name_owner( - dbsession_async, name="test_bundle_1", owner="test_user_1" - ) - assert bundle_by_name_owner is not None - - bundles_by_name_created_by = await Bundle.select_all_by_name_created_by( - dbsession_async, name="test_bundle_1", created_by="test_user_1" - ) - assert len(bundles_by_name_created_by) == 1 - - bundles_by_name_owner = await Bundle.select_all_by_name_owner( - dbsession_async, name="test_bundle_1", owner="test_user_1" - ) - assert len(bundles_by_name_owner) == 1 - - bundle_by_id = await Bundle.select_by_id(dbsession_async, bundle_id=bundles[0].id) - assert bundle_by_id is not None - - bundles_by_owner = await Bundle.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - assert len(bundles_by_owner) == 2 - - -@pytest.mark.asyncio -async def test_bundle_select_delete(dbsession_async: SessionAsync, bundles: List[Bundle]): - bundles_by_owner = await Bundle.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - prev_num_bundles = len(bundles_by_owner) - - await Bundle.delete(dbsession_async, bundles_by_owner[0]) - - # After deletion, there should now be 1 fewer bundles for this user. - bundles_by_owner = await Bundle.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - assert len(bundles_by_owner) == prev_num_bundles - 1 - - -@pytest.mark.asyncio -async def test_endpoint_select( - dbsession_async: SessionAsync, bundles: List[Bundle], endpoints: List[Endpoint] -): - endpoint_by_name_created_by = await Endpoint.select_by_name_created_by( - dbsession_async, name="test_endpoint_1", created_by="test_user_1" - ) - assert endpoint_by_name_created_by is not None - - endpoints_by_created_by = await Endpoint.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - assert len(endpoints_by_created_by) == 3 - - endpoints_by_owner = await Endpoint.select_all_by_owner(dbsession_async, owner="test_user_1") - assert len(endpoints_by_owner) == 3 - - endpoints_by_bundle_owner = await Endpoint.select_all_by_bundle_created_by( - dbsession_async, current_bundle_id=bundles[0].id, created_by="test_user_1" - ) - assert len(endpoints_by_bundle_owner) == 2 - - -@pytest.mark.asyncio -async def test_endpoint_select_delete( - dbsession_async: SessionAsync, bundles: List[Bundle], endpoints: List[Endpoint] -): - endpoints_by_user_id = await Endpoint.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - prev_num_endpoints = len(endpoints_by_user_id) - - await Endpoint.delete(dbsession_async, endpoints_by_user_id[0]) - - # After deletion, there should now be 1 fewer endpoints for this user. - endpoints_by_user_id = await Endpoint.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - assert len(endpoints_by_user_id) == prev_num_endpoints - 1 - - -@pytest.mark.asyncio -async def test_batch_job_select(dbsession_async: SessionAsync, batch_jobs: List[BatchJob]): - batch_job_by_id = await BatchJob.select_by_id(dbsession_async, batch_job_id=batch_jobs[0].id) - assert batch_job_by_id is not None - - batch_jobs_by_owner = await BatchJob.select_all_by_owner(dbsession_async, owner="test_user_1") - assert len(batch_jobs_by_owner) == 2 - - batch_jobs_by_owner = await BatchJob.select_all_by_bundle_owner( - dbsession_async, - model_bundle_id=batch_jobs[0].model_bundle_id, - owner="test_user_1", - ) - assert len(batch_jobs_by_owner) == 2 - - -@pytest.mark.asyncio -async def test_batch_job_update(dbsession_async: SessionAsync, batch_jobs: List[BatchJob]): - update_kwargs = {"status": "FAILED", "completed_at": datetime.now()} - await BatchJob.update_by_id( - session=dbsession_async, batch_job_id=batch_jobs[0].id, kwargs=update_kwargs - ) - batch_job = await BatchJob.select_by_id(dbsession_async, batch_job_id=batch_jobs[0].id) - assert batch_job is not None - assert batch_job.batch_job_status == update_kwargs["status"] - assert batch_job.completed_at.second == update_kwargs["completed_at"].second # type: ignore - - -@pytest.mark.asyncio -async def test_docker_image_batch_job_bundle_select( - dbsession_async: SessionAsync, - docker_image_batch_job_bundles: List[DockerImageBatchJobBundle], -): - batch_job_by_id = await DockerImageBatchJobBundle.select_by_id( - dbsession_async, batch_bundle_id=docker_image_batch_job_bundles[0].id - ) - assert batch_job_by_id is not None - - batch_jobs_by_owner = await DockerImageBatchJobBundle.select_all_by_owner( - dbsession_async, owner="test_user_1" - ) - assert len(batch_jobs_by_owner) == 2 - - batch_jobs_by_owner = await DockerImageBatchJobBundle.select_all_by_name_owner( - dbsession_async, - name=docker_image_batch_job_bundles[0].name, - owner="test_user_1", - ) - assert len(batch_jobs_by_owner) == 2 - - batch_jobs_by_owner = await DockerImageBatchJobBundle.select_all_by_name_owner( - dbsession_async, - name=docker_image_batch_job_bundles[2].name, - owner="test_user_2", - ) - assert len(batch_jobs_by_owner) == 1 - - batch_job_latest_by_name_owner = await DockerImageBatchJobBundle.select_latest_by_name_owner( - dbsession_async, - name=docker_image_batch_job_bundles[0].name, - owner="test_user_1", - ) - assert batch_job_latest_by_name_owner is not None - assert batch_job_latest_by_name_owner.id == docker_image_batch_job_bundles[1].id diff --git a/server/tests/unit/db/test_model.py b/server/tests/unit/db/test_model.py deleted file mode 100644 index 1bd6fab3..00000000 --- a/server/tests/unit/db/test_model.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import List - -from llm_engine_server.db.base import Session -from llm_engine_server.db.models import Bundle, Model, ModelArtifact, ModelVersion - - -def test_model_select(dbsession: Session, models: List[Model]): - models_by_owner = Model.select(dbsession, owner="test_user_id_1") - assert len(models_by_owner) == 2 - - models_by_name = Model.select(dbsession, owner="test_user_id_1", name="test_model_1") - assert len(models_by_name) == 1 - - models_by_created_by = Model.select( - dbsession, owner="test_user_id_1", created_by="test_user_id_1" - ) - assert len(models_by_created_by) == 2 - - models_by_task_types = Model.select( - dbsession, owner="test_user_id_1", task_types=["test_task_type_1"] - ) - assert len(models_by_task_types) == 2 - - model_by_id = Model.select_by_id(dbsession, model_id=models[0].id) - assert model_by_id is not None - - -def test_model_update(dbsession: Session, models: List[Model]): - Model.update_by_id(dbsession, models[0].id, description="new description") - model = Model.select_by_id(dbsession, models[0].id) - assert model is not None - assert model.description == "new description" - - -def test_model_version_select( - dbsession: Session, models: List[Model], model_versions: List[ModelVersion] -): - model_versions_by_owner = ModelVersion.select(dbsession, owner="test_user_id_1") - assert len(model_versions_by_owner) == 2 - - model_versions_by_model_id = ModelVersion.select( - dbsession, owner="test_user_id_1", model_id=models[0].id - ) - assert len(model_versions_by_model_id) == 2 - - model_versions_by_model_name = ModelVersion.select( - dbsession, owner="test_user_id_1", model_name="test_model_1" - ) - assert len(model_versions_by_model_name) == 2 - - model_versions_by_tags = ModelVersion.select( - dbsession, owner="test_user_id_1", tags=["test_tag_1"] - ) - assert len(model_versions_by_tags) == 2 - - model_version_by_id = ModelVersion.select_by_id( - dbsession, model_version_id=model_versions[0].id - ) - assert model_version_by_id is not None - - -def test_model_version_select_by_model_id( - dbsession: Session, - bundles: List[Bundle], - models: List[Model], - model_versions: List[ModelVersion], -): - model_version_by_bundle_id = ModelVersion.select_by_llm_engine_model_bundle_id( - dbsession, bundles[0].id - ) - assert model_version_by_bundle_id is not None - assert model_version_by_bundle_id.llm_engine_model_bundle_id == bundles[0].id - - model_version_by_nucleus_model_id = ModelVersion.select_by_nucleus_model_id( - dbsession, model_versions[2].nucleus_model_id # type: ignore - ) - assert model_version_by_nucleus_model_id is not None - - -def test_model_version_get_highest_version_number( - dbsession: Session, models: List[Model], model_versions: List[ModelVersion] -): - version_number = ModelVersion.get_highest_version_number_for_model( - dbsession, - model_id=models[0].id, - ) - assert version_number == 1 - - version_number = ModelVersion.get_highest_version_number_for_model( - dbsession, - model_id=models[1].id, - ) - assert version_number is None - - version_number = ModelVersion.get_highest_version_number_for_model( - dbsession, - model_id="unknown id", - ) - assert version_number is None - - -def test_model_version_update( - dbsession: Session, models: List[Model], model_versions: List[ModelVersion] -): - ModelVersion.update_by_id( - dbsession, model_versions[0].id, nucleus_model_id="test_nucleus_model_id_upd" - ) - model_version = ModelVersion.select_by_id(dbsession, model_versions[0].id) - assert model_version is not None - assert model_version.nucleus_model_id == "test_nucleus_model_id_upd" - - -def test_model_artifact_select(dbsession: Session, model_artifacts: List[ModelArtifact]): - model_artifacts_by_owner = ModelArtifact.select(dbsession, owner="test_user_id_1") - assert len(model_artifacts_by_owner) == 3 - - model_artifacts_by_no_owner = ModelArtifact.select(dbsession) - assert len(model_artifacts_by_no_owner) == 2 - - model_artifacts_by_name = ModelArtifact.select( - dbsession, owner="test_user_id_1", name="test_model_artifact_1" - ) - assert len(model_artifacts_by_name) == 1 - - model_artifacts_by_created_by = ModelArtifact.select( - dbsession, owner="test_user_id_1", created_by="test_user_id_1" - ) - assert len(model_artifacts_by_created_by) == 2 - - model_artifact_by_id = ModelArtifact.select_by_id( - dbsession, model_artifact_id=model_artifacts[0].id - ) - assert model_artifact_by_id is not None - - -def test_model_artifact_update(dbsession: Session, model_artifacts: List[ModelArtifact]): - ModelArtifact.update_by_id(dbsession, model_artifacts[0].id, description="new description") - updated_model_artifact = ModelArtifact.select_by_id(dbsession, model_artifacts[0].id) - assert updated_model_artifact is not None - assert updated_model_artifact.description == "new description" diff --git a/server/tests/unit/infra/services/test_image_cache_service.py b/server/tests/unit/infra/services/test_image_cache_service.py deleted file mode 100644 index 3e04a89f..00000000 --- a/server/tests/unit/infra/services/test_image_cache_service.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Any - -import pytest -from llm_engine_server.infra.services.image_cache_service import ImageCacheService - - -@pytest.mark.asyncio -async def test_image_cache_success( - fake_image_cache_service: ImageCacheService, - model_endpoint_1, - model_endpoint_2, - model_endpoint_3, - model_endpoint_4, -): - infra_states = { - model_endpoint_1.record.id: (bool, model_endpoint_1.infra_state), - model_endpoint_2.record.id: (bool, model_endpoint_2.infra_state), - model_endpoint_3.record.id: (bool, model_endpoint_3.infra_state), - model_endpoint_4.record.id: (bool, model_endpoint_4.infra_state), - } - repo: Any = fake_image_cache_service.model_endpoint_record_repository - repo.add_model_endpoint_record(model_endpoint_1.record) - repo.add_model_endpoint_record(model_endpoint_2.record) - repo.add_model_endpoint_record(model_endpoint_3.record) - repo.add_model_endpoint_record(model_endpoint_4.record) - - await fake_image_cache_service.execute(infra_states) # type: ignore - gateway: Any = fake_image_cache_service.image_cache_gateway - assert gateway.cached_images == { - "a10": [], - "a100": [], - "cpu": [], - "t4": [ - "000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:40d3b5fb06d1a8c3d14903390a3b23ae388bdb19", - "000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:e4ea48ddccfb9ca3ef6d846ae9b2d146d7e30b0f", - "000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:9a319cd9b897f02291f3242b1395f2b669993cdf-fd", - ], - } From 4877373ec67da42e77469a9f2cfbcff1e4f39dd3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 23 Aug 2023 16:26:53 -0700 Subject: [PATCH 051/425] Bump aiohttp from 3.8.4 to 3.8.5 in /model-engine (#217) Bumps [aiohttp](https://github.com/aio-libs/aiohttp) from 3.8.4 to 3.8.5. - [Release notes](https://github.com/aio-libs/aiohttp/releases) - [Changelog](https://github.com/aio-libs/aiohttp/blob/v3.8.5/CHANGES.rst) - [Commits](https://github.com/aio-libs/aiohttp/compare/v3.8.4...v3.8.5) --- updated-dependencies: - dependency-name: aiohttp dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- model-engine/requirements.txt | 154 +++++++++++++++------------------- 1 file changed, 66 insertions(+), 88 deletions(-) diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 2cf51929..174fb1ec 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -6,16 +6,16 @@ # aiofiles==23.1.0 # via quart -aiohttp==3.8.4 +aiohttp==3.8.5 # via - # -r model-engine/requirements.in + # -r requirements.in # kubernetes-asyncio aioredis==2.0.1 - # via -r model-engine/requirements.in + # via -r requirements.in aiosignal==1.3.1 # via aiohttp alembic==1.8.1 - # via -r model-engine/requirements.in + # via -r requirements.in amqp==5.1.1 # via kombu anyio==3.7.1 @@ -28,17 +28,12 @@ async-timeout==4.0.2 # via # aiohttp # aioredis - # redis asyncpg==0.27.0 - # via -r model-engine/requirements.in + # via -r requirements.in attrs==23.1.0 # via # aiohttp # ddtrace -backports-zoneinfo[tzdata]==0.2.1 - # via - # celery - # kombu billiard==4.1.0 # via celery bleach==6.0.0 @@ -47,37 +42,39 @@ blinker==1.6.2 # via quart boto3==1.28.1 # via - # -r model-engine/requirements.in + # -r requirements.in # celery # kombu boto3-stubs[essential]==1.26.67 - # via -r model-engine/requirements.in + # via -r requirements.in botocore==1.31.1 # via - # -r model-engine/requirements.in + # -r requirements.in # boto3 # s3transfer botocore-stubs==1.29.165 # via boto3-stubs build==0.8.0 - # via -r model-engine/requirements.in + # via -r requirements.in cachetools==5.3.1 # via google-auth celery[redis,sqs,tblib]==5.3.1 - # via -r model-engine/requirements.in + # via -r requirements.in certifi==2023.5.7 # via # datadog-api-client # kubernetes # kubernetes-asyncio # requests +cffi==1.15.1 + # via cryptography charset-normalizer==3.2.0 # via # aiohttp # requests click==8.1.4 # via - # -r model-engine/requirements.in + # -r requirements.in # celery # click-didyoumean # click-plugins @@ -91,31 +88,31 @@ click-plugins==1.1.1 click-repl==0.3.0 # via celery cloudpickle==2.1.0 - # via -r model-engine/requirements.in + # via -r requirements.in colorama==0.4.6 # via twine commonmark==0.9.1 # via rich croniter==1.4.1 - # via -r model-engine/requirements.in + # via -r requirements.in +cryptography==41.0.3 + # via secretstorage dataclasses-json==0.5.9 - # via -r model-engine/requirements.in + # via -r requirements.in datadog==0.46.0 - # via -r model-engine/requirements.in + # via -r requirements.in datadog-api-client==2.11.0 - # via -r model-engine/requirements.in + # via -r requirements.in ddtrace==0.49.2 - # via -r model-engine/requirements.in + # via -r requirements.in deprecation==2.1.0 - # via -r model-engine/requirements.in + # via -r requirements.in docker==5.0.3 - # via -r model-engine/requirements.in + # via -r requirements.in docutils==0.20.1 # via readme-renderer -exceptiongroup==1.1.2 - # via anyio fastapi==0.78.0 - # via -r model-engine/requirements.in + # via -r requirements.in frozenlist==1.3.3 # via # aiohttp @@ -123,15 +120,15 @@ frozenlist==1.3.3 gitdb==4.0.10 # via gitpython gitdb2==2.0.6 - # via -r model-engine/requirements.in + # via -r requirements.in gitpython==3.1.32 - # via -r model-engine/requirements.in + # via -r requirements.in google-auth==2.21.0 # via kubernetes greenlet==2.0.2 # via sqlalchemy gunicorn==20.1.0 - # via -r model-engine/requirements.in + # via -r requirements.in h11==0.14.0 # via # hypercorn @@ -142,7 +139,7 @@ h2==4.1.0 hpack==4.0.0 # via h2 httptools==0.5.0 - # via -r model-engine/requirements.in + # via -r requirements.in hypercorn==0.14.4 # via quart hyperframe==6.0.1 @@ -154,38 +151,36 @@ idna==3.4 # yarl importlib-metadata==6.8.0 # via - # alembic # keyring - # quart # twine -importlib-resources==6.0.0 - # via - # alembic - # keyring itsdangerous==2.1.2 # via quart jaraco-classes==3.3.0 # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.0.3 # via - # -r model-engine/requirements.in + # -r requirements.in # quart jmespath==1.0.1 # via # boto3 # botocore json-log-formatter==0.5.2 - # via -r model-engine/requirements.in + # via -r requirements.in keyring==24.2.0 # via twine kombu[sqs]==5.3.1 # via celery kubeconfig==1.1.1 - # via -r model-engine/requirements.in + # via -r requirements.in kubernetes==25.3.0 - # via -r model-engine/requirements.in + # via -r requirements.in kubernetes-asyncio==24.2.2 - # via -r model-engine/requirements.in + # via -r requirements.in mako==1.2.4 # via alembic markupsafe==2.1.3 @@ -225,7 +220,7 @@ mypy-extensions==1.0.0 oauthlib==3.2.2 # via requests-oauthlib orjson==3.8.6 - # via -r model-engine/requirements.in + # via -r requirements.in packaging==23.1 # via # build @@ -244,26 +239,28 @@ prompt-toolkit==3.0.39 # via click-repl protobuf==3.20.3 # via - # -r model-engine/requirements.in + # -r requirements.in # ddtrace psycopg2-binary==2.9.3 - # via -r model-engine/requirements.in + # via -r requirements.in py-xid==0.3.0 - # via -r model-engine/requirements.in + # via -r requirements.in pyasn1==0.5.0 # via # pyasn1-modules # rsa pyasn1-modules==0.3.0 # via google-auth +pycparser==2.21 + # via cffi pycurl==7.45.2 # via - # -r model-engine/requirements.in + # -r requirements.in # celery # kombu pydantic==1.10.11 # via - # -r model-engine/requirements.in + # -r requirements.in # fastapi pygments==2.15.1 # via @@ -279,21 +276,21 @@ python-dateutil==2.8.2 # kubernetes-asyncio # pg8000 python-multipart==0.0.6 - # via -r model-engine/requirements.in + # via -r requirements.in pyyaml==6.0 # via # kubeconfig # kubernetes # kubernetes-asyncio quart==0.18.3 - # via -r model-engine/requirements.in + # via -r requirements.in readme-renderer==40.0 # via twine redis==4.6.0 # via celery requests==2.31.0 # via - # -r model-engine/requirements.in + # -r requirements.in # datadog # docker # kubernetes @@ -302,7 +299,7 @@ requests==2.31.0 # requests-toolbelt # twine requests-auth-aws-sigv4==0.7 - # via -r model-engine/requirements.in + # via -r requirements.in requests-oauthlib==1.3.1 # via kubernetes requests-toolbelt==1.0.0 @@ -310,15 +307,17 @@ requests-toolbelt==1.0.0 rfc3986==2.0.0 # via twine rich==12.6.0 - # via -r model-engine/requirements.in + # via -r requirements.in rsa==4.9 # via google-auth s3transfer==0.6.1 # via boto3 scramp==1.4.4 # via pg8000 +secretstorage==3.3.3 + # via keyring sh==1.14.3 - # via -r model-engine/requirements.in + # via -r requirements.in six==1.16.0 # via # bleach @@ -329,7 +328,7 @@ six==1.16.0 # python-dateutil # tenacity smart-open==5.2.1 - # via -r model-engine/requirements.in + # via -r requirements.in smmap==5.0.0 # via # gitdb @@ -340,12 +339,12 @@ sniffio==1.3.0 # via anyio sqlalchemy[asyncio]==2.0.4 # via - # -r model-engine/requirements.in + # -r requirements.in # alembic sse-starlette==1.6.1 - # via -r model-engine/requirements.in + # via -r requirements.in sseclient-py==1.7.2 - # via -r model-engine/requirements.in + # via -r requirements.in starlette==0.19.1 # via # fastapi @@ -354,23 +353,18 @@ tblib==2.0.0 # via celery tenacity==6.2.0 # via - # -r model-engine/requirements.in + # -r requirements.in # ddtrace testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 - # via -r model-engine/requirements.in -tomli==2.0.1 - # via - # build - # hypercorn - # pep517 + # via -r requirements.in tqdm==4.65.0 # via - # -r model-engine/requirements.in + # -r requirements.in # twine twine==3.7.1 - # via -r model-engine/requirements.in + # via -r requirements.in types-awscrt==0.16.23 # via # botocore-stubs @@ -380,29 +374,15 @@ types-s3transfer==0.6.1 typing-extensions==4.7.1 # via # aioredis - # asgiref # boto3-stubs - # botocore-stubs # datadog-api-client - # kombu - # mypy-boto3-cloudformation - # mypy-boto3-dynamodb - # mypy-boto3-ec2 - # mypy-boto3-lambda - # mypy-boto3-rds - # mypy-boto3-s3 - # mypy-boto3-sqs # pydantic - # rich # sqlalchemy - # starlette # typing-inspect typing-inspect==0.9.0 # via dataclasses-json tzdata==2023.3 - # via - # backports-zoneinfo - # celery + # via celery urllib3==1.26.16 # via # botocore @@ -414,9 +394,9 @@ urllib3==1.26.16 # kubernetes-asyncio # requests uvicorn==0.17.6 - # via -r model-engine/requirements.in + # via -r requirements.in uvloop==0.17.0 - # via -r model-engine/requirements.in + # via -r requirements.in vine==5.0.0 # via # amqp @@ -436,12 +416,10 @@ wsproto==1.2.0 # via hypercorn yarl==1.9.2 # via - # -r model-engine/requirements.in + # -r requirements.in # aiohttp zipp==3.16.0 - # via - # importlib-metadata - # importlib-resources + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: setuptools==68.0.0 From 917998728a4186e0801373ca8dd7a3dc00200132 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 23 Aug 2023 16:34:17 -0700 Subject: [PATCH 052/425] Bump waitress from 2.0.0 to 2.1.2 in /model-engine/model_engine_server/inference (#215) --- .../inference/requirements_base.txt | 1 - .../inference/sync_inference/server.py | 96 ------------------- 2 files changed, 97 deletions(-) delete mode 100644 model-engine/model_engine_server/inference/sync_inference/server.py diff --git a/model-engine/model_engine_server/inference/requirements_base.txt b/model-engine/model_engine_server/inference/requirements_base.txt index aa3acad0..a352a14a 100644 --- a/model-engine/model_engine_server/inference/requirements_base.txt +++ b/model-engine/model_engine_server/inference/requirements_base.txt @@ -10,4 +10,3 @@ tqdm==4.65.0 # Pin typing-extensions so aioitertools doesn't break typing-extensions>=4.1.1 uvicorn==0.17.6 -waitress==2.0.0 diff --git a/model-engine/model_engine_server/inference/sync_inference/server.py b/model-engine/model_engine_server/inference/sync_inference/server.py deleted file mode 100644 index 1713a394..00000000 --- a/model-engine/model_engine_server/inference/sync_inference/server.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -from functools import wraps -from threading import BoundedSemaphore -from typing import Optional - -import waitress -from flask import Flask, Response, abort, request -from model_engine_server.common.dtos.tasks import EndpointPredictV1Request -from model_engine_server.core.loggers import filename_wo_ext, make_logger -from model_engine_server.inference.common import load_predict_fn_or_cls, run_predict - -logger = make_logger(filename_wo_ext(__file__)) - -NAME = "hosted-inference-sync-service" -CONCURRENCY = 2 # TODO read from env var?? what's our api -NUM_THREADS = CONCURRENCY + 1 # Extra thread for rejecting above-concurrency requests -FAIL_ON_CONCURRENCY_LIMIT = True # TODO read from env var?? -PORT = os.environ["PORT"] - - -class FlaskConcurrencyLimiter: - def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): - if concurrency is not None: - if concurrency < 1: - raise ValueError("Concurrency should be at least 1") - self.semaphore: Optional[BoundedSemaphore] = BoundedSemaphore(value=concurrency) - self.blocking = ( - not fail_on_concurrency_limit - ) # we want to block if we want to queue up requests - else: - self.semaphore = None - self.blocking = False # Unused - - def __enter__(self): - logger.debug("Entering concurrency limiter semaphore") - if self.semaphore and not self.semaphore.acquire(blocking=self.blocking): - logger.warning("Too many requests, returning 429") - abort(429) - # Just raises an HTTPException. - # __exit__ should not run; otherwise the release() doesn't have an acquire() - - def __exit__(self, type, value, traceback): - logger.debug("Exiting concurrency limiter semaphore") - if self.semaphore: - self.semaphore.release() - - -def with_concurrency_limit(concurrency_limiter: FlaskConcurrencyLimiter): - def _inner(flask_func): - @wraps(flask_func) - def _inner_2(*args, **kwargs): - with concurrency_limiter: - return flask_func(*args, **kwargs) - - return _inner_2 - - return _inner - - -app = Flask(NAME) -concurrency_limiter = FlaskConcurrencyLimiter(CONCURRENCY, FAIL_ON_CONCURRENCY_LIMIT) - -# How does this interact with threads? -# Analogous to init_worker() inside async_inference -predict_fn = load_predict_fn_or_cls() - - -@app.route("/healthcheck", methods=["GET"]) -@app.route("/healthz", methods=["GET"]) -@app.route("/readyz", methods=["GET"]) -def healthcheck(): - return Response(status=200, headers={}) - - -@app.route("/predict", methods=["POST"]) -@with_concurrency_limit(concurrency_limiter) -def predict(): - """ - Assumption: payload is a JSON with format {"url": , "args": , "returned_pickled": boolean} - Returns: Results of running the predict function on the request url. See `run_predict`. - - """ - try: - payload = request.get_json() - payload_pydantic = EndpointPredictV1Request.parse_obj(payload) - except Exception: - logger.error(f"Failed to decode payload from: {request}") - raise - else: - logger.debug(f"Received request: {payload}") - - return run_predict(predict_fn, payload_pydantic) - - -if __name__ == "__main__": - waitress.serve(app, port=PORT, url_scheme="https", threads=NUM_THREADS) From 719ccc3e4327f7a6ba47ef4fb6f4df24ab8432e5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 23 Aug 2023 17:11:24 -0700 Subject: [PATCH 053/425] Bump certifi from 2023.5.7 to 2023.7.22 in /clients/python (#216) Bumps [certifi](https://github.com/certifi/python-certifi) from 2023.5.7 to 2023.7.22. - [Commits](https://github.com/certifi/python-certifi/compare/2023.05.07...2023.07.22) --- updated-dependencies: - dependency-name: certifi dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Phil Chen <92065453+phil-scale@users.noreply.github.com> --- clients/python/poetry.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock index bee5a561..f2d221f7 100644 --- a/clients/python/poetry.lock +++ b/clients/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -182,13 +182,13 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte [[package]] name = "certifi" -version = "2023.5.7" +version = "2023.7.22" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, - {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, + {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, + {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, ] [[package]] From 06fe771e2ecbdb64fc5a9a8c0fe26a08ccb4092b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 23 Aug 2023 17:20:07 -0700 Subject: [PATCH 054/425] Bump certifi from 2023.5.7 to 2023.7.22 in /model-engine (#218) Bumps [certifi](https://github.com/certifi/python-certifi) from 2023.5.7 to 2023.7.22. - [Commits](https://github.com/certifi/python-certifi/compare/2023.05.07...2023.07.22) --- updated-dependencies: - dependency-name: certifi dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Phil Chen <92065453+phil-scale@users.noreply.github.com> --- model-engine/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 174fb1ec..1e37fe0f 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -60,7 +60,7 @@ cachetools==5.3.1 # via google-auth celery[redis,sqs,tblib]==5.3.1 # via -r requirements.in -certifi==2023.5.7 +certifi==2023.7.22 # via # datadog-api-client # kubernetes From 89b9ce98b7f119a70e31fbc0c121d017d859fede Mon Sep 17 00:00:00 2001 From: Phil Chen <92065453+phil-scale@users.noreply.github.com> Date: Wed, 23 Aug 2023 17:24:05 -0700 Subject: [PATCH 055/425] add readme to model-engine folder (#220) --- model-engine/README.md | 43 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 model-engine/README.md diff --git a/model-engine/README.md b/model-engine/README.md new file mode 100644 index 00000000..7d87e120 --- /dev/null +++ b/model-engine/README.md @@ -0,0 +1,43 @@ +# Model Engine + +The Model Engine is an API server that allows users to create, deploy, edit, +and delete machine learning endpoints. It consists of two main architectural +components: + +- The [gateway](./model_engine_server/entrypoints/start_fastapi_server.py) + provides a REST API for users to interact with. The routes of the REST API are + defined in [`model_engine_server.api`](./model_engine_server/api). +- The [`model_engine_server.service_builder`](./model_engine_server/service_builder) + package is the part of the code that creates the inference pods. It is the + endpoint builder. When we do a `POST` request to `/endpoints`, this gets run. + It gets run when users create or edit endpoints with `[POST,PUT] /v1/model-endpoints` + +There are two other microservices: + +- The [kubernetes cache](./model_engine_server/entrypoints/k8s_cache.py) + stores endpoint metadata on Redis so that Model Engine does not overload the API + server. +- The celery autoscaler (link TBD) automatically scales + the number of inference pods based on the number of requests for async endpoints. + +## Getting started + +Be sure to install the global `../requirements-dev.txt` first prior +to any installations of requirements in this directory +(`pip install -r ../requirements-dev.txt`), as well as the pre-commit hooks +(`pre-commit install` in the `llm-engine` root folder). Then, install the +requirements files and this folder as editable + +```bash +pip install -r requirements.txt && \ + pip install -r requirements-test.txt && \ + pip install -r requirements_override.txt && \ + pip install -e . +``` + +Run `mypy . --install-types` to set up mypy. + +## Testing + +Most of the business logic in Model Engine should contain unit tests, located in +[`tests/unit`](./tests/unit). To run the tests, run `pytest`. From 5414c025ea384e117d2ca5da14678f54956eeb23 Mon Sep 17 00:00:00 2001 From: Phil Chen <92065453+phil-scale@users.noreply.github.com> Date: Wed, 23 Aug 2023 17:38:20 -0700 Subject: [PATCH 056/425] add pre-commit hooks for mypy semgrep and trufflehog (#219) --- .pre-commit-config.yaml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a560c8e3..3449d73a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +fail_fast: false repos: - repo: https://github.com/psf/black # Make sure to update requirements-dev-extra.txt to match versions! @@ -49,3 +50,30 @@ repos: language: python - id: check-toml language: python + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.3.0' # Make sure this matches the version in requirements-dev.txt! + hooks: + - id: mypy + name: mypy-clients-python + entry: mypy --config-file clients/python/mypy.ini clients/python + language: system + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.3.0' # Make sure this matches the version in requirements-dev.txt! + hooks: + - id: mypy + name: mypy-server + entry: mypy --config-file model-engine/mypy.ini model-engine + language: system + - repo: local + hooks: + - id: trufflehog + name: TruffleHog + description: Detect secrets in your data. + entry: bash -c 'docker run --rm -v "$(pwd)/..:/workdir" -i --rm trufflesecurity/trufflehog:latest git file:///workdir/llm-engine --since-commit HEAD --only-verified --fail' + language: system + stages: ["commit", "push"] + - repo: https://github.com/returntocorp/semgrep + rev: 'v1.36.0' + hooks: + - id: semgrep + args: [ '--config', 'p/python', '--error' ] From 600c7b451f14697bf2ecff6bdee337874c2e118c Mon Sep 17 00:00:00 2001 From: Phil Chen <92065453+phil-scale@users.noreply.github.com> Date: Fri, 25 Aug 2023 08:46:27 -0700 Subject: [PATCH 057/425] Fix docs building (#223) --- .pre-commit-config.yaml | 4 ++-- docs/api/data_types.md | 52 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3449d73a..3fe2075c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,14 +55,14 @@ repos: hooks: - id: mypy name: mypy-clients-python - entry: mypy --config-file clients/python/mypy.ini clients/python + entry: mypy --config-file clients/python/mypy.ini language: system - repo: https://github.com/pre-commit/mirrors-mypy rev: 'v1.3.0' # Make sure this matches the version in requirements-dev.txt! hooks: - id: mypy name: mypy-server - entry: mypy --config-file model-engine/mypy.ini model-engine + entry: mypy --config-file model-engine/mypy.ini language: system - repo: local hooks: diff --git a/docs/api/data_types.md b/docs/api/data_types.md index 0663607f..12594058 100644 --- a/docs/api/data_types.md +++ b/docs/api/data_types.md @@ -16,16 +16,37 @@ - num_completion_tokens ::: llmengine.CompletionSyncResponse + selection: + members: + - request_id + - output ::: llmengine.CompletionStreamResponse + selection: + members: + - request_id + - output ::: llmengine.CreateFineTuneResponse + selection: + members: + - id ::: llmengine.GetFineTuneResponse + selection: + members: + - id + - fine_tuned_model ::: llmengine.ListFineTunesResponse + selection: + members: + - jobs ::: llmengine.CancelFineTuneResponse + selection: + members: + - success ::: llmengine.GetLLMEndpointResponse selection: @@ -42,19 +63,50 @@ - spec ::: llmengine.ListLLMEndpointsResponse + selection: + members: + - model_endpoints ::: llmengine.DeleteLLMEndpointResponse + selection: + members: + - deleted ::: llmengine.ModelDownloadRequest + selection: + members: + - model_name + - download_format ::: llmengine.ModelDownloadResponse + selection: + members: + - urls ::: llmengine.UploadFileResponse + selection: + members: + - id ::: llmengine.GetFileResponse + selection: + members: + - id + - filename + - size ::: llmengine.GetFileContentResponse + selection: + members: + - id + - content ::: llmengine.ListFilesResponse + selection: + members: + - files ::: llmengine.DeleteFileResponse + selection: + members: + - deleted From 053acd00477bf0bcd4ad2ee513e8518b45f15bbb Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 25 Aug 2023 09:39:22 -0700 Subject: [PATCH 058/425] Update service template (#222) * Update service template * fix --------- Co-authored-by: Phil Chen <92065453+phil-scale@users.noreply.github.com> --- charts/llm-engine/templates/_helpers.tpl | 8 +- .../service_template_config_map.yaml | 112 +++++++++++++++++- 2 files changed, 113 insertions(+), 7 deletions(-) diff --git a/charts/llm-engine/templates/_helpers.tpl b/charts/llm-engine/templates/_helpers.tpl index 08af45f4..01b63b8d 100644 --- a/charts/llm-engine/templates/_helpers.tpl +++ b/charts/llm-engine/templates/_helpers.tpl @@ -330,21 +330,21 @@ volumeMounts: mountPath: /workspace/llm_engine/llm_engine/infra/gateways/resources/templates {{- if .Values.aws }} - name: config-volume - mountPath: /home/user/.aws/config + mountPath: /root/.aws/config subPath: config {{- end }} {{- if .Values.config.values }} - name: llm-engine-service-config-volume mountPath: /workspace/llm_engine/service_configs - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs {{- end }} {{- end }} {{- define "llmEngine.forwarderVolumeMounts" }} volumeMounts: - name: config-volume - mountPath: /home/user/.aws/config + mountPath: /root/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -354,7 +354,7 @@ volumeMounts: subPath: raw_data {{- if .Values.config.values }} - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs {{- end }} {{- end }} diff --git a/charts/llm-engine/templates/service_template_config_map.yaml b/charts/llm-engine/templates/service_template_config_map.yaml index 0f277c45..af78b38f 100644 --- a/charts/llm-engine/templates/service_template_config_map.yaml +++ b/charts/llm-engine/templates/service_template_config_map.yaml @@ -180,7 +180,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/server/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --http - production_threads - --port @@ -223,7 +223,7 @@ data: - -m - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/server/model_engine_server/inference/configs/service--http_forwarder.yaml + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml - --port - "${FORWARDER_PORT}" - --num-workers @@ -266,7 +266,7 @@ data: - ddtrace-run - run-service - --config - - /workspace/server/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility @@ -483,6 +483,62 @@ data: protocol: TCP name: http ${NODE_PORT_DICT} + virtual-service.yaml: |- + apiVersion: networking.istio.io/v1alpha3 + kind: VirtualService + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + hosts: + - ${RESOURCE_NAME}.${DNS_HOST_DOMAIN} + gateways: + - default/internal-gateway + http: + - route: + - destination: + host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" + port: + number: 80 + destination-rule.yaml: |- + apiVersion: networking.istio.io/v1beta1 + kind: DestinationRule + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" + trafficPolicy: + loadBalancer: + simple: LEAST_REQUEST vertical-pod-autoscaler.yaml: |- apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler @@ -742,3 +798,53 @@ data: command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] terminationGracePeriodSeconds: 0 {{- end }} + cron-trigger.yaml: |- + apiVersion: batch/v1 + kind: CronJob + metadata: + name: ${NAME} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + launch_trigger_id: ${TRIGGER_ID} + tags.datadoghq.com/service: ${TRIGGER_ID} + spec: + schedule: "${CRON_SCHEDULE}" + successfulJobsHistoryLimit: 0 + failedJobsHistoryLimit: 0 + jobTemplate: + spec: + backoffLimit: 0 + activeDeadlineSeconds: ${BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS} + template: + metadata: + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + launch_trigger_id: ${TRIGGER_ID} + tags.datadoghq.com/service: ${TRIGGER_ID} + spec: + containers: + - name: ${NAME} + image: curlimages/curl:7.72.0 + imagePullPolicy: IfNotPresent + command: + - curl + - -X + - 'POST' + - '${HOST}/v1/docker-image-batch-jobs' + - -H + - 'accept: application/json' + - -H + - 'Content-Type: application/json' + - -d + - '{ "docker_image_batch_job_bundle_id": "${DOCKER_IMAGE_BATCH_JOB_BUNDLE_ID}", "job_config": ${JOB_CONFIG}, "labels": ${JOB_METADATA} }' + - -u + - '${OWNER}:' + restartPolicy: Never From 6ac69959e07d05e73e19dc05ec5eb440e5d47003 Mon Sep 17 00:00:00 2001 From: Phil Chen <92065453+phil-scale@users.noreply.github.com> Date: Fri, 25 Aug 2023 09:52:05 -0700 Subject: [PATCH 059/425] Ensure successful helm installation in integration test (#224) * update integration test * fix * add missing files * fix * fix * fix * try updating ImagePullPolicy * fix * add custom configmap * updates * fix * fix * fix * add registry creds * fix * fix * Update service template * fix * fix * fix docs building * fix * fix * fix * final fix hopefully --------- Co-authored-by: Yunfeng Bai --- .circleci/config.yml | 43 ++++++++++++--- .circleci/resources/.minikube-config-map | 4 ++ .circleci/resources/.minikube-registry-creds | 15 ++++++ .circleci/resources/postgres-k8s.yaml | 50 +++++++++++++++++ .circleci/resources/redis-k8s.yaml | 43 +++++++++++++++ charts/llm-engine/templates/_helpers.tpl | 12 ++--- .../templates/balloon_a100_deployment.yaml | 2 +- .../templates/balloon_a10_deployment.yaml | 2 +- .../templates/balloon_cpu_deployment.yaml | 2 +- .../templates/balloon_t4_deployment.yaml | 2 +- .../endpoint_builder_deployment.yaml | 2 +- .../templates/llm_engine_init_job.yaml | 2 +- ...oportional_a100_autoscaler_deployment.yaml | 2 +- ...roportional_a10_autoscaler_deployment.yaml | 2 +- ...proportional_t4_autoscaler_deployment.yaml | 2 +- .../service_template_config_map.yaml | 19 +++---- charts/llm-engine/values_circleci.yaml | 53 ++++++++++++------- .../service_template_config_map_circleci.yaml | 6 +-- 18 files changed, 209 insertions(+), 54 deletions(-) create mode 100644 .circleci/resources/.minikube-config-map create mode 100644 .circleci/resources/.minikube-registry-creds create mode 100644 .circleci/resources/postgres-k8s.yaml create mode 100644 .circleci/resources/redis-k8s.yaml diff --git a/.circleci/config.yml b/.circleci/config.yml index 0deee84f..186b5eba 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -93,34 +93,63 @@ jobs: - run: name: Build Docker Image command: | - docker build . -f model-engine/Dockerfile -t llm-engine:$CIRCLE_SHA1 + docker build . -f model-engine/Dockerfile -t model-engine:$CIRCLE_SHA1 integration_tests: executor: ubuntu-large steps: - checkout + - run: + name: Build Docker Image + command: | + docker build . -f model-engine/Dockerfile -t model-engine:$CIRCLE_SHA1 - run: name: Install minikube command: | cd $HOME curl -LO https://storage.googleapis.com/minikube/releases/latest/minikube_latest_amd64.deb sudo dpkg -i minikube_latest_amd64.deb - minikube start --vm-driver=docker --kubernetes-version=v1.23.0 --memory=14336 --cpus=8 + minikube start --vm-driver=docker --kubernetes-version=v1.23.0 --memory=49152 --cpus=14 - run: - name: Install helm + name: Install kubectl, helm command: | - cd $HOME + cd $HOME/bin curl https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash + curl -LO "https://dl.k8s.io/release/v1.23.0/bin/linux/amd64/kubectl" + chmod +x kubectl + - run: + name: Install helm chart dependencies (Redis, Postgres, Istio) + command: | + sudo apt-get update && sudo apt-get install -y expect + pushd $HOME/project/.circleci/resources + kubectl apply -f redis-k8s.yaml + kubectl apply -f postgres-k8s.yaml + kubectl create secret generic model-engine-postgres-credentials --from-literal=database_url=postgresql://postgres:circle_test@postgres.default:5432/circle_test + export ISTIO_VERSION=1.15.0 + popd + curl -L https://istio.io/downloadIstio | TARGET_ARCH=x86_64 sh - + install istio-${ISTIO_VERSION}/bin/istioctl $HOME/bin + $HOME/bin/istioctl install --set profile=demo -y + kubectl create namespace model-engine + kubectl create configmap default-config --from-literal=config="$(cat $HOME/project/.circleci/resources/.minikube-config-map | envsubst)" + kubectl create configmap default-config --namespace model-engine --from-literal=config="$(cat $HOME/project/.circleci/resources/.minikube-config-map | envsubst)" + cat $HOME/project/.circleci/resources/.minikube-registry-creds | envsubst | expect + minikube addons enable registry-creds + - run: + name: Pre-load model-engine image to minikube + command: | + minikube --logtostderr -v 1 image load model-engine:$CIRCLE_SHA1 - run: name: Install helm chart command: | - cd $HOME/project/charts - helm install llm-engine llm-engine --values llm-engine/values_circleci.yaml --set tag=$CIRCLE_SHA1 + pushd $HOME/project/charts + cat llm-engine/values_circleci.yaml | envsubst > llm-engine/values_circleci_subst.yaml + helm install llm-engine llm-engine --values llm-engine/values_circleci_subst.yaml --set tag=$CIRCLE_SHA1 --atomic --debug executors: ubuntu-large: machine: image: "ubuntu-2004:202201-02" - resource_class: xlarge + resource_class: 2xlarge commands: environment_setup: diff --git a/.circleci/resources/.minikube-config-map b/.circleci/resources/.minikube-config-map new file mode 100644 index 00000000..37ef6f32 --- /dev/null +++ b/.circleci/resources/.minikube-config-map @@ -0,0 +1,4 @@ +# Configmap for AWS credentials inside minikube. +[default] +aws_access_key_id = $CIRCLECI_AWS_ACCESS_KEY +aws_secret_access_key = $CIRCLECI_AWS_SECRET_KEY diff --git a/.circleci/resources/.minikube-registry-creds b/.circleci/resources/.minikube-registry-creds new file mode 100644 index 00000000..a1ef51f2 --- /dev/null +++ b/.circleci/resources/.minikube-registry-creds @@ -0,0 +1,15 @@ +# Script to send the registry-creds addon configuration to minikube +# Source: https://github.com/kubernetes/minikube/issues/8283 +# See expect syntax here: https://manpages.ubuntu.com/manpages/trusty/man1/expect.1.html +spawn minikube addons configure registry-creds +expect "Do you want to enable AWS Elastic Container Registry?" { send "y\r" } +expect "Enter AWS Access Key ID:" { send "$CIRCLECI_AWS_ACCESS_KEY\r" } +expect "Enter AWS Secret Access Key:" { send "$CIRCLECI_AWS_SECRET_KEY\r" } +expect "Enter AWS Session Token:" { send "\r" } +expect "Enter AWS Region:" { send "us-west-2\r" } +expect "Enter 12 digit AWS Account ID (Comma separated list):" { send "$CIRCLECI_AWS_ACCOUNT_ID\r" } +expect "Enter ARN of AWS role to assume:" { send "\r" } +expect "Do you want to enable Google Container Registry?" { send "n\r" } +expect "Do you want to enable Docker Registry?" { send "n\r" } +expect "Do you want to enable Azure Container Registry?" { send "n\r" } +expect eof diff --git a/.circleci/resources/postgres-k8s.yaml b/.circleci/resources/postgres-k8s.yaml new file mode 100644 index 00000000..13d33fe9 --- /dev/null +++ b/.circleci/resources/postgres-k8s.yaml @@ -0,0 +1,50 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: postgres + labels: + app: postgres +spec: + replicas: 1 + selector: + matchLabels: + app: postgres + template: + metadata: + labels: + app: postgres + spec: + containers: + - name: main + image: "cimg/postgres:12.8-postgis" + imagePullPolicy: IfNotPresent + resources: + requests: + memory: 1Gi + cpu: 1 + ports: + - containerPort: 5432 + env: + - name: POSTGRES_USER + value: postgres + - name: POSTGRES_DB + value: circle_test + - name: POSTGRES_PASSWORD + value: circle_test + +--- + +kind: Service +apiVersion: v1 +metadata: + name: postgres + labels: + app: postgres +spec: + type: ClusterIP + selector: + app: postgres + ports: + - name: redis + port: 5432 + targetPort: 5432 diff --git a/.circleci/resources/redis-k8s.yaml b/.circleci/resources/redis-k8s.yaml new file mode 100644 index 00000000..1d3207fe --- /dev/null +++ b/.circleci/resources/redis-k8s.yaml @@ -0,0 +1,43 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: redis-message-broker-master + labels: + app: redis-message-broker-master +spec: + replicas: 1 + selector: + matchLabels: + app: redis-message-broker-master + template: + metadata: + labels: + app: redis-message-broker-master + spec: + containers: + - name: main + image: redis + imagePullPolicy: IfNotPresent + resources: + requests: + memory: 1Gi + cpu: 1 + ports: + - containerPort: 6379 + +--- + +kind: Service +apiVersion: v1 +metadata: + name: redis-message-broker-master + labels: + app: redis-message-broker-master +spec: + type: ClusterIP + selector: + app: redis-message-broker-master + ports: + - name: redis + port: 6379 + targetPort: 6379 diff --git a/charts/llm-engine/templates/_helpers.tpl b/charts/llm-engine/templates/_helpers.tpl index 01b63b8d..eab5d63d 100644 --- a/charts/llm-engine/templates/_helpers.tpl +++ b/charts/llm-engine/templates/_helpers.tpl @@ -150,9 +150,9 @@ env: value: "${PREWARM}" - name: ML_INFRA_SERVICES_CONFIG_PATH {{- if .Values.config.file }} - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/{{ .Values.config.file.infra }}" + value: "${BASE_PATH}/model-engine/model_engine_server/core/configs/{{ .Values.config.file.infra }}" {{- else }} - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "${BASE_PATH}/model-engine/model_engine_server/core/configs/config.yaml" {{- end }} {{- end }} @@ -198,9 +198,9 @@ env: value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH {{- if .Values.config.file }} - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/{{ .Values.config.file.infra }}" + value: "/workspace/model-engine/model_engine_server/core/configs/{{ .Values.config.file.infra }}" {{- else }} - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" {{- end }} {{- end }} @@ -262,12 +262,12 @@ env: - name: DEPLOY_SERVICE_CONFIG_PATH value: "/workspace/llm_engine/service_configs/{{ .Values.config.file.llm_engine }}" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/{{ .Values.config.file.infra }}" + value: "/workspace/model-engine/model_engine_server/core/configs/{{ .Values.config.file.infra }}" {{- else }} - name: DEPLOY_SERVICE_CONFIG_PATH value: "/workspace/llm_engine/service_configs/service_config.yaml" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" {{- end }} - name: CELERY_ELASTICACHE_ENABLED value: "true" diff --git a/charts/llm-engine/templates/balloon_a100_deployment.yaml b/charts/llm-engine/templates/balloon_a100_deployment.yaml index 0559c2f6..471a1599 100644 --- a/charts/llm-engine/templates/balloon_a100_deployment.yaml +++ b/charts/llm-engine/templates/balloon_a100_deployment.yaml @@ -32,7 +32,7 @@ spec: effect: "NoSchedule" containers: - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ .Values.image.pullPolicy }} name: main resources: limits: diff --git a/charts/llm-engine/templates/balloon_a10_deployment.yaml b/charts/llm-engine/templates/balloon_a10_deployment.yaml index 183392f1..db0f49b3 100644 --- a/charts/llm-engine/templates/balloon_a10_deployment.yaml +++ b/charts/llm-engine/templates/balloon_a10_deployment.yaml @@ -32,7 +32,7 @@ spec: effect: "NoSchedule" containers: - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ .Values.image.pullPolicy }} name: main resources: limits: diff --git a/charts/llm-engine/templates/balloon_cpu_deployment.yaml b/charts/llm-engine/templates/balloon_cpu_deployment.yaml index 6849bc61..21c1c6c3 100644 --- a/charts/llm-engine/templates/balloon_cpu_deployment.yaml +++ b/charts/llm-engine/templates/balloon_cpu_deployment.yaml @@ -27,7 +27,7 @@ spec: node-lifecycle: normal containers: - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ .Values.image.pullPolicy }} name: main resources: limits: diff --git a/charts/llm-engine/templates/balloon_t4_deployment.yaml b/charts/llm-engine/templates/balloon_t4_deployment.yaml index 8e871d06..6c549853 100644 --- a/charts/llm-engine/templates/balloon_t4_deployment.yaml +++ b/charts/llm-engine/templates/balloon_t4_deployment.yaml @@ -32,7 +32,7 @@ spec: effect: "NoSchedule" containers: - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ .Values.image.pullPolicy }} name: main resources: limits: diff --git a/charts/llm-engine/templates/endpoint_builder_deployment.yaml b/charts/llm-engine/templates/endpoint_builder_deployment.yaml index a88e07c0..f42afc07 100644 --- a/charts/llm-engine/templates/endpoint_builder_deployment.yaml +++ b/charts/llm-engine/templates/endpoint_builder_deployment.yaml @@ -49,7 +49,7 @@ spec: - ddtrace-run args: - celery - - --app=server.model_engine_server.service_builder + - --app=model_engine_server.service_builder - worker - --loglevel=INFO - --concurrency=2 diff --git a/charts/llm-engine/templates/llm_engine_init_job.yaml b/charts/llm-engine/templates/llm_engine_init_job.yaml index 25d1e6c3..c975355b 100644 --- a/charts/llm-engine/templates/llm_engine_init_job.yaml +++ b/charts/llm-engine/templates/llm_engine_init_job.yaml @@ -1,4 +1,4 @@ -{{- if .Values.secrets.kubernetesDatabaseSecretName }} +{{- if (and .Values.llmEngineInitJob .Values.llmEngineInitJob.enabled) }} apiVersion: batch/v1 kind: Job metadata: diff --git a/charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml b/charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml index f288bdf1..4e99558f 100644 --- a/charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml +++ b/charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml @@ -28,7 +28,7 @@ spec: operator: "Exists" containers: - image: registry.k8s.io/cpa/cluster-proportional-autoscaler:1.8.5 - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ .Values.image.pullPolicy }} name: main resources: requests: diff --git a/charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml b/charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml index d6fd7594..e2dbe1c2 100644 --- a/charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml +++ b/charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml @@ -28,7 +28,7 @@ spec: operator: "Exists" containers: - image: registry.k8s.io/cpa/cluster-proportional-autoscaler:1.8.5 - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ .Values.image.pullPolicy }} name: main resources: requests: diff --git a/charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml b/charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml index 29e5a8e9..bdb535e0 100644 --- a/charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml +++ b/charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml @@ -28,7 +28,7 @@ spec: operator: "Exists" containers: - image: registry.k8s.io/cpa/cluster-proportional-autoscaler:1.8.5 - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ .Values.image.pullPolicy }} name: main resources: requests: diff --git a/charts/llm-engine/templates/service_template_config_map.yaml b/charts/llm-engine/templates/service_template_config_map.yaml index af78b38f..b344d3cd 100644 --- a/charts/llm-engine/templates/service_template_config_map.yaml +++ b/charts/llm-engine/templates/service_template_config_map.yaml @@ -20,6 +20,7 @@ {{- $service_template_service_account_name := .Values.serviceTemplate.serviceAccountName }} {{- $service_template_aws_config_map_name := .Values.serviceTemplate.awsConfigMapName }} {{- $celery_broker_type := .Values.celeryBrokerType }} +{{- $image_pull_policy := .Values.image.pullPolicy }} {{- if .Values.message }} {{- .Values.message }} @@ -95,7 +96,7 @@ data: containers: {{- if eq $flavor "artifact" }} - image: ${IMAGE} - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ $image_pull_policy }} name: main {{- with $security_context }} securityContext: @@ -173,7 +174,7 @@ data: {{- if eq $mode "sync" }} - name: http-forwarder image: {{ $forwarder_repository }}:${FORWARDER_IMAGE_TAG} - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ $image_pull_policy }} command: - /usr/bin/dumb-init - -- @@ -214,7 +215,7 @@ data: {{- else if eq $mode "streaming" }} - name: http-forwarder image: {{ $forwarder_repository }}:{{ $tag }} - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ $image_pull_policy }} command: - /usr/bin/dumb-init - -- @@ -259,7 +260,7 @@ data: {{- else if eq $mode "async" }} - name: celery-forwarder image: {{ $forwarder_repository }}:${FORWARDER_IMAGE_TAG} - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ $image_pull_policy }} command: - /usr/bin/dumb-init - -- @@ -296,7 +297,7 @@ data: {{- if eq $flavor "triton-enhanced-runnable-image" }} - name: tritonserver image: {{ $triton_repository }}:${TRITON_COMMIT_TAG}-triton - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ $image_pull_policy }} command: - /usr/bin/dumb-init - -- @@ -345,7 +346,7 @@ data: {{- toYaml . | nindent 16 }} {{- end }} image: ${IMAGE} - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ $image_pull_policy }} command: ${COMMAND} env: ${MAIN_ENV} readinessProbe: @@ -607,7 +608,7 @@ data: {{- tuple $env_var | toYaml | nindent 16 }} {{- end }} {{- end }} - imagePullPolicy: Always + imagePullPolicy: {{ $image_pull_policy }} command: - dumb-init - -- @@ -696,7 +697,7 @@ data: {{- tuple $env_var | toYaml | nindent 16 }} {{- end }} {{- end }} - imagePullPolicy: Always + imagePullPolicy: {{ $image_pull_policy }} command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. @@ -793,7 +794,7 @@ data: {{- end }} containers: - image: public.ecr.aws/docker/library/busybox:latest - imagePullPolicy: IfNotPresent + imagePullPolicy: {{ $image_pull_policy }} name: busybox command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] terminationGracePeriodSeconds: 0 diff --git a/charts/llm-engine/values_circleci.yaml b/charts/llm-engine/values_circleci.yaml index b57e0ec1..fb170ae9 100644 --- a/charts/llm-engine/values_circleci.yaml +++ b/charts/llm-engine/values_circleci.yaml @@ -14,16 +14,17 @@ replicaCount: # tag: context: circleci image: - gatewayRepository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine - builderRepository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine - cacherRepository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine - forwarderRepository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine - pullPolicy: Always + gatewayRepository: model-engine + builderRepository: model-engine + cacherRepository: model-engine + forwarderRepository: model-engine + pullPolicy: IfNotPresent # serviceIdentifier: secrets: - awsDatabaseSecretName: prod/llm_engine.db + kubernetesDatabaseSecretName: model-engine-postgres-credentials + service: type: ClusterIP @@ -33,7 +34,7 @@ virtualservice: enabled: true annotations: { } hostDomains: - - ml-internal.scale.com + - example.com gateways: - default/internal-gateway @@ -75,16 +76,17 @@ config: k8s_cluster_name: minikube dns_host_domain: localhost default_region: us-west-2 - ml_account_id: "000000000000" - docker_repo_prefix: "000000000000.dkr.ecr.us-west-2.amazonaws.com" + ml_account_id: "$CIRCLECI_AWS_ACCOUNT_ID" + docker_repo_prefix: "CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com" redis_host: redis-message-broker-master.default - s3_bucket: "scale-ml-circleci" + s3_bucket: "$CIRCLECI_AWS_S3_BUCKET" profile_ml_worker: "default" profile_ml_inference_worker: "default" llm_engine: # Endpoint config # K8s namespace the endpoints will be created in - endpoint_namespace: scale-deploy + endpoint_namespace: model-engine + model_primitive_host: none # Asynchronous endpoints sqs_profile: default @@ -97,26 +99,26 @@ config: "Sid": "__owner_statement", "Effect": "Allow", "Principal": { - "AWS": "arn:aws:iam::000000000000:root" + "AWS": "arn:aws:iam::$CIRCLECI_AWS_ACCOUNT_ID:root" }, "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" + "Resource": "arn:aws:sqs:us-west-2:$CIRCLECI_AWS_ACCOUNT_ID:${queue_name}" }, { "Effect": "Allow", "Principal": { - "AWS": "arn:aws:iam::000000000000:role/default" + "AWS": "arn:aws:iam::$CIRCLECI_AWS_ACCOUNT_ID:role/default" }, "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" + "Resource": "arn:aws:sqs:us-west-2:$CIRCLECI_AWS_ACCOUNT_ID:${queue_name}" }, { "Effect": "Allow", "Principal": { - "AWS": "arn:aws:iam::000000000000:role/ml_llm_engine" + "AWS": "arn:aws:iam::$CIRCLECI_AWS_ACCOUNT_ID:role/ml_llm_engine" }, "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" + "Resource": "arn:aws:sqs:us-west-2:$CIRCLECI_AWS_ACCOUNT_ID:${queue_name}" } ] } @@ -127,12 +129,23 @@ config: "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" } + billing_queue_arn: none cache_redis_url: redis://redis-message-broker-master.default/15 + s3_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET" + datadog_trace_enabled: false + istio_enabled: true + tgi_repository: "text-generation-inference" + hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET" # Service Account serviceAccount: annotations: - eks.amazonaws.com/role-arn: arn:aws:iam::000000000000:role/eks-default2 + "eks.amazonaws.com/role-arn": arn:aws:iam::$CIRCLECI_AWS_ACCOUNT_ID:role/default + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" + namespaces: + - default + - model-engine aws: configMap: @@ -145,8 +158,8 @@ forwarder: triton: image: - repository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv - tag: e83eccbc8959f90ebbe4bda618b61ec6ee2d8394-triton + repository: nvidia/tritonserver + tag: latest serviceTemplate: securityContext: diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 1f712fdb..85f312b8 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -2770,7 +2770,7 @@ data: value: ${GIT_TAG} - name: GIT_TAG value: ${GIT_TAG} - imagePullPolicy: Always + imagePullPolicy: IfNotPresent command: - dumb-init - -- @@ -2893,7 +2893,7 @@ data: value: ${GIT_TAG} - name: GIT_TAG value: ${GIT_TAG} - imagePullPolicy: Always + imagePullPolicy: IfNotPresent command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. @@ -3037,7 +3037,7 @@ data: value: ${GIT_TAG} - name: GIT_TAG value: ${GIT_TAG} - imagePullPolicy: Always + imagePullPolicy: IfNotPresent command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. From 2895d7d0583a101bd3d50929caef7cc097f45203 Mon Sep 17 00:00:00 2001 From: Phil Chen <92065453+phil-scale@users.noreply.github.com> Date: Fri, 25 Aug 2023 14:26:15 -0700 Subject: [PATCH 060/425] Update helm charts (#226) * update helm chart to latest and rename to model-engine * fix * renames * fix maybe * fix again * fix * update names * fix * fix --- .circleci/config.yml | 4 +- charts/llm-engine/README.md | 3 - .../llm-engine/templates/aws_config_map.yaml | 18 -- .../llm-engine/templates/gateway_service.yaml | 15 -- .../launch_default_priority_class.yaml | 11 - .../{llm-engine => model-engine}/.helmignore | 0 .../{llm-engine => model-engine}/Chart.yaml | 2 +- charts/model-engine/README.md | 3 + .../templates/_helpers.tpl | 171 ++++++++------ .../templates/aws_config_map.yaml | 26 +++ .../templates/balloon_a100_deployment.yaml | 14 +- .../templates/balloon_a10_deployment.yaml | 14 +- .../templates/balloon_cpu_deployment.yaml | 14 +- .../templates/balloon_t4_deployment.yaml | 14 +- .../templates/cacher_deployment.yaml | 28 +-- .../templates/cacher_vpa.yaml | 8 +- .../templates/cluster_rolebinding.yaml | 8 +- .../templates/database_init_job.yaml | 16 +- .../endpoint_builder_deployment.yaml | 32 +-- .../templates/endpoint_builder_vpa.yaml | 8 +- .../templates/gateway_deployment.yaml | 28 +-- .../templates/gateway_hpa.yaml | 6 +- .../templates/gateway_service.yaml | 18 ++ .../templates/gateway_vpa.yaml | 8 +- .../templates/istio-destinationrule.yaml | 18 ++ .../templates/istio-virtualservice.yaml | 31 +++ .../model_engine_default_priority_class.yaml | 11 + .../model_engine_high_priority_class.yaml} | 4 +- .../model_engine_low_priority_class.yaml} | 4 +- ...oportional_a100_autoscaler_deployment.yaml | 12 +- ...roportional_a10_autoscaler_deployment.yaml | 12 +- ...proportional_t4_autoscaler_deployment.yaml | 12 +- .../templates/service_account.yaml | 6 +- .../templates/service_config_map.yaml | 9 +- .../service_template_config_map.yaml | 220 ++++++------------ .../templates/spellbook_init_job.yaml} | 22 +- charts/model-engine/values.yaml | 9 + .../values_circleci.yaml | 11 +- .../values_sample.yaml | 0 39 files changed, 437 insertions(+), 413 deletions(-) delete mode 100644 charts/llm-engine/README.md delete mode 100644 charts/llm-engine/templates/aws_config_map.yaml delete mode 100644 charts/llm-engine/templates/gateway_service.yaml delete mode 100644 charts/llm-engine/templates/launch_default_priority_class.yaml rename charts/{llm-engine => model-engine}/.helmignore (100%) rename charts/{llm-engine => model-engine}/Chart.yaml (98%) create mode 100644 charts/model-engine/README.md rename charts/{llm-engine => model-engine}/templates/_helpers.tpl (63%) create mode 100644 charts/model-engine/templates/aws_config_map.yaml rename charts/{llm-engine => model-engine}/templates/balloon_a100_deployment.yaml (76%) rename charts/{llm-engine => model-engine}/templates/balloon_a10_deployment.yaml (76%) rename charts/{llm-engine => model-engine}/templates/balloon_cpu_deployment.yaml (72%) rename charts/{llm-engine => model-engine}/templates/balloon_t4_deployment.yaml (76%) rename charts/{llm-engine => model-engine}/templates/cacher_deployment.yaml (61%) rename charts/{llm-engine => model-engine}/templates/cacher_vpa.yaml (74%) rename charts/{llm-engine => model-engine}/templates/cluster_rolebinding.yaml (58%) rename charts/{llm-engine => model-engine}/templates/database_init_job.yaml (70%) rename charts/{llm-engine => model-engine}/templates/endpoint_builder_deployment.yaml (61%) rename charts/{llm-engine => model-engine}/templates/endpoint_builder_vpa.yaml (74%) rename charts/{llm-engine => model-engine}/templates/gateway_deployment.yaml (67%) rename charts/{llm-engine => model-engine}/templates/gateway_hpa.yaml (78%) create mode 100644 charts/model-engine/templates/gateway_service.yaml rename charts/{llm-engine => model-engine}/templates/gateway_vpa.yaml (77%) create mode 100644 charts/model-engine/templates/istio-destinationrule.yaml create mode 100644 charts/model-engine/templates/istio-virtualservice.yaml create mode 100644 charts/model-engine/templates/model_engine_default_priority_class.yaml rename charts/{llm-engine/templates/launch_high_priority_class.yaml => model-engine/templates/model_engine_high_priority_class.yaml} (53%) rename charts/{llm-engine/templates/launch_low_priority_class.yaml => model-engine/templates/model_engine_low_priority_class.yaml} (53%) rename charts/{llm-engine => model-engine}/templates/proportional_a100_autoscaler_deployment.yaml (77%) rename charts/{llm-engine => model-engine}/templates/proportional_a10_autoscaler_deployment.yaml (77%) rename charts/{llm-engine => model-engine}/templates/proportional_t4_autoscaler_deployment.yaml (77%) rename charts/{llm-engine => model-engine}/templates/service_account.yaml (66%) rename charts/{llm-engine => model-engine}/templates/service_config_map.yaml (64%) rename charts/{llm-engine => model-engine}/templates/service_template_config_map.yaml (78%) rename charts/{llm-engine/templates/llm_engine_init_job.yaml => model-engine/templates/spellbook_init_job.yaml} (60%) create mode 100644 charts/model-engine/values.yaml rename charts/{llm-engine => model-engine}/values_circleci.yaml (97%) rename charts/{llm-engine => model-engine}/values_sample.yaml (100%) diff --git a/.circleci/config.yml b/.circleci/config.yml index 186b5eba..78751757 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -142,8 +142,8 @@ jobs: name: Install helm chart command: | pushd $HOME/project/charts - cat llm-engine/values_circleci.yaml | envsubst > llm-engine/values_circleci_subst.yaml - helm install llm-engine llm-engine --values llm-engine/values_circleci_subst.yaml --set tag=$CIRCLE_SHA1 --atomic --debug + cat model-engine/values_circleci.yaml | envsubst > model-engine/values_circleci_subst.yaml + helm install model-engine model-engine --values model-engine/values_circleci_subst.yaml --set tag=$CIRCLE_SHA1 --atomic --debug executors: ubuntu-large: diff --git a/charts/llm-engine/README.md b/charts/llm-engine/README.md deleted file mode 100644 index 9281c374..00000000 --- a/charts/llm-engine/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# LLM Engine Helm Chart - -This chart contains k8s templates for deploying LLM Engine to a k8s cluster. diff --git a/charts/llm-engine/templates/aws_config_map.yaml b/charts/llm-engine/templates/aws_config_map.yaml deleted file mode 100644 index 48e2c30a..00000000 --- a/charts/llm-engine/templates/aws_config_map.yaml +++ /dev/null @@ -1,18 +0,0 @@ -{{- if .Values.aws }} -{{- if eq .Values.aws.configMap.create true }} -apiVersion: v1 -kind: ConfigMap -metadata: - name: {{ .Values.aws.configMap.name }} - labels: - {{- include "llmEngine.labels" . | nindent 4 }} - annotations: - "helm.sh/hook": pre-install,pre-upgrade - "helm.sh/hook-weight": "-2" -data: - config: |- - [profile {{ .Values.aws.profileName }}] - role_arn = {{ index .Values.serviceAccount.annotations "eks.amazonaws.com/role-arn" }} - web_identity_token_file = /var/run/secrets/eks.amazonaws.com/serviceaccount/token -{{- end }} -{{- end }} diff --git a/charts/llm-engine/templates/gateway_service.yaml b/charts/llm-engine/templates/gateway_service.yaml deleted file mode 100644 index 9a3497c1..00000000 --- a/charts/llm-engine/templates/gateway_service.yaml +++ /dev/null @@ -1,15 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: {{ include "llmEngine.fullname" . }} - labels: - {{- include "llmEngine.labels" . | nindent 4 }} -spec: - type: {{ .Values.service.type }} - ports: - - port: {{ .Values.service.port }} - targetPort: http - protocol: TCP - name: http - selector: - {{- include "llmEngine.selectorLabels.gateway" . | nindent 4 }} diff --git a/charts/llm-engine/templates/launch_default_priority_class.yaml b/charts/llm-engine/templates/launch_default_priority_class.yaml deleted file mode 100644 index 1217c7c1..00000000 --- a/charts/llm-engine/templates/launch_default_priority_class.yaml +++ /dev/null @@ -1,11 +0,0 @@ -{{- if not .Values.serviceIdentifier }} -apiVersion: scheduling.k8s.io/v1 -kind: PriorityClass -metadata: - name: "{{ include "llmEngine.fullname" . }}-default-priority" -value: 1 -# This ensures that the default llm-engine pods will never preempt any pods, which means -# they cannot take advantage of the dummy nodes. -preemptionPolicy: Never -description: "Default Priority Class for LLMEngine" -{{- end }} diff --git a/charts/llm-engine/.helmignore b/charts/model-engine/.helmignore similarity index 100% rename from charts/llm-engine/.helmignore rename to charts/model-engine/.helmignore diff --git a/charts/llm-engine/Chart.yaml b/charts/model-engine/Chart.yaml similarity index 98% rename from charts/llm-engine/Chart.yaml rename to charts/model-engine/Chart.yaml index 40300d18..16f2c405 100644 --- a/charts/llm-engine/Chart.yaml +++ b/charts/model-engine/Chart.yaml @@ -1,5 +1,5 @@ apiVersion: v2 -name: llm-engine +name: model-engine description: A Helm chart for Kubernetes # A chart can be either an 'application' or a 'library' chart. diff --git a/charts/model-engine/README.md b/charts/model-engine/README.md new file mode 100644 index 00000000..19c826ce --- /dev/null +++ b/charts/model-engine/README.md @@ -0,0 +1,3 @@ +# Scale Launch Helm Chart + +This chart contains k8s templates for the gateway, endpoint builder, and k8s cacher. diff --git a/charts/llm-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl similarity index 63% rename from charts/llm-engine/templates/_helpers.tpl rename to charts/model-engine/templates/_helpers.tpl index eab5d63d..1a6155ce 100644 --- a/charts/llm-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -1,7 +1,7 @@ {{/* Expand the name of the chart. */}} -{{- define "llmEngine.name" -}} +{{- define "modelEngine.name" -}} {{- default .Chart.Name | trunc 63 | trimSuffix "-" }} {{- end }} @@ -10,7 +10,7 @@ Create a default fully qualified app name. We truncate at 40 chars because some Kubernetes name fields are limited to 63 (by the DNS naming spec). If release name contains chart name it will be used as a full name. */}} -{{- define "llmEngine.fullname" -}} +{{- define "modelEngine.fullname" -}} {{- if .Values.serviceIdentifier }} {{- printf "%s-%s" .Chart.Name .Values.serviceIdentifier | trunc 40 | trimSuffix "-" }} {{- else }} @@ -18,73 +18,77 @@ If release name contains chart name it will be used as a full name. {{- end }} {{- end }} -{{- define "llmEngine.buildername" -}} -"{{ include "llmEngine.fullname" . }}-endpoint-builder" +{{- define "modelEngine.buildername" -}} +"{{ include "modelEngine.fullname" . }}-endpoint-builder" {{- end }} -{{- define "llmEngine.cachername" -}} -"{{ include "llmEngine.fullname" . }}-cacher" +{{- define "modelEngine.cachername" -}} +"{{ include "modelEngine.fullname" . }}-cacher" +{{- end }} + +{{- define "modelEngine.gatewayurl" -}} +{{ .Values.hostDomain.prefix }}{{ include "modelEngine.fullname" . }}.{{ .Release.Namespace }}:{{ .Values.service.port }} {{- end }} {{/* Create chart name and version as used by the chart label. */}} -{{- define "llmEngine.chart" -}} +{{- define "modelEngine.chart" -}} {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} {{- end }} {{/* Common labels */}} -{{- define "llmEngine.labels" -}} +{{- define "modelEngine.labels" -}} team: infra -product: llm-engine -helm.sh/chart: {{ include "llmEngine.chart" . }} +product: launch +helm.sh/chart: {{ include "modelEngine.chart" . }} app.kubernetes.io/managed-by: {{ .Release.Service }} app.kubernetes.io/version: {{ .Values.tag }} tags.datadoghq.com/version: {{ .Values.tag }} tags.datadoghq.com/env: {{ .Values.context }} {{- end }} -{{- define "llmEngine.selectorLabels.builder" -}} -app: {{ include "llmEngine.buildername" . }} +{{- define "modelEngine.selectorLabels.builder" -}} +app: {{ include "modelEngine.buildername" . }} {{- end }} -{{- define "llmEngine.selectorLabels.cacher" -}} -app: {{ include "llmEngine.cachername" . }} +{{- define "modelEngine.selectorLabels.cacher" -}} +app: {{ include "modelEngine.cachername" . }} {{- end }} -{{- define "llmEngine.selectorLabels.gateway" -}} -app: {{ include "llmEngine.fullname" . -}} +{{- define "modelEngine.selectorLabels.gateway" -}} +app: {{ include "modelEngine.fullname" . -}} {{- end }} -{{- define "llmEngine.baseTemplateLabels" -}} +{{- define "modelEngine.baseTemplateLabels" -}} user_id: ${OWNER} team: ${TEAM} product: ${PRODUCT} created_by: ${CREATED_BY} owner: ${OWNER} env: {{- .Values.context | printf " %s" }} -managed-by: {{- include "llmEngine.fullname" . | printf " %s\n" -}} -use_scale_llm_engine_endpoint_network_policy: "true" +managed-by: {{- include "modelEngine.fullname" . | printf " %s\n" -}} +use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: {{- .Values.context | printf " %s" }} -tags.datadoghq.com/version: {{- .Values.tag | printf " %s" }} +tags.datadoghq.com/version: ${GIT_TAG} {{- end }} -{{- define "llmEngine.serviceTemplateLabels" -}} -{{- include "llmEngine.baseTemplateLabels" . | printf "%s\n" -}} +{{- define "modelEngine.serviceTemplateLabels" -}} +{{- include "modelEngine.baseTemplateLabels" . | printf "%s\n" -}} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} {{- end }} -{{- define "llmEngine.jobTemplateLabels" -}} -{{- include "llmEngine.baseTemplateLabels" . | printf "%s\n" -}} -llm_engine_job_id: ${JOB_ID} +{{- define "modelEngine.jobTemplateLabels" -}} +{{- include "modelEngine.baseTemplateLabels" . | printf "%s\n" -}} +launch_job_id: ${JOB_ID} tags.datadoghq.com/service: ${JOB_ID} {{- end }} -{{- define "llmEngine.serviceTemplateAsyncAnnotations" -}} +{{- define "modelEngine.serviceTemplateAsyncAnnotations" -}} celery.scaleml.autoscaler/queue: ${QUEUE} celery.scaleml.autoscaler/broker: ${BROKER_NAME} celery.scaleml.autoscaler/taskVisibility: "VISIBILITY_24H" @@ -93,7 +97,7 @@ celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}" celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}" {{- end }} -{{- define "llmEngine.serviceTemplateAffinity" -}} +{{- define "modelEngine.serviceTemplateAffinity" -}} podAffinity: preferredDuringSchedulingIgnoredDuringExecution: - weight: 1 @@ -116,7 +120,7 @@ podAffinity: topologyKey: kubernetes.io/hostname {{- end }} -{{- define "llmEngine.baseServiceTemplateEnv" -}} +{{- define "modelEngine.baseServiceTemplateEnv" -}} env: - name: DATADOG_TRACE_ENABLED value: "${DATADOG_TRACE_ENABLED}" @@ -125,7 +129,7 @@ env: - name: DD_ENV value: {{ .Values.context }} - name: DD_VERSION - value: {{ .Values.tag }} + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -150,20 +154,20 @@ env: value: "${PREWARM}" - name: ML_INFRA_SERVICES_CONFIG_PATH {{- if .Values.config.file }} - value: "${BASE_PATH}/model-engine/model_engine_server/core/configs/{{ .Values.config.file.infra }}" + value: {{ .Values.config.file.infra | quote }} {{- else }} value: "${BASE_PATH}/model-engine/model_engine_server/core/configs/config.yaml" {{- end }} {{- end }} -{{- define "llmEngine.syncServiceTemplateEnv" -}} -{{- include "llmEngine.baseServiceTemplateEnv" . }} +{{- define "modelEngine.syncServiceTemplateEnv" -}} +{{- include "modelEngine.baseServiceTemplateEnv" . }} - name: PORT value: "${ARTIFACT_LIKE_CONTAINER_PORT}" {{- end }} -{{- define "llmEngine.asyncServiceTemplateEnv" -}} -{{- include "llmEngine.baseServiceTemplateEnv" . }} +{{- define "modelEngine.asyncServiceTemplateEnv" -}} +{{- include "modelEngine.baseServiceTemplateEnv" . }} - name: CELERY_S3_BUCKET value: "${CELERY_S3_BUCKET}" - name: BROKER_TYPE @@ -176,7 +180,7 @@ env: value: "${SQS_QUEUE_URL}" {{- end }} -{{- define "llmEngine.baseForwarderTemplateEnv" -}} +{{- define "modelEngine.baseForwarderTemplateEnv" -}} env: - name: DATADOG_TRACE_ENABLED value: "${DATADOG_TRACE_ENABLED}" @@ -185,7 +189,7 @@ env: - name: DD_ENV value: {{ .Values.context }} - name: DD_VERSION - value: {{ .Values.tag }} + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: @@ -198,22 +202,22 @@ env: value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH {{- if .Values.config.file }} - value: "/workspace/model-engine/model_engine_server/core/configs/{{ .Values.config.file.infra }}" + value: {{ .Values.config.file.infra | quote }} {{- else }} value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" {{- end }} {{- end }} -{{- define "llmEngine.syncForwarderTemplateEnv" -}} -{{- include "llmEngine.baseForwarderTemplateEnv" . }} +{{- define "modelEngine.syncForwarderTemplateEnv" -}} +{{- include "modelEngine.baseForwarderTemplateEnv" . }} {{- if and .Values.forwarder .Values.forwarder.forceUseIPv4 }} - name: HTTP_HOST value: "0.0.0.0" {{- end }} {{- end }} -{{- define "llmEngine.asyncForwarderTemplateEnv" -}} -{{- include "llmEngine.baseForwarderTemplateEnv" . }} +{{- define "modelEngine.asyncForwarderTemplateEnv" -}} +{{- include "modelEngine.baseForwarderTemplateEnv" . }} - name: CELERY_QUEUE value: "${QUEUE}" - name: CELERY_TASK_VISIBILITY @@ -222,29 +226,29 @@ env: value: "${CELERY_S3_BUCKET}" {{- end }} -{{- define "llmEngine.serviceEnv" }} +{{- define "modelEngine.serviceEnvBase" }} env: - name: DATADOG_TRACE_ENABLED value: "{{ .Values.datadog_trace_enabled }}" - name: DD_ENV value: {{ .Values.context }} - - name: DD_VERSION - value: {{ .Values.tag }} - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - - name: GIT_TAG - value: {{ .Values.tag }} - name: SERVICE_IDENTIFIER {{- if .Values.serviceIdentifier }} value: {{ .Values.serviceIdentifier }} {{- end }} + - name: GATEWAY_URL + value: {{ include "modelEngine.gatewayurl" . }} {{- if .Values.aws }} - name: AWS_PROFILE value: {{ .Values.aws.profileName }} - name: ECR_READ_AWS_PROFILE value: {{ .Values.aws.profileName }} + - name: S3_WRITE_AWS_PROFILE + value: {{ .Values.aws.s3WriteProfileName }} {{- end }} {{- with .Values.secrets }} {{- if .kubernetesDatabaseSecretName }} @@ -260,88 +264,109 @@ env: {{- end }} {{- if .Values.config.file }} - name: DEPLOY_SERVICE_CONFIG_PATH - value: "/workspace/llm_engine/service_configs/{{ .Values.config.file.llm_engine }}" + value: {{ .Values.config.file.launch | quote }} - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/model-engine/model_engine_server/core/configs/{{ .Values.config.file.infra }}" + value: {{ .Values.config.file.infra | quote }} {{- else }} - name: DEPLOY_SERVICE_CONFIG_PATH - value: "/workspace/llm_engine/service_configs/service_config.yaml" + value: "/workspace/model-engine/service_configs/service_config.yaml" - name: ML_INFRA_SERVICES_CONFIG_PATH value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" {{- end }} - name: CELERY_ELASTICACHE_ENABLED value: "true" - - name: LLM_ENGINE_SERVICE_TEMPLATE_FOLDER - value: "/workspace/llm_engine/llm_engine/infra/gateways/resources/templates" + - name: LAUNCH_SERVICE_TEMPLATE_FOLDER + value: "/workspace/model-engine/model_engine_server/infra/gateways/resources/templates" + {{- if .Values.redis.auth}} + - name: REDIS_AUTH_TOKEN + value: {{ .Values.redis.auth }} + {{- end }} {{- end }} -{{- define "llmEngine.gatewayEnv" }} -{{- include "llmEngine.serviceEnv" . }} +{{- define "modelEngine.serviceEnvGitTagFromHelmVar" }} +{{- include "modelEngine.serviceEnvBase" . }} + - name: DD_VERSION + value: {{ .Values.tag }} + - name: GIT_TAG + value: {{ .Values.tag }} +{{- end }} + +{{- define "modelEngine.serviceEnvGitTagFromPythonReplace" }} +{{- include "modelEngine.serviceEnvBase" . }} + - name: DD_VERSION + value: "${GIT_TAG}" + - name: GIT_TAG + value: "${GIT_TAG}" +{{- end }} + + +{{- define "modelEngine.gatewayEnv" }} +{{- include "modelEngine.serviceEnvGitTagFromHelmVar" . }} - name: DD_SERVICE - value: {{- printf " %s" (include "llmEngine.fullname" .) }} + value: {{- printf " %s" (include "modelEngine.fullname" .) }} {{- end }} -{{- define "llmEngine.builderEnv" }} -{{- include "llmEngine.serviceEnv" . }} +{{- define "modelEngine.builderEnv" }} +{{- include "modelEngine.serviceEnvGitTagFromHelmVar" . }} - name: DD_SERVICE - value: {{- printf " %s" (include "llmEngine.buildername" .) }} + value: {{- printf " %s" (include "modelEngine.buildername" .) }} {{- end }} -{{- define "llmEngine.cacherEnv" }} -{{- include "llmEngine.serviceEnv" . }} +{{- define "modelEngine.cacherEnv" }} +{{- include "modelEngine.serviceEnvGitTagFromHelmVar" . }} - name: DD_SERVICE - value: {{- printf " %s" (include "llmEngine.cachername" .) }} + value: {{- printf " %s" (include "modelEngine.cachername" .) }} {{- end }} -{{- define "llmEngine.volumes" }} +{{- define "modelEngine.volumes" }} volumes: - name: dshm emptyDir: medium: Memory - name: service-template-config configMap: - name: {{ include "llmEngine.fullname" . }}-service-template-config + name: {{ include "modelEngine.fullname" . }}-service-template-config {{- if .Values.aws }} - name: config-volume configMap: name: {{ .Values.aws.configMap.name }} {{- end }} {{- if .Values.config.values }} - - name: llm-engine-service-config-volume + - name: {{ .Chart.Name }}-service-config-volume configMap: - name: {{ include "llmEngine.fullname" . }}-service-config + name: {{ include "modelEngine.fullname" . }}-service-config items: - - key: llm_engine_service_config + - key: launch_service_config path: service_config.yaml - name: infra-service-config-volume configMap: - name: {{ include "llmEngine.fullname" . }}-service-config + name: {{ include "modelEngine.fullname" . }}-service-config items: - key: infra_service_config path: config.yaml {{- end }} {{- end }} -{{- define "llmEngine.volumeMounts" }} +{{- define "modelEngine.volumeMounts" }} volumeMounts: - name: dshm mountPath: /dev/shm - name: service-template-config - mountPath: /workspace/llm_engine/llm_engine/infra/gateways/resources/templates + mountPath: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates {{- if .Values.aws }} - name: config-volume - mountPath: /root/.aws/config + mountPath: {{ .Values.aws.configMap.mountPath }} subPath: config {{- end }} {{- if .Values.config.values }} - - name: llm-engine-service-config-volume - mountPath: /workspace/llm_engine/service_configs + - name: {{ .Chart.Name }}-service-config-volume + mountPath: /workspace/model-engine/service_configs - name: infra-service-config-volume mountPath: /workspace/model-engine/model_engine_server/core/configs {{- end }} {{- end }} -{{- define "llmEngine.forwarderVolumeMounts" }} +{{- define "modelEngine.forwarderVolumeMounts" }} volumeMounts: - name: config-volume mountPath: /root/.aws/config @@ -358,7 +383,7 @@ volumeMounts: {{- end }} {{- end }} -{{- define "llmEngine.serviceAccountNamespaces" }} +{{- define "modelEngine.serviceAccountNamespaces" }} namespaces: - {{ .Release.Namespace }} {{- range .Values.serviceAccount.namespaces }} diff --git a/charts/model-engine/templates/aws_config_map.yaml b/charts/model-engine/templates/aws_config_map.yaml new file mode 100644 index 00000000..60b91c97 --- /dev/null +++ b/charts/model-engine/templates/aws_config_map.yaml @@ -0,0 +1,26 @@ +{{- if .Values.aws }} +{{- if eq .Values.aws.configMap.create true }} +{{- $name := .Values.aws.configMap.name }} +{{- $profileName := .Values.aws.profileName }} +{{- $annotations := .Values.serviceAccount.annotations }} +{{- $labels := include "modelEngine.labels" . }} +{{- range $namespace := .Values.aws.configMap.namespaces }} +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ $name }} + namespace: {{- printf " %s" $namespace }} + labels: + {{- $labels | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" +data: + config: |- + [profile {{ $profileName }}] + role_arn = {{ index $annotations "eks.amazonaws.com/role-arn" }} + web_identity_token_file = /var/run/secrets/eks.amazonaws.com/serviceaccount/token +--- +{{- end }} +{{- end }} +{{- end }} diff --git a/charts/llm-engine/templates/balloon_a100_deployment.yaml b/charts/model-engine/templates/balloon_a100_deployment.yaml similarity index 76% rename from charts/llm-engine/templates/balloon_a100_deployment.yaml rename to charts/model-engine/templates/balloon_a100_deployment.yaml index 471a1599..50dbfea4 100644 --- a/charts/llm-engine/templates/balloon_a100_deployment.yaml +++ b/charts/model-engine/templates/balloon_a100_deployment.yaml @@ -2,7 +2,7 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-balloon-a100 + name: {{ .Chart.Name }}-balloon-a100 labels: team: infra product: common-warm-nodes @@ -10,12 +10,12 @@ spec: replicas: {{ .Values.replicaCount.balloonA100 }} selector: matchLabels: - app: llm-engine-balloon-a100 + app: {{ .Chart.Name }}-balloon-a100 version: v1 template: metadata: labels: - app: llm-engine-balloon-a100 + app: {{ .Chart.Name }}-balloon-a100 product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -25,14 +25,16 @@ spec: spec: nodeSelector: k8s.amazonaws.com/accelerator: nvidia-ampere-a100 - node-lifecycle: normal + {{- with .Values.balloonNodeSelector }} + {{- toYaml . | nindent 8 }} + {{- end }} tolerations: - key: "nvidia.com/gpu" operator: "Exists" effect: "NoSchedule" containers: - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: {{ .Values.image.pullPolicy }} + imagePullPolicy: IfNotPresent name: main resources: limits: @@ -44,5 +46,5 @@ spec: - -c - "while true; do sleep 30; done" terminationGracePeriodSeconds: 0 - priorityClassName: llm-engine-low-priority + priorityClassName: {{ .Chart.Name }}-low-priority {{- end }} diff --git a/charts/llm-engine/templates/balloon_a10_deployment.yaml b/charts/model-engine/templates/balloon_a10_deployment.yaml similarity index 76% rename from charts/llm-engine/templates/balloon_a10_deployment.yaml rename to charts/model-engine/templates/balloon_a10_deployment.yaml index db0f49b3..5e71af2b 100644 --- a/charts/llm-engine/templates/balloon_a10_deployment.yaml +++ b/charts/model-engine/templates/balloon_a10_deployment.yaml @@ -2,7 +2,7 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-balloon-a10 + name: {{ .Chart.Name }}-balloon-a10 labels: team: infra product: common-warm-nodes @@ -10,12 +10,12 @@ spec: replicas: {{ .Values.replicaCount.balloonA10 }} selector: matchLabels: - app: llm-engine-balloon-a10 + app: {{ .Chart.Name }}-balloon-a10 version: v1 template: metadata: labels: - app: llm-engine-balloon-a10 + app: {{ .Chart.Name }}-balloon-a10 product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -25,14 +25,16 @@ spec: spec: nodeSelector: k8s.amazonaws.com/accelerator: nvidia-ampere-a10 - node-lifecycle: normal + {{- with .Values.balloonNodeSelector }} + {{- toYaml . | nindent 8 }} + {{- end }} tolerations: - key: "nvidia.com/gpu" operator: "Exists" effect: "NoSchedule" containers: - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: {{ .Values.image.pullPolicy }} + imagePullPolicy: IfNotPresent name: main resources: limits: @@ -44,5 +46,5 @@ spec: - -c - "while true; do sleep 30; done" terminationGracePeriodSeconds: 0 - priorityClassName: llm-engine-low-priority + priorityClassName: {{ .Chart.Name }}-low-priority {{- end }} diff --git a/charts/llm-engine/templates/balloon_cpu_deployment.yaml b/charts/model-engine/templates/balloon_cpu_deployment.yaml similarity index 72% rename from charts/llm-engine/templates/balloon_cpu_deployment.yaml rename to charts/model-engine/templates/balloon_cpu_deployment.yaml index 21c1c6c3..1fd9e6c1 100644 --- a/charts/llm-engine/templates/balloon_cpu_deployment.yaml +++ b/charts/model-engine/templates/balloon_cpu_deployment.yaml @@ -2,7 +2,7 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-balloon-cpu + name: {{ .Chart.Name }}-balloon-cpu labels: team: infra product: common-warm-nodes @@ -10,12 +10,12 @@ spec: replicas: {{ .Values.replicaCount.balloonCpu }} selector: matchLabels: - app: llm-engine-balloon-cpu + app: {{ .Chart.Name }}-balloon-cpu version: v1 template: metadata: labels: - app: llm-engine-balloon-cpu + app: {{ .Chart.Name }}-balloon-cpu product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -23,11 +23,13 @@ spec: annotations: sidecar.istio.io/inject: "false" spec: + {{- with .Values.balloonNodeSelector }} nodeSelector: - node-lifecycle: normal + {{- toYaml . | nindent 8 }} + {{- end }} containers: - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: {{ .Values.image.pullPolicy }} + imagePullPolicy: IfNotPresent name: main resources: limits: @@ -38,5 +40,5 @@ spec: - -c - "while true; do sleep 30; done" terminationGracePeriodSeconds: 0 - priorityClassName: llm-engine-low-priority + priorityClassName: {{ .Chart.Name }}-low-priority {{- end }} diff --git a/charts/llm-engine/templates/balloon_t4_deployment.yaml b/charts/model-engine/templates/balloon_t4_deployment.yaml similarity index 76% rename from charts/llm-engine/templates/balloon_t4_deployment.yaml rename to charts/model-engine/templates/balloon_t4_deployment.yaml index 6c549853..6a5e8292 100644 --- a/charts/llm-engine/templates/balloon_t4_deployment.yaml +++ b/charts/model-engine/templates/balloon_t4_deployment.yaml @@ -2,7 +2,7 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-balloon-t4 + name: {{ .Chart.Name }}-balloon-t4 labels: team: infra product: common-warm-nodes @@ -10,12 +10,12 @@ spec: replicas: {{ .Values.replicaCount.balloonT4 }} selector: matchLabels: - app: llm-engine-balloon-t4 + app: {{ .Chart.Name }}-balloon-t4 version: v1 template: metadata: labels: - app: llm-engine-balloon-t4 + app: {{ .Chart.Name }}-balloon-t4 product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -25,14 +25,16 @@ spec: spec: nodeSelector: k8s.amazonaws.com/accelerator: nvidia-tesla-t4 - node-lifecycle: normal + {{- with .Values.balloonNodeSelector }} + {{- toYaml . | nindent 8 }} + {{- end }} tolerations: - key: "nvidia.com/gpu" operator: "Exists" effect: "NoSchedule" containers: - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: {{ .Values.image.pullPolicy }} + imagePullPolicy: IfNotPresent name: main resources: limits: @@ -44,5 +46,5 @@ spec: - -c - "while true; do sleep 30; done" terminationGracePeriodSeconds: 0 - priorityClassName: llm-engine-low-priority + priorityClassName: {{ .Chart.Name }}-low-priority {{- end }} diff --git a/charts/llm-engine/templates/cacher_deployment.yaml b/charts/model-engine/templates/cacher_deployment.yaml similarity index 61% rename from charts/llm-engine/templates/cacher_deployment.yaml rename to charts/model-engine/templates/cacher_deployment.yaml index 1191cb40..4cb2a9c2 100644 --- a/charts/llm-engine/templates/cacher_deployment.yaml +++ b/charts/model-engine/templates/cacher_deployment.yaml @@ -1,28 +1,28 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: {{ include "llmEngine.cachername" . }} + name: {{ include "modelEngine.cachername" . }} labels: - {{- include "llmEngine.selectorLabels.cacher" . | nindent 4 }} - {{- include "llmEngine.labels" . | nindent 4 }} - tags.datadoghq.com/service: {{ include "llmEngine.cachername" . }} + {{- include "modelEngine.selectorLabels.cacher" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} + tags.datadoghq.com/service: {{ include "modelEngine.cachername" . }} spec: replicas: {{ .Values.replicaCount.cacher }} selector: matchLabels: - {{- include "llmEngine.selectorLabels.cacher" . | nindent 6 }} + {{- include "modelEngine.selectorLabels.cacher" . | nindent 6 }} template: metadata: annotations: ad.datadoghq.com/main.logs: | [{ - "service": {{ include "llmEngine.cachername" . | quote }}, + "service": {{ include "modelEngine.cachername" . | quote }}, "source": "python" }] labels: - {{- include "llmEngine.selectorLabels.cacher" . | nindent 8 }} - {{- include "llmEngine.labels" . | nindent 8 }} - tags.datadoghq.com/service: {{ include "llmEngine.cachername" . }} + {{- include "modelEngine.selectorLabels.cacher" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} + tags.datadoghq.com/service: {{ include "modelEngine.cachername" . }} sidecar.istio.io/inject: "false" spec: {{- with .Values.imagePullSecrets }} @@ -30,7 +30,7 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} containers: - - name: {{ include "llmEngine.cachername" . }} + - name: {{ include "modelEngine.cachername" . }} image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} ports: @@ -52,10 +52,10 @@ spec: - model_engine_server.entrypoints.k8s_cache resources: {{- toYaml .Values.resources | nindent 12 }} - {{- include "llmEngine.cacherEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + {{- include "modelEngine.cacherEnv" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/llm-engine/templates/cacher_vpa.yaml b/charts/model-engine/templates/cacher_vpa.yaml similarity index 74% rename from charts/llm-engine/templates/cacher_vpa.yaml rename to charts/model-engine/templates/cacher_vpa.yaml index 0b79d1d5..4a07b3df 100644 --- a/charts/llm-engine/templates/cacher_vpa.yaml +++ b/charts/model-engine/templates/cacher_vpa.yaml @@ -2,19 +2,19 @@ apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler metadata: - name: {{ include "llmEngine.cachername" . }} + name: {{ include "modelEngine.cachername" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} spec: targetRef: apiVersion: "apps/v1" kind: Deployment - name: {{ include "llmEngine.cachername" . }} + name: {{ include "modelEngine.cachername" . }} updatePolicy: updateMode: "Auto" resourcePolicy: containerPolicies: - - containerName: {{ include "llmEngine.cachername" . }} + - containerName: {{ include "modelEngine.cachername" . }} minAllowed: cpu: {{ .Values.autoscaling.vertical.minAllowed.cpu }} memory: {{ .Values.autoscaling.vertical.minAllowed.memory }} diff --git a/charts/llm-engine/templates/cluster_rolebinding.yaml b/charts/model-engine/templates/cluster_rolebinding.yaml similarity index 58% rename from charts/llm-engine/templates/cluster_rolebinding.yaml rename to charts/model-engine/templates/cluster_rolebinding.yaml index b438ae93..bdafd94b 100644 --- a/charts/llm-engine/templates/cluster_rolebinding.yaml +++ b/charts/model-engine/templates/cluster_rolebinding.yaml @@ -1,11 +1,11 @@ -{{- $serviceAccountName := include "llmEngine.fullname" . }} -{{- $serviceAccountNamespaces := (include "llmEngine.serviceAccountNamespaces" . | fromYaml) }} +{{- $serviceAccountName := include "modelEngine.fullname" . }} +{{- $serviceAccountNamespaces := (include "modelEngine.serviceAccountNamespaces" . | fromYaml) }} apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding metadata: - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} roleRef: apiGroup: rbac.authorization.k8s.io kind: ClusterRole diff --git a/charts/llm-engine/templates/database_init_job.yaml b/charts/model-engine/templates/database_init_job.yaml similarity index 70% rename from charts/llm-engine/templates/database_init_job.yaml rename to charts/model-engine/templates/database_init_job.yaml index 571dd1f8..c87b7e92 100644 --- a/charts/llm-engine/templates/database_init_job.yaml +++ b/charts/model-engine/templates/database_init_job.yaml @@ -2,9 +2,9 @@ apiVersion: batch/v1 kind: Job metadata: - name: {{ include "llmEngine.fullname" . }}-database-setup + name: {{ include "modelEngine.fullname" . }}-database-setup labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} annotations: "helm.sh/hook": pre-install "helm.sh/hook-weight": "-1" @@ -16,7 +16,7 @@ spec: metadata: labels: sidecar.istio.io/inject: "false" - {{- include "llmEngine.labels" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} spec: restartPolicy: Never {{- with .Values.imagePullSecrets }} @@ -24,7 +24,7 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} containers: - - name: {{ include "llmEngine.fullname" . }} + - name: {{ include "modelEngine.fullname" . }} image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} command: @@ -34,10 +34,10 @@ spec: - python - -m - model_engine_server.entrypoints.init_database - {{- include "llmEngine.serviceEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + {{- include "modelEngine.serviceEnvGitTagFromHelmVar" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/llm-engine/templates/endpoint_builder_deployment.yaml b/charts/model-engine/templates/endpoint_builder_deployment.yaml similarity index 61% rename from charts/llm-engine/templates/endpoint_builder_deployment.yaml rename to charts/model-engine/templates/endpoint_builder_deployment.yaml index f42afc07..2f62a11a 100644 --- a/charts/llm-engine/templates/endpoint_builder_deployment.yaml +++ b/charts/model-engine/templates/endpoint_builder_deployment.yaml @@ -1,29 +1,29 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: {{ include "llmEngine.buildername" . }} + name: {{ include "modelEngine.buildername" . }} labels: - {{- include "llmEngine.selectorLabels.builder" . | nindent 4 }} - {{- include "llmEngine.labels" . | nindent 4 }} - tags.datadoghq.com/service: {{ include "llmEngine.buildername" . }} + {{- include "modelEngine.selectorLabels.builder" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} + tags.datadoghq.com/service: {{ include "modelEngine.buildername" . }} spec: replicas: {{ .Values.replicaCount.builder }} selector: matchLabels: - {{- include "llmEngine.selectorLabels.builder" . | nindent 6 }} + {{- include "modelEngine.selectorLabels.builder" . | nindent 6 }} template: metadata: annotations: cluster-autoscaler.kubernetes.io/safe-to-evict: "false" ad.datadoghq.com/main.logs: | [{ - "service": {{ include "llmEngine.buildername" . | quote }}, + "service": {{ include "modelEngine.buildername" . | quote }}, "source": "python" }] labels: - {{- include "llmEngine.selectorLabels.builder" . | nindent 8 }} - {{- include "llmEngine.labels" . | nindent 8 }} - tags.datadoghq.com/service: {{ include "llmEngine.buildername" . }} + {{- include "modelEngine.selectorLabels.builder" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} + tags.datadoghq.com/service: {{ include "modelEngine.buildername" . }} sidecar.istio.io/inject: "false" spec: {{- with .Values.imagePullSecrets }} @@ -31,7 +31,7 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} containers: - - name: {{ include "llmEngine.buildername" . }} + - name: {{ include "modelEngine.buildername" . }} image: "{{ .Values.image.builderRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} ports: @@ -54,16 +54,16 @@ spec: - --loglevel=INFO - --concurrency=2 {{- if .Values.serviceIdentifier }} - - --queues=llm-engine-{{ .Values.serviceIdentifier }}.service-builder + - --queues=model-engine-{{ .Values.serviceIdentifier }}.service-builder {{- else }} - - --queues=llm-engine.service-builder + - --queues=model-engine.service-builder {{- end }} resources: {{- toYaml .Values.resources | nindent 12 }} - {{- include "llmEngine.builderEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + {{- include "modelEngine.builderEnv" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/llm-engine/templates/endpoint_builder_vpa.yaml b/charts/model-engine/templates/endpoint_builder_vpa.yaml similarity index 74% rename from charts/llm-engine/templates/endpoint_builder_vpa.yaml rename to charts/model-engine/templates/endpoint_builder_vpa.yaml index e467e53a..64983d94 100644 --- a/charts/llm-engine/templates/endpoint_builder_vpa.yaml +++ b/charts/model-engine/templates/endpoint_builder_vpa.yaml @@ -2,19 +2,19 @@ apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler metadata: - name: {{ include "llmEngine.buildername" . }} + name: {{ include "modelEngine.buildername" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} spec: targetRef: apiVersion: "apps/v1" kind: Deployment - name: {{ include "llmEngine.buildername" . }} + name: {{ include "modelEngine.buildername" . }} updatePolicy: updateMode: "Auto" resourcePolicy: containerPolicies: - - containerName: {{ include "llmEngine.buildername" . }} + - containerName: {{ include "modelEngine.buildername" . }} minAllowed: cpu: {{ .Values.autoscaling.vertical.minAllowed.cpu }} memory: {{ .Values.autoscaling.vertical.minAllowed.memory }} diff --git a/charts/llm-engine/templates/gateway_deployment.yaml b/charts/model-engine/templates/gateway_deployment.yaml similarity index 67% rename from charts/llm-engine/templates/gateway_deployment.yaml rename to charts/model-engine/templates/gateway_deployment.yaml index f727d2d2..a58717a3 100644 --- a/charts/llm-engine/templates/gateway_deployment.yaml +++ b/charts/model-engine/templates/gateway_deployment.yaml @@ -1,11 +1,11 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} labels: - {{- include "llmEngine.selectorLabels.gateway" . | nindent 4 }} - {{- include "llmEngine.labels" . | nindent 4 }} - tags.datadoghq.com/service: {{ include "llmEngine.fullname" . }} + {{- include "modelEngine.selectorLabels.gateway" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} + tags.datadoghq.com/service: {{ include "modelEngine.fullname" . }} spec: {{- if not .Values.autoscaling.horizontal.enabled }} replicas: {{ .Values.replicaCount.gateway }} @@ -17,26 +17,26 @@ spec: maxSurge: 25% selector: matchLabels: - {{- include "llmEngine.selectorLabels.gateway" . | nindent 6 }} + {{- include "modelEngine.selectorLabels.gateway" . | nindent 6 }} template: metadata: annotations: ad.datadoghq.com/main.logs: | [{ - "service": {{ include "llmEngine.fullname" . | quote }}, + "service": {{ include "modelEngine.fullname" . | quote }}, "source": "python" }] labels: - {{- include "llmEngine.selectorLabels.gateway" . | nindent 8 }} - {{- include "llmEngine.labels" . | nindent 8 }} - tags.datadoghq.com/service: {{ include "llmEngine.fullname" . }} + {{- include "modelEngine.selectorLabels.gateway" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} + tags.datadoghq.com/service: {{ include "modelEngine.fullname" . }} spec: {{- with .Values.imagePullSecrets }} imagePullSecrets: {{- toYaml . | nindent 8 }} {{- end }} containers: - - name: {{ include "llmEngine.fullname" . }} + - name: {{ include "modelEngine.fullname" . }} image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} ports: @@ -66,10 +66,10 @@ spec: - model_engine_server.entrypoints.start_fastapi_server resources: {{- toYaml .Values.resources | nindent 12 }} - {{- include "llmEngine.gatewayEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + {{- include "modelEngine.gatewayEnv" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/llm-engine/templates/gateway_hpa.yaml b/charts/model-engine/templates/gateway_hpa.yaml similarity index 78% rename from charts/llm-engine/templates/gateway_hpa.yaml rename to charts/model-engine/templates/gateway_hpa.yaml index f9cd542e..9238b538 100644 --- a/charts/llm-engine/templates/gateway_hpa.yaml +++ b/charts/model-engine/templates/gateway_hpa.yaml @@ -2,14 +2,14 @@ apiVersion: autoscaling/v2 kind: HorizontalPodAutoscaler metadata: - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} spec: scaleTargetRef: apiVersion: apps/v1 kind: Deployment - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} minReplicas: {{ .Values.autoscaling.horizontal.minReplicas }} maxReplicas: {{ .Values.autoscaling.horizontal.maxReplicas }} metrics: diff --git a/charts/model-engine/templates/gateway_service.yaml b/charts/model-engine/templates/gateway_service.yaml new file mode 100644 index 00000000..1407ebef --- /dev/null +++ b/charts/model-engine/templates/gateway_service.yaml @@ -0,0 +1,18 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "modelEngine.fullname" . }} + labels: + {{- include "modelEngine.labels" . | nindent 4 }} +spec: + type: {{ .Values.service.type }} + ports: + - port: {{ .Values.service.port }} + targetPort: http + protocol: TCP + name: http + {{- with .Values.service.nodePort }} + nodePort: {{ . }} + {{- end }} + selector: + {{- include "modelEngine.selectorLabels.gateway" . | nindent 4 }} diff --git a/charts/llm-engine/templates/gateway_vpa.yaml b/charts/model-engine/templates/gateway_vpa.yaml similarity index 77% rename from charts/llm-engine/templates/gateway_vpa.yaml rename to charts/model-engine/templates/gateway_vpa.yaml index 4e93cd8a..061ed8cf 100644 --- a/charts/llm-engine/templates/gateway_vpa.yaml +++ b/charts/model-engine/templates/gateway_vpa.yaml @@ -2,21 +2,21 @@ apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler metadata: - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} spec: targetRef: apiVersion: "apps/v1" kind: Deployment - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} updatePolicy: updateMode: {{ .Values.autoscaling.vertical.updateMode }} resourcePolicy: containerPolicies: - containerName: istio-proxy mode: "Off" - - containerName: {{ include "llmEngine.fullname" . }} + - containerName: {{ include "modelEngine.fullname" . }} minAllowed: cpu: {{ .Values.autoscaling.vertical.minAllowed.cpu }} memory: {{ .Values.autoscaling.vertical.minAllowed.memory }} diff --git a/charts/model-engine/templates/istio-destinationrule.yaml b/charts/model-engine/templates/istio-destinationrule.yaml new file mode 100644 index 00000000..12b51afb --- /dev/null +++ b/charts/model-engine/templates/istio-destinationrule.yaml @@ -0,0 +1,18 @@ +{{- if .Values.destinationrule.enabled -}} +{{- $fullName := include "modelEngine.fullname" . -}} +apiVersion: networking.istio.io/v1beta1 +kind: DestinationRule +metadata: + name: {{ $fullName }} + labels: + {{- include "modelEngine.labels" . | nindent 4}} + {{- with .Values.destinationrule.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + host: "{{ $fullName }}.{{ .Release.Namespace }}.svc.cluster.local" + trafficPolicy: + loadBalancer: + simple: LEAST_REQUEST # Requires later version of Istio, which we have on the new clusters +{{- end }} diff --git a/charts/model-engine/templates/istio-virtualservice.yaml b/charts/model-engine/templates/istio-virtualservice.yaml new file mode 100644 index 00000000..1bd26e14 --- /dev/null +++ b/charts/model-engine/templates/istio-virtualservice.yaml @@ -0,0 +1,31 @@ +{{- if .Values.virtualservice.enabled -}} +{{- $fullName := include "modelEngine.fullname" . -}} +apiVersion: networking.istio.io/v1alpha3 +kind: VirtualService +metadata: + name: {{ $fullName }} + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + {{- with .Values.virtualservice.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + hosts: + {{- range .Values.virtualservice.hostDomains }} + - "{{ $fullName }}.{{ . }}" + {{- end }} + gateways: + {{- range .Values.virtualservice.gateways }} + - {{ . | quote }} + {{- end }} + http: + - route: + - destination: + host: "{{ $fullName }}.{{ .Release.Namespace }}.svc.cluster.local" + port: + number: 80 + retries: + attempts: 3 + retryOn: connect-failure,unavailable,gateway-error +{{- end }} diff --git a/charts/model-engine/templates/model_engine_default_priority_class.yaml b/charts/model-engine/templates/model_engine_default_priority_class.yaml new file mode 100644 index 00000000..a2d2dbb9 --- /dev/null +++ b/charts/model-engine/templates/model_engine_default_priority_class.yaml @@ -0,0 +1,11 @@ +{{- if not .Values.serviceIdentifier }} +apiVersion: scheduling.k8s.io/v1 +kind: PriorityClass +metadata: + name: "{{ include "modelEngine.fullname" . }}-default-priority" +value: 1 +# This ensures that the default launch pods will never preempt any pods, which means +# they cannot take advantage of the dummy nodes. +preemptionPolicy: Never +description: "Default Priority Class for Launch" +{{- end }} diff --git a/charts/llm-engine/templates/launch_high_priority_class.yaml b/charts/model-engine/templates/model_engine_high_priority_class.yaml similarity index 53% rename from charts/llm-engine/templates/launch_high_priority_class.yaml rename to charts/model-engine/templates/model_engine_high_priority_class.yaml index dd088b91..5dbfa7f0 100644 --- a/charts/llm-engine/templates/launch_high_priority_class.yaml +++ b/charts/model-engine/templates/model_engine_high_priority_class.yaml @@ -2,7 +2,7 @@ apiVersion: scheduling.k8s.io/v1 kind: PriorityClass metadata: - name: "{{ include "llmEngine.fullname" . }}-high-priority" + name: "{{ include "modelEngine.fullname" . }}-high-priority" value: 100000 -description: "High Priority Class for LLMEngine" +description: "High Priority Class for Launch" {{- end }} diff --git a/charts/llm-engine/templates/launch_low_priority_class.yaml b/charts/model-engine/templates/model_engine_low_priority_class.yaml similarity index 53% rename from charts/llm-engine/templates/launch_low_priority_class.yaml rename to charts/model-engine/templates/model_engine_low_priority_class.yaml index f40db336..71deb6c2 100644 --- a/charts/llm-engine/templates/launch_low_priority_class.yaml +++ b/charts/model-engine/templates/model_engine_low_priority_class.yaml @@ -2,7 +2,7 @@ apiVersion: scheduling.k8s.io/v1 kind: PriorityClass metadata: - name: "{{ include "llmEngine.fullname" . }}-low-priority" + name: "{{ include "modelEngine.fullname" . }}-low-priority" value: 0 -description: "Low Priority Class for LLMEngine" +description: "Low Priority Class for Launch" {{- end }} diff --git a/charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml b/charts/model-engine/templates/proportional_a100_autoscaler_deployment.yaml similarity index 77% rename from charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml rename to charts/model-engine/templates/proportional_a100_autoscaler_deployment.yaml index 4e99558f..f89f298e 100644 --- a/charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml +++ b/charts/model-engine/templates/proportional_a100_autoscaler_deployment.yaml @@ -3,19 +3,19 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-proportional-a100-autoscaler-deployment + name: {{ .Chart.Name }}-proportional-a100-autoscaler-deployment labels: team: infra product: common-warm-nodes spec: selector: matchLabels: - app: llm-engine-proportional-a100-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-a100-autoscaler-deployment version: v1 template: metadata: labels: - app: llm-engine-proportional-a100-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-a100-autoscaler-deployment product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -28,7 +28,7 @@ spec: operator: "Exists" containers: - image: registry.k8s.io/cpa/cluster-proportional-autoscaler:1.8.5 - imagePullPolicy: {{ .Values.image.pullPolicy }} + imagePullPolicy: IfNotPresent name: main resources: requests: @@ -38,12 +38,12 @@ spec: - /cluster-proportional-autoscaler - --namespace={{ .Release.Namespace }} - --configmap=cluster-proportional-autoscaler - - --target=deployment/llm-engine-balloon-a100 + - --target=deployment/{{ .Chart.Name }}-balloon-a100 - --default-params={"linear":{"nodesPerReplica":10,"preventSinglePointFailure":false,"includeUnschedulableNodes":false}} - --nodelabels=k8s.amazonaws.com/accelerator=nvidia-ampere-a100 - --logtostderr=true - --v=2 priorityClassName: system-cluster-critical - serviceAccountName: {{ include "llmEngine.fullname" . }} + serviceAccountName: {{ include "modelEngine.fullname" . }} {{- end }} {{- end }} diff --git a/charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml b/charts/model-engine/templates/proportional_a10_autoscaler_deployment.yaml similarity index 77% rename from charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml rename to charts/model-engine/templates/proportional_a10_autoscaler_deployment.yaml index e2dbe1c2..70274d26 100644 --- a/charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml +++ b/charts/model-engine/templates/proportional_a10_autoscaler_deployment.yaml @@ -3,19 +3,19 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-proportional-a10-autoscaler-deployment + name: {{ .Chart.Name }}-proportional-a10-autoscaler-deployment labels: team: infra product: common-warm-nodes spec: selector: matchLabels: - app: llm-engine-proportional-a10-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-a10-autoscaler-deployment version: v1 template: metadata: labels: - app: llm-engine-proportional-a10-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-a10-autoscaler-deployment product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -28,7 +28,7 @@ spec: operator: "Exists" containers: - image: registry.k8s.io/cpa/cluster-proportional-autoscaler:1.8.5 - imagePullPolicy: {{ .Values.image.pullPolicy }} + imagePullPolicy: IfNotPresent name: main resources: requests: @@ -38,12 +38,12 @@ spec: - /cluster-proportional-autoscaler - --namespace={{ .Release.Namespace }} - --configmap=cluster-proportional-autoscaler - - --target=deployment/llm-engine-balloon-a10 + - --target=deployment/{{ .Chart.Name }}-balloon-a10 - --default-params={"linear":{"nodesPerReplica":10,"preventSinglePointFailure":false,"includeUnschedulableNodes":false}} - --nodelabels=k8s.amazonaws.com/accelerator=nvidia-ampere-a10 - --logtostderr=true - --v=2 priorityClassName: system-cluster-critical - serviceAccountName: {{ include "llmEngine.fullname" . }} + serviceAccountName: {{ include "modelEngine.fullname" . }} {{- end }} {{- end }} diff --git a/charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml b/charts/model-engine/templates/proportional_t4_autoscaler_deployment.yaml similarity index 77% rename from charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml rename to charts/model-engine/templates/proportional_t4_autoscaler_deployment.yaml index bdb535e0..7175d985 100644 --- a/charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml +++ b/charts/model-engine/templates/proportional_t4_autoscaler_deployment.yaml @@ -3,19 +3,19 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-proportional-t4-autoscaler-deployment + name: {{ .Chart.Name }}-proportional-t4-autoscaler-deployment labels: team: infra product: common-warm-nodes spec: selector: matchLabels: - app: llm-engine-proportional-t4-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-t4-autoscaler-deployment version: v1 template: metadata: labels: - app: llm-engine-proportional-t4-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-t4-autoscaler-deployment product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -28,7 +28,7 @@ spec: operator: "Exists" containers: - image: registry.k8s.io/cpa/cluster-proportional-autoscaler:1.8.5 - imagePullPolicy: {{ .Values.image.pullPolicy }} + imagePullPolicy: IfNotPresent name: main resources: requests: @@ -38,12 +38,12 @@ spec: - /cluster-proportional-autoscaler - --namespace={{ .Release.Namespace }} - --configmap=cluster-proportional-autoscaler - - --target=deployment/llm-engine-balloon-t4 + - --target=deployment/{{ .Chart.Name }}-balloon-t4 - --default-params={"linear":{"nodesPerReplica":10,"preventSinglePointFailure":false,"includeUnschedulableNodes":false}} - --nodelabels=k8s.amazonaws.com/accelerator=nvidia-tesla-t4 - --logtostderr=true - --v=2 priorityClassName: system-cluster-critical - serviceAccountName: {{ include "llmEngine.fullname" . }} + serviceAccountName: {{ include "modelEngine.fullname" . }} {{- end }} {{- end }} diff --git a/charts/llm-engine/templates/service_account.yaml b/charts/model-engine/templates/service_account.yaml similarity index 66% rename from charts/llm-engine/templates/service_account.yaml rename to charts/model-engine/templates/service_account.yaml index 73be82d7..1d0d7d3b 100644 --- a/charts/llm-engine/templates/service_account.yaml +++ b/charts/model-engine/templates/service_account.yaml @@ -1,7 +1,7 @@ -{{- $serviceAccountName := include "llmEngine.fullname" . }} -{{- $serviceAccountNamespaces := (include "llmEngine.serviceAccountNamespaces" . | fromYaml) }} +{{- $serviceAccountName := include "modelEngine.fullname" . }} +{{- $serviceAccountNamespaces := (include "modelEngine.serviceAccountNamespaces" . | fromYaml) }} {{- $annotations := .Values.serviceAccount.annotations }} -{{- $labels := include "llmEngine.labels" . }} +{{- $labels := include "modelEngine.labels" . }} {{- range $namespace := (index $serviceAccountNamespaces "namespaces") }} apiVersion: v1 kind: ServiceAccount diff --git a/charts/llm-engine/templates/service_config_map.yaml b/charts/model-engine/templates/service_config_map.yaml similarity index 64% rename from charts/llm-engine/templates/service_config_map.yaml rename to charts/model-engine/templates/service_config_map.yaml index 003447dd..9234296d 100644 --- a/charts/llm-engine/templates/service_config_map.yaml +++ b/charts/model-engine/templates/service_config_map.yaml @@ -2,15 +2,16 @@ apiVersion: v1 kind: ConfigMap metadata: - name: {{ include "llmEngine.fullname" . }}-service-config + name: {{ include "modelEngine.fullname" . }}-service-config labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} annotations: "helm.sh/hook": pre-install,pre-upgrade "helm.sh/hook-weight": "-2" data: - llm_engine_service_config: |- - {{- with .Values.config.values.llm_engine }} + launch_service_config: |- + datadog_trace_enabled: {{ .Values.datadog_trace_enabled | default false | quote }} + {{- with .Values.config.values.launch }} {{- range $key, $value := . }} {{ $key }}: {{ $value | quote }} {{- end }} diff --git a/charts/llm-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml similarity index 78% rename from charts/llm-engine/templates/service_template_config_map.yaml rename to charts/model-engine/templates/service_template_config_map.yaml index b344d3cd..84c26206 100644 --- a/charts/llm-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -1,17 +1,17 @@ -{{- $llm_engine_name := include "llmEngine.fullname" . }} +{{- $launch_name := include "modelEngine.fullname" . }} {{- $config_values := .Values.config.values }} {{- $forwarder_repository := .Values.image.forwarderRepository -}} {{- $triton_repository := .Values.triton.image.repository -}} {{- $triton_tag := .Values.triton.image.tag -}} {{- $env := .Values.context -}} -{{- $service_template_labels := include "llmEngine.serviceTemplateLabels" . }} -{{- $job_template_labels := include "llmEngine.jobTemplateLabels" . }} -{{- $service_env := include "llmEngine.serviceEnv" . }} -{{- $async_service_template_env := include "llmEngine.asyncServiceTemplateEnv" . }} -{{- $sync_service_template_env := include "llmEngine.syncServiceTemplateEnv" . }} -{{- $async_forwarder_template_env := include "llmEngine.asyncForwarderTemplateEnv" . }} -{{- $sync_forwarder_template_env := include "llmEngine.syncForwarderTemplateEnv" . }} -{{- $forwarder_volume_mounts := include "llmEngine.forwarderVolumeMounts" . }} +{{- $service_template_labels := include "modelEngine.serviceTemplateLabels" . }} +{{- $job_template_labels := include "modelEngine.jobTemplateLabels" . }} +{{- $service_env := include "modelEngine.serviceEnvGitTagFromPythonReplace" . }} +{{- $async_service_template_env := include "modelEngine.asyncServiceTemplateEnv" . }} +{{- $sync_service_template_env := include "modelEngine.syncServiceTemplateEnv" . }} +{{- $async_forwarder_template_env := include "modelEngine.asyncForwarderTemplateEnv" . }} +{{- $sync_forwarder_template_env := include "modelEngine.syncForwarderTemplateEnv" . }} +{{- $forwarder_volume_mounts := include "modelEngine.forwarderVolumeMounts" . }} {{- $gateway_repository := .Values.image.gatewayRepository -}} {{- $tag := .Values.tag -}} {{- $aws_config_map_name := .Values.aws.configMap.name }} @@ -20,7 +20,7 @@ {{- $service_template_service_account_name := .Values.serviceTemplate.serviceAccountName }} {{- $service_template_aws_config_map_name := .Values.serviceTemplate.awsConfigMapName }} {{- $celery_broker_type := .Values.celeryBrokerType }} -{{- $image_pull_policy := .Values.image.pullPolicy }} +{{- $node_selector := .Values.nodeSelector }} {{- if .Values.message }} {{- .Values.message }} @@ -28,16 +28,16 @@ apiVersion: v1 kind: ConfigMap metadata: - name: {{ $llm_engine_name }}-service-template-config + name: {{ $launch_name }}-service-template-config labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} annotations: "helm.sh/hook": pre-install,pre-upgrade "helm.sh/hook-weight": "-2" data: {{- range $device := tuple "cpu" "gpu" }} {{- range $mode := tuple "async" "sync" "streaming"}} - {{- range $flavor := tuple "triton-enhanced-runnable-image" "runnable-image" "artifact" }} + {{- range $flavor := tuple "triton-enhanced-runnable-image" "runnable-image" }} {{- if or (ne $mode "streaming") (eq $flavor "runnable-image") }} deployment-{{ $flavor }}-{{ $mode }}-{{ $device }}.yaml: |- apiVersion: apps/v1 @@ -49,7 +49,7 @@ data: {{- $service_template_labels | nindent 8 }} {{- if eq $mode "async" }} annotations: - {{- include "llmEngine.serviceTemplateAsyncAnnotations" . | nindent 8 }} + {{- include "modelEngine.serviceTemplateAsyncAnnotations" . | nindent 8 }} {{- end }} spec: strategy: @@ -76,15 +76,17 @@ data: kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" spec: affinity: - {{- include "llmEngine.serviceTemplateAffinity" . | nindent 12 }} + {{- include "modelEngine.serviceTemplateAffinity" . | nindent 12 }} terminationGracePeriodSeconds: 600 {{- if $service_template_service_account_name }} serviceAccount: {{ $service_template_service_account_name }} {{- else }} - serviceAccount: {{ $llm_engine_name }} + serviceAccount: {{ $launch_name }} {{- end }} + {{- with $node_selector }} nodeSelector: - node-lifecycle: normal + {{- toYaml . | nindent 12 }} + {{- end }} {{- if eq $device "gpu" }} k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: @@ -94,104 +96,30 @@ data: {{- end }} priorityClassName: ${PRIORITY} containers: - {{- if eq $flavor "artifact" }} - - image: ${IMAGE} - imagePullPolicy: {{ $image_pull_policy }} - name: main - {{- with $security_context }} - securityContext: - {{- toYaml . | nindent 16 }} - {{- end }} - {{- if eq $mode "async" }} - {{- $async_service_template_env | nindent 14 }} - {{- else if eq $mode "sync" }} - {{- $sync_service_template_env | nindent 14 }} - {{- end }} - readinessProbe: - {{- if eq $mode "async" }} - exec: - command: - - cat - - /tmp/readyz - {{- else if eq $mode "sync" }} - httpGet: - path: /readyz - port: ${ARTIFACT_LIKE_CONTAINER_PORT} - {{- end }} - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - {{- if eq $mode "async" }} - # Not including --pool=solo means there's a worker process and a separate supervisor process - # meaning if the worker crashes (because of OOM or something) the supervisor process can mark the task as - # failed, which should get rid of infinite task retries - args: - - celery - - --app=llm_engine.inference.async_inference - - worker - - --loglevel=INFO - - --concurrency=1 - - --queues=${QUEUE} - - -O - - fair - {{- else if eq $mode "sync" }} - args: - - python - - -m - - llm_engine.inference.sync_inference.start_fastapi_server - {{- end }} - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - {{- if eq $device "gpu" }} - nvidia.com/gpu: ${GPUS} - {{- end }} - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: ${BASE_PATH}/user_config - subPath: raw_data - - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config - subPath: raw_data - {{- if $config_values }} - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs - {{- end }} - {{- else if contains "runnable-image" $flavor }} + {{- if contains "runnable-image" $flavor }} {{- if eq $mode "sync" }} - name: http-forwarder - image: {{ $forwarder_repository }}:${FORWARDER_IMAGE_TAG} - imagePullPolicy: {{ $image_pull_policy }} + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder - --config - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads - --port - "${FORWARDER_PORT}" - - --concurrency + - --num-workers - "${PER_WORKER}" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" {{- $sync_forwarder_template_env | nindent 14 }} readinessProbe: httpGet: @@ -215,7 +143,7 @@ data: {{- else if eq $mode "streaming" }} - name: http-forwarder image: {{ $forwarder_repository }}:{{ $tag }} - imagePullPolicy: {{ $image_pull_policy }} + imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- @@ -259,13 +187,15 @@ data: name: http {{- else if eq $mode "async" }} - name: celery-forwarder - image: {{ $forwarder_repository }}:${FORWARDER_IMAGE_TAG} - imagePullPolicy: {{ $image_pull_policy }} + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder - --config - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue @@ -280,7 +210,7 @@ data: - --sqs-url - "${SQS_QUEUE_URL}" {{- end }} - - --concurrency + - --num-workers - "${PER_WORKER}" {{- $async_forwarder_template_env | nindent 14 }} resources: @@ -297,7 +227,7 @@ data: {{- if eq $flavor "triton-enhanced-runnable-image" }} - name: tritonserver image: {{ $triton_repository }}:${TRITON_COMMIT_TAG}-triton - imagePullPolicy: {{ $image_pull_policy }} + imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- @@ -346,7 +276,7 @@ data: {{- toYaml . | nindent 16 }} {{- end }} image: ${IMAGE} - imagePullPolicy: {{ $image_pull_policy }} + imagePullPolicy: IfNotPresent command: ${COMMAND} env: ${MAIN_ENV} readinessProbe: @@ -379,7 +309,7 @@ data: {{- end }} # LIRA: For compatibility with runnable image converted from artifactlike bundle - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /home/modelengine/.aws/config subPath: config - name: user-config mountPath: /app/user_config @@ -414,7 +344,7 @@ data: {{- if $config_values }} - name: infra-service-config-volume configMap: - name: {{ $llm_engine_name }}-service-config + name: {{ $launch_name }}-service-config items: - key: infra_service_config path: config.yaml @@ -484,6 +414,7 @@ data: protocol: TCP name: http ${NODE_PORT_DICT} + {{- if .Values.virtualservice.enabled }} virtual-service.yaml: |- apiVersion: networking.istio.io/v1alpha3 kind: VirtualService @@ -491,19 +422,7 @@ data: name: ${RESOURCE_NAME} namespace: ${NAMESPACE} labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: model-engine - use_scale_launch_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: ${GIT_TAG} - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} + {{- $service_template_labels | nindent 8 }} spec: hosts: - ${RESOURCE_NAME}.${DNS_HOST_DOMAIN} @@ -515,6 +434,8 @@ data: host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" port: number: 80 + {{- end }} + {{- if .Values.destinationrule.enabled }} destination-rule.yaml: |- apiVersion: networking.istio.io/v1beta1 kind: DestinationRule @@ -522,24 +443,13 @@ data: name: ${RESOURCE_NAME} namespace: ${NAMESPACE} labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: model-engine - use_scale_launch_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: ${GIT_TAG} - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} + {{- $service_template_labels | nindent 8 }} spec: host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" trafficPolicy: loadBalancer: simple: LEAST_REQUEST + {{- end }} vertical-pod-autoscaler.yaml: |- apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler @@ -584,31 +494,33 @@ data: sidecar.istio.io/inject: "false" version: v1 annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:{{ $env }}", "llm_engine_job_id:${JOB_ID}"]}]' + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:{{ $env }}", "launch_job_id:${JOB_ID}"]}]' cluster-autoscaler.kubernetes.io/safe-to-evict: "false" spec: restartPolicy: Never + {{- with $node_selector }} nodeSelector: - node-lifecycle: normal - serviceAccountName: {{ $llm_engine_name }} + {{- toYaml . | nindent 12 }} + {{- end }} + serviceAccountName: {{ $launch_name }} volumes: - name: config-volume configMap: name: {{ $aws_config_map_name }} containers: - name: main - image: {{ $gateway_repository }}:{{ $tag }} + image: {{ $gateway_repository }}:${GIT_TAG} env: - name: DD_SERVICE value: ${RESOURCE_NAME} - {{- $env_vars := include "llmEngine.serviceEnv" . | fromYaml }} + {{- $env_vars := $service_env | fromYaml }} {{- range $env_var := index $env_vars "env" }} {{- $env_var_name := index $env_var "name" }} {{- if ne $env_var_name "DD_SERVICE" }} {{- tuple $env_var | toYaml | nindent 16 }} {{- end }} {{- end }} - imagePullPolicy: {{ $image_pull_policy }} + imagePullPolicy: Always command: - dumb-init - -- @@ -658,11 +570,13 @@ data: sidecar.istio.io/inject: "false" version: v1 annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:{{ $env }}", "llm_engine_job_id:${JOB_ID}"]}]' + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:{{ $env }}", "launch_job_id:${JOB_ID}"]}]' spec: restartPolicy: Never + {{- with $node_selector }} nodeSelector: - node-lifecycle: normal + {{- toYaml . | nindent 12 }} + {{- end }} {{- if eq $device "gpu" }} k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: @@ -673,7 +587,7 @@ data: {{- if $service_template_service_account_name }} serviceAccountName: {{ $service_template_service_account_name }} {{- else }} - serviceAccountName: {{ $llm_engine_name }} + serviceAccountName: {{ $launch_name }} {{- end }} volumes: - name: config-volume @@ -697,7 +611,7 @@ data: {{- tuple $env_var | toYaml | nindent 16 }} {{- end }} {{- end }} - imagePullPolicy: {{ $image_pull_policy }} + imagePullPolicy: Always command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. @@ -722,7 +636,7 @@ data: name: dshm initContainers: - name: input-downloader - image: {{ $gateway_repository }}:{{ $tag }} + image: {{ $gateway_repository }}:${GIT_TAG} command: - python - -m @@ -759,8 +673,8 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: selector: @@ -774,8 +688,8 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: launch + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 sidecar.istio.io/inject: "false" @@ -794,7 +708,7 @@ data: {{- end }} containers: - image: public.ecr.aws/docker/library/busybox:latest - imagePullPolicy: {{ $image_pull_policy }} + imagePullPolicy: IfNotPresent name: busybox command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] terminationGracePeriodSeconds: 0 diff --git a/charts/llm-engine/templates/llm_engine_init_job.yaml b/charts/model-engine/templates/spellbook_init_job.yaml similarity index 60% rename from charts/llm-engine/templates/llm_engine_init_job.yaml rename to charts/model-engine/templates/spellbook_init_job.yaml index c975355b..ed23f4e6 100644 --- a/charts/llm-engine/templates/llm_engine_init_job.yaml +++ b/charts/model-engine/templates/spellbook_init_job.yaml @@ -1,10 +1,10 @@ -{{- if (and .Values.llmEngineInitJob .Values.llmEngineInitJob.enabled) }} +{{- if and (.Values.secrets.kubernetesDatabaseSecretName) (.Values.spellbook.enabled) }} apiVersion: batch/v1 kind: Job metadata: - name: {{ include "llmEngine.fullname" . }}-init-job + name: {{ include "modelEngine.fullname" . }}-spellbook-setup labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} annotations: "helm.sh/hook": post-install "helm.sh/hook-weight": "0" @@ -16,7 +16,7 @@ spec: metadata: labels: sidecar.istio.io/inject: "false" - {{- include "llmEngine.labels" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} spec: restartPolicy: Never {{- with .Values.imagePullSecrets }} @@ -24,7 +24,7 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} containers: - - name: {{ include "llmEngine.fullname" . }} + - name: {{ include "modelEngine.fullname" . }} image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} command: @@ -33,13 +33,13 @@ spec: args: - python - -m - - model_engine_server.entrypoints.init_llm_engine_models + - model_engine_server.entrypoints.init_spellbook_models - --gateway-url - - 'http://{{- include "llmEngine.fullname" . }}.{{ .Release.Namespace }}:{{ .Values.service.port }}' - {{- include "llmEngine.serviceEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + - '{{- include "modelEngine.gatewayurl" . }}' + {{- include "modelEngine.serviceEnvGitTagFromHelmVar" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/model-engine/values.yaml b/charts/model-engine/values.yaml new file mode 100644 index 00000000..c228a34a --- /dev/null +++ b/charts/model-engine/values.yaml @@ -0,0 +1,9 @@ +datadog_trace_enabled: true +spellbook: + enabled: false +redis: + auth: +balloonNodeSelector: + node-lifecycle: normal +nodeSelector: + node-lifecycle: normal diff --git a/charts/llm-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml similarity index 97% rename from charts/llm-engine/values_circleci.yaml rename to charts/model-engine/values_circleci.yaml index fb170ae9..d31665ef 100644 --- a/charts/llm-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -38,6 +38,9 @@ virtualservice: gateways: - default/internal-gateway +hostDomain: + prefix: http:// + destinationrule: enabled: true annotations: { } @@ -64,7 +67,9 @@ resources: requests: cpu: 2 -nodeSelector: { } +nodeSelector: null + +balloonNodeSelector: null tolerations: [ ] @@ -82,7 +87,7 @@ config: s3_bucket: "$CIRCLECI_AWS_S3_BUCKET" profile_ml_worker: "default" profile_ml_inference_worker: "default" - llm_engine: + launch: # Endpoint config # K8s namespace the endpoints will be created in endpoint_namespace: model-engine @@ -151,7 +156,9 @@ aws: configMap: name: default-config create: false + mountPath: /root/.aws/config profileName: default + s3WriteProfileName: default forwarder: forceUseIPv4: true diff --git a/charts/llm-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml similarity index 100% rename from charts/llm-engine/values_sample.yaml rename to charts/model-engine/values_sample.yaml From 331349627e36a33e0d87488fb8b39309ea1a872f Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Fri, 25 Aug 2023 16:12:35 -0700 Subject: [PATCH 061/425] Fix API calls on Windows (#225) * replace os.path.join with urllib.parse.urljoin * bump version * bump versions --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/api_engine.py | 19 ++++++++++--------- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 3dab0728..4b1b86fb 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta11" +__version__ = "0.0.0.beta12" from typing import Sequence diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index 089138b7..aa857183 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -5,12 +5,13 @@ from functools import wraps from io import BufferedReader from typing import Any, AsyncIterable, Dict, Iterator, Optional +from urllib.parse import urljoin import requests from aiohttp import ClientSession, ClientTimeout from llmengine.errors import parse_error -SPELLBOOK_API_URL = "https://api.spellbook.scale.com/llm-engine" +SPELLBOOK_API_URL = "https://api.spellbook.scale.com/llm-engine/" LLM_ENGINE_BASE_PATH = os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) DEFAULT_TIMEOUT: int = 10 @@ -51,7 +52,7 @@ def validate_api_key(cls): def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: api_key = get_api_key() response = requests.get( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(LLM_ENGINE_BASE_PATH, resource_name), timeout=timeout, headers={"x-api-key": api_key}, auth=(api_key, ""), @@ -67,7 +68,7 @@ def put( ) -> Dict[str, Any]: api_key = get_api_key() response = requests.put( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data, timeout=timeout, headers={"x-api-key": api_key}, @@ -82,7 +83,7 @@ def put( def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: api_key = get_api_key() response = requests.delete( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(LLM_ENGINE_BASE_PATH, resource_name), timeout=timeout, headers={"x-api-key": api_key}, auth=(api_key, ""), @@ -96,7 +97,7 @@ def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Dict[str, Any]: api_key = get_api_key() response = requests.post( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data, timeout=timeout, headers={"x-api-key": api_key}, @@ -113,7 +114,7 @@ def post_stream( ) -> Iterator[Dict[str, Any]]: api_key = get_api_key() response = requests.post( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data, timeout=timeout, headers={"x-api-key": api_key}, @@ -145,7 +146,7 @@ def post_file( ) -> Dict[str, Any]: api_key = get_api_key() response = requests.post( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(LLM_ENGINE_BASE_PATH, resource_name), files=files, timeout=timeout, headers={"x-api-key": api_key}, @@ -164,7 +165,7 @@ async def apost_sync( timeout=ClientTimeout(timeout), headers={"x-api-key": api_key} ) as session: async with session.post( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), json=data + urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data ) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.read()) @@ -180,7 +181,7 @@ async def apost_stream( timeout=ClientTimeout(timeout), headers={"x-api-key": api_key} ) as session: async with session.post( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), json=data + urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data ) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.read()) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 4adfdb19..6352055d 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta11" +version = "0.0.0.beta12" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index b8559ba8..6af08c9a 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta11", + version="0.0.0.beta12", packages=find_packages(), ) From c5f0f221c205dd3823574a2c8f8d56466d25acda Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Mon, 28 Aug 2023 16:29:51 -0700 Subject: [PATCH 062/425] Sync scale from zero, part 1 (#229) * add patch files * fix ruff --- .../model_engine_server/api/dependencies.py | 5 ++ .../model_engine_server/api/tasks_v1.py | 5 +- .../model_engine_server/common/config.py | 2 +- .../model_engine_server/common/dtos/tasks.py | 9 ++- .../model_engine_server/domain/exceptions.py | 7 +++ .../domain/gateways/__init__.py | 2 + .../inference_autoscaling_metrics_gateway.py | 22 +++++++ ...eaming_model_endpoint_inference_gateway.py | 4 +- .../sync_model_endpoint_inference_gateway.py | 4 +- .../domain/services/model_endpoint_service.py | 11 ++++ .../use_cases/llm_model_endpoint_use_cases.py | 41 ++++++++++-- .../streaming_inference_use_cases.py | 10 ++- .../use_cases/sync_inference_use_cases.py | 10 ++- .../start_batch_job_orchestration.py | 5 ++ .../infra/gateways/__init__.py | 2 + ...eaming_model_endpoint_inference_gateway.py | 62 +++++++++++++++---- ...e_sync_model_endpoint_inference_gateway.py | 62 +++++++++++++++---- ...s_inference_autoscaling_metrics_gateway.py | 48 ++++++++++++++ .../services/live_model_endpoint_service.py | 10 +++ model-engine/tests/unit/conftest.py | 41 ++++++++++++ ...eaming_model_endpoint_inference_gateway.py | 14 ++--- ...e_sync_model_endpoint_inference_gateway.py | 14 ++--- .../tests/unit/infra/services/conftest.py | 2 + 23 files changed, 337 insertions(+), 55 deletions(-) create mode 100644 model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py create mode 100644 model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index c9e00eb9..2997542a 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -49,6 +49,7 @@ LiveStreamingModelEndpointInferenceGateway, LiveSyncModelEndpointInferenceGateway, ModelEndpointInfraGateway, + RedisInferenceAutoscalingMetricsGateway, S3FilesystemGateway, S3LLMArtifactGateway, ) @@ -179,6 +180,9 @@ def _get_external_interfaces( model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) + inference_autoscaling_metrics_gateway = RedisInferenceAutoscalingMetricsGateway( + redis_client=redis_client + ) # we can just reuse the existing redis client, we shouldn't get key collisions because of the prefix model_endpoint_service = LiveModelEndpointService( model_endpoint_record_repository=model_endpoint_record_repo, model_endpoint_infra_gateway=model_endpoint_infra_gateway, @@ -187,6 +191,7 @@ def _get_external_interfaces( streaming_model_endpoint_inference_gateway=streaming_model_endpoint_inference_gateway, sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, ) llm_model_endpoint_service = LiveLLMModelEndpointService( model_endpoint_record_repository=model_endpoint_record_repo, diff --git a/model-engine/model_engine_server/api/tasks_v1.py b/model-engine/model_engine_server/api/tasks_v1.py index 74b5b634..443b7fb7 100644 --- a/model-engine/model_engine_server/api/tasks_v1.py +++ b/model-engine/model_engine_server/api/tasks_v1.py @@ -11,6 +11,7 @@ CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, TaskStatus, ) @@ -97,7 +98,7 @@ def get_async_inference_task( @inference_task_router_v1.post("/sync-tasks", response_model=SyncEndpointPredictV1Response) async def create_sync_inference_task( model_endpoint_id: str, - request: EndpointPredictV1Request, + request: SyncEndpointPredictV1Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> SyncEndpointPredictV1Response: @@ -137,7 +138,7 @@ async def create_sync_inference_task( @inference_task_router_v1.post("/streaming-tasks") async def create_streaming_inference_task( model_endpoint_id: str, - request: EndpointPredictV1Request, + request: SyncEndpointPredictV1Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> EventSourceResponse: diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 4022ceb5..deeb4477 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -45,7 +45,7 @@ def get_model_cache_directory_name(model_name: str): class HostedModelInferenceServiceConfig: endpoint_namespace: str billing_queue_arn: str - cache_redis_url: str + cache_redis_url: str # also using this to store sync autoscaling metrics sqs_profile: str sqs_queue_policy_template: str sqs_queue_tag_template: str diff --git a/model-engine/model_engine_server/common/dtos/tasks.py b/model-engine/model_engine_server/common/dtos/tasks.py index 5b0bf580..36c20903 100644 --- a/model-engine/model_engine_server/common/dtos/tasks.py +++ b/model-engine/model_engine_server/common/dtos/tasks.py @@ -6,7 +6,7 @@ from typing import Any, Optional from model_engine_server.domain.entities import CallbackAuth -from pydantic import BaseModel +from pydantic import BaseModel, Field class ResponseSchema(BaseModel): @@ -49,3 +49,10 @@ class EndpointPredictV1Request(BaseModel): callback_url: Optional[str] = None callback_auth: Optional[CallbackAuth] = None return_pickled: bool = False + + +class SyncEndpointPredictV1Request(EndpointPredictV1Request): + timeout_seconds: Optional[float] = Field(default=None, gt=0) + num_retries: Optional[int] = Field(default=None, ge=0) + # See live_{sync,streaming}_model_endpoint_inference_gateway to see how timeout_seconds/num_retries interact. + # Also these fields are only relevant for sync endpoints diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index 66b6f708..c31eb0ad 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -59,6 +59,13 @@ class TooManyRequestsException(DomainException): """ +class NoHealthyUpstreamException(DomainException): + """ + Thrown if an endpoint returns a 503 exception for no healthy upstream. This can happen if there are zero pods + available to serve the request. + """ + + class CorruptRecordInfraStateException(DomainException): """ Thrown if the data from existing state (i.e. the db, k8s, etc.) is somehow uninterpretable diff --git a/model-engine/model_engine_server/domain/gateways/__init__.py b/model-engine/model_engine_server/domain/gateways/__init__.py index fea6d2b5..9550da56 100644 --- a/model-engine/model_engine_server/domain/gateways/__init__.py +++ b/model-engine/model_engine_server/domain/gateways/__init__.py @@ -2,6 +2,7 @@ from .cron_job_gateway import CronJobGateway from .docker_image_batch_job_gateway import DockerImageBatchJobGateway from .file_storage_gateway import FileStorageGateway +from .inference_autoscaling_metrics_gateway import InferenceAutoscalingMetricsGateway from .llm_artifact_gateway import LLMArtifactGateway from .model_endpoints_schema_gateway import ModelEndpointsSchemaGateway from .model_primitive_gateway import ModelPrimitiveGateway @@ -15,6 +16,7 @@ "CronJobGateway", "DockerImageBatchJobGateway", "FileStorageGateway", + "InferenceAutoscalingMetricsGateway", "LLMArtifactGateway", "ModelEndpointsSchemaGateway", "ModelPrimitiveGateway", diff --git a/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py new file mode 100644 index 00000000..da603b4d --- /dev/null +++ b/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod + + +class InferenceAutoscalingMetricsGateway(ABC): + """ + Abstract Base Class for a gateway that emits autoscaling metrics for inference requests. Can be used in conjunction + with various autoscaler resources, e.g. a Keda ScaledObject, to autoscale inference endpoints. + """ + + @abstractmethod + async def emit_inference_autoscaling_metric(self, endpoint_id: str): + """ + On an inference request, emit a metric + """ + pass + + @abstractmethod + async def emit_prewarm_metric(self, endpoint_id: str): + """ + If you want to prewarm an endpoint, emit a metric here + """ + pass diff --git a/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py index 565cbe45..8b80a525 100644 --- a/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py @@ -2,7 +2,7 @@ from typing import AsyncIterable from model_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) @@ -17,7 +17,7 @@ class StreamingModelEndpointInferenceGateway(ABC): @abstractmethod def streaming_predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, topic: str, predict_request: SyncEndpointPredictV1Request ) -> AsyncIterable[SyncEndpointPredictV1Response]: """ Runs a prediction request and returns a streaming response. diff --git a/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py index 99ec36fa..90d77950 100644 --- a/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from model_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) @@ -16,7 +16,7 @@ class SyncModelEndpointInferenceGateway(ABC): @abstractmethod async def predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, topic: str, predict_request: SyncEndpointPredictV1Request ) -> SyncEndpointPredictV1Response: """ Runs a prediction request and returns a response. diff --git a/model-engine/model_engine_server/domain/services/model_endpoint_service.py b/model-engine/model_engine_server/domain/services/model_endpoint_service.py index 90b50983..8492ae45 100644 --- a/model-engine/model_engine_server/domain/services/model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/model_endpoint_service.py @@ -18,6 +18,9 @@ StreamingModelEndpointInferenceGateway, SyncModelEndpointInferenceGateway, ) +from model_engine_server.domain.gateways.inference_autoscaling_metrics_gateway import ( + InferenceAutoscalingMetricsGateway, +) class ModelEndpointService(ABC): @@ -49,6 +52,14 @@ def get_streaming_model_endpoint_inference_gateway( Returns the sync model endpoint inference gateway. """ + @abstractmethod + def get_inference_auto_scaling_metrics_gateway( + self, + ) -> InferenceAutoscalingMetricsGateway: + """ + Returns the inference autoscaling metrics gateway. + """ + @abstractmethod async def create_model_endpoint( self, diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index fa4eef0f..638fd9e2 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -29,7 +29,7 @@ ) from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from model_engine_server.common.dtos.tasks import EndpointPredictV1Request, TaskStatus +from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User from model_engine_server.core.domain_exceptions import ( @@ -105,6 +105,10 @@ } +NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes +DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes + + def _model_endpoint_entity_to_get_llm_model_endpoint_response( model_endpoint: ModelEndpoint, ) -> GetLLMModelEndpointV1Response: @@ -495,6 +499,10 @@ async def execute( post_inference_hooks=request.post_inference_hooks, ) + await self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway().emit_prewarm_metric( + model_endpoint_record.id + ) + return CreateLLMModelEndpointV1Response( endpoint_creation_task_id=model_endpoint_record.creation_task_id # type: ignore ) @@ -694,6 +702,12 @@ async def execute( ) inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) if endpoint_content.inference_framework == LLMInferenceFramework.DEEPSPEED: args: Any = { @@ -710,7 +724,11 @@ async def execute( # Deepspeed models only accepts one stop sequence args["stop_sequence"] = request.stop_sequences[0] - inference_request = EndpointPredictV1Request(args=args) + inference_request = SyncEndpointPredictV1Request( + args=args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) predict_result = await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=inference_request ) @@ -745,7 +763,11 @@ async def execute( tgi_args["parameters"]["temperature"] = request.temperature tgi_args["parameters"]["do_sample"] = True - inference_request = EndpointPredictV1Request(args=tgi_args) + inference_request = SyncEndpointPredictV1Request( + args=tgi_args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) predict_result = await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=inference_request ) @@ -834,6 +856,12 @@ async def execute( inference_gateway = ( self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() ) + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) @@ -865,8 +893,11 @@ async def execute( args["parameters"]["temperature"] = request.temperature args["parameters"]["do_sample"] = True - inference_request = EndpointPredictV1Request(args=args) - + inference_request = SyncEndpointPredictV1Request( + args=args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) predict_result = inference_gateway.streaming_predict( topic=model_endpoint.record.destination, predict_request=inference_request ) diff --git a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py index f4dfce40..1fb70023 100644 --- a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py @@ -1,7 +1,7 @@ from typing import AsyncIterable from model_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) from model_engine_server.core.auth.authentication_repository import User @@ -27,7 +27,7 @@ def __init__(self, model_endpoint_service: ModelEndpointService): self.authz_module = LiveAuthorizationModule() async def execute( - self, user: User, model_endpoint_id: str, request: EndpointPredictV1Request + self, user: User, model_endpoint_id: str, request: SyncEndpointPredictV1Request ) -> AsyncIterable[SyncEndpointPredictV1Response]: """ Runs the use case to create a sync inference task. @@ -61,6 +61,12 @@ async def execute( inference_gateway = ( self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() ) + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint_id + ) return inference_gateway.streaming_predict( topic=model_endpoint.record.destination, predict_request=request ) diff --git a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py index 7ef1f8bd..d785beed 100644 --- a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py @@ -1,5 +1,5 @@ from model_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) from model_engine_server.core.auth.authentication_repository import User @@ -25,7 +25,7 @@ def __init__(self, model_endpoint_service: ModelEndpointService): self.authz_module = LiveAuthorizationModule() async def execute( - self, user: User, model_endpoint_id: str, request: EndpointPredictV1Request + self, user: User, model_endpoint_id: str, request: SyncEndpointPredictV1Request ) -> SyncEndpointPredictV1Response: """ Runs the use case to create a sync inference task. @@ -65,6 +65,12 @@ async def execute( ) inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint_id + ) return await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=request ) diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index 01a03445..6139a2a0 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -20,6 +20,7 @@ LiveModelEndpointsSchemaGateway, LiveStreamingModelEndpointInferenceGateway, LiveSyncModelEndpointInferenceGateway, + RedisInferenceAutoscalingMetricsGateway, S3FilesystemGateway, ) from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( @@ -95,6 +96,9 @@ async def run_batch_job( model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) + inference_autoscaling_metrics_gateway = RedisInferenceAutoscalingMetricsGateway( + redis_client=redis, + ) model_endpoint_service = LiveModelEndpointService( model_endpoint_record_repository=model_endpoint_record_repo, model_endpoint_infra_gateway=model_endpoint_infra_gateway, @@ -103,6 +107,7 @@ async def run_batch_job( streaming_model_endpoint_inference_gateway=streaming_model_endpoint_inference_gateway, sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, ) batch_job_record_repository = DbBatchJobRecordRepository(session=session, read_only=False) batch_job_progress_gateway = LiveBatchJobProgressGateway(filesystem_gateway=filesystem_gateway) diff --git a/model-engine/model_engine_server/infra/gateways/__init__.py b/model-engine/model_engine_server/infra/gateways/__init__.py index 0417527d..0f2b5faa 100644 --- a/model-engine/model_engine_server/infra/gateways/__init__.py +++ b/model-engine/model_engine_server/infra/gateways/__init__.py @@ -18,6 +18,7 @@ ) from .live_sync_model_endpoint_inference_gateway import LiveSyncModelEndpointInferenceGateway from .model_endpoint_infra_gateway import ModelEndpointInfraGateway +from .redis_inference_autoscaling_metrics_gateway import RedisInferenceAutoscalingMetricsGateway from .s3_filesystem_gateway import S3FilesystemGateway from .s3_llm_artifact_gateway import S3LLMArtifactGateway @@ -38,6 +39,7 @@ "LiveStreamingModelEndpointInferenceGateway", "LiveSyncModelEndpointInferenceGateway", "ModelEndpointInfraGateway", + "RedisInferenceAutoscalingMetricsGateway", "S3FilesystemGateway", "S3LLMArtifactGateway", ] diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index 9103b3e9..dd3aec47 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -6,14 +6,18 @@ import sseclient from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, TaskStatus, ) from model_engine_server.common.env_vars import CIRCLECI, LOCAL from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import filename_wo_ext, make_logger -from model_engine_server.domain.exceptions import TooManyRequestsException, UpstreamServiceError +from model_engine_server.domain.exceptions import ( + NoHealthyUpstreamException, + TooManyRequestsException, + UpstreamServiceError, +) from model_engine_server.domain.gateways.streaming_model_endpoint_inference_gateway import ( StreamingModelEndpointInferenceGateway, ) @@ -25,13 +29,19 @@ RetryError, retry_if_exception_type, stop_after_attempt, + stop_after_delay, + stop_any, wait_exponential, ) logger = make_logger(filename_wo_ext(__file__)) -SYNC_ENDPOINT_RETRIES = 5 # Must be an integer >= 0 +SYNC_ENDPOINT_RETRIES = 8 # Must be an integer >= 0 SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 +SYNC_ENDPOINT_MAX_RETRY_WAIT = 5 +SYNC_ENDPOINT_EXP_BACKOFF_BASE = ( + 1.2 # Must be a float > 1.0, lower number means more retries but less time waiting. +) def _get_streaming_endpoint_url(deployment_name: str) -> str: @@ -107,6 +117,8 @@ async def make_single_request(self, request_url: str, payload_json: Dict[str, An if errored: if status == 429: raise TooManyRequestsException("429 returned") + if status == 503: + raise NoHealthyUpstreamException("503 returned") else: raise UpstreamServiceError(status_code=status, content=content) @@ -126,9 +138,18 @@ async def make_request_with_retries( try: async for attempt in AsyncRetrying( - stop=stop_after_attempt(num_retries + 1), - retry=retry_if_exception_type(TooManyRequestsException), - wait=wait_exponential(multiplier=1, min=1, max=timeout_seconds), + stop=stop_any( + stop_after_attempt(num_retries + 1), stop_after_delay(timeout_seconds) + ), + retry=retry_if_exception_type( + (TooManyRequestsException, NoHealthyUpstreamException) + ), + wait=wait_exponential( + multiplier=1, + min=1, + max=SYNC_ENDPOINT_MAX_RETRY_WAIT, + exp_base=SYNC_ENDPOINT_EXP_BACKOFF_BASE, + ), ): with attempt: logger.info(f"Retry number {attempt.retry_state.attempt_number}") @@ -136,9 +157,16 @@ async def make_request_with_retries( async for item in response: yield orjson.loads(item) return - except RetryError: - logger.warning("Hit max # of retries, returning 429 to client") - raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") + except RetryError as e: + if type(e.last_attempt.exception()) == TooManyRequestsException: + logger.warning("Hit max # of retries, returning 429 to client") + raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") + elif type(e.last_attempt.exception()) == NoHealthyUpstreamException: + logger.warning("Pods didn't spin up in time, returning 503 to client") + raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") + else: + logger.error("Unknown Exception Type") + raise UpstreamServiceError(status_code=500, content=b"Unknown error") except JSONDecodeError: logger.exception("JSONDecodeError") raise UpstreamServiceError(status_code=500, content=b"JSONDecodeError") @@ -149,16 +177,26 @@ async def make_request_with_retries( raise Exception("Should never reach this line") async def streaming_predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, topic: str, predict_request: SyncEndpointPredictV1Request ) -> AsyncIterable[SyncEndpointPredictV1Response]: deployment_url = _get_streaming_endpoint_url(topic) try: + timeout_seconds = ( + SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS + if predict_request.timeout_seconds is None + else predict_request.timeout_seconds + ) + num_retries = ( + SYNC_ENDPOINT_RETRIES + if predict_request.num_retries is None + else predict_request.num_retries + ) response = self.make_request_with_retries( request_url=deployment_url, payload_json=predict_request.dict(), - timeout_seconds=SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS, - num_retries=SYNC_ENDPOINT_RETRIES, + timeout_seconds=timeout_seconds, + num_retries=num_retries, ) async for item in response: yield SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=item) diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index c29fcf53..0a763b1f 100644 --- a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -5,14 +5,18 @@ import requests from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, TaskStatus, ) from model_engine_server.common.env_vars import CIRCLECI, LOCAL from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import filename_wo_ext, make_logger -from model_engine_server.domain.exceptions import TooManyRequestsException, UpstreamServiceError +from model_engine_server.domain.exceptions import ( + NoHealthyUpstreamException, + TooManyRequestsException, + UpstreamServiceError, +) from model_engine_server.domain.gateways.sync_model_endpoint_inference_gateway import ( SyncModelEndpointInferenceGateway, ) @@ -23,13 +27,19 @@ RetryError, retry_if_exception_type, stop_after_attempt, + stop_after_delay, + stop_any, wait_exponential, ) logger = make_logger(filename_wo_ext(__file__)) -SYNC_ENDPOINT_RETRIES = 5 # Must be an integer >= 0 +SYNC_ENDPOINT_RETRIES = 8 # Must be an integer >= 0 SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 +SYNC_ENDPOINT_MAX_RETRY_WAIT = 5 +SYNC_ENDPOINT_EXP_BACKOFF_BASE = ( + 1.2 # Must be a float > 1.0, lower number means more retries but less time waiting. +) def _get_sync_endpoint_url(deployment_name: str) -> str: @@ -90,6 +100,8 @@ async def make_single_request(self, request_url: str, payload_json: Dict[str, An # tenacity can properly capture them. if status == 429: raise TooManyRequestsException("429 returned") + if status == 503: + raise NoHealthyUpstreamException("503 returned") else: raise UpstreamServiceError(status_code=status, content=content) @@ -109,16 +121,32 @@ async def make_request_with_retries( try: async for attempt in AsyncRetrying( - stop=stop_after_attempt(num_retries + 1), - retry=retry_if_exception_type(TooManyRequestsException), - wait=wait_exponential(multiplier=1, min=1, max=timeout_seconds), + stop=stop_any( + stop_after_attempt(num_retries + 1), stop_after_delay(timeout_seconds) + ), + retry=retry_if_exception_type( + (TooManyRequestsException, NoHealthyUpstreamException) + ), + wait=wait_exponential( + multiplier=1, + min=1, + max=SYNC_ENDPOINT_MAX_RETRY_WAIT, + exp_base=SYNC_ENDPOINT_EXP_BACKOFF_BASE, + ), ): with attempt: logger.info(f"Retry number {attempt.retry_state.attempt_number}") return await self.make_single_request(request_url, payload_json) - except RetryError: - logger.warning("Hit max # of retries, returning 429 to client") - raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") + except RetryError as e: + if type(e.last_attempt.exception()) == TooManyRequestsException: + logger.warning("Hit max # of retries, returning 429 to client") + raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") + elif type(e.last_attempt.exception()) == NoHealthyUpstreamException: + logger.warning("Pods didn't spin up in time, returning 503 to client") + raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") + else: + logger.error("Unknown Exception Type") + raise UpstreamServiceError(status_code=500, content=b"Unknown error") # Never reached because tenacity should throw a RetryError if we exit the for loop. # This is for mypy. @@ -126,16 +154,26 @@ async def make_request_with_retries( return {} async def predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, topic: str, predict_request: SyncEndpointPredictV1Request ) -> SyncEndpointPredictV1Response: deployment_url = _get_sync_endpoint_url(topic) try: + timeout_seconds = ( + SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS + if predict_request.timeout_seconds is None + else predict_request.timeout_seconds + ) + num_retries = ( + SYNC_ENDPOINT_RETRIES + if predict_request.num_retries is None + else predict_request.num_retries + ) response = await self.make_request_with_retries( request_url=deployment_url, payload_json=predict_request.dict(), - timeout_seconds=SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS, - num_retries=SYNC_ENDPOINT_RETRIES, + timeout_seconds=timeout_seconds, + num_retries=num_retries, ) except UpstreamServiceError as exc: logger.error(f"Service error on sync task: {exc.content!r}") diff --git a/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py new file mode 100644 index 00000000..5761493e --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py @@ -0,0 +1,48 @@ +from typing import Optional + +import aioredis +from model_engine_server.domain.gateways.inference_autoscaling_metrics_gateway import ( + InferenceAutoscalingMetricsGateway, +) + +EXPIRY_SECONDS = 60 # 1 minute; this gets added to the cooldown time present in the keda ScaledObject to get total +# scaledown time. This also needs to be larger than the keda ScaledObject's refresh rate. +PREWARM_EXPIRY_SECONDS = 60 * 60 # 1 hour + + +class RedisInferenceAutoscalingMetricsGateway(InferenceAutoscalingMetricsGateway): + def __init__( + self, redis_info: Optional[str] = None, redis_client: Optional[aioredis.Redis] = None + ): + assert redis_info or redis_client, "Either redis_info or redis_client must be defined." + if redis_info: + # If aioredis cannot create a connection pool, reraise that as an error because the + # default error message is cryptic and not obvious. + try: + self._redis = aioredis.from_url(redis_info, health_check_interval=60) + except Exception as exc: + raise RuntimeError( + "If redis_info is specified, RedisInferenceAutoscalingMetricsGateway must be" + "initialized within a coroutine. Please specify the redis_client directly." + ) from exc + else: + assert redis_client is not None # for mypy + self._redis = redis_client + + @staticmethod + def _find_redis_key(endpoint_id: str): + return f"launch-endpoint-autoscaling:{endpoint_id}" + + async def _emit_metric(self, endpoint_id: str, expiry_time: int): + key = self._find_redis_key(endpoint_id) + await self._redis.expire(key, expiry_time) # does nothing if key doesn't exist, + # but avoids a race condition where the key expires in between the lpush and subsequent expire commands + await self._redis.lpush(key, 1) # we only care about the length of the list, not the values + await self._redis.ltrim(key, 0, 0) # we only want to scale from 0 to 1 for redis + await self._redis.expire(key, expiry_time) + + async def emit_inference_autoscaling_metric(self, endpoint_id: str): + await self._emit_metric(endpoint_id, EXPIRY_SECONDS) + + async def emit_prewarm_metric(self, endpoint_id: str): + await self._emit_metric(endpoint_id, PREWARM_EXPIRY_SECONDS) diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index ced2a6e1..ab88886c 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -27,6 +27,9 @@ StreamingModelEndpointInferenceGateway, SyncModelEndpointInferenceGateway, ) +from model_engine_server.domain.gateways.inference_autoscaling_metrics_gateway import ( + InferenceAutoscalingMetricsGateway, +) from model_engine_server.domain.services import ModelEndpointService from model_engine_server.domain.use_cases.model_endpoint_use_cases import MODEL_BUNDLE_CHANGED_KEY from model_engine_server.infra.gateways import ModelEndpointInfraGateway @@ -51,6 +54,7 @@ def __init__( streaming_model_endpoint_inference_gateway: StreamingModelEndpointInferenceGateway, sync_model_endpoint_inference_gateway: SyncModelEndpointInferenceGateway, model_endpoints_schema_gateway: ModelEndpointsSchemaGateway, + inference_autoscaling_metrics_gateway: InferenceAutoscalingMetricsGateway, ): self.model_endpoint_record_repository = model_endpoint_record_repository self.model_endpoint_infra_gateway = model_endpoint_infra_gateway @@ -59,6 +63,7 @@ def __init__( self.streaming_model_endpoint_inference_gateway = streaming_model_endpoint_inference_gateway self.sync_model_endpoint_inference_gateway = sync_model_endpoint_inference_gateway self.model_endpoints_schema_gateway = model_endpoints_schema_gateway + self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway def get_async_model_endpoint_inference_gateway( self, @@ -75,6 +80,11 @@ def get_streaming_model_endpoint_inference_gateway( ) -> StreamingModelEndpointInferenceGateway: return self.streaming_model_endpoint_inference_gateway + def get_inference_auto_scaling_metrics_gateway( + self, + ) -> InferenceAutoscalingMetricsGateway: + return self.inference_autoscaling_metrics_gateway + async def _get_model_endpoint_infra_state( self, record: ModelEndpointRecord, use_cache: bool ) -> Optional[ModelEndpointInfraState]: diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index df1b5e2e..9714fd9d 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -36,6 +36,7 @@ CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, TaskStatus, ) @@ -92,6 +93,7 @@ CronJobGateway, DockerImageBatchJobGateway, FileStorageGateway, + InferenceAutoscalingMetricsGateway, LLMArtifactGateway, StreamingModelEndpointInferenceGateway, SyncModelEndpointInferenceGateway, @@ -1519,6 +1521,14 @@ def get_last_request(self): return self.tasks[-1] +class FakeInferenceAutoscalingMetricsGateway(InferenceAutoscalingMetricsGateway): + async def emit_inference_autoscaling_metric(self, endpoint_id: str): + pass + + async def emit_prewarm_metric(self, endpoint_id: str): + pass + + class FakeModelEndpointService(ModelEndpointService): db: Dict[str, ModelEndpoint] @@ -1531,6 +1541,7 @@ def __init__( StreamingModelEndpointInferenceGateway ] = None, sync_model_endpoint_inference_gateway: Optional[SyncModelEndpointInferenceGateway] = None, + inference_autoscaling_metrics_gateway: Optional[InferenceAutoscalingMetricsGateway] = None, ): if contents: self.db = contents @@ -1560,6 +1571,11 @@ def __init__( if sync_model_endpoint_inference_gateway is None: sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway() self.sync_model_endpoint_inference_gateway = sync_model_endpoint_inference_gateway + + if inference_autoscaling_metrics_gateway is None: + inference_autoscaling_metrics_gateway = FakeInferenceAutoscalingMetricsGateway() + self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway + self.model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=FakeFilesystemGateway() ) @@ -1579,6 +1595,11 @@ def get_sync_model_endpoint_inference_gateway( ) -> SyncModelEndpointInferenceGateway: return self.sync_model_endpoint_inference_gateway + def get_inference_auto_scaling_metrics_gateway( + self, + ) -> InferenceAutoscalingMetricsGateway: + return self.inference_autoscaling_metrics_gateway + def add_model_endpoint(self, model_endpoint: ModelEndpoint): self.db[model_endpoint.record.id] = model_endpoint @@ -1989,6 +2010,12 @@ def fake_sync_model_endpoint_inference_gateway() -> FakeSyncModelEndpointInferen return gateway +@pytest.fixture +def fake_inference_autoscaling_metrics_gateway() -> FakeInferenceAutoscalingMetricsGateway: + gateway = FakeInferenceAutoscalingMetricsGateway() + return gateway + + @pytest.fixture def fake_file_storage_gateway() -> FakeFileStorageGateway: gateway = FakeFileStorageGateway() @@ -2073,6 +2100,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: FakeStreamingModelEndpointInferenceGateway() ) sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway() + inference_autoscaling_metrics_gateway = FakeInferenceAutoscalingMetricsGateway() model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=FakeFilesystemGateway(), ) @@ -2083,6 +2111,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: async_model_endpoint_inference_gateway=async_model_endpoint_inference_gateway, streaming_model_endpoint_inference_gateway=streaming_model_endpoint_inference_gateway, sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, ) fake_batch_job_service = LiveBatchJobService( @@ -3360,6 +3389,18 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An return request, request_dict +@pytest.fixture +def sync_endpoint_predict_request_1() -> Tuple[SyncEndpointPredictV1Request, Dict[str, Any]]: + request = SyncEndpointPredictV1Request( + url="test_url", + return_pickled=False, + timeout_seconds=10, + num_retries=5, + ) + request_dict = request.dict() + return request, request_dict + + @pytest.fixture def llm_model_endpoint_async( test_api_key: str, model_bundle_1: ModelBundle diff --git a/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py index 35256fce..13980fb9 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py @@ -5,7 +5,7 @@ import pytest from model_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) from model_engine_server.domain.exceptions import UpstreamServiceError @@ -104,7 +104,7 @@ async def test_make_request_with_retries_failed_traceback(): @pytest.mark.asyncio async def test_streaming_predict_success( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] ): gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) @@ -115,7 +115,7 @@ async def test_streaming_predict_success( mock_client_session, ): response = gateway.streaming_predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] ) count = 0 async for message in response: @@ -131,7 +131,7 @@ async def test_streaming_predict_success( @pytest.mark.asyncio async def test_predict_raises_traceback_json( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] ): gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) @@ -143,7 +143,7 @@ async def test_predict_raises_traceback_json( mock_client_session, ): response = gateway.streaming_predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] ) count = 0 async for message in response: @@ -159,7 +159,7 @@ async def test_predict_raises_traceback_json( @pytest.mark.asyncio async def test_predict_raises_traceback_not_json( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] ): gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) @@ -171,7 +171,7 @@ async def test_predict_raises_traceback_not_json( mock_client_session, ): response = gateway.streaming_predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] ) count = 0 async for message in response: diff --git a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py index 758e6ce0..6241fe6e 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py @@ -5,7 +5,7 @@ import pytest from model_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) from model_engine_server.domain.exceptions import UpstreamServiceError @@ -82,7 +82,7 @@ async def test_make_request_with_retries_failed_traceback(): @pytest.mark.asyncio async def test_predict_success( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] ): gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) @@ -93,7 +93,7 @@ async def test_predict_success( mock_client_session, ): response = await gateway.predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] ) assert isinstance(response, SyncEndpointPredictV1Response) assert response.dict() == { @@ -105,7 +105,7 @@ async def test_predict_success( @pytest.mark.asyncio async def test_predict_raises_traceback_json( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] ): gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) @@ -117,7 +117,7 @@ async def test_predict_raises_traceback_json( mock_client_session, ): response = await gateway.predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] ) assert isinstance(response, SyncEndpointPredictV1Response) assert response.dict() == { @@ -129,7 +129,7 @@ async def test_predict_raises_traceback_json( @pytest.mark.asyncio async def test_predict_raises_traceback_not_json( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] ): gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) @@ -141,7 +141,7 @@ async def test_predict_raises_traceback_not_json( mock_client_session, ): response = await gateway.predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] ) assert isinstance(response, SyncEndpointPredictV1Response) assert response.dict() == { diff --git a/model-engine/tests/unit/infra/services/conftest.py b/model-engine/tests/unit/infra/services/conftest.py index 9efc271f..acd43e5a 100644 --- a/model-engine/tests/unit/infra/services/conftest.py +++ b/model-engine/tests/unit/infra/services/conftest.py @@ -16,6 +16,7 @@ def fake_live_model_endpoint_service( fake_async_model_endpoint_inference_gateway, fake_streaming_model_endpoint_inference_gateway, fake_sync_model_endpoint_inference_gateway, + fake_inference_autoscaling_metrics_gateway, fake_filesystem_gateway, model_bundle_1: ModelBundle, model_bundle_2: ModelBundle, @@ -37,6 +38,7 @@ def fake_live_model_endpoint_service( async_model_endpoint_inference_gateway=fake_async_model_endpoint_inference_gateway, streaming_model_endpoint_inference_gateway=fake_streaming_model_endpoint_inference_gateway, sync_model_endpoint_inference_gateway=fake_sync_model_endpoint_inference_gateway, + inference_autoscaling_metrics_gateway=fake_inference_autoscaling_metrics_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, ) return service From daf232883bf32e9a75399d7227ad5aa7efa1567d Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 28 Aug 2023 18:12:57 -0700 Subject: [PATCH 063/425] Doc add fine tune support for llama 2 70b (#232) --- docs/model_zoo.md | 2 +- requirements-dev.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 264196a6..5c0bab7c 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -9,7 +9,7 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `llama-2-7b-chat` | ✅ | | | `llama-2-13b` | ✅ | | | `llama-2-13b-chat` | ✅ | | -| `llama-2-70b` | ✅ | | +| `llama-2-70b` | ✅ | ✅ | | `llama-2-70b-chat` | ✅ | | | `falcon-7b` | ✅ | | | `falcon-7b-instruct` | ✅ | | diff --git a/requirements-dev.txt b/requirements-dev.txt index e3edc67e..f6e4d22c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,3 +6,4 @@ isort==5.12.0 mypy==1.3.0 pip-tools==7.0.0 poetry==1.5.1 +pre-commit==3.3.3 \ No newline at end of file From c7bd0c62946f6d0322af530da94f8c73ce30ac8a Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Mon, 28 Aug 2023 18:37:26 -0700 Subject: [PATCH 064/425] add file upload guidance to fine-tune docs (#231) --- clients/python/llmengine/fine_tuning.py | 18 +++++++++--------- docs/guides/fine_tuning.md | 24 ++++++++++++++---------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index 6a19dc3b..91f9d05f 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -107,14 +107,14 @@ def create( writer.writerows(data) ``` - Currently, data needs to be uploaded to a publicly accessible web URL so that it can be read - for fine-tuning. Publicly accessible HTTP and HTTPS URLs are currently supported. - Support for privately sharing data with the LLM Engine API is coming shortly. For quick - iteration, you can look into tools like Pastebin or GitHub Gists to quickly host your CSV - files in a public manner. An example Github Gist can be found - [here](https://gist.github.com/tigss/7cec73251a37de72756a3b15eace9965). To use the gist, - you can use the URL given when you click the “Raw” button - ([URL](https://gist.githubusercontent.com/tigss/7cec73251a37de72756a3b15eace9965/raw/85d9742890e1e6b0c06468507292893b820c13c9/llm_sample_data.csv)). + Currently, data needs to be uploaded to either a publicly accessible web URL or to LLM Engine's + private file server so that it can be read for fine-tuning. Publicly accessible HTTP and HTTPS + URLs are currently supported. + + To privately share data with the LLM Engine API, use LLM Engine's [File.upload](../../api/python_client/#llmengine.File.upload) + API. You can upload data in local file to LLM Engine's private file server and then use the + returned file ID to reference your data in the FineTune API. The file ID is generally in the + form of `file-`, e.g. "file-7DLVeLdN2Ty4M2m". Example code for fine-tuning: === "Fine-tuning in Python" @@ -123,7 +123,7 @@ def create( response = FineTune.create( model="llama-2-7b", - training_file="https://my-bucket.s3.us-west-2.amazonaws.com/path/to/training-file.csv", + training_file="file-7DLVeLdN2Ty4M2m", ) print(response.json()) diff --git a/docs/guides/fine_tuning.md b/docs/guides/fine_tuning.md index b5d43673..af1fe703 100644 --- a/docs/guides/fine_tuning.md +++ b/docs/guides/fine_tuning.md @@ -103,14 +103,18 @@ with open('customer_service_data.csv', 'w', newline='') as file: ## Making your data accessible to LLM Engine -Currently, data needs to be uploaded to a publicly accessible web URL so that it can be read -for fine-tuning. Publicly accessible HTTP and HTTPS URLs are currently supported. -Support for privately sharing data with the LLM Engine API is coming shortly. For quick -iteration, you can look into tools like Pastebin or GitHub Gists to quickly host your CSV -files in a public manner. An example Github Gist can be found -[here](https://gist.github.com/tigss/7cec73251a37de72756a3b15eace9965). To use the gist, -you can use the URL given when you click the “Raw” button -([URL](https://gist.githubusercontent.com/tigss/7cec73251a37de72756a3b15eace9965/raw/85d9742890e1e6b0c06468507292893b820c13c9/llm_sample_data.csv)). +Currently, data needs to be uploaded to either a publicly accessible web URL or to LLM Engine's private file server so that it can be read for fine-tuning. Publicly accessible HTTP and HTTPS URLs are currently supported. + +To privately share data with the LLM Engine API, use LLM Engine's [File.upload](../../api/python_client/#llmengine.File.upload) API. You can upload data in local file to LLM Engine's private file server and then use the returned file ID to reference your data in the FineTune API. The file ID is generally in the form of `file-`, e.g. "file-7DLVeLdN2Ty4M2m". + +=== "Upload to LLM Engine's private file server" + +```python +from llmengine import File + +response = File.upload(open("customer_service_data.csv", "r")) +print(response.json()) +``` ## Launching the fine-tune @@ -137,8 +141,8 @@ from llmengine import FineTune response = FineTune.create( model="llama-2-7b", - training_file="s3://my-bucket/path/to/training-file.csv", - validation_file="s3://my-bucket/path/to/validation-file.csv", + training_file="file-7DLVeLdN2Ty4M2m", + training_file="file-ezSRtpgKQyItI26", ) print(response.json()) From 783ae5d41fbb6c06ec947e147ad7d32732913fda Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 29 Aug 2023 11:19:52 -0700 Subject: [PATCH 065/425] Update LICENSE (#235) * Update LICENSE * update copyright --- LICENSE | 207 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 197 insertions(+), 10 deletions(-) diff --git a/LICENSE b/LICENSE index d803528b..b8106a2f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,14 +1,201 @@ -Copyright [2023] [Scale AI] + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - http://www.apache.org/licenses/LICENSE-2.0 + 1. Definitions. -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 Scale AI + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. From c4395e9962c8f414c4efb601bf75e62413914a90 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 29 Aug 2023 17:33:05 -0700 Subject: [PATCH 066/425] Update README.md (#236) --- README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 68a4f3ef..39fc7098 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,10 @@ -# ⚡ LLM Engine ⚡ +# LLM Engine -**The open source engine for fine-tuning and serving large language models**. +[![LICENSE](https://img.shields.io/github/license/scaleapi/llm-engine.svg)](https://github.com/scaleapi/llm-engine/blob/master/LICENSE) +[![Release Notes](https://img.shields.io/github/release/scaleapi/llm-engine)](https://github.com/scaleapi/llm-engine/releases) +[![CircleCI](https://circleci.com/gh/scaleapi/llm-engine.svg?style=shield)](https://circleci.com/gh/scaleapi/llm-engine) + +🚀 **The open source engine for fine-tuning and serving large language models**. 🚀 Scale's LLM Engine is the easiest way to customize and serve LLMs. In LLM Engine, models can be accessed via Scale's hosted version or by using the Helm charts in this repository to run model inference and fine-tuning in your own infrastructure. @@ -87,4 +91,4 @@ print(response.output.text) You should see a successful completion of your given prompt! _What's next?_ Visit the [LLM Engine documentation pages](https://scaleapi.github.io/llm-engine/) for more on -the `Completion` and `FineTune` APIs and how to use them. +the `Completion` and `FineTune` APIs and how to use them. Check out this [blog post](https://scale.com/blog/fine-tune-llama-2) for an end-to-end example. From c846089ae3697eca750ed0919ae0e656a0dc894d Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 30 Aug 2023 10:32:29 -0700 Subject: [PATCH 067/425] Ianmacleod/fix download artifact gateway (#237) * fixing prefix * fixing llm artifact gateway --- .../infra/gateways/s3_llm_artifact_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py index 9ebb84e6..d46be385 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -32,5 +32,5 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[ ) prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}" for obj in bucket.objects.filter(Prefix=prefix): - model_files.append(f"s3://{hmi_config.s3_bucket_name}/{obj.key}") + model_files.append(f"s3://{bucket_name}/{obj.key}") return model_files From b314ad32927adb861fc8d3a918e23559c0b7a4da Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Wed, 30 Aug 2023 11:17:05 -0700 Subject: [PATCH 068/425] add peft config documentation (#238) --- clients/python/llmengine/fine_tuning.py | 1 + docs/guides/fine_tuning.md | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index 91f9d05f..e15f36a7 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -66,6 +66,7 @@ def create( * `warmup_ratio`: Ratio of training steps used for learning rate warmup. (Default: 0.03) * `epochs`: Number of fine-tuning epochs. This should be less than 20. (Default: 5) * `weight_decay`: Regularization penalty applied to learned weights. (Default: 0.001) + * `peft_config`: A dict of parameters for the PEFT algorithm. See [LoraConfig](https://huggingface.co/docs/peft/main/en/package_reference/tuners#peft.LoraConfig) for more information. wandb_config (`Optional[Dict[str, Any]]`): A dict of configuration parameters for Weights & Biases. See [Weights & Biases](https://docs.wandb.ai/ref/python/init) for more information. diff --git a/docs/guides/fine_tuning.md b/docs/guides/fine_tuning.md index af1fe703..24b7babd 100644 --- a/docs/guides/fine_tuning.md +++ b/docs/guides/fine_tuning.md @@ -142,7 +142,7 @@ from llmengine import FineTune response = FineTune.create( model="llama-2-7b", training_file="file-7DLVeLdN2Ty4M2m", - training_file="file-ezSRtpgKQyItI26", + validation_file="file-ezSRtpgKQyItI26", ) print(response.json()) From 8b6aaebd1259ea5741db3f37667db66213b805ec Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 30 Aug 2023 12:47:01 -0700 Subject: [PATCH 069/425] Update client completion timeout (#239) * update client completion timeout * isort * [skip ci] * don't skip ci? --- clients/python/llmengine/completion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 8178867c..8cecd765 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -8,6 +8,8 @@ CompletionSyncV1Request, ) +COMPLETION_TIMEOUT = 300 + class Completion(APIEngine): """ @@ -31,7 +33,7 @@ async def acreate( temperature: float = 0.2, stop_sequences: Optional[List[str]] = None, return_token_log_probs: Optional[bool] = False, - timeout: int = 10, + timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]: """ @@ -193,7 +195,7 @@ def create( temperature: float = 0.2, stop_sequences: Optional[List[str]] = None, return_token_log_probs: Optional[bool] = False, - timeout: int = 10, + timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]: """ From 139a4ddb46f794fdfbe6e1d89d77f6b2bce32d92 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 30 Aug 2023 13:25:49 -0700 Subject: [PATCH 070/425] Add nvidia.com/gpu in requests (#240) --- charts/model-engine/templates/_helpers.tpl | 2 +- .../templates/service_template_config_map.yaml | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index 1a6155ce..8bcebe2c 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -42,7 +42,7 @@ Common labels */}} {{- define "modelEngine.labels" -}} team: infra -product: launch +product: model-engine helm.sh/chart: {{ include "modelEngine.chart" . }} app.kubernetes.io/managed-by: {{ .Release.Service }} app.kubernetes.io/version: {{ .Values.tag }} diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 84c26206..fa52f51f 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -287,6 +287,9 @@ data: periodSeconds: 5 resources: requests: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -616,6 +619,9 @@ data: resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. requests: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -673,7 +679,7 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: @@ -688,7 +694,7 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 From a64f3155dd25966acbc408ad6aff8f1404dd0076 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 30 Aug 2023 17:37:15 -0700 Subject: [PATCH 071/425] Add new image to image cache (#242) --- .../infra/services/image_cache_service.py | 5 ++++- .../tests/unit/infra/services/test_image_cache_service.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index db395a54..53b14980 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -71,11 +71,14 @@ def _cache_finetune_llm_images( tgi_image = DockerImage( f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.3-launch_s3" ) + tgi_image_2 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.4" + ) forwarder_image = DockerImage( f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG ) - for llm_image in [istio_image, tgi_image, forwarder_image]: + for llm_image in [istio_image, tgi_image, tgi_image_2, forwarder_image]: if self.docker_repository.is_repo_name( llm_image.repo ) and not self.docker_repository.image_exists(llm_image.tag, llm_image.repo): diff --git a/model-engine/tests/unit/infra/services/test_image_cache_service.py b/model-engine/tests/unit/infra/services/test_image_cache_service.py index f405ea21..c2ce5243 100644 --- a/model-engine/tests/unit/infra/services/test_image_cache_service.py +++ b/model-engine/tests/unit/infra/services/test_image_cache_service.py @@ -51,8 +51,11 @@ async def test_caching_finetune_llm_images( tgi_image = DockerImage( f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.3-launch_s3" ) + tgi_image_2 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.4" + ) forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG) for key in ["a10", "a100"]: - for llm_image in [istio_image, tgi_image, forwarder_image]: + for llm_image in [istio_image, tgi_image, tgi_image_2, forwarder_image]: assert f"{llm_image.repo}:{llm_image.tag}" in gateway.cached_images[key] From dfc81969f5de696acb6c3bc6f0a95a1034333829 Mon Sep 17 00:00:00 2001 From: William Song Date: Thu, 31 Aug 2023 08:38:44 -0700 Subject: [PATCH 072/425] Remove plugins from endpoint containers (#241) * various fixes for bundle endpoints --- .../core/docker/remote_build.py | 7 +- .../model_engine_server/core/loggers.py | 9 +- .../inference/post_inference_hooks.py | 60 +----- .../inference/pytorch_or_tf.base.Dockerfile | 1 - .../inference/requirements_base.txt | 20 +- .../repositories/ecr_docker_repository.py | 4 + .../services/live_endpoint_builder_service.py | 3 + .../service_builder/tasks_v1.py | 11 +- model-engine/requirements.in | 2 +- model-engine/requirements.txt | 178 ++++++++++++------ 10 files changed, 160 insertions(+), 135 deletions(-) diff --git a/model-engine/model_engine_server/core/docker/remote_build.py b/model-engine/model_engine_server/core/docker/remote_build.py index 8b250cfd..6261334e 100644 --- a/model-engine/model_engine_server/core/docker/remote_build.py +++ b/model-engine/model_engine_server/core/docker/remote_build.py @@ -70,7 +70,8 @@ def zip_context( assert len(folders_to_include) > 0 assert s3_file_name.endswith(".gz") - print(f"Uploading to s3 at: {s3_file_name}") + s3_uri = f"s3://{S3_BUCKET}/{s3_file_name}" + print(f"Uploading to s3 at: {s3_uri}") try: # Need to gimme_okta_aws_creds (you can export AWS_PROFILE='ml-admin' right after) tar_command = _build_tar_cmd(context, ignore_file, folders_to_include) @@ -83,7 +84,7 @@ def zip_context( ) as proc: assert proc.stdout is not None with storage_client.open( - f"s3://{S3_BUCKET}/{s3_file_name}", + s3_uri, "wb", ) as out_file: shutil.copyfileobj(proc.stdout, out_file) @@ -429,6 +430,7 @@ def build_remote_block( :param ignore_file: File (e.g. .dockerignore) containing things to ignore when preparing docker context. Relative to context :return: BuildResult representing if docker image has successfully built/pushed """ + logger.info(f"build_remote_block args {locals()}") job_name = build_remote( context, dockerfile, @@ -439,6 +441,7 @@ def build_remote_block( build_args, custom_tags, ) + logger.info(f"Waiting for job {job_name} to finish") result = get_pod_status_and_log(job_name) return result diff --git a/model-engine/model_engine_server/core/loggers.py b/model-engine/model_engine_server/core/loggers.py index 91b69758..e8245199 100644 --- a/model-engine/model_engine_server/core/loggers.py +++ b/model-engine/model_engine_server/core/loggers.py @@ -10,7 +10,7 @@ import ddtrace import json_log_formatter import tqdm -from ddtrace.helpers import get_correlation_ids +from ddtrace import tracer # DO NOT CHANGE LOGGING FORMAT LOG_FORMAT: str = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] - %(message)s" @@ -82,11 +82,12 @@ def json_record(self, message: str, extra: dict, record: logging.LogRecord) -> d if request_id: extra["request_id"] = request_id - trace_id, span_id = get_correlation_ids() + context = tracer.current_trace_context() + trace_id, span_id = (context.trace_id, context.span_id) if context else (0, 0) # add ids to event dictionary - extra["dd.trace_id"] = trace_id or 0 - extra["dd.span_id"] = span_id or 0 + extra["dd.trace_id"] = trace_id + extra["dd.span_id"] = span_id # add the env, service, and version configured for the tracer. # If tracing is not set up, then this should pull values from DD_ENV, DD_SERVICE, and DD_VERSION. diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py index cd460a27..00abaa5d 100644 --- a/model-engine/model_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -1,32 +1,19 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from uuid import uuid4 import requests -from model_engine_server.common.constants import ( - BILLING_POST_INFERENCE_HOOK, - CALLBACK_POST_INFERENCE_HOOK, -) +from model_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth -from model_engine_server.inference.common import _write_to_s3 from model_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( InferenceMonitoringMetricsGateway, ) -from model_engine_server.inference.domain.gateways.usage_metrics_gateway import UsageMetricsGateway -from model_engine_server.inference.infra.gateways.fake_usage_metrics_gateway import ( - FakeUsageMetricsGateway, -) from tenacity import Retrying, stop_after_attempt, wait_exponential logger = make_logger(filename_wo_ext(__file__)) -def _upload_data(data: Any): - return _write_to_s3(data).get("result_url") - - class PostInferenceHook(ABC): def __init__( self, @@ -48,41 +35,6 @@ def handle( pass -class BillingHook(PostInferenceHook): - def __init__( - self, - endpoint_name: str, - bundle_name: str, - user_id: str, - billing_queue: Optional[str], - billing_tags: Optional[Dict[str, Any]], - ): - super().__init__(endpoint_name, bundle_name, user_id) - self._billing_queue = billing_queue - self._billing_tags = billing_tags or {} - - def handle( - self, - request_payload: EndpointPredictV1Request, - response: Dict[str, Any], - task_id: Optional[str], - ): - if not self._user_id or not self._billing_queue: - logger.error("Usage inputs could not be found for billing hook, aborting") - return - if not task_id: - task_id = str(uuid4()) - - events_queue: UsageMetricsGateway - try: - from plugins.eventbridge_usage_metrics_gateway import EventbridgeUsageMetricsGateway - - events_queue = EventbridgeUsageMetricsGateway(self._billing_queue) - except ModuleNotFoundError: - events_queue = FakeUsageMetricsGateway() - events_queue.emit_task_call_metric(idempotency_token=task_id, tags=self._billing_tags) - - class CallbackHook(PostInferenceHook): def __init__( self, @@ -142,15 +94,7 @@ def __init__( # TODO: Ensure that this process gracefully handles errors in # initializing each post-inference hook. hook_lower = hook.lower() - if hook_lower == BILLING_POST_INFERENCE_HOOK: - self._hooks[hook_lower] = BillingHook( - endpoint_name, - bundle_name, - user_id, - billing_queue, - billing_tags, - ) - elif hook_lower == CALLBACK_POST_INFERENCE_HOOK: + if hook_lower == CALLBACK_POST_INFERENCE_HOOK: self._hooks[hook_lower] = CallbackHook( endpoint_name, bundle_name, diff --git a/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile index edac54c9..01cbdf0c 100644 --- a/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile +++ b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile @@ -57,7 +57,6 @@ COPY --chown=modelengine \ RUN pip install -r /app/model-engine/model_engine_server/inference/requirements_base.txt COPY --chown=modelengine model-engine/setup.py /app/model-engine/setup.py -COPY --chown=modelengine model-engine/model_engine_server.egg-info /app/model-engine/model_engine_server.egg-info COPY --chown=modelengine model-engine/model_engine_server/__init__.py /app/model-engine/model_engine_server/__init__.py COPY --chown=modelengine model-engine/model_engine_server/common /app/model-engine/model_engine_server/common COPY --chown=modelengine model-engine/model_engine_server/core /app/model-engine/model_engine_server/core diff --git a/model-engine/model_engine_server/inference/requirements_base.txt b/model-engine/model_engine_server/inference/requirements_base.txt index a352a14a..cedabe42 100644 --- a/model-engine/model_engine_server/inference/requirements_base.txt +++ b/model-engine/model_engine_server/inference/requirements_base.txt @@ -1,12 +1,22 @@ -aioredis==2.0.1 +aioredis~=2.0 +boto3>=1.28.38 celery[redis,sqs,tblib]==5.3.1 +datadog-api-client==2.11.0 +datadog~=0.46.0 fastapi==0.78.0 -gunicorn==20.1.0 # Incompatibility between celery 5 and python 3.7 because of importlib-metadata 5, so we pin it importlib-metadata<5.0;python_version<"3.8" -json-log-formatter==0.5.2 +scale-launch>=0.1.0 smart_open==5.1.0 -tqdm==4.65.0 -# Pin typing-extensions so aioitertools doesn't break typing-extensions>=4.1.1 uvicorn==0.17.6 +waitress==2.0.0 + +# HACK: at time of adding, these deps are imported by model-engine/model_engine_server files +# add here to to prevent `ModuleNotFoundError` error on container startup, these should be in sync with server reqs +# long term: consider having slimmer deps and seperating inference container deps from server container deps +ddtrace==1.8.3 # required for ddtrace-run entrypoint command as well +json-log-formatter~=0.3 # model_engine_server/core/loggers.py +tenacity>=6.0.0,<=6.2.0 # model_engine_server/core/loggers.py +tqdm~=4.64 # model_engine_server/common/service_requests.py +gunicorn~=20.0 diff --git a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py index ca5f7469..d2277ab3 100644 --- a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py @@ -4,8 +4,11 @@ from model_engine_server.core.config import infra_config from model_engine_server.core.docker.ecr import image_exists as ecr_image_exists from model_engine_server.core.docker.remote_build import build_remote_block +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.repositories import DockerRepository +logger = make_logger(logger_name()) + class ECRDockerRepository(DockerRepository): def image_exists( @@ -21,6 +24,7 @@ def get_image_url(self, image_tag: str, repository_name: str) -> str: return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: + logger.info(f"build_image args {locals()}") folders_to_include = ["model-engine"] if image_params.requirements_folder: folders_to_include.append(image_params.requirements_folder) diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 61b381e0..eabbf034 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -172,6 +172,7 @@ async def build_endpoint( base_image_params = self.get_base_image_params( build_endpoint_request, logger_adapter ) + logger.info(f"base_image_params: {base_image_params}") base_image = await self._build_image( base_image_params, build_endpoint_request, @@ -490,6 +491,8 @@ def get_base_image_params( inference_folder = "model-engine/model_engine_server/inference" base_path: str = os.getenv("WORKSPACE") # type: ignore + logger.info(f"inference_folder: {inference_folder}") + logger.info(f"dockerfile: {inference_folder}/{dockerfile}") return BuildImageRequest( repo="launch/inference", image_tag=resulting_image_tag[:MAX_IMAGE_TAG_LEN], diff --git a/model-engine/model_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py index 539b6803..772a5297 100644 --- a/model-engine/model_engine_server/service_builder/tasks_v1.py +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -90,16 +90,7 @@ async def _build_endpoint( session = SessionAsyncNullPool pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) redis = aioredis.Redis(connection_pool=pool) - - service: LiveEndpointBuilderService - try: - from plugins.dependencies import ( - get_live_endpoint_builder_service as get_custom_live_endpoint_builder_service, - ) - - service = get_custom_live_endpoint_builder_service(session, redis) - except ModuleNotFoundError: - service = get_live_endpoint_builder_service(session, redis) + service: LiveEndpointBuilderService = get_live_endpoint_builder_service(session, redis) response = await service.build_endpoint(build_endpoint_request) await redis.close() diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 5caed45c..e173eeef 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -15,7 +15,7 @@ croniter==1.4.1 dataclasses-json>=0.5.7 datadog-api-client==2.11.0 datadog~=0.46.0 -ddtrace~=0.49.2 +ddtrace==1.8.3 deprecation~=2.1 docker~=5.0 fastapi==0.78.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 1e37fe0f..c367da1e 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -8,14 +8,14 @@ aiofiles==23.1.0 # via quart aiohttp==3.8.5 # via - # -r requirements.in + # -r model-engine/requirements.in # kubernetes-asyncio aioredis==2.0.1 - # via -r requirements.in + # via -r model-engine/requirements.in aiosignal==1.3.1 # via aiohttp alembic==1.8.1 - # via -r requirements.in + # via -r model-engine/requirements.in amqp==5.1.1 # via kombu anyio==3.7.1 @@ -28,12 +28,20 @@ async-timeout==4.0.2 # via # aiohttp # aioredis + # redis asyncpg==0.27.0 - # via -r requirements.in + # via -r model-engine/requirements.in attrs==23.1.0 # via # aiohttp + # cattrs # ddtrace + # jsonschema + # referencing +backports-zoneinfo[tzdata]==0.2.1 + # via + # celery + # kombu billiard==4.1.0 # via celery bleach==6.0.0 @@ -42,24 +50,28 @@ blinker==1.6.2 # via quart boto3==1.28.1 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # kombu boto3-stubs[essential]==1.26.67 - # via -r requirements.in + # via -r model-engine/requirements.in botocore==1.31.1 # via - # -r requirements.in + # -r model-engine/requirements.in # boto3 # s3transfer botocore-stubs==1.29.165 # via boto3-stubs build==0.8.0 - # via -r requirements.in + # via -r model-engine/requirements.in +bytecode==0.14.2 + # via ddtrace cachetools==5.3.1 # via google-auth +cattrs==23.1.2 + # via ddtrace celery[redis,sqs,tblib]==5.3.1 - # via -r requirements.in + # via -r model-engine/requirements.in certifi==2023.7.22 # via # datadog-api-client @@ -74,7 +86,7 @@ charset-normalizer==3.2.0 # requests click==8.1.4 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # click-didyoumean # click-plugins @@ -88,31 +100,39 @@ click-plugins==1.1.1 click-repl==0.3.0 # via celery cloudpickle==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in colorama==0.4.6 # via twine commonmark==0.9.1 # via rich croniter==1.4.1 - # via -r requirements.in + # via -r model-engine/requirements.in cryptography==41.0.3 # via secretstorage dataclasses-json==0.5.9 - # via -r requirements.in + # via -r model-engine/requirements.in datadog==0.46.0 - # via -r requirements.in + # via -r model-engine/requirements.in datadog-api-client==2.11.0 - # via -r requirements.in -ddtrace==0.49.2 - # via -r requirements.in + # via -r model-engine/requirements.in +ddsketch==2.0.4 + # via ddtrace +ddtrace==1.8.3 + # via -r model-engine/requirements.in deprecation==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in docker==5.0.3 - # via -r requirements.in + # via -r model-engine/requirements.in docutils==0.20.1 # via readme-renderer +envier==0.4.0 + # via ddtrace +exceptiongroup==1.1.3 + # via + # anyio + # cattrs fastapi==0.78.0 - # via -r requirements.in + # via -r model-engine/requirements.in frozenlist==1.3.3 # via # aiohttp @@ -120,15 +140,15 @@ frozenlist==1.3.3 gitdb==4.0.10 # via gitpython gitdb2==2.0.6 - # via -r requirements.in + # via -r model-engine/requirements.in gitpython==3.1.32 - # via -r requirements.in + # via -r model-engine/requirements.in google-auth==2.21.0 # via kubernetes greenlet==2.0.2 # via sqlalchemy gunicorn==20.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in h11==0.14.0 # via # hypercorn @@ -139,7 +159,7 @@ h2==4.1.0 hpack==4.0.0 # via h2 httptools==0.5.0 - # via -r requirements.in + # via -r model-engine/requirements.in hypercorn==0.14.4 # via quart hyperframe==6.0.1 @@ -151,8 +171,16 @@ idna==3.4 # yarl importlib-metadata==6.8.0 # via + # alembic # keyring + # quart # twine +importlib-resources==6.0.1 + # via + # alembic + # jsonschema + # jsonschema-specifications + # keyring itsdangerous==2.1.2 # via quart jaraco-classes==3.3.0 @@ -163,24 +191,28 @@ jeepney==0.8.0 # secretstorage jinja2==3.0.3 # via - # -r requirements.in + # -r model-engine/requirements.in # quart jmespath==1.0.1 # via # boto3 # botocore json-log-formatter==0.5.2 - # via -r requirements.in + # via -r model-engine/requirements.in +jsonschema==4.19.0 + # via ddtrace +jsonschema-specifications==2023.7.1 + # via jsonschema keyring==24.2.0 # via twine kombu[sqs]==5.3.1 # via celery kubeconfig==1.1.1 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes==25.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes-asyncio==24.2.2 - # via -r requirements.in + # via -r model-engine/requirements.in mako==1.2.4 # via alembic markupsafe==2.1.3 @@ -220,7 +252,7 @@ mypy-extensions==1.0.0 oauthlib==3.2.2 # via requests-oauthlib orjson==3.8.6 - # via -r requirements.in + # via -r model-engine/requirements.in packaging==23.1 # via # build @@ -233,18 +265,21 @@ pg8000==1.29.8 # via testing-postgresql pkginfo==1.9.6 # via twine +pkgutil-resolve-name==1.3.10 + # via jsonschema priority==2.0.0 # via hypercorn prompt-toolkit==3.0.39 # via click-repl protobuf==3.20.3 # via - # -r requirements.in + # -r model-engine/requirements.in + # ddsketch # ddtrace psycopg2-binary==2.9.3 - # via -r requirements.in + # via -r model-engine/requirements.in py-xid==0.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in pyasn1==0.5.0 # via # pyasn1-modules @@ -255,12 +290,12 @@ pycparser==2.21 # via cffi pycurl==7.45.2 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # kombu pydantic==1.10.11 # via - # -r requirements.in + # -r model-engine/requirements.in # fastapi pygments==2.15.1 # via @@ -276,21 +311,25 @@ python-dateutil==2.8.2 # kubernetes-asyncio # pg8000 python-multipart==0.0.6 - # via -r requirements.in + # via -r model-engine/requirements.in pyyaml==6.0 # via # kubeconfig # kubernetes # kubernetes-asyncio quart==0.18.3 - # via -r requirements.in + # via -r model-engine/requirements.in readme-renderer==40.0 # via twine redis==4.6.0 # via celery +referencing==0.30.2 + # via + # jsonschema + # jsonschema-specifications requests==2.31.0 # via - # -r requirements.in + # -r model-engine/requirements.in # datadog # docker # kubernetes @@ -299,7 +338,7 @@ requests==2.31.0 # requests-toolbelt # twine requests-auth-aws-sigv4==0.7 - # via -r requirements.in + # via -r model-engine/requirements.in requests-oauthlib==1.3.1 # via kubernetes requests-toolbelt==1.0.0 @@ -307,7 +346,11 @@ requests-toolbelt==1.0.0 rfc3986==2.0.0 # via twine rich==12.6.0 - # via -r requirements.in + # via -r model-engine/requirements.in +rpds-py==0.10.0 + # via + # jsonschema + # referencing rsa==4.9 # via google-auth s3transfer==0.6.1 @@ -317,10 +360,11 @@ scramp==1.4.4 secretstorage==3.3.3 # via keyring sh==1.14.3 - # via -r requirements.in + # via -r model-engine/requirements.in six==1.16.0 # via # bleach + # ddsketch # ddtrace # google-auth # kubernetes @@ -328,7 +372,7 @@ six==1.16.0 # python-dateutil # tenacity smart-open==5.2.1 - # via -r requirements.in + # via -r model-engine/requirements.in smmap==5.0.0 # via # gitdb @@ -339,12 +383,12 @@ sniffio==1.3.0 # via anyio sqlalchemy[asyncio]==2.0.4 # via - # -r requirements.in + # -r model-engine/requirements.in # alembic sse-starlette==1.6.1 - # via -r requirements.in + # via -r model-engine/requirements.in sseclient-py==1.7.2 - # via -r requirements.in + # via -r model-engine/requirements.in starlette==0.19.1 # via # fastapi @@ -353,18 +397,23 @@ tblib==2.0.0 # via celery tenacity==6.2.0 # via - # -r requirements.in + # -r model-engine/requirements.in # ddtrace testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in +tomli==2.0.1 + # via + # build + # hypercorn + # pep517 tqdm==4.65.0 # via - # -r requirements.in + # -r model-engine/requirements.in # twine twine==3.7.1 - # via -r requirements.in + # via -r model-engine/requirements.in types-awscrt==0.16.23 # via # botocore-stubs @@ -374,15 +423,32 @@ types-s3transfer==0.6.1 typing-extensions==4.7.1 # via # aioredis + # asgiref # boto3-stubs + # botocore-stubs + # bytecode + # cattrs # datadog-api-client + # ddtrace + # kombu + # mypy-boto3-cloudformation + # mypy-boto3-dynamodb + # mypy-boto3-ec2 + # mypy-boto3-lambda + # mypy-boto3-rds + # mypy-boto3-s3 + # mypy-boto3-sqs # pydantic + # rich # sqlalchemy + # starlette # typing-inspect typing-inspect==0.9.0 # via dataclasses-json tzdata==2023.3 - # via celery + # via + # backports-zoneinfo + # celery urllib3==1.26.16 # via # botocore @@ -394,9 +460,9 @@ urllib3==1.26.16 # kubernetes-asyncio # requests uvicorn==0.17.6 - # via -r requirements.in + # via -r model-engine/requirements.in uvloop==0.17.0 - # via -r requirements.in + # via -r model-engine/requirements.in vine==5.0.0 # via # amqp @@ -414,12 +480,16 @@ werkzeug==2.3.6 # via quart wsproto==1.2.0 # via hypercorn +xmltodict==0.13.0 + # via ddtrace yarl==1.9.2 # via - # -r requirements.in + # -r model-engine/requirements.in # aiohttp zipp==3.16.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: setuptools==68.0.0 From 6c4f376213b05198ccebfaaa80f93e086cbc8fd5 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:54:02 -0700 Subject: [PATCH 073/425] Add vLLM as an inference framework (#228) * Add vLLM dockerfile * fix * fixes * fixes * fix unit test --- charts/model-engine/values_circleci.yaml | 1 + .../model_engine_server/common/config.py | 1 + .../domain/entities/llm_entity.py | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 215 +++++++++++++----- .../inference/vllm/Dockerfile | 7 + .../inference/vllm/build_and_upload_image.sh | 21 ++ .../inference/vllm/requirements.txt | 1 + .../inference/vllm/vllm_server.py | 99 ++++++++ .../service_config_circleci.yaml | 1 + .../tests/unit/domain/test_llm_use_cases.py | 2 +- 10 files changed, 290 insertions(+), 59 deletions(-) create mode 100644 model-engine/model_engine_server/inference/vllm/Dockerfile create mode 100755 model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh create mode 100644 model-engine/model_engine_server/inference/vllm/requirements.txt create mode 100644 model-engine/model_engine_server/inference/vllm/vllm_server.py diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index d31665ef..8279f00e 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -140,6 +140,7 @@ config: datadog_trace_enabled: false istio_enabled: true tgi_repository: "text-generation-inference" + vllm_repository: "vllm" hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET" # Service Account diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index deeb4477..250a51e6 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -55,6 +55,7 @@ class HostedModelInferenceServiceConfig: istio_enabled: bool datadog_trace_enabled: bool tgi_repository: str + vllm_repository: str @classmethod def from_yaml(cls, yaml_path): diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index f9062709..80344b54 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -10,6 +10,7 @@ class LLMSource(str, Enum): class LLMInferenceFramework(str, Enum): DEEPSPEED = "deepspeed" TEXT_GENERATION_INFERENCE = "text_generation_inference" + VLLM = "vllm" class Quantization(str, Enum): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 638fd9e2..8abdf322 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -102,6 +102,21 @@ "falcon-40b": "tiiuae/falcon-40b", "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", }, + LLMInferenceFramework.VLLM: { + "mpt-7b": "mosaicml/mpt-7b", + "mpt-7b-instruct": "mosaicml/mpt-7b-instruct", + "llama-7b": "decapoda-research/llama-7b-hf", + "llama-2-7b": "meta-llama/Llama-2-7b-hf", + "llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", + "llama-2-13b": "meta-llama/Llama-2-13b-hf", + "llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", + "llama-2-70b": "meta-llama/Llama-2-70b-hf", + "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", + "falcon-7b": "tiiuae/falcon-7b", + "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", + "falcon-40b": "tiiuae/falcon-40b", + "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", + }, } @@ -195,6 +210,15 @@ async def create_model_bundle( quantize, checkpoint_path, ) + elif framework == LLMInferenceFramework.VLLM: + bundle_id = await self.create_vllm_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + checkpoint_path, + ) else: raise ObjectHasInvalidValueException( f"Framework {framework} is not supported for source {source}." @@ -226,73 +250,37 @@ async def create_text_generation_inference_bundle( max_input_length = 4095 max_total_tokens = 4096 + subcommands = [] if checkpoint_path is not None: if checkpoint_path.startswith("s3://"): - base_path = checkpoint_path.split("/")[-1] final_weights_folder = "model_files" - subcommands = [] - - s5cmd = "s5cmd" - # This is a hack for now to skip installing s5cmd for text-generation-inference:0.9.3-launch_s3, - # which has s5cmd binary already baked in. Otherwise, install s5cmd if it's not already available - if framework_image_tag != "0.9.3-launch_s3": - subcommands.append( - f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}" - ) - else: - s5cmd = "./s5cmd" - - if base_path.endswith(".tar"): - # If the checkpoint file is a tar file, extract it into final_weights_folder - subcommands.extend( - [ - f"{s5cmd} cp {checkpoint_path} .", - f"mkdir -p {final_weights_folder}", - f"tar --no-same-owner -xf {base_path} -C {final_weights_folder}", - ] - ) - else: - subcommands.append( - f"{s5cmd} --numworkers 512 cp --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}" - ) - subcommands.append( - f"text-generation-launcher --hostname :: --model-id ./{final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}" + subcommands += self.load_model_weights_sub_commands( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + framework_image_tag, + checkpoint_path, + final_weights_folder, ) - - if quantize: - subcommands[-1] = subcommands[-1] + f" --quantize {quantize}" - command = [ - "/bin/bash", - "-c", - ";".join(subcommands), - ] else: raise ObjectHasInvalidValueException( f"Not able to load checkpoint path {checkpoint_path}." ) else: - hf_model_name = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.TEXT_GENERATION_INFERENCE][ - model_name - ] + final_weights_folder = _SUPPORTED_MODEL_NAMES[ + LLMInferenceFramework.TEXT_GENERATION_INFERENCE + ][model_name] - command = [ - "text-generation-launcher", - "--model-id", - hf_model_name, - "--num-shard", - str(num_shards), - "--port", - "5005", - "--hostname", - "::", - "--max-input-length", - str(max_input_length), - "--max-total-tokens", - str(max_total_tokens), - ] - if quantize: - command = command + [f"--quantize {quantize}"] + subcommands.append( + f"text-generation-launcher --hostname :: --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}" + ) + + if quantize: + subcommands[-1] = subcommands[-1] + f" --quantize {quantize}" + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] return ( await self.create_model_bundle_use_case.execute( @@ -322,6 +310,44 @@ async def create_text_generation_inference_bundle( ) ).model_bundle_id + def load_model_weights_sub_commands( + self, framework, framework_image_tag, checkpoint_path, final_weights_folder + ): + subcommands = [] + s5cmd = "s5cmd" + + base_path = checkpoint_path.split("/")[-1] + + # This is a hack for now to skip installing s5cmd for text-generation-inference:0.9.3-launch_s3, + # which has s5cmd binary already baked in. Otherwise, install s5cmd if it's not already available + if ( + framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE + and framework_image_tag != "0.9.3-launch_s3" + ): + subcommands.append(f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}") + else: + s5cmd = "./s5cmd" + + if base_path.endswith(".tar"): + # If the checkpoint file is a tar file, extract it into final_weights_folder + subcommands.extend( + [ + f"{s5cmd} cp {checkpoint_path} .", + f"mkdir -p {final_weights_folder}", + f"tar --no-same-owner -xf {base_path} -C {final_weights_folder}", + ] + ) + else: + if framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + subcommands.append( + f"{s5cmd} --numworkers 512 cp --concurrency 10 --exclude '*.bin' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) + else: + subcommands.append( + f"{s5cmd} --numworkers 512 cp --concurrency 10 --exclude '*.safetensors' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) + return subcommands + async def create_deepspeed_bundle( self, user: User, @@ -401,6 +427,76 @@ async def create_deepspeed_bundle( ) ).model_bundle_id + async def create_vllm_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + checkpoint_path: Optional[str], + ): + command = [] + + max_num_batched_tokens = 2560 # vLLM's default + if "llama-2" in model_name: + max_num_batched_tokens = 4096 # Need to be bigger than model's context window + + subcommands = [] + if checkpoint_path is not None: + if checkpoint_path.startswith("s3://"): + final_weights_folder = "model_files" + subcommands += self.load_model_weights_sub_commands( + LLMInferenceFramework.VLLM, + framework_image_tag, + checkpoint_path, + final_weights_folder, + ) + else: + raise ObjectHasInvalidValueException( + f"Not able to load checkpoint path {checkpoint_path}." + ) + else: + final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] + + subcommands.append( + f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005 --max-num-batched-tokens {max_num_batched_tokens}" + ) + + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] + + return ( + await self.create_model_bundle_use_case.execute( + user, + CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.vllm_repository, + tag=framework_image_tag, + command=command, + streaming_command=command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/health", + predict_route="/predict", + streaming_predict_route="/stream", + env={}, + ), + metadata={}, + ), + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + async def execute( self, user: User, request: CreateLLMModelEndpointV1Request ) -> CreateLLMModelEndpointV1Response: @@ -417,10 +513,13 @@ async def execute( validate_model_name(request.model_name, request.inference_framework) validate_num_shards(request.num_shards, request.inference_framework, request.gpus) - if request.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + if request.inference_framework in [ + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + LLMInferenceFramework.VLLM, + ]: if request.endpoint_type != ModelEndpointType.STREAMING: raise ObjectHasInvalidValueException( - f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference." + f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference and vLLM." ) bundle = await self.create_model_bundle( diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile b/model-engine/model_engine_server/inference/vllm/Dockerfile new file mode 100644 index 00000000..a7f34f3e --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile @@ -0,0 +1,7 @@ +FROM nvcr.io/nvidia/pytorch:22.12-py3 + +RUN pip uninstall torch -y +RUN pip install llm==0.1.3 ray[air]==2.6.3 +RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz +RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz +COPY vllm_server.py /workspace/vllm_server.py diff --git a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh new file mode 100755 index 00000000..750da0e0 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Build and push vLLM docker image to AWS ECR. + +set -eo pipefail + +if [ -z "$1" ]; then + echo "Must supply AWS account ID" + exit 1; +fi + +if [ -z "$2" ]; then + echo "Must supply the image tag" + exit 1; +fi + +IMAGE_TAG=$2 +ACCOUNT=$1 +aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com +DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/vllm:$IMAGE_TAG . +docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/vllm:$IMAGE_TAG diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt new file mode 100644 index 00000000..afd523e3 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -0,0 +1 @@ +ray[air]==2.6.3 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py new file mode 100644 index 00000000..62c71ba0 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -0,0 +1,99 @@ +import argparse +import json +from typing import AsyncGenerator + +import uvicorn +from fastapi import BackgroundTasks, FastAPI, Request +from fastapi.responses import Response, StreamingResponse +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds +app = FastAPI() + + +@app.get("/healthz") +@app.get("/health") +def healthcheck(): + return "OK" + + +@app.post("/predict") +@app.post("/stream") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + # TODO: vLLM spends a long time decoding text repeatedly, that for every new token `text` is regenerated, + # (see detokenize_incrementally) which we should definitely optimize away. + async def stream_results() -> AsyncGenerator[str, None]: + async for request_output in results_generator: + ret = { + "text": request_output.outputs[0].text, + "count_prompt_tokens": len(request_output.prompt_token_ids), + "count_output_tokens": len(request_output.outputs[0].token_ids), + "log_probs": request_output.outputs[0].logprobs[-1], + } + yield f"data:{json.dumps(ret)}\n\n" + + async def abort_request() -> None: + await engine.abort(request_id) + + if stream: + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) + return StreamingResponse(stream_results(), background=background_tasks) + + # Non-streaming case + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + ret = { + "text": final_output.outputs[0].text, + "count_prompt_tokens": len(final_output.prompt_token_ids), + "count_output_tokens": len(final_output.outputs[0].token_ids), + "log_probs": final_output.outputs[0].logprobs, + } + return Response(content=json.dumps(ret)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) # None == IPv4 / IPv6 dualstack + parser.add_argument("--port", type=int, default=5005) + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ) diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index 0d3ae024..04ea1da8 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -55,6 +55,7 @@ s3_file_llm_fine_tune_repository: "s3://test-bucket" datadog_trace_enabled: false istio_enabled: true tgi_repository: "text-generation-inference" +vllm_repository: "vllm" # S3 access hf_user_fine_tuned_weights_prefix: "s3://test-bucket" diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 4e30c41f..444fafb5 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -140,7 +140,7 @@ async def test_create_model_endpoint_use_case_success( bundle = await fake_model_bundle_repository.get_latest_model_bundle_by_name( owner=user.team_id, name=create_llm_model_endpoint_request_llama_2.name ) - assert "--max-total-tokens" in bundle.flavor.command and "4096" in bundle.flavor.command + assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1] @pytest.mark.asyncio From c5acde0f33191a1ea72ea5c7825c19d92657c126 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Thu, 31 Aug 2023 18:39:13 -0700 Subject: [PATCH 074/425] change max_input_length to half of max_total_tokens (#244) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 8abdf322..e57ef0a6 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -244,10 +244,10 @@ async def create_text_generation_inference_bundle( command = [] # TGI requires max_input_length < max_total_tokens - max_input_length = 2047 + max_input_length = 1024 max_total_tokens = 2048 if "llama-2" in model_name: - max_input_length = 4095 + max_input_length = 2048 max_total_tokens = 4096 subcommands = [] From 30d8ec79a743b0fd9c50f65361a2ba8c6574c214 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Thu, 31 Aug 2023 23:56:19 -0700 Subject: [PATCH 075/425] Validate Fine-tuning CSV headers (#243) * validate ft csv headers * reduce buffer size * isort * strict type * add test_create_fine_tune_invalid_headers * remove unused import --- .../use_cases/llm_fine_tuning_use_cases.py | 29 +++++++++++ .../tests/unit/domain/test_llm_use_cases.py | 48 +++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py index 331b0e48..0e837e09 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -1,6 +1,8 @@ +import csv import datetime import re +import smart_open from model_engine_server.common.dtos.llms import ( CancelFineTuneResponse, CreateFineTuneRequest, @@ -19,6 +21,7 @@ from model_engine_server.domain.services import LLMFineTuningService, ModelEndpointService DEFAULT_FINE_TUNING_METHOD = "lora" +REQUIRED_COLUMNS = ["prompt", "response"] MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER = 5 MAX_LLM_ENDPOINTS_PER_INTERNAL_USER = 15 @@ -43,6 +46,23 @@ def ensure_model_name_is_valid_k8s_label(model_name: str): return re.sub("[^-A-Za-z0-9_.]", "-", model_name).lstrip("-_.")[:62].rstrip("-_.") +def read_csv_headers(file_location: str): + """ + Read the headers of a csv file. Assumes the file exists and is valid. + """ + with smart_open.open(file_location, transport_params=dict(buffer_size=1024)) as file: + csv_reader = csv.DictReader(file) + return csv_reader.fieldnames + + +def are_dataset_headers_valid(file_location: str): + """ + Ensure the dataset headers are valid with required columns 'prompt' and 'response'. + """ + current_headers = read_csv_headers(file_location) + return all(required_header in current_headers for required_header in REQUIRED_COLUMNS) + + class CreateFineTuneV1UseCase: def __init__( self, @@ -120,6 +140,15 @@ async def execute(self, user: User, request: CreateFineTuneRequest) -> CreateFin else: validation_file = request.validation_file + if training_file is not None and not are_dataset_headers_valid(training_file): + raise InvalidRequestException( + f"Required column headers {','.join(REQUIRED_COLUMNS)} not found in training dataset" + ) + if validation_file is not None and not are_dataset_headers_valid(validation_file): + raise InvalidRequestException( + f"Required column headers {','.join(REQUIRED_COLUMNS)} not found in validation dataset" + ) + await self.llm_fine_tune_events_repository.initialize_events(user.team_id, fine_tuned_model) fine_tune_id = await self.llm_fine_tuning_service.create_fine_tune( created_by=user.user_id, diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 444fafb5..b841ba1b 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1,4 +1,5 @@ from typing import Any, Tuple +from unittest import mock import pytest from model_engine_server.common.dtos.llms import ( @@ -652,6 +653,10 @@ async def test_create_llm_fine_tune_model_name_valid(): @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,response"), +) async def test_create_fine_tune_success( fake_llm_fine_tuning_service, fake_model_endpoint_service, @@ -684,6 +689,10 @@ async def test_create_fine_tune_success( @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,response"), +) async def test_create_fine_tune_limit( fake_llm_fine_tuning_service, fake_model_endpoint_service, @@ -716,6 +725,10 @@ async def test_create_fine_tune_limit( @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,response"), +) async def test_create_fine_tune_long_suffix( fake_llm_fine_tuning_service, fake_model_endpoint_service, @@ -743,6 +756,41 @@ async def test_create_fine_tune_long_suffix( @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,not_response"), +) +async def test_create_fine_tune_invalid_headers( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + test_api_key: str, +): + use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix=None, + ) + with pytest.raises(InvalidRequestException): + await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,response"), +) async def test_get_fine_tune_events_success( fake_llm_fine_tuning_service, fake_llm_fine_tuning_events_repository, From ef439204673e104009306659aae2569a696edd0c Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Fri, 1 Sep 2023 11:10:20 -0700 Subject: [PATCH 076/425] Sync scale from zero part 2 (#230) Add in keda scaled objects CRUD for supporting sync scale from zero --- .../service_template_config_map.yaml | 29 +++ .../model_engine_server/common/config.py | 12 ++ .../use_cases/model_endpoint_use_cases.py | 5 + ...s_inference_autoscaling_metrics_gateway.py | 1 + .../k8s_endpoint_resource_delegate.py | 189 ++++++++++++++++-- .../gateways/resources/k8s_resource_types.py | 27 +++ .../service_template_config_map_circleci.yaml | 41 ++++ .../repositories/ecr_docker_repository.py | 4 + .../domain/test_model_endpoint_use_cases.py | 15 ++ 9 files changed, 304 insertions(+), 19 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index fa52f51f..ad9b7f62 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -399,6 +399,35 @@ data: target: type: Value averageValue: ${CONCURRENCY} + keda-scaled-object.yaml: |- + apiVersion: keda.sh/v1alpha1 + kind: ScaledObject + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + scaleTargetRef: + name: ${RESOURCE_NAME} + pollingInterval: 5 + cooldownPeriod: 300 + minReplicaCount: ${MIN_WORKERS} + maxReplicaCount: ${MAX_WORKERS} + fallback: + failureThreshold: 3 + replicas: ${MIN_WORKERS} + triggers: + - type: redis + metadata: + address: ${REDIS_HOST_PORT} # Format must be host:port + passwordFromEnv: "" + listName: "launch-endpoint-autoscaling:${ENDPOINT_ID}" + listLength: "100" # something absurdly high so we don't scale past 1 pod + activationListLength: "0" + enableTLS: "false" + unsafeSsl: "false" + databaseIndex: "${REDIS_DB_INDEX}" service.yaml: |- apiVersion: v1 kind: Service diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 250a51e6..25cf5453 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -63,6 +63,18 @@ def from_yaml(cls, yaml_path): raw_data = yaml.safe_load(f) return HostedModelInferenceServiceConfig(**raw_data) + @property + def cache_redis_host_port(self) -> str: + # redis://redis.url:6379/ + # -> redis.url:6379 + return self.cache_redis_url.split("redis://")[1].split("/")[0] + + @property + def cache_redis_db_index(self) -> int: + # redis://redis.url:6379/ + # -> + return int(self.cache_redis_url.split("/")[-1]) + def read_default_config(): logger.info(f"Using config file path: `{SERVICE_CONFIG_PATH}`") diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index 04e595d4..1633a72b 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -104,6 +104,11 @@ def validate_deployment_resources( max_workers: Optional[int], endpoint_type: ModelEndpointType, ) -> None: + if endpoint_type in [ModelEndpointType.STREAMING, ModelEndpointType.SYNC]: + # Special case for sync endpoints, where we can have 0, 1 min/max workers. + # Otherwise, fall through to the general case. + if min_workers == 0 and max_workers == 1: + return # TODO: we should be also validating the update request against the existing state in k8s (e.g. # so min_workers <= max_workers always) maybe this occurs already in update_model_endpoint. min_endpoint_size = 0 if endpoint_type == ModelEndpointType.ASYNC else 1 diff --git a/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py index 5761493e..a5bcc31e 100644 --- a/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py @@ -31,6 +31,7 @@ def __init__( @staticmethod def _find_redis_key(endpoint_id: str): + # Keep in line with keda scaled object yaml return f"launch-endpoint-autoscaling:{endpoint_id}" async def _emit_metric(self, endpoint_id: str, expiry_time: int): diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 7006ca1f..56836596 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -599,6 +599,46 @@ async def _create_vpa(vpa: Dict[str, Any], name: str) -> None: logger.exception("Got an exception when trying to apply the VerticalPodAutoscaler") raise + @staticmethod + async def _create_keda_scaled_object(scaled_object: Dict[str, Any], name: str) -> None: + custom_objects_api = get_kubernetes_custom_objects_client() + try: + await custom_objects_api.create_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + body=scaled_object, + ) + except ApiException as exc: + if exc.status == 409: + logger.info(f"ScaledObject {name} already exists, replacing") + + # The async k8s client has a bug with patching custom objects, so we manually + # merge the new ScaledObject with the old one and then replace the old one with the merged + # one. See _create_vpa for more details. + # There is a setting `restoreToOriginalReplicaCount` in the keda ScaledObject that should be set to + # false which should make it safe to do this replace (as opposed to a patch) + existing_scaled_object = await custom_objects_api.get_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + name=name, + ) + new_scaled_object = deep_update(existing_scaled_object, scaled_object) + await custom_objects_api.replace_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + name=name, + body=new_scaled_object, + ) + else: + logger.exception("Got an exception when trying to apply the ScaledObject") + raise + @staticmethod async def _create_destination_rule(destination_rule: Dict[str, Any], name: str) -> None: """ @@ -995,6 +1035,28 @@ async def _delete_hpa(endpoint_id: str, deployment_name: str) -> bool: return False return True + @staticmethod + async def _delete_keda_scaled_object(endpoint_id: str) -> bool: + custom_objects_client = get_kubernetes_custom_objects_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + await custom_objects_client.delete_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + name=k8s_resource_group_name, + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Trying to delete nonexistent ScaledObject {k8s_resource_group_name}" + ) + else: + logger.exception(f"Deletion of ScaledObject {k8s_resource_group_name} failed") + return False + return True + # --- Private higher level fns that interact with k8s @staticmethod @@ -1102,19 +1164,46 @@ async def _create_or_update_resources( else: api_version = "autoscaling/v2beta2" - hpa_arguments = get_endpoint_resource_arguments_from_request( - k8s_resource_group_name=k8s_resource_group_name, - request=request, - sqs_queue_name=sqs_queue_name_str, - sqs_queue_url=sqs_queue_url_str, - endpoint_resource_name="horizontal-pod-autoscaler", - api_version=api_version, - ) - hpa_template = load_k8s_yaml("horizontal-pod-autoscaler.yaml", hpa_arguments) - await self._create_hpa( - hpa=hpa_template, - name=k8s_resource_group_name, - ) + # create exactly one of HPA or KEDA ScaledObject, depending if we request more than 0 min_workers + # Right now, keda only will support scaling from 0 to 1 + # TODO support keda scaling from 1 to N as well + if request.build_endpoint_request.min_workers > 0: + # Delete keda scaled object if it exists so exactly one of HPA or KEDA ScaledObject remains + await self._delete_keda_scaled_object( + build_endpoint_request.model_endpoint_record.id + ) + hpa_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="horizontal-pod-autoscaler", + api_version=api_version, + ) + hpa_template = load_k8s_yaml("horizontal-pod-autoscaler.yaml", hpa_arguments) + await self._create_hpa( + hpa=hpa_template, + name=k8s_resource_group_name, + ) + else: # min workers == 0, use keda + # Delete hpa if it exists so exactly one of HPA or KEDA ScaledObject remains + await self._delete_hpa( + build_endpoint_request.model_endpoint_record.id, k8s_resource_group_name + ) + keda_scaled_object_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="keda-scaled-object", + ) + keda_scaled_object_template = load_k8s_yaml( + "keda-scaled-object.yaml", keda_scaled_object_arguments + ) + await self._create_keda_scaled_object( + scaled_object=keda_scaled_object_template, + name=k8s_resource_group_name, + ) service_arguments = get_endpoint_resource_arguments_from_request( k8s_resource_group_name=k8s_resource_group_name, @@ -1204,6 +1293,17 @@ def _get_sync_autoscaling_params( per_worker=per_worker, ) + @staticmethod + def _get_sync_autoscaling_params_from_keda( + keda_config, + ) -> HorizontalAutoscalingEndpointParams: + spec = keda_config["spec"] + return dict( + max_workers=spec.get("maxReplicaCount"), + min_workers=spec.get("minReplicaCount"), + per_worker=1, # TODO dummy value, fill in when we autoscale from 0 to 1 + ) + async def _get_resources( self, endpoint_id: str, deployment_name: str, endpoint_type: ModelEndpointType ) -> ModelEndpointInfraState: @@ -1232,10 +1332,36 @@ async def _get_resources( horizontal_autoscaling_params = self._get_async_autoscaling_params(deployment_config) elif endpoint_type in {ModelEndpointType.SYNC, ModelEndpointType.STREAMING}: autoscaling_client = get_kubernetes_autoscaling_client() - hpa_config = await autoscaling_client.read_namespaced_horizontal_pod_autoscaler( - k8s_resource_group_name, hmi_config.endpoint_namespace - ) - horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config) + custom_object_client = get_kubernetes_custom_objects_client() + try: + hpa_config = await autoscaling_client.read_namespaced_horizontal_pod_autoscaler( + k8s_resource_group_name, hmi_config.endpoint_namespace + ) + except ApiException as e: + if e.status == 404: + hpa_config = None + else: + raise e + try: + keda_scaled_object_config = await custom_object_client.get_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + name=k8s_resource_group_name, + ) + except ApiException: + keda_scaled_object_config = None + if hpa_config is not None: + horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config) + elif keda_scaled_object_config is not None: + horizontal_autoscaling_params = self._get_sync_autoscaling_params_from_keda( + keda_scaled_object_config + ) + else: + raise EndpointResourceInfraException( + f"Could not find autoscaling config for {endpoint_type}" + ) else: raise ValueError(f"Unexpected endpoint type {endpoint_type}") @@ -1326,10 +1452,25 @@ async def _get_all_resources( vpas = [] else: raise + try: + keda_scaled_objects = ( + await custom_objects_client.list_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + ) + )["items"] + except ApiException as e: + if e.status == 404: + keda_scaled_objects = [] + else: + raise deployments_by_name = {deployment.metadata.name: deployment for deployment in deployments} hpas_by_name = {hpa.metadata.name: hpa for hpa in hpas} vpas_by_name = {vpa["metadata"]["name"]: vpa for vpa in vpas} + keda_scaled_objects_by_name = {kso["metadata"]["name"]: kso for kso in keda_scaled_objects} all_config_maps = await self._get_all_config_maps() # can safely assume hpa with same name as deployment corresponds to the same Launch Endpoint logger.info(f"Orphaned hpas: {set(hpas_by_name).difference(set(deployments_by_name))}") @@ -1340,6 +1481,7 @@ async def _get_all_resources( try: hpa_config = hpas_by_name.get(name, None) vpa_config = vpas_by_name.get(name, None) + keda_scaled_object_config = keda_scaled_objects_by_name.get(name, None) common_params = self._get_common_endpoint_params(deployment_config) launch_container = self._get_launch_container(deployment_config) @@ -1355,9 +1497,14 @@ async def _get_all_resources( if hpa_config: # Assume it's a sync endpoint # TODO I think this is correct but only barely, it introduces a coupling between - # an HPA existing and an endpoint being a sync endpoint. The "more correct" + # an HPA (or keda SO) existing and an endpoint being a sync endpoint. The "more correct" # thing to do is to query the db to get the endpoints, but it doesn't belong here horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config) + elif keda_scaled_object_config: + # Also assume it's a sync endpoint + horizontal_autoscaling_params = self._get_sync_autoscaling_params_from_keda( + keda_scaled_object_config + ) else: horizontal_autoscaling_params = self._get_async_autoscaling_params( deployment_config @@ -1427,9 +1574,13 @@ async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) - service_delete_succeeded = await self._delete_service( endpoint_id=endpoint_id, deployment_name=deployment_name ) + # we should have created exactly one of an HPA or a keda scaled object hpa_delete_succeeded = await self._delete_hpa( endpoint_id=endpoint_id, deployment_name=deployment_name ) + keda_scaled_object_succeeded = await self._delete_keda_scaled_object( + endpoint_id=endpoint_id + ) await self._delete_vpa(endpoint_id=endpoint_id) destination_rule_delete_succeeded = await self._delete_destination_rule( @@ -1443,7 +1594,7 @@ async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) - deployment_delete_succeeded and config_map_delete_succeeded and service_delete_succeeded - and hpa_delete_succeeded + and (hpa_delete_succeeded or keda_scaled_object_succeeded) and destination_rule_delete_succeeded and virtual_service_delete_succeeded ) diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 632ec7bf..6c0f9724 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -284,6 +284,14 @@ class HorizontalPodAutoscalerArguments(_BaseEndpointArguments): API_VERSION: str +class KedaScaledObjectArguments(_BaseEndpointArguments): + MIN_WORKERS: int + MAX_WORKERS: int + # CONCURRENCY: float # TODO add in when we scale from 1 -> N pods + REDIS_HOST_PORT: str + REDIS_DB_INDEX: str + + class UserConfigArguments(_BaseEndpointArguments): """Keyword-arguments for substituting into user-config templates.""" @@ -1089,6 +1097,25 @@ def get_endpoint_resource_arguments_from_request( MIN_WORKERS=build_endpoint_request.min_workers, MAX_WORKERS=build_endpoint_request.max_workers, ) + elif endpoint_resource_name == "keda-scaled-object": + return KedaScaledObjectArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + # Scaled Object arguments + MIN_WORKERS=build_endpoint_request.min_workers, + MAX_WORKERS=build_endpoint_request.max_workers, + # CONCURRENCY=build_endpoint_request.concurrency, + REDIS_HOST_PORT=hmi_config.cache_redis_host_port, + REDIS_DB_INDEX=hmi_config.cache_redis_db_index, + ) elif endpoint_resource_name == "service": # Use ClusterIP by default for sync endpoint. # In Circle CI, we use a NodePort to expose the service to CI. diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 85f312b8..48f0e924 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -2558,6 +2558,47 @@ data: target: type: Value averageValue: ${CONCURRENCY} + keda-scaled-object.yaml: |- + apiVersion: keda.sh/v1alpha1 + kind: ScaledObject + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + scaleTargetRef: + name: ${RESOURCE_NAME} + pollingInterval: 5 + cooldownPeriod: 300 + minReplicaCount: ${MIN_WORKERS} + maxReplicaCount: ${MAX_WORKERS} + fallback: + failureThreshold: 3 + replicas: ${MIN_WORKERS} + triggers: + - type: redis + metadata: + address: ${REDIS_HOST_PORT} # Format must be host:port + passwordFromEnv: "" + listName: "launch-endpoint-autoscaling:${ENDPOINT_ID}" + listLength: "100" # something absurdly high so we don't scale past 1 pod + activationListLength: "0" + enableTLS: "false" + unsafeSsl: "false" + databaseIndex: "${REDIS_DB_INDEX}" service.yaml: |- apiVersion: v1 kind: Service diff --git a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py index d2277ab3..8ca5dd61 100644 --- a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py @@ -29,6 +29,10 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: if image_params.requirements_folder: folders_to_include.append(image_params.requirements_folder) + dockerfile_root_folder = image_params.dockerfile.split("/")[0] + if dockerfile_root_folder not in folders_to_include: + folders_to_include.append(dockerfile_root_folder) + build_args = { "BASE_IMAGE": image_params.base_image, } diff --git a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py index 1875d7d0..95901f8a 100644 --- a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py @@ -69,6 +69,21 @@ async def test_create_model_endpoint_use_case_success( assert response_3.endpoint_creation_task_id assert isinstance(response_3, CreateModelEndpointV1Response) + # test special case where sync/streaming endpoint that has 0-1 min-max workers works + request = create_model_endpoint_request_sync.copy() + request.min_workers = 0 + request.max_workers = 1 + response_4 = await use_case.execute(user=user, request=request) + assert response_4.endpoint_creation_task_id + assert isinstance(response_4, CreateModelEndpointV1Response) + + request = create_model_endpoint_request_streaming.copy() + request.min_workers = 0 + request.max_workers = 1 + response_5 = await use_case.execute(user=user, request=request) + assert response_5.endpoint_creation_task_id + assert isinstance(response_5, CreateModelEndpointV1Response) + @pytest.mark.asyncio async def test_create_model_endpoint_use_case_raises_invalid_value_exception( From 72454b7394d5d94ee65d80f09dea217b7c4d948c Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 1 Sep 2023 11:29:47 -0700 Subject: [PATCH 077/425] Completions for vLLM endpoints (#245) * Completions for vLLM endpoints * remove todo * fix --- .../use_cases/llm_model_endpoint_use_cases.py | 79 +++++++++++++++++++ .../inference/vllm/Dockerfile | 2 +- .../inference/vllm/vllm_server.py | 14 +++- 3 files changed, 92 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index e57ef0a6..311a312d 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -744,6 +744,18 @@ def model_output_to_completion_output( except Exception as e: logger.exception(f"Error parsing text-generation-inference output {model_output}") raise e + elif model_content.inference_framework == LLMInferenceFramework.VLLM: + tokens = None + if with_token_probs: + tokens = [ + TokenOutput(token=model_output["tokens"][index], log_prob=list(t.values())[0]) + for index, t in enumerate(model_output["log_probs"]) + ] + return CompletionOutput( + text=model_output["text"], + num_completion_tokens=model_output["count_output_tokens"], + tokens=tokens, + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" @@ -879,6 +891,40 @@ async def execute( output = json.loads(predict_result.result["result"]) + return CompletionSyncV1Response( + request_id=request_id, + output=self.model_output_to_completion_output( + output, model_endpoint, request.return_token_log_probs + ), + ) + elif endpoint_content.inference_framework == LLMInferenceFramework.VLLM: + vllm_args: Any = { + "prompt": request.prompt, + "max_tokens": request.max_new_tokens, + } + if request.stop_sequences is not None: + vllm_args["stop"] = request.stop_sequences + vllm_args["temperature"] = request.temperature + if request.return_token_log_probs: + vllm_args["logprobs"] = 1 + + inference_request = SyncEndpointPredictV1Request( + args=vllm_args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, predict_request=inference_request + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + return CompletionSyncV1Response( + request_id=request_id, + output=None, + ) + + output = json.loads(predict_result.result["result"]) + return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( @@ -991,6 +1037,17 @@ async def execute( if request.temperature > 0: args["parameters"]["temperature"] = request.temperature args["parameters"]["do_sample"] = True + elif model_content.inference_framework == LLMInferenceFramework.VLLM: + args = { + "prompt": request.prompt, + "max_tokens": request.max_new_tokens, + } + if request.stop_sequences is not None: + args["stop"] = request.stop_sequences + args["temperature"] = request.temperature + if request.return_token_log_probs: + args["logprobs"] = 1 + args["stream"] = True inference_request = SyncEndpointPredictV1Request( args=args, @@ -1063,6 +1120,28 @@ async def execute( request_id=request_id, output=None, ) + elif model_content.inference_framework == LLMInferenceFramework.VLLM: + if res.status == TaskStatus.SUCCESS and result is not None: + token = None + if request.return_token_log_probs: + token = TokenOutput( + token=result["result"]["text"], + log_prob=list(result["result"]["log_probs"].values())[0], + ) + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["text"], + finished=result["result"]["finished"], + num_completion_tokens=result["result"]["count_output_tokens"], + token=token, + ), + ) + else: + yield CompletionStreamV1Response( + request_id=request_id, + output=None, + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile b/model-engine/model_engine_server/inference/vllm/Dockerfile index a7f34f3e..b8d440b2 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile @@ -1,7 +1,7 @@ FROM nvcr.io/nvidia/pytorch:22.12-py3 RUN pip uninstall torch -y -RUN pip install llm==0.1.3 ray[air]==2.6.3 +RUN pip install vllm==0.1.4 ray[air]==2.6.3 RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz COPY vllm_server.py /workspace/vllm_server.py diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 62c71ba0..954c143a 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -42,13 +42,18 @@ async def generate(request: Request) -> Response: # TODO: vLLM spends a long time decoding text repeatedly, that for every new token `text` is regenerated, # (see detokenize_incrementally) which we should definitely optimize away. async def stream_results() -> AsyncGenerator[str, None]: + last_output_text = "" async for request_output in results_generator: ret = { - "text": request_output.outputs[0].text, + "text": request_output.outputs[-1].text[len(last_output_text) :], "count_prompt_tokens": len(request_output.prompt_token_ids), "count_output_tokens": len(request_output.outputs[0].token_ids), - "log_probs": request_output.outputs[0].logprobs[-1], + "log_probs": request_output.outputs[0].logprobs[-1] + if sampling_params.logprobs + else None, + "finished": request_output.finished, } + last_output_text = request_output.outputs[-1].text yield f"data:{json.dumps(ret)}\n\n" async def abort_request() -> None: @@ -62,7 +67,11 @@ async def abort_request() -> None: # Non-streaming case final_output = None + tokens = [] + last_output_text = "" async for request_output in results_generator: + tokens.append(request_output.outputs[-1].text[len(last_output_text) :]) + last_output_text = request_output.outputs[-1].text if await request.is_disconnected(): # Abort the request if the client disconnects. await engine.abort(request_id) @@ -76,6 +85,7 @@ async def abort_request() -> None: "count_prompt_tokens": len(final_output.prompt_token_ids), "count_output_tokens": len(final_output.outputs[0].token_ids), "log_probs": final_output.outputs[0].logprobs, + "tokens": tokens, } return Response(content=json.dumps(ret)) From f0d0420829fbaab99da053ea18595ab1c598e61d Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 1 Sep 2023 12:19:13 -0700 Subject: [PATCH 078/425] Download bin files for TGI also (#247) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 311a312d..cfd69f96 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -340,7 +340,7 @@ def load_model_weights_sub_commands( else: if framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: subcommands.append( - f"{s5cmd} --numworkers 512 cp --concurrency 10 --exclude '*.bin' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + f"{s5cmd} --numworkers 512 cp --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}" ) else: subcommands.append( From acebca8e47a9a51a0c1a37025eec1b766fe196cc Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Fri, 1 Sep 2023 12:26:57 -0700 Subject: [PATCH 079/425] update team label (#246) --- .../services/docker_image_batch_job_llm_fine_tuning_service.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py index d5edd4aa..008d5dfc 100644 --- a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py +++ b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py @@ -97,7 +97,8 @@ async def create_fine_tune( gpu_type=di_batch_job_bundle.gpu_type, storage=di_batch_job_bundle.storage, ), - labels=dict(team="infra", product="llm-fine-tune"), + # TODO: Pass user-defined labels + labels=dict(team="egp", product="llm-fine-tune"), annotations=dict(fine_tuned_model=fine_tuned_model), mount_location=di_batch_job_bundle.mount_location, ) From b5cf6a939f14556a12acec665eaaa3f16a75fbda Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Fri, 1 Sep 2023 12:52:31 -0700 Subject: [PATCH 080/425] Ianmacleod/completion sync error throws 4xx (#234) * changing 5xx error to 4xx error * . * . * adding completion stream changes * parsing error dictionary * . * . * fixing error handling for 400 * . * hacky way of fixing completion stream w error message * cleanup * cleanup, add docs * . * fixing indentation on docs --- docs/getting_started.md | 10 ++-- docs/guides/completions.md | 12 +++-- .../model_engine_server/api/llms_v1.py | 10 +++- .../use_cases/llm_model_endpoint_use_cases.py | 47 ++++++++++++++----- 4 files changed, 58 insertions(+), 21 deletions(-) diff --git a/docs/getting_started.md b/docs/getting_started.md index ead931e2..a796bea0 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -74,7 +74,11 @@ stream = Completion.create( ) for response in stream: - if response.output: - print(response.output.text, end="") - sys.stdout.flush() + try: + if response.output: + print(response.output.text, end="") + sys.stdout.flush() + except: # an error occurred + print(stream.text) # print the error message out + break ``` diff --git a/docs/guides/completions.md b/docs/guides/completions.md index a5ea9a06..4719edc3 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -67,6 +67,8 @@ applications. When streaming, tokens will be sent as data-only To enable token streaming, pass `stream=True` to either [Completion.create](../../api/python_client/#llmengine.completion.Completion.create) or [Completion.acreate](../../api/python_client/#llmengine.completion.Completion.acreate). +Note that errors from streaming calls are returned back to the user as plain-text messages and currently need to be handled by the client. + An example of token streaming using the synchronous Completions API looks as follows: === "Token streaming with synchronous API in python" @@ -85,9 +87,13 @@ stream = Completion.create( ) for response in stream: - if response.output: - print(response.output.text, end="") - sys.stdout.flush() + try: + if response.output: + print(response.output.text, end="") + sys.stdout.flush() + except: # an error occurred + print(stream.text) # print the error message out + break ``` ## Async requests diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 50bbbbe8..3e7533da 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -199,6 +199,8 @@ async def create_completion_sync_task( ) from exc except ObjectHasInvalidValueException as exc: raise HTTPException(status_code=400, detail=str(exc)) + except InvalidRequestException as exc: + raise HTTPException(status_code=400, detail=str(exc)) except EndpointUnsupportedInferenceTypeException as exc: raise HTTPException( status_code=400, @@ -230,8 +232,12 @@ async def create_completion_stream_task( ) async def event_generator(): - async for message in response: - yield {"data": message.json()} + try: + async for message in response: + yield {"data": message.json()} + except InvalidRequestException as exc: + yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} + return return EventSourceResponse(event_generator()) except UpstreamServiceError: diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index cfd69f96..13b0fa1c 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -53,6 +53,8 @@ from model_engine_server.domain.exceptions import ( EndpointLabelsException, EndpointUnsupportedInferenceTypeException, + InvalidRequestException, + UpstreamServiceError, ) from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway from model_engine_server.domain.repositories import ModelBundleRepository @@ -741,9 +743,15 @@ def model_output_to_completion_output( num_completion_tokens=model_output["details"]["generated_tokens"], tokens=tokens, ) - except Exception as e: - logger.exception(f"Error parsing text-generation-inference output {model_output}") - raise e + except Exception: + logger.exception(f"Error parsing text-generation-inference output {model_output}.") + if model_output.get("error_type") == "validation": + raise InvalidRequestException(model_output.get("error")) # trigger a 400 + else: + raise UpstreamServiceError( + status_code=500, content=bytes(model_output["error"]) + ) + elif model_content.inference_framework == LLMInferenceFramework.VLLM: tokens = None if with_token_probs: @@ -924,7 +932,6 @@ async def execute( ) output = json.loads(predict_result.result["result"]) - return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( @@ -1106,15 +1113,29 @@ async def execute( token=result["result"]["token"]["text"], log_prob=result["result"]["token"]["logprob"], ) - yield CompletionStreamV1Response( - request_id=request_id, - output=CompletionStreamOutput( - text=result["result"]["token"]["text"], - finished=finished, - num_completion_tokens=num_completion_tokens, - token=token, - ), - ) + try: + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["token"]["text"], + finished=finished, + num_completion_tokens=num_completion_tokens, + token=token, + ), + ) + except Exception: + logger.exception( + f"Error parsing text-generation-inference output. Result: {result['result']}" + ) + if result["result"].get("error_type") == "validation": + raise InvalidRequestException( + result["result"].get("error") + ) # trigger a 400 + else: + raise UpstreamServiceError( + status_code=500, content=result.get("error") + ) # also change llms_v1.py that will return a 500 HTTPException so user can retry + else: yield CompletionStreamV1Response( request_id=request_id, From 27dd1208a10b4510ec12b2394c5760df6ccac2f2 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 1 Sep 2023 15:37:34 -0700 Subject: [PATCH 081/425] Some fixes (#248) --- .../model-engine/templates/balloon_cpu_deployment.yaml | 2 +- clients/python/llmengine/model.py | 6 +++--- .../infra/services/image_cache_service.py | 10 ++++++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/charts/model-engine/templates/balloon_cpu_deployment.yaml b/charts/model-engine/templates/balloon_cpu_deployment.yaml index 1fd9e6c1..a7be9011 100644 --- a/charts/model-engine/templates/balloon_cpu_deployment.yaml +++ b/charts/model-engine/templates/balloon_cpu_deployment.yaml @@ -34,7 +34,7 @@ spec: resources: limits: memory: 28Gi - cpu: 8 + cpu: 6 command: - /bin/bash - -c diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 26bbcf2d..fd18b7b1 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -41,10 +41,10 @@ def create( quantize: Optional[Quantization] = None, checkpoint_path: Optional[str] = None, # General endpoint fields - cpus: int = 32, - memory: str = "192Gi", + cpus: int = 8, + memory: str = "40Gi", storage: str = "96Gi", - gpus: int = 4, + gpus: int = 1, min_workers: int = 0, max_workers: int = 1, per_worker: int = 10, diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index 53b14980..b6343dcc 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -128,7 +128,9 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi if state.resource_state.gpus == 0 and ( ( state.image not in images_to_cache_priority["cpu"] - or last_updated_at + or last_updated_at.replace( + tzinfo=images_to_cache_priority["cpu"][state.image].last_updated_at.tzinfo + ) > images_to_cache_priority["cpu"][state.image].last_updated_at ) and self.docker_repository.image_exists(image_tag, repository_name) @@ -143,7 +145,11 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi if state.resource_state.gpu_type == gpu_type and ( ( state.image not in images_to_cache_priority[key] - or last_updated_at + or last_updated_at.replace( + tzinfo=images_to_cache_priority[key][ + state.image + ].last_updated_at.tzinfo + ) > images_to_cache_priority[key][state.image].last_updated_at ) and self.docker_repository.image_exists(image_tag, repository_name) From d578491e6d47824da2676762873bd36c10796354 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 1 Sep 2023 15:48:45 -0700 Subject: [PATCH 082/425] Higher concurrency limit for gunicorn (#249) --- model-engine/model_engine_server/api/worker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/api/worker.py b/model-engine/model_engine_server/api/worker.py index 776b5dd6..945614da 100644 --- a/model-engine/model_engine_server/api/worker.py +++ b/model-engine/model_engine_server/api/worker.py @@ -1,8 +1,6 @@ from uvicorn.workers import UvicornWorker -# The target concurrency is around 50, so we set the limit to 32 with 4 workers -# for a total concurrency of 128 to allow for some headroom. -CONCURRENCY_LIMIT = 32 +CONCURRENCY_LIMIT = 1000 class LaunchWorker(UvicornWorker): From 77517bf87172de791723b5db2d3954fa117e8406 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:25:33 -0700 Subject: [PATCH 083/425] Pass labels to job config (#251) --- .../docker_image_batch_job_llm_fine_tuning_service.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py index 008d5dfc..9e25b8cf 100644 --- a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py +++ b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py @@ -73,10 +73,14 @@ async def create_fine_tune( if not di_batch_job_bundle.public and di_batch_job_bundle.owner != owner: raise LLMFineTuningMethodNotImplementedException("Fine-tuning method not accessible") + # TODO: Pass user-defined labels + labels = dict(team="egp", product="llm-fine-tune") + batch_job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( created_by=created_by, owner=owner, job_config=dict( + **labels, gateway_url=os.getenv("GATEWAY_URL"), user_id=owner, training_file=training_file, @@ -97,8 +101,7 @@ async def create_fine_tune( gpu_type=di_batch_job_bundle.gpu_type, storage=di_batch_job_bundle.storage, ), - # TODO: Pass user-defined labels - labels=dict(team="egp", product="llm-fine-tune"), + labels=labels, annotations=dict(fine_tuned_model=fine_tuned_model), mount_location=di_batch_job_bundle.mount_location, ) From f209eb3454f4749db1a4630e7f7119b4181c2b36 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:03:03 -0700 Subject: [PATCH 084/425] Bump python client version from 0.0.0beta12 to 0.0.0beta13 (#253) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 4b1b86fb..768cbb86 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta12" +__version__ = "0.0.0.beta13" from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 6352055d..51afec3e 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta12" +version = "0.0.0.beta13" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 6af08c9a..2b51d491 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta12", + version="0.0.0.beta13", packages=find_packages(), ) From e00623095d0c5abd3d7037c7452f4083e49036a9 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 5 Sep 2023 11:16:30 -0700 Subject: [PATCH 085/425] Add comments (#250) --- model-engine/model_engine_server/api/worker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model-engine/model_engine_server/api/worker.py b/model-engine/model_engine_server/api/worker.py index 945614da..d08113b5 100644 --- a/model-engine/model_engine_server/api/worker.py +++ b/model-engine/model_engine_server/api/worker.py @@ -1,5 +1,7 @@ from uvicorn.workers import UvicornWorker +# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit, before adding rate limiting just increase the concurrency +# We'll autoscale at target concurrency of a much lower number (around 50), and this just makes sure we don't 503 with bursty traffic CONCURRENCY_LIMIT = 1000 From 2e969eeeea5268257961d1f0081d018d62e88982 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 5 Sep 2023 11:24:28 -0700 Subject: [PATCH 086/425] Fix vllm docker tensor paralllel (#254) --- model-engine/model_engine_server/inference/vllm/Dockerfile | 3 ++- .../model_engine_server/inference/vllm/requirements.txt | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile b/model-engine/model_engine_server/inference/vllm/Dockerfile index b8d440b2..d03a2c03 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile @@ -1,7 +1,8 @@ FROM nvcr.io/nvidia/pytorch:22.12-py3 RUN pip uninstall torch -y -RUN pip install vllm==0.1.4 ray[air]==2.6.3 +COPY requirements.txt /workspace/requirements.txt +RUN pip install -r requirements.txt RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz COPY vllm_server.py /workspace/vllm_server.py diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index afd523e3..05654616 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1 +1,3 @@ -ray[air]==2.6.3 \ No newline at end of file +ray==2.6.3 +vllm==0.1.4 +pydantic==1.10.12 From b6514c002e7d58a70ef984007c954ccb648a6435 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 5 Sep 2023 13:37:17 -0700 Subject: [PATCH 087/425] Increase liveness timeout for main container (#255) --- .../model-engine/templates/service_template_config_map.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index ad9b7f62..8bdd4eee 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -285,6 +285,12 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + livenessProbe: + httpGet: + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + timeoutSeconds: 5 resources: requests: {{- if eq $device "gpu" }} From d8a5554732014501c56b23023eeaa11df716256a Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 5 Sep 2023 14:37:22 -0700 Subject: [PATCH 088/425] Mark batch jobs as not safe to evict (#256) --- charts/model-engine/templates/service_template_config_map.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 8bdd4eee..070185d8 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -609,6 +609,7 @@ data: version: v1 annotations: ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:{{ $env }}", "launch_job_id:${JOB_ID}"]}]' + cluster-autoscaler.kubernetes.io/safe-to-evict: "false" spec: restartPolicy: Never {{- with $node_selector }} From 2c7c66d844a044f20e13bb8214670df225c43cac Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:34:41 -0700 Subject: [PATCH 089/425] removing timezone tzinfo in favor of utc (#257) * removing timezone tzinfo in favor of utc * adding utc everywhere in last_updated * adding variable assignment * fixing timezone assignment to utc * fixing logging --- .../infra/services/image_cache_service.py | 63 ++++++++++--------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index b6343dcc..5eec2bad 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import Dict, NamedTuple, Tuple +import pytz from model_engine_server.common.config import hmi_config from model_engine_server.common.env_vars import GIT_TAG from model_engine_server.core.config import infra_config @@ -108,7 +109,11 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi if record is None: continue - last_updated_at = record.last_updated_at or datetime.min + last_updated_at = ( + record.last_updated_at.replace(tzinfo=pytz.utc) + if record.last_updated_at is not None + else datetime.min.replace(tzinfo=pytz.utc) + ) has_no_available_workers = int(state.deployment_state.available_workers == 0) is_high_priority = int(state.high_priority is True) @@ -125,36 +130,36 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi image_repository_and_tag = state.image.split("/", 1)[1] repository_name, image_tag = image_repository_and_tag.split(":") - if state.resource_state.gpus == 0 and ( - ( - state.image not in images_to_cache_priority["cpu"] - or last_updated_at.replace( - tzinfo=images_to_cache_priority["cpu"][state.image].last_updated_at.tzinfo + try: + if state.resource_state.gpus == 0 and ( + ( + state.image not in images_to_cache_priority["cpu"] + or last_updated_at.replace(tzinfo=pytz.utc) + > images_to_cache_priority["cpu"][state.image].last_updated_at ) - > images_to_cache_priority["cpu"][state.image].last_updated_at - ) - and self.docker_repository.image_exists(image_tag, repository_name) - ): - images_to_cache_priority["cpu"][state.image] = cache_priority - elif state.resource_state.gpus > 0: - for gpu_type, key in [ - (GpuType.NVIDIA_AMPERE_A10, "a10"), - (GpuType.NVIDIA_AMPERE_A100, "a100"), - (GpuType.NVIDIA_TESLA_T4, "t4"), - ]: - if state.resource_state.gpu_type == gpu_type and ( - ( - state.image not in images_to_cache_priority[key] - or last_updated_at.replace( - tzinfo=images_to_cache_priority[key][ - state.image - ].last_updated_at.tzinfo + and self.docker_repository.image_exists(image_tag, repository_name) + ): + images_to_cache_priority["cpu"][state.image] = cache_priority + elif state.resource_state.gpus > 0: + for gpu_type, key in [ + (GpuType.NVIDIA_AMPERE_A10, "a10"), + (GpuType.NVIDIA_AMPERE_A100, "a100"), + (GpuType.NVIDIA_TESLA_T4, "t4"), + ]: + if state.resource_state.gpu_type == gpu_type and ( + ( + state.image not in images_to_cache_priority[key] + or last_updated_at.replace(tzinfo=pytz.utc) + > images_to_cache_priority[key][state.image].last_updated_at ) - > images_to_cache_priority[key][state.image].last_updated_at - ) - and self.docker_repository.image_exists(image_tag, repository_name) - ): - images_to_cache_priority[key][state.image] = cache_priority + and self.docker_repository.image_exists(image_tag, repository_name) + ): + images_to_cache_priority[key][state.image] = cache_priority + except Exception as exc: + logger.warning( + f"Endpoint {endpoint_id} had an error. Error message: {exc}. Skipping caching ..." + ) + continue images_to_cache = CachedImages(cpu=[], a10=[], a100=[], t4=[]) for key, val in images_to_cache_priority.items(): From 3500e767e64829ee68f1794cece2dcf17670390b Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Thu, 7 Sep 2023 10:06:12 -0700 Subject: [PATCH 090/425] invalid CSV input returns InvalidRequestException (#258) * invalid CSV input returns invalidrequestexception * adding helper function for error handling * . * fixing logging statements --- .../use_cases/llm_fine_tuning_use_cases.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py index 0e837e09..a66fc3ff 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -1,6 +1,7 @@ import csv import datetime import re +from typing import Optional import smart_open from model_engine_server.common.dtos.llms import ( @@ -48,7 +49,7 @@ def ensure_model_name_is_valid_k8s_label(model_name: str): def read_csv_headers(file_location: str): """ - Read the headers of a csv file. Assumes the file exists and is valid. + Read the headers of a csv file. """ with smart_open.open(file_location, transport_params=dict(buffer_size=1024)) as file: csv_reader = csv.DictReader(file) @@ -63,6 +64,26 @@ def are_dataset_headers_valid(file_location: str): return all(required_header in current_headers for required_header in REQUIRED_COLUMNS) +def check_file_is_valid(file_name: Optional[str], file_type: str): + """ + Ensure the file is valid with required columns 'prompt' and 'response', isn't malformatted, and exists. + file_type: 'training' or 'validation' + """ + try: + if file_name is not None and not are_dataset_headers_valid(file_name): + raise InvalidRequestException( + f"Required column headers {','.join(REQUIRED_COLUMNS)} not found in {file_type} dataset" + ) + except FileNotFoundError: + raise InvalidRequestException( + f"Cannot find the {file_type} file. Verify the path and file name are correct." + ) + except csv.Error as exc: + raise InvalidRequestException( + f"Cannot parse the {file_type} dataset as CSV. Details: {exc}" + ) + + class CreateFineTuneV1UseCase: def __init__( self, @@ -140,14 +161,8 @@ async def execute(self, user: User, request: CreateFineTuneRequest) -> CreateFin else: validation_file = request.validation_file - if training_file is not None and not are_dataset_headers_valid(training_file): - raise InvalidRequestException( - f"Required column headers {','.join(REQUIRED_COLUMNS)} not found in training dataset" - ) - if validation_file is not None and not are_dataset_headers_valid(validation_file): - raise InvalidRequestException( - f"Required column headers {','.join(REQUIRED_COLUMNS)} not found in validation dataset" - ) + check_file_is_valid(training_file, "training") + check_file_is_valid(validation_file, "validation") await self.llm_fine_tune_events_repository.initialize_events(user.team_id, fine_tuned_model) fine_tune_id = await self.llm_fine_tuning_service.create_fine_tune( From b59199897332e3085f7520a5b302d4de4e480295 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:00:44 -0700 Subject: [PATCH 091/425] bumping image tag (#262) --- .../model_engine_server/inference/vllm/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 05654616..05047c7b 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,3 +1,3 @@ ray==2.6.3 -vllm==0.1.4 +vllm==0.1.5 pydantic==1.10.12 From d5a15478901812fee628871df974d4504e0d876d Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 12 Sep 2023 13:14:04 -0700 Subject: [PATCH 092/425] Integrate LightLLM (#273) * Integrate LightLLM * wip --- charts/model-engine/values_circleci.yaml | 1 + .../model_engine_server/common/config.py | 1 + .../domain/entities/llm_entity.py | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 192 ++++++++++++++++++ .../service_config_circleci.yaml | 1 + 5 files changed, 196 insertions(+) diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 8279f00e..e770f31d 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -141,6 +141,7 @@ config: istio_enabled: true tgi_repository: "text-generation-inference" vllm_repository: "vllm" + lightllm_repository: "lightllm" hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET" # Service Account diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 25cf5453..64098ea1 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -56,6 +56,7 @@ class HostedModelInferenceServiceConfig: datadog_trace_enabled: bool tgi_repository: str vllm_repository: str + lightllm_repository: str @classmethod def from_yaml(cls, yaml_path): diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index 80344b54..dfb6f63c 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -11,6 +11,7 @@ class LLMInferenceFramework(str, Enum): DEEPSPEED = "deepspeed" TEXT_GENERATION_INFERENCE = "text_generation_inference" VLLM = "vllm" + LIGHTLLM = "lightllm" class Quantization(str, Enum): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 13b0fa1c..92f9588a 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -119,6 +119,15 @@ "falcon-40b": "tiiuae/falcon-40b", "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", }, + LLMInferenceFramework.LIGHTLLM: { + "llama-7b": "decapoda-research/llama-7b-hf", + "llama-2-7b": "meta-llama/Llama-2-7b-hf", + "llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", + "llama-2-13b": "meta-llama/Llama-2-13b-hf", + "llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", + "llama-2-70b": "meta-llama/Llama-2-70b-hf", + "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", + }, } @@ -221,6 +230,15 @@ async def create_model_bundle( num_shards, checkpoint_path, ) + elif framework == LLMInferenceFramework.LIGHTLLM: + bundle_id = await self.create_lightllm_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + checkpoint_path, + ) else: raise ObjectHasInvalidValueException( f"Framework {framework} is not supported for source {source}." @@ -499,6 +517,86 @@ async def create_vllm_bundle( ) ).model_bundle_id + async def create_lightllm_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + checkpoint_path: Optional[str], + ): + command = [] + + # TODO: incorporate auto calculate max_total_token_num from https://github.com/ModelTC/lightllm/pull/81 + max_total_token_num = 6000 # LightLLM default + if num_shards == 1: + max_total_token_num = 15000 # Default for Llama 2 7B on 1 x A10 + elif num_shards == 2: + max_total_token_num = 21000 # Default for Llama 2 13B on 2 x A10 + elif num_shards == 4: + max_total_token_num = 70000 # Default for Llama 2 13B on 4 x A10 + max_req_input_len = 2047 + max_req_total_len = 2048 + if "llama-2" in model_name: + max_req_input_len = 4095 + max_req_total_len = 4096 + + subcommands = [] + if checkpoint_path is not None: + if checkpoint_path.startswith("s3://"): + final_weights_folder = "model_files" + subcommands += self.load_model_weights_sub_commands( + LLMInferenceFramework.LIGHTLLM, + framework_image_tag, + checkpoint_path, + final_weights_folder, + ) + else: + raise ObjectHasInvalidValueException( + f"Not able to load checkpoint path {checkpoint_path}." + ) + else: + final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] + + subcommands.append( + f"python -m lightllm.server.api_server --model_dir {final_weights_folder} --port 5005 --tp {num_shards} --max_total_token_num {max_total_token_num} --max_req_input_len {max_req_input_len} --max_req_total_len {max_req_total_len} --tokenizer_mode auto" + ) + + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] + + return ( + await self.create_model_bundle_use_case.execute( + user, + CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.lightllm_repository, + tag=framework_image_tag, + command=command, + streaming_command=command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/health", + predict_route="/generate", + streaming_predict_route="/generate_stream", + env={}, + ), + metadata={}, + ), + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + async def execute( self, user: User, request: CreateLLMModelEndpointV1Request ) -> CreateLLMModelEndpointV1Response: @@ -764,6 +862,19 @@ def model_output_to_completion_output( num_completion_tokens=model_output["count_output_tokens"], tokens=tokens, ) + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: + print(model_output) + tokens = None + if with_token_probs: + tokens = [ + TokenOutput(token=t["text"], log_prob=t["logprob"]) + for t in model_output["tokens"] + ] + return CompletionOutput( + text=model_output["generated_text"][0], + num_completion_tokens=model_output["count_output_tokens"], + tokens=tokens, + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" @@ -925,6 +1036,44 @@ async def execute( topic=model_endpoint.record.destination, predict_request=inference_request ) + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + return CompletionSyncV1Response( + request_id=request_id, + output=None, + ) + + output = json.loads(predict_result.result["result"]) + return CompletionSyncV1Response( + request_id=request_id, + output=self.model_output_to_completion_output( + output, model_endpoint, request.return_token_log_probs + ), + ) + elif endpoint_content.inference_framework == LLMInferenceFramework.LIGHTLLM: + lightllm_args: Any = { + "inputs": request.prompt, + "parameters": { + "max_new_tokens": request.max_new_tokens, + }, + } + # TODO: implement stop sequences + if request.temperature > 0: + lightllm_args["parameters"]["temperature"] = request.temperature + lightllm_args["parameters"]["do_sample"] = True + else: + lightllm_args["parameters"]["do_sample"] = False + if request.return_token_log_probs: + lightllm_args["parameters"]["return_details"] = True + + inference_request = SyncEndpointPredictV1Request( + args=lightllm_args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, predict_request=inference_request + ) + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: return CompletionSyncV1Response( request_id=request_id, @@ -1055,6 +1204,25 @@ async def execute( if request.return_token_log_probs: args["logprobs"] = 1 args["stream"] = True + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: + args = { + "inputs": request.prompt, + "parameters": { + "max_new_tokens": request.max_new_tokens, + }, + } + # TODO: stop sequences + if request.temperature > 0: + args["parameters"]["temperature"] = request.temperature + args["parameters"]["do_sample"] = True + else: + args["parameters"]["do_sample"] = False + if request.return_token_log_probs: + args["parameters"]["return_details"] = True + else: + raise EndpointUnsupportedInferenceTypeException( + f"Unsupported inference framework {model_content.inference_framework}" + ) inference_request = SyncEndpointPredictV1Request( args=args, @@ -1163,6 +1331,30 @@ async def execute( request_id=request_id, output=None, ) + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: + if res.status == TaskStatus.SUCCESS and result is not None: + print(result) + token = None + num_completion_tokens += 1 + if request.return_token_log_probs: + token = TokenOutput( + token=result["result"]["token"]["text"], + log_prob=result["result"]["token"]["logprob"], + ) + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["token"]["text"], + finished=result["result"]["finished"], + num_completion_tokens=num_completion_tokens, + token=token, + ), + ) + else: + yield CompletionStreamV1Response( + request_id=request_id, + output=None, + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index 04ea1da8..25b55c7a 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -56,6 +56,7 @@ datadog_trace_enabled: false istio_enabled: true tgi_repository: "text-generation-inference" vllm_repository: "vllm" +lightllm_repository: "lightllm" # S3 access hf_user_fine_tuned_weights_prefix: "s3://test-bucket" From a4252637e55ee809007092bf08e1ac4e6044ba8e Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:54:07 -0700 Subject: [PATCH 093/425] removing datadog interfaces logging (#275) --- model-engine/model_engine_server/api/dependencies.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 2997542a..e9d8424b 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -272,10 +272,8 @@ async def get_external_interfaces(): try: from plugins.dependencies import get_external_interfaces as get_custom_external_interfaces - logger.info("Using custom external interfaces") yield get_custom_external_interfaces() except ModuleNotFoundError: - logger.info("Using default external interfaces") yield get_default_external_interfaces() finally: pass @@ -287,10 +285,8 @@ async def get_external_interfaces_read_only(): get_external_interfaces_read_only as get_custom_external_interfaces_read_only, ) - logger.info("Using custom external interfaces") yield get_custom_external_interfaces_read_only() except ModuleNotFoundError: - logger.info("Using default external interfaces") yield get_default_external_interfaces_read_only() finally: pass From 017bd7bda4a4ccccdd3d83fd2be09b6824a766bb Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:26:30 -0700 Subject: [PATCH 094/425] Ianmacleod/vllm by default (#274) * bumping image tag * new git sha * increase vllm version --- .../model_engine_server/inference/vllm/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 05047c7b..7d8f12fc 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,3 +1,3 @@ ray==2.6.3 -vllm==0.1.5 +vllm==0.1.7 pydantic==1.10.12 From 0b32a27e4cc74a21948d8b5035588d50e11b3236 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:57:55 -0700 Subject: [PATCH 095/425] update docs (#276) * update docs * update docs with note about jupyter --- docs/getting_started.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/getting_started.md b/docs/getting_started.md index a796bea0..46741d1b 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -37,6 +37,13 @@ export SCALE_API_KEY="[Your API key]" You can also add in the line above to your `.zshrc` or `.bash_profile` so it's automatically set for future sessions. +Alternatively, you can also set your API key using either of the following patterns: +``` +llmengine.api_engine.api_key = "abc" +llmengine.api_engine.set_api_key("abc") +``` +These patterns are useful for Jupyter Notebook users to set API keys without the need for using `os.environ`. + ## Example Code ### Sample Completion From 9b519fa4f6996bd741abb54975999d0411cd1464 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Fri, 15 Sep 2023 09:54:39 -0700 Subject: [PATCH 096/425] Ianmacleod/add model delete (#261) * starting work on model.delete() * adding server implementation * adding client changes to delete based on id, not name of model * changing use case to be based on name, not id * adding client model changes * adding unit tests * updating tests * . * . * fix * fix * pls work * addressing feedback from review * reverting accidental change to conftest * add new test --- clients/python/llmengine/model.py | 12 +-- .../model_engine_server/api/llms_v1.py | 42 +++++++++ .../model_engine_server/common/dtos/llms.py | 4 + .../use_cases/llm_model_endpoint_use_cases.py | 42 ++++++++- .../tests/unit/domain/test_llm_use_cases.py | 86 +++++++++++++++++++ 5 files changed, 179 insertions(+), 7 deletions(-) diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index fd18b7b1..021a9ff5 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -334,7 +334,7 @@ def list(cls) -> ListLLMEndpointsResponse: return ListLLMEndpointsResponse.parse_obj(response) @classmethod - def delete(cls, model: str) -> DeleteLLMEndpointResponse: + def delete(cls, model_endpoint_name: str) -> DeleteLLMEndpointResponse: """ Deletes an LLM model. @@ -345,11 +345,11 @@ def delete(cls, model: str) -> DeleteLLMEndpointResponse: Engine, an error will be thrown. Args: - model (`str`): - Name of the model + model_endpoint_name (`str`): + Name of the model endpoint to be deleted Returns: - response: whether the model was successfully deleted + response: whether the model endpoint was successfully deleted === "Deleting model in Python" ```python @@ -366,7 +366,9 @@ def delete(cls, model: str) -> DeleteLLMEndpointResponse: } ``` """ - response = cls._delete(f"v1/llm/model-endpoints/{model}", timeout=DEFAULT_TIMEOUT) + response = cls._delete( + f"v1/llm/model-endpoints/{model_endpoint_name}", timeout=DEFAULT_TIMEOUT + ) return DeleteLLMEndpointResponse.parse_obj(response) @classmethod diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 3e7533da..af114adb 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -21,6 +21,7 @@ CreateFineTuneResponse, CreateLLMModelEndpointV1Request, CreateLLMModelEndpointV1Response, + DeleteLLMEndpointResponse, GetFineTuneEventsResponse, GetFineTuneResponse, GetLLMModelEndpointV1Response, @@ -40,9 +41,11 @@ ) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.exceptions import ( + EndpointDeleteFailedException, EndpointLabelsException, EndpointResourceInvalidRequestException, EndpointUnsupportedInferenceTypeException, + ExistingEndpointOperationInProgressException, InvalidRequestException, LLMFineTuningMethodNotImplementedException, LLMFineTuningQuotaReached, @@ -59,6 +62,7 @@ CompletionStreamV1UseCase, CompletionSyncV1UseCase, CreateLLMModelEndpointV1UseCase, + DeleteLLMEndpointByNameUseCase, GetLLMModelEndpointByNameV1UseCase, ListLLMModelEndpointsV1UseCase, ModelDownloadV1UseCase, @@ -384,3 +388,41 @@ async def download_model_endpoint( status_code=404, detail="The requested fine-tuned model could not be found.", ) from exc + + +@llm_router_v1.delete( + "/model-endpoints/{model_endpoint_name}", response_model=DeleteLLMEndpointResponse +) +async def delete_llm_model_endpoint( + model_endpoint_name: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> DeleteLLMEndpointResponse: + add_trace_resource_name("llm_model_endpoints_delete") + logger.info(f"DELETE /model-endpoints/{model_endpoint_name} for {auth}") + try: + use_case = DeleteLLMEndpointByNameUseCase( + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + model_endpoint_service=external_interfaces.model_endpoint_service, + ) + return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) + except (ObjectNotFoundException) as exc: + raise HTTPException( + status_code=404, + detail="The requested model endpoint could not be found.", + ) from exc + except (ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=403, + detail="You don't have permission to delete the requested model endpoint.", + ) from exc + except ExistingEndpointOperationInProgressException as exc: + raise HTTPException( + status_code=409, + detail="Existing operation on endpoint in progress, try again later.", + ) from exc + except EndpointDeleteFailedException as exc: # pragma: no cover + raise HTTPException( + status_code=500, + detail="deletion of endpoint failed.", + ) from exc diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index d62f7992..2735f577 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -228,3 +228,7 @@ class ModelDownloadResponse(BaseModel): urls: Dict[str, str] = Field( ..., description="Dictionary of (file_name, url) pairs to download the model from." ) + + +class DeleteLLMEndpointResponse(BaseModel): + deleted: bool diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 92f9588a..af5c0b9c 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -21,6 +21,7 @@ CompletionSyncV1Response, CreateLLMModelEndpointV1Request, CreateLLMModelEndpointV1Response, + DeleteLLMEndpointResponse, GetLLMModelEndpointV1Response, ListLLMModelEndpointsV1Response, ModelDownloadRequest, @@ -779,8 +780,45 @@ async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndp return _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) -class DeleteLLMModelEndpointByIdV1UseCase: - pass +class DeleteLLMEndpointByNameUseCase: + """ + Use case for deleting an LLM Model Endpoint of a given user by endpoint name. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + ): + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + + async def execute(self, user: User, model_endpoint_name: str) -> DeleteLLMEndpointResponse: + """ + Runs the use case to delete the LLM endpoint owned by the user with the given name. + + Args: + user: The owner of the model endpoint. + model_endpoint_name: The name of the model endpoint. + + Returns: + A response object that contains a boolean indicating if deletion was successful. + + Raises: + ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectNotAuthorizedException: If the owner does not own the model endpoint. + """ + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.user_id, name=model_endpoint_name, order_by=None + ) + if len(model_endpoints) != 1: + raise ObjectNotFoundException + model_endpoint = model_endpoints[0] + if not self.authz_module.check_access_write_owned_entity(user, model_endpoint.record): + raise ObjectNotAuthorizedException + await self.model_endpoint_service.delete_model_endpoint(model_endpoint.record.id) + return DeleteLLMEndpointResponse(deleted=True) def deepspeed_result_to_tokens(result: Dict[str, Any]) -> List[TokenOutput]: diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index b841ba1b..da7a6451 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -35,6 +35,7 @@ CompletionStreamV1UseCase, CompletionSyncV1UseCase, CreateLLMModelEndpointV1UseCase, + DeleteLLMEndpointByNameUseCase, GetLLMModelEndpointByNameV1UseCase, ModelDownloadV1UseCase, ) @@ -869,3 +870,88 @@ async def test_download_nonexistent_model_raises_not_found( ) with pytest.raises(ObjectNotFoundException): await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +async def test_delete_model_success( + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + test_api_key: str, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + use_case = DeleteLLMEndpointByNameUseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response = await use_case.execute( + user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name + ) + remaining_endpoint_model_service = await fake_model_endpoint_service.get_model_endpoint( + llm_model_endpoint_sync[0].record.id + ) + assert remaining_endpoint_model_service is None + assert response.deleted is True + + +@pytest.mark.asyncio +async def test_delete_nonexistent_model_raises_not_found( + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + test_api_key: str, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + use_case = DeleteLLMEndpointByNameUseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(ObjectNotFoundException): + await use_case.execute(user=user, model_endpoint_name="nonexistent-model") + + +@pytest.mark.asyncio +async def test_delete_unauthorized_model_raises_not_authorized( + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + use_case = DeleteLLMEndpointByNameUseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id="fakeapikey", team_id="fakeapikey", is_privileged_user=True) + with pytest.raises(ObjectNotAuthorizedException): + await use_case.execute( + user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name + ) + + +@pytest.mark.asyncio +async def test_delete_public_inference_model_raises_not_authorized( + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + test_api_key, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + use_case = DeleteLLMEndpointByNameUseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User( + user_id="fakeapikey", team_id="faketeam", is_privileged_user=True + ) # write access is based on team_id, so team_id != owner's team_id + with pytest.raises( + ObjectNotAuthorizedException + ): # user cannot delete public inference model they don't own + await use_case.execute( + user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name + ) From 7f9e38d1bd08ce95f448da19c6f7e7e73eda1bda Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:45:09 -0700 Subject: [PATCH 097/425] =?UTF-8?q?fixing=20cacher,=20tested=20in=20prod?= =?UTF-8?q?=20version=20of=20cacher=20deployment=20in=20k8s=20and=E2=80=A6?= =?UTF-8?q?=20(#278)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fixing cacher, tested in prod version of cacher deployment in k8s and seems to be working * update logging --- .../infra/services/image_cache_service.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index 5eec2bad..5bae890b 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -64,7 +64,7 @@ def _cache_finetune_llm_images( is_high_priority=1, # make it a high priority has_no_available_workers=1, # assuming it has no available workers so that it will be at top after reverse sorting - last_updated_at=datetime.max, + last_updated_at=datetime.max.replace(tzinfo=pytz.utc), # setting it to max to ensure it will be at top after reverse sorting ) @@ -135,7 +135,9 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi ( state.image not in images_to_cache_priority["cpu"] or last_updated_at.replace(tzinfo=pytz.utc) - > images_to_cache_priority["cpu"][state.image].last_updated_at + > images_to_cache_priority["cpu"][state.image].last_updated_at.replace( + tzinfo=pytz.utc + ) ) and self.docker_repository.image_exists(image_tag, repository_name) ): @@ -150,7 +152,9 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi ( state.image not in images_to_cache_priority[key] or last_updated_at.replace(tzinfo=pytz.utc) - > images_to_cache_priority[key][state.image].last_updated_at + > images_to_cache_priority[key][ + state.image + ].last_updated_at.replace(tzinfo=pytz.utc) ) and self.docker_repository.image_exists(image_tag, repository_name) ): @@ -162,9 +166,13 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi continue images_to_cache = CachedImages(cpu=[], a10=[], a100=[], t4=[]) - for key, val in images_to_cache_priority.items(): - images_to_cache[key] = sorted( # type: ignore - val.keys(), key=lambda image: val[image], reverse=True - )[:IMAGES_TO_CACHE_PER_INSTANCE_TYPE] + try: + for key, val in images_to_cache_priority.items(): + images_to_cache[key] = sorted( # type: ignore + val.keys(), key=lambda image: val[image], reverse=True + )[:IMAGES_TO_CACHE_PER_INSTANCE_TYPE] + logger.info("sorted images to cache successfully") + except Exception as exc: + logger.warning(f"sorting had an error. Error message: {exc}. Skipping sorting...") await self.image_cache_gateway.create_or_update_image_cache(images_to_cache) From f75a9290ce51c1abeae66681ea98631944b8aa6d Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Sun, 17 Sep 2023 22:21:49 -0700 Subject: [PATCH 098/425] Ianmacleod/fix cacher (#279) * fixing cacher, tested in prod version of cacher deployment in k8s and seems to be working * update logging * removing unnecessary logging statements and try/except blocks --- .../infra/services/image_cache_service.py | 71 ++++++++----------- 1 file changed, 30 insertions(+), 41 deletions(-) diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index 5bae890b..47ab21d3 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -130,49 +130,38 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi image_repository_and_tag = state.image.split("/", 1)[1] repository_name, image_tag = image_repository_and_tag.split(":") - try: - if state.resource_state.gpus == 0 and ( - ( - state.image not in images_to_cache_priority["cpu"] - or last_updated_at.replace(tzinfo=pytz.utc) - > images_to_cache_priority["cpu"][state.image].last_updated_at.replace( - tzinfo=pytz.utc - ) + if state.resource_state.gpus == 0 and ( + ( + state.image not in images_to_cache_priority["cpu"] + or last_updated_at.replace(tzinfo=pytz.utc) + > images_to_cache_priority["cpu"][state.image].last_updated_at.replace( + tzinfo=pytz.utc ) - and self.docker_repository.image_exists(image_tag, repository_name) - ): - images_to_cache_priority["cpu"][state.image] = cache_priority - elif state.resource_state.gpus > 0: - for gpu_type, key in [ - (GpuType.NVIDIA_AMPERE_A10, "a10"), - (GpuType.NVIDIA_AMPERE_A100, "a100"), - (GpuType.NVIDIA_TESLA_T4, "t4"), - ]: - if state.resource_state.gpu_type == gpu_type and ( - ( - state.image not in images_to_cache_priority[key] - or last_updated_at.replace(tzinfo=pytz.utc) - > images_to_cache_priority[key][ - state.image - ].last_updated_at.replace(tzinfo=pytz.utc) - ) - and self.docker_repository.image_exists(image_tag, repository_name) - ): - images_to_cache_priority[key][state.image] = cache_priority - except Exception as exc: - logger.warning( - f"Endpoint {endpoint_id} had an error. Error message: {exc}. Skipping caching ..." ) - continue - + and self.docker_repository.image_exists(image_tag, repository_name) + ): + images_to_cache_priority["cpu"][state.image] = cache_priority + elif state.resource_state.gpus > 0: + for gpu_type, key in [ + (GpuType.NVIDIA_AMPERE_A10, "a10"), + (GpuType.NVIDIA_AMPERE_A100, "a100"), + (GpuType.NVIDIA_TESLA_T4, "t4"), + ]: + if state.resource_state.gpu_type == gpu_type and ( + ( + state.image not in images_to_cache_priority[key] + or last_updated_at.replace(tzinfo=pytz.utc) + > images_to_cache_priority[key][state.image].last_updated_at.replace( + tzinfo=pytz.utc + ) + ) + and self.docker_repository.image_exists(image_tag, repository_name) + ): + images_to_cache_priority[key][state.image] = cache_priority images_to_cache = CachedImages(cpu=[], a10=[], a100=[], t4=[]) - try: - for key, val in images_to_cache_priority.items(): - images_to_cache[key] = sorted( # type: ignore - val.keys(), key=lambda image: val[image], reverse=True - )[:IMAGES_TO_CACHE_PER_INSTANCE_TYPE] - logger.info("sorted images to cache successfully") - except Exception as exc: - logger.warning(f"sorting had an error. Error message: {exc}. Skipping sorting...") + for key, val in images_to_cache_priority.items(): + images_to_cache[key] = sorted( # type: ignore + val.keys(), key=lambda image: val[image], reverse=True + )[:IMAGES_TO_CACHE_PER_INSTANCE_TYPE] await self.image_cache_gateway.create_or_update_image_cache(images_to_cache) From 1ef384750a95c92ca36554aa646f4661c94273e6 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Mon, 18 Sep 2023 13:15:15 -0700 Subject: [PATCH 099/425] add vllm to inference framework enum (#280) --- clients/python/llmengine/data_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 34eaf0f9..c6f9edb7 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -14,6 +14,7 @@ class LLMInferenceFramework(str, Enum): DEEPSPEED = "deepspeed" TEXT_GENERATION_INFERENCE = "text_generation_inference" + VLLM = "vllm" class LLMSource(str, Enum): From 26119b6b4061f3514d96458acd5ff30c5ae4029f Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Mon, 18 Sep 2023 14:30:47 -0700 Subject: [PATCH 100/425] Ianmacleod/update client enum with lightllm (#281) * add vllm to inference framework enum * add lightllm as well --- clients/python/llmengine/data_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index c6f9edb7..1a8baba3 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -15,6 +15,7 @@ class LLMInferenceFramework(str, Enum): DEEPSPEED = "deepspeed" TEXT_GENERATION_INFERENCE = "text_generation_inference" VLLM = "vllm" + LIGHTLLM = "lightllm" class LLMSource(str, Enum): From 02f9876a781507d25cfd2a26ef3f514a18213e6e Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 18 Sep 2023 15:29:53 -0700 Subject: [PATCH 101/425] Some fixes for endpoints (#283) --- .../templates/service_template_config_map.yaml | 7 ++----- model-engine/model_engine_server/common/resource_limits.py | 4 ++-- .../inference/configs/service--http_forwarder.yaml | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 070185d8..90c78441 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -127,6 +127,7 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: 0.1 @@ -172,6 +173,7 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: 0.1 @@ -285,11 +287,6 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 - livenessProbe: - httpGet: - path: ${HEALTHCHECK_ROUTE} - port: ${USER_CONTAINER_PORT} - initialDelaySeconds: ${READINESS_INITIAL_DELAY} timeoutSeconds: 5 resources: requests: diff --git a/model-engine/model_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py index ee19af55..502f7dd7 100644 --- a/model-engine/model_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -43,8 +43,8 @@ GpuType.NVIDIA_AMPERE_A100E: A100_INSTANCE_LIMITS, } -FORWARDER_CPU_USAGE = 0.5 -FORWARDER_MEMORY_USAGE = "1Gi" +FORWARDER_CPU_USAGE = 1 +FORWARDER_MEMORY_USAGE = "2Gi" FORWARDER_STORAGE_USAGE = "1G" logger = make_logger(filename_wo_ext(__name__)) diff --git a/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml index f7f046d5..f0e3eef1 100644 --- a/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml +++ b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml @@ -16,4 +16,4 @@ forwarder: batch_route: null model_engine_unwrap: true serialize_results_as_string: false - max_concurrency: 20 + max_concurrency: 100 From 5acc4b271e3bc09c539cf907809ebb85d49968f0 Mon Sep 17 00:00:00 2001 From: Phil Chen <92065453+phil-scale@users.noreply.github.com> Date: Mon, 18 Sep 2023 17:28:44 -0700 Subject: [PATCH 102/425] Add actual pytests to integration tests (#227) --- .circleci/config.yml | 37 + charts/model-engine/templates/_helpers.tpl | 4 + .../templates/service_config_map.yaml | 28 + .../service_template_config_map.yaml | 2 +- charts/model-engine/values_circleci.yaml | 2 +- integration_tests/__init__.py | 0 integration_tests/rest_api_utils.py | 793 ++++++++++++++++++ integration_tests/test_batch_jobs.py | 25 + integration_tests/test_bundles.py | 23 + integration_tests/test_docs.py | 210 +++++ integration_tests/test_endpoints.py | 228 +++++ integration_tests/test_file.py | 27 + integration_tests/test_fine_tunes.py | 31 + model-engine/Dockerfile | 2 + .../model_engine_server/api/dependencies.py | 5 +- .../entrypoints/k8s_cache.py | 4 +- .../inference/forwarding/echo_server.py | 56 ++ .../service_template_config_map_circleci.yaml | 66 +- .../infra/gateways/s3_file_storage_gateway.py | 4 +- .../infra/repositories/__init__.py | 2 + .../repositories/fake_docker_repository.py | 21 + .../service_builder/tasks_v1.py | 5 +- .../service_config_circleci.yaml | 2 +- .../services/test_image_cache_service.py | 10 +- 24 files changed, 1548 insertions(+), 39 deletions(-) create mode 100644 integration_tests/__init__.py create mode 100644 integration_tests/rest_api_utils.py create mode 100644 integration_tests/test_batch_jobs.py create mode 100644 integration_tests/test_bundles.py create mode 100644 integration_tests/test_docs.py create mode 100644 integration_tests/test_endpoints.py create mode 100644 integration_tests/test_file.py create mode 100644 integration_tests/test_fine_tunes.py create mode 100644 model-engine/model_engine_server/inference/forwarding/echo_server.py create mode 100644 model-engine/model_engine_server/infra/repositories/fake_docker_repository.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 78751757..63763dea 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -138,12 +138,48 @@ jobs: name: Pre-load model-engine image to minikube command: | minikube --logtostderr -v 1 image load model-engine:$CIRCLE_SHA1 + - run: + name: Pre-load integration test images to minikube + command: | + docker build -f model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile \ + --build-arg BASE_IMAGE=pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime \ + --build-arg REQUIREMENTS_FILE="$CIRCLE_SHA1-base-requirements.txt" \ + -t temp:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1 . + + touch $CIRCLE_SHA1-requirements.txt + echo -e "cloudpickle==2.1.0\npyyaml==6.0" > $CIRCLE_SHA1-requirements.txt + + DOCKER_BUILDKIT=1 docker build -f model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile \ + --build-arg BASE_IMAGE=temp:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1 \ + --build-arg REQUIREMENTS_FILE="$CIRCLE_SHA1-requirements.txt" \ + -t $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1-021694 . + rm $CIRCLE_SHA1-requirements.txt + + minikube --logtostderr -v 1 image load $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1-021694 - run: name: Install helm chart command: | pushd $HOME/project/charts cat model-engine/values_circleci.yaml | envsubst > model-engine/values_circleci_subst.yaml helm install model-engine model-engine --values model-engine/values_circleci_subst.yaml --set tag=$CIRCLE_SHA1 --atomic --debug + - run: + name: Change python version to 3.8.12 + command: | + pyenv install 3.8.12 + pyenv global 3.8.12 + - run: + name: Install integration test dependencies + command: | + sudo apt-get update && sudo apt-get install -y libcurl4-openssl-dev libssl-dev python3-dev + pip install -r model-engine/requirements.txt + - install_client + - install_server + - run: + name: Run integration tests + command: | + pushd $HOME/project + kubectl port-forward svc/model-engine 5001:80 & + GIT_TAG=$CIRCLE_SHA1 pytest integration_tests executors: ubuntu-large: @@ -188,6 +224,7 @@ commands: - run: name: Install LLM Engine client command: | + pip install --upgrade pip pip install -e $HOME/project/clients/python run_unit_tests_python_client: description: Unit tests of the python client diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index 8bcebe2c..9389bcac 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -281,6 +281,10 @@ env: - name: REDIS_AUTH_TOKEN value: {{ .Values.redis.auth }} {{- end }} + {{- if eq .Values.context "circleci" }} + - name: CIRCLECI + value: "true" + {{- end }} {{- end }} {{- define "modelEngine.serviceEnvGitTagFromHelmVar" }} diff --git a/charts/model-engine/templates/service_config_map.yaml b/charts/model-engine/templates/service_config_map.yaml index 9234296d..b6809b22 100644 --- a/charts/model-engine/templates/service_config_map.yaml +++ b/charts/model-engine/templates/service_config_map.yaml @@ -23,4 +23,32 @@ data: {{ $key }}: {{ $value | quote }} {{- end }} {{- end }} + +--- + +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "modelEngine.fullname" . }}-service-config + namespace: {{ .Values.config.values.launch.endpoint_namespace }} + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" +data: + launch_service_config: |- + datadog_trace_enabled: {{ .Values.datadog_trace_enabled | default false | quote }} + {{- with .Values.config.values.launch }} + {{- range $key, $value := . }} + {{ $key }}: {{ $value | quote }} + {{- end }} + {{- end }} + infra_service_config: |- + env: {{ .Values.context | quote }} + {{- with .Values.config.values.infra }} + {{- range $key, $value := . }} + {{ $key }}: {{ $value | quote }} + {{- end }} + {{- end }} {{- end }} diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 90c78441..3bd9674f 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -143,7 +143,7 @@ data: name: http {{- else if eq $mode "streaming" }} - name: http-forwarder - image: {{ $forwarder_repository }}:{{ $tag }} + image: {{ $forwarder_repository }}:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index e770f31d..657c5f50 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -82,7 +82,7 @@ config: dns_host_domain: localhost default_region: us-west-2 ml_account_id: "$CIRCLECI_AWS_ACCOUNT_ID" - docker_repo_prefix: "CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com" + docker_repo_prefix: "$CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com" redis_host: redis-message-broker-master.default s3_bucket: "$CIRCLECI_AWS_S3_BUCKET" profile_ml_worker: "default" diff --git a/integration_tests/__init__.py b/integration_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py new file mode 100644 index 00000000..fb0dd7c3 --- /dev/null +++ b/integration_tests/rest_api_utils.py @@ -0,0 +1,793 @@ +import asyncio +import inspect +import json +import os +import time +from typing import Any, Dict, List, Sequence + +import aiohttp +import requests +from model_engine_server.common.dtos.tasks import TaskStatus +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +_DEFAULT_BASE_PATH = "http://localhost:5001" +BASE_PATH = os.environ.get("BASE_PATH", _DEFAULT_BASE_PATH) +print(f"Integration tests using gateway {BASE_PATH=}") +DEFAULT_NETWORK_TIMEOUT_SEC = 10 + +# Generate some fake 24-character user IDs. +# We don't want different people to get user ID collisions but at the same time we want people to +# consistently use the same user IDs so that they can clean up their extra endpoints. +USER_PREFIX = os.getenv("SERVICE_IDENTIFIER", "test")[:8] +USER_ID_0 = USER_PREFIX + "0" * (24 - len(USER_PREFIX)) +USER_ID_1 = USER_PREFIX + "1" * (24 - len(USER_PREFIX)) + +DEFAULT_USERS: Sequence[str] = ( + USER_ID_0, + USER_ID_1, +) + + +def echo_load_predict_fn(model): + def echo(**keyword_args): + return model(**keyword_args) + + return echo + + +def echo_load_model_fn(): + def my_model(**keyword_args): + return {k: v for k, v in keyword_args.items()} + + return my_model + + +CREATE_MODEL_BUNDLE_REQUEST_SIMPLE = { + "name": "model_bundle_simple", + "schema_location": "s3://model-engine-integration-tests/model_bundles/echo_schemas", + "metadata": { + "test_key": "test_value", + }, + "flavor": { + "flavor": "cloudpickle_artifact", + "load_predict_fn": inspect.getsource(echo_load_predict_fn), + "load_model_fn": inspect.getsource(echo_load_model_fn), + "framework": { + "framework_type": "pytorch", + "pytorch_image_tag": "1.7.1-cuda11.0-cudnn8-runtime", + }, + "requirements": ["cloudpickle==2.1.0", "pyyaml==6.0"], + "location": "s3://model-engine-integration-tests/model_bundles/echo_bundle", + }, +} + +CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE = { + "name": "model_bundle_runnable_image", + "schema_location": "s3://model-engine-integration-tests/model_bundles/echo_schemas", + "metadata": { + "test_key": "test_value", + }, + "flavor": { + "flavor": "streaming_enhanced_runnable_image", + "repository": "model-engine", + "tag": "2c1951dfff7159d7d29dd13b4f888e8355f8d51e", + "command": [ + "dumb-init", + "--", + "ddtrace-run", + "python", + "-m", + "model_engine_server.inference.forwarding.echo_server", + "--port", + "5005", + ], + "streaming_command": [ + "dumb-init", + "--", + "ddtrace-run", + "python", + "-m", + "model_engine_server.inference.forwarding.echo_server", + "--port", + "5005", + ], + "env": { + "TEST_KEY": "test_value", + "ML_INFRA_SERVICES_CONFIG_PATH": "/workspace/model-engine/model_engine_server/core/configs/default.yaml", + # infra configs are mounted here + "HTTP_HOST": "0.0.0.0", # Hack for uvicorn to work in minikube + }, + "protocol": "http", + "readiness_initial_delay_seconds": 20, + }, +} + +CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE = { + "bundle_name": "model_bundle_simple", + "name": "model-endpoint-simple-async", + "endpoint_type": "async", + "cpus": "0.5", + "memory": "500Mi", + "min_workers": 1, + "max_workers": 1, + "gpus": 0, + "per_worker": 1, + "labels": {"team": "infra", "product": "launch"}, + "metadata": {}, +} + +CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE = CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE.copy() +CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE["name"] = "model-endpoint-simple-sync" +CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE["endpoint_type"] = "sync" + +CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE = { + "bundle_name": "model_bundle_runnable_image", + "name": "model-endpoint-runnable-image-async", + "post_inference_hooks": [], + "endpoint_type": "async", + "cpus": "1", + "gpus": 0, + "memory": "1Gi", + "optimize_costs": False, + "min_workers": 1, + "max_workers": 1, + "per_worker": 1, + "labels": {"team": "infra", "product": "launch"}, + "metadata": {"key": "value"}, +} + +CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE = ( + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE.copy() +) +CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE[ + "name" +] = "model-endpoint-runnable-image-sync-streaming" +CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE["endpoint_type"] = "streaming" + +UPDATE_MODEL_ENDPOINT_REQUEST_SIMPLE = { + "bundle_name": "model_bundle_simple", + "cpus": "1", + "memory": "1Gi", + "max_workers": 2, +} + +UPDATE_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE = { + "bundle_name": "model_bundle_runnable_image", + "cpus": "2", + "memory": "2Gi", + "max_workers": 2, +} + +INFERENCE_PAYLOAD: Dict[str, Any] = { + "args": {"y": 1}, + "url": None, +} + +INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE: Dict[str, Any] = INFERENCE_PAYLOAD.copy() +INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE["return_pickled"] = False + +INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE: Dict[str, Any] = INFERENCE_PAYLOAD.copy() +INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE["return_pickled"] = True + +CREATE_BATCH_JOB_REQUEST: Dict[str, Any] = { + "bundle_name": "model_bundle_simple", + "input_path": "TBA", + "serialization_format": "JSON", + "labels": {"team": "infra", "product": "launch"}, + "resource_requests": { + "memory": "500Mi", + "max_workers": 1, + "gpus": 0, + }, +} + +CREATE_DOCKER_IMAGE_BATCH_JOB_BUNDLE_REQUEST: Dict[str, Any] = { + "name": "di_batch_job_bundle_1", + "image_repository": "model-engine", + "image_tag": "2c1951dfff7159d7d29dd13b4f888e8355f8d51e", + "command": ["jq", ".", "/launch_mount_location/file"], + "env": {"ENV1": "VAL1"}, + "mount_location": "/launch_mount_location/file", + "resource_requests": { + "cpus": 0.1, + "memory": "10Mi", + }, +} + +CREATE_DOCKER_IMAGE_BATCH_JOB_REQUEST: Dict[str, Any] = { + "docker_image_batch_job_bundle_name": "di_batch_job_bundle_1", + "job_config": {"data": {"to": "mount"}}, + "labels": {"team": "infra", "product": "testing"}, + "resource_requests": {"cpus": 0.15, "memory": "15Mi"}, +} + +CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST: Dict[str, Any] = { + "name": "fine_tune_di_batch_job_bundle_1", + "image_repository": "model-engine", + "image_tag": "2c1951dfff7159d7d29dd13b4f888e8355f8d51e", + "command": ["cat", "/launch_mount_location/file"], + "env": {"ENV1": "VAL1"}, + "mount_location": "/launch_mount_location/file", + "resource_requests": { + "cpus": 0.1, + "memory": "10Mi", + }, + "public": True, +} + +CREATE_FINE_TUNE_REQUEST: Dict[str, Any] = { + "model": "test_base_model", + "training_file": "s3://model-engine-integration-tests/fine_tune_files/run_through_walls.csv", + "validation_file": None, + # "fine_tuning_method": "test_fine_tuning_method", # ignored until we change it + "hyperparameters": {}, +} + + +def create_model_bundle( + create_model_bundle_request: Dict[str, Any], user_id: str, version: str +) -> Dict[str, Any]: + response = requests.post( + f"{BASE_PATH}/{version}/model-bundles", + json=create_model_bundle_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) +def get_latest_model_bundle(model_name: str, user_id: str, version: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/{version}/model-bundles/latest?model_name={model_name}", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_or_create_model_bundle( + create_model_bundle_request: Dict[str, Any], user_id: str, version: str +) -> Dict[str, Any]: + # In v1, we will no longer have the uniqueness constraint of (name, created_by) but right now + # for backwards compatibility, such a constraint exists. As a result, we use this get-or-create + # method as a temporary workaround since v1 will not support bundle deletion initially. + try: + return get_latest_model_bundle(create_model_bundle_request["name"], user_id, version) + except: # noqa: E722 + return create_model_bundle(create_model_bundle_request, user_id, version) + + +def replace_model_bundle_name_with_id(request: Dict[str, Any], user_id: str, version): + if "bundle_name" in request: + model_bundle = get_latest_model_bundle(request["bundle_name"], user_id, version) + request["model_bundle_id"] = model_bundle["id"] + del request["bundle_name"] + + +def create_model_endpoint( + create_model_endpoint_request: Dict[str, Any], user_id: str +) -> Dict[str, Any]: + create_model_endpoint_request = create_model_endpoint_request.copy() + replace_model_bundle_name_with_id(create_model_endpoint_request, user_id, "v1") + response = requests.post( + f"{BASE_PATH}/v1/model-endpoints", + json=create_model_endpoint_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def create_batch_job(create_batch_job_request: Dict[str, Any], user_id: str) -> Dict[str, Any]: + create_batch_job_request = create_batch_job_request.copy() + replace_model_bundle_name_with_id(create_batch_job_request, user_id, "v2") + response = requests.post( + f"{BASE_PATH}/v1/batch-jobs", + json=create_batch_job_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def cancel_batch_job(batch_job_id: str, user_id: str) -> Dict[str, Any]: + response = requests.put( + f"{BASE_PATH}/v1/batch-jobs/{batch_job_id}", + json={"cancel": True}, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def create_docker_image_batch_job_bundle( + create_docker_image_batch_job_bundle_request: Dict[str, Any], user_id: str +) -> Dict[str, Any]: + response = requests.post( + f"{BASE_PATH}/v1/docker-image-batch-job-bundles", + json=create_docker_image_batch_job_bundle_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_latest_docker_image_batch_job_bundle(bundle_name: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/docker-image-batch-job-bundles/latest?bundle_name={bundle_name}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_or_create_docker_image_batch_job_bundle( + create_docker_image_batch_job_bundle_request: Dict[str, Any], user_id: str +): + try: + return get_latest_docker_image_batch_job_bundle( + create_docker_image_batch_job_bundle_request["name"], user_id + ) + except: # noqa: E722 + return create_docker_image_batch_job_bundle( + create_docker_image_batch_job_bundle_request, user_id + ) + + +def get_docker_image_batch_job_bundle_by_id( + docker_image_batch_job_bundle_id: str, user_id: str +) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/docker-image-batch-job-bundles/{docker_image_batch_job_bundle_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def create_docker_image_batch_job( + create_docker_image_batch_job_request: Dict[str, Any], user_id: str +) -> Dict[str, Any]: + response = requests.post( + f"{BASE_PATH}/v1/docker-image-batch-jobs", + json=create_docker_image_batch_job_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_docker_image_batch_job(batch_job_id: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/docker-image-batch-jobs/{batch_job_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def create_fine_tune(create_fine_tune_request: Dict[str, Any], user_id: str) -> Dict[str, Any]: + response = requests.post( + f"{BASE_PATH}/v1/llm/fine-tunes", + json=create_fine_tune_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_fine_tune_by_id(fine_tune_id: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/llm/fine-tunes/{fine_tune_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def list_fine_tunes(user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/llm/fine-tunes", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def cancel_fine_tune_by_id(fine_tune_id: str, user_id: str) -> Dict[str, Any]: + response = requests.put( + f"{BASE_PATH}/v1/llm/fine-tunes/{fine_tune_id}/cancel", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def upload_file(file, user_id: str) -> Dict[str, Any]: + files = {"file": file} + response = requests.post( + f"{BASE_PATH}/v1/files", + files=files, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_file_by_id(file_id: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/files/{file_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def list_files(user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/files", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=30, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def delete_file_by_id(file_id: str, user_id: str) -> Dict[str, Any]: + response = requests.delete( + f"{BASE_PATH}/v1/files/{file_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_file_content_by_id(file_id: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/files/{file_id}/content", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +@retry(stop=stop_after_attempt(6), wait=wait_fixed(1)) +def get_model_endpoint(name: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/model-endpoints?name={name}", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json()["model_endpoints"][0] + + +def update_model_endpoint( + endpoint_name: str, update_model_endpoint_request: Dict[str, Any], user_id: str +) -> Dict[str, Any]: + update_model_endpoint_request = update_model_endpoint_request.copy() + replace_model_bundle_name_with_id(update_model_endpoint_request, user_id, "v2") + endpoint = get_model_endpoint(endpoint_name, user_id) + response = requests.put( + f"{BASE_PATH}/v1/model-endpoints/{endpoint['id']}", + json=update_model_endpoint_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def delete_model_endpoint(endpoint_name: str, user_id: str) -> Dict[str, Any]: + endpoint = get_model_endpoint(endpoint_name, user_id) + response = requests.delete( + f"{BASE_PATH}/v1/model-endpoints/{endpoint['id']}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) +def list_model_endpoints(user_id: str) -> List[Dict[str, Any]]: + response = requests.get( + f"{BASE_PATH}/v1/model-endpoints", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json()["model_endpoints"] + + +async def create_async_task( + model_endpoint_id: str, + create_async_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/async-tasks?model_endpoint_id={model_endpoint_id}", + json=create_async_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) as response: + return (await response.json())["task_id"] + + +async def create_async_tasks( + endpoint_name: str, create_async_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + endpoint = get_model_endpoint(endpoint_name, user_id) + async with aiohttp.ClientSession() as session: + tasks = [] + for create_async_task_request in create_async_task_requests: + task = create_async_task(endpoint["id"], create_async_task_request, user_id, session) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +async def create_sync_task( + model_endpoint_id: str, + create_sync_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/sync-tasks?model_endpoint_id={model_endpoint_id}", + json=create_sync_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) as response: + assert response.status == 200, (await response.read()).decode() + return await response.json() + + +async def create_streaming_task( + model_endpoint_id: str, + create_streaming_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/streaming-tasks?model_endpoint_id={model_endpoint_id}", + json=create_streaming_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) as response: + assert response.status == 200, (await response.read()).decode() + return (await response.read()).decode() + + +async def create_sync_tasks( + endpoint_name: str, create_sync_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + endpoint = get_model_endpoint(endpoint_name, user_id) + async with aiohttp.ClientSession() as session: + tasks = [] + for create_sync_task_request in create_sync_task_requests: + task = create_sync_task(endpoint["id"], create_sync_task_request, user_id, session) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +async def create_streaming_tasks( + endpoint_name: str, create_streaming_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + endpoint = get_model_endpoint(endpoint_name, user_id) + async with aiohttp.ClientSession() as session: + tasks = [] + for create_streaming_task_request in create_streaming_task_requests: + task = create_streaming_task( + endpoint["id"], create_streaming_task_request, user_id, session + ) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +async def get_async_task( + task_id: str, user_id: str, session: aiohttp.ClientSession +) -> Dict[str, Any]: + async with session.get( + f"{BASE_PATH}/v1/async-tasks/{task_id}", + auth=aiohttp.BasicAuth(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) as response: + return await response.json() + + +async def get_async_tasks(task_ids: List[str], user_id: str) -> List[Dict[str, Any]]: + async with aiohttp.ClientSession() as session: + tasks = [] + for task_id in task_ids: + task = get_async_task(task_id, user_id, session) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +# Wait 25 minutes (1500 seconds) for endpoints to build. +@retry(stop=stop_after_attempt(25), wait=wait_fixed(60)) +def ensure_n_ready_endpoints_long(n: int, user_id: str): + endpoints = list_model_endpoints(user_id) + ready_endpoints = [endpoint for endpoint in endpoints if endpoint["status"] == "READY"] + print( + f"User {user_id} Current num endpoints: {len(endpoints)}, num ready endpoints: {len(ready_endpoints)}" + ) + assert ( + len(ready_endpoints) >= n + ), f"Expected {n} ready endpoints, got {len(ready_endpoints)}. Look through endpoint builder for errors." + + +# Wait 2 minutes (120 seconds) for endpoints to build. +@retry(stop=stop_after_attempt(12), wait=wait_fixed(10)) +def ensure_n_ready_endpoints_short(n: int, user_id: str): + endpoints = list_model_endpoints(user_id) + ready_endpoints = [endpoint for endpoint in endpoints if endpoint["status"] == "READY"] + print( + f"User {user_id} Current num endpoints: {len(endpoints)}, num ready endpoints: {len(ready_endpoints)}" + ) + assert len(ready_endpoints) >= n + + +def delete_all_endpoints(user_id): + endpoints = list_model_endpoints(user_id) + for i, endpoint in enumerate(endpoints): + response = delete_model_endpoint(endpoint["name"], user_id) + assert response["deleted"] + print(f"[{i + 1}/{len(endpoints)}] Deleted {endpoint=}") + + +# Wait up to 5 minutes (300 seconds) for the gateway to be ready. +@retry(stop=stop_after_attempt(30), wait=wait_fixed(10)) +def ensure_gateway_ready(): + response = requests.get(f"{BASE_PATH}/healthz") + assert response.ok + + +# Wait up to 10 minutes (600 seconds) for the pods to spin up. +@retry(stop=stop_after_attempt(200), wait=wait_fixed(3)) +def ensure_nonzero_available_workers(endpoint_name: str, user_id: str): + simple_endpoint = get_model_endpoint(endpoint_name, user_id) + assert simple_endpoint.get("deployment_state", {}).get("available_workers", 0) + + +def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_pickled: bool): + print(response) + assert response["status"] == "SUCCESS" + assert response["traceback"] is None + if return_pickled: + assert response["result"]["result_url"].startswith("s3://") + else: + assert response["result"] == {"result": '{"y": 1}'} + + +# Wait up to 30 seconds for the tasks to be returned. +@retry( + stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError) +) +def ensure_all_async_tasks_success(task_ids: List[str], user_id: str, return_pickled: bool): + responses = asyncio.run(get_async_tasks(task_ids, user_id)) + for response in responses: + if response["status"] not in (TaskStatus.PENDING, TaskStatus.SUCCESS, TaskStatus.STARTED): + print(response) + raise ValueError("Task failed!") + ensure_inference_task_response_is_correct(response, return_pickled) + + +def delete_existing_endpoints(users: Sequence[str] = DEFAULT_USERS) -> None: + if len(users) == 0: + raise ValueError("Must supply at least one user!") + + # list all endpoints before attempting to delete them + print(f"[{len({users})} ] Listing all user endpoints... ({users})") + all_endpoint_info = [] + for i, u in enumerate(users): + u_endpoints = list_model_endpoints(u) + all_endpoint_info.append(u_endpoints) + k8s_endpoint_names = [ + f"launch-endpoint-id-{endpoint['id'].replace('_', '-')}" for endpoint in u_endpoints + ] + print( + f"[{i + 1}/{len(users)}] {len(u_endpoints)} endpoints for user {u}: {k8s_endpoint_names}" + ) + + if all([len(info) == 0 for info in all_endpoint_info]): + return + + # delete the endpoints: if this fails, manually remove the dangling k8s deployments + # and delete the user's endpoints from the hosted_model_inference.endpoints table + # i.e. by default this is running the following SQL: + # + # >>>> delete from model_engine_server.endpoints where created_by in ( + # >>>> 'test00000000000000000000', + # >>>> 'test11111111111111111111', + # >>>> ) + # + time.sleep(15) # need to sleep to allow the cache to refresh + print(f"[{len({users})}] Deleting all user endpoints...") + try: + for i, u in enumerate(users): + print(f"[{i + 1}/{len(users)}] Deleting all endpoints for user with ID {u}") + delete_all_endpoints(u) + except Exception: # noqa + try: + j: str = json.dumps(all_endpoint_info, indent=2) + except Exception as j_error: # noqa + j = f"[FAILED TO JSON ENCODE {j_error}]\n{all_endpoint_info}" + barrier: str = "-" * 80 + print(f"ERROR! Deletion failed. All endpoint information:\n{barrier}\n{j}\n{barrier}") + raise + + time.sleep(15) diff --git a/integration_tests/test_batch_jobs.py b/integration_tests/test_batch_jobs.py new file mode 100644 index 00000000..8f4f1dec --- /dev/null +++ b/integration_tests/test_batch_jobs.py @@ -0,0 +1,25 @@ +from .rest_api_utils import ( + CREATE_BATCH_JOB_REQUEST, + CREATE_DOCKER_IMAGE_BATCH_JOB_BUNDLE_REQUEST, + CREATE_DOCKER_IMAGE_BATCH_JOB_REQUEST, + USER_ID_0, + cancel_batch_job, + create_batch_job, + create_docker_image_batch_job, + get_or_create_docker_image_batch_job_bundle, +) +from .test_bundles import model_bundles # noqa + + +def test_di_batch_jobs(model_bundles) -> None: # noqa + get_or_create_docker_image_batch_job_bundle( + CREATE_DOCKER_IMAGE_BATCH_JOB_BUNDLE_REQUEST, USER_ID_0 + ) + create_docker_image_batch_job(CREATE_DOCKER_IMAGE_BATCH_JOB_REQUEST, USER_ID_0) + + batch_job_id = create_batch_job(CREATE_BATCH_JOB_REQUEST, USER_ID_0)["job_id"] + + # TODO: assert that batch job actually succeeds. + + cancel_response = cancel_batch_job(batch_job_id, USER_ID_0) + assert cancel_response["success"] diff --git a/integration_tests/test_bundles.py b/integration_tests/test_bundles.py new file mode 100644 index 00000000..5d38d80f --- /dev/null +++ b/integration_tests/test_bundles.py @@ -0,0 +1,23 @@ +import pytest + +from .rest_api_utils import ( + CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE, + CREATE_MODEL_BUNDLE_REQUEST_SIMPLE, + USER_ID_0, + USER_ID_1, + create_model_bundle, + get_latest_model_bundle, +) + + +@pytest.fixture(scope="session") +def model_bundles(): + for user in [USER_ID_0, USER_ID_1]: + for create_bundle_request in [ + CREATE_MODEL_BUNDLE_REQUEST_SIMPLE, + CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE, + ]: + create_model_bundle(create_bundle_request, user, "v2") + bundle = get_latest_model_bundle(create_bundle_request["name"], user, "v2") + assert bundle["name"] == create_bundle_request["name"] + assert bundle["metadata"] == create_bundle_request["metadata"] diff --git a/integration_tests/test_docs.py b/integration_tests/test_docs.py new file mode 100644 index 00000000..2185154e --- /dev/null +++ b/integration_tests/test_docs.py @@ -0,0 +1,210 @@ +# Ignore lint errors for f-strings because the f-strings are actually regex expressions. +# flake8: noqa: W605 +import importlib.util +import os +import random +import re +from pathlib import Path +from textwrap import dedent + +import pytest +from _pytest.assertion.rewrite import AssertionRewritingHook + +ROOT_DIR = Path(__file__).parent.parent + +TEST_SKIP_MAGIC_STRING = "# test='skip'" + + +@pytest.fixture +def tmp_work_path(tmp_path: Path): + """ + Create a temporary working directory. + """ + previous_cwd = Path.cwd() + os.chdir(tmp_path) + + yield tmp_path + + os.chdir(previous_cwd) + + +class SetEnv: + def __init__(self): + self.envars = set() + + def __call__(self, name, value): + self.envars.add(name) + os.environ[name] = value + + def clear(self): + for n in self.envars: + os.environ.pop(n) + + +@pytest.fixture +def env(): + setenv = SetEnv() + + yield setenv + + setenv.clear() + + +@pytest.fixture() +def seed() -> int: + """Returns a random seed between 0 and 999, inclusive.""" + return random.randint(0, 999) + + +@pytest.fixture() +def integration_test_user_id() -> str: + return "62bc820451dbea002b1c5421" + + +def modify_source(source: str, seed: int) -> str: + # Adds some custom logic to update code from docs to comply with some requirements. + source = re.sub(r"('team'|\"team\"): ('\w+'|\"\w+\")", r"'team': 'infra'", source) + source = re.sub( + r"('product'|\"product\"): ('\w+'|\"\w+\")", + r"'product': 'launch-integration-test'", + source, + ) + + # Add suffix to avoid name collisions + source = re.sub( + r"('endpoint_name'|\"endpoint_name\"): ('(\w+)'|\"(\w+)\")", + f"'endpoint_name': '\g<3>\g<4>-{seed}'", + source, + ) + source = re.sub( + r"endpoint_name=('(\w+)'|\"(\w+)\")", + f"endpoint_name='\g<2>\g<3>-{seed}'", + source, + ) + source = re.sub(r'"repository": "..."', '"repository": "launch_rearch"', source) + source = re.sub( + r'"tag": "..."', '"tag": "11d9d42047cc9a0c6435b19e5e91bc7e0ad31efc-cpu"', source + ) + source = re.sub( + r'"command": ...', + """"command": [ + "dumb-init", + "--", + "ddtrace-run", + "run-service", + "--config", + "/install/launch_rearch/config/service--user_defined_code.yaml", + "--concurrency", + "1", + "--http", + "production", + "--port", + "5005", + ]""", + source, + ) + source = re.sub( + r'"streaming_command": ...', + """"streaming_command": [ + "dumb-init", + "--", + "ddtrace-run", + "run-streamer", + "--config", + "/install/std-ml-srv/tests/resources/example_echo_streaming_service_configuration.yaml", + "--concurrency", + "1", + "--http-mode", + "production", + "--port", + "5005", + ]""", + source, + ) + return source + + +@pytest.fixture +def import_execute(request, tmp_work_path: Path): + def _import_execute(module_name: str, source: str, seed: int, rewrite_assertions: bool = False): + if rewrite_assertions: + loader = AssertionRewritingHook(config=request.config) + loader.mark_rewrite(module_name) + else: + loader = None + + module_path = tmp_work_path / f"{module_name}.py" + modified_source = modify_source(source, seed) + module_path.write_text(modified_source) + spec = importlib.util.spec_from_file_location("__main__", str(module_path), loader=loader) + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except KeyboardInterrupt: + print("KeyboardInterrupt") + + return _import_execute + + +def extract_code_chunks(path: Path, text: str, offset: int): + rel_path = path.relative_to(ROOT_DIR) + for m_code in re.finditer(r"```(.*?)$\n(.*?)\n( *)```", text, flags=re.M | re.S): + prefix = m_code.group(1).lower() + if not prefix.startswith(("py", "{.py")): + continue + + start_line = offset + text[: m_code.start()].count("\n") + 1 + code = dedent(m_code.group(2)) + end_line = start_line + code.count("\n") + 1 + source = "\n" * start_line + code + if TEST_SKIP_MAGIC_STRING in prefix or TEST_SKIP_MAGIC_STRING in code: + source = "__skip__" + yield pytest.param( + f"{path.stem}_{start_line}_{end_line}", source, id=f"{rel_path}:{start_line}-{end_line}" + ) + + +def generate_code_chunks(*directories: str): + for d in directories: + for path in (ROOT_DIR / d).glob("**/*"): + if path.suffix == ".py": + code = path.read_text() + for m_docstring in re.finditer(r'(^\s*)r?"""$(.*?)\1"""', code, flags=re.M | re.S): + start_line = code[: m_docstring.start()].count("\n") + docstring = m_docstring.group(2) + yield from extract_code_chunks(path, docstring, start_line) + elif path.suffix == ".md": + # TODO: remove this hack to skip llms.md + if "llms.md" in path.name: + continue + code = path.read_text() + yield from extract_code_chunks(path, code, 0) + + +# Assumes that launch-python-client is cloned at `models/launch-python-client` +@pytest.mark.parametrize( + "module_name,source_code", + generate_code_chunks( + "launch-python-client/docs", + "launch-python-client/launch", + "launch_internal/docs", + "launch_internal/launch_internal", + ), +) +def test_docs_examples( + module_name, + source_code, + import_execute, + env, + seed, + integration_test_user_id, +): + if source_code == "__skip__": + pytest.skip("test='skip' on code snippet") + + env("LAUNCH_API_KEY", os.getenv("LAUNCH_TEST_API_KEY", integration_test_user_id)) + + try: + import_execute(module_name, source_code, seed, True) + except Exception: + raise diff --git a/integration_tests/test_endpoints.py b/integration_tests/test_endpoints.py new file mode 100644 index 00000000..2af5a257 --- /dev/null +++ b/integration_tests/test_endpoints.py @@ -0,0 +1,228 @@ +import asyncio +import time + +import pytest + +from .rest_api_utils import ( + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE, + CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE, + CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + INFERENCE_PAYLOAD, + INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE, + INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE, + UPDATE_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + UPDATE_MODEL_ENDPOINT_REQUEST_SIMPLE, + USER_ID_0, + create_async_tasks, + create_model_endpoint, + create_streaming_tasks, + create_sync_tasks, + delete_existing_endpoints, + delete_model_endpoint, + ensure_all_async_tasks_success, + ensure_gateway_ready, + ensure_inference_task_response_is_correct, + ensure_n_ready_endpoints_long, + ensure_n_ready_endpoints_short, + ensure_nonzero_available_workers, + get_model_endpoint, + update_model_endpoint, +) + + +@pytest.fixture(autouse=True) +def delete_endpoints(capsys): + try: + ensure_gateway_ready() + delete_existing_endpoints() + except Exception: + with capsys.disabled(): + print("Endpoint deletion failed") + + +@pytest.mark.parametrize( + "create_endpoint_request,update_endpoint_request,inference_requests", + [ + ( + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE, + UPDATE_MODEL_ENDPOINT_REQUEST_SIMPLE, + [ + (INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE, True), + (INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE, False), + ], + ), + ( + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + UPDATE_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + [(INFERENCE_PAYLOAD, False)], + ), + ], +) +def test_async_model_endpoint( + capsys, create_endpoint_request, update_endpoint_request, inference_requests +): + with capsys.disabled(): + try: + user = USER_ID_0 + print(f"Creating {create_endpoint_request['name']} model endpoint...") + create_model_endpoint(create_endpoint_request, user) + ensure_n_ready_endpoints_long(1, user) + + print(f"Updating {create_endpoint_request['name']} model endpoint...") + update_model_endpoint( + create_endpoint_request["name"], + update_endpoint_request, + user, + ) + # Let the cache update + time.sleep(30) + # Endpoint builds should be cached now. + ensure_n_ready_endpoints_short(1, user) + + print("Checking endpoint state...") + endpoint = get_model_endpoint(create_endpoint_request["name"], user) + assert endpoint["resource_state"]["cpus"] == update_endpoint_request["cpus"] + assert endpoint["resource_state"]["memory"] == update_endpoint_request["memory"] + assert ( + endpoint["deployment_state"]["max_workers"] + == update_endpoint_request["max_workers"] + ) + + time.sleep(10) + + for inference_payload, return_pickled in inference_requests: + print( + f"Sending async tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..." + ) + task_ids = asyncio.run( + create_async_tasks( + create_endpoint_request["name"], + [inference_payload] * 3, + user, + ) + ) + print("Retrieving async task results...") + ensure_nonzero_available_workers(create_endpoint_request["name"], user) + ensure_all_async_tasks_success(task_ids, user, return_pickled) + finally: + delete_model_endpoint(create_endpoint_request["name"], user) + + +def test_sync_model_endpoint(capsys): + with capsys.disabled(): + try: + user = USER_ID_0 + create_endpoint_request = CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE + update_endpoint_request = UPDATE_MODEL_ENDPOINT_REQUEST_SIMPLE + inference_requests = [ + (INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE, True), + (INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE, False), + ] + + print(f"Creating {create_endpoint_request['name']} model endpoint...") + create_model_endpoint(create_endpoint_request, user) + ensure_n_ready_endpoints_short(1, user) + + print(f"Updating {create_endpoint_request['name']} model endpoint...") + update_model_endpoint( + create_endpoint_request["name"], + update_endpoint_request, + user, + ) + # Let the cache update + time.sleep(30) + # Endpoint builds should be cached now. + ensure_n_ready_endpoints_short(1, user) + ensure_nonzero_available_workers(create_endpoint_request["name"], user) + + print("Checking endpoint state...") + endpoint = get_model_endpoint(create_endpoint_request["name"], user) + assert endpoint["resource_state"]["cpus"] == update_endpoint_request["cpus"] + assert endpoint["resource_state"]["memory"] == update_endpoint_request["memory"] + assert ( + endpoint["deployment_state"]["max_workers"] + == update_endpoint_request["max_workers"] + ) + + time.sleep(10) + + for inference_payload, return_pickled in inference_requests: + print( + f"Sending sync tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..." + ) + task_responses = asyncio.run( + create_sync_tasks( + create_endpoint_request["name"], + [inference_payload], + user, + ) + ) + for response in task_responses: + ensure_inference_task_response_is_correct(response, return_pickled) + finally: + delete_model_endpoint(create_endpoint_request["name"], user) + + +def test_sync_streaming_model_endpoint(capsys): + with capsys.disabled(): + try: + user = USER_ID_0 + create_endpoint_request = CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE + update_endpoint_request = UPDATE_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE + + print(f"Creating {create_endpoint_request['name']} model endpoint...") + create_model_endpoint(create_endpoint_request, user) + ensure_n_ready_endpoints_short(1, user) + + print(f"Updating {create_endpoint_request['name']} model endpoint...") + update_model_endpoint( + create_endpoint_request["name"], + update_endpoint_request, + user, + ) + # Let the cache update + time.sleep(30) + # Endpoint builds should be cached now. + ensure_n_ready_endpoints_short(1, user) + ensure_nonzero_available_workers(create_endpoint_request["name"], user) + + print("Checking endpoint state...") + endpoint = get_model_endpoint(create_endpoint_request["name"], user) + assert endpoint["resource_state"]["cpus"] == update_endpoint_request["cpus"] + assert endpoint["resource_state"]["memory"] == update_endpoint_request["memory"] + assert ( + endpoint["deployment_state"]["max_workers"] + == update_endpoint_request["max_workers"] + ) + + time.sleep(5) + + print(f"Sending sync tasks to {create_endpoint_request['name']} for user {user} ...") + task_responses = asyncio.run( + create_sync_tasks( + create_endpoint_request["name"], + [INFERENCE_PAYLOAD] * 3, + user, + ) + ) + for response in task_responses: + ensure_inference_task_response_is_correct(response, False) + + print( + f"Sending streaming tasks to {create_endpoint_request['name']} for user {user} ..." + ) + task_responses = asyncio.run( + create_streaming_tasks( + create_endpoint_request["name"], + [INFERENCE_PAYLOAD] * 5, + user, + ) + ) + for response in task_responses: + assert ( + response.strip() + == 'data: {"status": "SUCCESS", "result": {"result": {"y": 1}}, "traceback": null}' + ) + finally: + delete_model_endpoint(create_endpoint_request["name"], user) diff --git a/integration_tests/test_file.py b/integration_tests/test_file.py new file mode 100644 index 00000000..53c10345 --- /dev/null +++ b/integration_tests/test_file.py @@ -0,0 +1,27 @@ +from .rest_api_utils import ( # list_files, delete_file_by_id, + get_file_by_id, + get_file_content_by_id, + upload_file, +) + + +def test_files() -> None: + user = "62bc820451dbea002b1c5421" # CDS needs proper user ID + + upload_response = upload_file(open(__file__, "rb"), user) + file_id = upload_response["id"] + + content = get_file_content_by_id(file_id, user) + assert content["id"] == file_id + assert content["content"] + + get_response = get_file_by_id(file_id, user) + assert get_response["id"] == file_id + assert get_response["filename"] == "test_file.py" + + # TODO: add tests back + # list_response = list_files(user) + # assert len(list_response["files"]) > 0 + + # delete_response = delete_file_by_id(file_id, user) + # assert delete_response["deleted"] diff --git a/integration_tests/test_fine_tunes.py b/integration_tests/test_fine_tunes.py new file mode 100644 index 00000000..a5aee7ac --- /dev/null +++ b/integration_tests/test_fine_tunes.py @@ -0,0 +1,31 @@ +from .rest_api_utils import ( # CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, CREATE_FINE_TUNE_REQUEST, USER_ID_0, cancel_fine_tune_by_id, create_docker_image_batch_job_bundle, create_fine_tune, get_fine_tune_by_id, + USER_ID_1, + list_fine_tunes, +) + + +def test_fine_tunes() -> None: + # TODO: get this test to work (move LLM fine tune repository to database rather than in S3) + + # di_batch_job_id = create_docker_image_batch_job_bundle( + # CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, USER_ID_0 + # )["docker_image_batch_job_bundle_id"] + + # create_response = create_fine_tune(CREATE_FINE_TUNE_REQUEST, USER_ID_0) + # fine_tune_id = create_response["id"] + + # get_response = get_fine_tune_by_id(fine_tune_id, USER_ID_0) + # assert get_response["id"] == fine_tune_id + + # list_response_0_before = list_fine_tunes(USER_ID_0) + # num_jobs = len(list_response_0_before["jobs"]) + # assert num_jobs >= 1 + + list_response_1 = list_fine_tunes(USER_ID_1) + assert len(list_response_1["jobs"]) == 0 + + # cancel_response = cancel_fine_tune_by_id(fine_tune_id, USER_ID_0) + # assert cancel_response["success"] + + # list_response_0_after = list_fine_tunes(USER_ID_0) + # assert len(list_response_0_after["jobs"]) == num_jobs - 1 diff --git a/model-engine/Dockerfile b/model-engine/Dockerfile index 7e186157..80939559 100644 --- a/model-engine/Dockerfile +++ b/model-engine/Dockerfile @@ -48,6 +48,8 @@ COPY model-engine/setup.py /workspace/model-engine/setup.py COPY model-engine/model_engine_server /workspace/model-engine/model_engine_server RUN pip install -e . +COPY integration_tests /workspace/integration_tests + WORKDIR /workspace ENV PYTHONPATH /workspace ENV WORKSPACE /workspace diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index e9d8424b..bdd158db 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -80,6 +80,7 @@ DbModelEndpointRecordRepository, DbTriggerRepository, ECRDockerRepository, + FakeDockerRepository, RedisModelEndpointCacheRepository, S3FileLLMFineTuneEventsRepository, S3FileLLMFineTuneRepository, @@ -232,8 +233,10 @@ def _get_external_interfaces( file_storage_gateway = S3FileStorageGateway() + docker_repository = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository() + external_interfaces = ExternalInterfaces( - docker_repository=ECRDockerRepository(), + docker_repository=docker_repository, model_bundle_repository=model_bundle_repository, model_endpoint_service=model_endpoint_service, llm_model_endpoint_service=llm_model_endpoint_service, diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index 3802129b..445dd83c 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -37,7 +37,7 @@ from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( SQSEndpointResourceDelegate, ) -from model_engine_server.infra.repositories import ECRDockerRepository +from model_engine_server.infra.repositories import ECRDockerRepository, FakeDockerRepository from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( DbModelEndpointRecordRepository, ) @@ -117,7 +117,7 @@ async def main(args: Any): sqs_delegate=sqs_delegate, ) image_cache_gateway = ImageCacheGateway() - docker_repo = ECRDockerRepository() + docker_repo = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository() while True: loop_start = time.time() await loop_iteration( diff --git a/model-engine/model_engine_server/inference/forwarding/echo_server.py b/model-engine/model_engine_server/inference/forwarding/echo_server.py new file mode 100644 index 00000000..12470cfc --- /dev/null +++ b/model-engine/model_engine_server/inference/forwarding/echo_server.py @@ -0,0 +1,56 @@ +""" +This file is for testing purposes only. It serves as simple server to mock a deployed model. +""" +import argparse +import subprocess + +from fastapi import FastAPI, Request +from sse_starlette.sse import EventSourceResponse + +app = FastAPI() + + +@app.get("/healthz") +@app.get("/readyz") +def healthcheck(): + return "OK" + + +@app.post("/predict") +async def predict(request: Request): + return await request.json() + + +@app.post("/stream") +async def stream(request: Request): + value = (await request.body()).decode() + return EventSourceResponse([{"data": value}].__iter__()) + + +def entrypoint(): + parser = argparse.ArgumentParser() + parser.add_argument("--num-workers", type=int, default=1) + parser.add_argument("--host", type=str, default="[::]") + parser.add_argument("--port", type=int, default=5009) + + args = parser.parse_args() + + command = [ + "gunicorn", + "--bind", + f"{args.host}:{args.port}", + "--timeout", + "1200", + "--keep-alive", + "2", + "--worker-class", + "uvicorn.workers.UvicornWorker", + "--workers", + str(args.num_workers), + "model_engine_server.inference.forwarding.echo_server:app", + ] + subprocess.run(command) + + +if __name__ == "__main__": + entrypoint() diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 48f0e924..e50e1623 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -7,11 +7,11 @@ metadata: name: model-engine-service-template-config labels: team: infra - product: launch + product: model-engine helm.sh/chart: model-engine-0.1.0 app.kubernetes.io/managed-by: Helm - app.kubernetes.io/version: 7034db9f84a3a6009d2ef738e5497b300f24f6cd - tags.datadoghq.com/version: 7034db9f84a3a6009d2ef738e5497b300f24f6cd + app.kubernetes.io/version: a93c7fe34529efde2b468b9cbbf3abf300308164 + tags.datadoghq.com/version: a93c7fe34529efde2b468b9cbbf3abf300308164 tags.datadoghq.com/env: circleci annotations: "helm.sh/hook": pre-install,pre-upgrade @@ -106,7 +106,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -377,7 +377,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -596,7 +596,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -864,7 +864,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -1088,7 +1088,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:7034db9f84a3a6009d2ef738e5497b300f24f6cd + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -1327,7 +1327,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -1459,6 +1459,7 @@ data: periodSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -1604,7 +1605,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -1692,6 +1693,7 @@ data: periodSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -1829,7 +1831,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -1966,6 +1968,7 @@ data: periodSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -2103,7 +2106,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -2196,6 +2199,7 @@ data: periodSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -2333,7 +2337,7 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:7034db9f84a3a6009d2ef738e5497b300f24f6cd + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -2428,6 +2432,7 @@ data: periodSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -2776,7 +2781,7 @@ data: name: default-config containers: - name: main - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} env: - name: DD_SERVICE value: ${RESOURCE_NAME} @@ -2807,11 +2812,13 @@ data: value: "true" - name: LAUNCH_SERVICE_TEMPLATE_FOLDER value: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + - name: CIRCLECI + value: "true" - name: DD_VERSION value: ${GIT_TAG} - name: GIT_TAG value: ${GIT_TAG} - imagePullPolicy: IfNotPresent + imagePullPolicy: Always command: - dumb-init - -- @@ -2930,11 +2937,13 @@ data: value: "true" - name: LAUNCH_SERVICE_TEMPLATE_FOLDER value: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + - name: CIRCLECI + value: "true" - name: DD_VERSION value: ${GIT_TAG} - name: GIT_TAG value: ${GIT_TAG} - imagePullPolicy: IfNotPresent + imagePullPolicy: Always command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. @@ -2956,7 +2965,7 @@ data: name: dshm initContainers: - name: input-downloader - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} command: - python - -m @@ -3074,15 +3083,18 @@ data: value: "true" - name: LAUNCH_SERVICE_TEMPLATE_FOLDER value: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + - name: CIRCLECI + value: "true" - name: DD_VERSION value: ${GIT_TAG} - name: GIT_TAG value: ${GIT_TAG} - imagePullPolicy: IfNotPresent + imagePullPolicy: Always command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -3101,7 +3113,7 @@ data: name: dshm initContainers: - name: input-downloader - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/launch/gateway:${GIT_TAG} + image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} command: - python - -m @@ -3134,7 +3146,7 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: @@ -3149,7 +3161,7 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 @@ -3171,7 +3183,7 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: @@ -3186,7 +3198,7 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 @@ -3212,7 +3224,7 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: @@ -3227,7 +3239,7 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 @@ -3253,7 +3265,7 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: @@ -3268,7 +3280,7 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: launch + product: model-engine use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 diff --git a/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py index 7f297f61..a5020740 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py @@ -36,7 +36,7 @@ async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: ) return FileMetadata( id=file_id, - filename=get_s3_url(owner, file_id), + filename=file_id, size=obj.get("ContentLength"), owner=owner, updated_at=obj.get("LastModified"), @@ -57,7 +57,7 @@ async def upload_file(self, owner: str, filename: str, content: bytes) -> str: with self.filesystem_gateway.open( get_s3_url(owner, filename), mode="w", aws_profile=infra_config().profile_ml_worker ) as f: - f.write(content) + f.write(content.decode("utf-8")) return filename async def delete_file(self, owner: str, file_id: str) -> bool: diff --git a/model-engine/model_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py index bf109926..93fd708b 100644 --- a/model-engine/model_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -7,6 +7,7 @@ from .db_model_endpoint_record_repository import DbModelEndpointRecordRepository from .db_trigger_repository import DbTriggerRepository from .ecr_docker_repository import ECRDockerRepository +from .fake_docker_repository import FakeDockerRepository from .feature_flag_repository import FeatureFlagRepository from .llm_fine_tune_repository import LLMFineTuneRepository from .model_endpoint_cache_repository import ModelEndpointCacheRepository @@ -24,6 +25,7 @@ "DbModelEndpointRecordRepository", "DbTriggerRepository", "ECRDockerRepository", + "FakeDockerRepository", "FeatureFlagRepository", "LLMFineTuneRepository", "ModelEndpointRecordRepository", diff --git a/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py new file mode 100644 index 00000000..b7fa39a6 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py @@ -0,0 +1,21 @@ +from typing import Optional + +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.repositories import DockerRepository + +logger = make_logger(logger_name()) + + +class FakeDockerRepository(DockerRepository): + def image_exists( + self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None + ) -> bool: + return True + + def get_image_url(self, image_tag: str, repository_name: str) -> str: + return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" + + def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: + raise NotImplementedError("FakeDockerRepository build_image() not implemented") diff --git a/model-engine/model_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py index 772a5297..f7dc0d4e 100644 --- a/model-engine/model_engine_server/service_builder/tasks_v1.py +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -37,6 +37,7 @@ from model_engine_server.infra.repositories import ( DbModelEndpointRecordRepository, ECRDockerRepository, + FakeDockerRepository, RedisFeatureFlagRepository, RedisModelEndpointCacheRepository, ) @@ -66,8 +67,10 @@ def get_live_endpoint_builder_service( else: monitoring_metrics_gateway = DatadogMonitoringMetricsGateway() + docker_repository = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository() + service = LiveEndpointBuilderService( - docker_repository=ECRDockerRepository(), + docker_repository=docker_repository, resource_gateway=LiveEndpointResourceGateway( sqs_delegate=sqs_delegate, ), diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index 25b55c7a..17e36639 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -50,7 +50,7 @@ billing_queue_arn: none # There's a separate piece of infra that caches k8s state onto redis, so we need a url to it cache_redis_url: redis://127.0.0.1:6379/15 -s3_file_llm_fine_tune_repository: "s3://test-bucket" +s3_file_llm_fine_tune_repository: "s3://model-engine-integration-tests/fine_tune_repository/circleci" datadog_trace_enabled: false istio_enabled: true diff --git a/model-engine/tests/unit/infra/services/test_image_cache_service.py b/model-engine/tests/unit/infra/services/test_image_cache_service.py index c2ce5243..aa1821fa 100644 --- a/model-engine/tests/unit/infra/services/test_image_cache_service.py +++ b/model-engine/tests/unit/infra/services/test_image_cache_service.py @@ -30,13 +30,17 @@ async def test_image_cache_success( await fake_image_cache_service.execute(infra_states) # type: ignore gateway: Any = fake_image_cache_service.image_cache_gateway - assert f"{infra_config().docker_repo_prefix}/my-repo:abcdefg222" in gateway.cached_images["t4"] assert ( - f"{infra_config().docker_repo_prefix}/my-repo:abcdefg111111111" + f"{infra_config().ml_account_id}.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg222" in gateway.cached_images["t4"] ) assert ( - f"{infra_config().docker_repo_prefix}/my-repo:abcdefg00000" in gateway.cached_images["t4"] + f"{infra_config().ml_account_id}.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg111111111" + in gateway.cached_images["t4"] + ) + assert ( + f"{infra_config().ml_account_id}.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg00000" + in gateway.cached_images["t4"] ) From 109b15b10a1b152323a35e8641c156595e32a99a Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 19 Sep 2023 07:35:12 -0700 Subject: [PATCH 103/425] bump pypi version (#284) --- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 51afec3e..fbe0ddde 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta13" +version = "0.0.0.beta14" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 2b51d491..c37b0da6 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta13", + version="0.0.0.beta14", packages=find_packages(), ) From 3884f85de8788e7c44063f191fedf9359c7c65ee Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 20 Sep 2023 14:01:49 -0700 Subject: [PATCH 104/425] Ianmacleod/error handling (#282) * changing exception location * adding exception handling for generic exceptions * moving more error handling, still haven't gotten middleware working * precommit * add new error handler middleware * iterating * iterating * iterating * setting DD_REMOTE_CONFIGURATION_ENABLED as false so that we don't get unneccessary logging errors * updated tracer working? * cleanup * precommit hooks --- charts/model-engine/templates/_helpers.tpl | 6 ++ model-engine/model_engine_server/api/app.py | 31 +++++++++- .../model_engine_server/api/batch_jobs_v1.py | 10 ++-- .../api/docker_image_batch_job_bundles_v1.py | 6 +- .../model_engine_server/api/files_v1.py | 4 +- .../model_engine_server/api/llms_v1.py | 26 ++++---- .../api/model_bundles_v1.py | 4 +- .../api/model_bundles_v2.py | 4 +- .../api/model_endpoints_v1.py | 58 ++++++++++++++---- .../model_engine_server/api/tasks_v1.py | 6 +- .../model_engine_server/api/triggers_v1.py | 10 ++-- .../common/datadog_utils.py | 6 ++ .../core/domain_exceptions.py | 59 ------------------ .../model_engine_server/core/loggers.py | 11 ++-- .../model_engine_server/domain/exceptions.py | 60 ++++++++++++++++++- .../use_cases/async_inference_use_cases.py | 10 ++-- .../domain/use_cases/batch_job_use_cases.py | 12 ++-- ...docker_image_batch_job_bundle_use_cases.py | 8 +-- .../domain/use_cases/file_use_cases.py | 2 +- .../use_cases/llm_fine_tuning_use_cases.py | 7 ++- .../use_cases/llm_model_endpoint_use_cases.py | 8 +-- .../use_cases/model_bundle_use_cases.py | 10 ++-- .../use_cases/model_endpoint_use_cases.py | 8 +-- .../streaming_inference_use_cases.py | 10 ++-- .../use_cases/sync_inference_use_cases.py | 10 ++-- .../domain/use_cases/trigger_use_cases.py | 11 ++-- .../infra/repositories/db_repository_mixin.py | 2 +- ...s3_file_llm_fine_tune_events_repository.py | 2 +- .../live_batch_job_orchestration_service.py | 2 +- .../services/live_endpoint_builder_service.py | 6 +- .../services/live_model_endpoint_service.py | 10 ++-- model-engine/tests/unit/api/test_tasks.py | 6 +- model-engine/tests/unit/conftest.py | 6 +- .../domain/test_async_inference_use_cases.py | 4 +- ...docker_image_batch_job_bundle_use_cases.py | 2 +- .../tests/unit/domain/test_llm_use_cases.py | 8 +-- .../domain/test_model_bundle_use_cases.py | 2 +- .../domain/test_model_endpoint_use_cases.py | 8 +-- .../test_streaming_inference_use_cases.py | 6 +- .../domain/test_sync_inference_use_cases.py | 4 +- .../test_db_batch_job_record_repository.py | 2 +- ...ocker_image_batch_job_bundle_repository.py | 6 +- .../test_db_model_bundle_repository.py | 2 +- ...est_db_model_endpoint_record_repository.py | 2 +- ...st_live_batch_job_orchestration_service.py | 2 +- .../test_live_endpoint_builder_service.py | 6 +- .../test_live_model_endpoint_service.py | 6 +- 47 files changed, 280 insertions(+), 211 deletions(-) delete mode 100644 model-engine/model_engine_server/core/domain_exceptions.py diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index 9389bcac..0fcf816d 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -124,6 +124,8 @@ podAffinity: env: - name: DATADOG_TRACE_ENABLED value: "${DATADOG_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -184,6 +186,8 @@ env: env: - name: DATADOG_TRACE_ENABLED value: "${DATADOG_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -230,6 +234,8 @@ env: env: - name: DATADOG_TRACE_ENABLED value: "{{ .Values.datadog_trace_enabled }}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_ENV value: {{ .Values.context }} - name: DD_AGENT_HOST diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index 786f097a..a13d62e6 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -1,7 +1,9 @@ import os +import traceback from pathlib import Path -from fastapi import FastAPI, Response +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from model_engine_server.api.batch_jobs_v1 import batch_job_router_v1 from model_engine_server.api.dependencies import get_or_create_aioredis_pool @@ -16,6 +18,9 @@ from model_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1 from model_engine_server.api.tasks_v1 import inference_task_router_v1 from model_engine_server.api.triggers_v1 import trigger_router_v1 +from model_engine_server.common.datadog_utils import get_request_id +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from starlette.middleware.base import BaseHTTPMiddleware app = FastAPI(title="launch", version="1.0.0", redoc_url="/api") @@ -30,6 +35,30 @@ app.include_router(file_router_v1) app.include_router(trigger_router_v1) +logger = make_logger(filename_wo_ext(__name__)) + + +class ExceptionLoggingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + try: + return await call_next(request) + except Exception as e: + tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + structured_log = {"error": str(e), "traceback": "".join(tb_str)} + logger.error("Unhandled exception: %s", structured_log) + request_id = get_request_id() + return JSONResponse( + { + "status_code": 500, + "content": { + "error": f"Internal error for request_id {request_id}. Our team has been notified." + }, + } + ) + + +app.add_middleware(ExceptionLoggingMiddleware) + # TODO: Remove this once we have a better way to serve internal docs INTERNAL_DOCS_PATH = str(Path(__file__).parents[3] / "launch_internal/site") if os.path.exists(INTERNAL_DOCS_PATH): diff --git a/model-engine/model_engine_server/api/batch_jobs_v1.py b/model-engine/model_engine_server/api/batch_jobs_v1.py index 7e939d9c..022b9dc8 100644 --- a/model-engine/model_engine_server/api/batch_jobs_v1.py +++ b/model-engine/model_engine_server/api/batch_jobs_v1.py @@ -22,16 +22,14 @@ UpdateDockerImageBatchJobV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - DockerImageNotFoundException, - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, EndpointLabelsException, EndpointResourceInvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, ) from model_engine_server.domain.use_cases.batch_job_use_cases import ( CreateBatchJobV1UseCase, diff --git a/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py index 96cc3d49..1444a39b 100644 --- a/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py +++ b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py @@ -15,12 +15,12 @@ ) from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ( + EndpointResourceInvalidRequestException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from model_engine_server.core.loggers import filename_wo_ext, make_logger -from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException from model_engine_server.domain.use_cases.docker_image_batch_job_bundle_use_cases import ( CreateDockerImageBatchJobBundleV1UseCase, GetDockerImageBatchJobBundleByIdV1UseCase, diff --git a/model-engine/model_engine_server/api/files_v1.py b/model-engine/model_engine_server/api/files_v1.py index a2d23ba3..8c50cc53 100644 --- a/model-engine/model_engine_server/api/files_v1.py +++ b/model-engine/model_engine_server/api/files_v1.py @@ -16,11 +16,11 @@ UploadFileResponse, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.use_cases.file_use_cases import ( DeleteFileUseCase, GetFileContentUseCase, diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index af114adb..78117376 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -1,7 +1,6 @@ """LLM Model Endpoint routes for the hosted model inference service. """ from typing import Optional -from uuid import uuid4 from fastapi import APIRouter, Depends, HTTPException, Query from model_engine_server.api.dependencies import ( @@ -10,7 +9,7 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_request_id, add_trace_resource_name +from model_engine_server.common.datadog_utils import add_trace_resource_name, get_request_id from model_engine_server.common.dtos.llms import ( CancelFineTuneResponse, CompletionStreamV1Request, @@ -32,13 +31,6 @@ ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectAlreadyExistsException, - ObjectHasInvalidValueException, - ObjectNotApprovedException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, @@ -49,6 +41,11 @@ InvalidRequestException, LLMFineTuningMethodNotImplementedException, LLMFineTuningQuotaReached, + ObjectAlreadyExistsException, + ObjectHasInvalidValueException, + ObjectNotApprovedException, + ObjectNotAuthorizedException, + ObjectNotFoundException, UpstreamServiceError, ) from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( @@ -192,10 +189,12 @@ async def create_completion_sync_task( user=auth, model_endpoint_name=model_endpoint_name, request=request ) except UpstreamServiceError: - request_id = str(uuid4()) - add_trace_request_id(request_id) + request_id = get_request_id() logger.exception(f"Upstream service error for request {request_id}") - return CompletionSyncV1Response(request_id=request_id, output=None) + raise HTTPException( + status_code=500, + detail=f"Upstream service error for request_id {request_id}.", + ) except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( status_code=404, @@ -245,8 +244,7 @@ async def event_generator(): return EventSourceResponse(event_generator()) except UpstreamServiceError: - request_id = str(uuid4()) - add_trace_request_id(request_id) + request_id = get_request_id() logger.exception(f"Upstream service error for request {request_id}") return EventSourceResponse( iter((CompletionStreamV1Response(request_id=request_id).json(),)) diff --git a/model-engine/model_engine_server/api/model_bundles_v1.py b/model-engine/model_engine_server/api/model_bundles_v1.py index de24f860..e192af13 100644 --- a/model-engine/model_engine_server/api/model_bundles_v1.py +++ b/model-engine/model_engine_server/api/model_bundles_v1.py @@ -19,13 +19,13 @@ ModelBundleV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.use_cases.model_bundle_use_cases import ( CloneModelBundleV1UseCase, CreateModelBundleV1UseCase, diff --git a/model-engine/model_engine_server/api/model_bundles_v2.py b/model-engine/model_engine_server/api/model_bundles_v2.py index 94801916..d35de5cf 100644 --- a/model-engine/model_engine_server/api/model_bundles_v2.py +++ b/model-engine/model_engine_server/api/model_bundles_v2.py @@ -19,13 +19,13 @@ ModelBundleV2Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.use_cases.model_bundle_use_cases import ( CloneModelBundleV2UseCase, CreateModelBundleV2UseCase, diff --git a/model-engine/model_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py index d37f8bf6..9a6c9da4 100644 --- a/model-engine/model_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -12,7 +12,7 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name +from model_engine_server.common.datadog_utils import add_trace_resource_name, get_request_id from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, @@ -24,19 +24,17 @@ UpdateModelEndpointV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectAlreadyExistsException, - ObjectHasInvalidValueException, - ObjectNotApprovedException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, EndpointLabelsException, EndpointResourceInvalidRequestException, ExistingEndpointOperationInProgressException, + ObjectAlreadyExistsException, + ObjectHasInvalidValueException, + ObjectNotApprovedException, + ObjectNotAuthorizedException, + ObjectNotFoundException, ) from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CreateModelEndpointV1UseCase, @@ -94,6 +92,13 @@ async def create_model_endpoint( status_code=404, detail="The specified model bundle could not be found.", ) from exc + except Exception as exc: + request_id = get_request_id() + logger.exception(f"Internal service error for request {request_id}: {exc}") + raise HTTPException( + status_code=500, + detail=f"Internal error for request_id {request_id}.", + ) @model_endpoint_router_v1.get("/model-endpoints", response_model=ListModelEndpointsV1Response) @@ -108,10 +113,18 @@ async def list_model_endpoints( """ add_trace_resource_name("model_endpoints_get") logger.info(f"GET /model-endpoints?name={name}&order_by={order_by} for {auth}") - use_case = ListModelEndpointsV1UseCase( - model_endpoint_service=external_interfaces.model_endpoint_service, - ) - return await use_case.execute(user=auth, name=name, order_by=order_by) + try: + use_case = ListModelEndpointsV1UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + ) + return await use_case.execute(user=auth, name=name, order_by=order_by) + except Exception as exc: + request_id = get_request_id() + logger.exception(f"Internal service error for request {request_id}: {exc}") + raise HTTPException( + status_code=500, + detail=f"Internal error for request_id {request_id}.", + ) @model_endpoint_router_v1.get( @@ -137,6 +150,13 @@ async def get_model_endpoint( status_code=404, detail=f"Model Endpoint {model_endpoint_id} was not found.", ) from exc + except Exception as exc: + request_id = get_request_id() + logger.exception(f"Internal service error for request {request_id}: {exc}") + raise HTTPException( + status_code=500, + detail=f"Internal error for request_id {request_id}.", + ) @model_endpoint_router_v1.put( @@ -181,6 +201,13 @@ async def update_model_endpoint( status_code=409, detail="Existing operation on endpoint in progress, try again later.", ) from exc + except Exception as exc: + request_id = get_request_id() + logger.exception(f"Internal service error for request {request_id}: {exc}") + raise HTTPException( + status_code=500, + detail=f"Internal error for request_id {request_id}.", + ) @model_endpoint_router_v1.delete( @@ -216,3 +243,10 @@ async def delete_model_endpoint( status_code=500, detail="deletion of endpoint failed, compute resources still exist.", ) from exc + except Exception as exc: + request_id = get_request_id() + logger.exception(f"Internal service error for request {request_id}: {exc}") + raise HTTPException( + status_code=500, + detail=f"Internal error for request_id {request_id}.", + ) diff --git a/model-engine/model_engine_server/api/tasks_v1.py b/model-engine/model_engine_server/api/tasks_v1.py index 443b7fb7..05fdb270 100644 --- a/model-engine/model_engine_server/api/tasks_v1.py +++ b/model-engine/model_engine_server/api/tasks_v1.py @@ -16,13 +16,11 @@ TaskStatus, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.exceptions import ( EndpointUnsupportedInferenceTypeException, + ObjectNotAuthorizedException, + ObjectNotFoundException, UpstreamServiceError, ) from model_engine_server.domain.use_cases.async_inference_use_cases import ( diff --git a/model-engine/model_engine_server/api/triggers_v1.py b/model-engine/model_engine_server/api/triggers_v1.py index cc32180e..30f3310b 100644 --- a/model-engine/model_engine_server/api/triggers_v1.py +++ b/model-engine/model_engine_server/api/triggers_v1.py @@ -15,17 +15,15 @@ UpdateTriggerV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - DockerImageNotFoundException, - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.exceptions import ( CronSyntaxException, + DockerImageNotFoundException, EndpointLabelsException, EndpointResourceInvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, TriggerNameAlreadyExistsException, ) from model_engine_server.domain.use_cases.trigger_use_cases import ( diff --git a/model-engine/model_engine_server/common/datadog_utils.py b/model-engine/model_engine_server/common/datadog_utils.py index c73fa2f9..26152f03 100644 --- a/model-engine/model_engine_server/common/datadog_utils.py +++ b/model-engine/model_engine_server/common/datadog_utils.py @@ -17,3 +17,9 @@ def add_trace_request_id(request_id: str): current_span = tracer.current_span() if current_span: current_span.set_tag("launch.request_id", request_id) + + +def get_request_id(): + """Gets the request id for an api call (in our case, dd trace id) so that we can filter in Datadog easier""" + current_span = tracer.current_span() + return current_span.trace_id if current_span else None diff --git a/model-engine/model_engine_server/core/domain_exceptions.py b/model-engine/model_engine_server/core/domain_exceptions.py deleted file mode 100644 index 62068614..00000000 --- a/model-engine/model_engine_server/core/domain_exceptions.py +++ /dev/null @@ -1,59 +0,0 @@ -from dataclasses import dataclass - - -class DomainException(Exception): - """ - Base class for exceptions thrown for domain (business logic) errors. - """ - - -class ObjectAlreadyExistsException(DomainException): - """ - Thrown when the user tries to create a model with a name that already exists. - """ - - -class ObjectNotFoundException(DomainException): - """ - Thrown when a required object is not found, e.g. when creating a version for a nonexistent model - """ - - -class ObjectNotAuthorizedException(DomainException): - """ - Thrown when a user tries to access an object they don't own. - """ - - -class ObjectHasInvalidValueException(DomainException, ValueError): - """ - Thrown when a user tries to create an object with an invalid value. - """ - - -class ObjectNotApprovedException(DomainException): - """ - Thrown when a required object is not approved, e.g. for a Bundle in review. - """ - - -@dataclass -class DockerImageNotFoundException(DomainException): - """ - Thrown when a user tries to specify a custom Docker image that cannot be found. - """ - - repository: str - tag: str - - -class DockerBuildFailedException(DomainException): - """ - Thrown if the server failed to build a docker image. - """ - - -class ReadOnlyDatabaseException(DomainException): - """ - Thrown if the server attempted to write to a read-only database. - """ diff --git a/model-engine/model_engine_server/core/loggers.py b/model-engine/model_engine_server/core/loggers.py index e8245199..593302c0 100644 --- a/model-engine/model_engine_server/core/loggers.py +++ b/model-engine/model_engine_server/core/loggers.py @@ -10,7 +10,7 @@ import ddtrace import json_log_formatter import tqdm -from ddtrace import tracer +from ddtrace.tracer import Tracer # DO NOT CHANGE LOGGING FORMAT LOG_FORMAT: str = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] - %(message)s" @@ -82,12 +82,13 @@ def json_record(self, message: str, extra: dict, record: logging.LogRecord) -> d if request_id: extra["request_id"] = request_id - context = tracer.current_trace_context() - trace_id, span_id = (context.trace_id, context.span_id) if context else (0, 0) + context = Tracer().get_log_correlation_context() + trace_id = context.get("trace_id") + span_id = context.get("span_id") # add ids to event dictionary - extra["dd.trace_id"] = trace_id - extra["dd.span_id"] = span_id + extra["dd.trace_id"] = trace_id or 0 + extra["dd.span_id"] = span_id or 0 # add the env, service, and version configured for the tracer. # If tracing is not set up, then this should pull values from DD_ENV, DD_SERVICE, and DD_VERSION. diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index c31eb0ad..934a5e21 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -1,4 +1,62 @@ -from model_engine_server.core.domain_exceptions import DomainException +from dataclasses import dataclass + + +class DomainException(Exception): + """ + Base class for exceptions thrown for domain (business logic) errors. + """ + + +class ObjectAlreadyExistsException(DomainException): + """ + Thrown when the user tries to create a model with a name that already exists. + """ + + +class ObjectNotFoundException(DomainException): + """ + Thrown when a required object is not found, e.g. when creating a version for a nonexistent model + """ + + +class ObjectNotAuthorizedException(DomainException): + """ + Thrown when a user tries to access an object they don't own. + """ + + +class ObjectHasInvalidValueException(DomainException, ValueError): + """ + Thrown when a user tries to create an object with an invalid value. + """ + + +class ObjectNotApprovedException(DomainException): + """ + Thrown when a required object is not approved, e.g. for a Bundle in review. + """ + + +@dataclass +class DockerImageNotFoundException(DomainException): + """ + Thrown when a user tries to specify a custom Docker image that cannot be found. + """ + + repository: str + tag: str + + +class DockerBuildFailedException(DomainException): + """ + Thrown if the server failed to build a docker image. + """ + + +class ReadOnlyDatabaseException(DomainException): + """ + Thrown if the server attempted to write to a read-only database. + """ class ExistingEndpointOperationInProgressException(DomainException): diff --git a/model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py index 3b8a5ddf..647905f2 100644 --- a/model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py @@ -4,15 +4,15 @@ GetAsyncTaskV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.domain.authorization.live_authorization_module import ( LiveAuthorizationModule, ) from model_engine_server.domain.entities import ModelEndpointType -from model_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService DEFAULT_TASK_TIMEOUT_SECONDS = 86400 diff --git a/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py index 0a1bb1f5..7ea13e11 100644 --- a/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py @@ -17,17 +17,17 @@ ) from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - DockerImageNotFoundException, - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.authorization.live_authorization_module import ( LiveAuthorizationModule, ) from model_engine_server.domain.entities import ModelEndpointType +from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) from model_engine_server.domain.gateways import CronJobGateway, DockerImageBatchJobGateway from model_engine_server.domain.repositories import ( DockerImageBatchJobBundleRepository, diff --git a/model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py b/model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py index 29d40fe8..3767ffe5 100644 --- a/model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py @@ -8,13 +8,13 @@ ) from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.domain.authorization.live_authorization_module import ( LiveAuthorizationModule, ) +from model_engine_server.domain.exceptions import ( + ObjectNotAuthorizedException, + ObjectNotFoundException, +) from model_engine_server.domain.repositories import DockerImageBatchJobBundleRepository diff --git a/model-engine/model_engine_server/domain/use_cases/file_use_cases.py b/model-engine/model_engine_server/domain/use_cases/file_use_cases.py index e646e8a0..a3ede743 100644 --- a/model-engine/model_engine_server/domain/use_cases/file_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/file_use_cases.py @@ -6,8 +6,8 @@ UploadFileResponse, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ObjectNotFoundException from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.domain.exceptions import ObjectNotFoundException from model_engine_server.domain.gateways import FileStorageGateway logger = make_logger(filename_wo_ext(__file__)) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py index a66fc3ff..039b15ad 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -13,10 +13,13 @@ ListFineTunesResponse, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ObjectNotFoundException from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.entities import BatchJobStatus -from model_engine_server.domain.exceptions import InvalidRequestException, LLMFineTuningQuotaReached +from model_engine_server.domain.exceptions import ( + InvalidRequestException, + LLMFineTuningQuotaReached, + ObjectNotFoundException, +) from model_engine_server.domain.gateways import FileStorageGateway from model_engine_server.domain.repositories import LLMFineTuneEventsRepository from model_engine_server.domain.services import LLMFineTuningService, ModelEndpointService diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index af5c0b9c..9e4b6ba6 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -33,11 +33,6 @@ from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.entities import ( LLMInferenceFramework, @@ -55,6 +50,9 @@ EndpointLabelsException, EndpointUnsupportedInferenceTypeException, InvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, UpstreamServiceError, ) from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway diff --git a/model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py index be75e695..d79b8793 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py @@ -16,11 +16,6 @@ ModelBundleV2Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - DockerImageNotFoundException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.domain.authorization.live_authorization_module import ( LiveAuthorizationModule, ) @@ -37,6 +32,11 @@ TensorflowFramework, ZipArtifactFlavor, ) +from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) from model_engine_server.domain.gateways import ModelPrimitiveGateway from model_engine_server.domain.repositories import DockerRepository, ModelBundleRepository diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index 1633a72b..bab01204 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -24,11 +24,6 @@ from model_engine_server.common.resource_limits import MAX_ENDPOINT_SIZE, validate_resource_requests from model_engine_server.common.settings import REQUIRED_ENDPOINT_LABELS, RESTRICTED_ENDPOINT_LABELS from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.authorization.live_authorization_module import ( LiveAuthorizationModule, @@ -43,6 +38,9 @@ EndpointInfraStateNotFound, EndpointLabelsException, EndpointResourceInvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, ) from model_engine_server.domain.repositories import ModelBundleRepository from model_engine_server.domain.services import ModelEndpointService diff --git a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py index 1fb70023..e17b512e 100644 --- a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py @@ -5,15 +5,15 @@ SyncEndpointPredictV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.domain.authorization.live_authorization_module import ( LiveAuthorizationModule, ) from model_engine_server.domain.entities import ModelEndpointType -from model_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService diff --git a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py index d785beed..16196ab6 100644 --- a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py @@ -3,15 +3,15 @@ SyncEndpointPredictV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.domain.authorization.live_authorization_module import ( LiveAuthorizationModule, ) from model_engine_server.domain.entities import ModelEndpointType -from model_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService diff --git a/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py b/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py index b616c299..a0bd1769 100644 --- a/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py @@ -14,16 +14,17 @@ from model_engine_server.common.settings import REQUIRED_ENDPOINT_LABELS from model_engine_server.core.auth.authentication_repository import User from model_engine_server.core.config import infra_config -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, +) +from model_engine_server.domain.exceptions import ( + CronSyntaxException, DockerImageNotFoundException, + EndpointLabelsException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from model_engine_server.domain.authorization.live_authorization_module import ( - LiveAuthorizationModule, -) -from model_engine_server.domain.exceptions import CronSyntaxException, EndpointLabelsException from model_engine_server.domain.gateways.cron_job_gateway import CronJobGateway from model_engine_server.domain.repositories import ( DockerImageBatchJobBundleRepository, diff --git a/model-engine/model_engine_server/infra/repositories/db_repository_mixin.py b/model-engine/model_engine_server/infra/repositories/db_repository_mixin.py index cd8bc402..e0e9f242 100644 --- a/model-engine/model_engine_server/infra/repositories/db_repository_mixin.py +++ b/model-engine/model_engine_server/infra/repositories/db_repository_mixin.py @@ -2,7 +2,7 @@ from functools import wraps from typing import Callable -from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException +from model_engine_server.domain.exceptions import ReadOnlyDatabaseException from sqlalchemy.ext.asyncio import AsyncSession diff --git a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py index 90f179c9..6993d1d0 100644 --- a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py +++ b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py @@ -6,8 +6,8 @@ import boto3 import smart_open from model_engine_server.core.config import infra_config -from model_engine_server.core.domain_exceptions import ObjectNotFoundException from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent +from model_engine_server.domain.exceptions import ObjectNotFoundException from model_engine_server.domain.repositories.llm_fine_tune_events_repository import ( LLMFineTuneEventsRepository, ) diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py index 258a4429..7a096c2a 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py @@ -16,7 +16,6 @@ TaskStatus, ) from model_engine_server.core.config import infra_config -from model_engine_server.core.domain_exceptions import ObjectNotFoundException from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.entities import ( BatchJobProgress, @@ -25,6 +24,7 @@ BatchJobStatus, ModelEndpointStatus, ) +from model_engine_server.domain.exceptions import ObjectNotFoundException from model_engine_server.domain.gateways import AsyncModelEndpointInferenceGateway from model_engine_server.domain.services import ModelEndpointService from model_engine_server.domain.use_cases.async_inference_use_cases import ( diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index eabbf034..1a8c0c7d 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -19,7 +19,6 @@ from model_engine_server.common.io import open_wrapper from model_engine_server.common.serialization_utils import bool_to_str from model_engine_server.core.config import infra_config -from model_engine_server.core.domain_exceptions import DockerBuildFailedException from model_engine_server.core.loggers import make_logger from model_engine_server.core.notification_gateway import NotificationApp, NotificationGateway from model_engine_server.core.utils.env import environment @@ -40,7 +39,10 @@ TensorflowFramework, ZipArtifactFlavor, ) -from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.exceptions import ( + DockerBuildFailedException, + EndpointResourceInfraException, +) from model_engine_server.domain.gateways import MonitoringMetricsGateway from model_engine_server.domain.repositories import DockerRepository from model_engine_server.domain.services import EndpointBuilderService diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index ab88886c..dba1a055 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -3,10 +3,6 @@ from datadog import statsd from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.common.settings import generate_deployment_name -from model_engine_server.core.domain_exceptions import ( - ObjectAlreadyExistsException, - ObjectNotFoundException, -) from model_engine_server.core.loggers import filename_wo_ext, make_logger from model_engine_server.domain.entities import ( CallbackAuth, @@ -20,7 +16,11 @@ ModelEndpointType, StorageSpecificationType, ) -from model_engine_server.domain.exceptions import EndpointDeleteFailedException +from model_engine_server.domain.exceptions import ( + EndpointDeleteFailedException, + ObjectAlreadyExistsException, + ObjectNotFoundException, +) from model_engine_server.domain.gateways import ( AsyncModelEndpointInferenceGateway, ModelEndpointsSchemaGateway, diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index db65a80f..5192f025 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -2,12 +2,12 @@ from unittest.mock import AsyncMock, MagicMock, patch from model_engine_server.common.dtos.tasks import EndpointPredictV1Request -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint +from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, + UpstreamServiceError, ) -from model_engine_server.domain.entities import ModelBundle, ModelEndpoint -from model_engine_server.domain.exceptions import UpstreamServiceError def test_create_async_task_success( diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 9714fd9d..8ea517fd 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -41,7 +41,6 @@ TaskStatus, ) from model_engine_server.common.settings import generate_destination -from model_engine_server.core.domain_exceptions import ObjectNotFoundException from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway from model_engine_server.db.endpoint_row_lock import get_lock_key from model_engine_server.db.models import BatchJob as OrmBatchJob @@ -87,7 +86,10 @@ DockerImageBatchJobBundle, ) from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate -from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.exceptions import ( + EndpointResourceInfraException, + ObjectNotFoundException, +) from model_engine_server.domain.gateways import ( AsyncModelEndpointInferenceGateway, CronJobGateway, diff --git a/model-engine/tests/unit/domain/test_async_inference_use_cases.py b/model-engine/tests/unit/domain/test_async_inference_use_cases.py index 7a122b3b..4334480d 100644 --- a/model-engine/tests/unit/domain/test_async_inference_use_cases.py +++ b/model-engine/tests/unit/domain/test_async_inference_use_cases.py @@ -3,11 +3,11 @@ import pytest from model_engine_server.common.dtos.tasks import EndpointPredictV1Request, TaskStatus from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from model_engine_server.domain.entities import ModelEndpoint from model_engine_server.domain.use_cases.async_inference_use_cases import ( CreateAsyncInferenceTaskV1UseCase, GetAsyncInferenceTaskV1UseCase, diff --git a/model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py b/model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py index 9522c9d5..4f62b79d 100644 --- a/model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py +++ b/model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py @@ -5,7 +5,7 @@ ) from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index da7a6451..f79d1e39 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -14,16 +14,14 @@ ) from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.domain.entities import ModelEndpoint, ModelEndpointType from model_engine_server.domain.exceptions import ( EndpointUnsupportedInferenceTypeException, InvalidRequestException, LLMFineTuningQuotaReached, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, ) from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( MAX_LLM_ENDPOINTS_PER_INTERNAL_USER, diff --git a/model-engine/tests/unit/domain/test_model_bundle_use_cases.py b/model-engine/tests/unit/domain/test_model_bundle_use_cases.py index d9b4bc25..ae2bb7e2 100644 --- a/model-engine/tests/unit/domain/test_model_bundle_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_bundle_use_cases.py @@ -10,7 +10,7 @@ ModelBundleV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, ObjectNotAuthorizedException, ObjectNotFoundException, diff --git a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py index 95901f8a..49e017fa 100644 --- a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py @@ -17,16 +17,14 @@ STORAGE_LIMIT, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) from model_engine_server.domain.entities import ModelBundle, ModelEndpoint from model_engine_server.domain.exceptions import ( EndpointBillingTagsMalformedException, EndpointLabelsException, EndpointResourceInvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, ) from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CreateModelEndpointV1UseCase, diff --git a/model-engine/tests/unit/domain/test_streaming_inference_use_cases.py b/model-engine/tests/unit/domain/test_streaming_inference_use_cases.py index 191fa0f4..9da48267 100644 --- a/model-engine/tests/unit/domain/test_streaming_inference_use_cases.py +++ b/model-engine/tests/unit/domain/test_streaming_inference_use_cases.py @@ -3,12 +3,12 @@ import pytest from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from model_engine_server.domain.entities import ModelEndpoint -from model_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException from model_engine_server.domain.use_cases.streaming_inference_use_cases import ( CreateStreamingInferenceTaskV1UseCase, ) diff --git a/model-engine/tests/unit/domain/test_sync_inference_use_cases.py b/model-engine/tests/unit/domain/test_sync_inference_use_cases.py index 879d5345..673cafa1 100644 --- a/model-engine/tests/unit/domain/test_sync_inference_use_cases.py +++ b/model-engine/tests/unit/domain/test_sync_inference_use_cases.py @@ -3,11 +3,11 @@ import pytest from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.domain_exceptions import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from model_engine_server.domain.entities import ModelEndpoint from model_engine_server.domain.use_cases.sync_inference_use_cases import ( CreateSyncInferenceTaskV1UseCase, ) diff --git a/model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py b/model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py index d52d327b..214a1ebc 100644 --- a/model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py @@ -3,9 +3,9 @@ from unittest.mock import AsyncMock import pytest -from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException from model_engine_server.db.models import BatchJob, Bundle from model_engine_server.domain.entities import BatchJobRecord +from model_engine_server.domain.exceptions import ReadOnlyDatabaseException from model_engine_server.infra.repositories.db_batch_job_record_repository import ( DbBatchJobRecordRepository, OrmBatchJob, diff --git a/model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py b/model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py index b28bf81f..2bfaab3b 100644 --- a/model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py @@ -4,13 +4,15 @@ import pytest from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException from model_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle from model_engine_server.domain.entities import GpuType from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) -from model_engine_server.domain.exceptions import CorruptRecordInfraStateException +from model_engine_server.domain.exceptions import ( + CorruptRecordInfraStateException, + ReadOnlyDatabaseException, +) from model_engine_server.infra.repositories import DbDockerImageBatchJobBundleRepository from model_engine_server.infra.repositories.db_docker_image_batch_job_bundle_repository import ( translate_docker_image_batch_job_bundle_orm_to_entity, diff --git a/model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py b/model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py index dd73b221..4eb94d20 100644 --- a/model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py @@ -4,7 +4,6 @@ import pytest from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException from model_engine_server.db.models import Bundle from model_engine_server.domain.entities import ( CloudpickleArtifactFlavor, @@ -12,6 +11,7 @@ ModelBundlePackagingType, PytorchFramework, ) +from model_engine_server.domain.exceptions import ReadOnlyDatabaseException from model_engine_server.infra.repositories.db_model_bundle_repository import ( DbModelBundleRepository, OrmModelBundle, diff --git a/model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py b/model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py index 8d751272..3ad72127 100644 --- a/model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py @@ -4,9 +4,9 @@ import pytest from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from model_engine_server.core.domain_exceptions import ReadOnlyDatabaseException from model_engine_server.db.models import Bundle, Endpoint from model_engine_server.domain.entities import ModelEndpointRecord +from model_engine_server.domain.exceptions import ReadOnlyDatabaseException from model_engine_server.infra.gateways import FakeMonitoringMetricsGateway from model_engine_server.infra.repositories import db_model_endpoint_record_repository from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( diff --git a/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py b/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py index 8d894622..11b2abe5 100644 --- a/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py +++ b/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py @@ -6,7 +6,6 @@ import pytest from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME from model_engine_server.common.dtos.tasks import GetAsyncTaskV1Response, ResponseSchema, TaskStatus -from model_engine_server.core.domain_exceptions import ObjectNotFoundException from model_engine_server.domain.entities import ( BatchJob, BatchJobSerializationFormat, @@ -15,6 +14,7 @@ ModelEndpoint, ModelEndpointStatus, ) +from model_engine_server.domain.exceptions import ObjectNotFoundException from model_engine_server.infra.gateways import LiveBatchJobProgressGateway from model_engine_server.infra.services import ( LiveBatchJobOrchestrationService, diff --git a/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py index 6d5724fb..31d44b2d 100644 --- a/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py +++ b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py @@ -8,14 +8,16 @@ BuildEndpointResponse, BuildEndpointStatus, ) -from model_engine_server.core.domain_exceptions import DockerBuildFailedException from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway from model_engine_server.core.notification_gateway import NotificationApp from model_engine_server.domain.entities.model_bundle_entity import ( ArtifactLike, RunnableImageFlavor, ) -from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.exceptions import ( + DockerBuildFailedException, + EndpointResourceInfraException, +) from model_engine_server.infra.gateways.fake_monitoring_metrics_gateway import ( FakeMonitoringMetricsGateway, ) diff --git a/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py index 87cbab0f..66969005 100644 --- a/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py +++ b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py @@ -2,10 +2,6 @@ from unittest.mock import AsyncMock import pytest -from model_engine_server.core.domain_exceptions import ( - ObjectAlreadyExistsException, - ObjectNotFoundException, -) from model_engine_server.domain.entities import ( ModelBundle, ModelEndpoint, @@ -15,6 +11,8 @@ from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, ExistingEndpointOperationInProgressException, + ObjectAlreadyExistsException, + ObjectNotFoundException, ) from model_engine_server.infra.services import LiveModelEndpointService From c40cd50d59b234c64fbeac37056e0fba801d4d76 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Fri, 22 Sep 2023 19:02:49 -0700 Subject: [PATCH 105/425] Ianmacleod/add datadog tracing in training (#287) * got datadog tracing working locally, need to figure out exactly what changed * iterating to get things working in training * iterating, adding fake error route * iterating, adding fake error route pt2 * adding timestamp to user logs for errors * adding timestamp to user logs for errors without __init__ for midleware * pushing draft pr * working version of tags for request_id as uuid in middleware * cleanup * removing exctra logging test line * removing testing in logging output, adding type:ignore where needed * removing dd patching --- model-engine/model_engine_server/api/app.py | 60 +++++++++++-------- .../model_engine_server/api/llms_v1.py | 6 +- .../api/model_endpoints_v1.py | 46 ++------------ .../common/datadog_utils.py | 6 -- .../model_engine_server/core/loggers.py | 20 +++---- 5 files changed, 52 insertions(+), 86 deletions(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index a13d62e6..f87fcf76 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -1,7 +1,10 @@ import os import traceback +import uuid +from datetime import datetime from pathlib import Path +import pytz from fastapi import FastAPI, Request, Response from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles @@ -18,9 +21,14 @@ from model_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1 from model_engine_server.api.tasks_v1 import inference_task_router_v1 from model_engine_server.api.triggers_v1 import trigger_router_v1 -from model_engine_server.common.datadog_utils import get_request_id -from model_engine_server.core.loggers import filename_wo_ext, make_logger -from starlette.middleware.base import BaseHTTPMiddleware +from model_engine_server.core.loggers import ( + filename_wo_ext, + get_request_id, + make_logger, + set_request_id, +) + +logger = make_logger(filename_wo_ext(__name__)) app = FastAPI(title="launch", version="1.0.0", redoc_url="/api") @@ -35,29 +43,33 @@ app.include_router(file_router_v1) app.include_router(trigger_router_v1) -logger = make_logger(filename_wo_ext(__name__)) - - -class ExceptionLoggingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - try: - return await call_next(request) - except Exception as e: - tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) - structured_log = {"error": str(e), "traceback": "".join(tb_str)} - logger.error("Unhandled exception: %s", structured_log) - request_id = get_request_id() - return JSONResponse( - { - "status_code": 500, - "content": { - "error": f"Internal error for request_id {request_id}. Our team has been notified." - }, - } - ) +@app.middleware("http") +async def dispatch(request: Request, call_next): + try: + set_request_id(str(uuid.uuid4())) + return await call_next(request) + except Exception as e: + tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + request_id = get_request_id() + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": str(e), + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Unhandled exception: %s", structured_log) + return JSONResponse( + { + "status_code": 500, + "content": { + "error": "Internal error occurred. Our team has been notified.", + "timestamp": timestamp, + "request_id": request_id, + }, + } + ) -app.add_middleware(ExceptionLoggingMiddleware) # TODO: Remove this once we have a better way to serve internal docs INTERNAL_DOCS_PATH = str(Path(__file__).parents[3] / "launch_internal/site") diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 78117376..9f6b48d8 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -9,7 +9,7 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name, get_request_id +from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.llms import ( CancelFineTuneResponse, CompletionStreamV1Request, @@ -31,7 +31,7 @@ ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import filename_wo_ext, get_request_id, make_logger from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, EndpointLabelsException, @@ -247,7 +247,7 @@ async def event_generator(): request_id = get_request_id() logger.exception(f"Upstream service error for request {request_id}") return EventSourceResponse( - iter((CompletionStreamV1Response(request_id=request_id).json(),)) + iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore ) except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( diff --git a/model-engine/model_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py index 9a6c9da4..4bf3cf32 100644 --- a/model-engine/model_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -12,7 +12,7 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name, get_request_id +from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, @@ -92,13 +92,6 @@ async def create_model_endpoint( status_code=404, detail="The specified model bundle could not be found.", ) from exc - except Exception as exc: - request_id = get_request_id() - logger.exception(f"Internal service error for request {request_id}: {exc}") - raise HTTPException( - status_code=500, - detail=f"Internal error for request_id {request_id}.", - ) @model_endpoint_router_v1.get("/model-endpoints", response_model=ListModelEndpointsV1Response) @@ -113,18 +106,10 @@ async def list_model_endpoints( """ add_trace_resource_name("model_endpoints_get") logger.info(f"GET /model-endpoints?name={name}&order_by={order_by} for {auth}") - try: - use_case = ListModelEndpointsV1UseCase( - model_endpoint_service=external_interfaces.model_endpoint_service, - ) - return await use_case.execute(user=auth, name=name, order_by=order_by) - except Exception as exc: - request_id = get_request_id() - logger.exception(f"Internal service error for request {request_id}: {exc}") - raise HTTPException( - status_code=500, - detail=f"Internal error for request_id {request_id}.", - ) + use_case = ListModelEndpointsV1UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + ) + return await use_case.execute(user=auth, name=name, order_by=order_by) @model_endpoint_router_v1.get( @@ -150,13 +135,6 @@ async def get_model_endpoint( status_code=404, detail=f"Model Endpoint {model_endpoint_id} was not found.", ) from exc - except Exception as exc: - request_id = get_request_id() - logger.exception(f"Internal service error for request {request_id}: {exc}") - raise HTTPException( - status_code=500, - detail=f"Internal error for request_id {request_id}.", - ) @model_endpoint_router_v1.put( @@ -201,13 +179,6 @@ async def update_model_endpoint( status_code=409, detail="Existing operation on endpoint in progress, try again later.", ) from exc - except Exception as exc: - request_id = get_request_id() - logger.exception(f"Internal service error for request {request_id}: {exc}") - raise HTTPException( - status_code=500, - detail=f"Internal error for request_id {request_id}.", - ) @model_endpoint_router_v1.delete( @@ -243,10 +214,3 @@ async def delete_model_endpoint( status_code=500, detail="deletion of endpoint failed, compute resources still exist.", ) from exc - except Exception as exc: - request_id = get_request_id() - logger.exception(f"Internal service error for request {request_id}: {exc}") - raise HTTPException( - status_code=500, - detail=f"Internal error for request_id {request_id}.", - ) diff --git a/model-engine/model_engine_server/common/datadog_utils.py b/model-engine/model_engine_server/common/datadog_utils.py index 26152f03..c73fa2f9 100644 --- a/model-engine/model_engine_server/common/datadog_utils.py +++ b/model-engine/model_engine_server/common/datadog_utils.py @@ -17,9 +17,3 @@ def add_trace_request_id(request_id: str): current_span = tracer.current_span() if current_span: current_span.set_tag("launch.request_id", request_id) - - -def get_request_id(): - """Gets the request id for an api call (in our case, dd trace id) so that we can filter in Datadog easier""" - current_span = tracer.current_span() - return current_span.trace_id if current_span else None diff --git a/model-engine/model_engine_server/core/loggers.py b/model-engine/model_engine_server/core/loggers.py index 593302c0..94c96998 100644 --- a/model-engine/model_engine_server/core/loggers.py +++ b/model-engine/model_engine_server/core/loggers.py @@ -10,7 +10,7 @@ import ddtrace import json_log_formatter import tqdm -from ddtrace.tracer import Tracer +from ddtrace import tracer # DO NOT CHANGE LOGGING FORMAT LOG_FORMAT: str = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] - %(message)s" @@ -47,7 +47,7 @@ def get_request_id() -> Optional[str]: def set_request_id(request_id: str) -> None: """Set the request id in the context variable.""" - ctx_var_request_id.set(request_id) + ctx_var_request_id.set(request_id) # type: ignore def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Logger: @@ -82,13 +82,9 @@ def json_record(self, message: str, extra: dict, record: logging.LogRecord) -> d if request_id: extra["request_id"] = request_id - context = Tracer().get_log_correlation_context() - trace_id = context.get("trace_id") - span_id = context.get("span_id") - - # add ids to event dictionary - extra["dd.trace_id"] = trace_id or 0 - extra["dd.span_id"] = span_id or 0 + current_span = tracer.current_span() + extra["dd.trace_id"] = current_span.trace_id if current_span else 0 + extra["dd.span_id"] = current_span.span_id if current_span else 0 # add the env, service, and version configured for the tracer. # If tracing is not set up, then this should pull values from DD_ENV, DD_SERVICE, and DD_VERSION. @@ -188,7 +184,7 @@ def logger_name(*, fallback_name: Optional[str] = None) -> str: # in which case we use it's file name if hasattr(calling_module, "__file__"): - return filename_wo_ext(calling_module.__file__) + return filename_wo_ext(calling_module.__file__) # type: ignore if fallback_name is not None: fallback_name = fallback_name.strip() if len(fallback_name) > 0: @@ -260,8 +256,8 @@ def silence_chatty_datadog_loggers(*, silence_internal_writer: bool = False) -> silence_chatty_logger("ddtrace.internal.writer", quieter=logging.FATAL) -@contextmanager -def loggers_at_level(*loggers_or_names, new_level: int) -> None: +@contextmanager # type: ignore +def loggers_at_level(*loggers_or_names, new_level: int) -> None: # type: ignore """Temporarily set one or more loggers to a specific level, resetting to previous levels on context end. :param:`loggers_or_names` is one or more :class:`logging.Logger` instances, or `str` names From 142f65fe51b43fbee001709bbdee75be854e452a Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 25 Sep 2023 16:09:02 -0700 Subject: [PATCH 106/425] Add retries for flaky integration tests (#289) --- integration_tests/rest_api_utils.py | 7 +++++- integration_tests/test_bundles.py | 4 ++++ integration_tests/test_endpoints.py | 34 ++++++++++++++++++----------- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index fb0dd7c3..604b8744 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -224,6 +224,11 @@ def my_model(**keyword_args): } +@retry(stop=stop_after_attempt(300), wait=wait_fixed(2)) +def ensure_launch_gateway_healthy(): + assert requests.get(f"{BASE_PATH}/healthz").status_code == 200 + + def create_model_bundle( create_model_bundle_request: Dict[str, Any], user_id: str, version: str ) -> Dict[str, Any]: @@ -735,7 +740,7 @@ def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_p # Wait up to 30 seconds for the tasks to be returned. @retry( - stop=stop_after_attempt(30), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError) + stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError) ) def ensure_all_async_tasks_success(task_ids: List[str], user_id: str, return_pickled: bool): responses = asyncio.run(get_async_tasks(task_ids, user_id)) diff --git a/integration_tests/test_bundles.py b/integration_tests/test_bundles.py index 5d38d80f..cb8a45e7 100644 --- a/integration_tests/test_bundles.py +++ b/integration_tests/test_bundles.py @@ -1,4 +1,5 @@ import pytest +from tenacity import retry, stop_after_attempt, wait_fixed from .rest_api_utils import ( CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE, @@ -6,12 +7,15 @@ USER_ID_0, USER_ID_1, create_model_bundle, + ensure_launch_gateway_healthy, get_latest_model_bundle, ) @pytest.fixture(scope="session") +@retry(stop=stop_after_attempt(10), wait=wait_fixed(30)) def model_bundles(): + ensure_launch_gateway_healthy() for user in [USER_ID_0, USER_ID_1]: for create_bundle_request in [ CREATE_MODEL_BUNDLE_REQUEST_SIMPLE, diff --git a/integration_tests/test_endpoints.py b/integration_tests/test_endpoints.py index 2af5a257..ad40a2d9 100644 --- a/integration_tests/test_endpoints.py +++ b/integration_tests/test_endpoints.py @@ -2,6 +2,7 @@ import time import pytest +from tenacity import RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_fixed from .rest_api_utils import ( CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, @@ -41,6 +42,23 @@ def delete_endpoints(capsys): print("Endpoint deletion failed") +@retry(stop=stop_after_attempt(3), wait=wait_fixed(10), retry=retry_if_exception_type(RetryError)) +def ensure_async_inference_works(user, create_endpoint_request, inference_payload, return_pickled): + print( + f"Sending async tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..." + ) + task_ids = asyncio.run( + create_async_tasks( + create_endpoint_request["name"], + [inference_payload] * 3, + user, + ) + ) + print("Retrieving async task results...") + ensure_nonzero_available_workers(create_endpoint_request["name"], user) + ensure_all_async_tasks_success(task_ids, user, return_pickled) + + @pytest.mark.parametrize( "create_endpoint_request,update_endpoint_request,inference_requests", [ @@ -89,22 +107,12 @@ def test_async_model_endpoint( == update_endpoint_request["max_workers"] ) - time.sleep(10) + time.sleep(20) for inference_payload, return_pickled in inference_requests: - print( - f"Sending async tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..." - ) - task_ids = asyncio.run( - create_async_tasks( - create_endpoint_request["name"], - [inference_payload] * 3, - user, - ) + ensure_async_inference_works( + user, create_endpoint_request, inference_payload, return_pickled ) - print("Retrieving async task results...") - ensure_nonzero_available_workers(create_endpoint_request["name"], user) - ensure_all_async_tasks_success(task_ids, user, return_pickled) finally: delete_model_endpoint(create_endpoint_request["name"], user) From 643ac1f617588f112d4a4d3e8ab3a8e3cf31f412 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Mon, 25 Sep 2023 17:15:31 -0700 Subject: [PATCH 107/425] Ianmacleod/pypi version nudge (#290) * adding pypi version bump nudge * version bump * changing naming conventions for env variable to be consistent * . --- clients/python/llmengine/__init__.py | 31 +++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 768cbb86..e0910afa 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta13" +__version__ = "0.0.0.b14" +import os from typing import Sequence +import requests from llmengine.completion import Completion from llmengine.data_types import ( CancelFineTuneResponse, @@ -67,3 +69,30 @@ "Model", "UploadFileResponse", ) + + +def check_version(): + try: + current_version = __version__ + response = requests.get("https://pypi.org/pypi/scale-llm-engine/json") + latest_version = response.json()["info"]["version"] + + if current_version != latest_version: + print( + f"A newer version ({latest_version}) of 'scale-llm-engine' is available. Please upgrade!" + ) + print("To upgrade, run: pip install --upgrade scale-llm-engine") + print( + "Don't want to see this message? Set the environment variable 'LLM_ENGINE_DISABLE_VERSION_CHECK' to 'true'." + ) + except requests.RequestException: + # Handle exceptions related to the request (like timeouts, connection errors, etc.) + print( + "Failed to check for the most recent llm-engine package version. Please check your internet connection." + ) + except Exception: + print("Something went wrong with checking for the most recent llm-engine package version.") + + +if not os.environ.get("LLM_ENGINE_DISABLE_VERSION_CHECK"): + check_version() From 4b3ca5244093a27708f83170a612abfb678ca85d Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 25 Sep 2023 18:35:48 -0700 Subject: [PATCH 108/425] Propagate extra server args to the gunicorn command (#291) --- .../inference/forwarding/echo_server.py | 3 ++- .../inference/forwarding/http_forwarder.py | 6 +++++- .../inference/sync_inference/start_fastapi_server.py | 8 ++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/inference/forwarding/echo_server.py b/model-engine/model_engine_server/inference/forwarding/echo_server.py index 12470cfc..0a44b832 100644 --- a/model-engine/model_engine_server/inference/forwarding/echo_server.py +++ b/model-engine/model_engine_server/inference/forwarding/echo_server.py @@ -33,7 +33,7 @@ def entrypoint(): parser.add_argument("--host", type=str, default="[::]") parser.add_argument("--port", type=int, default=5009) - args = parser.parse_args() + args, extra_args = parser.parse_known_args() command = [ "gunicorn", @@ -48,6 +48,7 @@ def entrypoint(): "--workers", str(args.num_workers), "model_engine_server.inference.forwarding.echo_server:app", + *extra_args, ] subprocess.run(command) diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index 85de6ded..5943bc50 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -138,8 +138,9 @@ def entrypoint(): parser.add_argument("--host", type=str, default="[::]") parser.add_argument("--port", type=int, default=5000) parser.add_argument("--set", type=str, action="append") + parser.add_argument("--graceful-timeout", type=int, default=600) - args = parser.parse_args() + args, extra_args = parser.parse_known_args() values = [f"CONFIG_FILE={args.config}"] if args.set is not None: @@ -160,8 +161,11 @@ def entrypoint(): "uvicorn.workers.UvicornWorker", "--workers", str(args.num_workers), + "--graceful-timeout", + str(args.graceful_timeout), *envs, "model_engine_server.inference.forwarding.http_forwarder:app", + *extra_args, ] subprocess.run(command) diff --git a/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py index 97aea0ed..2b3aef79 100644 --- a/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py @@ -1,3 +1,4 @@ +import argparse import os import subprocess @@ -8,6 +9,10 @@ def start_server(): + parser = argparse.ArgumentParser() + parser.add_argument("--graceful-timeout", type=int, default=600) + args, extra_args = parser.parse_known_args() + # TODO: HTTPS command = [ "gunicorn", @@ -21,7 +26,10 @@ def start_server(): "uvicorn.workers.UvicornWorker", "--workers", str(NUM_PROCESSES), + "--graceful-timeout", + str(args.graceful_timeout), "model_engine_server.inference.sync_inference.fastapi_server:app", + *extra_args, ] unset_sensitive_envvars() subprocess.run(command) From 476543df16a09ca5dde1f45852249dbfd28fa187 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 26 Sep 2023 09:49:23 -0700 Subject: [PATCH 109/425] Support AWQ for vLLM (#292) --- .../model_engine_server/domain/entities/llm_entity.py | 1 + .../domain/use_cases/llm_model_endpoint_use_cases.py | 8 ++++++++ .../model_engine_server/inference/vllm/requirements.txt | 2 +- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index dfb6f63c..0624857f 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -16,6 +16,7 @@ class LLMInferenceFramework(str, Enum): class Quantization(str, Enum): BITSANDBYTES = "bitsandbytes" + AWQ = "awq" @dataclass diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 9e4b6ba6..98ea40ec 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -227,6 +227,7 @@ async def create_model_bundle( framework_image_tag, endpoint_name, num_shards, + quantize, checkpoint_path, ) elif framework == LLMInferenceFramework.LIGHTLLM: @@ -453,6 +454,7 @@ async def create_vllm_bundle( framework_image_tag: str, endpoint_unique_name: str, num_shards: int, + quantize: Optional[Quantization], checkpoint_path: Optional[str], ): command = [] @@ -482,6 +484,12 @@ async def create_vllm_bundle( f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005 --max-num-batched-tokens {max_num_batched_tokens}" ) + if quantize: + if quantize == Quantization.AWQ: + subcommands[-1] = subcommands[-1] + f" --quantization {quantize}" + else: + raise InvalidRequestException(f"Quantization {quantize} is not supported by vLLM.") + command = [ "/bin/bash", "-c", diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 7d8f12fc..db3b97a4 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,3 +1,3 @@ ray==2.6.3 -vllm==0.1.7 +git+https://github.com/vllm-project/vllm.git@7d7e3b78a3c265ab3c57eeff43af56f509907998#egg=vllm pydantic==1.10.12 From 57d2ab2d1ca6335190b3b2f77295cbd99e88b4c8 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 27 Sep 2023 11:17:45 -0700 Subject: [PATCH 110/425] Update issue templates (#288) * Update issue templates * add bug emoji * add rocket ship emoji * modifying gh issues templates * consolidating bug report template * adding new feature prioritization comments * whoops --- .github/ISSUE_TEMPLATE/bug_report.md | 41 ++++++++++++++++++++ .github/ISSUE_TEMPLATE/custom.md | 10 +++++ .github/ISSUE_TEMPLATE/feature_request.md | 46 +++++++++++++++++++++++ 3 files changed, 97 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/custom.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..725c9b4f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,41 @@ +--- +name: "\U0001F41B Bug report" +about: Submit a bug report to help us improve LLM Engine. If this doesn't look right, [choose a different type.](https://github.com/scaleapi/llm-engine/issues/new/choose) +title: '' +labels: bug +assignees: '' + +--- + +**Describe the bug** +Thank you for taking the time to file a bug report! Before you do so, please take a look at existing open issues and make sure that your issue is not already documented. If it isn't, please provide us with a clear and concise description of what the bug is. + +**LLM Engine Version** +- LLM Engine Version: + +**System Version** +- Python Version: +- Operating System: + +**Timestamp and Request ID** +_If you ran into an internal error while using `llm-engine`, please provide the following_ +- Timestamp: +- Request ID: + +**Minimal Reproducible Example** +Steps to reproduce the behavior: +1. Install LLM Engine '....' +2. Make API call '....' +3. See error + +Please provide a code snippet that documents how your bug can be reproduced. +``` +import llmengine +... +``` + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/custom.md b/.github/ISSUE_TEMPLATE/custom.md new file mode 100644 index 00000000..89130c0b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/custom.md @@ -0,0 +1,10 @@ +--- +name: Custom issue template +about: If your issue doesn't fall into a bug template or feature request, please provide some information on it here. +title: '' +labels: '' +assignees: '' + +--- + + diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..3043cd19 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,46 @@ +--- +name: "\U0001F680 Feature request" +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +## Feature Request + +**What is the problem you're currently running into?** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Why do you want this feature?** +A clear and concise description of why you want the feature. + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. + +### Prioritization + +- **Does this feature block you from using the project?** + - [ ] Yes + - [ ] No + +- **How many users will benefit from this feature?** + - [ ] Just me + - [ ] Few people might benefit + - [ ] Many users will love it! + +- **Complexity** + - [ ] I believe it's a simple feature to implement + - [ ] It might require some effort to implement + - [ ] It's probably complex, and might take significant effort + +--- + +Thank you for your contribution to `llm-engine`. Please ensure you've given the feature considerable thought before submitting it. Once your feature request is accepted, and you're interested in building it, please mention it so that the maintainers can guide you! + From 0a18866dcac35fc671771ffbd097d4324d1ec8a0 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 27 Sep 2023 13:35:01 -0700 Subject: [PATCH 111/425] Ian GitHub templates (#293) * Update issue templates again (remove extra line) --- .github/ISSUE_TEMPLATE/bug_report.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 725c9b4f..ac062d42 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,6 +1,6 @@ --- name: "\U0001F41B Bug report" -about: Submit a bug report to help us improve LLM Engine. If this doesn't look right, [choose a different type.](https://github.com/scaleapi/llm-engine/issues/new/choose) +about: Submit a bug report to help us improve LLM Engine. title: '' labels: bug assignees: '' @@ -18,9 +18,9 @@ Thank you for taking the time to file a bug report! Before you do so, please tak - Operating System: **Timestamp and Request ID** -_If you ran into an internal error while using `llm-engine`, please provide the following_ -- Timestamp: -- Request ID: +_If you ran into an internal error while using `llm-engine`, please provide the following. These fields are provided in the JSON Response when an internal error occurs._ +- `timestamp`: +- `request_id`: **Minimal Reproducible Example** Steps to reproduce the behavior: From 6f7f4473aeed3cb7b3c3294aa2d7e893a27634ad Mon Sep 17 00:00:00 2001 From: Ian Macleod Date: Wed, 27 Sep 2023 21:45:35 +0000 Subject: [PATCH 112/425] bump the llm engine pypi version --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index e0910afa..55f0b995 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.b14" +__version__ = "0.0.0.b15" import os from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index fbe0ddde..8dc324cd 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta14" +version = "0.0.0.beta15" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index c37b0da6..c25d1f49 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta14", + version="0.0.0.beta15", packages=find_packages(), ) From d57e75580991f5d941b57791e64ad2eb04dacedf Mon Sep 17 00:00:00 2001 From: Ian Macleod Date: Wed, 27 Sep 2023 22:11:02 +0000 Subject: [PATCH 113/425] changing version formatting so that upgrading will work --- clients/python/llmengine/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 55f0b995..2b7cb0bc 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.b15" +__version__ = "0.0.0b15" import os from typing import Sequence @@ -76,6 +76,8 @@ def check_version(): current_version = __version__ response = requests.get("https://pypi.org/pypi/scale-llm-engine/json") latest_version = response.json()["info"]["version"] + print(current_version) + print(latest_version) if current_version != latest_version: print( From 53af4fb8d32dcbf7b498cbbe2f88ac3a611db9d6 Mon Sep 17 00:00:00 2001 From: Ian Macleod Date: Wed, 27 Sep 2023 22:12:57 +0000 Subject: [PATCH 114/425] clean up print statements --- clients/python/llmengine/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 2b7cb0bc..017ae95c 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -76,8 +76,6 @@ def check_version(): current_version = __version__ response = requests.get("https://pypi.org/pypi/scale-llm-engine/json") latest_version = response.json()["info"]["version"] - print(current_version) - print(latest_version) if current_version != latest_version: print( From 0c701f8134ba91f92e216341b553a6b1c2cd8a62 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:49:03 -0700 Subject: [PATCH 115/425] bump the llm engine pypi version (#294) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 017ae95c..af4ac08b 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b15" +__version__ = "0.0.0b16" import os from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 8dc324cd..c117c42b 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta15" +version = "0.0.0.beta16" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index c25d1f49..862607ea 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta15", + version="0.0.0.beta16", packages=find_packages(), ) From cb5e8a4a23d8da6f88d4780e50d1a82eec7b947a Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 29 Sep 2023 12:38:52 -0700 Subject: [PATCH 116/425] Add A100e GPU type (#299) --- clients/python/llmengine/data_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 1a8baba3..08a7f0be 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -32,6 +32,7 @@ class GpuType(str, Enum): NVIDIA_TESLA_T4 = "nvidia-tesla-t4" NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" + NVIDIA_AMPERE_A100E = "nvidia-ampere-a100e" class ModelEndpointType(str, Enum): From c77c9f0e07ffbcb006821e99dd7a102e4d20c8ca Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 29 Sep 2023 14:09:39 -0700 Subject: [PATCH 117/425] bump client version (#300) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index af4ac08b..2c0eeb45 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b16" +__version__ = "0.0.0b17" import os from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index c117c42b..ff781834 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta16" +version = "0.0.0.beta17" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 862607ea..3ce304b4 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta16", + version="0.0.0.beta17", packages=find_packages(), ) From 90233702607d4c3018cbe4cf3c3161154a05c96f Mon Sep 17 00:00:00 2001 From: Frances Yuan <139272087+francesy-scale@users.noreply.github.com> Date: Fri, 29 Sep 2023 14:52:45 -0700 Subject: [PATCH 118/425] Add repetition_penalty, top_k, top_p to Completion (#295) * add repetition_penalty, top_k, top_p * add frequency_penalty, presence_penalty, add lightllm * add comments * fix * fix Optional, add params validation * remove repetition_penalty * add back optional, update validation function * type check --- clients/python/llmengine/completion.py | 64 +++++++++++++ clients/python/llmengine/data_types.py | 8 ++ .../model_engine_server/common/dtos/llms.py | 40 +++++++- .../use_cases/llm_model_endpoint_use_cases.py | 94 ++++++++++++++++++- 4 files changed, 203 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 8cecd765..507754d8 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -33,6 +33,10 @@ async def acreate( temperature: float = 0.2, stop_sequences: Optional[List[str]] = None, return_token_log_probs: Optional[bool] = False, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]: @@ -72,6 +76,26 @@ async def acreate( Whether to return the log probabilities of generated tokens. When True, the response will include a list of tokens and their log probabilities. + presence_penalty (Optional[float]): + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + https://platform.openai.com/docs/guides/gpt/parameter-details + Range: [0.0, 2.0]. Higher values encourage the model to use new tokens. + + frequency_penalty (Optional[float]): + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + https://platform.openai.com/docs/guides/gpt/parameter-details + Range: [0.0, 2.0]. Higher values encourage the model to use new tokens. + + top_k (Optional[int]): + Integer that controls the number of top tokens to consider. + Range: [1, infinity). -1 means consider all tokens. + + top_p (Optional[float]): + Float that controls the cumulative probability of the top tokens to consider. + Range: (0.0, 1.0]. 1.0 means consider all tokens. + timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -164,6 +188,10 @@ async def _acreate_stream( temperature=temperature, stop_sequences=stop_sequences, return_token_log_probs=return_token_log_probs, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + top_k=top_k, + top_p=top_p, timeout=timeout, ) @@ -184,6 +212,10 @@ async def _acreate_sync(**kwargs) -> CompletionSyncResponse: temperature=temperature, stop_sequences=stop_sequences, return_token_log_probs=return_token_log_probs, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + top_k=top_k, + top_p=top_p, ) @classmethod @@ -195,6 +227,10 @@ def create( temperature: float = 0.2, stop_sequences: Optional[List[str]] = None, return_token_log_probs: Optional[bool] = False, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]: @@ -235,6 +271,26 @@ def create( Whether to return the log probabilities of generated tokens. When True, the response will include a list of tokens and their log probabilities. + presence_penalty (Optional[float]): + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + https://platform.openai.com/docs/guides/gpt/parameter-details + Range: [0.0, 2.0]. Higher values encourage the model to use new tokens. + + frequency_penalty (Optional[float]): + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + https://platform.openai.com/docs/guides/gpt/parameter-details + Range: [0.0, 2.0]. Higher values encourage the model to use new tokens. + + top_k (Optional[int]): + Integer that controls the number of top tokens to consider. + Range: [1, infinity). -1 means consider all tokens. + + top_p (Optional[float]): + Float that controls the cumulative probability of the top tokens to consider. + Range: (0.0, 1.0]. 1.0 means consider all tokens. + timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -317,6 +373,10 @@ def _create_stream(**kwargs): temperature=temperature, stop_sequences=stop_sequences, return_token_log_probs=return_token_log_probs, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + top_k=top_k, + top_p=top_p, ) else: @@ -326,6 +386,10 @@ def _create_stream(**kwargs): temperature=temperature, stop_sequences=stop_sequences, return_token_log_probs=return_token_log_probs, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + top_k=top_k, + top_p=top_p, ).dict() response = cls.post_sync( resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}", diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 08a7f0be..2cdc2f89 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -269,6 +269,10 @@ class CompletionSyncV1Request(BaseModel): temperature: float = Field(..., ge=0.0) stop_sequences: Optional[List[str]] = Field(default=None) return_token_log_probs: Optional[bool] = Field(default=False) + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_k: Optional[int] = Field(default=None, ge=-1) + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) class TokenOutput(BaseModel): @@ -330,6 +334,10 @@ class CompletionStreamV1Request(BaseModel): temperature: float = Field(..., ge=0.0) stop_sequences: Optional[List[str]] = Field(default=None) return_token_log_probs: Optional[bool] = Field(default=False) + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_k: Optional[int] = Field(default=None, ge=-1) + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) class CompletionStreamOutput(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 2735f577..27a12ddc 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -104,7 +104,7 @@ class CompletionSyncV1Request(BaseModel): prompt: str max_new_tokens: int - temperature: float = Field(ge=0, le=1) + temperature: float = Field(ge=0.0, le=1.0) """ Temperature of the sampling. Setting to 0 equals to greedy sampling. """ @@ -116,6 +116,24 @@ class CompletionSyncV1Request(BaseModel): """ Whether to return the log probabilities of the tokens. """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ class TokenOutput(BaseModel): @@ -145,7 +163,7 @@ class CompletionStreamV1Request(BaseModel): prompt: str max_new_tokens: int - temperature: float = Field(ge=0, le=1) + temperature: float = Field(ge=0.0, le=1.0) """ Temperature of the sampling. Setting to 0 equals to greedy sampling. """ @@ -157,6 +175,24 @@ class CompletionStreamV1Request(BaseModel): """ Whether to return the log probabilities of the tokens. Only affects behavior for text-generation-inference models """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ class CompletionStreamOutput(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 98ea40ec..76dff95b 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -8,7 +8,7 @@ import math import os from dataclasses import asdict -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import Any, AsyncIterable, Dict, List, Optional, Union from uuid import uuid4 from model_engine_server.common.config import hmi_config @@ -839,6 +839,54 @@ def deepspeed_result_to_tokens(result: Dict[str, Any]) -> List[TokenOutput]: return tokens +def validate_and_update_completion_params( + inference_framework: LLMInferenceFramework, + request: Union[CompletionSyncV1Request, CompletionStreamV1Request], +) -> Union[CompletionSyncV1Request, CompletionStreamV1Request]: + # top_k, top_p + if inference_framework in [ + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: + if request.temperature == 0: + if request.top_k not in [-1, None] or request.top_p not in [1.0, None]: + raise ObjectHasInvalidValueException( + "top_k and top_p can't be enabled when temperature is 0." + ) + if request.top_k == 0: + raise ObjectHasInvalidValueException( + "top_k needs to be strictly positive, or set it to be -1 / None to disable top_k." + ) + if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + request.top_k = None if request.top_k == -1 else request.top_k + request.top_p = None if request.top_p == 1.0 else request.top_p + if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: + request.top_k = -1 if request.top_k is None else request.top_k + request.top_p = 1.0 if request.top_p is None else request.top_p + else: + if request.top_k or request.top_p: + raise ObjectHasInvalidValueException( + "top_k and top_p are only supported in text-generation-inference, vllm, lightllm." + ) + + # presence_penalty, frequency_penalty + if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: + request.presence_penalty = ( + 0.0 if request.presence_penalty is None else request.presence_penalty + ) + request.frequency_penalty = ( + 0.0 if request.frequency_penalty is None else request.frequency_penalty + ) + else: + if request.presence_penalty or request.frequency_penalty: + raise ObjectHasInvalidValueException( + "presence_penalty and frequency_penalty are only supported in vllm, lightllm." + ) + + return request + + class CompletionSyncV1UseCase: """ Use case for running a prompt completion on an LLM endpoint. @@ -983,6 +1031,15 @@ async def execute( endpoint_id=model_endpoint.record.id ) endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + validated_request = validate_and_update_completion_params( + endpoint_content.inference_framework, request + ) + if not isinstance(validated_request, CompletionSyncV1Request): + raise ValueError( + f"request has type {validated_request.__class__.__name__}, expected type CompletionSyncV1Request" + ) + request = validated_request + if endpoint_content.inference_framework == LLMInferenceFramework.DEEPSPEED: args: Any = { "prompts": [request.prompt], @@ -1036,6 +1093,10 @@ async def execute( if request.temperature > 0: tgi_args["parameters"]["temperature"] = request.temperature tgi_args["parameters"]["do_sample"] = True + tgi_args["parameters"]["top_k"] = request.top_k + tgi_args["parameters"]["top_p"] = request.top_p + else: + tgi_args["parameters"]["do_sample"] = False inference_request = SyncEndpointPredictV1Request( args=tgi_args, @@ -1064,10 +1125,15 @@ async def execute( vllm_args: Any = { "prompt": request.prompt, "max_tokens": request.max_new_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, } if request.stop_sequences is not None: vllm_args["stop"] = request.stop_sequences vllm_args["temperature"] = request.temperature + if request.temperature > 0: + vllm_args["top_k"] = request.top_k + vllm_args["top_p"] = request.top_p if request.return_token_log_probs: vllm_args["logprobs"] = 1 @@ -1098,12 +1164,16 @@ async def execute( "inputs": request.prompt, "parameters": { "max_new_tokens": request.max_new_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, }, } # TODO: implement stop sequences if request.temperature > 0: lightllm_args["parameters"]["temperature"] = request.temperature lightllm_args["parameters"]["do_sample"] = True + lightllm_args["top_k"] = request.top_k + lightllm_args["top_p"] = request.top_p else: lightllm_args["parameters"]["do_sample"] = False if request.return_token_log_probs: @@ -1172,6 +1242,7 @@ async def execute( request_id = str(uuid4()) add_trace_request_id(request_id) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( owner=user.team_id, name=model_endpoint_name, order_by=None ) @@ -1209,6 +1280,14 @@ async def execute( ) model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + validated_request = validate_and_update_completion_params( + model_content.inference_framework, request + ) + if not isinstance(validated_request, CompletionStreamV1Request): + raise ValueError( + f"request has type {validated_request.__class__.__name__}, expected type CompletionStreamV1Request" + ) + request = validated_request args: Any = None if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: @@ -1237,14 +1316,23 @@ async def execute( if request.temperature > 0: args["parameters"]["temperature"] = request.temperature args["parameters"]["do_sample"] = True + args["parameters"]["top_k"] = request.top_k + args["parameters"]["top_p"] = request.top_p + else: + args["parameters"]["do_sample"] = False elif model_content.inference_framework == LLMInferenceFramework.VLLM: args = { "prompt": request.prompt, "max_tokens": request.max_new_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, } if request.stop_sequences is not None: args["stop"] = request.stop_sequences args["temperature"] = request.temperature + if request.temperature > 0: + args["top_k"] = request.top_k + args["top_p"] = request.top_p if request.return_token_log_probs: args["logprobs"] = 1 args["stream"] = True @@ -1253,12 +1341,16 @@ async def execute( "inputs": request.prompt, "parameters": { "max_new_tokens": request.max_new_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, }, } # TODO: stop sequences if request.temperature > 0: args["parameters"]["temperature"] = request.temperature args["parameters"]["do_sample"] = True + args["parameters"]["top_k"] = request.top_k + args["parameters"]["top_p"] = request.top_p else: args["parameters"]["do_sample"] = False if request.return_token_log_probs: From 8723c535240458349189a9392fd470b1a3d4f9f2 Mon Sep 17 00:00:00 2001 From: Sam Denton <106690182+sam-scale@users.noreply.github.com> Date: Fri, 29 Sep 2023 16:17:11 -0700 Subject: [PATCH 119/425] Smartly check safetensors vs. bin (#296) * Smartly check safetensors vs. bin * Fix formatting * Add unit test * Add unit test * heh hope this works. * refactoring * adding new utils file, removing test * adding in unit test, refactoring again * adding artifact gateway to use case * renaming gateway function * whoops * cleanup --------- Co-authored-by: Ian Macleod Co-authored-by: Ian Macleod <139901935+ian-scale@users.noreply.github.com> --- .../model_engine_server/api/llms_v1.py | 1 + .../core/aws/storage_client.py | 2 +- .../domain/gateways/llm_artifact_gateway.py | 7 ++++ .../use_cases/llm_model_endpoint_use_cases.py | 32 +++++++++++++++- .../infra/gateways/s3_llm_artifact_gateway.py | 13 +++++++ model-engine/tests/unit/conftest.py | 5 +++ .../tests/unit/domain/test_llm_use_cases.py | 38 +++++++++++++++++++ 7 files changed, 95 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 9f6b48d8..67abfefa 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -92,6 +92,7 @@ async def create_model_endpoint( create_model_bundle_use_case=create_model_bundle_use_case, model_bundle_repository=external_interfaces.model_bundle_repository, model_endpoint_service=external_interfaces.model_endpoint_service, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, ) return await use_case.execute(user=auth, request=request) except ObjectAlreadyExistsException as exc: diff --git a/model-engine/model_engine_server/core/aws/storage_client.py b/model-engine/model_engine_server/core/aws/storage_client.py index c73c500d..814b00c4 100644 --- a/model-engine/model_engine_server/core/aws/storage_client.py +++ b/model-engine/model_engine_server/core/aws/storage_client.py @@ -20,7 +20,7 @@ def sync_storage_client(**kwargs) -> BaseClient: - return session(infra_config().profile_ml_worker).client("s3", **kwargs) + return session(infra_config().profile_ml_worker).client("s3", **kwargs) # type: ignore def open(uri: str, mode: str = "rt", **kwargs) -> IO: # pylint: disable=redefined-builtin diff --git a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py index dba41676..21e3c697 100644 --- a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py @@ -7,6 +7,13 @@ class LLMArtifactGateway(ABC): Abstract Base Class for interacting with llm artifacts. """ + @abstractmethod + def list_files(self, path: str, **kwargs) -> List[str]: + """ + Gets a list of files from a given path. + """ + pass + @abstractmethod def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: """ diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 76dff95b..388ed6aa 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -134,6 +134,25 @@ DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes +def _exclude_safetensors_or_bin(model_files: List[str]) -> Optional[str]: + """ + This function is used to determine whether to exclude "*.safetensors" or "*.bin" files + based on which file type is present more often in the checkpoint folder. The less + frequently present file type is excluded. + If both files are equally present, no exclusion string is returned. + """ + exclude_str = None + if len([f for f in model_files if f.endswith(".safetensors")]) > len( + [f for f in model_files if f.endswith(".bin")] + ): + exclude_str = "*.bin" + elif len([f for f in model_files if f.endswith(".safetensors")]) < len( + [f for f in model_files if f.endswith(".bin")] + ): + exclude_str = "*.safetensors" + return exclude_str + + def _model_endpoint_entity_to_get_llm_model_endpoint_response( model_endpoint: ModelEndpoint, ) -> GetLLMModelEndpointV1Response: @@ -182,11 +201,13 @@ def __init__( create_model_bundle_use_case: CreateModelBundleV2UseCase, model_bundle_repository: ModelBundleRepository, model_endpoint_service: ModelEndpointService, + llm_artifact_gateway: LLMArtifactGateway, ): self.authz_module = LiveAuthorizationModule() self.create_model_bundle_use_case = create_model_bundle_use_case self.model_bundle_repository = model_bundle_repository self.model_endpoint_service = model_endpoint_service + self.llm_artifact_gateway = llm_artifact_gateway async def create_model_bundle( self, @@ -358,14 +379,21 @@ def load_model_weights_sub_commands( ] ) else: - if framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + # Let's check whether to exclude "*.safetensors" or "*.bin" files + checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) + model_files = [f for f in checkpoint_files if "model" in f] + + exclude_str = _exclude_safetensors_or_bin(model_files) + + if exclude_str is None: subcommands.append( f"{s5cmd} --numworkers 512 cp --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}" ) else: subcommands.append( - f"{s5cmd} --numworkers 512 cp --concurrency 10 --exclude '*.safetensors' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + f"{s5cmd} --numworkers 512 cp --concurrency 10 --exclude '{exclude_str}' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" ) + return subcommands async def create_deepspeed_bundle( diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py index d46be385..6ce80446 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -3,6 +3,7 @@ import boto3 from model_engine_server.common.config import get_model_cache_directory_name, hmi_config +from model_engine_server.core.utils.url import parse_attachment_url from model_engine_server.domain.gateways import LLMArtifactGateway @@ -17,6 +18,18 @@ def _get_s3_resource(self, kwargs): resource = session.resource("s3") return resource + def list_files(self, path: str, **kwargs) -> List[str]: + s3 = self._get_s3_resource(kwargs) + parsed_remote = parse_attachment_url(path) + bucket = parsed_remote.bucket + key = parsed_remote.key + # From here: https://dev.to/aws-builders/how-to-list-contents-of-s3-bucket-using-boto3-python-47mm + files = [ + bucket_object["Key"] + for bucket_object in s3.list_objects_v2(Bucket=bucket, Prefix=key)["Contents"] + ] + return files + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: s3 = self._get_s3_resource(kwargs) # parsing prefix to get S3 bucket name diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 8ea517fd..cee6db6e 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -748,11 +748,16 @@ async def initialize_events(self, user_id: str, model_endpoint_name: str): class FakeLLMArtifactGateway(LLMArtifactGateway): def __init__(self): self.existing_models = [] + self.s3_bucket = {"fake-checkpoint": ["fake.bin, fake2.bin", "fake3.safetensors"]} self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} def _add_model(self, owner: str, model_name: str): self.existing_models.append((owner, model_name)) + def list_files(self, path: str, **kwargs) -> List[str]: + if path in self.s3_bucket: + return self.s3_bucket[path] + def get_model_weights_urls(self, owner: str, model_name: str): if (owner, model_name) in self.existing_models: return self.urls diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index f79d1e39..7171e5b4 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -36,6 +36,7 @@ DeleteLLMEndpointByNameUseCase, GetLLMModelEndpointByNameV1UseCase, ModelDownloadV1UseCase, + _exclude_safetensors_or_bin, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -47,6 +48,7 @@ async def test_create_model_endpoint_use_case_success( fake_model_endpoint_service, fake_docker_repository_image_always_exists, fake_model_primitive_gateway, + fake_llm_artifact_gateway, create_llm_model_endpoint_request_async: CreateLLMModelEndpointV1Request, create_llm_model_endpoint_request_sync: CreateLLMModelEndpointV1Request, create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, @@ -62,6 +64,7 @@ async def test_create_model_endpoint_use_case_success( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, model_endpoint_service=fake_model_endpoint_service, + llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async) @@ -150,6 +153,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( fake_model_endpoint_service, fake_docker_repository_image_always_exists, fake_model_primitive_gateway, + fake_llm_artifact_gateway, create_llm_model_endpoint_text_generation_inference_request_async: CreateLLMModelEndpointV1Request, create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request, ): @@ -163,6 +167,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, model_endpoint_service=fake_model_endpoint_service, + llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -202,6 +207,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception fake_model_endpoint_service, fake_docker_repository_image_always_exists, fake_model_primitive_gateway, + fake_llm_artifact_gateway, create_llm_model_endpoint_request_invalid_model_name: CreateLLMModelEndpointV1Request, ): fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository @@ -214,6 +220,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, model_endpoint_service=fake_model_endpoint_service, + llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): @@ -953,3 +960,34 @@ async def test_delete_public_inference_model_raises_not_authorized( await use_case.execute( user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name ) + + +@pytest.mark.asyncio +async def test_exclude_safetensors_or_bin_majority_bin_returns_exclude_safetensors(): + fake_model_files = ["fake.bin", "fake2.bin", "fake3.safetensors", "model.json", "optimizer.pt"] + assert _exclude_safetensors_or_bin(fake_model_files) == "*.safetensors" + + +@pytest.mark.asyncio +async def test_exclude_safetensors_or_bin_majority_safetensors_returns_exclude_bin(): + fake_model_files = [ + "fake.bin", + "fake2.safetensors", + "fake3.safetensors", + "model.json", + "optimizer.pt", + ] + assert _exclude_safetensors_or_bin(fake_model_files) == "*.bin" + + +@pytest.mark.asyncio +async def test_exclude_safetensors_or_bin_equal_bins_and_safetensors_returns_none(): + fake_model_files = [ + "fake.bin", + "fake2.safetensors", + "fake3.safetensors", + "fake4.bin", + "model.json", + "optimizer.pt", + ] + assert _exclude_safetensors_or_bin(fake_model_files) is None From 6e0bed1cf119b98654d12369b1987e022ba402d1 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Mon, 2 Oct 2023 09:59:29 -0700 Subject: [PATCH 120/425] adding s3 session function instead of client function (#302) --- .../infra/gateways/s3_llm_artifact_gateway.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py index 6ce80446..12f03d2a 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -23,11 +23,10 @@ def list_files(self, path: str, **kwargs) -> List[str]: parsed_remote = parse_attachment_url(path) bucket = parsed_remote.bucket key = parsed_remote.key - # From here: https://dev.to/aws-builders/how-to-list-contents-of-s3-bucket-using-boto3-python-47mm - files = [ - bucket_object["Key"] - for bucket_object in s3.list_objects_v2(Bucket=bucket, Prefix=key)["Contents"] - ] + + # Using resource's bucket object to get its objects with specific prefix + s3_bucket = s3.Bucket(bucket) + files = [obj.key for obj in s3_bucket.objects.filter(Prefix=key)] return files def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: From 4817ffcbfe9ea3a41ed94ca6f3f0b3cad4dfa67f Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 3 Oct 2023 16:50:53 -0700 Subject: [PATCH 121/425] Increase graceful timeout and hardcode AWS_PROFILE (#306) --- .../model-engine/templates/service_template_config_map.yaml | 4 ++++ .../inference/sync_inference/start_fastapi_server.py | 2 +- .../infra/gateways/resources/k8s_resource_types.py | 1 + model-engine/requirements-test.txt | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 3bd9674f..8eda04b1 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -77,7 +77,11 @@ data: spec: affinity: {{- include "modelEngine.serviceTemplateAffinity" . | nindent 12 }} + {{- if eq $mode "async" }} + terminationGracePeriodSeconds: 1800 + {{- else }} terminationGracePeriodSeconds: 600 + {{- end }} {{- if $service_template_service_account_name }} serviceAccount: {{ $service_template_service_account_name }} {{- else }} diff --git a/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py index 2b3aef79..2c93b770 100644 --- a/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py @@ -10,7 +10,7 @@ def start_server(): parser = argparse.ArgumentParser() - parser.add_argument("--graceful-timeout", type=int, default=600) + parser.add_argument("--graceful-timeout", type=int, default=1800) args, extra_args = parser.parse_known_args() # TODO: HTTPS diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 6c0f9724..9e77ac68 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -507,6 +507,7 @@ def get_endpoint_resource_arguments_from_request( main_env = [] if isinstance(flavor, RunnableImageLike) and flavor.env: main_env = [{"name": key, "value": value} for key, value in flavor.env.items()] + main_env.append({"name": "AWS_PROFILE", "value": build_endpoint_request.aws_role}) infra_service_config_volume_mount_path = "/infra-config" forwarder_config_file_name = "service--forwarder.yaml" diff --git a/model-engine/requirements-test.txt b/model-engine/requirements-test.txt index f93718b3..9ad7b6e2 100644 --- a/model-engine/requirements-test.txt +++ b/model-engine/requirements-test.txt @@ -9,6 +9,7 @@ pytest-mypy==0.9.1 pytest-mypy-plugins==1.10.1 pytest-asyncio==0.20.1 pytest-pylint==0.18.0 +pylint<3.0.0 types-cachetools==5.3.0.5 types-croniter==1.4.0.0 types-PyYAML==6.0.7 From 8e9f0c9d753740ef664c7e1383332e99b486eb35 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 3 Oct 2023 17:51:44 -0700 Subject: [PATCH 122/425] bump pypi version (#303) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 2c0eeb45..f52e0462 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b17" +__version__ = "0.0.0b18" import os from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index ff781834..6aa1063a 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta17" +version = "0.0.0.beta18" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 3ce304b4..97d6eaae 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta17", + version="0.0.0.beta18", packages=find_packages(), ) From acdd1ea0f29f164c2295bc19b546054df50354c9 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 4 Oct 2023 10:47:18 -0700 Subject: [PATCH 123/425] Ianmacleod/add mistral (#307) * add mistral 7b instruct * adding mistral support * update docs * update docs again * add mistral 7b max model len and max num batched tokens --- docs/model_zoo.md | 34 ++++++++++--------- .../use_cases/llm_model_endpoint_use_cases.py | 23 ++++++++++--- .../inference/vllm/requirements.txt | 2 +- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 5c0bab7c..0c61c38a 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -2,22 +2,24 @@ Scale hosts the following models in the LLM Engine Model Zoo: -| Model Name | Inference APIs Available | Fine-tuning APIs Available | -| --------------------- | ------------------------ | -------------------------- | -| `llama-7b` | ✅ | ✅ | -| `llama-2-7b` | ✅ | ✅ | -| `llama-2-7b-chat` | ✅ | | -| `llama-2-13b` | ✅ | | -| `llama-2-13b-chat` | ✅ | | -| `llama-2-70b` | ✅ | ✅ | -| `llama-2-70b-chat` | ✅ | | -| `falcon-7b` | ✅ | | -| `falcon-7b-instruct` | ✅ | | -| `falcon-40b` | ✅ | | -| `falcon-40b-instruct` | ✅ | | -| `mpt-7b` | ✅ | | -| `mpt-7b-instruct` | ✅ | ✅ | -| `flan-t5-xxl` | ✅ | | +| Model Name | Inference APIs Available | Fine-tuning APIs Available | Inference Frameworks Available | +| --------------------- | ------------------------ | -------------------------- | ------------------------------ | +| `llama-7b` | ✅ | ✅ | deepspeed, text-generation-inference | +| `llama-2-7b` | ✅ | ✅ | text-generation-inference, vllm | +| `llama-2-7b-chat` | ✅ | | text-generation-inference, vllm | +| `llama-2-13b` | ✅ | | text-generation-inference, vllm | +| `llama-2-13b-chat` | ✅ | | text-generation-inference, vllm | +| `llama-2-70b` | ✅ | ✅ | text-generation-inference, vllm | +| `llama-2-70b-chat` | ✅ | | text-generation-inference, vllm | +| `falcon-7b` | ✅ | | text-generation-inference, vllm | +| `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | +| `falcon-40b` | ✅ | | text-generation-inference, vllm | +| `falcon-40b-instruct` | ✅ | | text-generation-inference, vllm | +| `mpt-7b` | ✅ | | deepspeed, text-generation-inference, vllm | +| `mpt-7b-instruct` | ✅ | ✅ | deepspeed, text-generation-inference, vllm | +| `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | +| `mistral-7b` | ✅ | | vllm | +| `mistral-7b-instruct` | ✅ | | vllm | ## Usage diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 388ed6aa..d1a746c9 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -117,6 +117,8 @@ "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", "falcon-40b": "tiiuae/falcon-40b", "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", + "mistral-7b": "mistralai/Mistral-7B-v0.1", + "mistral-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", }, LLMInferenceFramework.LIGHTLLM: { "llama-7b": "decapoda-research/llama-7b-hf", @@ -488,13 +490,21 @@ async def create_vllm_bundle( command = [] max_num_batched_tokens = 2560 # vLLM's default + max_model_len = None if "llama-2" in model_name: max_num_batched_tokens = 4096 # Need to be bigger than model's context window + if "mistral" in model_name: + max_num_batched_tokens = 8000 + max_model_len = 8000 subcommands = [] if checkpoint_path is not None: if checkpoint_path.startswith("s3://"): - final_weights_folder = "model_files" + # added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder + if "mistral" in model_name: + final_weights_folder = "mistral_files" + else: + final_weights_folder = "model_files" subcommands += self.load_model_weights_sub_commands( LLMInferenceFramework.VLLM, framework_image_tag, @@ -508,9 +518,14 @@ async def create_vllm_bundle( else: final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] - subcommands.append( - f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005 --max-num-batched-tokens {max_num_batched_tokens}" - ) + if max_model_len: + subcommands.append( + f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005 --max-num-batched-tokens {max_num_batched_tokens} --max-model-len {max_model_len}" + ) + else: + subcommands.append( + f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005 --max-num-batched-tokens {max_num_batched_tokens}" + ) if quantize: if quantize == Quantization.AWQ: diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index db3b97a4..b5407ab9 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,3 +1,3 @@ ray==2.6.3 -git+https://github.com/vllm-project/vllm.git@7d7e3b78a3c265ab3c57eeff43af56f509907998#egg=vllm +vllm==0.2.0 pydantic==1.10.12 From ee1d41cb78cc61ab11786cfe3d3989c0b64369ee Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Thu, 5 Oct 2023 13:33:46 -0700 Subject: [PATCH 124/425] Ianmacleod/add falcon 180b (#309) * add falcon 180b * also add regular and chat --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index d1a746c9..41f0a92c 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -102,6 +102,8 @@ "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", "falcon-40b": "tiiuae/falcon-40b", "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", + "falcon-180b": "tiiuae/falcon-180B", + "falcon-180b-chat": "tiiuae/falcon-180B-chat", }, LLMInferenceFramework.VLLM: { "mpt-7b": "mosaicml/mpt-7b", From 1dc42c301f492c4b857b15b020d7affc21b706bf Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Fri, 6 Oct 2023 16:02:23 -0700 Subject: [PATCH 125/425] update 180b inference framework (#310) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 41f0a92c..5f4eb872 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -102,8 +102,6 @@ "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", "falcon-40b": "tiiuae/falcon-40b", "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", - "falcon-180b": "tiiuae/falcon-180B", - "falcon-180b-chat": "tiiuae/falcon-180B-chat", }, LLMInferenceFramework.VLLM: { "mpt-7b": "mosaicml/mpt-7b", @@ -121,6 +119,8 @@ "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", "mistral-7b": "mistralai/Mistral-7B-v0.1", "mistral-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", + "falcon-180b": "tiiuae/falcon-180B", + "falcon-180b-chat": "tiiuae/falcon-180B-chat", }, LLMInferenceFramework.LIGHTLLM: { "llama-7b": "decapoda-research/llama-7b-hf", From c0f51ab65fe009042a71cd3f5cd1a521f78e00cd Mon Sep 17 00:00:00 2001 From: mfagundo-scale <142335718+mfagundo-scale@users.noreply.github.com> Date: Mon, 9 Oct 2023 11:36:31 -0700 Subject: [PATCH 126/425] Adding code llama to TGI (#311) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 5f4eb872..f1eb665a 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -102,6 +102,9 @@ "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", "falcon-40b": "tiiuae/falcon-40b", "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", + "code-llama-7b": "codellama/CodeLlama-7b-hf", + "code-llama-13b": "codellama/CodeLlama-13b-hf", + "code-llama-34b": "codellama/CodeLlama-34b-hf", }, LLMInferenceFramework.VLLM: { "mpt-7b": "mosaicml/mpt-7b", From 43cfaba1bd0bb6a09100ea06eafad63123dd40d1 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 10 Oct 2023 17:05:05 -0700 Subject: [PATCH 127/425] Add AWQ enum (#317) --- clients/python/llmengine/data_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 2cdc2f89..211106d8 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -24,6 +24,7 @@ class LLMSource(str, Enum): class Quantization(str, Enum): BITSANDBYTES = "bitsandbytes" + AWQ = "awq" class GpuType(str, Enum): From 483cd253b168dcfb78f6c917b24d2f29c3bd5fcc Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Wed, 11 Oct 2023 11:17:43 -0700 Subject: [PATCH 128/425] Fix documentation to reference Files API (#312) --- clients/python/llmengine/file.py | 5 ++++- clients/python/llmengine/fine_tuning.py | 10 ++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/clients/python/llmengine/file.py b/clients/python/llmengine/file.py index 097fed1f..670efda3 100644 --- a/clients/python/llmengine/file.py +++ b/clients/python/llmengine/file.py @@ -22,9 +22,12 @@ def upload(cls, file: BufferedReader) -> UploadFileResponse: """ Uploads a file to LLM engine. + For use in [FineTune creation](./#llmengine.fine_tuning.FineTune.create), this should be a CSV file with two columns: `prompt` and `response`. + A maximum of 100,000 rows of data is currently supported. + Args: file (`BufferedReader`): - A file opened with open(file_path, "r") + A local file opened with `open(file_path, "r")` Returns: UploadFileResponse: an object that contains the ID of the uploaded file diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index e15f36a7..bf9dcf0d 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -38,8 +38,10 @@ def create( This API can be used to fine-tune a model. The _model_ is the name of base model ([Model Zoo](../../model_zoo) for available models) to fine-tune. The training and validation files should consist of prompt and response pairs. `training_file` - and `validation_file` must be publicly accessible HTTP or HTTPS URLs to a CSV file - that includes two columns: `prompt` and `response`. A maximum of 100,000 rows of data is + and `validation_file` must be either publicly accessible HTTP or HTTPS URLs, or + file IDs of files uploaded to LLM Engine's [Files API](./#llmengine.File) (these + will have the `file-` prefix). The referenced files must be CSV files that include + two columns: `prompt` and `response`. A maximum of 100,000 rows of data is currently supported. At least 200 rows of data is recommended to start to see benefits from fine-tuning. For sequences longer than the native `max_seq_length` of the model, the sequences will be truncated. @@ -52,10 +54,10 @@ def create( The name of the base model to fine-tune. See [Model Zoo](../../model_zoo) for the list of available models to fine-tune. training_file (`str`): - Publicly accessible URL to a CSV file for training. When no validation_file is provided, one will automatically be created using a 10% split of the training_file data. + Publicly accessible URL or file ID referencing a CSV file for training. When no validation_file is provided, one will automatically be created using a 10% split of the training_file data. validation_file (`Optional[str]`): - Publicly accessible URL to a CSV file for validation. The validation file is used to compute metrics which let LLM Engine pick the best fine-tuned checkpoint, which will be used for inference when fine-tuning is complete. + Publicly accessible URL or file ID referencing a CSV file for validation. The validation file is used to compute metrics which let LLM Engine pick the best fine-tuned checkpoint, which will be used for inference when fine-tuning is complete. hyperparameters (`Optional[Dict[str, str]]`): A dict of hyperparameters to customize fine-tuning behavior. From 4ef1eedab233d16e624d6c48118013b41ba12c64 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 11 Oct 2023 14:49:56 -0700 Subject: [PATCH 129/425] Return TGI errors (#313) * Return TGI errors * remove prints * fix lint --- .../use_cases/llm_model_endpoint_use_cases.py | 31 ++-- model-engine/tests/unit/conftest.py | 132 ++++++++++++++++++ .../tests/unit/domain/test_llm_use_cases.py | 49 ++++++- 3 files changed, 199 insertions(+), 13 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index f1eb665a..97d9f69b 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -909,7 +909,10 @@ def validate_and_update_completion_params( if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: request.top_k = None if request.top_k == -1 else request.top_k request.top_p = None if request.top_p == 1.0 else request.top_p - if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: + if inference_framework in [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: request.top_k = -1 if request.top_k is None else request.top_k request.top_p = 1.0 if request.top_p is None else request.top_p else: @@ -919,7 +922,10 @@ def validate_and_update_completion_params( ) # presence_penalty, frequency_penalty - if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: + if inference_framework in [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: request.presence_penalty = ( 0.0 if request.presence_penalty is None else request.presence_penalty ) @@ -987,14 +993,17 @@ def model_output_to_completion_output( raise InvalidRequestException(model_output.get("error")) # trigger a 400 else: raise UpstreamServiceError( - status_code=500, content=bytes(model_output["error"]) + status_code=500, content=bytes(model_output["error"], "utf-8") ) elif model_content.inference_framework == LLMInferenceFramework.VLLM: tokens = None if with_token_probs: tokens = [ - TokenOutput(token=model_output["tokens"][index], log_prob=list(t.values())[0]) + TokenOutput( + token=model_output["tokens"][index], + log_prob=list(t.values())[0], + ) for index, t in enumerate(model_output["log_probs"]) ] return CompletionOutput( @@ -1003,7 +1012,6 @@ def model_output_to_completion_output( tokens=tokens, ) elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: - print(model_output) tokens = None if with_token_probs: tokens = [ @@ -1109,7 +1117,8 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: @@ -1152,7 +1161,8 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1191,7 +1201,8 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1233,7 +1244,8 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1517,7 +1529,6 @@ async def execute( ) elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: if res.status == TaskStatus.SUCCESS and result is not None: - print(result) token = None num_completion_tokens += 1 if request.return_token_log_probs: diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index cee6db6e..b55a2b50 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -3672,6 +3672,138 @@ def llm_model_endpoint_sync( return model_endpoint, model_endpoint_json +@pytest.fixture +def llm_model_endpoint_sync_tgi( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_2", + name="test_llm_model_endpoint_name_1", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "text_generation_inference", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.SYNC, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + __root__=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + model_endpoint_json: Dict[str, Any] = { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "model_name": "llama-7b", + "source": "hugging_face", + "status": "READY", + "inference_framework": "text_generation_inference", + "inference_framework_image_tag": "123", + "num_shards": 4, + "spec": { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "endpoint_type": "sync", + "destination": "test_destination", + "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", + "metadata": { + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "text_generation_inference", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + "bundle_name": "test_model_bundle_name_1", + "status": "READY", + "post_inference_hooks": ["callback"], + "default_callback_url": "http://www.example.com", + "default_callback_auth": { + "kind": "basic", + "username": "test_username", + "password": "test_password", + }, + "labels": {}, + "aws_role": "test_aws_role", + "results_s3_bucket": "test_s3_bucket", + "created_by": test_api_key, + "created_at": "2022-01-03T00:00:00", + "last_updated_at": "2022-01-03T00:00:00", + "deployment_state": { + "min_workers": 1, + "max_workers": 3, + "per_worker": 2, + "available_workers": 1, + "unavailable_workers": 1, + }, + "resource_state": { + "cpus": "1", + "gpus": 1, + "memory": "1G", + "gpu_type": "nvidia-tesla-t4", + "storage": "10G", + "optimize_costs": True, + }, + "num_queued_items": 1, + "public_inference": True, + }, + } + return model_endpoint, model_endpoint_json + + @pytest.fixture def llm_model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> ModelEndpoint: # model_bundle_5 is a runnable bundle diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 7171e5b4..edb53abb 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -22,6 +22,7 @@ ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, + UpstreamServiceError, ) from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( MAX_LLM_ENDPOINTS_PER_INTERNAL_USER, @@ -171,7 +172,8 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( - user=user, request=create_llm_model_endpoint_text_generation_inference_request_streaming + user=user, + request=create_llm_model_endpoint_text_generation_inference_request_streaming, ) assert response_1.endpoint_creation_task_id assert isinstance(response_1, CreateLLMModelEndpointV1Response) @@ -196,7 +198,8 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( with pytest.raises(ObjectHasInvalidValueException): await use_case.execute( - user=user, request=create_llm_model_endpoint_text_generation_inference_request_async + user=user, + request=create_llm_model_endpoint_text_generation_inference_request_async, ) @@ -483,6 +486,40 @@ async def test_completion_sync_use_case_predict_failed( assert response_1.output is None +@pytest.mark.asyncio +async def test_completion_sync_use_case_predict_failed_with_errors( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_sync_tgi: Tuple[ModelEndpoint, Any], + completion_sync_request: CompletionSyncV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_tgi[0]) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": """ + { + "error": "Request failed during generation: Server error: transport error", + "error_type": "generation" + } +""" + }, + traceback="failed to predict", + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(UpstreamServiceError): + await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_sync_tgi[0].record.name, + request=completion_sync_request, + ) + + @pytest.mark.asyncio async def test_completion_sync_use_case_not_sync_endpoint_raises( test_api_key: str, @@ -964,7 +1001,13 @@ async def test_delete_public_inference_model_raises_not_authorized( @pytest.mark.asyncio async def test_exclude_safetensors_or_bin_majority_bin_returns_exclude_safetensors(): - fake_model_files = ["fake.bin", "fake2.bin", "fake3.safetensors", "model.json", "optimizer.pt"] + fake_model_files = [ + "fake.bin", + "fake2.bin", + "fake3.safetensors", + "model.json", + "optimizer.pt", + ] assert _exclude_safetensors_or_bin(fake_model_files) == "*.safetensors" From 65afe0ad978c3dc4f0a14e75c3342c70db227129 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 11 Oct 2023 16:11:59 -0700 Subject: [PATCH 130/425] Fix streaming endpoint failure handling (#314) * Fix streaming endpoint failure handling * Fix streaming endpoint failure handling * remove print * comments * client side changes * client side changes * fix * strong typing --- .pre-commit-config.yaml | 1 + clients/python/llmengine/data_types.py | 21 +++++ docs/getting_started.md | 11 +-- docs/guides/completions.md | 11 +-- .../model_engine_server/api/llms_v1.py | 93 ++++++++++++------- .../model_engine_server/common/dtos/llms.py | 20 ++++ .../use_cases/llm_model_endpoint_use_cases.py | 2 +- model-engine/tests/unit/api/test_llms.py | 57 ++++++++++++ model-engine/tests/unit/api/test_tasks.py | 1 + 9 files changed, 168 insertions(+), 49 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3fe2075c..bb2d9cc0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,6 +55,7 @@ repos: hooks: - id: mypy name: mypy-clients-python + files: clients/python/.* entry: mypy --config-file clients/python/mypy.ini language: system - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 211106d8..07612420 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -355,6 +355,24 @@ class CompletionStreamOutput(BaseModel): """Detailed token information.""" +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" + + class CompletionStreamResponse(BaseModel): """ Response object for a stream prompt completion task. @@ -372,6 +390,9 @@ class CompletionStreamResponse(BaseModel): output: Optional[CompletionStreamOutput] = None """Completion output.""" + error: Optional[StreamError] = None + """Error of the response (if any).""" + class CreateFineTuneRequest(BaseModel): """ diff --git a/docs/getting_started.md b/docs/getting_started.md index 46741d1b..fea0531a 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -81,11 +81,10 @@ stream = Completion.create( ) for response in stream: - try: - if response.output: - print(response.output.text, end="") - sys.stdout.flush() - except: # an error occurred - print(stream.text) # print the error message out + if response.output: + print(response.output.text, end="") + sys.stdout.flush() + else: # an error occurred + print(response.error) # print the error message out break ``` diff --git a/docs/guides/completions.md b/docs/guides/completions.md index 4719edc3..dee51f61 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -87,12 +87,11 @@ stream = Completion.create( ) for response in stream: - try: - if response.output: - print(response.output.text, end="") - sys.stdout.flush() - except: # an error occurred - print(stream.text) # print the error message out + if response.output: + print(response.output.text, end="") + sys.stdout.flush() + else: # an error occurred + print(response.error) # print the error message out break ``` diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 67abfefa..92ddad0e 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -1,7 +1,10 @@ """LLM Model Endpoint routes for the hosted model inference service. """ +import traceback +from datetime import datetime from typing import Optional +import pytz from fastapi import APIRouter, Depends, HTTPException, Query from model_engine_server.api.dependencies import ( ExternalInterfaces, @@ -28,6 +31,8 @@ ListLLMModelEndpointsV1Response, ModelDownloadRequest, ModelDownloadResponse, + StreamError, + StreamErrorContent, ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User @@ -71,6 +76,34 @@ logger = make_logger(filename_wo_ext(__name__)) +def handle_streaming_exception( + e: Exception, + code: int, + message: str, +): + tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + request_id = get_request_id() + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": message, + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Exception: %s", structured_log) + return { + "data": CompletionStreamV1Response( + request_id=str(request_id), + error=StreamError( + status_code=code, + content=StreamErrorContent( + error=message, + timestamp=timestamp, + ), + ), + ).json() + } + + @llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response) async def create_model_endpoint( request: CreateLLMModelEndpointV1Request, @@ -226,42 +259,30 @@ async def create_completion_stream_task( logger.info( f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" ) - try: - use_case = CompletionStreamV1UseCase( - model_endpoint_service=external_interfaces.model_endpoint_service, - llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, - ) - response = use_case.execute( - user=auth, model_endpoint_name=model_endpoint_name, request=request - ) + use_case = CompletionStreamV1UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + ) + response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request) - async def event_generator(): - try: - async for message in response: - yield {"data": message.json()} - except InvalidRequestException as exc: - yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} - return + async def event_generator(): + try: + async for message in response: + yield {"data": message.json()} + except (InvalidRequestException, ObjectHasInvalidValueException) as exc: + yield handle_streaming_exception(exc, 400, str(exc)) + except ( + ObjectNotFoundException, + ObjectNotAuthorizedException, + EndpointUnsupportedInferenceTypeException, + ) as exc: + yield handle_streaming_exception(exc, 404, str(exc)) + except Exception as exc: + yield handle_streaming_exception( + exc, 500, "Internal error occurred. Our team has been notified." + ) - return EventSourceResponse(event_generator()) - except UpstreamServiceError: - request_id = get_request_id() - logger.exception(f"Upstream service error for request {request_id}") - return EventSourceResponse( - iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore - ) - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail="The specified endpoint could not be found.", - ) from exc - except ObjectHasInvalidValueException as exc: - raise HTTPException(status_code=400, detail=str(exc)) - except EndpointUnsupportedInferenceTypeException as exc: - raise HTTPException( - status_code=400, - detail=f"Unsupported inference type: {str(exc)}", - ) from exc + return EventSourceResponse(event_generator()) @llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse) @@ -405,12 +426,12 @@ async def delete_llm_model_endpoint( model_endpoint_service=external_interfaces.model_endpoint_service, ) return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) - except (ObjectNotFoundException) as exc: + except ObjectNotFoundException as exc: raise HTTPException( status_code=404, detail="The requested model endpoint could not be found.", ) from exc - except (ObjectNotAuthorizedException) as exc: + except ObjectNotAuthorizedException as exc: raise HTTPException( status_code=403, detail="You don't have permission to delete the requested model endpoint.", diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 27a12ddc..bf0b7519 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -202,6 +202,24 @@ class CompletionStreamOutput(BaseModel): token: Optional[TokenOutput] = None +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" + + class CompletionStreamV1Response(BaseModel): """ Response object for a stream prompt completion task. @@ -209,6 +227,8 @@ class CompletionStreamV1Response(BaseModel): request_id: str output: Optional[CompletionStreamOutput] = None + error: Optional[StreamError] = None + """Error of the response (if any).""" class CreateFineTuneRequest(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 97d9f69b..5b179872 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1308,7 +1308,7 @@ async def execute( ) if len(model_endpoints) == 0: - raise ObjectNotFoundException + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") if len(model_endpoints) > 1: raise ObjectHasInvalidValueException( diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 2e909aeb..32178b49 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -113,6 +113,32 @@ def test_completion_sync_success( assert response_1.json().keys() == {"output", "request_id"} +def test_completion_sync_endpoint_not_found_returns_404( + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + completion_sync_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_sync[0] + .infra_state.deployment_name: llm_model_endpoint_sync[0] + .infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", + auth=("no_user", ""), + json=completion_sync_request, + ) + assert response_1.status_code == 404 + + @pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") def test_completion_stream_success( llm_model_endpoint_streaming: ModelEndpoint, @@ -136,6 +162,7 @@ def test_completion_stream_success( f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", auth=("no_user", ""), json=completion_stream_request, + stream=True, ) assert response_1.status_code == 200 count = 0 @@ -146,3 +173,33 @@ def test_completion_stream_success( ) count += 1 assert count == 1 + + +@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") +def test_completion_stream_endpoint_not_found_returns_404( + llm_model_endpoint_streaming: ModelEndpoint, + completion_stream_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + auth=("no_user", ""), + json=completion_stream_request, + stream=True, + ) + + assert response_1.status_code == 200 + + for message in response_1: + assert "404" in message.decode("utf-8") diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index 5192f025..611195bd 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -364,6 +364,7 @@ def test_create_streaming_task_success( f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", auth=(test_api_key, ""), json=endpoint_predict_request_1[1], + stream=True, ) assert response.status_code == 200 count = 0 From 60ac144c55aad971cdd7f152f4f7816ce2fb7d2f Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 11 Oct 2023 18:07:59 -0700 Subject: [PATCH 131/425] Validate quantization (#315) * Validate quantization * comments --- clients/python/llmengine/model.py | 2 +- .../model_engine_server/common/env_vars.py | 3 +- .../use_cases/llm_model_endpoint_use_cases.py | 24 ++++++++++++++- model-engine/tests/unit/domain/conftest.py | 29 ++++++++++++++++++- .../tests/unit/domain/test_llm_use_cases.py | 29 +++++++++++++++++++ 5 files changed, 83 insertions(+), 4 deletions(-) diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 021a9ff5..b5ed181a 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -76,7 +76,7 @@ def create( num_shards (`int`): Number of shards for the LLM. When bigger than 1, LLM will be sharded - to multiple GPUs. Number of GPUs must be larger than num_shards. + to multiple GPUs. Number of GPUs must be equal or larger than num_shards. Only affects behavior for text-generation-inference models quantize (`Optional[Quantization]`): diff --git a/model-engine/model_engine_server/common/env_vars.py b/model-engine/model_engine_server/common/env_vars.py index a51a7698..9d5ca20a 100644 --- a/model-engine/model_engine_server/common/env_vars.py +++ b/model-engine/model_engine_server/common/env_vars.py @@ -2,6 +2,7 @@ A place for defining, setting, and referencing all environment variables used in Launch. """ import os +import sys from typing import Optional, Sequence from model_engine_server.common.constants import PROJECT_ROOT @@ -73,5 +74,5 @@ def get_boolean_env_var(name: str) -> bool: logger.warning("LOCAL development & testing mode is ON") GIT_TAG: str = os.environ.get("GIT_TAG", "GIT_TAG_NOT_FOUND") -if GIT_TAG == "GIT_TAG_NOT_FOUND": +if GIT_TAG == "GIT_TAG_NOT_FOUND" and "pytest" not in sys.modules: raise ValueError("GIT_TAG environment variable must be set") diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 5b179872..82ca48e1 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -136,6 +136,13 @@ }, } +_SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = { + LLMInferenceFramework.DEEPSPEED: [], + LLMInferenceFramework.TEXT_GENERATION_INFERENCE: [Quantization.BITSANDBYTES], + LLMInferenceFramework.VLLM: [Quantization.AWQ], + LLMInferenceFramework.LIGHTLLM: [], +} + NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes @@ -198,8 +205,21 @@ def validate_num_shards( raise ObjectHasInvalidValueException("DeepSpeed requires more than 1 GPU.") if num_shards != gpus: raise ObjectHasInvalidValueException( - f"DeepSpeed requires num shard {num_shards} to be the same as number of GPUs {gpus}." + f"Num shard {num_shards} must be the same as number of GPUs {gpus} for DeepSpeed." ) + if num_shards > gpus: + raise ObjectHasInvalidValueException( + f"Num shard {num_shards} must be less than or equal to the number of GPUs {gpus}." + ) + + +def validate_quantization( + quantize: Optional[Quantization], inference_framework: LLMInferenceFramework +) -> None: + if quantize is not None and quantize not in _SUPPORTED_QUANTIZATIONS[inference_framework]: + raise ObjectHasInvalidValueException( + f"Quantization {quantize} is not supported for inference framework {inference_framework}. Supported quantization types are {_SUPPORTED_QUANTIZATIONS[inference_framework]}." + ) class CreateLLMModelEndpointV1UseCase: @@ -667,10 +687,12 @@ async def execute( validate_post_inference_hooks(user, request.post_inference_hooks) validate_model_name(request.model_name, request.inference_framework) validate_num_shards(request.num_shards, request.inference_framework, request.gpus) + validate_quantization(request.quantize, request.inference_framework) if request.inference_framework in [ LLMInferenceFramework.TEXT_GENERATION_INFERENCE, LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, ]: if request.endpoint_type != ModelEndpointType.STREAMING: raise ObjectHasInvalidValueException( diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 85e57ea4..6a958ed4 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -18,6 +18,7 @@ ) from model_engine_server.domain.entities import ( GpuType, + LLMInferenceFramework, ModelBundle, ModelBundleEnvironmentParams, ModelBundleFrameworkType, @@ -283,7 +284,6 @@ def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( inference_framework="deepspeed", inference_framework_image_tag="test_tag", num_shards=2, - quantize=Quantization.BITSANDBYTES, endpoint_type=ModelEndpointType.STREAMING, metadata={}, post_inference_hooks=["billing"], @@ -356,6 +356,33 @@ def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndp ) +@pytest.fixture +def create_llm_model_endpoint_request_invalid_quantization() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_1", + model_name="nonexist", + source="hugging_face", + inference_framework=LLMInferenceFramework.VLLM, + inference_framework_image_tag="test_tag", + num_shards=2, + quantize=Quantization.BITSANDBYTES, + endpoint_type=ModelEndpointType.SYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage=None, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + ) + + @pytest.fixture def completion_sync_request() -> CompletionSyncV1Request: return CompletionSyncV1Request( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index edb53abb..c71995ea 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -232,6 +232,35 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception ) +@pytest.mark.asyncio +async def test_create_llm_model_endpoint_use_case_quantization_exception( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_request_invalid_quantization: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute( + user=user, request=create_llm_model_endpoint_request_invalid_quantization + ) + + @pytest.mark.asyncio async def test_get_llm_model_endpoint_use_case_raises_not_found( test_api_key: str, From d30a1a5d504d76d68a6406ee726e570af129f8b8 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Thu, 12 Oct 2023 14:56:42 -0700 Subject: [PATCH 132/425] Properly return PENDING status for docker image batch jobs/fine tune jobs (#318) --- .../infra/gateways/live_cron_job_gateway.py | 20 +- .../live_docker_image_batch_job_gateway.py | 57 ++++- .../unit/infra/gateways/k8s_fake_objects.py | 67 ++++++ .../test_k8s_endpoint_resource_delegate.py | 13 +- ...est_live_docker_image_batch_job_gateway.py | 211 ++++++++++++++++++ 5 files changed, 351 insertions(+), 17 deletions(-) create mode 100644 model-engine/tests/unit/infra/gateways/k8s_fake_objects.py diff --git a/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py index b8316b25..257f7cbd 100644 --- a/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py @@ -10,9 +10,11 @@ from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( LAUNCH_JOB_ID_LABEL_SELECTOR, _parse_job_status_from_k8s_obj, + make_job_id_to_pods_mapping, ) from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( get_kubernetes_batch_client, + get_kubernetes_core_client, load_k8s_yaml, maybe_load_kube_config, ) @@ -97,6 +99,20 @@ async def list_jobs( logger.exception("Got an exception when trying to list the Jobs") raise EndpointResourceInfraException from exc + core_client = get_kubernetes_core_client() + + try: + label_selector = f"trigger_id={trigger_id}" if trigger_id else f"owner={owner},job-name" + pods = await core_client.list_namespaced_pod( + namespace=hmi_config.endpoint_namespace, + label_selector=label_selector, + ) + except ApiException as exc: + logger.exception("Got an exception when trying to list the Pods") + raise EndpointResourceInfraException from exc + + pods_per_job = make_job_id_to_pods_mapping(pods.items) + return [ DockerImageBatchJob( id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR), @@ -104,7 +120,9 @@ async def list_jobs( owner=job.metadata.labels.get("owner"), created_at=job.metadata.creation_timestamp, completed_at=job.status.completion_time, - status=_parse_job_status_from_k8s_obj(job), + status=_parse_job_status_from_k8s_obj( + job, pods_per_job[job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)] + ), ) for job in jobs.items ] diff --git a/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py index bc1d6a9b..40e09a8c 100644 --- a/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py @@ -1,9 +1,11 @@ import os import re +from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union from kubernetes_asyncio.client.models.v1_job import V1Job +from kubernetes_asyncio.client.models.v1_pod import V1Pod from kubernetes_asyncio.client.rest import ApiException from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests @@ -17,6 +19,7 @@ ) from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( get_kubernetes_batch_client, + get_kubernetes_core_client, load_k8s_yaml, maybe_load_kube_config, ) @@ -84,7 +87,7 @@ def _k8s_job_name_from_id(job_id: str): return f"launch-di-batch-job-{job_id}" -def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus: +def _parse_job_status_from_k8s_obj(job: V1Job, pods: List[V1Pod]) -> BatchJobStatus: status = job.status # these counts are the number of pods in some given status if status.failed is not None and status.failed > 0: @@ -94,10 +97,30 @@ def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus: if status.ready is not None and status.ready > 0: return BatchJobStatus.RUNNING # empirically this doesn't happen if status.active is not None and status.active > 0: - return BatchJobStatus.RUNNING # TODO this might be a mix of pending and running + for pod in pods: + # In case there are multiple pods for a given job (e.g. if a pod gets shut down) + # let's interpret the job as running if any of the pods are running + # I haven't empirically seen this, but guard against it just in case. + if pod.status.phase == "Running": + return BatchJobStatus.RUNNING + return BatchJobStatus.PENDING return BatchJobStatus.PENDING +def make_job_id_to_pods_mapping(pods: List[V1Pod]) -> defaultdict: + """ + Returns a defaultdict mapping job IDs to pods + """ + job_id_to_pods_mapping = defaultdict(list) + for pod in pods: + job_id = pod.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR) + if job_id is not None: + job_id_to_pods_mapping[job_id].append(pod) + else: + logger.warning(f"Pod {pod.metadata.name} has no job ID label") + return job_id_to_pods_mapping + + class LiveDockerImageBatchJobGateway(DockerImageBatchJobGateway): def __init__(self): pass @@ -282,10 +305,21 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker logger.exception("Got an exception when trying to read the Job") raise EndpointResourceInfraException from exc + core_client = get_kubernetes_core_client() + try: + pods = await core_client.list_namespaced_pod( + namespace=hmi_config.endpoint_namespace, + label_selector=f"{LAUNCH_JOB_ID_LABEL_SELECTOR}={batch_job_id}", + ) + except ApiException as exc: + logger.exception("Got an exception when trying to read pods for the Job") + raise EndpointResourceInfraException from exc + # This pod list isn't always needed, but it's simpler code-wise to always make the request + job_labels = job.metadata.labels annotations = job.metadata.annotations - status = _parse_job_status_from_k8s_obj(job) + status = _parse_job_status_from_k8s_obj(job, pods.items) return DockerImageBatchJob( id=batch_job_id, @@ -309,6 +343,19 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc logger.exception("Got an exception when trying to list the Jobs") raise EndpointResourceInfraException from exc + core_client = get_kubernetes_core_client() + try: + pods = await core_client.list_namespaced_pod( + namespace=hmi_config.endpoint_namespace, + label_selector=f"{OWNER_LABEL_SELECTOR}={owner},job-name", # get only pods associated with a job + ) + except ApiException as exc: + logger.exception("Got an exception when trying to read pods for the Job") + raise EndpointResourceInfraException from exc + + # Join jobs + pods + pods_per_job = make_job_id_to_pods_mapping(pods.items) + return [ DockerImageBatchJob( id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR), @@ -317,7 +364,9 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc created_at=job.metadata.creation_timestamp, completed_at=job.status.completion_time, annotations=job.metadata.annotations, - status=_parse_job_status_from_k8s_obj(job), + status=_parse_job_status_from_k8s_obj( + job, pods_per_job[job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)] + ), ) for job in jobs.items ] diff --git a/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py b/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py new file mode 100644 index 00000000..55039109 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py @@ -0,0 +1,67 @@ +# Various fake k8s objects to be used in mocking out the python k8s api client +# Only classes are defined here. If you need to add various fields to the classes, please do so here. + +from dataclasses import dataclass, field +from datetime import datetime +from typing import List, Optional + + +@dataclass +class FakeK8sV1ObjectMeta: + name: str = "fake_name" + namespace: str = "fake_namespace" + annotations: dict = field(default_factory=dict) + labels: dict = field(default_factory=dict) + creation_timestamp: datetime = datetime(2021, 1, 1, 0, 0, 0, 0) + # TODO: everything else + + +@dataclass +class FakeK8sV1PodStatus: + phase: str = "Running" + # TODO: everything else + + +@dataclass +class FakeK8sV1JobStatus: + active: int = 0 + succeeded: int = 0 + failed: int = 0 + ready: int = 0 + terminating: int = 0 + completion_time: Optional[datetime] = None + + +@dataclass +class FakeK8sV1Job: + metadata: FakeK8sV1ObjectMeta = FakeK8sV1ObjectMeta() + status: FakeK8sV1JobStatus = FakeK8sV1JobStatus() + # TODO: spec, api_version, kind + + +@dataclass +class FakeK8sV1JobList: + items: List[FakeK8sV1Job] = field(default_factory=list) + + +@dataclass +class FakeK8sV1Pod: + metadata: FakeK8sV1ObjectMeta = FakeK8sV1ObjectMeta() + status: FakeK8sV1PodStatus = FakeK8sV1PodStatus() + # TODO: spec, api_version, kind + + +@dataclass +class FakeK8sV1PodList: + items: List[FakeK8sV1Pod] = field(default_factory=list) + + +@dataclass +class FakeK8sEnvVar: + name: str + value: str + + +@dataclass +class FakeK8sDeploymentContainer: + env: List[FakeK8sEnvVar] diff --git a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py index 40fe6c14..93e2c8e3 100644 --- a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any, Dict, List from unittest.mock import AsyncMock, Mock, patch @@ -25,21 +24,11 @@ DictStrStr, ResourceArguments, ) +from tests.unit.infra.gateways.k8s_fake_objects import FakeK8sDeploymentContainer, FakeK8sEnvVar MODULE_PATH = "model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate" -@dataclass -class FakeK8sEnvVar: - name: str - value: str - - -@dataclass -class FakeK8sDeploymentContainer: - env: List[FakeK8sEnvVar] - - @pytest.fixture def mock_get_kubernetes_cluster_version(): mock_version = "1.26" diff --git a/model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py index 2f4c5c2a..b792b3d4 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py @@ -1,11 +1,222 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from model_engine_server.domain.entities import BatchJobStatus from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( K8sEnvDict, + LiveDockerImageBatchJobGateway, _add_list_values, _check_batch_job_id_valid, _get_job_id, ) +from tests.unit.infra.gateways.k8s_fake_objects import ( + FakeK8sV1Job, + FakeK8sV1JobList, + FakeK8sV1JobStatus, + FakeK8sV1ObjectMeta, + FakeK8sV1Pod, + FakeK8sV1PodList, + FakeK8sV1PodStatus, +) + +MODULE_PATH = "model_engine_server.infra.gateways.live_docker_image_batch_job_gateway" + + +@pytest.fixture +def mock_core_client(): + mock_client = AsyncMock() + with patch( + f"{MODULE_PATH}.get_kubernetes_core_client", + return_value=mock_client, + ): + yield mock_client + + +@pytest.fixture +def mock_batch_client(): + mock_client = AsyncMock() + with patch( + f"{MODULE_PATH}.get_kubernetes_batch_client", + return_value=mock_client, + ): + yield mock_client + + +@pytest.fixture +def docker_image_batch_job_gateway(): + gateway = LiveDockerImageBatchJobGateway() + return gateway + + +@pytest.mark.parametrize( + "active, succeeded, failed, pod_phase, pod_exists, expected_status", + [ + [1, 0, 0, "Running", True, BatchJobStatus.RUNNING], + [0, 1, 0, "Succeeded", True, BatchJobStatus.SUCCESS], + [0, 0, 1, "Failed", True, BatchJobStatus.FAILURE], + [1, 0, 0, "Pending", True, BatchJobStatus.PENDING], + [0, 0, 0, "Pending", False, BatchJobStatus.PENDING], + ], +) +@pytest.mark.asyncio +async def test_get_docker_image_batch_job_phase( + active, + succeeded, + failed, + pod_phase, + pod_exists, + expected_status, + docker_image_batch_job_gateway, + mock_core_client, + mock_batch_client, +): + if pod_exists: + pod_items = [ + FakeK8sV1Pod( + metadata=FakeK8sV1ObjectMeta( + labels={ + "job-name": "job-name", + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id", + } + ), + status=FakeK8sV1PodStatus( + phase=pod_phase, + ), + ) + ] + else: + pod_items = [] + + mock_core_client.list_namespaced_pod.return_value = FakeK8sV1PodList(items=pod_items) + mock_batch_client.list_namespaced_job.return_value = FakeK8sV1JobList( + items=[ + FakeK8sV1Job( + metadata=FakeK8sV1ObjectMeta( + name="job-name", + labels={ + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id", + }, + ), + status=FakeK8sV1JobStatus( + active=active, + succeeded=succeeded, + failed=failed, + ), + ) + ] + ) + + job = await docker_image_batch_job_gateway.get_docker_image_batch_job("launch_job_id") + assert job is not None + assert job.status == expected_status + + +@pytest.mark.asyncio +async def test_list_docker_image_batch_jobs( + docker_image_batch_job_gateway, + mock_core_client, + mock_batch_client, +): + mock_core_client.list_namespaced_pod.return_value = FakeK8sV1PodList( + items=[ + FakeK8sV1Pod( + metadata=FakeK8sV1ObjectMeta( + labels={ + "job-name": "job-name", + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id", + } + ), + status=FakeK8sV1PodStatus( + phase="Running", + ), + ), + FakeK8sV1Pod( + metadata=FakeK8sV1ObjectMeta( + labels={ + "job-name": "job-name2", + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id2", + } + ), + status=FakeK8sV1PodStatus( + phase="Succeeded", + ), + ), + ] + ) + mock_batch_client.list_namespaced_job.return_value = FakeK8sV1JobList( + items=[ + FakeK8sV1Job( + metadata=FakeK8sV1ObjectMeta( + name="job-name", + labels={ + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id", + }, + ), + status=FakeK8sV1JobStatus( + active=1, + succeeded=0, + failed=0, + ), + ), + FakeK8sV1Job( + metadata=FakeK8sV1ObjectMeta( + name="job-name2", + labels={ + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id2", + }, + ), + status=FakeK8sV1JobStatus( + active=0, + succeeded=1, + failed=0, + ), + ), + FakeK8sV1Job( + metadata=FakeK8sV1ObjectMeta( + name="job-name3", + labels={ + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id3", + }, + ), + status=FakeK8sV1JobStatus( + active=0, + succeeded=0, + failed=0, + ), + ), + ] + ) + + jobs = await docker_image_batch_job_gateway.list_docker_image_batch_jobs(owner="owner") + assert len(jobs) == 3 + job_ids_to_phases = {job.id: job.status for job in jobs} + assert job_ids_to_phases["launch_job_id"] == BatchJobStatus.RUNNING + assert job_ids_to_phases["launch_job_id2"] == BatchJobStatus.SUCCESS + assert job_ids_to_phases["launch_job_id3"] == BatchJobStatus.PENDING +# Small function functionality tests def test_valid_job_ids_are_valid(): for _ in range(20): # _get_job_id() is nondeterministic From 4367b83254a178969016c41ac0d424582528c625 Mon Sep 17 00:00:00 2001 From: William Song Date: Thu, 12 Oct 2023 20:19:41 -0700 Subject: [PATCH 133/425] add user_id and team_id as log facets (#321) * add user_id and team_id as log facets, refactor a little * fix lint, remove draft comments --- model-engine/model_engine_server/api/app.py | 8 ++-- .../model_engine_server/api/dependencies.py | 11 ++++- .../model_engine_server/api/llms_v1.py | 11 +++-- .../model_engine_server/core/loggers.py | 48 +++++++++++++------ 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index f87fcf76..1593b951 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -22,10 +22,10 @@ from model_engine_server.api.tasks_v1 import inference_task_router_v1 from model_engine_server.api.triggers_v1 import trigger_router_v1 from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, filename_wo_ext, - get_request_id, make_logger, - set_request_id, ) logger = make_logger(filename_wo_ext(__name__)) @@ -47,11 +47,11 @@ @app.middleware("http") async def dispatch(request: Request, call_next): try: - set_request_id(str(uuid.uuid4())) + LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) return await call_next(request) except Exception as e: tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) - request_id = get_request_id() + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") structured_log = { "error": str(e), diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index bdd158db..89854841 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -13,7 +13,12 @@ from model_engine_server.core.auth.fake_authentication_repository import ( FakeAuthenticationRepository, ) -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + filename_wo_ext, + make_logger, +) from model_engine_server.db.base import SessionAsync, SessionReadOnlyAsync from model_engine_server.domain.gateways import ( CronJobGateway, @@ -330,6 +335,10 @@ async def verify_authentication( headers={"WWW-Authenticate": "Basic"}, ) + # set logger context with identity data + LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id) + LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id) + return auth diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 92ddad0e..4917ee32 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -36,7 +36,12 @@ ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, get_request_id, make_logger +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + filename_wo_ext, + make_logger, +) from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, EndpointLabelsException, @@ -82,7 +87,7 @@ def handle_streaming_exception( message: str, ): tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) - request_id = get_request_id() + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") structured_log = { "error": message, @@ -223,7 +228,7 @@ async def create_completion_sync_task( user=auth, model_endpoint_name=model_endpoint_name, request=request ) except UpstreamServiceError: - request_id = get_request_id() + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) logger.exception(f"Upstream service error for request {request_id}") raise HTTPException( status_code=500, diff --git a/model-engine/model_engine_server/core/loggers.py b/model-engine/model_engine_server/core/loggers.py index 94c96998..ce0ee847 100644 --- a/model-engine/model_engine_server/core/loggers.py +++ b/model-engine/model_engine_server/core/loggers.py @@ -5,7 +5,8 @@ import sys import warnings from contextlib import contextmanager -from typing import Optional, Sequence +from enum import Enum +from typing import Dict, Optional, Sequence import ddtrace import json_log_formatter @@ -16,8 +17,6 @@ LOG_FORMAT: str = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] - %(message)s" # REQUIRED FOR DATADOG COMPATIBILITY -ctx_var_request_id = contextvars.ContextVar("ctx_var_request_id", default=None) - __all__: Sequence[str] = ( # most common imports "make_logger", @@ -35,19 +34,37 @@ "loggers_at_level", # utils "filename_wo_ext", - "get_request_id", - "set_request_id", + "LoggerTagKey", + "LoggerTagManager", ) -def get_request_id() -> Optional[str]: - """Get the request id from the context variable.""" - return ctx_var_request_id.get() +class LoggerTagKey(str, Enum): + REQUEST_ID = "request_id" + TEAM_ID = "team_id" + USER_ID = "user_id" + + +class LoggerTagManager: + _context_vars: Dict[LoggerTagKey, contextvars.ContextVar] = {} + @classmethod + def get(cls, key: LoggerTagKey) -> Optional[str]: + """Get the value from the context variable.""" + ctx_var = cls._context_vars.get(key) + if ctx_var is not None: + return ctx_var.get() + return None -def set_request_id(request_id: str) -> None: - """Set the request id in the context variable.""" - ctx_var_request_id.set(request_id) # type: ignore + @classmethod + def set(cls, key: LoggerTagKey, value: Optional[str]) -> None: + """Set the value in the context variable.""" + if value is not None: + ctx_var = cls._context_vars.get(key) + if ctx_var is None: + ctx_var = contextvars.ContextVar(f"ctx_var_{key.name.lower()}", default=None) + cls._context_vars[key] = ctx_var + ctx_var.set(value) def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Logger: @@ -77,10 +94,11 @@ def json_record(self, message: str, extra: dict, record: logging.LogRecord) -> d extra["lineno"] = record.lineno extra["pathname"] = record.pathname - # add the http request id if it exists - request_id = ctx_var_request_id.get() - if request_id: - extra["request_id"] = request_id + # add additional logger tags + for tag_key in LoggerTagKey: + tag_value = LoggerTagManager.get(tag_key) + if tag_value: + extra[tag_key.value] = tag_value current_span = tracer.current_span() extra["dd.trace_id"] = current_span.trace_id if current_span else 0 From 744e26382403e96eb8fc5a48898e332d70508257 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 12 Oct 2023 21:46:04 -0700 Subject: [PATCH 134/425] publish 0.0.0b19 (#322) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index f52e0462..694a969f 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b18" +__version__ = "0.0.0b19" import os from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 6aa1063a..19225a91 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta18" +version = "0.0.0.beta19" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 97d6eaae..8a9895b2 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta18", + version="0.0.0.beta19", packages=find_packages(), ) From dcc3404d920c1f5cd5034ac4cf1efd84d45cb988 Mon Sep 17 00:00:00 2001 From: William Song Date: Mon, 16 Oct 2023 08:56:51 -0700 Subject: [PATCH 135/425] Auth for post_file client route (#323) --- clients/python/README.md | 2 +- clients/python/llmengine/api_engine.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/clients/python/README.md b/clients/python/README.md index 21befcb4..e9f6d289 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -19,7 +19,7 @@ If you are using LLM Engine, you can get your API key from Set the `SCALE_API_KEY` environment variable to your API key. If you are using your own infrastructure, you can set the -`LLM_ENGINE_SERVE_BASE_PATH` environment variable to the base URL of your +`LLM_ENGINE_BASE_PATH` environment variable to the base URL of your self-hosted `llmengine` endpoint. ```python diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index aa857183..1431d6cb 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -150,6 +150,7 @@ def post_file( files=files, timeout=timeout, headers={"x-api-key": api_key}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) From 1ecc1927ec15be20defb2690eed460a0dd707ccf Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 16 Oct 2023 18:02:31 -0700 Subject: [PATCH 136/425] Add pod disruption budget to all endpoints (#328) * Add pod disruption budget to all endpoints * Delete pdb as well * fix test * fix tests --- .../service_template_config_map.yaml | 13 ++++ .../gateways/resources/image_cache_gateway.py | 3 - .../k8s_endpoint_resource_delegate.py | 78 +++++++++++++++++++ .../gateways/resources/k8s_resource_types.py | 23 +++++- .../service_template_config_map_circleci.yaml | 25 ++++++ .../test_k8s_endpoint_resource_delegate.py | 33 ++++++++ 6 files changed, 170 insertions(+), 5 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 8eda04b1..1df03140 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -515,6 +515,19 @@ data: cpu: ${CPUS} memory: ${MEMORY} controlledResources: ["cpu", "memory"] + pod-disruption-budget.yaml: |- + apiVersion: policy/v1 + kind: PodDisruptionBudget + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + minAvailable: 1 + selector: + matchLabels: + app: ${RESOURCE_NAME} batch-job-orchestration-job.yaml: |- apiVersion: batch/v1 kind: Job diff --git a/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py index bdd15e27..fc5a7e54 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py @@ -23,9 +23,6 @@ class CachedImages(TypedDict): t4: List[str] -KUBERNETES_MAX_LENGTH = 64 - - class ImageCacheGateway: async def create_or_update_image_cache(self, cached_images: CachedImages) -> None: """ diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 56836596..9e129be6 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -67,6 +67,7 @@ _kubernetes_core_api = None _kubernetes_autoscaling_api = None _kubernetes_batch_api = None +_kubernetes_policy_api = None _kubernetes_custom_objects_api = None _kubernetes_cluster_version = None @@ -147,6 +148,16 @@ def get_kubernetes_batch_client(): # pragma: no cover return _kubernetes_batch_api +def get_kubernetes_policy_client(): # pragma: no cover + if _lazy_load_kubernetes_clients: + global _kubernetes_policy_api + else: + _kubernetes_policy_api = None + if not _kubernetes_policy_api: + _kubernetes_policy_api = kubernetes_asyncio.client.PolicyV1Api() + return _kubernetes_policy_api + + def get_kubernetes_custom_objects_client(): # pragma: no cover if _lazy_load_kubernetes_clients: global _kubernetes_custom_objects_api @@ -599,6 +610,37 @@ async def _create_vpa(vpa: Dict[str, Any], name: str) -> None: logger.exception("Got an exception when trying to apply the VerticalPodAutoscaler") raise + @staticmethod + async def _create_pdb(pdb: Dict[str, Any], name: str) -> None: + """ + Lower-level function to create/patch a k8s PodDisruptionBudget (pdb) + Args: + pdb: PDB body (a nested Dict in the format specified by Kubernetes) + name: The name of the pdb on K8s + + Returns: + Nothing; raises a k8s ApiException if failure + + """ + policy_api = get_kubernetes_policy_client() + try: + await policy_api.create_namespaced_pod_disruption_budget( + namespace=hmi_config.endpoint_namespace, + body=pdb, + ) + except ApiException as exc: + if exc.status == 409: + logger.info(f"PodDisruptionBudget {name} already exists, replacing") + + await policy_api.patch_namespaced_pod_disruption_budget( + name=name, + namespace=hmi_config.endpoint_namespace, + body=pdb, + ) + else: + logger.exception("Got an exception when trying to apply the PodDisruptionBudget") + raise + @staticmethod async def _create_keda_scaled_object(scaled_object: Dict[str, Any], name: str) -> None: custom_objects_api = get_kubernetes_custom_objects_client() @@ -1035,6 +1077,27 @@ async def _delete_hpa(endpoint_id: str, deployment_name: str) -> bool: return False return True + @staticmethod + async def _delete_pdb(endpoint_id: str) -> bool: + policy_client = get_kubernetes_policy_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + await policy_client.delete_namespaced_pod_disruption_budget( + namespace=hmi_config.endpoint_namespace, + name=k8s_resource_group_name, + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Trying to delete nonexistent PodDisruptionBudget {k8s_resource_group_name}" + ) + else: + logger.exception( + f"Deletion of PodDisruptionBudget {k8s_resource_group_name} failed" + ) + return False + return True + @staticmethod async def _delete_keda_scaled_object(endpoint_id: str) -> bool: custom_objects_client = get_kubernetes_custom_objects_client() @@ -1152,6 +1215,19 @@ async def _create_or_update_resources( name=k8s_resource_group_name, ) + pdb_config_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="pod-disruption-budget", + ) + pdb_template = load_k8s_yaml("pod-disruption-budget.yaml", pdb_config_arguments) + await self._create_pdb( + pdb=pdb_template, + name=k8s_resource_group_name, + ) + if model_endpoint_record.endpoint_type in { ModelEndpointType.SYNC, ModelEndpointType.STREAMING, @@ -1561,6 +1637,7 @@ async def _delete_resources_async(self, endpoint_id: str, deployment_name: str) endpoint_id=endpoint_id, deployment_name=deployment_name ) await self._delete_vpa(endpoint_id=endpoint_id) + await self._delete_pdb(endpoint_id=endpoint_id) return deployment_delete_succeeded and config_map_delete_succeeded async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) -> bool: @@ -1582,6 +1659,7 @@ async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) - endpoint_id=endpoint_id ) await self._delete_vpa(endpoint_id=endpoint_id) + await self._delete_pdb(endpoint_id=endpoint_id) destination_rule_delete_succeeded = await self._delete_destination_rule( endpoint_id=endpoint_id diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 9e77ac68..72b2c196 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -66,7 +66,7 @@ LAUNCH_HIGH_PRIORITY_CLASS = "model-engine-high-priority" LAUNCH_DEFAULT_PRIORITY_CLASS = "model-engine-default-priority" -KUBERNETES_MAX_LENGTH = 64 +IMAGE_HASH_MAX_LENGTH = 32 FORWARDER_PORT = 5000 USER_CONTAINER_PORT = 5005 ARTIFACT_LIKE_CONTAINER_PORT = FORWARDER_PORT @@ -329,6 +329,12 @@ class VerticalPodAutoscalerArguments(_BaseEndpointArguments): MEMORY: str +class PodDisruptionBudgetArguments(_BaseEndpointArguments): + """Keyword-arguments for substituting into pod disruption budget templates.""" + + pass + + class VirtualServiceArguments(_BaseEndpointArguments): """Keyword-arguments for substituting into virtual-service templates.""" @@ -432,7 +438,7 @@ class VerticalAutoscalingEndpointParams(TypedDict): def compute_image_hash(image: str) -> str: - return str(hashlib.md5(str(image).encode()).hexdigest())[:KUBERNETES_MAX_LENGTH] + return str(hashlib.sha256(str(image).encode()).hexdigest())[:IMAGE_HASH_MAX_LENGTH] def container_start_triton_cmd( @@ -1184,5 +1190,18 @@ def get_endpoint_resource_arguments_from_request( CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), ) + elif endpoint_resource_name == "pod-disruption-budget": + return PodDisruptionBudgetArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + ) else: raise Exception(f"Unknown resource name: {endpoint_resource_name}") diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index e50e1623..606fee3e 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -2728,6 +2728,31 @@ data: cpu: ${CPUS} memory: ${MEMORY} controlledResources: ["cpu", "memory"] + pod-disruption-budget.yaml: |- + apiVersion: policy/v1 + kind: PodDisruptionBudget + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + minAvailable: 1 + selector: + matchLabels: + app: ${RESOURCE_NAME} batch-job-orchestration-job.yaml: |- apiVersion: batch/v1 kind: Job diff --git a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py index 93e2c8e3..acd298b3 100644 --- a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py @@ -69,6 +69,16 @@ def mock_autoscaling_client(): yield mock_client +@pytest.fixture +def mock_policy_client(): + mock_client = AsyncMock() + with patch( + f"{MODULE_PATH}.get_kubernetes_policy_client", + return_value=mock_client, + ): + yield mock_client + + @pytest.fixture def mock_custom_objects_client(): mock_client = AsyncMock() @@ -276,6 +286,7 @@ async def test_create_async_endpoint_has_correct_labels( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, mock_get_kubernetes_cluster_version, create_resources_request_async_runnable_image: CreateOrUpdateResourcesRequest, @@ -323,6 +334,11 @@ async def test_create_async_endpoint_has_correct_labels( ) assert delete_custom_object_call_args_list == [] + # Verify PDB labels + create_pdb_call_args = mock_policy_client.create_namespaced_pod_disruption_budget.call_args + pdb_body = create_pdb_call_args.kwargs["body"] + _verify_non_deployment_labels(pdb_body, request) + if build_endpoint_request.model_endpoint_record.endpoint_type == ModelEndpointType.SYNC: assert create_custom_object_call_args_list == [] _verify_custom_object_plurals( @@ -339,6 +355,7 @@ async def test_create_streaming_endpoint_has_correct_labels( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, mock_get_kubernetes_cluster_version, create_resources_request_streaming_runnable_image: CreateOrUpdateResourcesRequest, @@ -365,6 +382,11 @@ async def test_create_streaming_endpoint_has_correct_labels( config_map_body = create_config_map_call_args.kwargs["body"] _verify_non_deployment_labels(config_map_body, request) + # Verify PDB labels + create_pdb_call_args = mock_policy_client.create_namespaced_pod_disruption_budget.call_args + pdb_body = create_pdb_call_args.kwargs["body"] + _verify_non_deployment_labels(pdb_body, request) + # Verify HPA labels create_hpa_call_args = ( mock_autoscaling_client.create_namespaced_horizontal_pod_autoscaler.call_args @@ -406,6 +428,7 @@ async def test_create_sync_endpoint_has_correct_labels( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, mock_get_kubernetes_cluster_version, create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest, @@ -441,6 +464,11 @@ async def test_create_sync_endpoint_has_correct_labels( hpa_body = create_hpa_call_args.kwargs["body"] _verify_non_deployment_labels(hpa_body, request) + # Verify PDB labels + create_pdb_call_args = mock_policy_client.create_namespaced_pod_disruption_budget.call_args + pdb_body = create_pdb_call_args.kwargs["body"] + _verify_non_deployment_labels(pdb_body, request) + # Make sure that an VPA is created if optimize_costs is True. build_endpoint_request = request.build_endpoint_request optimize_costs = build_endpoint_request.optimize_costs @@ -477,6 +505,7 @@ async def test_create_sync_endpoint_has_correct_k8s_service_type( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, mock_get_kubernetes_cluster_version, create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest, @@ -531,6 +560,7 @@ async def test_get_resources_async_success( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, ): k8s_endpoint_resource_delegate.__setattr__( @@ -590,6 +620,7 @@ async def test_get_resources_sync_success( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, ): k8s_endpoint_resource_delegate.__setattr__( @@ -653,6 +684,7 @@ async def test_delete_resources_async_success( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, ): deleted = await k8s_endpoint_resource_delegate.delete_resources( @@ -667,6 +699,7 @@ async def test_delete_resources_sync_success( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, ): deleted = await k8s_endpoint_resource_delegate.delete_resources( From 82824343336e3fabc82de64e94fdc146c72f1052 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 17 Oct 2023 16:39:03 -0700 Subject: [PATCH 137/425] create celery worker with inference worker profile (#327) * create celery worker with inference worker * try get aws profile --- model-engine/model_engine_server/core/celery/app.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index 7e87d2f0..42651d56 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -504,7 +504,8 @@ def _get_backend_url_and_conf( elif backend_protocol == "s3": backend_url = "s3://" if aws_role is None: - aws_session = session(infra_config().profile_ml_worker) + aws_profile = os.getenv("AWS_PROFILE", infra_config().profile_ml_worker) + aws_session = session(aws_profile) else: aws_session = session(aws_role) out_conf_changes.update( From 5407e4d825836ff71686fa9e328fd76abfe8347c Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:53:48 -0700 Subject: [PATCH 138/425] Bump http forwarder request CPU (#330) --- .../model-engine/templates/service_template_config_map.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 1df03140..15486174 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -134,7 +134,7 @@ data: timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -180,7 +180,7 @@ data: timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: From a5245a51a0b94c970aa632dfb78b228852f0d232 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:37:47 -0700 Subject: [PATCH 139/425] [Docs] Clarify get-events API usage (#320) * spin out a monitoring finetune section * add code snippet * example error messages * while we're here update mistral * add link to canonical def of hyperparams --- docs/guides/fine_tuning.md | 34 +++++++++++++++++++++++++++++++--- docs/model_zoo.md | 4 ++-- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/docs/guides/fine_tuning.md b/docs/guides/fine_tuning.md index 24b7babd..705f1b23 100644 --- a/docs/guides/fine_tuning.md +++ b/docs/guides/fine_tuning.md @@ -141,8 +141,8 @@ from llmengine import FineTune response = FineTune.create( model="llama-2-7b", - training_file="file-7DLVeLdN2Ty4M2m", - validation_file="file-ezSRtpgKQyItI26", + training_file="file-AbCDeLdN2Ty4M2m", + validation_file="file-ezSRpgtKQyItI26", ) print(response.json()) @@ -152,7 +152,35 @@ See the [Model Zoo](../../model_zoo) to see which models have fine-tuning suppor See [Integrations](../integrations.md) to see how to track fine-tuning metrics. -Once the fine-tune is launched, you can also [get the status of your fine-tune](../../api/python_client/#llmengine.fine_tuning.FineTune.get). You can also [list events that your fine-tune produces](../../api/python_client/#llmengine.fine_tuning.FineTune.get_events). +## Monitoring the fine-tune + +Once the fine-tune is launched, you can also [get the status of your fine-tune](../../api/python_client/#llmengine.fine_tuning.FineTune.get). +You can also [list events that your fine-tune produces](../../api/python_client/#llmengine.fine_tuning.FineTune.get_events). +```python +from llmengine import FineTune + +fine_tune_id = "ft-cabcdefghi1234567890" +fine_tune = FineTune.get(fine_tune_id) +print(fine_tune.status) # BatchJobStatus.RUNNING +print(fine_tune.fine_tuned_model) # "llama-2-7b.700101-000000 + +fine_tune_events = FineTune.get_events(fine_tune_id) +for event in fine_tune_events.events: + print(event) +# Prints something like: +# timestamp=1697590000.0 message="{'loss': 12.345, 'learning_rate': 0.0, 'epoch': 0.97}" level='info' +# timestamp=1697590000.0 message="{'eval_loss': 23.456, 'eval_runtime': 19.876, 'eval_samples_per_second': 4.9, 'eval_steps_per_second': 4.9, 'epoch': 0.97}" level='info' +# timestamp=1697590020.0 message="{'train_runtime': 421.234, 'train_samples_per_second': 2.042, 'train_steps_per_second': 0.042, 'total_flos': 123.45, 'train_loss': 34.567, 'epoch': 0.97}" level='info' + + +``` + +The status of your fine-tune will give a high-level overview of the fine-tune's progress. +The events of your fine-tune will give more detail, such as the training loss and validation loss at each epoch, +as well as any errors that may have occurred. If you encounter any errors with your fine-tune, +the events are a good place to start debugging. For example, if you see `Unable to read training or validation dataset`, +you may need to make your files accessible to LLM Engine. If you see `Invalid value received for lora parameter 'lora_alpha'!`, +you should [check that your hyperparameters are valid](../../api/python_client/#llmengine.fine_tuning.FineTune.create). ## Making inference calls to your fine-tune diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 0c61c38a..7c438019 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -18,8 +18,8 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `mpt-7b` | ✅ | | deepspeed, text-generation-inference, vllm | | `mpt-7b-instruct` | ✅ | ✅ | deepspeed, text-generation-inference, vllm | | `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | -| `mistral-7b` | ✅ | | vllm | -| `mistral-7b-instruct` | ✅ | | vllm | +| `mistral-7b` | ✅ | ✅ | vllm | +| `mistral-7b-instruct` | ✅ | ✅ | vllm | ## Usage From e7cb20daf921ddabbeee99b909d866a7ecb81f28 Mon Sep 17 00:00:00 2001 From: William Song Date: Thu, 19 Oct 2023 11:42:19 -1000 Subject: [PATCH 140/425] Enable additional Datadog tagging for jobs (#324) --- charts/model-engine/templates/_helpers.tpl | 3 +++ model-engine/model_engine_server/common/env_vars.py | 1 + .../gateways/live_batch_job_orchestration_gateway.py | 8 +++++++- .../gateways/live_docker_image_batch_job_gateway.py | 9 ++++++++- .../infra/gateways/resources/k8s_resource_types.py | 1 + .../docker_image_batch_job_llm_fine_tuning_service.py | 6 ++++++ 6 files changed, 26 insertions(+), 2 deletions(-) diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index 0fcf816d..75b69dc3 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -85,7 +85,10 @@ endpoint_name: ${ENDPOINT_NAME} {{- define "modelEngine.jobTemplateLabels" -}} {{- include "modelEngine.baseTemplateLabels" . | printf "%s\n" -}} launch_job_id: ${JOB_ID} +tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} +tags.datadoghq.com/user_id: ${OWNER} +tags.datadoghq.com/team: ${TEAM} {{- end }} {{- define "modelEngine.serviceTemplateAsyncAnnotations" -}} diff --git a/model-engine/model_engine_server/common/env_vars.py b/model-engine/model_engine_server/common/env_vars.py index 9d5ca20a..ad7478fa 100644 --- a/model-engine/model_engine_server/common/env_vars.py +++ b/model-engine/model_engine_server/common/env_vars.py @@ -64,6 +64,7 @@ def get_boolean_env_var(name: str) -> bool: ) """The path to the config map containing the Launch service template. """ +logger.info(f"{LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH=}") LAUNCH_SERVICE_TEMPLATE_FOLDER: Optional[str] = os.environ.get("LAUNCH_SERVICE_TEMPLATE_FOLDER") """The path to the folder containing the Launch service template. If set, this overrides diff --git a/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py b/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py index 40cc9f9d..2f12c943 100644 --- a/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py @@ -3,7 +3,12 @@ from kubernetes_asyncio.client.rest import ApiException from model_engine_server.common.config import hmi_config from model_engine_server.common.env_vars import GIT_TAG -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + filename_wo_ext, + make_logger, +) from model_engine_server.domain.entities import BatchJobSerializationFormat from model_engine_server.domain.exceptions import EndpointResourceInfraException from model_engine_server.infra.gateways import BatchJobOrchestrationGateway @@ -55,6 +60,7 @@ async def create_batch_job_orchestrator( BATCH_JOB_MAX_RUNTIME=int(timeout_seconds + SHUTDOWN_GRACE_PERIOD), BATCH_JOB_TTL_SECONDS_AFTER_FINISHED=BATCH_JOB_TTL_SECONDS_AFTER_FINISHED, GIT_TAG=GIT_TAG, + REQUEST_ID=LoggerTagManager.get(LoggerTagKey.REQUEST_ID) or "", ) resource_key = "batch-job-orchestration-job.yaml" deployment_spec = load_k8s_yaml(resource_key, substitution_kwargs) diff --git a/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py index 40e09a8c..cb7154af 100644 --- a/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py @@ -11,7 +11,12 @@ from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests from model_engine_server.common.serialization_utils import python_json_to_b64 from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + filename_wo_ext, + make_logger, +) from model_engine_server.domain.entities.batch_job_entity import BatchJobStatus, DockerImageBatchJob from model_engine_server.domain.exceptions import EndpointResourceInfraException from model_engine_server.domain.gateways.docker_image_batch_job_gateway import ( @@ -231,6 +236,7 @@ def _generate_job_spec( # GPU Arguments GPU_TYPE=resource_requests.gpu_type.value, GPUS=resource_requests.gpus or 1, + REQUEST_ID=LoggerTagManager.get(LoggerTagKey.REQUEST_ID) or "", ) else: resource_key = "docker-image-batch-job-cpu.yaml" @@ -259,6 +265,7 @@ def _generate_job_spec( LOCAL_FILE_NAME=mount_location, FILE_CONTENTS_B64ENCODED=job_config_b64encoded, AWS_ROLE=infra_config().profile_ml_inference_worker, + REQUEST_ID=LoggerTagManager.get(LoggerTagKey.REQUEST_ID) or "", ) resource_spec = load_k8s_yaml(resource_key, substitution_kwargs) diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 72b2c196..048f8002 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -167,6 +167,7 @@ class _JobArguments(_BaseResourceArguments): JOB_ID: str BATCH_JOB_MAX_RUNTIME: int BATCH_JOB_TTL_SECONDS_AFTER_FINISHED: int + REQUEST_ID: str class _DockerImageBatchJobArguments(_JobArguments): diff --git a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py index 9e25b8cf..f4622a16 100644 --- a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py +++ b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import FineTuneHparamValueType from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob from model_engine_server.domain.exceptions import ( @@ -17,6 +18,8 @@ from model_engine_server.domain.services import LLMFineTuningService from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository +logger = make_logger(logger_name()) + class DockerImageBatchJobLLMFineTuningService(LLMFineTuningService): def __init__( @@ -76,6 +79,9 @@ async def create_fine_tune( # TODO: Pass user-defined labels labels = dict(team="egp", product="llm-fine-tune") + logger.info( + f"Using bundle {di_batch_job_bundle.id} for fine-tune job: {di_batch_job_bundle.image_repository=}, {di_batch_job_bundle.image_tag=}" + ) batch_job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( created_by=created_by, owner=owner, From fe24d634967ffa2437a85cd81faa22dbf114390b Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:59:34 -0700 Subject: [PATCH 141/425] fix celery worker profile for s3 access (#333) * change profile * fix profile settings --- model-engine/model_engine_server/core/celery/app.py | 3 +-- .../inference/forwarding/celery_forwarder.py | 1 + .../infra/gateways/celery_task_queue_gateway.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index 42651d56..7e87d2f0 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -504,8 +504,7 @@ def _get_backend_url_and_conf( elif backend_protocol == "s3": backend_url = "s3://" if aws_role is None: - aws_profile = os.getenv("AWS_PROFILE", infra_config().profile_ml_worker) - aws_session = session(aws_profile) + aws_session = session(infra_config().profile_ml_worker) else: aws_session = session(aws_role) out_conf_changes.update( diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 16e7fc34..6206f711 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -95,6 +95,7 @@ def create_celery_service( app: Celery = celery_app( name=None, s3_bucket=infra_config().s3_bucket, + aws_role=infra_config().profile_ml_inference_worker, task_visibility=task_visibility, broker_type=str(BrokerType.SQS.value if sqs_url else BrokerType.REDIS.value), broker_transport_options={"predefined_queues": {queue_name: {"url": sqs_url}}} diff --git a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index 66f39f83..8d487029 100644 --- a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -11,9 +11,7 @@ from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway celery_redis = celery_app( - None, - s3_bucket=infra_config().s3_bucket, - broker_type=str(BrokerType.REDIS.value), + None, s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.REDIS.value) ) celery_redis_24h = celery_app( None, From 1a3b5e046abeffa38f5f55fbfb3e6b0069518ae0 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 19 Oct 2023 17:43:03 -0700 Subject: [PATCH 142/425] Hardcode number of forwarder workers (#334) * Hardcode number of forwarder workers * revert worker count change for async endpoints * type --- .../templates/service_template_config_map.yaml | 4 ++-- .../model_engine_server/common/resource_limits.py | 1 + .../infra/gateways/resources/k8s_resource_types.py | 9 +++++++++ .../service_template_config_map_circleci.yaml | 12 ++++++------ 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 15486174..281bc2ea 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -117,7 +117,7 @@ data: - --port - "${FORWARDER_PORT}" - --num-workers - - "${PER_WORKER}" + - "${FORWARDER_WORKER_COUNT}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -161,7 +161,7 @@ data: - --port - "${FORWARDER_PORT}" - --num-workers - - "${PER_WORKER}" + - "${FORWARDER_WORKER_COUNT}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set diff --git a/model-engine/model_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py index 502f7dd7..64e25669 100644 --- a/model-engine/model_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -46,6 +46,7 @@ FORWARDER_CPU_USAGE = 1 FORWARDER_MEMORY_USAGE = "2Gi" FORWARDER_STORAGE_USAGE = "1G" +FORWARDER_WORKER_COUNT = 2 logger = make_logger(filename_wo_ext(__name__)) diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 048f8002..e417058d 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -10,6 +10,7 @@ FORWARDER_CPU_USAGE, FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_USAGE, + FORWARDER_WORKER_COUNT, ) from model_engine_server.common.serialization_utils import python_json_to_b64 from model_engine_server.core.config import infra_config @@ -136,6 +137,7 @@ class _SyncRunnableImageDeploymentArguments(TypedDict): """Keyword-arguments for substituting into sync deployment templates.""" FORWARDER_PORT: int + FORWARDER_WORKER_COUNT: int class _StreamingDeploymentArguments(TypedDict): @@ -143,6 +145,7 @@ class _StreamingDeploymentArguments(TypedDict): FORWARDER_PORT: int STREAMING_PREDICT_ROUTE: str + FORWARDER_WORKER_COUNT: int class _RunnableImageDeploymentArguments(_BaseDeploymentArguments): @@ -691,6 +694,7 @@ def get_endpoint_resource_arguments_from_request( USER_CONTAINER_PORT=USER_CONTAINER_PORT, # Streaming Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, ) elif endpoint_resource_name == "deployment-runnable-image-streaming-gpu": assert isinstance(flavor, StreamingEnhancedRunnableImageFlavor) @@ -735,6 +739,7 @@ def get_endpoint_resource_arguments_from_request( USER_CONTAINER_PORT=USER_CONTAINER_PORT, # Streaming Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, # GPU Deployment Arguments GPU_TYPE=build_endpoint_request.gpu_type.value, GPUS=build_endpoint_request.gpus, @@ -780,6 +785,7 @@ def get_endpoint_resource_arguments_from_request( USER_CONTAINER_PORT=USER_CONTAINER_PORT, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, ) elif endpoint_resource_name == "deployment-runnable-image-sync-gpu": assert isinstance(flavor, RunnableImageLike) @@ -823,6 +829,7 @@ def get_endpoint_resource_arguments_from_request( USER_CONTAINER_PORT=USER_CONTAINER_PORT, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, # GPU Deployment Arguments GPU_TYPE=build_endpoint_request.gpu_type.value, GPUS=build_endpoint_request.gpus, @@ -982,6 +989,7 @@ def get_endpoint_resource_arguments_from_request( USER_CONTAINER_PORT=USER_CONTAINER_PORT, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, # Triton Deployment Arguments TRITON_MODEL_REPOSITORY=flavor.triton_model_repository, TRITON_CPUS=str(flavor.triton_num_cpu), @@ -1033,6 +1041,7 @@ def get_endpoint_resource_arguments_from_request( USER_CONTAINER_PORT=USER_CONTAINER_PORT, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, # GPU Deployment Arguments GPU_TYPE=build_endpoint_request.gpu_type.value, GPUS=build_endpoint_request.gpus, diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 606fee3e..63e94d0b 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -610,7 +610,7 @@ data: - --port - "${FORWARDER_PORT}" - --num-workers - - "${PER_WORKER}" + - "${FORWARDER_WORKER_COUNT}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -878,7 +878,7 @@ data: - --port - "${FORWARDER_PORT}" - --num-workers - - "${PER_WORKER}" + - "${FORWARDER_WORKER_COUNT}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -1102,7 +1102,7 @@ data: - --port - "${FORWARDER_PORT}" - --num-workers - - "${PER_WORKER}" + - "${FORWARDER_WORKER_COUNT}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -1845,7 +1845,7 @@ data: - --port - "${FORWARDER_PORT}" - --num-workers - - "${PER_WORKER}" + - "${FORWARDER_WORKER_COUNT}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -2120,7 +2120,7 @@ data: - --port - "${FORWARDER_PORT}" - --num-workers - - "${PER_WORKER}" + - "${FORWARDER_WORKER_COUNT}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -2351,7 +2351,7 @@ data: - --port - "${FORWARDER_PORT}" - --num-workers - - "${PER_WORKER}" + - "${FORWARDER_WORKER_COUNT}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set From 38caaa63561f0be6b7e617c135babfe0ea29e990 Mon Sep 17 00:00:00 2001 From: William Song Date: Thu, 19 Oct 2023 16:10:32 -1000 Subject: [PATCH 143/425] use make_logger(logger_name()) as standard (#337) --- model-engine/model_engine_server/api/app.py | 4 ++-- model-engine/model_engine_server/api/batch_jobs_v1.py | 4 ++-- model-engine/model_engine_server/api/dependencies.py | 4 ++-- .../api/docker_image_batch_job_bundles_v1.py | 4 ++-- model-engine/model_engine_server/api/files_v1.py | 4 ++-- model-engine/model_engine_server/api/llms_v1.py | 4 ++-- model-engine/model_engine_server/api/model_bundles_v1.py | 4 ++-- model-engine/model_engine_server/api/model_bundles_v2.py | 4 ++-- .../model_engine_server/api/model_endpoints_docs_v1.py | 4 ++-- model-engine/model_engine_server/api/model_endpoints_v1.py | 4 ++-- model-engine/model_engine_server/api/tasks_v1.py | 4 ++-- model-engine/model_engine_server/api/triggers_v1.py | 4 ++-- model-engine/model_engine_server/common/config.py | 4 ++-- model-engine/model_engine_server/common/resource_limits.py | 4 ++-- model-engine/model_engine_server/common/service_requests.py | 4 ++-- model-engine/model_engine_server/core/aws/secrets.py | 4 ++-- model-engine/model_engine_server/core/config.py | 4 ++-- .../model_engine_server/core/docker/docker_image.py | 5 ++--- model-engine/model_engine_server/core/loggers.py | 5 ++--- model-engine/model_engine_server/core/utils/timer.py | 5 ++--- model-engine/model_engine_server/db/base.py | 4 ++-- model-engine/model_engine_server/db/endpoint_row_lock.py | 4 ++-- .../domain/use_cases/batch_job_use_cases.py | 4 ++-- .../model_engine_server/domain/use_cases/file_use_cases.py | 4 ++-- .../domain/use_cases/llm_fine_tuning_use_cases.py | 4 ++-- .../domain/use_cases/llm_model_endpoint_use_cases.py | 4 ++-- .../domain/use_cases/model_endpoint_use_cases.py | 4 ++-- model-engine/model_engine_server/entrypoints/k8s_cache.py | 4 ++-- .../start_docker_image_batch_job_init_container.py | 4 ++-- .../model_engine_server/inference/async_inference/tasks.py | 4 ++-- model-engine/model_engine_server/inference/common.py | 4 ++-- .../inference/download_and_inject_bundle.py | 4 ++-- .../model_engine_server/inference/post_inference_hooks.py | 4 ++-- .../model_engine_server/inference/service_requests.py | 4 ++-- .../inference/sync_inference/fastapi_server.py | 4 ++-- .../infra/gateways/live_batch_job_orchestration_gateway.py | 4 ++-- .../infra/gateways/live_batch_job_progress_gateway.py | 4 ++-- .../infra/gateways/live_cron_job_gateway.py | 4 ++-- .../infra/gateways/live_docker_image_batch_job_gateway.py | 4 ++-- .../live_streaming_model_endpoint_inference_gateway.py | 4 ++-- .../gateways/live_sync_model_endpoint_inference_gateway.py | 4 ++-- .../infra/gateways/resources/image_cache_gateway.py | 4 ++-- .../gateways/resources/k8s_endpoint_resource_delegate.py | 4 ++-- .../gateways/resources/live_endpoint_resource_gateway.py | 4 ++-- .../resources/live_sqs_endpoint_resource_delegate.py | 4 ++-- .../repositories/db_model_endpoint_record_repository.py | 4 ++-- .../infra/services/image_cache_service.py | 4 ++-- .../infra/services/live_batch_job_orchestration_service.py | 4 ++-- .../infra/services/live_batch_job_service.py | 4 ++-- .../infra/services/live_endpoint_builder_service.py | 6 +++--- .../infra/services/live_llm_model_endpoint_service.py | 4 ++-- .../infra/services/live_model_endpoint_service.py | 4 ++-- 52 files changed, 105 insertions(+), 108 deletions(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index 1593b951..b3a41dfd 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -24,11 +24,11 @@ from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, - filename_wo_ext, + logger_name, make_logger, ) -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) app = FastAPI(title="launch", version="1.0.0", redoc_url="/api") diff --git a/model-engine/model_engine_server/api/batch_jobs_v1.py b/model-engine/model_engine_server/api/batch_jobs_v1.py index 022b9dc8..6241b202 100644 --- a/model-engine/model_engine_server/api/batch_jobs_v1.py +++ b/model-engine/model_engine_server/api/batch_jobs_v1.py @@ -22,7 +22,7 @@ UpdateDockerImageBatchJobV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, EndpointLabelsException, @@ -43,7 +43,7 @@ batch_job_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @batch_job_router_v1.post("/batch-jobs", response_model=CreateBatchJobV1Response) diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 89854841..30062252 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -16,7 +16,7 @@ from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, - filename_wo_ext, + logger_name, make_logger, ) from model_engine_server.db.base import SessionAsync, SessionReadOnlyAsync @@ -100,7 +100,7 @@ ) from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) AUTH = HTTPBasic(auto_error=False) diff --git a/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py index 1444a39b..4b2980be 100644 --- a/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py +++ b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py @@ -15,7 +15,7 @@ ) from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( EndpointResourceInvalidRequestException, ObjectNotAuthorizedException, @@ -30,7 +30,7 @@ docker_image_batch_job_bundle_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @docker_image_batch_job_bundle_router_v1.post( diff --git a/model-engine/model_engine_server/api/files_v1.py b/model-engine/model_engine_server/api/files_v1.py index 8c50cc53..dd52c10b 100644 --- a/model-engine/model_engine_server/api/files_v1.py +++ b/model-engine/model_engine_server/api/files_v1.py @@ -16,7 +16,7 @@ UploadFileResponse, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, @@ -30,7 +30,7 @@ ) file_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @file_router_v1.post("/files", response_model=UploadFileResponse) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 4917ee32..7e73ef70 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -39,7 +39,7 @@ from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, - filename_wo_ext, + logger_name, make_logger, ) from model_engine_server.domain.exceptions import ( @@ -78,7 +78,7 @@ from sse_starlette.sse import EventSourceResponse llm_router_v1 = APIRouter(prefix="/v1/llm") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) def handle_streaming_exception( diff --git a/model-engine/model_engine_server/api/model_bundles_v1.py b/model-engine/model_engine_server/api/model_bundles_v1.py index e192af13..c83bda5d 100644 --- a/model-engine/model_engine_server/api/model_bundles_v1.py +++ b/model-engine/model_engine_server/api/model_bundles_v1.py @@ -19,7 +19,7 @@ ModelBundleV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, @@ -35,7 +35,7 @@ ) model_bundle_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @model_bundle_router_v1.post("/model-bundles", response_model=CreateModelBundleV1Response) diff --git a/model-engine/model_engine_server/api/model_bundles_v2.py b/model-engine/model_engine_server/api/model_bundles_v2.py index d35de5cf..39f4a7d8 100644 --- a/model-engine/model_engine_server/api/model_bundles_v2.py +++ b/model-engine/model_engine_server/api/model_bundles_v2.py @@ -19,7 +19,7 @@ ModelBundleV2Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, @@ -35,7 +35,7 @@ ) model_bundle_router_v2 = APIRouter(prefix="/v2") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @model_bundle_router_v2.post("/model-bundles", response_model=CreateModelBundleV2Response) diff --git a/model-engine/model_engine_server/api/model_endpoints_docs_v1.py b/model-engine/model_engine_server/api/model_endpoints_docs_v1.py index 9b7f1d1f..f4f2d734 100644 --- a/model-engine/model_engine_server/api/model_endpoints_docs_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_docs_v1.py @@ -8,14 +8,14 @@ verify_authentication, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.use_cases.model_endpoints_schema_use_cases import ( GetModelEndpointsSchemaV1UseCase, ) from starlette.responses import HTMLResponse model_endpoints_docs_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @model_endpoints_docs_router_v1.get("/model-endpoints-schema.json") diff --git a/model-engine/model_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py index 4bf3cf32..e761d2c5 100644 --- a/model-engine/model_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -24,7 +24,7 @@ UpdateModelEndpointV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, EndpointLabelsException, @@ -45,7 +45,7 @@ ) model_endpoint_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @model_endpoint_router_v1.post("/model-endpoints", response_model=CreateModelEndpointV1Response) diff --git a/model-engine/model_engine_server/api/tasks_v1.py b/model-engine/model_engine_server/api/tasks_v1.py index 05fdb270..25b97838 100644 --- a/model-engine/model_engine_server/api/tasks_v1.py +++ b/model-engine/model_engine_server/api/tasks_v1.py @@ -16,7 +16,7 @@ TaskStatus, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( EndpointUnsupportedInferenceTypeException, ObjectNotAuthorizedException, @@ -36,7 +36,7 @@ from sse_starlette.sse import EventSourceResponse inference_task_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @inference_task_router_v1.post("/async-tasks", response_model=CreateAsyncTaskV1Response) diff --git a/model-engine/model_engine_server/api/triggers_v1.py b/model-engine/model_engine_server/api/triggers_v1.py index 30f3310b..30c95acd 100644 --- a/model-engine/model_engine_server/api/triggers_v1.py +++ b/model-engine/model_engine_server/api/triggers_v1.py @@ -15,7 +15,7 @@ UpdateTriggerV1Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( CronSyntaxException, DockerImageNotFoundException, @@ -36,7 +36,7 @@ trigger_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @trigger_router_v1.post("/triggers", response_model=CreateTriggerV1Response) diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 64098ea1..de76ff96 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -7,9 +7,9 @@ from typing import Sequence import yaml -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) __all__: Sequence[str] = ( "DEFAULT_SERVICE_CONFIG_PATH", diff --git a/model-engine/model_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py index 64e25669..10bf0f0d 100644 --- a/model-engine/model_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -1,6 +1,6 @@ from typing import Optional, Union, cast -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ( CpuSpecificationType, GpuType, @@ -48,7 +48,7 @@ FORWARDER_STORAGE_USAGE = "1G" FORWARDER_WORKER_COUNT = 2 -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) def validate_resource_requests( diff --git a/model-engine/model_engine_server/common/service_requests.py b/model-engine/model_engine_server/common/service_requests.py index 9f5327d4..96aeb6f0 100644 --- a/model-engine/model_engine_server/common/service_requests.py +++ b/model-engine/model_engine_server/common/service_requests.py @@ -4,7 +4,7 @@ import requests from model_engine_server.common.errors import HTTP429Exception, UpstreamHTTPSvcError -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from tenacity import ( RetryError, Retrying, @@ -13,7 +13,7 @@ wait_exponential, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) SYNC_ENDPOINT_RETRIES = 10 # Must be an integer >= 0 SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 diff --git a/model-engine/model_engine_server/core/aws/secrets.py b/model-engine/model_engine_server/core/aws/secrets.py index 37ed25e1..3c39b259 100644 --- a/model-engine/model_engine_server/core/aws/secrets.py +++ b/model-engine/model_engine_server/core/aws/secrets.py @@ -6,9 +6,9 @@ import boto3 from botocore.exceptions import ClientError from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) @lru_cache(maxsize=2) diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index 53942fb2..ef08839c 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -12,9 +12,9 @@ from typing import Optional, Sequence import yaml -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) __all__: Sequence[str] = ( "DEFAULT_CONFIG_PATH", diff --git a/model-engine/model_engine_server/core/docker/docker_image.py b/model-engine/model_engine_server/core/docker/docker_image.py index f61294c3..66edb928 100644 --- a/model-engine/model_engine_server/core/docker/docker_image.py +++ b/model-engine/model_engine_server/core/docker/docker_image.py @@ -5,7 +5,6 @@ """ import base64 -import logging import os import pathlib import subprocess @@ -17,11 +16,11 @@ import click import docker from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import make_logger +from model_engine_server.core.loggers import logger_name, make_logger from .remote_build import MODELS_ROOT, build_remote_wrapper -logger = make_logger("ml_serve.docker_image", log_level=logging.INFO) +logger = make_logger(logger_name()) def _get_aws_creds() -> Dict[str, str]: diff --git a/model-engine/model_engine_server/core/loggers.py b/model-engine/model_engine_server/core/loggers.py index ce0ee847..30b0deee 100644 --- a/model-engine/model_engine_server/core/loggers.py +++ b/model-engine/model_engine_server/core/loggers.py @@ -33,7 +33,6 @@ "silence_chatty_logger", "loggers_at_level", # utils - "filename_wo_ext", "LoggerTagKey", "LoggerTagManager", ) @@ -202,7 +201,7 @@ def logger_name(*, fallback_name: Optional[str] = None) -> str: # in which case we use it's file name if hasattr(calling_module, "__file__"): - return filename_wo_ext(calling_module.__file__) # type: ignore + return _filename_wo_ext(calling_module.__file__) # type: ignore if fallback_name is not None: fallback_name = fallback_name.strip() if len(fallback_name) > 0: @@ -316,6 +315,6 @@ def loggers_at_level(*loggers_or_names, new_level: int) -> None: # type: ignore log.setLevel(level) -def filename_wo_ext(filename: str) -> str: +def _filename_wo_ext(filename: str) -> str: """Gets the filename, without the file extension, if present.""" return os.path.split(filename)[1].split(".", 1)[0] diff --git a/model-engine/model_engine_server/core/utils/timer.py b/model-engine/model_engine_server/core/utils/timer.py index 6936cfa7..53a6f8fe 100644 --- a/model-engine/model_engine_server/core/utils/timer.py +++ b/model-engine/model_engine_server/core/utils/timer.py @@ -26,9 +26,8 @@ class timer: # pylint: disable=invalid-name The other use case is to pass in a `name` and a `logger`. The timing will be recorded when the context block is exited: - >>> from model_engine_server.core.loggers import make_logger - >>> - >>> log = make_logger("my-main-program") + >>> from model_engine_server.core.loggers import make_logger, logger_name >>> + >>> log = make_logger(logger_name()) >>> >>> with timer(logger=log, name="timing-func-f"): >>> f() diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index 37496e19..0f882ea3 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -6,14 +6,14 @@ import sqlalchemy from model_engine_server.core.aws.secrets import get_key_file from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import async_scoped_session, async_sessionmaker, create_async_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) def get_key_file_name(environment: str) -> str: diff --git a/model-engine/model_engine_server/db/endpoint_row_lock.py b/model-engine/model_engine_server/db/endpoint_row_lock.py index 676546f6..b3d0e307 100644 --- a/model-engine/model_engine_server/db/endpoint_row_lock.py +++ b/model-engine/model_engine_server/db/endpoint_row_lock.py @@ -4,12 +4,12 @@ import time from contextlib import AbstractContextManager -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from sqlalchemy import BIGINT, cast, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.session import Session -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) BLOCKING_LOCK_TIMEOUT_SECONDS = 120 BLOCKING_LOCK_TIMEOUT_POLL_FREQ_SECONDS = 0.5 diff --git a/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py index 7ea13e11..d1f98cc9 100644 --- a/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py @@ -17,7 +17,7 @@ ) from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.authorization.live_authorization_module import ( LiveAuthorizationModule, ) @@ -41,7 +41,7 @@ validate_labels, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class CreateBatchJobV1UseCase: diff --git a/model-engine/model_engine_server/domain/use_cases/file_use_cases.py b/model-engine/model_engine_server/domain/use_cases/file_use_cases.py index a3ede743..47f3162a 100644 --- a/model-engine/model_engine_server/domain/use_cases/file_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/file_use_cases.py @@ -6,11 +6,11 @@ UploadFileResponse, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ObjectNotFoundException from model_engine_server.domain.gateways import FileStorageGateway -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class UploadFileUseCase: diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py index 039b15ad..689adfdb 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -13,7 +13,7 @@ ListFineTunesResponse, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import BatchJobStatus from model_engine_server.domain.exceptions import ( InvalidRequestException, @@ -34,7 +34,7 @@ # k8s labels need to be <= 62 characters, timestamp takes 13 characters, 2 characters for periods, # model name is currently 17 long, but want to add a bit of buffer. -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) def is_model_name_suffix_valid(model_name: str): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 82ca48e1..a3980fed 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -33,7 +33,7 @@ from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ( LLMInferenceFramework, LLMMetadata, @@ -72,7 +72,7 @@ validate_post_inference_hooks, ) -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) _SUPPORTED_MODEL_NAMES = { LLMInferenceFramework.DEEPSPEED: { diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index bab01204..8c58cdcc 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -24,7 +24,7 @@ from model_engine_server.common.resource_limits import MAX_ENDPOINT_SIZE, validate_resource_requests from model_engine_server.common.settings import REQUIRED_ENDPOINT_LABELS, RESTRICTED_ENDPOINT_LABELS from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.authorization.live_authorization_module import ( LiveAuthorizationModule, ) @@ -48,7 +48,7 @@ CONVERTED_FROM_ARTIFACT_LIKE_KEY = "_CONVERTED_FROM_ARTIFACT_LIKE" MODEL_BUNDLE_CHANGED_KEY = "_MODEL_BUNDLE_CHANGED" -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) def model_endpoint_entity_to_get_model_endpoint_response( diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index 445dd83c..354238c9 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -13,7 +13,7 @@ from model_engine_server.common.config import hmi_config from model_engine_server.common.constants import READYZ_FPATH from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.db.base import SessionAsyncNullPool from model_engine_server.domain.gateways import MonitoringMetricsGateway from model_engine_server.domain.repositories import DockerRepository @@ -55,7 +55,7 @@ ModelEndpointCacheWriteService, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) # This is the entrypoint to the k8s cacher try: diff --git a/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py b/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py index f26662c3..1c0048be 100644 --- a/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py +++ b/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py @@ -4,10 +4,10 @@ import model_engine_server.core.aws.storage_client as storage_client from model_engine_server.common.serialization_utils import b64_to_str from model_engine_server.core.aws.storage_client import s3_fileobj_exists -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.core.utils.url import parse_attachment_url -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) def main(input_local: str, local_file: str, remote_file: str, file_contents_b64encoded: str): diff --git a/model-engine/model_engine_server/inference/async_inference/tasks.py b/model-engine/model_engine_server/inference/async_inference/tasks.py index 074e12ef..6fce0588 100644 --- a/model-engine/model_engine_server/inference/async_inference/tasks.py +++ b/model-engine/model_engine_server/inference/async_inference/tasks.py @@ -6,7 +6,7 @@ from model_engine_server.common.constants import READYZ_FPATH from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.common.serialization_utils import str_to_bool -from model_engine_server.core.loggers import make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.core.utils.timer import timer from model_engine_server.domain.entities import ModelEndpointConfig from model_engine_server.inference.async_inference.celery import async_inference_service @@ -20,7 +20,7 @@ ) from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler -logger = make_logger(__name__) +logger = make_logger(logger_name()) # This should be safe as long as the celery workers are separate processes # (or we're using pool=solo) so they're not shared between threads diff --git a/model-engine/model_engine_server/inference/common.py b/model-engine/model_engine_server/inference/common.py index a1242371..2655eb12 100644 --- a/model-engine/model_engine_server/inference/common.py +++ b/model-engine/model_engine_server/inference/common.py @@ -12,12 +12,12 @@ from model_engine_server.common.dtos.tasks import EndpointPredictV1Request, RequestSchema from model_engine_server.common.io import open_wrapper from model_engine_server.common.serialization_utils import b64_to_python_json -from model_engine_server.core.loggers import make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.core.utils.timer import timer from model_engine_server.domain.entities import ModelEndpointConfig from model_engine_server.inference.service_requests import make_request -logger = make_logger(__name__) +logger = make_logger(logger_name()) s3_client = None diff --git a/model-engine/model_engine_server/inference/download_and_inject_bundle.py b/model-engine/model_engine_server/inference/download_and_inject_bundle.py index 74fb3b15..7fa1f726 100644 --- a/model-engine/model_engine_server/inference/download_and_inject_bundle.py +++ b/model-engine/model_engine_server/inference/download_and_inject_bundle.py @@ -2,9 +2,9 @@ import os import shutil -from model_engine_server.core.loggers import make_logger +from model_engine_server.core.loggers import logger_name, make_logger -logger = make_logger(__name__) +logger = make_logger(logger_name()) LOCAL_BUNDLE_PATH = os.getenv("LOCAL_BUNDLE_PATH", "") LOAD_MODEL_MODULE_PATH = os.getenv("LOAD_MODEL_MODULE_PATH", "") diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py index 00abaa5d..05dba306 100644 --- a/model-engine/model_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -4,14 +4,14 @@ import requests from model_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK from model_engine_server.common.dtos.tasks import EndpointPredictV1Request -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth from model_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( InferenceMonitoringMetricsGateway, ) from tenacity import Retrying, stop_after_attempt, wait_exponential -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class PostInferenceHook(ABC): diff --git a/model-engine/model_engine_server/inference/service_requests.py b/model-engine/model_engine_server/inference/service_requests.py index ad94c7f5..795ff91f 100644 --- a/model-engine/model_engine_server/inference/service_requests.py +++ b/model-engine/model_engine_server/inference/service_requests.py @@ -13,9 +13,9 @@ from model_engine_server.common.io import open_wrapper from model_engine_server.common.service_requests import make_sync_request_with_retries from model_engine_server.core.celery import TaskVisibility, celery_app -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) # TODO now that we're on SQS this won't work, since it connects to redis s3_bucket: str = os.environ.get("CELERY_S3_BUCKET") # type: ignore diff --git a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py index f25bece2..bec1c50c 100644 --- a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py @@ -6,7 +6,7 @@ from fastapi import BackgroundTasks, FastAPI, HTTPException, Response, status from model_engine_server.common.dtos.tasks import EndpointPredictV1Request -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.common import ( get_endpoint_config, load_predict_fn_or_cls, @@ -22,7 +22,7 @@ NAME, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class MultiprocessingConcurrencyLimiter: diff --git a/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py b/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py index 2f12c943..93b87ce8 100644 --- a/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py @@ -6,7 +6,7 @@ from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, - filename_wo_ext, + logger_name, make_logger, ) from model_engine_server.domain.entities import BatchJobSerializationFormat @@ -26,7 +26,7 @@ SHUTDOWN_GRACE_PERIOD = 60 -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class LiveBatchJobOrchestrationGateway(BatchJobOrchestrationGateway): diff --git a/model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py b/model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py index 7de8f8aa..ef7c506e 100644 --- a/model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py @@ -1,10 +1,10 @@ from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import BatchJobProgress from model_engine_server.infra.gateways import BatchJobProgressGateway from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) def get_batch_job_progress_location(user_id: str, batch_job_id: str): diff --git a/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py index 257f7cbd..2970bead 100644 --- a/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py @@ -3,7 +3,7 @@ from kubernetes_asyncio.client.rest import ApiException from model_engine_server.common import dict_not_none from model_engine_server.common.config import hmi_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob from model_engine_server.domain.exceptions import EndpointResourceInfraException from model_engine_server.domain.gateways.cron_job_gateway import CronJobGateway @@ -22,7 +22,7 @@ BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS = 10 -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) def _k8s_cron_job_name_from_id(trigger_id: str): diff --git a/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py index cb7154af..8ad2c09a 100644 --- a/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py @@ -14,7 +14,7 @@ from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, - filename_wo_ext, + logger_name, make_logger, ) from model_engine_server.domain.entities.batch_job_entity import BatchJobStatus, DockerImageBatchJob @@ -54,7 +54,7 @@ BATCH_JOB_MAX_RUNTIME_SECONDS = 86400 * 7 # 7 days BATCH_JOB_TTL_SECONDS_AFTER_FINISHED = 86400 * 3 # 3 days -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class K8sEnvDict(TypedDict): diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index dd3aec47..e1519abc 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -12,7 +12,7 @@ ) from model_engine_server.common.env_vars import CIRCLECI, LOCAL from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( NoHealthyUpstreamException, TooManyRequestsException, @@ -34,7 +34,7 @@ wait_exponential, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) SYNC_ENDPOINT_RETRIES = 8 # Must be an integer >= 0 SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index 0a763b1f..dd427f93 100644 --- a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -11,7 +11,7 @@ ) from model_engine_server.common.env_vars import CIRCLECI, LOCAL from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( NoHealthyUpstreamException, TooManyRequestsException, @@ -32,7 +32,7 @@ wait_exponential, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) SYNC_ENDPOINT_RETRIES = 8 # Must be an integer >= 0 SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 diff --git a/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py index fc5a7e54..84f5c011 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py @@ -3,7 +3,7 @@ from kubernetes_asyncio.client.rest import ApiException from model_engine_server.common.config import hmi_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( get_kubernetes_apps_client, load_k8s_yaml, @@ -13,7 +13,7 @@ compute_image_hash, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class CachedImages(TypedDict): diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 9e129be6..154fdb5b 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -24,7 +24,7 @@ ) from model_engine_server.common.serialization_utils import b64_to_python_json, str_to_bool from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ( ModelEndpointConfig, ModelEndpointDeploymentState, @@ -52,7 +52,7 @@ from packaging import version from pydantic.utils import deep_update -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) HTTP_PORT = 5000 diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index 838cc592..2d6b7410 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -2,7 +2,7 @@ from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ( ModelEndpointInfraState, ModelEndpointRecord, @@ -21,7 +21,7 @@ SQSEndpointResourceDelegate, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class SqsQueueInfo(QueueInfo): diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py index fae21d5e..6d9f6597 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py @@ -7,7 +7,7 @@ from aiobotocore.client import AioBaseClient from model_engine_server.common.config import hmi_config from model_engine_server.core.aws.roles import session -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import EndpointResourceInfraException from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( SQSEndpointResourceDelegate, @@ -15,7 +15,7 @@ ) from mypy_boto3_sqs.type_defs import GetQueueAttributesResultTypeDef -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) __all__: Sequence[str] = ("LiveSQSEndpointResourceDelegate",) diff --git a/model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py b/model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py index 69fef3de..bfd8cab0 100644 --- a/model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py @@ -5,7 +5,7 @@ from cachetools import TTLCache from model_engine_server.common import dict_not_none from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.db.endpoint_row_lock import AdvisoryLockContextManager, get_lock_key from model_engine_server.db.models import Endpoint as OrmModelEndpoint from model_engine_server.domain.entities import ModelEndpointRecord @@ -23,7 +23,7 @@ from sqlalchemy import or_, text from sqlalchemy.ext.asyncio import AsyncSession -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) CACHE_SIZE = 512 CACHE_TTL_SECONDS = 15.0 # Kubernetes caching is 15 seconds as well diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index 47ab21d3..d79f4c49 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -5,7 +5,7 @@ from model_engine_server.common.config import hmi_config from model_engine_server.common.env_vars import GIT_TAG from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import GpuType, ModelEndpointInfraState from model_engine_server.domain.repositories import DockerRepository from model_engine_server.infra.gateways.resources.image_cache_gateway import ( @@ -16,7 +16,7 @@ ModelEndpointRecordRepository, ) -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) IMAGES_TO_CACHE_PER_INSTANCE_TYPE = 32 diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py index 7a096c2a..76c6cd38 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py @@ -16,7 +16,7 @@ TaskStatus, ) from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ( BatchJobProgress, BatchJobRecord, @@ -39,7 +39,7 @@ BatchJobOrchestrationService, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) @dataclass diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_service.py index 9036de50..f9e6f904 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_service.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from model_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ( BatchJob, BatchJobProgress, @@ -17,7 +17,7 @@ BatchJobRecordRepository, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) DEFAULT_ENDPOINT_CPUS_BATCH_JOB = 3 DEFAULT_ENDPOINT_MEMORY_BATCH_JOB = "12Gi" diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 1a8c0c7d..64871730 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -19,7 +19,7 @@ from model_engine_server.common.io import open_wrapper from model_engine_server.common.serialization_utils import bool_to_str from model_engine_server.core.config import infra_config -from model_engine_server.core.loggers import make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.core.notification_gateway import NotificationApp, NotificationGateway from model_engine_server.core.utils.env import environment from model_engine_server.domain.entities import ( @@ -64,9 +64,9 @@ if LOCAL: with environment(KUBERNETES_SERVICE_HOST=None): - logger = make_logger("model_engine_server.service_builder") + logger = make_logger(logger_name()) else: - logger = make_logger("model_engine_server.service_builder") + logger = make_logger(logger_name()) __all__: Sequence[str] = ( "INITIAL_K8S_CACHE_TTL_SECONDS", diff --git a/model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py index 41b6e5a9..644e0df6 100644 --- a/model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py @@ -1,7 +1,7 @@ from typing import List, Optional from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ModelEndpoint from model_engine_server.domain.services import LLMModelEndpointService from model_engine_server.infra.repositories.model_endpoint_record_repository import ( @@ -9,7 +9,7 @@ ) from model_engine_server.infra.services import LiveModelEndpointService -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class LiveLLMModelEndpointService(LLMModelEndpointService): diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index dba1a055..ede0c39e 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -3,7 +3,7 @@ from datadog import statsd from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.common.settings import generate_deployment_name -from model_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, @@ -38,7 +38,7 @@ ModelEndpointRecordRepository, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) STATSD_CACHE_HIT_NAME = "launch.get_infra_state.cache_hit" STATSD_CACHE_MISS_NAME = "launch.get_infra_state.cache_miss" From 49eb538200705ec14cb537409ca2cd8d9ddf9c18 Mon Sep 17 00:00:00 2001 From: Sam Denton <106690182+sam-scale@users.noreply.github.com> Date: Fri, 20 Oct 2023 11:40:35 -0700 Subject: [PATCH 144/425] Fix up the mammoth max length issue. (#335) * Fix up the mammoth max length issue. * clean up typing --- .../use_cases/llm_model_endpoint_use_cases.py | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index a3980fed..afcb1ecf 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -124,6 +124,12 @@ "mistral-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", "falcon-180b": "tiiuae/falcon-180B", "falcon-180b-chat": "tiiuae/falcon-180B-chat", + "code-llama-7b": "codellama/CodeLlama-7b-hf", + "code-llama-13b": "codellama/CodeLlama-13b-hf", + "code-llama-34b": "codellama/CodeLlama-34b-hf", + "mammoth-coder-llama-2-7b": "TIGER-Lab/MAmmoTH-Coder-7B", + "mammoth-coder-llama-2-13b": "TIGER-Lab/MAmmoTH-Coder-13B", + "mammoth-coder-llama-2-34b": "TIGER-Lab/MAmmoTH-Coder-34B", }, LLMInferenceFramework.LIGHTLLM: { "llama-7b": "decapoda-research/llama-7b-hf", @@ -143,6 +149,20 @@ LLMInferenceFramework.LIGHTLLM: [], } +# We need a dict where if we need to override we can +# NOTE: These are in *descending* order of priority. e.g. if you see 'mammoth-coder' +# you'll use that override and not listen to the 'llama-2' override +_VLLM_MODEL_LENGTH_OVERRIDES: Dict[str, Dict[str, int]] = { + "mammoth-coder": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, + # Based on config here: https://huggingface.co/TIGER-Lab/MAmmoTH-Coder-7B/blob/main/config.json#L12 + # Can also see 13B, 34B there too + "code-llama": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, + # Based on config here: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json#L12 + # Can also see 13B, 34B there too + "llama-2": {"max_model_len": 4096, "max_num_batched_tokens": 4096}, + "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, +} + NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes @@ -514,13 +534,14 @@ async def create_vllm_bundle( ): command = [] - max_num_batched_tokens = 2560 # vLLM's default - max_model_len = None - if "llama-2" in model_name: - max_num_batched_tokens = 4096 # Need to be bigger than model's context window - if "mistral" in model_name: - max_num_batched_tokens = 8000 - max_model_len = 8000 + max_num_batched_tokens: int = 2560 # vLLM's default + max_model_len: Optional[int] = None + + for key, value in _VLLM_MODEL_LENGTH_OVERRIDES.items(): + if key in model_name: + max_model_len = value["max_model_len"] + max_num_batched_tokens = value["max_num_batched_tokens"] + break subcommands = [] if checkpoint_path is not None: From 271156c6e4e0438eccfd0cd869aed9f18b6565c5 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 20 Oct 2023 13:48:14 -0700 Subject: [PATCH 145/425] Add docs for Model.create, update default values and fix per_worker concurrency (#332) * Add docs for Model.create * comments * comments * fixes * add examples * fix * comments --- clients/python/llmengine/data_types.py | 4 +- clients/python/llmengine/model.py | 132 ++++++++++++++---- docs/api/python_client.md | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 2 +- .../infra/gateways/k8s_resource_parser.py | 16 +-- .../gateways/test_k8s_resource_parser.py | 12 +- 6 files changed, 120 insertions(+), 47 deletions(-) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 07612420..f5a5a0b2 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -37,8 +37,6 @@ class GpuType(str, Enum): class ModelEndpointType(str, Enum): - ASYNC = "async" - SYNC = "sync" STREAMING = "streaming" @@ -135,7 +133,7 @@ class CreateLLMEndpointRequest(BaseModel): # LLM specific fields model_name: str source: LLMSource = LLMSource.HUGGING_FACE - inference_framework: LLMInferenceFramework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE + inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM inference_framework_image_tag: str num_shards: int = 1 """ diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index b5ed181a..3bd88944 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -36,18 +36,18 @@ def create( model: str, inference_framework_image_tag: str, source: LLMSource = LLMSource.HUGGING_FACE, - inference_framework: LLMInferenceFramework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE, - num_shards: int = 4, + inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM, + num_shards: int = 1, quantize: Optional[Quantization] = None, checkpoint_path: Optional[str] = None, # General endpoint fields cpus: int = 8, - memory: str = "40Gi", - storage: str = "96Gi", + memory: str = "24Gi", + storage: str = "40Gi", gpus: int = 1, min_workers: int = 0, max_workers: int = 1, - per_worker: int = 10, + per_worker: int = 2, endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING, gpu_type: Optional[str] = "nvidia-ampere-a10", high_priority: Optional[bool] = False, @@ -57,7 +57,8 @@ def create( labels: Optional[Dict[str, str]] = None, ) -> CreateLLMEndpointResponse: """ - Create an LLM model. Note: This feature is only available for self-hosted users. + Create an LLM model. Note: This API is only available for self-hosted users. + Args: name (`str`): Name of the endpoint @@ -72,32 +73,34 @@ def create( Source of the LLM. Currently only HuggingFace is supported inference_framework (`LLMInferenceFramework`): - Inference framework for the LLM. Currently only DeepSpeed is supported + Inference framework for the LLM. Current supported frameworks are + LLMInferenceFramework.DEEPSPEED, LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + LLMInferenceFramework.VLLM and LLMInferenceFramework.LIGHTLLM num_shards (`int`): Number of shards for the LLM. When bigger than 1, LLM will be sharded to multiple GPUs. Number of GPUs must be equal or larger than num_shards. - Only affects behavior for text-generation-inference models quantize (`Optional[Quantization]`): - Quantization for the LLM. Only affects behavior for text-generation-inference models + Quantization method for the LLM. `text_generation_inference` supports `bitsandbytes` and `vllm` supports `awq`. checkpoint_path (`Optional[str]`): - Path to the checkpoint for the LLM. For now we only support loading a tar file from AWS S3. - Safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be slower). - Only affects behavior for text-generation-inference models + Remote path to the checkpoint for the LLM. LLM engine must have permission to access the given path. + Can be either a folder or a tar file. Folder is preferred since we don't need to untar and model loads faster. + For model weights, safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be longer). cpus (`int`): Number of cpus each worker should get, e.g. 1, 2, etc. This must be greater - than or equal to 1 + than or equal to 1. Recommendation is set it to 8 * GPU count. memory (`str`): Amount of memory each worker should get, e.g. "4Gi", "512Mi", etc. This must - be a positive amount of memory + be a positive amount of memory. Recommendation is set it to 24Gi * GPU count. storage (`str`): Amount of local ephemeral storage each worker should get, e.g. "4Gi", - "512Mi", etc. This must be a positive amount of storage + "512Mi", etc. This must be a positive amount of storage. + Recommendataion is 40Gi for 7B models, 80Gi for 13B models and 200Gi for 70B models. gpus (`int`): Number of gpus each worker should get, e.g. 0, 1, etc. @@ -105,8 +108,10 @@ def create( min_workers (`int`): The minimum number of workers. Must be greater than or equal to 0. This should be determined by computing the minimum throughput of your workload and - dividing it by the throughput of a single worker. This field must be at least ``1`` - for synchronous endpoints + dividing it by the throughput of a single worker. When this number is 0, + max_workers must be 1, and the endpoint will autoscale between + 0 and 1 pods. When this number is greater than 0, max_workers can be any number + greater or equal to min_workers. max_workers (`int`): The maximum number of workers. Must be greater than or equal to 0, @@ -116,25 +121,22 @@ def create( per_worker (`int`): The maximum number of concurrent requests that an individual worker can - service. Launch automatically scales the number of workers for the endpoint so that + service. LLM engine automatically scales the number of workers for the endpoint so that each worker is processing ``per_worker`` requests, subject to the limits defined by ``min_workers`` and ``max_workers`` - - If the average number of concurrent requests per worker is lower than ``per_worker``, then the number of workers will be reduced. - Otherwise, if the average number of concurrent requests per worker is higher than ``per_worker``, then the number of workers will be increased to meet the elevated traffic. - Here is our recommendation for computing ``per_worker``: - 1. Compute ``min_workers`` and ``max_workers`` per your minimum and maximum throughput requirements. 2. Determine a value for the maximum number of concurrent requests in the workload. Divide this number by ``max_workers``. Doing this ensures that the number of workers will "climb" to ``max_workers``. endpoint_type (`ModelEndpointType`): - ``"sync"``, ``"async"`` or ``"streaming"``. + Currently only ``"streaming"`` endpoints are supported. gpu_type (`Optional[str]`): If specifying a non-zero number of gpus, this controls the type of gpu @@ -142,6 +144,8 @@ def create( - ``nvidia-tesla-t4`` - ``nvidia-ampere-a10`` + - ``nvidia-ampere-a100`` + - ``nvidia-ampere-a100e`` high_priority (`Optional[bool]`): Either ``True`` or ``False``. Enabling this will allow the created @@ -151,7 +155,7 @@ def create( List of hooks to trigger after inference tasks are served default_callback_url (`Optional[str]`): - The default callback url to use for async endpoints. + The default callback url to use for sync completion requests. This can be overridden in the task parameters for each individual task. post_inference_hooks must contain "callback" for the callback to be triggered @@ -159,11 +163,89 @@ def create( If ``True``, this endpoint will be available to all user IDs for inference - labels (`Optional[Dict[str, str]]`): An optional dictionary of key/value pairs to associate with this endpoint Returns: - CreateLLMEndpointResponse: creation task ID of the created Model. + CreateLLMEndpointResponse: creation task ID of the created Model. Currently not used. + + === "Create Llama 2 7B model in Python" + ```python + from llmengine import Model + + response = Model.create( + name="llama-2-7b-test" + model="llama-2-7b", + inference_framework_image_tag="0.2.1.post1", + inference_framework=LLMInferenceFramework.VLLM, + num_shards=1, + checkpoint_path="s3://path/to/checkpoint", + cpus=8, + memory="24Gi", + storage="40Gi", + gpus=1, + min_workers=0, + max_workers=1, + per_worker=10, + endpoint_type=ModelEndpointType.STREAMING, + gpu_type="nvidia-ampere-a10", + public_inference=False, + ) + + print(response.json()) + ``` + + === "Create Llama 2 13B model in Python" + ```python + from llmengine import Model + + response = Model.create( + name="llama-2-13b-test" + model="llama-2-13b", + inference_framework_image_tag="0.2.1.post1", + inference_framework=LLMInferenceFramework.VLLM, + num_shards=2, + checkpoint_path="s3://path/to/checkpoint", + cpus=16, + memory="48Gi", + storage="80Gi", + gpus=2, + min_workers=0, + max_workers=1, + per_worker=10, + endpoint_type=ModelEndpointType.STREAMING, + gpu_type="nvidia-ampere-a10", + public_inference=False, + ) + + print(response.json()) + ``` + + === "Create Llama 2 70B model with 8bit quantization in Python" + ```python + from llmengine import Model + + response = Model.create( + name="llama-2-70b-test" + model="llama-2-70b", + inference_framework_image_tag="0.9.4", + inference_framework=LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + num_shards=4, + quantize="bitsandbytes", + checkpoint_path="s3://path/to/checkpoint", + cpus=40, + memory="96Gi", + storage="200Gi", + gpus=4, + min_workers=0, + max_workers=1, + per_worker=10, + endpoint_type=ModelEndpointType.STREAMING, + gpu_type="nvidia-ampere-a10", + public_inference=False, + ) + + print(response.json()) + ``` """ post_inference_hooks_strs = None if post_inference_hooks is not None: diff --git a/docs/api/python_client.md b/docs/api/python_client.md index 427ae8b6..bdbc6f3e 100644 --- a/docs/api/python_client.md +++ b/docs/api/python_client.md @@ -18,6 +18,7 @@ ::: llmengine.Model selection: members: + - create - get - list - delete diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index afcb1ecf..13f48363 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -717,7 +717,7 @@ async def execute( ]: if request.endpoint_type != ModelEndpointType.STREAMING: raise ObjectHasInvalidValueException( - f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference and vLLM." + f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM and LightLLM." ) bundle = await self.create_model_bundle( diff --git a/model-engine/model_engine_server/infra/gateways/k8s_resource_parser.py b/model-engine/model_engine_server/infra/gateways/k8s_resource_parser.py index a8626f65..947a734d 100644 --- a/model-engine/model_engine_server/infra/gateways/k8s_resource_parser.py +++ b/model-engine/model_engine_server/infra/gateways/k8s_resource_parser.py @@ -1,10 +1,8 @@ import hashlib +import math import re from typing import Union -MAX_CONCURRENCY_TO_TARGET_CONCURRENCY_RATIO = 2.0 - - # found this regex floating around somewhere, probably validates k8s requests in general: # '^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$' @@ -57,12 +55,12 @@ def parse_mem_request(req: str): def get_node_port(service_name: str) -> int: """Hashes the service name to a port number in the range [30000, 32767]""" - return int(hashlib.md5(service_name.encode()).hexdigest(), 16) % (32768 - 30000) + 30000 + return int(hashlib.sha256(service_name.encode()).hexdigest(), 16) % (32768 - 30000) + 30000 def get_target_concurrency_from_per_worker_value(per_worker: int) -> float: """Returns the target concurrency given a per-worker value""" - return per_worker / MAX_CONCURRENCY_TO_TARGET_CONCURRENCY_RATIO + return per_worker def get_per_worker_value_from_target_concurrency(concurrency: Union[str, int, float]) -> int: @@ -70,13 +68,7 @@ def get_per_worker_value_from_target_concurrency(concurrency: Union[str, int, fl Inverse of get_target_concurrency_from_per_worker_value """ - return int( - round( - parse_cpu_request(str(concurrency)) - * MAX_CONCURRENCY_TO_TARGET_CONCURRENCY_RATIO - / 1000.0 - ) - ) + return int(math.ceil(parse_cpu_request(str(concurrency)) / 1000.0)) def format_bytes(num_bytes) -> str: diff --git a/model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py b/model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py index 7f59350a..dd3462d5 100644 --- a/model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py +++ b/model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py @@ -99,13 +99,13 @@ def test_parse_mem_request(): @pytest.mark.parametrize( "input_value", [ - "1", - "1.5", - "500m", - "5500m", + ("1", "1"), + ("1.5", "2"), + ("500m", "1"), + ("5500m", "6"), ], ) def test_get_target_concurrency_to_per_worker_value(input_value): assert get_target_concurrency_from_per_worker_value( - parse_cpu_request(str(get_per_worker_value_from_target_concurrency(input_value))) - ) == parse_cpu_request(input_value) + parse_cpu_request(str(get_per_worker_value_from_target_concurrency(input_value[0]))) + ) == parse_cpu_request(input_value[1]) From 567ae66911085ec7e233292872658ff5ff6b1664 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Fri, 20 Oct 2023 15:14:06 -0700 Subject: [PATCH 146/425] updating docs to add codellama models (#343) --- docs/model_zoo.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 7c438019..05f24d60 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -20,6 +20,9 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | | `mistral-7b` | ✅ | ✅ | vllm | | `mistral-7b-instruct` | ✅ | ✅ | vllm | +| `code-llama-7b` | ✅ | | text-generation-inference, vllm | +| `code-llama-13b` | ✅ | | text-generation-inference, vllm | +| `code-llama-34b` | ✅ | | text-generation-inference, vllm | ## Usage From c0d30b0e65fbaf8b21a45e0a2a622f3f990f0b8d Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 20 Oct 2023 15:44:57 -0700 Subject: [PATCH 147/425] Add PodDisruptionBudget to model engine (#342) * Add PDB to model-engine * missing file --- .../templates/pod_disruption_budget.yaml | 17 +++++++++++++++++ charts/model-engine/values_circleci.yaml | 4 ++++ charts/model-engine/values_sample.yaml | 4 ++++ 3 files changed, 25 insertions(+) create mode 100644 charts/model-engine/templates/pod_disruption_budget.yaml diff --git a/charts/model-engine/templates/pod_disruption_budget.yaml b/charts/model-engine/templates/pod_disruption_budget.yaml new file mode 100644 index 00000000..0959caec --- /dev/null +++ b/charts/model-engine/templates/pod_disruption_budget.yaml @@ -0,0 +1,17 @@ +{{- if .Values.podDisruptionBudget.enabled }} +apiVersion: policy/v1 +kind: PodDisruptionBudget +metadata: + name: {{ include "modelEngine.fullname" . }} + labels: + {{- include "modelEngine.labels" . | nindent 4 }} +spec: + {{- if .Values.podDisruptionBudget.minAvailable }} + minAvailable: {{ .Values.podDisruptionBudget.minAvailable }} + {{- else }} + maxUnavailable: {{ .Values.podDisruptionBudget.maxUnavailable }} + {{- end }} + selector: + matchLabels: + {{- include "modelEngine.selectorLabels.gateway" . | nindent 6 }} +{{- end }} diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 657c5f50..1562ffbc 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -63,6 +63,10 @@ autoscaling: prewarming: enabled: false +podDisruptionBudget: + enabled: true + minAvailable: 1 + resources: requests: cpu: 2 diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 70d740cf..61ed8404 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -63,6 +63,10 @@ autoscaling: prewarming: enabled: false +podDisruptionBudget: + enabled: true + minAvailable: 1 + # resources specify the k8s resources for LLM Engine server deployments (e.g gateway, cache, and builder deployments) resources: requests: From 280b86735dc48898d641d62f111758152c1effbe Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:22:08 -0700 Subject: [PATCH 148/425] Allow auth to accept API keys (#326) * allow auth to accept API keys * fix try-except * refactor auth repo methods * test change * use auth plugin * move dd monitoring gateway * types --- .../model_engine_server/api/dependencies.py | 44 +++++++++++++++---- .../core/auth/authentication_repository.py | 20 ++------- .../auth/fake_authentication_repository.py | 18 +++----- .../entrypoints/k8s_cache.py | 15 ++----- .../start_batch_job_orchestration.py | 13 ++---- .../infra/gateways/__init__.py | 2 - .../datadog_monitoring_metrics_gateway.py | 38 ---------------- .../service_builder/tasks_v1.py | 18 ++------ model-engine/tests/unit/api/conftest.py | 8 ++-- 9 files changed, 62 insertions(+), 114 deletions(-) delete mode 100644 model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 30062252..08f362cc 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -1,7 +1,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Callable, Iterator, Optional +from typing import Callable, Optional import aioredis from fastapi import Depends, HTTPException, status @@ -26,6 +26,7 @@ FileStorageGateway, LLMArtifactGateway, ModelPrimitiveGateway, + MonitoringMetricsGateway, TaskQueueGateway, ) from model_engine_server.domain.repositories import ( @@ -134,6 +135,24 @@ class ExternalInterfaces: cron_job_gateway: CronJobGateway +def get_default_monitoring_metrics_gateway() -> MonitoringMetricsGateway: + monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + return monitoring_metrics_gateway + + +def get_monitoring_metrics_gateway() -> MonitoringMetricsGateway: + try: + from plugins.dependencies import ( + get_monitoring_metrics_gateway as get_custom_monitoring_metrics_gateway, + ) + + return get_custom_monitoring_metrics_gateway() + except ModuleNotFoundError: + return get_default_monitoring_metrics_gateway() + finally: + pass + + def _get_external_interfaces( read_only: bool, session: Callable[[], AsyncSession] ) -> ExternalInterfaces: @@ -144,7 +163,7 @@ def _get_external_interfaces( redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS) redis_24h_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS_24H) sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + monitoring_metrics_gateway = get_monitoring_metrics_gateway() model_endpoint_record_repo = DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, @@ -300,12 +319,21 @@ async def get_external_interfaces_read_only(): pass -def get_auth_repository() -> Iterator[AuthenticationRepository]: +def get_default_auth_repository() -> AuthenticationRepository: + auth_repo = FakeAuthenticationRepository() + return auth_repo + + +async def get_auth_repository(): """ Dependency for an AuthenticationRepository. This implementation returns a fake repository. """ try: - yield FakeAuthenticationRepository() + from plugins.dependencies import get_auth_repository as get_custom_auth_repository + + yield get_custom_auth_repository() + except ModuleNotFoundError: + yield get_default_auth_repository() finally: pass @@ -318,15 +346,15 @@ async def verify_authentication( Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise, raises a 401. """ - user_id = credentials.username if credentials is not None else None - if user_id is None: + username = credentials.username if credentials is not None else None + if username is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="No user id was passed in", + detail="No authentication was passed in", headers={"WWW-Authenticate": "Basic"}, ) - auth = await auth_repo.get_auth_from_user_id_async(user_id=user_id) + auth = await auth_repo.get_auth_from_username_async(username=username) if not auth: raise HTTPException( diff --git a/model-engine/model_engine_server/core/auth/authentication_repository.py b/model-engine/model_engine_server/core/auth/authentication_repository.py index 1b1d1f77..f3a60847 100644 --- a/model-engine/model_engine_server/core/auth/authentication_repository.py +++ b/model-engine/model_engine_server/core/auth/authentication_repository.py @@ -24,25 +24,13 @@ def is_allowed_team(team: str) -> bool: """ @abstractmethod - def get_auth_from_user_id(self, user_id: str) -> Optional[User]: + def get_auth_from_username(self, username: str) -> Optional[User]: """ - Returns authentication information associated with a given user_id. + Returns authentication information associated with a given Basic Auth username. """ @abstractmethod - def get_auth_from_api_key(self, api_key: str) -> Optional[User]: + async def get_auth_from_username_async(self, username: str) -> Optional[User]: """ - Returns authentication information associated with a given api_key. - """ - - @abstractmethod - async def get_auth_from_user_id_async(self, user_id: str) -> Optional[User]: - """ - Returns authentication information associated with a given user_id. - """ - - @abstractmethod - async def get_auth_from_api_key_async(self, api_key: str) -> Optional[User]: - """ - Returns authentication information associated with a given api_key. + Returns authentication information associated with a given Basic Auth username. """ diff --git a/model-engine/model_engine_server/core/auth/fake_authentication_repository.py b/model-engine/model_engine_server/core/auth/fake_authentication_repository.py index d3e5f4c1..ff38e768 100644 --- a/model-engine/model_engine_server/core/auth/fake_authentication_repository.py +++ b/model-engine/model_engine_server/core/auth/fake_authentication_repository.py @@ -13,16 +13,10 @@ def __init__(self, user_team_override: Optional[Dict[str, str]] = None): def is_allowed_team(team: str) -> bool: return True - def get_auth_from_user_id(self, user_id: str) -> Optional[User]: - team_id = self.user_team_override.get(user_id, user_id) - return User(user_id=user_id, team_id=team_id, is_privileged_user=True) + def get_auth_from_username(self, username: str) -> Optional[User]: + team_id = self.user_team_override.get(username, username) + return User(user_id=username, team_id=team_id, is_privileged_user=True) - async def get_auth_from_user_id_async(self, user_id: str) -> Optional[User]: - team_id = self.user_team_override.get(user_id, user_id) - return User(user_id=user_id, team_id=team_id, is_privileged_user=True) - - def get_auth_from_api_key(self, api_key: str) -> Optional[User]: - return User(user_id=api_key, team_id=api_key, is_privileged_user=True) - - async def get_auth_from_api_key_async(self, api_key: str) -> Optional[User]: - return User(user_id=api_key, team_id=api_key, is_privileged_user=True) + async def get_auth_from_username_async(self, username: str) -> Optional[User]: + team_id = self.user_team_override.get(username, username) + return User(user_id=username, team_id=team_id, is_privileged_user=True) diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index 354238c9..12a6e82f 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -10,17 +10,13 @@ from kubernetes import config as kube_config from kubernetes.config.config_exception import ConfigException +from model_engine_server.api.dependencies import get_monitoring_metrics_gateway from model_engine_server.common.config import hmi_config from model_engine_server.common.constants import READYZ_FPATH -from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH +from model_engine_server.common.env_vars import CIRCLECI from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.db.base import SessionAsyncNullPool -from model_engine_server.domain.gateways import MonitoringMetricsGateway from model_engine_server.domain.repositories import DockerRepository -from model_engine_server.infra.gateways import ( - DatadogMonitoringMetricsGateway, - FakeMonitoringMetricsGateway, -) from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) @@ -95,16 +91,13 @@ async def main(args: Any): logger.info(f"Using cache redis url {redis_url}") cache_repo = RedisModelEndpointCacheRepository(redis_info=redis_url) - monitoring_metrics_gateway: MonitoringMetricsGateway - if SKIP_AUTH: - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() - else: - monitoring_metrics_gateway = DatadogMonitoringMetricsGateway() + monitoring_metrics_gateway = get_monitoring_metrics_gateway() endpoint_record_repo = DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, session=SessionAsyncNullPool, read_only=True, ) + sqs_delegate: SQSEndpointResourceDelegate if CIRCLECI: sqs_delegate = FakeSQSEndpointResourceDelegate() diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index 6139a2a0..c52442ab 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -4,16 +4,14 @@ from datetime import timedelta import aioredis +from model_engine_server.api.dependencies import get_monitoring_metrics_gateway from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.model_endpoints import BrokerType -from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH +from model_engine_server.common.env_vars import CIRCLECI from model_engine_server.db.base import SessionAsyncNullPool from model_engine_server.domain.entities import BatchJobSerializationFormat -from model_engine_server.domain.gateways import MonitoringMetricsGateway from model_engine_server.infra.gateways import ( CeleryTaskQueueGateway, - DatadogMonitoringMetricsGateway, - FakeMonitoringMetricsGateway, LiveAsyncModelEndpointInferenceGateway, LiveBatchJobProgressGateway, LiveModelEndpointInfraGateway, @@ -58,11 +56,8 @@ async def run_batch_job( redis = aioredis.Redis(connection_pool=pool) redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS) sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) - monitoring_metrics_gateway: MonitoringMetricsGateway - if SKIP_AUTH: - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() - else: - monitoring_metrics_gateway = DatadogMonitoringMetricsGateway() + + monitoring_metrics_gateway = get_monitoring_metrics_gateway() model_endpoint_record_repo = DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False ) diff --git a/model-engine/model_engine_server/infra/gateways/__init__.py b/model-engine/model_engine_server/infra/gateways/__init__.py index 0f2b5faa..de4eb6b7 100644 --- a/model-engine/model_engine_server/infra/gateways/__init__.py +++ b/model-engine/model_engine_server/infra/gateways/__init__.py @@ -3,7 +3,6 @@ from .batch_job_orchestration_gateway import BatchJobOrchestrationGateway from .batch_job_progress_gateway import BatchJobProgressGateway from .celery_task_queue_gateway import CeleryTaskQueueGateway -from .datadog_monitoring_metrics_gateway import DatadogMonitoringMetricsGateway from .fake_model_primitive_gateway import FakeModelPrimitiveGateway from .fake_monitoring_metrics_gateway import FakeMonitoringMetricsGateway from .live_async_model_endpoint_inference_gateway import LiveAsyncModelEndpointInferenceGateway @@ -26,7 +25,6 @@ "BatchJobOrchestrationGateway", "BatchJobProgressGateway", "CeleryTaskQueueGateway", - "DatadogMonitoringMetricsGateway", "FakeModelPrimitiveGateway", "FakeMonitoringMetricsGateway", "LiveAsyncModelEndpointInferenceGateway", diff --git a/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py deleted file mode 100644 index 4dc73f69..00000000 --- a/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py +++ /dev/null @@ -1,38 +0,0 @@ -from datadog import statsd -from model_engine_server.core.config import infra_config -from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway - - -class DatadogMonitoringMetricsGateway(MonitoringMetricsGateway): - def __init__(self): - self.tags = [f"env:{infra_config().env}"] - - def emit_attempted_build_metric(self): - statsd.increment("scale_launch.service_builder.attempt", tags=self.tags) - - def emit_successful_build_metric(self): - statsd.increment("scale_launch.service_builder.success", tags=self.tags) - - def emit_build_time_metric(self, duration_seconds: float): - statsd.distribution( - "scale_launch.service_builder.endpoint_build_time", duration_seconds, tags=self.tags - ) - - def emit_image_build_cache_hit_metric(self, image_type: str): - statsd.increment( - f"scale_launch.service_builder.{image_type}_image_cache_hit", tags=self.tags - ) - - def emit_image_build_cache_miss_metric(self, image_type: str): - statsd.increment( - f"scale_launch.service_builder.{image_type}_image_cache_miss", tags=self.tags - ) - - def emit_docker_failed_build_metric(self): - statsd.increment("scale_launch.service_builder.docker_failed", tags=self.tags) - - def emit_database_cache_hit_metric(self): - statsd.increment("scale_launch.database_cache.hit", tags=self.tags) - - def emit_database_cache_miss_metric(self): - statsd.increment("scale_launch.database_cache.miss", tags=self.tags) diff --git a/model-engine/model_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py index f7dc0d4e..b8b38a28 100644 --- a/model-engine/model_engine_server/service_builder/tasks_v1.py +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -4,21 +4,17 @@ import aioredis from celery.signals import worker_process_init +from model_engine_server.api.dependencies import get_monitoring_metrics_gateway from model_engine_server.common.config import hmi_config from model_engine_server.common.constants import READYZ_FPATH from model_engine_server.common.dtos.endpoint_builder import ( BuildEndpointRequest, BuildEndpointResponse, ) -from model_engine_server.common.env_vars import CIRCLECI, SKIP_AUTH +from model_engine_server.common.env_vars import CIRCLECI from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway from model_engine_server.db.base import SessionAsyncNullPool -from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway -from model_engine_server.infra.gateways import ( - DatadogMonitoringMetricsGateway, - FakeMonitoringMetricsGateway, - S3FilesystemGateway, -) +from model_engine_server.infra.gateways import S3FilesystemGateway from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( FakeSQSEndpointResourceDelegate, ) @@ -61,14 +57,8 @@ def get_live_endpoint_builder_service( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) notification_gateway = FakeNotificationGateway() - monitoring_metrics_gateway: MonitoringMetricsGateway - if SKIP_AUTH: - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() - else: - monitoring_metrics_gateway = DatadogMonitoringMetricsGateway() - + monitoring_metrics_gateway = get_monitoring_metrics_gateway() docker_repository = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository() - service = LiveEndpointBuilderService( docker_repository=docker_repository, resource_gateway=LiveEndpointResourceGateway( diff --git a/model-engine/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py index fa6f08fa..9b085915 100644 --- a/model-engine/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -69,11 +69,11 @@ def fake_verify_authentication( Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise, raises a 401. """ - auth_user_id = credentials.username if credentials is not None else None - if not auth_user_id: - raise HTTPException(status_code=401, detail="No user id was passed in") + auth_username = credentials.username if credentials is not None else None + if not auth_username: + raise HTTPException(status_code=401, detail="No authentication was passed in") - auth = auth_repo.get_auth_from_user_id(user_id=auth_user_id) + auth = auth_repo.get_auth_from_username(username=auth_username) if not auth: raise HTTPException(status_code=401, detail="Could not authenticate user") From aa06906a3e0e8486b8ffdff3aee518da73e8b9a6 Mon Sep 17 00:00:00 2001 From: William Song Date: Fri, 20 Oct 2023 13:52:05 -1000 Subject: [PATCH 149/425] Add job_name in build logs for easier debugging (#340) --- .../common/dtos/docker_repository.py | 1 + .../core/docker/remote_build.py | 7 ++++--- .../infra/repositories/ecr_docker_repository.py | 4 +++- .../services/live_endpoint_builder_service.py | 16 +++++++++------- model-engine/tests/unit/conftest.py | 2 +- .../test_live_endpoint_builder_service.py | 2 +- 6 files changed, 19 insertions(+), 13 deletions(-) diff --git a/model-engine/model_engine_server/common/dtos/docker_repository.py b/model-engine/model_engine_server/common/dtos/docker_repository.py index 5548eead..6e4651d9 100644 --- a/model-engine/model_engine_server/common/dtos/docker_repository.py +++ b/model-engine/model_engine_server/common/dtos/docker_repository.py @@ -17,6 +17,7 @@ class BuildImageRequest(BaseModel): class BuildImageResponse(BaseModel): status: bool logs: str + job_name: str # TODO: We may want to add a DTO for streaming logs from the docker build to users. diff --git a/model-engine/model_engine_server/core/docker/remote_build.py b/model-engine/model_engine_server/core/docker/remote_build.py index 6261334e..26d58721 100644 --- a/model-engine/model_engine_server/core/docker/remote_build.py +++ b/model-engine/model_engine_server/core/docker/remote_build.py @@ -48,6 +48,7 @@ class BuildResult: status: bool logs: str + job_name: str def zip_context( @@ -398,13 +399,13 @@ def cleanup_logs_process(): ) elif event["object"].status.phase == "Succeeded": cleanup_logs_process() - return BuildResult(status=True, logs=_read_pod_logs(pod_name)) + return BuildResult(status=True, logs=_read_pod_logs(pod_name), job_name=job_name) elif event["object"].status.phase == "Failed": cleanup_logs_process() - return BuildResult(status=False, logs=_read_pod_logs(pod_name)) + return BuildResult(status=False, logs=_read_pod_logs(pod_name), job_name=job_name) if logs_process is not None: logs_process.kill() - return BuildResult(status=False, logs=_read_pod_logs(pod_name)) + return BuildResult(status=False, logs=_read_pod_logs(pod_name), job_name=job_name) def build_remote_block( diff --git a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py index 8ca5dd61..47aeb61c 100644 --- a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py @@ -47,4 +47,6 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: folders_to_include=folders_to_include, build_args=build_args, ) - return BuildImageResponse(status=build_result.status, logs=build_result.logs) + return BuildImageResponse( + status=build_result.status, logs=build_result.logs, job_name=build_result.job_name + ) diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 64871730..7073b39f 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -174,7 +174,7 @@ async def build_endpoint( base_image_params = self.get_base_image_params( build_endpoint_request, logger_adapter ) - logger.info(f"base_image_params: {base_image_params}") + logger_adapter.info(f"base_image_params: {base_image_params}") base_image = await self._build_image( base_image_params, build_endpoint_request, @@ -227,7 +227,9 @@ async def build_endpoint( if os.path.exists(model_bundle_path): os.remove(model_bundle_path) else: - logger.error(f"No bundle object found at {model_bundle_path}!") + logger_adapter.error( + f"No bundle object found at {model_bundle_path}!" + ) except DockerBuildFailedException: log_error("Failed to build base and user docker images") @@ -493,8 +495,8 @@ def get_base_image_params( inference_folder = "model-engine/model_engine_server/inference" base_path: str = os.getenv("WORKSPACE") # type: ignore - logger.info(f"inference_folder: {inference_folder}") - logger.info(f"dockerfile: {inference_folder}/{dockerfile}") + logger_adapter.info(f"inference_folder: {inference_folder}") + logger_adapter.info(f"dockerfile: {inference_folder}/{dockerfile}") return BuildImageRequest( repo="launch/inference", image_tag=resulting_image_tag[:MAX_IMAGE_TAG_LEN], @@ -614,7 +616,7 @@ def _get_inject_bundle_image_params( pass _, model_bundle_path = tempfile.mkstemp(dir=bundle_folder, suffix=".zip") bundle_url = model_bundle.location - logger.info( + logger_adapter.info( f"Downloading bundle from serialized object at location {bundle_url} to local path {model_bundle_path}" ) with open_wrapper(bundle_url, "rb") as bundle_data: # type: ignore @@ -678,6 +680,7 @@ async def _build_image( ) build_result_status = build_result.status build_result_logs: str = build_result.logs + logger_adapter.info(f"Image Build job: {build_result.job_name}") except Exception: # noqa build_result_status = False s3_logs_location: Optional[str] = None @@ -759,8 +762,7 @@ async def _build_image( else: self.monitoring_metrics_gateway.emit_image_build_cache_hit_metric(image_type) logger_adapter.info( - f"Image {image_params.repo}:{image_params.image_tag} already exists, " - f"skipping build for {endpoint_id=}" + f"Image already exists, skipping build. Image={image_params.repo}:{image_params.image_tag}, {endpoint_id=}" ) return self.docker_repository.get_image_url(image_params.image_tag, image_params.repo) diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index b55a2b50..b784e5c4 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -668,7 +668,7 @@ def get_image_url(self, image_tag: str, repository_name: str) -> str: def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: if self.raises_error: raise Exception("I hope you're handling this!") - return BuildImageResponse(status=True, logs="") + return BuildImageResponse(status=True, logs="", job_name="test-job-name") class FakeModelEndpointCacheRepository(ModelEndpointCacheRepository): diff --git a/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py index 31d44b2d..bf568c9a 100644 --- a/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py +++ b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py @@ -213,7 +213,7 @@ async def test_build_endpoint_build_result_failed_yields_docker_build_failed_exc repo.add_model_endpoint_record(build_endpoint_request_sync_pytorch.model_endpoint_record) endpoint_builder_service_empty_docker_not_built.docker_repository.__setattr__( "build_image", - Mock(return_value=BuildImageResponse(status=False, logs="")), + Mock(return_value=BuildImageResponse(status=False, logs="", job_name="")), ) with pytest.raises(DockerBuildFailedException): await endpoint_builder_service_empty_docker_not_built.build_endpoint( From a5d904d6a705d016f57d7c09883dd25f48f26cbc Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 23 Oct 2023 12:18:12 -0700 Subject: [PATCH 150/425] Make PDB optional (#344) Co-authored-by: Ian Macleod <139901935+ian-scale@users.noreply.github.com> --- charts/model-engine/templates/pod_disruption_budget.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/charts/model-engine/templates/pod_disruption_budget.yaml b/charts/model-engine/templates/pod_disruption_budget.yaml index 0959caec..67b7a02f 100644 --- a/charts/model-engine/templates/pod_disruption_budget.yaml +++ b/charts/model-engine/templates/pod_disruption_budget.yaml @@ -1,4 +1,4 @@ -{{- if .Values.podDisruptionBudget.enabled }} +{{- if and .Values.podDisruptionBudget .Values.podDisruptionBudget.enabled }} apiVersion: policy/v1 kind: PodDisruptionBudget metadata: From f2c253ffb91012c0c81db3583ddfb086a0554f62 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Mon, 23 Oct 2023 14:32:21 -0700 Subject: [PATCH 151/425] Revert "fix celery worker profile for s3 access (#333)" (#345) This reverts commit fe24d634967ffa2437a85cd81faa22dbf114390b. --- model-engine/model_engine_server/core/celery/app.py | 3 ++- .../inference/forwarding/celery_forwarder.py | 1 - .../infra/gateways/celery_task_queue_gateway.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index 7e87d2f0..42651d56 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -504,7 +504,8 @@ def _get_backend_url_and_conf( elif backend_protocol == "s3": backend_url = "s3://" if aws_role is None: - aws_session = session(infra_config().profile_ml_worker) + aws_profile = os.getenv("AWS_PROFILE", infra_config().profile_ml_worker) + aws_session = session(aws_profile) else: aws_session = session(aws_role) out_conf_changes.update( diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 6206f711..16e7fc34 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -95,7 +95,6 @@ def create_celery_service( app: Celery = celery_app( name=None, s3_bucket=infra_config().s3_bucket, - aws_role=infra_config().profile_ml_inference_worker, task_visibility=task_visibility, broker_type=str(BrokerType.SQS.value if sqs_url else BrokerType.REDIS.value), broker_transport_options={"predefined_queues": {queue_name: {"url": sqs_url}}} diff --git a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index 8d487029..66f39f83 100644 --- a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -11,7 +11,9 @@ from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway celery_redis = celery_app( - None, s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.REDIS.value) + None, + s3_bucket=infra_config().s3_bucket, + broker_type=str(BrokerType.REDIS.value), ) celery_redis_24h = celery_app( None, From c7dae60e02abe0bfc90461fee49e7cf833277f4b Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Mon, 23 Oct 2023 16:00:46 -0700 Subject: [PATCH 152/425] Revert "Revert "fix celery worker profile for s3 access (#333)" (#345)" (#346) This reverts commit f2c253ffb91012c0c81db3583ddfb086a0554f62. --- model-engine/model_engine_server/core/celery/app.py | 3 +-- .../inference/forwarding/celery_forwarder.py | 1 + .../infra/gateways/celery_task_queue_gateway.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index 42651d56..7e87d2f0 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -504,8 +504,7 @@ def _get_backend_url_and_conf( elif backend_protocol == "s3": backend_url = "s3://" if aws_role is None: - aws_profile = os.getenv("AWS_PROFILE", infra_config().profile_ml_worker) - aws_session = session(aws_profile) + aws_session = session(infra_config().profile_ml_worker) else: aws_session = session(aws_role) out_conf_changes.update( diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 16e7fc34..6206f711 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -95,6 +95,7 @@ def create_celery_service( app: Celery = celery_app( name=None, s3_bucket=infra_config().s3_bucket, + aws_role=infra_config().profile_ml_inference_worker, task_visibility=task_visibility, broker_type=str(BrokerType.SQS.value if sqs_url else BrokerType.REDIS.value), broker_transport_options={"predefined_queues": {queue_name: {"url": sqs_url}}} diff --git a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index 66f39f83..8d487029 100644 --- a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -11,9 +11,7 @@ from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway celery_redis = celery_app( - None, - s3_bucket=infra_config().s3_bucket, - broker_type=str(BrokerType.REDIS.value), + None, s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.REDIS.value) ) celery_redis_24h = celery_app( None, From f894c1014e97c78b023f4ee4ef872289e9e9fc12 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 24 Oct 2023 15:13:28 -0700 Subject: [PATCH 153/425] Pass file ID to fine-tuning script (#347) --- .../use_cases/llm_fine_tuning_use_cases.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py index 689adfdb..268569be 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -147,33 +147,33 @@ async def execute(self, user: User, request: CreateFineTuneRequest) -> CreateFin fine_tuned_model = ensure_model_name_is_valid_k8s_label(fine_tuned_model) if request.training_file.startswith("file-"): - training_file = await self.file_storage_gateway.get_url_from_id( + training_file_url = await self.file_storage_gateway.get_url_from_id( user.team_id, request.training_file ) - if training_file is None: + if training_file_url is None: raise ObjectNotFoundException("Training file does not exist") else: - training_file = request.training_file + training_file_url = request.training_file if request.validation_file is not None and request.validation_file.startswith("file-"): - validation_file = await self.file_storage_gateway.get_url_from_id( + validation_file_url = await self.file_storage_gateway.get_url_from_id( user.team_id, request.validation_file ) - if validation_file is None: + if validation_file_url is None: raise ObjectNotFoundException("Validation file does not exist") else: - validation_file = request.validation_file + validation_file_url = request.validation_file - check_file_is_valid(training_file, "training") - check_file_is_valid(validation_file, "validation") + check_file_is_valid(training_file_url, "training") + check_file_is_valid(validation_file_url, "validation") await self.llm_fine_tune_events_repository.initialize_events(user.team_id, fine_tuned_model) fine_tune_id = await self.llm_fine_tuning_service.create_fine_tune( created_by=user.user_id, owner=user.team_id, model=request.model, - training_file=training_file, - validation_file=validation_file, + training_file=request.training_file, # for Files API, pass file ID rather than signed URL since the latter expires; fine-tuning script will get file content from Files API + validation_file=request.validation_file, fine_tuning_method=DEFAULT_FINE_TUNING_METHOD, hyperparameters=request.hyperparameters, fine_tuned_model=fine_tuned_model, From d2d4d10fc91567b1ec86220f6e1cbe1b507283fb Mon Sep 17 00:00:00 2001 From: Sam Denton <106690182+sam-scale@users.noreply.github.com> Date: Tue, 24 Oct 2023 16:23:13 -0700 Subject: [PATCH 154/425] llama should have None max length (#348) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 13f48363..3bdec782 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -152,14 +152,14 @@ # We need a dict where if we need to override we can # NOTE: These are in *descending* order of priority. e.g. if you see 'mammoth-coder' # you'll use that override and not listen to the 'llama-2' override -_VLLM_MODEL_LENGTH_OVERRIDES: Dict[str, Dict[str, int]] = { +_VLLM_MODEL_LENGTH_OVERRIDES: Dict[str, Dict[str, Optional[int]]] = { "mammoth-coder": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, # Based on config here: https://huggingface.co/TIGER-Lab/MAmmoTH-Coder-7B/blob/main/config.json#L12 # Can also see 13B, 34B there too "code-llama": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, # Based on config here: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json#L12 # Can also see 13B, 34B there too - "llama-2": {"max_model_len": 4096, "max_num_batched_tokens": 4096}, + "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, } @@ -534,7 +534,7 @@ async def create_vllm_bundle( ): command = [] - max_num_batched_tokens: int = 2560 # vLLM's default + max_num_batched_tokens: Optional[int] = 2560 # vLLM's default max_model_len: Optional[int] = None for key, value in _VLLM_MODEL_LENGTH_OVERRIDES.items(): From d2effdcad39632ba9badacfe668acd96a35d9aee Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 24 Oct 2023 18:06:02 -0700 Subject: [PATCH 155/425] taking out codellama13b and 34b (#349) --- docs/model_zoo.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 05f24d60..148b1dfc 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -21,8 +21,6 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `mistral-7b` | ✅ | ✅ | vllm | | `mistral-7b-instruct` | ✅ | ✅ | vllm | | `code-llama-7b` | ✅ | | text-generation-inference, vllm | -| `code-llama-13b` | ✅ | | text-generation-inference, vllm | -| `code-llama-34b` | ✅ | | text-generation-inference, vllm | ## Usage From 2603b9b6ecd9cb6809870ca11e9081308f172305 Mon Sep 17 00:00:00 2001 From: Edward Park Date: Wed, 25 Oct 2023 18:57:36 +0200 Subject: [PATCH 156/425] Change DATADOG_TRACE_ENABLED to DD_TRACE_ENABLED (#350) --- charts/model-engine/templates/_helpers.tpl | 12 ++--- .../templates/service_config_map.yaml | 4 +- charts/model-engine/values.yaml | 2 +- charts/model-engine/values_circleci.yaml | 2 +- charts/model-engine/values_sample.yaml | 4 +- docs/guides/self_hosting.md | 2 +- .../model_engine_server/common/config.py | 2 +- .../k8s_endpoint_resource_delegate.py | 4 +- .../gateways/resources/k8s_resource_types.py | 24 +++++----- .../service_template_config_map_circleci.yaml | 46 +++++++++---------- .../services/live_endpoint_builder_service.py | 2 +- .../service_config_circleci.yaml | 2 +- 12 files changed, 53 insertions(+), 53 deletions(-) diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index 75b69dc3..8e9f1da3 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -125,8 +125,8 @@ podAffinity: {{- define "modelEngine.baseServiceTemplateEnv" -}} env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_REMOTE_CONFIGURATION_ENABLED value: "false" - name: DD_SERVICE @@ -187,8 +187,8 @@ env: {{- define "modelEngine.baseForwarderTemplateEnv" -}} env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_REMOTE_CONFIGURATION_ENABLED value: "false" - name: DD_SERVICE @@ -235,8 +235,8 @@ env: {{- define "modelEngine.serviceEnvBase" }} env: - - name: DATADOG_TRACE_ENABLED - value: "{{ .Values.datadog_trace_enabled }}" + - name: DD_TRACE_ENABLED + value: "{{ .Values.dd_trace_enabled }}" - name: DD_REMOTE_CONFIGURATION_ENABLED value: "false" - name: DD_ENV diff --git a/charts/model-engine/templates/service_config_map.yaml b/charts/model-engine/templates/service_config_map.yaml index b6809b22..70a12755 100644 --- a/charts/model-engine/templates/service_config_map.yaml +++ b/charts/model-engine/templates/service_config_map.yaml @@ -10,7 +10,7 @@ metadata: "helm.sh/hook-weight": "-2" data: launch_service_config: |- - datadog_trace_enabled: {{ .Values.datadog_trace_enabled | default false | quote }} + dd_trace_enabled: {{ .Values.dd_trace_enabled | default false | quote }} {{- with .Values.config.values.launch }} {{- range $key, $value := . }} {{ $key }}: {{ $value | quote }} @@ -38,7 +38,7 @@ metadata: "helm.sh/hook-weight": "-2" data: launch_service_config: |- - datadog_trace_enabled: {{ .Values.datadog_trace_enabled | default false | quote }} + dd_trace_enabled: {{ .Values.dd_trace_enabled | default false | quote }} {{- with .Values.config.values.launch }} {{- range $key, $value := . }} {{ $key }}: {{ $value | quote }} diff --git a/charts/model-engine/values.yaml b/charts/model-engine/values.yaml index c228a34a..666f1c4d 100644 --- a/charts/model-engine/values.yaml +++ b/charts/model-engine/values.yaml @@ -1,4 +1,4 @@ -datadog_trace_enabled: true +dd_trace_enabled: true spellbook: enabled: false redis: diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 1562ffbc..a82e2bad 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -141,7 +141,7 @@ config: billing_queue_arn: none cache_redis_url: redis://redis-message-broker-master.default/15 s3_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET" - datadog_trace_enabled: false + dd_trace_enabled: false istio_enabled: true tgi_repository: "text-generation-inference" vllm_repository: "vllm" diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 61ed8404..75bb808e 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -118,8 +118,8 @@ config: cache_redis_url: redis://llm-engine-prod-cache.use1.cache.amazonaws.com:6379/15 # s3_file_llm_fine_tuning_job_repository [required] is the S3 URI for the S3 bucket/key that you wish to save fine-tuned assests s3_file_llm_fine_tuning_job_repository: "s3://llm-engine/llm-ft-job-repository" - # datadog_trace_enabled specifies whether to enable datadog tracing, datadog must be installed in the cluster - datadog_trace_enabled: false + # dd_trace_enabled specifies whether to enable datadog tracing, datadog must be installed in the cluster + dd_trace_enabled: false # Asynchronous endpoints configs (coming soon) sqs_profile: default diff --git a/docs/guides/self_hosting.md b/docs/guides/self_hosting.md index 0c446191..84aaa376 100644 --- a/docs/guides/self_hosting.md +++ b/docs/guides/self_hosting.md @@ -114,7 +114,7 @@ Below are the configurations to specify in the `values_sample.yaml` file. | config.values.llm_engine.endpoint_namespace | K8s namespace the endpoints will be created in | Yes | | config.values.llm_engine.cache_redis_url | The full url for the redis cluster you wish to connect | Yes | | config.values.llm_engine.s3_file_llm_fine_tuning_job_repository | The S3 URI for the S3 bucket/key that you wish to save fine-tuned assets | Yes | -| config.values.datadog_trace_enabled | Whether to enable datadog tracing, datadog must be installed in the cluster | No | +| config.values.dd_trace_enabled | Whether to enable datadog tracing, datadog must be installed in the cluster | No | ## Play With It Once `helm install` succeeds, you can forward port `5000` from a `llm-engine` pod and test sending requests to it. diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index de76ff96..7d62b9dd 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -53,7 +53,7 @@ class HostedModelInferenceServiceConfig: s3_file_llm_fine_tune_repository: str hf_user_fine_tuned_weights_prefix: str istio_enabled: bool - datadog_trace_enabled: bool + dd_trace_enabled: bool tgi_repository: str vllm_repository: str lightllm_repository: str diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 154fdb5b..ae8576e3 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -60,7 +60,7 @@ # and where the user actually owns the files BASE_PATH_IN_ENDPOINT = "/app" -DATADOG_ENV_VAR = {"DATADOG_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"} +DATADOG_ENV_VAR = {"DD_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"} _lazy_load_kubernetes_clients = True _kubernetes_apps_api = None @@ -237,7 +237,7 @@ def add_datadog_env_to_main_container(deployment_template: Dict[str, Any]) -> No user_container_envs.extend( [ { - "name": "DATADOG_TRACE_ENABLED", + "name": "DD_TRACE_ENABLED", "value": "false" if CIRCLECI else "true", }, { diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index e417058d..1a7998e2 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -106,7 +106,7 @@ class _BaseDeploymentArguments(_BaseEndpointArguments): PRIORITY: str IMAGE: str IMAGE_HASH: str - DATADOG_TRACE_ENABLED: str + DD_TRACE_ENABLED: str CPUS: str MEMORY: str STORAGE_DICT: DictStrStr @@ -510,7 +510,7 @@ def get_endpoint_resource_arguments_from_request( # In Circle CI, we use Redis on localhost instead of SQS broker_name = BrokerName.SQS.value if not CIRCLECI else BrokerName.REDIS.value broker_type = BrokerType.SQS.value if not CIRCLECI else BrokerType.REDIS.value - datadog_trace_enabled = hmi_config.datadog_trace_enabled + dd_trace_enabled = hmi_config.dd_trace_enabled if broker_type == BrokerType.REDIS.value: sqs_queue_url = "" @@ -573,7 +573,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -621,7 +621,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -671,7 +671,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -716,7 +716,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -763,7 +763,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -807,7 +807,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -853,7 +853,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -909,7 +909,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -967,7 +967,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -1019,7 +1019,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=dd_trace_enabled, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 63e94d0b..93a779c2 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -128,8 +128,8 @@ data: - --num-workers - "${PER_WORKER}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -399,8 +399,8 @@ data: - --num-workers - "${PER_WORKER}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -618,8 +618,8 @@ data: - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -886,8 +886,8 @@ data: - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -1112,8 +1112,8 @@ data: - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -1349,8 +1349,8 @@ data: - --num-workers - "${PER_WORKER}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -1627,8 +1627,8 @@ data: - --num-workers - "${PER_WORKER}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -1853,8 +1853,8 @@ data: - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -2128,8 +2128,8 @@ data: - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -2361,8 +2361,8 @@ data: - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -2810,7 +2810,7 @@ data: env: - name: DD_SERVICE value: ${RESOURCE_NAME} - - name: DATADOG_TRACE_ENABLED + - name: DD_TRACE_ENABLED value: "false" - name: DD_ENV value: circleci @@ -2935,7 +2935,7 @@ data: env: - name: DD_SERVICE value: ${RESOURCE_NAME} - - name: DATADOG_TRACE_ENABLED + - name: DD_TRACE_ENABLED value: "false" - name: DD_ENV value: circleci @@ -3081,7 +3081,7 @@ data: env: - name: DD_SERVICE value: ${RESOURCE_NAME} - - name: DATADOG_TRACE_ENABLED + - name: DD_TRACE_ENABLED value: "false" - name: DD_ENV value: circleci diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 7073b39f..30ca9b7b 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -83,7 +83,7 @@ RESTRICTED_ENV_VARS_KEYS = { "BASE": [ - "DATADOG_TRACE_ENABLED", + "DD_TRACE_ENABLED", "DD_AGENT_HOST", "DD_ENV", "DD_SERVICE", diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index 17e36639..a6c98e9b 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -52,7 +52,7 @@ cache_redis_url: redis://127.0.0.1:6379/15 s3_file_llm_fine_tune_repository: "s3://model-engine-integration-tests/fine_tune_repository/circleci" -datadog_trace_enabled: false +dd_trace_enabled: false istio_enabled: true tgi_repository: "text-generation-inference" vllm_repository: "vllm" From f3ad7ec3b2f05f827b46d9476524569e58b5c416 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Wed, 25 Oct 2023 17:11:22 -0700 Subject: [PATCH 157/425] Allow fine-tuning hyperparameter to be Dict (#353) --- .../model_engine_server/domain/entities/common_types.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/domain/entities/common_types.py b/model-engine/model_engine_server/domain/entities/common_types.py index 3556723c..ea0c2240 100644 --- a/model-engine/model_engine_server/domain/entities/common_types.py +++ b/model-engine/model_engine_server/domain/entities/common_types.py @@ -1,5 +1,7 @@ -from typing import Union +from typing import Any, Dict, Union CpuSpecificationType = Union[str, int, float] StorageSpecificationType = Union[str, int, float] # TODO(phil): we can make this more specific. -FineTuneHparamValueType = Union[str, int, float] # should suffice for now +FineTuneHparamValueType = Union[ + str, int, float, Dict[str, Any] +] # should just make this Any if we need to add more From 0b8a817f6b8c583999011192cefe6f015cc8b9c5 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Thu, 26 Oct 2023 13:36:34 -0700 Subject: [PATCH 158/425] adding real auth to integration tests (#352) * adding real auth to integration tests * use real auth * add default fakeuser * cleanup --- integration_tests/rest_api_utils.py | 15 ++++----------- integration_tests/test_bundles.py | 19 +++++++++---------- integration_tests/test_fine_tunes.py | 10 +++++----- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index 604b8744..abec0cf4 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -15,17 +15,10 @@ print(f"Integration tests using gateway {BASE_PATH=}") DEFAULT_NETWORK_TIMEOUT_SEC = 10 -# Generate some fake 24-character user IDs. -# We don't want different people to get user ID collisions but at the same time we want people to -# consistently use the same user IDs so that they can clean up their extra endpoints. -USER_PREFIX = os.getenv("SERVICE_IDENTIFIER", "test")[:8] -USER_ID_0 = USER_PREFIX + "0" * (24 - len(USER_PREFIX)) -USER_ID_1 = USER_PREFIX + "1" * (24 - len(USER_PREFIX)) - -DEFAULT_USERS: Sequence[str] = ( - USER_ID_0, - USER_ID_1, -) +# Use the scale-launch-integration-tests id +USER_ID_0 = os.getenv("TEST_USER_ID", "fakeuser") + +DEFAULT_USERS: Sequence[str] = (USER_ID_0,) # type: ignore def echo_load_predict_fn(model): diff --git a/integration_tests/test_bundles.py b/integration_tests/test_bundles.py index cb8a45e7..3e8c47d8 100644 --- a/integration_tests/test_bundles.py +++ b/integration_tests/test_bundles.py @@ -5,7 +5,6 @@ CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE, CREATE_MODEL_BUNDLE_REQUEST_SIMPLE, USER_ID_0, - USER_ID_1, create_model_bundle, ensure_launch_gateway_healthy, get_latest_model_bundle, @@ -16,12 +15,12 @@ @retry(stop=stop_after_attempt(10), wait=wait_fixed(30)) def model_bundles(): ensure_launch_gateway_healthy() - for user in [USER_ID_0, USER_ID_1]: - for create_bundle_request in [ - CREATE_MODEL_BUNDLE_REQUEST_SIMPLE, - CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE, - ]: - create_model_bundle(create_bundle_request, user, "v2") - bundle = get_latest_model_bundle(create_bundle_request["name"], user, "v2") - assert bundle["name"] == create_bundle_request["name"] - assert bundle["metadata"] == create_bundle_request["metadata"] + user = USER_ID_0 + for create_bundle_request in [ + CREATE_MODEL_BUNDLE_REQUEST_SIMPLE, + CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE, + ]: + create_model_bundle(create_bundle_request, user, "v2") + bundle = get_latest_model_bundle(create_bundle_request["name"], user, "v2") + assert bundle["name"] == create_bundle_request["name"] + assert bundle["metadata"] == create_bundle_request["metadata"] diff --git a/integration_tests/test_fine_tunes.py b/integration_tests/test_fine_tunes.py index a5aee7ac..024540e2 100644 --- a/integration_tests/test_fine_tunes.py +++ b/integration_tests/test_fine_tunes.py @@ -1,9 +1,12 @@ +import pytest + from .rest_api_utils import ( # CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, CREATE_FINE_TUNE_REQUEST, USER_ID_0, cancel_fine_tune_by_id, create_docker_image_batch_job_bundle, create_fine_tune, get_fine_tune_by_id, - USER_ID_1, + USER_ID_0, list_fine_tunes, ) +@pytest.mark.skip(reason="test doesn't currently work, needs to be implemented correctly") def test_fine_tunes() -> None: # TODO: get this test to work (move LLM fine tune repository to database rather than in S3) @@ -21,11 +24,8 @@ def test_fine_tunes() -> None: # num_jobs = len(list_response_0_before["jobs"]) # assert num_jobs >= 1 - list_response_1 = list_fine_tunes(USER_ID_1) + list_response_1 = list_fine_tunes(USER_ID_0) assert len(list_response_1["jobs"]) == 0 - # cancel_response = cancel_fine_tune_by_id(fine_tune_id, USER_ID_0) - # assert cancel_response["success"] - # list_response_0_after = list_fine_tunes(USER_ID_0) # assert len(list_response_0_after["jobs"]) == num_jobs - 1 From d29faffb70cc2930a2d1ff833becae6bf78e12e8 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Fri, 27 Oct 2023 09:58:20 -0700 Subject: [PATCH 159/425] add new llm-jp models to llm-engine (#354) * add new llm-jp models to llm-engine * update names for japanese llms --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 3bdec782..78218843 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -105,6 +105,8 @@ "code-llama-7b": "codellama/CodeLlama-7b-hf", "code-llama-13b": "codellama/CodeLlama-13b-hf", "code-llama-34b": "codellama/CodeLlama-34b-hf", + "llm-jp-13b-instruct-full": "llm-jp/llm-jp-13b-instruct-full-jaster-v1.0", + "llm-jp-13b-instruct-full-dolly": "llm-jp/llm-jp-13b-instruct-full-dolly-oasst-v1.0", }, LLMInferenceFramework.VLLM: { "mpt-7b": "mosaicml/mpt-7b", From 8c21282deb954769ca1024887d714e1045b52b10 Mon Sep 17 00:00:00 2001 From: Jason Liang Date: Fri, 27 Oct 2023 16:00:16 -0700 Subject: [PATCH 160/425] Generalize SQS region (#355) * Generalize SQS region * Fix lint * Fix hmi_config * Remove DD change --- .../infra/gateways/resources/k8s_resource_types.py | 6 +++--- .../resources/live_sqs_endpoint_resource_delegate.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 1a7998e2..c86d18fd 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -496,9 +496,9 @@ def get_endpoint_resource_arguments_from_request( change_cause_message = ( f"Deployment at {datetime.utcnow()} UTC. " - f"Using deployment constructed from model bundle ID: {model_bundle.id}, " - f"model bundle name: {model_bundle.name}, " - f"endpoint ID: {model_endpoint_record.id}" + f"Using deployment constructed from model bundle ID {model_bundle.id}, " + f"model bundle name {model_bundle.name}, " + f"endpoint ID {model_endpoint_record.id}" ) priority = LAUNCH_DEFAULT_PRIORITY_CLASS diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py index 6d9f6597..f04d6b65 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py @@ -7,6 +7,7 @@ from aiobotocore.client import AioBaseClient from model_engine_server.common.config import hmi_config from model_engine_server.core.aws.roles import session +from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import EndpointResourceInfraException from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( @@ -21,7 +22,9 @@ def _create_async_sqs_client(sqs_profile: Optional[str]) -> AioBaseClient: - return session(role=sqs_profile, session_type=AioSession).client("sqs", region_name="us-west-2") + return session(role=sqs_profile, session_type=AioSession).client( + "sqs", region_name=infra_config().default_region + ) def _get_queue_policy(queue_name: str) -> str: From da9f82b98774db16adaf0dc054477ccee76e8d6f Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Thu, 2 Nov 2023 10:35:58 -0700 Subject: [PATCH 161/425] Track LLM Metrics (#356) * add emit_route_call_metric fn * add MonitoringMetricsGateway to external interfaces * record metrics on llm routes * missed change * instantiate to none * fix opt params * use fastapi dependency injection instead * change kwargs to args --- .../model_engine_server/api/dependencies.py | 2 ++ .../model_engine_server/api/llms_v1.py | 19 +++++++++++++++++-- .../core/auth/authentication_repository.py | 3 ++- .../gateways/monitoring_metrics_gateway.py | 17 +++++++++++++++++ .../fake_monitoring_metrics_gateway.py | 10 +++++++++- model-engine/tests/unit/conftest.py | 2 ++ 6 files changed, 49 insertions(+), 4 deletions(-) diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 08f362cc..0a055248 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -133,6 +133,7 @@ class ExternalInterfaces: filesystem_gateway: FilesystemGateway llm_artifact_gateway: LLMArtifactGateway cron_job_gateway: CronJobGateway + monitoring_metrics_gateway: MonitoringMetricsGateway def get_default_monitoring_metrics_gateway() -> MonitoringMetricsGateway: @@ -279,6 +280,7 @@ def _get_external_interfaces( llm_artifact_gateway=llm_artifact_gateway, trigger_repository=trigger_repository, cron_job_gateway=cron_job_gateway, + monitoring_metrics_gateway=monitoring_metrics_gateway, ) return external_interfaces diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 7e73ef70..0dcc3faa 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -5,7 +5,7 @@ from typing import Optional import pytz -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, Request from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, @@ -58,6 +58,7 @@ ObjectNotFoundException, UpstreamServiceError, ) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( CancelFineTuneV1UseCase, CreateFineTuneV1UseCase, @@ -77,7 +78,21 @@ from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase from sse_starlette.sse import EventSourceResponse -llm_router_v1 = APIRouter(prefix="/v1/llm") + +async def record_route_call( + request: Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +): + route = f"{request.method}_{request.url.path}".lower() + model_name = request.query_params.get("model_endpoint_name", None) + + external_interfaces.monitoring_metrics_gateway.emit_route_call_metric( + route, MetricMetadata(user=auth, model_name=model_name) + ) + + +llm_router_v1 = APIRouter(prefix="/v1/llm", dependencies=[Depends(record_route_call)]) logger = make_logger(logger_name()) diff --git a/model-engine/model_engine_server/core/auth/authentication_repository.py b/model-engine/model_engine_server/core/auth/authentication_repository.py index f3a60847..a4d36dc1 100644 --- a/model-engine/model_engine_server/core/auth/authentication_repository.py +++ b/model-engine/model_engine_server/core/auth/authentication_repository.py @@ -7,7 +7,8 @@ class User: user_id: str team_id: str - is_privileged_user: bool + email: Optional[str] = None + is_privileged_user: bool = False class AuthenticationRepository(ABC): diff --git a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py index 28a561cf..5e7e0382 100644 --- a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -7,6 +7,15 @@ """ from abc import ABC, abstractmethod +from typing import Optional + +from model_engine_server.core.auth.authentication_repository import User +from pydantic import BaseModel + + +class MetricMetadata(BaseModel): + user: User + model_name: Optional[str] class MonitoringMetricsGateway(ABC): @@ -64,3 +73,11 @@ def emit_database_cache_miss_metric(self): Missed database cache metric """ + + @abstractmethod + def emit_route_call_metric(self, route: str, metadata: MetricMetadata): + """ + Route call metric + + """ + pass diff --git a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py index 65b6cd7e..32c2b6f3 100644 --- a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -1,6 +1,9 @@ from collections import defaultdict -from model_engine_server.domain.gateways import MonitoringMetricsGateway +from model_engine_server.domain.gateways.monitoring_metrics_gateway import ( + MetricMetadata, + MonitoringMetricsGateway, +) class FakeMonitoringMetricsGateway(MonitoringMetricsGateway): @@ -15,6 +18,7 @@ def __init__(self): self.successful_hook = defaultdict(int) self.database_cache_hit = 0 self.database_cache_miss = 0 + self.route_call = defaultdict(int) def reset(self): self.attempted_build = 0 @@ -27,6 +31,7 @@ def reset(self): self.successful_hook = defaultdict(int) self.database_cache_hit = 0 self.database_cache_miss = 0 + self.route_call = defaultdict(int) def emit_attempted_build_metric(self): self.attempted_build += 1 @@ -57,3 +62,6 @@ def emit_database_cache_hit_metric(self): def emit_database_cache_miss_metric(self): self.database_cache_miss += 1 + + def emit_route_call_metric(self, route: str, _metadata: MetricMetadata): + self.route_call[route] += 1 diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index b784e5c4..92d2e88b 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -2093,6 +2093,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: fake_model_bundle_repository = FakeModelBundleRepository( contents=fake_model_bundle_repository_contents ) + fake_monitoring_metrics_gateway = FakeMonitoringMetricsGateway() fake_model_endpoint_record_repository = FakeModelEndpointRecordRepository( contents=fake_model_endpoint_record_repository_contents, model_bundle_repository=fake_model_bundle_repository, @@ -2176,6 +2177,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: cron_job_gateway=fake_cron_job_gateway, filesystem_gateway=fake_file_system_gateway, llm_artifact_gateway=fake_llm_artifact_gateway, + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, ) try: yield repositories From 4c83d54a3e5d9604101c0c2185636b9fde06ce50 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:02:19 -0700 Subject: [PATCH 162/425] remove extra trace facet (#359) --- .../model_engine_server/api/batch_jobs_v1.py | 8 -------- .../api/docker_image_batch_job_bundles_v1.py | 5 ----- model-engine/model_engine_server/api/files_v1.py | 6 ------ model-engine/model_engine_server/api/llms_v1.py | 13 ------------- .../model_engine_server/api/model_bundles_v1.py | 6 ------ .../model_engine_server/api/model_bundles_v2.py | 6 ------ .../model_engine_server/api/model_endpoints_v1.py | 6 ------ model-engine/model_engine_server/api/tasks_v1.py | 5 ----- model-engine/model_engine_server/api/triggers_v1.py | 6 ------ .../model_engine_server/common/datadog_utils.py | 9 --------- 10 files changed, 70 deletions(-) diff --git a/model-engine/model_engine_server/api/batch_jobs_v1.py b/model-engine/model_engine_server/api/batch_jobs_v1.py index 6241b202..1724c5a7 100644 --- a/model-engine/model_engine_server/api/batch_jobs_v1.py +++ b/model-engine/model_engine_server/api/batch_jobs_v1.py @@ -7,7 +7,6 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.batch_jobs import ( CreateBatchJobV1Request, CreateBatchJobV1Response, @@ -55,7 +54,6 @@ async def create_batch_job( """ Runs a batch job. """ - add_trace_resource_name("batch_jobs_post") logger.info(f"POST /batch-jobs with {request} for {auth}") try: use_case = CreateBatchJobV1UseCase( @@ -85,7 +83,6 @@ async def get_batch_job( """ Gets a batch job. """ - add_trace_resource_name("batch_jobs_get") logger.info(f"GET /batch-jobs/{batch_job_id} for {auth}") try: use_case = GetBatchJobV1UseCase(batch_job_service=external_interfaces.batch_job_service) @@ -107,7 +104,6 @@ async def update_batch_job( """ Updates a batch job. """ - add_trace_resource_name("batch_jobs_put") logger.info(f"PUT /batch-jobs/{batch_job_id} for {auth}") try: use_case = UpdateBatchJobV1UseCase(batch_job_service=external_interfaces.batch_job_service) @@ -127,7 +123,6 @@ async def create_docker_image_batch_job( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> CreateDockerImageBatchJobV1Response: - add_trace_resource_name("batch_jobs_di_create") logger.info(f"POST /docker-image-batch-jobs with {request} for {auth}") try: use_case = CreateDockerImageBatchJobV1UseCase( @@ -166,7 +161,6 @@ async def get_docker_image_batch_job( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> GetDockerImageBatchJobV1Response: - add_trace_resource_name("batch_jobs_di_get") logger.info(f"GET /docker-image-batch-jobs/{batch_job_id} for {auth}") try: use_case = GetDockerImageBatchJobV1UseCase( @@ -191,7 +185,6 @@ async def list_docker_image_batch_jobs( """ Lists docker image batch jobs spawned by trigger with given ID """ - add_trace_resource_name("batch_jobs_di_get_trigger") logger.info(f"GET /docker-image-batch-jobs?trigger_id={trigger_id}") try: use_case = ListDockerImageBatchJobsV1UseCase( @@ -212,7 +205,6 @@ async def update_docker_image_batch_job( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> UpdateDockerImageBatchJobV1Response: - add_trace_resource_name("batch_jobs_di_put") logger.info(f"PUT /docker-image-batch-jobs/{batch_job_id} with {request} for {auth}") try: use_case = UpdateDockerImageBatchJobV1UseCase( diff --git a/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py index 4b2980be..be0b93ad 100644 --- a/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py +++ b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py @@ -6,7 +6,6 @@ get_external_interfaces, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.batch_jobs import ( CreateDockerImageBatchJobBundleV1Request, CreateDockerImageBatchJobBundleV1Response, @@ -44,7 +43,6 @@ async def create_docker_image_batch_job_bundle( """ Creates a docker iamge batch job bundle """ - add_trace_resource_name("docker_image_batch_job_bundle_post") logger.info(f"POST /docker-image-batch-job-bundles with {request} for {auth}") try: use_case = CreateDockerImageBatchJobBundleV1UseCase( @@ -71,7 +69,6 @@ async def list_docker_image_batch_job_model_bundles( Lists docker image batch job bundles owned by current owner """ - add_trace_resource_name("docker_image_batch_job_bundle_get") logger.info( f"GET /docker-image-batch-job-bundles?bundle_name={bundle_name}&order_by={order_by} for auth" ) @@ -90,7 +87,6 @@ async def get_latest_docker_image_batch_job_bundle( external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> DockerImageBatchJobBundleV1Response: """Gets latest Docker Image Batch Job Bundle with given name owned by the current owner""" - add_trace_resource_name("docker_image_batch_job_bundle_latest_get") logger.info(f"GET /docker-image-batch-job-bundles/latest?bundle_name={bundle_name} for {auth}") try: use_case = GetLatestDockerImageBatchJobBundleByNameV1UseCase( @@ -114,7 +110,6 @@ async def get_docker_image_batch_job_model_bundle( external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> DockerImageBatchJobBundleV1Response: """Get details for a given DockerImageBatchJobBundle owned by the current owner""" - add_trace_resource_name("docker_image_batch_job_bundle_id_get") logger.info( f"GET /docker-image-batch-job-bundles/{docker_image_batch_job_bundle_id} for {auth}" ) diff --git a/model-engine/model_engine_server/api/files_v1.py b/model-engine/model_engine_server/api/files_v1.py index dd52c10b..556566d5 100644 --- a/model-engine/model_engine_server/api/files_v1.py +++ b/model-engine/model_engine_server/api/files_v1.py @@ -7,7 +7,6 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.files import ( DeleteFileResponse, GetFileContentResponse, @@ -39,7 +38,6 @@ async def upload_file( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> UploadFileResponse: - add_trace_resource_name("files_upload") logger.info(f"POST /files with filename {file.filename} for {auth}") use_case = UploadFileUseCase( file_storage_gateway=external_interfaces.file_storage_gateway, @@ -57,7 +55,6 @@ async def get_file( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> GetFileResponse: - add_trace_resource_name("files_get") logger.info(f"GET /files/{file_id} for {auth}") try: use_case = GetFileUseCase( @@ -76,7 +73,6 @@ async def list_files( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> ListFilesResponse: - add_trace_resource_name("files_list") logger.info(f"GET /files for {auth}") use_case = ListFilesUseCase( file_storage_gateway=external_interfaces.file_storage_gateway, @@ -90,7 +86,6 @@ async def delete_file( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> DeleteFileResponse: - add_trace_resource_name("files_delete") logger.info(f"DELETE /files/{file_id} for {auth}") try: use_case = DeleteFileUseCase( @@ -113,7 +108,6 @@ async def get_file_content( """ Describe the LLM Model endpoint with given name. """ - add_trace_resource_name("files_content_get") logger.info(f"GET /files/{file_id}/content for {auth}") try: use_case = GetFileContentUseCase( diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 0dcc3faa..9e811f91 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -12,7 +12,6 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.llms import ( CancelFineTuneResponse, CompletionStreamV1Request, @@ -133,7 +132,6 @@ async def create_model_endpoint( """ Creates an LLM endpoint for the current user. """ - add_trace_resource_name("llm_model_endpoints_post") logger.info(f"POST /llm/model-endpoints with {request} for {auth}") try: create_model_bundle_use_case = CreateModelBundleV2UseCase( @@ -187,7 +185,6 @@ async def list_model_endpoints( """ Lists the LLM model endpoints owned by the current owner, plus all public_inference LLMs. """ - add_trace_resource_name("llm_model_endpoints_get") logger.info(f"GET /llm/model-endpoints?name={name}&order_by={order_by} for {auth}") use_case = ListLLMModelEndpointsV1UseCase( llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, @@ -206,7 +203,6 @@ async def get_model_endpoint( """ Describe the LLM Model endpoint with given name. """ - add_trace_resource_name("llm_model_endpoints_name_get") logger.info(f"GET /llm/model-endpoints/{model_endpoint_name} for {auth}") try: use_case = GetLLMModelEndpointByNameV1UseCase( @@ -230,7 +226,6 @@ async def create_completion_sync_task( """ Runs a sync prompt completion on an LLM. """ - add_trace_resource_name("llm_completion_sync_post") logger.info( f"POST /completion_sync with {request} to endpoint {model_endpoint_name} for {auth}" ) @@ -275,7 +270,6 @@ async def create_completion_stream_task( """ Runs a stream prompt completion on an LLM. """ - add_trace_resource_name("llm_completion_stream_post") logger.info( f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" ) @@ -311,7 +305,6 @@ async def create_fine_tune( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> CreateFineTuneResponse: - add_trace_resource_name("fine_tunes_create") logger.info(f"POST /fine-tunes with {request} for {auth}") try: use_case = CreateFineTuneV1UseCase( @@ -340,7 +333,6 @@ async def get_fine_tune( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> GetFineTuneResponse: - add_trace_resource_name("fine_tunes_get") logger.info(f"GET /fine-tunes/{fine_tune_id} for {auth}") try: use_case = GetFineTuneV1UseCase( @@ -359,7 +351,6 @@ async def list_fine_tunes( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> ListFineTunesResponse: - add_trace_resource_name("fine_tunes_list") logger.info(f"GET /fine-tunes for {auth}") use_case = ListFineTunesV1UseCase( llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, @@ -373,7 +364,6 @@ async def cancel_fine_tune( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> CancelFineTuneResponse: - add_trace_resource_name("fine_tunes_cancel") logger.info(f"PUT /fine-tunes/{fine_tune_id}/cancel for {auth}") try: use_case = CancelFineTuneV1UseCase( @@ -393,7 +383,6 @@ async def get_fine_tune_events( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> GetFineTuneEventsResponse: - add_trace_resource_name("fine_tunes_events_get") logger.info(f"GET /fine-tunes/{fine_tune_id}/events for {auth}") try: use_case = GetFineTuneEventsV1UseCase( @@ -414,7 +403,6 @@ async def download_model_endpoint( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> ModelDownloadResponse: - add_trace_resource_name("model_endpoints_download") logger.info(f"POST /model-endpoints/download with {request} for {auth}") try: use_case = ModelDownloadV1UseCase( @@ -438,7 +426,6 @@ async def delete_llm_model_endpoint( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> DeleteLLMEndpointResponse: - add_trace_resource_name("llm_model_endpoints_delete") logger.info(f"DELETE /model-endpoints/{model_endpoint_name} for {auth}") try: use_case = DeleteLLMEndpointByNameUseCase( diff --git a/model-engine/model_engine_server/api/model_bundles_v1.py b/model-engine/model_engine_server/api/model_bundles_v1.py index c83bda5d..de73fb4c 100644 --- a/model-engine/model_engine_server/api/model_bundles_v1.py +++ b/model-engine/model_engine_server/api/model_bundles_v1.py @@ -9,7 +9,6 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV1Request, CreateModelBundleV1Request, @@ -48,7 +47,6 @@ async def create_model_bundle( Creates a ModelBundle for the current user. """ logger.info(f"POST /model-bundles with {request} for {auth}") - add_trace_resource_name("model_bundles_post") try: use_case = CreateModelBundleV1UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -84,7 +82,6 @@ async def clone_model_bundle_with_changes( """ Creates a ModelBundle by cloning an existing one and then applying changes on top. """ - add_trace_resource_name("model_bundles_clone") try: use_case = CloneModelBundleV1UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -107,7 +104,6 @@ async def list_model_bundles( """ Lists the ModelBundles owned by the current owner. """ - add_trace_resource_name("model_bundles_get") logger.info(f"GET /model-bundles?model_name={model_name}&order_by={order_by} for {auth}") use_case = ListModelBundlesV1UseCase( model_bundle_repository=external_interfaces.model_bundle_repository @@ -124,7 +120,6 @@ async def get_latest_model_bundle( """ Gets the latest Model Bundle with the given name owned by the current owner. """ - add_trace_resource_name("model_bundles_latest_get") logger.info(f"GET /model-bundles/latest?model_name={model_name} for {auth}") try: use_case = GetLatestModelBundleByNameV1UseCase( @@ -149,7 +144,6 @@ async def get_model_bundle( """ Gets the details for a given ModelBundle owned by the current owner. """ - add_trace_resource_name("model_bundles_id_get") logger.info(f"GET /model-bundles/{model_bundle_id} for {auth}") try: use_case = GetModelBundleByIdV1UseCase( diff --git a/model-engine/model_engine_server/api/model_bundles_v2.py b/model-engine/model_engine_server/api/model_bundles_v2.py index 39f4a7d8..3376de70 100644 --- a/model-engine/model_engine_server/api/model_bundles_v2.py +++ b/model-engine/model_engine_server/api/model_bundles_v2.py @@ -9,7 +9,6 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV2Request, CreateModelBundleV2Request, @@ -48,7 +47,6 @@ async def create_model_bundle( Creates a ModelBundle for the current user. """ logger.info(f"POST /model-bundles with {request} for {auth}") - add_trace_resource_name("model_bundles_post") try: use_case = CreateModelBundleV2UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -84,7 +82,6 @@ async def clone_model_bundle_with_changes( """ Creates a ModelBundle by cloning an existing one and then applying changes on top. """ - add_trace_resource_name("model_bundles_clone") try: use_case = CloneModelBundleV2UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -107,7 +104,6 @@ async def list_model_bundles( """ Lists the ModelBundles owned by the current owner. """ - add_trace_resource_name("model_bundles_get") logger.info(f"GET /model-bundles?model_name={model_name}&order_by={order_by} for {auth}") use_case = ListModelBundlesV2UseCase( model_bundle_repository=external_interfaces.model_bundle_repository @@ -124,7 +120,6 @@ async def get_latest_model_bundle( """ Gets the latest Model Bundle with the given name owned by the current owner. """ - add_trace_resource_name("model_bundles_latest_get") logger.info(f"GET /model-bundles/latest?model_name={model_name} for {auth}") try: use_case = GetLatestModelBundleByNameV2UseCase( @@ -149,7 +144,6 @@ async def get_model_bundle( """ Gets the details for a given ModelBundle owned by the current owner. """ - add_trace_resource_name("model_bundles_id_get") logger.info(f"GET /model-bundles/{model_bundle_id} for {auth}") try: use_case = GetModelBundleByIdV2UseCase( diff --git a/model-engine/model_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py index e761d2c5..3b45f071 100644 --- a/model-engine/model_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -12,7 +12,6 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, @@ -57,7 +56,6 @@ async def create_model_endpoint( """ Creates a Model for the current user. """ - add_trace_resource_name("model_endpoints_post") logger.info(f"POST /model-endpoints with {request} for {auth}") try: use_case = CreateModelEndpointV1UseCase( @@ -104,7 +102,6 @@ async def list_model_endpoints( """ Lists the Models owned by the current owner. """ - add_trace_resource_name("model_endpoints_get") logger.info(f"GET /model-endpoints?name={name}&order_by={order_by} for {auth}") use_case = ListModelEndpointsV1UseCase( model_endpoint_service=external_interfaces.model_endpoint_service, @@ -123,7 +120,6 @@ async def get_model_endpoint( """ Describe the Model endpoint with given ID. """ - add_trace_resource_name("model_endpoints_id_get") logger.info(f"GET /model-endpoints/{model_endpoint_id} for {auth}") try: use_case = GetModelEndpointByIdV1UseCase( @@ -149,7 +145,6 @@ async def update_model_endpoint( """ Lists the Models owned by the current owner. """ - add_trace_resource_name("model_endpoints_id_put") logger.info(f"PUT /model-endpoints/{model_endpoint_id} with {request} for {auth}") try: use_case = UpdateModelEndpointByIdV1UseCase( @@ -192,7 +187,6 @@ async def delete_model_endpoint( """ Lists the Models owned by the current owner. """ - add_trace_resource_name("model_endpoints_id_delete") logger.info(f"DELETE /model-endpoints/{model_endpoint_id} for {auth}") try: use_case = DeleteModelEndpointByIdV1UseCase( diff --git a/model-engine/model_engine_server/api/tasks_v1.py b/model-engine/model_engine_server/api/tasks_v1.py index 25b97838..524f2f46 100644 --- a/model-engine/model_engine_server/api/tasks_v1.py +++ b/model-engine/model_engine_server/api/tasks_v1.py @@ -6,7 +6,6 @@ get_external_interfaces_read_only, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, @@ -49,7 +48,6 @@ async def create_async_inference_task( """ Runs an async inference prediction. """ - add_trace_resource_name("task_async_post") logger.info(f"POST /async-tasks {request} to endpoint {model_endpoint_id} for {auth}") try: use_case = CreateAsyncInferenceTaskV1UseCase( @@ -79,7 +77,6 @@ def get_async_inference_task( """ Gets the status of an async inference task. """ - add_trace_resource_name("task_async_id_get") logger.info(f"GET /async-tasks/{task_id} for {auth}") try: use_case = GetAsyncInferenceTaskV1UseCase( @@ -103,7 +100,6 @@ async def create_sync_inference_task( """ Runs a sync inference prediction. """ - add_trace_resource_name("task_sync_post") logger.info(f"POST /sync-tasks with {request} to endpoint {model_endpoint_id} for {auth}") try: use_case = CreateSyncInferenceTaskV1UseCase( @@ -143,7 +139,6 @@ async def create_streaming_inference_task( """ Runs a streaming inference prediction. """ - add_trace_resource_name("task_streaming_post") logger.info(f"POST /streaming-tasks with {request} to endpoint {model_endpoint_id} for {auth}") try: use_case = CreateStreamingInferenceTaskV1UseCase( diff --git a/model-engine/model_engine_server/api/triggers_v1.py b/model-engine/model_engine_server/api/triggers_v1.py index 30c95acd..010140af 100644 --- a/model-engine/model_engine_server/api/triggers_v1.py +++ b/model-engine/model_engine_server/api/triggers_v1.py @@ -4,7 +4,6 @@ get_external_interfaces, verify_authentication, ) -from model_engine_server.common.datadog_utils import add_trace_resource_name from model_engine_server.common.dtos.triggers import ( CreateTriggerV1Request, CreateTriggerV1Response, @@ -48,7 +47,6 @@ async def create_trigger( """ Creates and runs a trigger """ - add_trace_resource_name("triggers_post") logger.info(f"POST /triggers with {request} for {auth}") try: use_case = CreateTriggerUseCase( @@ -102,7 +100,6 @@ async def list_triggers( """ Lists descriptions of all triggers """ - add_trace_resource_name("triggers_get") logger.info(f"GET /triggers for {auth}") use_case = ListTriggersUseCase(trigger_repository=external_interfaces.trigger_repository) return await use_case.execute(user=auth) @@ -117,7 +114,6 @@ async def get_trigger( """ Describes the trigger with the given ID """ - add_trace_resource_name("triggers_id_get") logger.info(f"GET /triggers/{trigger_id} for {auth}") try: use_case = GetTriggerUseCase(trigger_repository=external_interfaces.trigger_repository) @@ -136,7 +132,6 @@ async def update_trigger( """ Updates the trigger with the given ID """ - add_trace_resource_name("triggers_id_put") logger.info(f"PUT /triggers/{trigger_id} with {request} for {auth}") try: use_case = UpdateTriggerUseCase( @@ -162,7 +157,6 @@ async def delete_trigger( """ Deletes the trigger with the given ID """ - add_trace_resource_name("trigger_id_delete") logger.info(f"DELETE /triggers/{trigger_id} for {auth}") try: use_case = DeleteTriggerUseCase( diff --git a/model-engine/model_engine_server/common/datadog_utils.py b/model-engine/model_engine_server/common/datadog_utils.py index c73fa2f9..3e3513cb 100644 --- a/model-engine/model_engine_server/common/datadog_utils.py +++ b/model-engine/model_engine_server/common/datadog_utils.py @@ -1,15 +1,6 @@ from ddtrace import tracer -def add_trace_resource_name(tag: str): - """Adds a custom tag to a given dd trace corresponding to the route - (e.g. get_model_bundles for GET /model-bundles, etc.) so that we can filter in Datadog easier - """ - current_span = tracer.current_span() - if current_span: - current_span.set_tag("launch.resource_name", tag) - - def add_trace_request_id(request_id: str): """Adds a custom tag to a given dd trace corresponding to the request id so that we can filter in Datadog easier From 97cb4e39e4eee3fce5b768abfb2dfa4633351d0e Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:35:14 -0700 Subject: [PATCH 163/425] Ianmacleod/add codellama instruct, refactor codellama models (#360) * add codellama-7b-instruct to llm-engine, name refactor for codellama * more name refactoring * update docs * add codellama-13b-instruct --- docs/model_zoo.md | 6 +++++- .../use_cases/llm_model_endpoint_use_cases.py | 16 ++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 148b1dfc..0431ea14 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -20,7 +20,11 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | | `mistral-7b` | ✅ | ✅ | vllm | | `mistral-7b-instruct` | ✅ | ✅ | vllm | -| `code-llama-7b` | ✅ | | text-generation-inference, vllm | +| `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | +| `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | +| `codellama-13b` | ✅ | | text-generation-inference, vllm | +| `codellama-13b-instruct` | ✅ | | text-generation-inference, vllm | +| `codellama-34b` | ✅ | | text-generation-inference, vllm | ## Usage diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 78218843..cf9fbbf7 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -102,9 +102,11 @@ "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", "falcon-40b": "tiiuae/falcon-40b", "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", - "code-llama-7b": "codellama/CodeLlama-7b-hf", - "code-llama-13b": "codellama/CodeLlama-13b-hf", - "code-llama-34b": "codellama/CodeLlama-34b-hf", + "codellama-7b": "codellama/CodeLlama-7b-hf", + "codellama-7b-instruct": "codellama/CodeLlama-7b-Instruct-hf", + "codellama-13b": "codellama/CodeLlama-13b-hf", + "codellama-13b-instruct": "codellama/CodeLlama-13b-Instruct-hf", + "codellama-34b": "codellama/CodeLlama-34b-hf", "llm-jp-13b-instruct-full": "llm-jp/llm-jp-13b-instruct-full-jaster-v1.0", "llm-jp-13b-instruct-full-dolly": "llm-jp/llm-jp-13b-instruct-full-dolly-oasst-v1.0", }, @@ -126,9 +128,11 @@ "mistral-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", "falcon-180b": "tiiuae/falcon-180B", "falcon-180b-chat": "tiiuae/falcon-180B-chat", - "code-llama-7b": "codellama/CodeLlama-7b-hf", - "code-llama-13b": "codellama/CodeLlama-13b-hf", - "code-llama-34b": "codellama/CodeLlama-34b-hf", + "codellama-7b": "codellama/CodeLlama-7b-hf", + "codellama-7b-instruct": "codellama/CodeLlama-7b-Instruct-hf", + "codellama-13b": "codellama/CodeLlama-13b-hf", + "codellama-13b-instruct": "codellama/CodeLlama-13b-Instruct-hf", + "codellama-34b": "codellama/CodeLlama-34b-hf", "mammoth-coder-llama-2-7b": "TIGER-Lab/MAmmoTH-Coder-7B", "mammoth-coder-llama-2-13b": "TIGER-Lab/MAmmoTH-Coder-13B", "mammoth-coder-llama-2-34b": "TIGER-Lab/MAmmoTH-Coder-34B", From f3118a672ca9fdd11b59221624d8c31e36980cf9 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:25:25 -0700 Subject: [PATCH 164/425] Various changes/bugfixes to chart/code to streamline deployment on different forms of infra (#339) * explicit 'run db init script' * add in dbInitScript default false value * add dbsecretawsprofile? * oops * make init_db script also use aws secret if available * make user inference image repository configurable * make the refresh codeartifact optional if we're unable to run the refresh codeartifact script * file not found error also * add optional image builder service account * mark todo * rename * use new hmi_config params * wip adjust kaniko cache name * kaniko cache * todo make a serviceaccount for the inference workers * inference service account template, kinda messy tbh * missed some spots * fix bug in helm chart * rerun integration tests * add to values_sample.yaml * fix number of zeros hehe --- charts/model-engine/templates/_helpers.tpl | 2 + .../templates/database_init_job.yaml | 2 +- .../service_account_image_builder.yaml | 19 ++++++ .../templates/service_account_inference.yaml | 18 ++++++ charts/model-engine/values.yaml | 2 + charts/model-engine/values_circleci.yaml | 4 ++ charts/model-engine/values_sample.yaml | 62 ++++++++++++++++--- .../model_engine_server/common/config.py | 4 ++ .../core/docker/kaniko_template.yaml | 1 + .../core/docker/remote_build.py | 38 ++++++++---- .../entrypoints/init_database.py | 25 ++++---- .../k8s_endpoint_resource_delegate.py | 1 - .../repositories/ecr_docker_repository.py | 2 + .../services/live_endpoint_builder_service.py | 9 ++- .../service_config_circleci.yaml | 4 ++ 15 files changed, 159 insertions(+), 34 deletions(-) create mode 100644 charts/model-engine/templates/service_account_image_builder.yaml create mode 100644 charts/model-engine/templates/service_account_inference.yaml diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index 8e9f1da3..7dd0410d 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -256,6 +256,8 @@ env: value: {{ .Values.aws.profileName }} - name: ECR_READ_AWS_PROFILE value: {{ .Values.aws.profileName }} + - name: DB_SECRET_AWS_PROFILE + value: {{ .Values.aws.profileName }} - name: S3_WRITE_AWS_PROFILE value: {{ .Values.aws.s3WriteProfileName }} {{- end }} diff --git a/charts/model-engine/templates/database_init_job.yaml b/charts/model-engine/templates/database_init_job.yaml index c87b7e92..0c273de9 100644 --- a/charts/model-engine/templates/database_init_job.yaml +++ b/charts/model-engine/templates/database_init_job.yaml @@ -1,4 +1,4 @@ -{{- if .Values.secrets.kubernetesDatabaseSecretName }} +{{- if or (.Values.secrets.kubernetesDatabaseSecretName) (.Values.db.runDbInitScript) }} apiVersion: batch/v1 kind: Job metadata: diff --git a/charts/model-engine/templates/service_account_image_builder.yaml b/charts/model-engine/templates/service_account_image_builder.yaml new file mode 100644 index 00000000..8cdec485 --- /dev/null +++ b/charts/model-engine/templates/service_account_image_builder.yaml @@ -0,0 +1,19 @@ +{{- if and (.Values.imageBuilderServiceAccount) (.Values.imageBuilderServiceAccount.create) }} +{{- $serviceAccountNamespaces := (include "modelEngine.serviceAccountNamespaces" . | fromYaml) }} +{{- $annotations := .Values.imageBuilderServiceAccount.annotations }} +{{- $labels := include "modelEngine.labels" . }} +{{- range $namespace := (index $serviceAccountNamespaces "namespaces") }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: kaniko + namespace: {{- printf " %s" $namespace }} + labels: + {{- $labels | nindent 4 }} + {{- with $annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +--- +{{- end }} +{{- end }} \ No newline at end of file diff --git a/charts/model-engine/templates/service_account_inference.yaml b/charts/model-engine/templates/service_account_inference.yaml new file mode 100644 index 00000000..9be37377 --- /dev/null +++ b/charts/model-engine/templates/service_account_inference.yaml @@ -0,0 +1,18 @@ +{{- if and (.Values.serviceTemplate) (.Values.serviceTemplate.createServiceAccount) (.Values.serviceTemplate.serviceAccountAnnotations) (.Values.serviceTemplate.serviceAccountName) (.Values.config.values.launch.endpoint_namespace)}} +{{- $annotations := .Values.serviceTemplate.serviceAccountAnnotations }} +{{- $inferenceServiceAccountName := .Values.serviceTemplate.serviceAccountName }} +{{- $inferenceServiceAccountNamespace := .Values.config.values.launch.endpoint_namespace }} +{{- $labels := include "modelEngine.labels" . }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{- printf " %s" $inferenceServiceAccountName }} + namespace: {{- printf " %s" $inferenceServiceAccountNamespace }} + labels: + {{- $labels | nindent 4 }} + {{- with $annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +--- +{{- end }} \ No newline at end of file diff --git a/charts/model-engine/values.yaml b/charts/model-engine/values.yaml index 666f1c4d..b75b6efa 100644 --- a/charts/model-engine/values.yaml +++ b/charts/model-engine/values.yaml @@ -3,6 +3,8 @@ spellbook: enabled: false redis: auth: +db: + runDbInitScript: false balloonNodeSelector: node-lifecycle: normal nodeSelector: diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index a82e2bad..60bf462b 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -146,6 +146,10 @@ config: tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" + user_inference_base_repository: "launch/inference" + user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" + user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" + docker_image_layer_cache_repository: "kaniko-cache" hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET" # Service Account diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 75bb808e..eb9d695b 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -1,25 +1,29 @@ # This is a YAML-formatted file. # tag [required] is the LLM Engine docker image tag -tag: 41ecada1b51ce3a46bbc3190a36ed7890db370d3 +tag: 60ac144c55aad971cdd7f152f4f7816ce2fb7d2f # context is a user-specified deployment tag. Can be used to context: production image: # gatewayRepository [required] is the docker repository to pull the LLM Engine gateway image from - gatewayRepository: public.ecr.aws/b2z8n5q1/llm-engine + gatewayRepository: public.ecr.aws/b2z8n5q1/model-engine # builderRepository [required] is the docker repository to pull the LLM Engine endpoint builder image from - builderRepository: public.ecr.aws/b2z8n5q1/llm-engine + builderRepository: public.ecr.aws/b2z8n5q1/model-engine # cacherRepository [required] is the docker repository to pull the LLM Engine cacher image from - cacherRepository: public.ecr.aws/b2z8n5q1/llm-engine + cacherRepository: public.ecr.aws/b2z8n5q1/model-engine # forwarderRepository [required] is the docker repository to pull the LLM Engine forwarder image from - forwarderRepository: public.ecr.aws/b2z8n5q1/llm-engine + forwarderRepository: public.ecr.aws/b2z8n5q1/model-engine # pullPolicy is the docker image pull policy pullPolicy: Always secrets: - # kubernetesDatabaseSecretName [required] is the name of the secret that contains the database credentials + # kubernetesDatabaseSecretName or awsDatabaseSecretName [required] + # is the name of the secret that contains the database credentials kubernetesDatabaseSecretName: llm-engine-postgres-credentials +db: + runDbInitScript: false + # serviceAccount [required] specifies the service account for LLM Engine server deployments (e.g gateway, cache, and builder deployments). serviceAccount: annotations: @@ -29,11 +33,35 @@ serviceAccount: "helm.sh/hook-weight": "-2" namespaces: [] +imageBuilderServiceAccount: + create: true + annotations: + # eks.amazonaws.com/role-arn [required] is the ARN of the IAM role that the image builder service account will assume. Needs to have ecr permissions + eks.amazonaws.com/role-arn: arn:aws:iam::000000000000:role/k8s-main-llm-engine-image-builder + # Reads from serviceAccount.namespaces to determine which namespaces to create the image builder service account in + # service specifies the service configuration for the main LLM Engine server. Users should setup their own ingress controller to expose the service. service: type: ClusterIP port: 80 +# virtualservice specifies the configuration of an Istio VirtualService +virtualservice: + enabled: true + annotations: { } + hostDomains: + - llm-engine.domain.com + gateways: + - default/internal-gateway + +hostDomain: + prefix: http:// + +# destinationrule specifies the configuration of an Istio DestinationRule +destinationrule: + enabled: true + annotations: { } + # replicaCount specifies the amount of replica pods for each deployment replicaCount: # gateway is the main LLM Engine server deployment @@ -92,6 +120,14 @@ serviceTemplate: drop: - all mountInfraConfig: true + # createServiceAccount/serviceAccountName/serviceAccountAnnotations specify whether to create a serviceAccount for + # inference pods. Assumes the inference pods run in a separate namespace to the LLM Engine control plane. + createServiceAccount: true + serviceAccountName: model-engine + serviceAccountAnnotations: + eks.amazonaws.com/role-arn: arn:aws:iam::000000000000:role/llm-engine + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" # config specifes the `data` field of the service config map config: @@ -111,7 +147,7 @@ config: redis_host: llm-engine-prod-cache.use1.cache.amazonaws.com # s3_bucket [required] is the S3 bucket you wish to connect s3_bucket: "llm-engine" - llm_engine: + launch: # endpoint_namespace [required] is K8s namespace the endpoints will be created in endpoint_namespace: llm-engine # cache_redis_url [required] is the full url for the redis cluster you wish to connect @@ -120,6 +156,7 @@ config: s3_file_llm_fine_tuning_job_repository: "s3://llm-engine/llm-ft-job-repository" # dd_trace_enabled specifies whether to enable datadog tracing, datadog must be installed in the cluster dd_trace_enabled: false + istio_enabled: true # Asynchronous endpoints configs (coming soon) sqs_profile: default @@ -155,6 +192,17 @@ config: "Spellbook-Serve-Endpoint-Name": "${endpoint_name}", "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" } + billing_queue_arn: "unused" + model_primitive_host: "unused" + hf_user_fine_tuned_weights_prefix: "s3://llm-engine/fine_tuned_weights" + + tgi_repository: "text-generation-inference" + vllm_repository: "vllm" + lightllm_repository: "lightllm" + user_inference_base_repository: "launch/inference" + user_inference_pytorch_repository: "launch/inference/pytorch" + user_inference_tensorflow_repository: "launch/inference/tf" + docker_image_layer_cache_repository: "launch-docker-build-cache" # Triton enhanced endpoints (coming soon) triton: diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 7d62b9dd..958881e1 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -57,6 +57,10 @@ class HostedModelInferenceServiceConfig: tgi_repository: str vllm_repository: str lightllm_repository: str + user_inference_base_repository: str + user_inference_pytorch_repository: str + user_inference_tensorflow_repository: str + docker_image_layer_cache_repository: str @classmethod def from_yaml(cls, yaml_path): diff --git a/model-engine/model_engine_server/core/docker/kaniko_template.yaml b/model-engine/model_engine_server/core/docker/kaniko_template.yaml index dfda89e3..a5bb5384 100644 --- a/model-engine/model_engine_server/core/docker/kaniko_template.yaml +++ b/model-engine/model_engine_server/core/docker/kaniko_template.yaml @@ -43,6 +43,7 @@ spec: env: - name: AWS_REGION value: us-west-2 + # TODO we need to parametrize AWS_REGION volumeMounts: - name: pipconf mountPath: /kaniko/pip diff --git a/model-engine/model_engine_server/core/docker/remote_build.py b/model-engine/model_engine_server/core/docker/remote_build.py index 26d58721..5b192064 100644 --- a/model-engine/model_engine_server/core/docker/remote_build.py +++ b/model-engine/model_engine_server/core/docker/remote_build.py @@ -124,6 +124,7 @@ def start_build_job( path_to_dockerfile: str, repotags: Iterable[str], use_cache: bool, + cache_name: str, build_args: Optional[Dict[str, str]] = None, custom_tags: Optional[Dict[str, str]] = None, ) -> str: @@ -172,7 +173,7 @@ def start_build_job( S3_BUCKET=S3_BUCKET, S3_FILE=s3_file_name, USE_CACHE="true" if use_cache else "false", - CACHE_REPO=f"{infra_config().docker_repo_prefix}/kaniko-cache", + CACHE_REPO=f"{infra_config().docker_repo_prefix}/{cache_name}", AWS_ACCESS_KEY_ID=aws_access_key_id, AWS_SECRET_ACCESS_KEY=aws_secret_access_key, NAMESPACE=NAMESPACE, @@ -196,15 +197,21 @@ def start_build_job( os.makedirs("/tmp") pip_conf_file = "/tmp/.codeartifact-pip-conf" aws_profile = infra_config().profile_ml_worker - subprocess.check_output( - [ - f"AWS_PROFILE={aws_profile} python scripts_py3/scale_scripts/exe/maybe_refresh_codeartifact.py --export {pip_conf_file}" - ], - cwd=str(MODELS_ROOT), - shell=True, - ) - with open(pip_conf_file) as f_conf: - pip_conf_base64 = b64encode(f_conf.read().encode("utf-8")).decode("utf-8") + try: + # nosemgrep + subprocess.check_output( + [ + f"AWS_PROFILE={aws_profile} python scripts_py3/scale_scripts/exe/maybe_refresh_codeartifact.py --export {pip_conf_file}" + ], + cwd=str(MODELS_ROOT), + shell=True, + ) + with open(pip_conf_file) as f_conf: + pip_conf_data = f_conf.read() + except (subprocess.CalledProcessError, FileNotFoundError): + print("WARNING: Failed to refresh CodeArtifact token secret, using empty secret") + pip_conf_data = "" + pip_conf_base64 = b64encode(pip_conf_data.encode("utf-8")).decode("utf-8") data = {"data": {"codeartifact_pip_conf": pip_conf_base64}} subprocess.check_output( ["kubectl", "patch", "secret", "codeartifact-pip-conf", f"-p={json.dumps(data)}"] @@ -223,6 +230,7 @@ def build_remote( repotags: Union[str, Iterable[str]], folders_to_include: Optional[List[str]] = None, use_cache: bool = True, + cache_name: str = "kaniko-cache", ignore_file: Optional[str] = None, build_args: Optional[Dict[str, str]] = None, custom_tags: Optional[Dict[str, str]] = None, @@ -284,7 +292,9 @@ def build_remote( folders_to_include=folders_to_include, ignore_file=ignore_file, ) - return start_build_job(s3_file_name, dockerfile, repotags, use_cache, build_args, custom_tags) + return start_build_job( + s3_file_name, dockerfile, repotags, use_cache, cache_name, build_args, custom_tags + ) def verify_and_reformat_as_relative_to(context: str, dockerfile: str) -> str: @@ -414,6 +424,7 @@ def build_remote_block( repotags: Union[str, Iterable[str]], folders_to_include: Optional[List[str]] = None, use_cache: bool = True, + cache_name: str = "kaniko-cache", ignore_file: Optional[str] = None, build_args: Optional[Dict[str, str]] = None, custom_tags: Optional[Dict[str, str]] = None, @@ -438,6 +449,7 @@ def build_remote_block( repotags, folders_to_include, use_cache, + cache_name, ignore_file, build_args, custom_tags, @@ -522,6 +534,8 @@ def build_remote_wrapper( custom_tags = json.loads(custom_tags) folders_to_include: Optional[List[str]] = folders.split(",") if folders is not None else None + cache_name = "kaniko-cache" + build_args = None if build_arg: build_arg_kvs = [arg.split("=") for arg in build_arg] @@ -534,6 +548,7 @@ def build_remote_wrapper( repotags=repotag, folders_to_include=folders_to_include, use_cache=not no_cache, + cache_name=cache_name, ignore_file=dockerignore, build_args=build_args, custom_tags=custom_tags, @@ -545,6 +560,7 @@ def build_remote_wrapper( repotags=repotag, folders_to_include=folders_to_include, use_cache=not no_cache, + cache_name=cache_name, ignore_file=dockerignore, build_args=build_args, custom_tags=custom_tags, diff --git a/model-engine/model_engine_server/entrypoints/init_database.py b/model-engine/model_engine_server/entrypoints/init_database.py index 30ca1a1c..5f80ef64 100644 --- a/model-engine/model_engine_server/entrypoints/init_database.py +++ b/model-engine/model_engine_server/entrypoints/init_database.py @@ -2,7 +2,7 @@ import os import psycopg2 -from model_engine_server.db.base import Base +from model_engine_server.db.base import Base, get_engine_url from model_engine_server.db.models import * from sqlalchemy import create_engine from sqlalchemy.engine import Engine @@ -38,13 +38,16 @@ def init_database_and_engine(database_url) -> Engine: if __name__ == "__main__": url = os.getenv("ML_INFRA_DATABASE_URL") - if url is not None: - for attempt in Retrying( - stop=stop_after_attempt(6), - wait=wait_exponential(), - reraise=True, - ): - with attempt: - init_database_and_engine(url) - - print(f"Successfully initialized database at {url}") + # If we are at this point, we want to init the db. + if url is None: + print("No k8s secret for DB url found, trying AWS secret") + url = get_engine_url(read_only=False, sync=True) + for attempt in Retrying( + stop=stop_after_attempt(6), + wait=wait_exponential(), + reraise=True, + ): + with attempt: + init_database_and_engine(url) + + print(f"Successfully initialized database at {url}") diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index ae8576e3..024ca99e 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -134,7 +134,6 @@ def get_kubernetes_autoscaling_client(): # pragma: no cover _kubernetes_autoscaling_api = kubernetes_asyncio.client.AutoscalingV2Api() else: _kubernetes_autoscaling_api = kubernetes_asyncio.client.AutoscalingV2beta2Api() - _kubernetes_autoscaling_api = kubernetes_asyncio.client.AutoscalingV2beta2Api() return _kubernetes_autoscaling_api diff --git a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py index 47aeb61c..16c6b742 100644 --- a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py @@ -1,5 +1,6 @@ from typing import Optional +from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse from model_engine_server.core.config import infra_config from model_engine_server.core.docker.ecr import image_exists as ecr_image_exists @@ -46,6 +47,7 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: repotags=[f"{image_params.repo}:{image_params.image_tag}"], folders_to_include=folders_to_include, build_args=build_args, + cache_name=hmi_config.docker_image_layer_cache_repository, ) return BuildImageResponse( status=build_result.status, logs=build_result.logs, job_name=build_result.job_name diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 30ca9b7b..bef91df0 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Sequence, Set from datadog import statsd +from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse from model_engine_server.common.dtos.endpoint_builder import ( BuildEndpointRequest, @@ -498,7 +499,7 @@ def get_base_image_params( logger_adapter.info(f"inference_folder: {inference_folder}") logger_adapter.info(f"dockerfile: {inference_folder}/{dockerfile}") return BuildImageRequest( - repo="launch/inference", + repo=hmi_config.user_inference_base_repository, image_tag=resulting_image_tag[:MAX_IMAGE_TAG_LEN], aws_profile=ECR_AWS_PROFILE, # type: ignore base_path=base_path, @@ -529,7 +530,7 @@ def _get_user_image_params( dockerfile = "pytorch_or_tf.user.Dockerfile" service_image_tag = self._get_image_tag(base_image_tag, GIT_TAG, requirements_hash) - ecr_repo = "hosted-model-inference/async-pytorch" + ecr_repo = hmi_config.user_inference_pytorch_repository elif isinstance(env_params, TensorflowFramework): if build_endpoint_request.gpus > 0: raise NotImplementedError("Tensorflow GPU image not supported yet") @@ -541,7 +542,7 @@ def _get_user_image_params( raise ValueError("Tensorflow version must be specified if the framework is TF.") dockerfile = "pytorch_or_tf.user.Dockerfile" service_image_tag = self._get_image_tag(tensorflow_version, GIT_TAG, requirements_hash) - ecr_repo = "hosted-model-inference/async-tensorflow-cpu" + ecr_repo = hmi_config.user_inference_tensorflow_repository elif isinstance(env_params, CustomFramework): if ( env_params.image_tag is None or env_params.image_repository is None @@ -596,6 +597,7 @@ def _get_inject_bundle_image_params( bundle_id = model_bundle.id service_image_str = "-".join([base_image_params.image_tag, GIT_TAG, bundle_id]) + # nosemgrep service_image_hash = hashlib.md5(str(service_image_str).encode("utf-8")).hexdigest() service_image_tag = f"inject-bundle-image-{service_image_hash}" ecr_repo = base_image_params.repo @@ -803,6 +805,7 @@ def _get_restricted_env_vars(env_vars: Dict[str, str]) -> Set[str]: @staticmethod def _get_requirements_hash(requirements: List[str]) -> str: """Identifying hash for endpoint's Python requirements.""" + # nosemgrep return hashlib.md5("\n".join(sorted(requirements)).encode("utf-8")).hexdigest()[:6] @staticmethod diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index a6c98e9b..3438f65d 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -57,6 +57,10 @@ istio_enabled: true tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" +user_inference_base_repository: "launch/inference" +user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" +user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" +docker_image_layer_cache_repository: "kaniko-cache" # S3 access hf_user_fine_tuned_weights_prefix: "s3://test-bucket" From cd9054977373fcb2e65d42ea709aee827be9f9ea Mon Sep 17 00:00:00 2001 From: William Song Date: Fri, 3 Nov 2023 14:32:36 -0700 Subject: [PATCH 165/425] Add PR template (#341) * add PR template * fill in PR template content --- .github/PULL_REQUEST_TEMPLATE.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..a5eb802d --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,7 @@ +# Pull Request Summary + +_What is this PR changing? Why is this change being made? Any caveats you'd like to highlight? Link any relevant documents, links, or screenshots here if applicable._ + +## Test Plan and Usage Guide + +_How did you validate that your PR works correctly? How do you run or demo the code? Provide enough detail so a reviewer can reasonably reproduce the testing procedure. Paste example command line invocations if applicable._ From 5f7c5667b11a2aa803466e1d42d974760c0fa586 Mon Sep 17 00:00:00 2001 From: William Song Date: Fri, 3 Nov 2023 15:12:00 -0700 Subject: [PATCH 166/425] Unmount aws config from root (#361) * unmount aws config from root * move mount out of /root * modify main_env in python rather than templates --- .../service_template_config_map.yaml | 23 +++++++++++-------- .../gateways/resources/k8s_resource_types.py | 2 ++ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 281bc2ea..15410122 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -243,6 +243,8 @@ data: env: - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" ports: - containerPort: 8000 name: http @@ -271,7 +273,7 @@ data: ${TRITON_STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -309,7 +311,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -317,10 +319,6 @@ data: - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} {{- end }} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -565,6 +563,8 @@ data: env: - name: DD_SERVICE value: ${RESOURCE_NAME} + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" {{- $env_vars := $service_env | fromYaml }} {{- range $env_var := index $env_vars "env" }} {{- $env_var_name := index $env_var "name" }} @@ -601,7 +601,7 @@ data: memory: 32Gi volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config {{- range $device := tuple "cpu" "gpu" }} docker-image-batch-job-{{- $device }}.yaml: |- @@ -657,6 +657,8 @@ data: env: - name: DD_SERVICE value: ${RESOURCE_NAME} + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" {{- $env_vars := $service_env | fromYaml }} {{- range $env_var := index $env_vars "env" }} {{- $env_var_name := index $env_var "name" }} @@ -684,7 +686,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} @@ -693,6 +695,9 @@ data: initContainers: - name: input-downloader image: {{ $gateway_repository }}:${GIT_TAG} + env: + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" command: - python - -m @@ -713,7 +718,7 @@ data: memory: 1Gi volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index c86d18fd..ed957fcc 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -518,6 +518,8 @@ def get_endpoint_resource_arguments_from_request( if isinstance(flavor, RunnableImageLike) and flavor.env: main_env = [{"name": key, "value": value} for key, value in flavor.env.items()] main_env.append({"name": "AWS_PROFILE", "value": build_endpoint_request.aws_role}) + # NOTE: /opt/.aws/config is where service_template_config_map.yaml mounts the AWS config file, point to the mount for boto clients + main_env.append({"name": "AWS_CONFIG_FILE", "value": "/opt/.aws/config"}) infra_service_config_volume_mount_path = "/infra-config" forwarder_config_file_name = "service--forwarder.yaml" From 751bf3889973f2c8ecca94d10f16717342c448f4 Mon Sep 17 00:00:00 2001 From: tiffzhao5 <142925794+tiffzhao5@users.noreply.github.com> Date: Mon, 6 Nov 2023 11:48:18 -0800 Subject: [PATCH 167/425] Implement automated code coverage for CI (#362) * add test coverage * temp * change diff cover version * Revert "temp" This reverts commit 4eaec10b9cab0b44eb06c1158c9e5dcaa2085037. --- .circleci/config.yml | 3 ++- model-engine/requirements-test.txt | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 63763dea..b10843eb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -260,5 +260,6 @@ commands: name: Unit Tests command: | pushd model-engine - GIT_TAG=$(git rev-parse HEAD) WORKSPACE=.. pytest + GIT_TAG=$(git rev-parse HEAD) WORKSPACE=.. pytest --cov --cov-report=xml + diff-cover coverage.xml --compare-branch=origin/main --fail-under=80 popd diff --git a/model-engine/requirements-test.txt b/model-engine/requirements-test.txt index 9ad7b6e2..158e0743 100644 --- a/model-engine/requirements-test.txt +++ b/model-engine/requirements-test.txt @@ -2,6 +2,7 @@ multiprocess==0.70.14 pytest==7.2.0 pytest-asyncio==0.20.1 pytest-cov==2.10.0 +diff-cover==7.7.0 moto==3.1.12 coverage==5.5 mypy==1.3.0 From c0f08ed474ba86c906ba8d727803df5156b0ae5e Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 6 Nov 2023 18:20:07 -0800 Subject: [PATCH 168/425] Download only known files (#364) --- .../use_cases/llm_model_endpoint_use_cases.py | 44 ++++++++----------- .../tests/unit/domain/test_llm_use_cases.py | 28 +++++------- 2 files changed, 30 insertions(+), 42 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index cf9fbbf7..ed840ec2 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -174,23 +174,22 @@ DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes -def _exclude_safetensors_or_bin(model_files: List[str]) -> Optional[str]: +def _include_safetensors_bin_or_pt(model_files: List[str]) -> Optional[str]: """ - This function is used to determine whether to exclude "*.safetensors" or "*.bin" files - based on which file type is present more often in the checkpoint folder. The less - frequently present file type is excluded. - If both files are equally present, no exclusion string is returned. + This function is used to determine whether to include "*.safetensors", "*.bin", or "*.pt" files + based on which file type is present most often in the checkpoint folder. The most + frequently present file type is included. + In case of ties, priority is given to "*.safetensors", then "*.bin", then "*.pt". """ - exclude_str = None - if len([f for f in model_files if f.endswith(".safetensors")]) > len( - [f for f in model_files if f.endswith(".bin")] - ): - exclude_str = "*.bin" - elif len([f for f in model_files if f.endswith(".safetensors")]) < len( - [f for f in model_files if f.endswith(".bin")] - ): - exclude_str = "*.safetensors" - return exclude_str + num_safetensors = len([f for f in model_files if f.endswith(".safetensors")]) + num_bin = len([f for f in model_files if f.endswith(".bin")]) + num_pt = len([f for f in model_files if f.endswith(".pt")]) + maximum = max(num_safetensors, num_bin, num_pt) + if num_safetensors == maximum: + return "*.safetensors" + if num_bin == maximum: + return "*.bin" + return "*.pt" def _model_endpoint_entity_to_get_llm_model_endpoint_response( @@ -436,16 +435,11 @@ def load_model_weights_sub_commands( checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) model_files = [f for f in checkpoint_files if "model" in f] - exclude_str = _exclude_safetensors_or_bin(model_files) - - if exclude_str is None: - subcommands.append( - f"{s5cmd} --numworkers 512 cp --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}" - ) - else: - subcommands.append( - f"{s5cmd} --numworkers 512 cp --concurrency 10 --exclude '{exclude_str}' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" - ) + include_str = _include_safetensors_bin_or_pt(model_files) + file_selection_str = f"--include '*.model' --include '*.json' --include '{include_str}' --exclude 'optimizer*'" + subcommands.append( + f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) return subcommands diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index c71995ea..ed1eee41 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -37,7 +37,7 @@ DeleteLLMEndpointByNameUseCase, GetLLMModelEndpointByNameV1UseCase, ModelDownloadV1UseCase, - _exclude_safetensors_or_bin, + _include_safetensors_bin_or_pt, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -1029,37 +1029,31 @@ async def test_delete_public_inference_model_raises_not_authorized( @pytest.mark.asyncio -async def test_exclude_safetensors_or_bin_majority_bin_returns_exclude_safetensors(): - fake_model_files = [ - "fake.bin", - "fake2.bin", - "fake3.safetensors", - "model.json", - "optimizer.pt", - ] - assert _exclude_safetensors_or_bin(fake_model_files) == "*.safetensors" +async def test_include_safetensors_bin_or_pt_majority_safetensors(): + fake_model_files = ["fake.bin", "fake2.safetensors", "model.json", "optimizer.pt"] + assert _include_safetensors_bin_or_pt(fake_model_files) == "*.safetensors" @pytest.mark.asyncio -async def test_exclude_safetensors_or_bin_majority_safetensors_returns_exclude_bin(): +async def test_include_safetensors_bin_or_pt_majority_bin(): fake_model_files = [ "fake.bin", - "fake2.safetensors", + "fake2.bin", "fake3.safetensors", "model.json", "optimizer.pt", + "fake4.pt", ] - assert _exclude_safetensors_or_bin(fake_model_files) == "*.bin" + assert _include_safetensors_bin_or_pt(fake_model_files) == "*.bin" @pytest.mark.asyncio -async def test_exclude_safetensors_or_bin_equal_bins_and_safetensors_returns_none(): +async def test_include_safetensors_bin_or_pt_majority_pt(): fake_model_files = [ "fake.bin", "fake2.safetensors", - "fake3.safetensors", - "fake4.bin", "model.json", "optimizer.pt", + "fake3.pt", ] - assert _exclude_safetensors_or_bin(fake_model_files) is None + assert _include_safetensors_bin_or_pt(fake_model_files) == "*.pt" From 8614dabfc8c0fc8b3b6f800cae6ec8c130db4f0f Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 7 Nov 2023 09:45:47 -0800 Subject: [PATCH 169/425] Documentation fix (#365) --- clients/python/llmengine/fine_tuning.py | 2 +- docs/api/data_types.md | 36 ++++++++++++------------- docs/api/python_client.md | 8 +++--- mkdocs.yml | 6 ++--- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index bf9dcf0d..b0f73d6b 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -59,7 +59,7 @@ def create( validation_file (`Optional[str]`): Publicly accessible URL or file ID referencing a CSV file for validation. The validation file is used to compute metrics which let LLM Engine pick the best fine-tuned checkpoint, which will be used for inference when fine-tuning is complete. - hyperparameters (`Optional[Dict[str, str]]`): + hyperparameters (`Optional[Dict[str, Union[str, int, float, Dict[str, Any]]]]`): A dict of hyperparameters to customize fine-tuning behavior. Currently supported hyperparameters: diff --git a/docs/api/data_types.md b/docs/api/data_types.md index 12594058..44dd3d8f 100644 --- a/docs/api/data_types.md +++ b/docs/api/data_types.md @@ -1,14 +1,14 @@ # 🐍 Python Client Data Type Reference ::: llmengine.CompletionOutput - selection: + options: members: - text - num_prompt_tokens - num_completion_tokens ::: llmengine.CompletionStreamOutput - selection: + options: members: - text - finished @@ -16,40 +16,40 @@ - num_completion_tokens ::: llmengine.CompletionSyncResponse - selection: + options: members: - request_id - output ::: llmengine.CompletionStreamResponse - selection: + options: members: - request_id - output ::: llmengine.CreateFineTuneResponse - selection: + options: members: - id ::: llmengine.GetFineTuneResponse - selection: + options: members: - id - fine_tuned_model ::: llmengine.ListFineTunesResponse - selection: + options: members: - jobs ::: llmengine.CancelFineTuneResponse - selection: + options: members: - success ::: llmengine.GetLLMEndpointResponse - selection: + options: members: - name - source @@ -63,50 +63,50 @@ - spec ::: llmengine.ListLLMEndpointsResponse - selection: + options: members: - model_endpoints ::: llmengine.DeleteLLMEndpointResponse - selection: + options: members: - deleted ::: llmengine.ModelDownloadRequest - selection: + options: members: - model_name - download_format ::: llmengine.ModelDownloadResponse - selection: + options: members: - urls ::: llmengine.UploadFileResponse - selection: + options: members: - id ::: llmengine.GetFileResponse - selection: + options: members: - id - filename - size ::: llmengine.GetFileContentResponse - selection: + options: members: - id - content ::: llmengine.ListFilesResponse - selection: + options: members: - files ::: llmengine.DeleteFileResponse - selection: + options: members: - deleted diff --git a/docs/api/python_client.md b/docs/api/python_client.md index bdbc6f3e..d77d28bc 100644 --- a/docs/api/python_client.md +++ b/docs/api/python_client.md @@ -1,13 +1,13 @@ # 🐍 Python Client API Reference ::: llmengine.Completion - selection: + options: members: - create - acreate ::: llmengine.FineTune - selection: + options: members: - create - get @@ -16,7 +16,7 @@ - cancel ::: llmengine.Model - selection: + options: members: - create - get @@ -25,7 +25,7 @@ - download ::: llmengine.File - selection: + options: members: - upload - get diff --git a/mkdocs.yml b/mkdocs.yml index a24b4763..b719a2cb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -88,17 +88,17 @@ markdown_extensions: - neoteroi.cards - footnotes +watch: + - clients/python/llmengine + plugins: - search - mkdocstrings: - watch: - - clients/python/llmengine handlers: python: options: separate_signature: true line_length: 60 - rendering: show_root_heading: true show_root_full_path: false show_source: false From ce1de6b7a4827827955cd9a4a3fa79601116377b Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 7 Nov 2023 18:33:33 -0800 Subject: [PATCH 170/425] Change more AWS config mount paths (#367) --- charts/model-engine/templates/_helpers.tpl | 6 +- charts/model-engine/values_circleci.yaml | 2 +- .../core/docker/docker_image.py | 4 +- .../service_template_config_map_circleci.yaml | 58 +++++++++---------- 4 files changed, 37 insertions(+), 33 deletions(-) diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index 7dd0410d..d367f039 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -203,6 +203,8 @@ env: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -254,6 +256,8 @@ env: {{- if .Values.aws }} - name: AWS_PROFILE value: {{ .Values.aws.profileName }} + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: ECR_READ_AWS_PROFILE value: {{ .Values.aws.profileName }} - name: DB_SECRET_AWS_PROFILE @@ -384,7 +388,7 @@ volumeMounts: {{- define "modelEngine.forwarderVolumeMounts" }} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 60bf462b..1cc777e3 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -166,7 +166,7 @@ aws: configMap: name: default-config create: false - mountPath: /root/.aws/config + mountPath: /opt/.aws/config profileName: default s3WriteProfileName: default diff --git a/model-engine/model_engine_server/core/docker/docker_image.py b/model-engine/model_engine_server/core/docker/docker_image.py index 66edb928..8d68f8c8 100644 --- a/model-engine/model_engine_server/core/docker/docker_image.py +++ b/model-engine/model_engine_server/core/docker/docker_image.py @@ -158,13 +158,13 @@ def build( command=test_command, volumes={ os.path.join(home_dir, ".aws"): { - "bind": "/root/.aws/config", + "bind": "/opt/.aws/config", "mode": "ro", } }, environment={ "AWS_PROFILE": infra_config().profile_ml_worker, - "AWS_CONFIG_FILE": "/root/.aws/config", + "AWS_CONFIG_FILE": "/opt/.aws/config", }, remove=True, ) diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 93a779c2..8e014e18 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -167,7 +167,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -217,7 +217,7 @@ data: ${TRITON_STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -247,7 +247,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -438,7 +438,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -474,7 +474,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -659,7 +659,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -712,7 +712,7 @@ data: ${TRITON_STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -742,7 +742,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -927,7 +927,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -966,7 +966,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -1153,7 +1153,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -1192,7 +1192,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -1388,7 +1388,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -1438,7 +1438,7 @@ data: ${TRITON_STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -1470,7 +1470,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -1666,7 +1666,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -1704,7 +1704,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -1894,7 +1894,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -1947,7 +1947,7 @@ data: ${TRITON_STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -1979,7 +1979,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -2169,7 +2169,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -2210,7 +2210,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -2402,7 +2402,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -2443,7 +2443,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -2872,7 +2872,7 @@ data: memory: 32Gi volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config docker-image-batch-job-cpu.yaml: |- apiVersion: batch/v1 @@ -2982,7 +2982,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} @@ -3011,7 +3011,7 @@ data: memory: 1Gi volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} @@ -3130,7 +3130,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} @@ -3159,7 +3159,7 @@ data: memory: 1Gi volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} From a5dbfa555e8e81aaf1e199417f23f2ef864ea0f8 Mon Sep 17 00:00:00 2001 From: tiffzhao5 <142925794+tiffzhao5@users.noreply.github.com> Date: Wed, 8 Nov 2023 18:38:03 -0800 Subject: [PATCH 171/425] Validating inference framework image tags (#357) * ensuring invalid image tag errors are surfaced to users clearly * adding new vllm version * update error message, handling for deepspeed * update conftest * update conftest * more fixes to tags * add unit test * fix * check ecr image * catch docker image exception * fix * revert removal commit * fix + refactor --------- Co-authored-by: Ian Macleod Co-authored-by: Ian Macleod <139901935+ian-scale@users.noreply.github.com> --- .../model_engine_server/api/llms_v1.py | 7 +++ .../use_cases/llm_model_endpoint_use_cases.py | 55 ++++++++++------- model-engine/tests/unit/conftest.py | 8 +-- model-engine/tests/unit/domain/conftest.py | 4 +- .../tests/unit/domain/test_llm_use_cases.py | 61 ++++++++++++++++++- 5 files changed, 106 insertions(+), 29 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 9e811f91..038cdc5f 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -42,6 +42,7 @@ make_logger, ) from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, EndpointDeleteFailedException, EndpointLabelsException, EndpointResourceInvalidRequestException, @@ -144,6 +145,7 @@ async def create_model_endpoint( model_bundle_repository=external_interfaces.model_bundle_repository, model_endpoint_service=external_interfaces.model_endpoint_service, llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + docker_repository=external_interfaces.docker_repository, ) return await use_case.execute(user=auth, request=request) except ObjectAlreadyExistsException as exc: @@ -173,6 +175,11 @@ async def create_model_endpoint( status_code=404, detail="The specified model bundle could not be found.", ) from exc + except DockerImageNotFoundException as exc: + raise HTTPException( + status_code=404, + detail="The specified docker image could not be found.", + ) from exc @llm_router_v1.get("/model-endpoints", response_model=ListLLMModelEndpointsV1Response) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index ed840ec2..829f4801 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -47,6 +47,7 @@ StreamingEnhancedRunnableImageFlavor, ) from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, EndpointLabelsException, EndpointUnsupportedInferenceTypeException, InvalidRequestException, @@ -57,6 +58,7 @@ ) from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway from model_engine_server.domain.repositories import ModelBundleRepository +from model_engine_server.domain.repositories.docker_repository import DockerRepository from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway @@ -254,12 +256,26 @@ def __init__( model_bundle_repository: ModelBundleRepository, model_endpoint_service: ModelEndpointService, llm_artifact_gateway: LLMArtifactGateway, + docker_repository: DockerRepository, ): self.authz_module = LiveAuthorizationModule() self.create_model_bundle_use_case = create_model_bundle_use_case self.model_bundle_repository = model_bundle_repository self.model_endpoint_service = model_endpoint_service self.llm_artifact_gateway = llm_artifact_gateway + self.docker_repository = docker_repository + + def check_docker_image_exists_for_image_tag( + self, framework_image_tag: str, repository_name: str + ): + if not self.docker_repository.image_exists( + image_tag=framework_image_tag, + repository_name=repository_name, + ): + raise DockerImageNotFoundException( + repository=repository_name, + tag=framework_image_tag, + ) async def create_model_bundle( self, @@ -276,6 +292,7 @@ async def create_model_bundle( ) -> ModelBundle: if source == LLMSource.HUGGING_FACE: if framework == LLMInferenceFramework.DEEPSPEED: + self.check_docker_image_exists_for_image_tag(framework_image_tag, "instant-llm") bundle_id = await self.create_deepspeed_bundle( user, model_name, @@ -284,6 +301,9 @@ async def create_model_bundle( endpoint_name, ) elif framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + self.check_docker_image_exists_for_image_tag( + framework_image_tag, hmi_config.tgi_repository + ) bundle_id = await self.create_text_generation_inference_bundle( user, model_name, @@ -294,6 +314,9 @@ async def create_model_bundle( checkpoint_path, ) elif framework == LLMInferenceFramework.VLLM: + self.check_docker_image_exists_for_image_tag( + framework_image_tag, hmi_config.vllm_repository + ) bundle_id = await self.create_vllm_bundle( user, model_name, @@ -304,6 +327,9 @@ async def create_model_bundle( checkpoint_path, ) elif framework == LLMInferenceFramework.LIGHTLLM: + self.check_docker_image_exists_for_image_tag( + framework_image_tag, hmi_config.lightllm_repository + ) bundle_id = await self.create_lightllm_bundle( user, model_name, @@ -713,7 +739,6 @@ async def execute( if request.inference_framework in [ LLMInferenceFramework.TEXT_GENERATION_INFERENCE, LLMInferenceFramework.VLLM, - LLMInferenceFramework.LIGHTLLM, ]: if request.endpoint_type != ModelEndpointType.STREAMING: raise ObjectHasInvalidValueException( @@ -952,10 +977,7 @@ def validate_and_update_completion_params( if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: request.top_k = None if request.top_k == -1 else request.top_k request.top_p = None if request.top_p == 1.0 else request.top_p - if inference_framework in [ - LLMInferenceFramework.VLLM, - LLMInferenceFramework.LIGHTLLM, - ]: + if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: request.top_k = -1 if request.top_k is None else request.top_k request.top_p = 1.0 if request.top_p is None else request.top_p else: @@ -965,10 +987,7 @@ def validate_and_update_completion_params( ) # presence_penalty, frequency_penalty - if inference_framework in [ - LLMInferenceFramework.VLLM, - LLMInferenceFramework.LIGHTLLM, - ]: + if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: request.presence_penalty = ( 0.0 if request.presence_penalty is None else request.presence_penalty ) @@ -1005,7 +1024,6 @@ def model_output_to_completion_output( with_token_probs: Optional[bool], ) -> CompletionOutput: model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) - if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: completion_token_count = len(model_output["token_probs"]["tokens"]) tokens = None @@ -1043,10 +1061,7 @@ def model_output_to_completion_output( tokens = None if with_token_probs: tokens = [ - TokenOutput( - token=model_output["tokens"][index], - log_prob=list(t.values())[0], - ) + TokenOutput(token=model_output["tokens"][index], log_prob=list(t.values())[0]) for index, t in enumerate(model_output["log_probs"]) ] return CompletionOutput( @@ -1160,8 +1175,7 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, - predict_request=inference_request, + topic=model_endpoint.record.destination, predict_request=inference_request ) if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: @@ -1204,8 +1218,7 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, - predict_request=inference_request, + topic=model_endpoint.record.destination, predict_request=inference_request ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1244,8 +1257,7 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, - predict_request=inference_request, + topic=model_endpoint.record.destination, predict_request=inference_request ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1287,8 +1299,7 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, - predict_request=inference_request, + topic=model_endpoint.record.destination, predict_request=inference_request ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 92d2e88b..e9dd1e44 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -3690,7 +3690,7 @@ def llm_model_endpoint_sync_tgi( "model_name": "llama-7b", "source": "hugging_face", "inference_framework": "text_generation_inference", - "inference_framework_image_tag": "123", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, } }, @@ -3752,7 +3752,7 @@ def llm_model_endpoint_sync_tgi( "source": "hugging_face", "status": "READY", "inference_framework": "text_generation_inference", - "inference_framework_image_tag": "123", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, "spec": { "id": "test_llm_model_endpoint_id_2", @@ -3765,7 +3765,7 @@ def llm_model_endpoint_sync_tgi( "model_name": "llama-7b", "source": "hugging_face", "inference_framework": "text_generation_inference", - "inference_framework_image_tag": "123", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, } }, @@ -3887,7 +3887,7 @@ def llm_model_endpoint_text_generation_inference( "model_name": "llama-7b", "source": "hugging_face", "inference_framework": "text_generation_inference", - "inference_framework_image_tag": "123", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, } }, diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 6a958ed4..c27aaa52 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -254,7 +254,7 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque model_name="llama-2-7b", source="hugging_face", inference_framework="text_generation_inference", - inference_framework_image_tag="test_tag", + inference_framework_image_tag="0.9.4", num_shards=2, endpoint_type=ModelEndpointType.STREAMING, metadata={}, @@ -310,7 +310,7 @@ def create_llm_model_endpoint_text_generation_inference_request_async() -> ( model_name="mpt-7b", source="hugging_face", inference_framework="text_generation_inference", - inference_framework_image_tag="test_tag", + inference_framework_image_tag="0.9.4", num_shards=2, quantize=Quantization.BITSANDBYTES, endpoint_type=ModelEndpointType.ASYNC, diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index ed1eee41..c83e1049 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -14,8 +14,13 @@ ) from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.domain.entities import ModelEndpoint, ModelEndpointType +from model_engine_server.domain.entities import ( + LLMInferenceFramework, + ModelEndpoint, + ModelEndpointType, +) from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, EndpointUnsupportedInferenceTypeException, InvalidRequestException, LLMFineTuningQuotaReached, @@ -66,6 +71,7 @@ async def test_create_model_endpoint_use_case_success( model_bundle_repository=fake_model_bundle_repository, model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async) @@ -147,6 +153,56 @@ async def test_create_model_endpoint_use_case_success( assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1] +@pytest.mark.asyncio +@pytest.mark.parametrize( + "valid, inference_framework, inference_framework_image_tag", + [ + (False, LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "0.9.2"), + (True, LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "0.9.3"), + (False, LLMInferenceFramework.VLLM, "0.1.6"), + (True, LLMInferenceFramework.VLLM, "0.1.3.6"), + ], +) +async def test_create_model_bundle_inference_framework_image_tag_validation( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_docker_repository_image_never_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request, + valid, + inference_framework, + inference_framework_image_tag, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + + use_case = CreateLLMModelEndpointV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + + request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() + request.inference_framework = inference_framework + request.inference_framework_image_tag = inference_framework_image_tag + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + if valid: + await use_case.execute(user=user, request=request) + else: + use_case.docker_repository = fake_docker_repository_image_never_exists + with pytest.raises(DockerImageNotFoundException): + await use_case.execute(user=user, request=request) + + @pytest.mark.asyncio async def test_create_model_endpoint_text_generation_inference_use_case_success( test_api_key: str, @@ -169,6 +225,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( model_bundle_repository=fake_model_bundle_repository, model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -224,6 +281,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception model_bundle_repository=fake_model_bundle_repository, model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): @@ -253,6 +311,7 @@ async def test_create_llm_model_endpoint_use_case_quantization_exception( model_bundle_repository=fake_model_bundle_repository, model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): From b4afd0889b813fbd7296cb301fc7fd25ca2c0c42 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Thu, 9 Nov 2023 10:57:27 -0800 Subject: [PATCH 172/425] Ianmacleod/add codellama 34b (#369) * adding codellama 34b logic * adding 13b to docs --- docs/model_zoo.md | 7 ++++--- .../domain/use_cases/llm_model_endpoint_use_cases.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 0431ea14..ebe9f082 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -22,9 +22,10 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `mistral-7b-instruct` | ✅ | ✅ | vllm | | `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | | `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | -| `codellama-13b` | ✅ | | text-generation-inference, vllm | -| `codellama-13b-instruct` | ✅ | | text-generation-inference, vllm | -| `codellama-34b` | ✅ | | text-generation-inference, vllm | +| `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | +| `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | +| `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | +| `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | ## Usage diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 829f4801..ab138747 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -109,6 +109,7 @@ "codellama-13b": "codellama/CodeLlama-13b-hf", "codellama-13b-instruct": "codellama/CodeLlama-13b-Instruct-hf", "codellama-34b": "codellama/CodeLlama-34b-hf", + "codellama-34b-instruct": "codellama/CodeLlama-34b-Instruct-hf", "llm-jp-13b-instruct-full": "llm-jp/llm-jp-13b-instruct-full-jaster-v1.0", "llm-jp-13b-instruct-full-dolly": "llm-jp/llm-jp-13b-instruct-full-dolly-oasst-v1.0", }, @@ -135,6 +136,7 @@ "codellama-13b": "codellama/CodeLlama-13b-hf", "codellama-13b-instruct": "codellama/CodeLlama-13b-Instruct-hf", "codellama-34b": "codellama/CodeLlama-34b-hf", + "codellama-34b-instruct": "codellama/CodeLlama-34b-Instruct-hf", "mammoth-coder-llama-2-7b": "TIGER-Lab/MAmmoTH-Coder-7B", "mammoth-coder-llama-2-13b": "TIGER-Lab/MAmmoTH-Coder-13B", "mammoth-coder-llama-2-34b": "TIGER-Lab/MAmmoTH-Coder-34B", From 8bf0aa539fa5c42321e3b11279e93e0d31cddc9b Mon Sep 17 00:00:00 2001 From: tiffzhao5 <142925794+tiffzhao5@users.noreply.github.com> Date: Thu, 9 Nov 2023 14:36:33 -0800 Subject: [PATCH 173/425] Better error when model is not ready for predictions (#368) * add error * add unit tests --- ...eaming_model_endpoint_inference_gateway.py | 9 +++++- ...e_sync_model_endpoint_inference_gateway.py | 12 +++++-- ...eaming_model_endpoint_inference_gateway.py | 31 +++++++++++++++++++ ...e_sync_model_endpoint_inference_gateway.py | 30 ++++++++++++++++++ 4 files changed, 79 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index e1519abc..9c2ff9b7 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -142,7 +142,11 @@ async def make_request_with_retries( stop_after_attempt(num_retries + 1), stop_after_delay(timeout_seconds) ), retry=retry_if_exception_type( - (TooManyRequestsException, NoHealthyUpstreamException) + ( + TooManyRequestsException, + NoHealthyUpstreamException, + aiohttp.ClientConnectorError, + ) ), wait=wait_exponential( multiplier=1, @@ -164,6 +168,9 @@ async def make_request_with_retries( elif type(e.last_attempt.exception()) == NoHealthyUpstreamException: logger.warning("Pods didn't spin up in time, returning 503 to client") raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") + elif type(e.last_attempt.exception()) == aiohttp.ClientConnectorError: + logger.warning("ClientConnectorError, returning 503 to client") + raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") else: logger.error("Unknown Exception Type") raise UpstreamServiceError(status_code=500, content=b"Unknown error") diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index dd427f93..add25b7b 100644 --- a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -122,10 +122,15 @@ async def make_request_with_retries( try: async for attempt in AsyncRetrying( stop=stop_any( - stop_after_attempt(num_retries + 1), stop_after_delay(timeout_seconds) + stop_after_attempt(num_retries + 1), + stop_after_delay(timeout_seconds), ), retry=retry_if_exception_type( - (TooManyRequestsException, NoHealthyUpstreamException) + ( + TooManyRequestsException, + NoHealthyUpstreamException, + aiohttp.ClientConnectorError, + ) ), wait=wait_exponential( multiplier=1, @@ -144,6 +149,9 @@ async def make_request_with_retries( elif type(e.last_attempt.exception()) == NoHealthyUpstreamException: logger.warning("Pods didn't spin up in time, returning 503 to client") raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") + elif type(e.last_attempt.exception()) == aiohttp.ClientConnectorError: + logger.warning("ClientConnectorError, returning 503 to client") + raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") else: logger.error("Unknown Exception Type") raise UpstreamServiceError(status_code=500, content=b"Unknown error") diff --git a/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py index 13980fb9..e2cabc79 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Tuple from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp import pytest from model_engine_server.common.dtos.tasks import ( SyncEndpointPredictV1Request, @@ -53,6 +54,22 @@ def _get_mock_client_session(fake_response: FakeResponse): return mock_client_session +def _get_mock_client_session_with_client_connector_error(): + mock_post = AsyncMock( + side_effect=aiohttp.ClientConnectorError(connection_key=None, os_error=OSError()) + ) + mock_client_session_val = AsyncMock() + mock_client_session_val.post = mock_post + mock_client_session_val.__aenter__ = AsyncMock(return_value=mock_client_session_val) + + async def _aexit(*exc): + pass + + mock_client_session_val.__aexit__ = AsyncMock(side_effect=_aexit) + mock_client_session = MagicMock(return_value=mock_client_session_val) + return mock_client_session + + @pytest.mark.asyncio async def test_make_request_with_retries_success(): gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) @@ -102,6 +119,20 @@ async def test_make_request_with_retries_failed_traceback(): response +@pytest.mark.asyncio +async def test_make_request_with_retries_failed_with_client_connector_error(): + gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) + + mock_client_session = _get_mock_client_session_with_client_connector_error() + + with pytest.raises(UpstreamServiceError), patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + async for response in gateway.make_request_with_retries("test_request_url", {}, 0.05, 2): + response + + @pytest.mark.asyncio async def test_streaming_predict_success( sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] diff --git a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py index 6241fe6e..806ee93f 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Tuple from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp import pytest from model_engine_server.common.dtos.tasks import ( SyncEndpointPredictV1Request, @@ -37,6 +38,22 @@ def _get_mock_client_session(fake_response: FakeResponse): return mock_client_session +def _get_mock_client_session_with_client_connector_error(): + mock_post = AsyncMock( + side_effect=aiohttp.ClientConnectorError(connection_key=None, os_error=OSError()) + ) + mock_client_session_val = AsyncMock() + mock_client_session_val.post = mock_post + mock_client_session_val.__aenter__ = AsyncMock(return_value=mock_client_session_val) + + async def _aexit(*exc): + pass + + mock_client_session_val.__aexit__ = AsyncMock(side_effect=_aexit) + mock_client_session = MagicMock(return_value=mock_client_session_val) + return mock_client_session + + @pytest.mark.asyncio async def test_make_request_with_retries_success(): gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) @@ -80,6 +97,19 @@ async def test_make_request_with_retries_failed_traceback(): await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) +@pytest.mark.asyncio +async def test_make_request_with_retries_failed_with_client_connector_error(): + gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + + mock_client_session = _get_mock_client_session_with_client_connector_error() + + with pytest.raises(UpstreamServiceError), patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) + + @pytest.mark.asyncio async def test_predict_success( sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] From 24d103797c3da531c4bdd79f153ab2e27a3c979f Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Mon, 13 Nov 2023 10:06:19 -0800 Subject: [PATCH 174/425] Improve metrics route team tags (#371) * add team email to user * format request route with path params --- model-engine/model_engine_server/api/llms_v1.py | 9 ++++++++- .../core/auth/authentication_repository.py | 5 +++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 038cdc5f..a15e3f20 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -79,12 +79,19 @@ from sse_starlette.sse import EventSourceResponse +def format_request_route(request: Request) -> str: + url_path = request.url.path + for path_param in request.path_params: + url_path = url_path.replace(request.path_params[path_param], f":{path_param}") + return f"{request.method}_{url_path}".lower() + + async def record_route_call( request: Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ): - route = f"{request.method}_{request.url.path}".lower() + route = format_request_route(request) model_name = request.query_params.get("model_endpoint_name", None) external_interfaces.monitoring_metrics_gateway.emit_route_call_metric( diff --git a/model-engine/model_engine_server/core/auth/authentication_repository.py b/model-engine/model_engine_server/core/auth/authentication_repository.py index a4d36dc1..2d3b591e 100644 --- a/model-engine/model_engine_server/core/auth/authentication_repository.py +++ b/model-engine/model_engine_server/core/auth/authentication_repository.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional @@ -7,7 +7,8 @@ class User: user_id: str team_id: str - email: Optional[str] = None + email: Optional[str] = field(repr=False, default=None) + team_email: Optional[str] = field(repr=False, default=None) is_privileged_user: bool = False From 3e55f49e5b6c8c7b1672ebbe4d92ae4e5a2f1ac2 Mon Sep 17 00:00:00 2001 From: William Song Date: Mon, 13 Nov 2023 14:43:08 -0800 Subject: [PATCH 175/425] Enable custom istio metric tags with Telemetry API (#373) Add custom tags for istio metrics --- .../_istio-attribute-match-conditions.tpl | 117 ++++++++++++++++++ .../model-engine/templates/istio-metrics.yaml | 34 +++++ 2 files changed, 151 insertions(+) create mode 100644 charts/model-engine/templates/_istio-attribute-match-conditions.tpl create mode 100644 charts/model-engine/templates/istio-metrics.yaml diff --git a/charts/model-engine/templates/_istio-attribute-match-conditions.tpl b/charts/model-engine/templates/_istio-attribute-match-conditions.tpl new file mode 100644 index 00000000..6e9feeb1 --- /dev/null +++ b/charts/model-engine/templates/_istio-attribute-match-conditions.tpl @@ -0,0 +1,117 @@ +{{- /* Generated from the OpenAPI schema with model-engine-internal/scripts/generate_istio_metric_tags.py */}} +{{- define "modelEngine.istioAttributeMatchConditions" -}} +- condition: request.method == 'GET' && request.url_path == '/healthcheck' + value: get_/healthcheck +- condition: request.method == 'GET' && request.url_path == '/healthz' + value: get_/healthz +- condition: request.method == 'GET' && request.url_path == '/readyz' + value: get_/readyz +- condition: request.method == 'POST' && request.url_path == '/v1/async-tasks' + value: post_/v1/async-tasks +- condition: request.method == 'GET' && request.url_path.matches('^/v1/async-tasks/[[:alnum:]-_]*$') + value: get_/v1/async-tasks/_task_id +- condition: request.method == 'POST' && request.url_path == '/v1/batch-jobs' + value: post_/v1/batch-jobs +- condition: request.method == 'GET' && request.url_path.matches('^/v1/batch-jobs/[[:alnum:]-_]*$') + value: get_/v1/batch-jobs/_batch_job_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/batch-jobs/[[:alnum:]-_]*$') + value: put_/v1/batch-jobs/_batch_job_id +- condition: request.method == 'GET' && request.url_path == '/v1/docker-image-batch-job-bundles' + value: get_/v1/docker-image-batch-job-bundles +- condition: request.method == 'POST' && request.url_path == '/v1/docker-image-batch-job-bundles' + value: post_/v1/docker-image-batch-job-bundles +- condition: request.method == 'GET' && request.url_path == '/v1/docker-image-batch-job-bundles/latest' + value: get_/v1/docker-image-batch-job-bundles/latest +- condition: request.method == 'GET' && request.url_path.matches('^/v1/docker-image-batch-job-bundles/[[:alnum:]-_]*$') + value: get_/v1/docker-image-batch-job-bundles/_docker_image_batch_job_bundle_id +- condition: request.method == 'GET' && request.url_path == '/v1/docker-image-batch-jobs' + value: get_/v1/docker-image-batch-jobs +- condition: request.method == 'POST' && request.url_path == '/v1/docker-image-batch-jobs' + value: post_/v1/docker-image-batch-jobs +- condition: request.method == 'GET' && request.url_path.matches('^/v1/docker-image-batch-jobs/[[:alnum:]-_]*$') + value: get_/v1/docker-image-batch-jobs/_batch_job_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/docker-image-batch-jobs/[[:alnum:]-_]*$') + value: put_/v1/docker-image-batch-jobs/_batch_job_id +- condition: request.method == 'GET' && request.url_path == '/v1/files' + value: get_/v1/files +- condition: request.method == 'POST' && request.url_path == '/v1/files' + value: post_/v1/files +- condition: request.method == 'DELETE' && request.url_path.matches('^/v1/files/[[:alnum:]-_]*$') + value: delete_/v1/files/_file_id +- condition: request.method == 'GET' && request.url_path.matches('^/v1/files/[[:alnum:]-_]*$') + value: get_/v1/files/_file_id +- condition: request.method == 'GET' && request.url_path.matches('^/v1/files/[[:alnum:]-_]*/content$') + value: get_/v1/files/_file_id/content +- condition: request.method == 'POST' && request.url_path == '/v1/llm/completions-stream' + value: post_/v1/llm/completions-stream +- condition: request.method == 'POST' && request.url_path == '/v1/llm/completions-sync' + value: post_/v1/llm/completions-sync +- condition: request.method == 'GET' && request.url_path == '/v1/llm/fine-tunes' + value: get_/v1/llm/fine-tunes +- condition: request.method == 'POST' && request.url_path == '/v1/llm/fine-tunes' + value: post_/v1/llm/fine-tunes +- condition: request.method == 'GET' && request.url_path.matches('^/v1/llm/fine-tunes/[[:alnum:]-_]*$') + value: get_/v1/llm/fine-tunes/_fine_tune_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/llm/fine-tunes/[[:alnum:]-_]*/cancel$') + value: put_/v1/llm/fine-tunes/_fine_tune_id/cancel +- condition: request.method == 'GET' && request.url_path.matches('^/v1/llm/fine-tunes/[[:alnum:]-_]*/events$') + value: get_/v1/llm/fine-tunes/_fine_tune_id/events +- condition: request.method == 'GET' && request.url_path == '/v1/llm/model-endpoints' + value: get_/v1/llm/model-endpoints +- condition: request.method == 'POST' && request.url_path == '/v1/llm/model-endpoints' + value: post_/v1/llm/model-endpoints +- condition: request.method == 'POST' && request.url_path == '/v1/llm/model-endpoints/download' + value: post_/v1/llm/model-endpoints/download +- condition: request.method == 'DELETE' && request.url_path.matches('^/v1/llm/model-endpoints/[[:alnum:]-_]*$') + value: delete_/v1/llm/model-endpoints/_model_endpoint_name +- condition: request.method == 'GET' && request.url_path.matches('^/v1/llm/model-endpoints/[[:alnum:]-_]*$') + value: get_/v1/llm/model-endpoints/_model_endpoint_name +- condition: request.method == 'GET' && request.url_path == '/v1/model-bundles' + value: get_/v1/model-bundles +- condition: request.method == 'POST' && request.url_path == '/v1/model-bundles' + value: post_/v1/model-bundles +- condition: request.method == 'POST' && request.url_path == '/v1/model-bundles/clone-with-changes' + value: post_/v1/model-bundles/clone-with-changes +- condition: request.method == 'GET' && request.url_path == '/v1/model-bundles/latest' + value: get_/v1/model-bundles/latest +- condition: request.method == 'GET' && request.url_path.matches('^/v1/model-bundles/[[:alnum:]-_]*$') + value: get_/v1/model-bundles/_model_bundle_id +- condition: request.method == 'GET' && request.url_path == '/v1/model-endpoints' + value: get_/v1/model-endpoints +- condition: request.method == 'POST' && request.url_path == '/v1/model-endpoints' + value: post_/v1/model-endpoints +- condition: request.method == 'GET' && request.url_path == '/v1/model-endpoints-api' + value: get_/v1/model-endpoints-api +- condition: request.method == 'GET' && request.url_path == '/v1/model-endpoints-schema.json' + value: get_/v1/model-endpoints-schema.json +- condition: request.method == 'DELETE' && request.url_path.matches('^/v1/model-endpoints/[[:alnum:]-_]*$') + value: delete_/v1/model-endpoints/_model_endpoint_id +- condition: request.method == 'GET' && request.url_path.matches('^/v1/model-endpoints/[[:alnum:]-_]*$') + value: get_/v1/model-endpoints/_model_endpoint_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/model-endpoints/[[:alnum:]-_]*$') + value: put_/v1/model-endpoints/_model_endpoint_id +- condition: request.method == 'POST' && request.url_path == '/v1/streaming-tasks' + value: post_/v1/streaming-tasks +- condition: request.method == 'POST' && request.url_path == '/v1/sync-tasks' + value: post_/v1/sync-tasks +- condition: request.method == 'GET' && request.url_path == '/v1/triggers' + value: get_/v1/triggers +- condition: request.method == 'POST' && request.url_path == '/v1/triggers' + value: post_/v1/triggers +- condition: request.method == 'DELETE' && request.url_path.matches('^/v1/triggers/[[:alnum:]-_]*$') + value: delete_/v1/triggers/_trigger_id +- condition: request.method == 'GET' && request.url_path.matches('^/v1/triggers/[[:alnum:]-_]*$') + value: get_/v1/triggers/_trigger_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/triggers/[[:alnum:]-_]*$') + value: put_/v1/triggers/_trigger_id +- condition: request.method == 'GET' && request.url_path == '/v2/model-bundles' + value: get_/v2/model-bundles +- condition: request.method == 'POST' && request.url_path == '/v2/model-bundles' + value: post_/v2/model-bundles +- condition: request.method == 'POST' && request.url_path == '/v2/model-bundles/clone-with-changes' + value: post_/v2/model-bundles/clone-with-changes +- condition: request.method == 'GET' && request.url_path == '/v2/model-bundles/latest' + value: get_/v2/model-bundles/latest +- condition: request.method == 'GET' && request.url_path.matches('^/v2/model-bundles/[[:alnum:]-_]*$') + value: get_/v2/model-bundles/_model_bundle_id +{{- end -}} diff --git a/charts/model-engine/templates/istio-metrics.yaml b/charts/model-engine/templates/istio-metrics.yaml new file mode 100644 index 00000000..be5a2e90 --- /dev/null +++ b/charts/model-engine/templates/istio-metrics.yaml @@ -0,0 +1,34 @@ +apiVersion: telemetry.istio.io/v1alpha1 +kind: Telemetry +metadata: + name: custom-tags + namespace: istio-system +spec: + metrics: + - overrides: + - match: + metric: REQUEST_COUNT + mode: CLIENT_AND_SERVER + tagOverrides: + request_operation: + value: istio_requestOperation + providers: + - name: prometheus +--- +apiVersion: extensions.istio.io/v1alpha1 +kind: WasmPlugin +metadata: + name: istio-attributegen-filter + namespace: istio-system +spec: + imagePullPolicy: Always + phase: AUTHN + pluginConfig: + attributes: + - match: + {{- include "modelEngine.istioAttributeMatchConditions" . | nindent 6 }} + output_attribute: istio_requestOperation + selector: + matchLabels: + {{- include "modelEngine.selectorLabels.gateway" . | nindent 6 }} + url: https://storage.googleapis.com/istio-build/proxy/attributegen-359dcd3a19f109c50e97517fe6b1e2676e870c4d.wasm From 0e47fc81a0a0421ba231b84258ed0d247b504191 Mon Sep 17 00:00:00 2001 From: William Song Date: Tue, 14 Nov 2023 15:32:45 -0800 Subject: [PATCH 176/425] use modelEngine fullname (#374) --- charts/model-engine/templates/istio-metrics.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/charts/model-engine/templates/istio-metrics.yaml b/charts/model-engine/templates/istio-metrics.yaml index be5a2e90..7020e793 100644 --- a/charts/model-engine/templates/istio-metrics.yaml +++ b/charts/model-engine/templates/istio-metrics.yaml @@ -1,7 +1,7 @@ apiVersion: telemetry.istio.io/v1alpha1 kind: Telemetry metadata: - name: custom-tags + name: {{ include "modelEngine.fullname" . }}-custom-tags namespace: istio-system spec: metrics: @@ -18,7 +18,7 @@ spec: apiVersion: extensions.istio.io/v1alpha1 kind: WasmPlugin metadata: - name: istio-attributegen-filter + name: {{ include "modelEngine.fullname" . }}-attributegen namespace: istio-system spec: imagePullPolicy: Always From b3193979648ad35ca65bbe93e6ead5c13a86b2ea Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 14 Nov 2023 16:08:53 -0800 Subject: [PATCH 177/425] Forward HTTP status code for sync requests (#375) * Forward HTTP status code for sync requests * don't return json response for celery forwarding results * fix unit tests * forward for all sync requests --- ...-runnable-img-converted-from-artifact.yaml | 1 + .../inference/configs/service--forwarder.yaml | 1 + .../configs/service--http_forwarder.yaml | 1 + .../inference/forwarding/echo_server.py | 7 ++ .../inference/forwarding/forwarding.py | 28 ++++++-- .../tests/unit/inference/test_forwarding.py | 68 ++++++++++++++++++- 6 files changed, 98 insertions(+), 8 deletions(-) diff --git a/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml b/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml index 6fd6f920..0c9b43b4 100644 --- a/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml +++ b/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml @@ -9,6 +9,7 @@ forwarder: model_engine_unwrap: false serialize_results_as_string: false wrap_response: false + forward_http_status: true async: user_port: 5005 user_hostname: "localhost" diff --git a/model-engine/model_engine_server/inference/configs/service--forwarder.yaml b/model-engine/model_engine_server/inference/configs/service--forwarder.yaml index 9e284230..fea277db 100644 --- a/model-engine/model_engine_server/inference/configs/service--forwarder.yaml +++ b/model-engine/model_engine_server/inference/configs/service--forwarder.yaml @@ -8,6 +8,7 @@ forwarder: batch_route: null model_engine_unwrap: true serialize_results_as_string: true + forward_http_status: true async: user_port: 5005 user_hostname: "localhost" diff --git a/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml index f0e3eef1..10052970 100644 --- a/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml +++ b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml @@ -8,6 +8,7 @@ forwarder: batch_route: null model_engine_unwrap: true serialize_results_as_string: true + forward_http_status: true stream: user_port: 5005 user_hostname: "localhost" diff --git a/model-engine/model_engine_server/inference/forwarding/echo_server.py b/model-engine/model_engine_server/inference/forwarding/echo_server.py index 0a44b832..db6c0b3c 100644 --- a/model-engine/model_engine_server/inference/forwarding/echo_server.py +++ b/model-engine/model_engine_server/inference/forwarding/echo_server.py @@ -5,6 +5,7 @@ import subprocess from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse from sse_starlette.sse import EventSourceResponse app = FastAPI() @@ -21,6 +22,12 @@ async def predict(request: Request): return await request.json() +@app.post("/predict500") +async def predict500(request: Request): + response = JSONResponse(content=await request.json(), status_code=500) + return response + + @app.post("/stream") async def stream(request: Request): value = (await request.body()).decode() diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 196942d5..099fe7d4 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -8,6 +8,7 @@ import requests import sseclient import yaml +from fastapi.responses import JSONResponse from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.common import get_endpoint_config @@ -122,6 +123,7 @@ class Forwarder(ModelEngineSerializationMixin): serialize_results_as_string: bool post_inference_hooks_handler: PostInferenceHooksHandler wrap_response: bool + forward_http_status: bool def __call__(self, json_payload: Any) -> Any: request_obj = EndpointPredictV1Request.parse_obj(json_payload) @@ -131,13 +133,14 @@ def __call__(self, json_payload: Any) -> Any: logger.info(f"Accepted request, forwarding {json_payload_repr=}") try: - response: Any = requests.post( + response_raw: Any = requests.post( self.predict_endpoint, json=json_payload, headers={ "Content-Type": "application/json", }, - ).json() + ) + response = response_raw.json() except Exception: logger.exception( f"Failed to get response for request ({json_payload_repr}) " @@ -145,18 +148,27 @@ def __call__(self, json_payload: Any) -> Any: ) raise if isinstance(response, dict): - logger.info(f"Got response from user-defined service: {response.keys()=}") + logger.info( + f"Got response from user-defined service: {response.keys()=}, {response_raw.status_code=}" + ) elif isinstance(response, list): - logger.info(f"Got response from user-defined service: {len(response)=}") + logger.info( + f"Got response from user-defined service: {len(response)=}, {response_raw.status_code=}" + ) else: - logger.info(f"Got response from user-defined service: {response=}") + logger.info( + f"Got response from user-defined service: {response=}, {response_raw.status_code=}" + ) if self.wrap_response: response = self.get_response_payload(using_serialize_results_as_string, response) # TODO: we actually want to do this after we've returned the response. self.post_inference_hooks_handler.handle(request_obj, response) - return response + if self.forward_http_status: + return JSONResponse(content=response, status_code=response_raw.status_code) + else: + return response @dataclass(frozen=True) @@ -180,6 +192,7 @@ class LoadForwarder: model_engine_unwrap: bool = True serialize_results_as_string: bool = True wrap_response: bool = True + forward_http_status: bool = False def load(self, resources: Path, cache: Any) -> Forwarder: if self.use_grpc: @@ -278,6 +291,7 @@ def endpoint(route: str) -> str: serialize_results_as_string=serialize_results_as_string, post_inference_hooks_handler=handler, wrap_response=self.wrap_response, + forward_http_status=self.forward_http_status, ) @@ -492,7 +506,7 @@ def _set_value(config: dict, key_path: List[str], value: Any) -> None: """ key = key_path[0] if len(key_path) == 1: - config[key] = value + config[key] = value if not value.isdigit() else int(value) else: if key not in config: config[key] = dict() diff --git a/model-engine/tests/unit/inference/test_forwarding.py b/model-engine/tests/unit/inference/test_forwarding.py index 283af031..07117967 100644 --- a/model-engine/tests/unit/inference/test_forwarding.py +++ b/model-engine/tests/unit/inference/test_forwarding.py @@ -4,6 +4,7 @@ from unittest import mock import pytest +from fastapi.responses import JSONResponse from model_engine_server.core.utils.env import environment from model_engine_server.domain.entities import ModelEndpointConfig from model_engine_server.inference.forwarding.forwarding import ( @@ -33,6 +34,19 @@ class mocked_static_status_code: def mocked_post(*args, **kwargs): # noqa @dataclass class mocked_static_json: + status_code: int = 200 + + def json(self) -> dict: + return PAYLOAD # type: ignore + + return mocked_static_json() + + +def mocked_post_500(*args, **kwargs): # noqa + @dataclass + class mocked_static_json: + status_code: int = 500 + def json(self) -> dict: return PAYLOAD # type: ignore @@ -85,16 +99,27 @@ def test_forwarders(post_inference_hooks_handler): serialize_results_as_string=False, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, + forward_http_status=True, ) json_response = fwd({"ignore": "me"}) _check(json_response) def _check(json_response) -> None: + json_response = ( + json.loads(json_response.body.decode("utf-8")) + if isinstance(json_response, JSONResponse) + else json_response + ) assert json_response == {"result": PAYLOAD} def _check_responses_not_wrapped(json_response) -> None: + json_response = ( + json.loads(json_response.body.decode("utf-8")) + if isinstance(json_response, JSONResponse) + else json_response + ) assert json_response == PAYLOAD @@ -121,12 +146,18 @@ def test_forwarders_serialize_results_as_string(post_inference_hooks_handler): serialize_results_as_string=True, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, + forward_http_status=True, ) json_response = fwd({"ignore": "me"}) _check_serialized(json_response) def _check_serialized(json_response) -> None: + json_response = ( + json.loads(json_response.body.decode("utf-8")) + if isinstance(json_response, JSONResponse) + else json_response + ) assert isinstance(json_response["result"], str) assert len(json_response) == 1, f"expecting only 'result' key, but got {json_response=}" assert json.loads(json_response["result"]) == PAYLOAD @@ -141,10 +172,10 @@ def test_forwarders_override_serialize_results(post_inference_hooks_handler): serialize_results_as_string=True, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, + forward_http_status=True, ) json_response = fwd({"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: False}) _check(json_response) - assert json_response == {"result": PAYLOAD} fwd = Forwarder( "ignored", @@ -152,6 +183,7 @@ def test_forwarders_override_serialize_results(post_inference_hooks_handler): serialize_results_as_string=False, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, + forward_http_status=True, ) json_response = fwd({"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: True}) _check_serialized(json_response) @@ -166,11 +198,43 @@ def test_forwarder_does_not_wrap_response(post_inference_hooks_handler): serialize_results_as_string=False, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=False, + forward_http_status=True, ) json_response = fwd({"ignore": "me"}) _check_responses_not_wrapped(json_response) +@mock.patch("requests.post", mocked_post_500) +@mock.patch("requests.get", mocked_get) +def test_forwarder_return_status_code(post_inference_hooks_handler): + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=True, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=False, + forward_http_status=True, + ) + json_response = fwd({"ignore": "me"}) + _check_responses_not_wrapped(json_response) + assert json_response.status_code == 500 + + +@mock.patch("requests.post", mocked_post_500) +@mock.patch("requests.get", mocked_get) +def test_forwarder_dont_return_status_code(post_inference_hooks_handler): + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=True, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=False, + forward_http_status=False, + ) + json_response = fwd({"ignore": "me"}) + assert json_response == PAYLOAD + + @mock.patch("requests.post", mocked_post) @mock.patch("requests.get", mocked_get) @mock.patch( @@ -219,6 +283,7 @@ def test_forwarder_serialize_within_args(post_inference_hooks_handler): serialize_results_as_string=True, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, + forward_http_status=True, ) # expected: no `serialize_results_as_string` at top-level nor in 'args' json_response = fwd({"something": "to ignore", "args": {"my": "payload", "is": "here"}}) @@ -237,6 +302,7 @@ def test_forwarder_serialize_within_args(post_inference_hooks_handler): serialize_results_as_string=True, post_inference_hooks_handler=post_inference_hooks_handler, wrap_response=True, + forward_http_status=True, ) json_response = fwd(payload) _check_serialized(json_response) From 4e2ea6cd22364fd72424a7ee5d95a1dba03f74c2 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 14 Nov 2023 21:05:24 -0800 Subject: [PATCH 178/425] Integrate TensorRT-LLM (#358) * TRT-LLM WIP * Integrate TensorRT-LLM * fix * revert * formatting * fix * comments --- charts/model-engine/values_circleci.yaml | 1 + .../model_engine_server/common/config.py | 1 + .../domain/entities/llm_entity.py | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 220 ++++++++++++++- .../inference/tensorrt-llm/Dockerfile | 12 + .../tensorrt-llm/launch_triton_server.py | 33 +++ .../inference/tensorrt-llm/requirements.txt | 2 + .../triton_model_repo/ensemble/1/.tmp | 0 .../triton_model_repo/ensemble/config.pbtxt | 255 ++++++++++++++++++ .../postprocessing/1/model.py | 156 +++++++++++ .../postprocessing/config.pbtxt | 69 +++++ .../preprocessing/1/model.py | 224 +++++++++++++++ .../preprocessing/config.pbtxt | 99 +++++++ .../triton_model_repo/tensorrt_llm/1/.gitkeep | 0 .../tensorrt_llm/config.pbtxt | 208 ++++++++++++++ model-engine/mypy.ini | 2 +- model-engine/requirements.in | 3 +- model-engine/requirements.txt | 31 ++- .../service_config_circleci.yaml | 1 + model-engine/setup.cfg | 2 + model-engine/tests/unit/conftest.py | 73 +++++ model-engine/tests/unit/domain/conftest.py | 59 +++- .../tests/unit/domain/test_llm_use_cases.py | 145 ++++++++++ 23 files changed, 1589 insertions(+), 8 deletions(-) create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/1/.tmp create mode 100755 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py create mode 100755 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/1/.gitkeep create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 1cc777e3..a5e29f87 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -146,6 +146,7 @@ config: tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" + tensorrt_llm_repository: "tensorrt-llm" user_inference_base_repository: "launch/inference" user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 958881e1..f2b33eea 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -57,6 +57,7 @@ class HostedModelInferenceServiceConfig: tgi_repository: str vllm_repository: str lightllm_repository: str + tensorrt_llm_repository: str user_inference_base_repository: str user_inference_pytorch_repository: str user_inference_tensorflow_repository: str diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index 0624857f..30ec8993 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -12,6 +12,7 @@ class LLMInferenceFramework(str, Enum): TEXT_GENERATION_INFERENCE = "text_generation_inference" VLLM = "vllm" LIGHTLLM = "lightllm" + TENSORRT_LLM = "tensorrt_llm" class Quantization(str, Enum): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index ab138747..1130e6e2 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -62,6 +62,10 @@ from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +# Hack for TensorRT-LLM. Remove when it supports returning output tokens only +# See https://github.com/NVIDIA/TensorRT-LLM/issues/227 +from transformers import AutoTokenizer + from ...common.datadog_utils import add_trace_request_id from ..authorization.live_authorization_module import LiveAuthorizationModule from .model_bundle_use_cases import CreateModelBundleV2UseCase @@ -150,6 +154,9 @@ "llama-2-70b": "meta-llama/Llama-2-70b-hf", "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", }, + LLMInferenceFramework.TENSORRT_LLM: { + "llama-2-7b": "huggyllama/llama-7b", # Hack to get tokenizer for llama without sign in to huggingface + }, } _SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = { @@ -157,6 +164,7 @@ LLMInferenceFramework.TEXT_GENERATION_INFERENCE: [Quantization.BITSANDBYTES], LLMInferenceFramework.VLLM: [Quantization.AWQ], LLMInferenceFramework.LIGHTLLM: [], + LLMInferenceFramework.TENSORRT_LLM: [], } # We need a dict where if we need to override we can @@ -340,6 +348,14 @@ async def create_model_bundle( num_shards, checkpoint_path, ) + elif framework == LLMInferenceFramework.TENSORRT_LLM: + bundle_id = await self.create_tensorrt_llm_bundle( + user, + framework_image_tag, + endpoint_name, + num_shards, + checkpoint_path, + ) else: raise ObjectHasInvalidValueException( f"Framework {framework} is not supported for source {source}." @@ -384,7 +400,7 @@ async def create_text_generation_inference_bundle( ) else: raise ObjectHasInvalidValueException( - f"Not able to load checkpoint path {checkpoint_path}." + f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) else: final_weights_folder = _SUPPORTED_MODEL_NAMES[ @@ -471,6 +487,32 @@ def load_model_weights_sub_commands( return subcommands + def load_model_files_sub_commands_trt_llm( + self, + checkpoint_path, + ): + """ + This function generate subcommands to load model files for TensorRT-LLM. + Each model checkpoint is constituted of two folders: `model_weights` which stores the model engine files, + and `model_tokenizer` which stores the model tokenizer files. + See llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt + and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt + """ + subcommands = [] + + base_path = checkpoint_path.split("/")[-1] + + if base_path.endswith(".tar"): + raise ObjectHasInvalidValueException( + "Checkpoint for TensorRT-LLM models must be a folder, not a tar file." + ) + else: + subcommands.append( + f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./" + ) + + return subcommands + async def create_deepspeed_bundle( self, user: User, @@ -587,7 +629,7 @@ async def create_vllm_bundle( ) else: raise ObjectHasInvalidValueException( - f"Not able to load checkpoint path {checkpoint_path}." + f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) else: final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] @@ -678,7 +720,7 @@ async def create_lightllm_bundle( ) else: raise ObjectHasInvalidValueException( - f"Not able to load checkpoint path {checkpoint_path}." + f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) else: final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] @@ -721,6 +763,70 @@ async def create_lightllm_bundle( ) ).model_bundle_id + async def create_tensorrt_llm_bundle( + self, + user: User, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + checkpoint_path: Optional[str], + ): + command = [] + + subcommands = [] + if checkpoint_path is not None: + if checkpoint_path.startswith("s3://"): + subcommands += self.load_model_files_sub_commands_trt_llm( + checkpoint_path, + ) + else: + raise ObjectHasInvalidValueException( + f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." + ) + else: + raise ObjectHasInvalidValueException( + "Checkpoint must be provided for TensorRT-LLM models." + ) + + subcommands.append( + f"python3 launch_triton_server.py --world_size={num_shards} --model_repo=./model_repo/" + ) + + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] + + return ( + await self.create_model_bundle_use_case.execute( + user, + CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.tensorrt_llm_repository, + tag=framework_image_tag, + command=command, + streaming_command=command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/v2/health/ready", + # See https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md + predict_route="/v2/models/ensemble/generate", + streaming_predict_route="/v2/models/ensemble/generate_stream", + env={}, + ), + metadata={}, + ), + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + async def execute( self, user: User, request: CreateLLMModelEndpointV1Request ) -> CreateLLMModelEndpointV1Response: @@ -741,6 +847,8 @@ async def execute( if request.inference_framework in [ LLMInferenceFramework.TEXT_GENERATION_INFERENCE, LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + LLMInferenceFramework.TENSORRT_LLM, ]: if request.endpoint_type != ModelEndpointType.STREAMING: raise ObjectHasInvalidValueException( @@ -1002,9 +1110,26 @@ def validate_and_update_completion_params( "presence_penalty and frequency_penalty are only supported in vllm, lightllm." ) + # return_token_log_probs + if inference_framework in [ + LLMInferenceFramework.DEEPSPEED, + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: + pass + else: + if request.return_token_log_probs: + raise ObjectHasInvalidValueException( + "return_token_log_probs is only supported in deepspeed, text-generation-inference, vllm, lightllm." + ) + return request +tokenizer_cache: Dict[str, AutoTokenizer] = {} + + class CompletionSyncV1UseCase: """ Use case for running a prompt completion on an LLM endpoint. @@ -1024,6 +1149,7 @@ def model_output_to_completion_output( model_output: Dict[str, Any], model_endpoint: ModelEndpoint, with_token_probs: Optional[bool], + prompt: Optional[str] = None, ) -> CompletionOutput: model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: @@ -1083,6 +1209,28 @@ def model_output_to_completion_output( num_completion_tokens=model_output["count_output_tokens"], tokens=tokens, ) + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + if not model_content.model_name: + raise InvalidRequestException( + f"Invalid endpoint {model_content.name} has no base model" + ) + if not prompt: + raise InvalidRequestException("Prompt must be provided for TensorRT-LLM models.") + if model_content.model_name not in tokenizer_cache: + tokenizer_cache[model_content.model_name] = AutoTokenizer.from_pretrained( + _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.TENSORRT_LLM][ + model_content.model_name + ] + ) + tokenizer = tokenizer_cache[model_content.model_name] + prompt_tokens = tokenizer.encode(prompt) + + return CompletionOutput( + text=model_output["text_output"][ + len(prompt) + 4 : + ], # Output is " prompt output" + num_completion_tokens=len(model_output["token_ids"]) - len(prompt_tokens), + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" @@ -1187,6 +1335,7 @@ async def execute( predict_result.result["result"][0], model_endpoint, request.return_token_log_probs, + request.prompt, ), ) else: @@ -1317,6 +1466,42 @@ async def execute( output, model_endpoint, request.return_token_log_probs ), ) + elif endpoint_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + # TODO: Stop sequences is buggy and return token logprobs are not supported + # TODO: verify the implementation of presence_penalty and repetition_penalty + # and see if they fit our existing definition of presence_penalty and frequency_penalty + # Ref https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/sampling_penalty_kernels.cu + trt_llm_args: Any = { + "text_input": request.prompt, + "max_tokens": request.max_new_tokens, + "stop_words": request.stop_sequences if request.stop_sequences else "", + "bad_words": "", + "temperature": request.temperature, + } + + inference_request = SyncEndpointPredictV1Request( + args=trt_llm_args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + return CompletionSyncV1Response( + request_id=request_id, + output=None, + ) + + output = json.loads(predict_result.result["result"]) + return CompletionSyncV1Response( + request_id=request_id, + output=self.model_output_to_completion_output( + output, model_endpoint, request.return_token_log_probs, request.prompt + ), + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {endpoint_content.inference_framework}" @@ -1471,6 +1656,19 @@ async def execute( args["parameters"]["do_sample"] = False if request.return_token_log_probs: args["parameters"]["return_details"] = True + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + # TODO: Stop sequences is buggy and return token logprobs are not supported + # TODO: verify the implementation of presence_penalty and repetition_penalty + # and see if they fit our existing definition of presence_penalty and frequency_penalty + # Ref https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/sampling_penalty_kernels.cu + args = { + "text_input": request.prompt, + "max_tokens": request.max_new_tokens, + "stop_words": request.stop_sequences if request.stop_sequences else "", + "bad_words": "", + "temperature": request.temperature, + "stream": True, + } else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" @@ -1606,6 +1804,22 @@ async def execute( request_id=request_id, output=None, ) + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + if res.status == TaskStatus.SUCCESS and result is not None: + num_completion_tokens += 1 + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["text_output"], + finished=False, # Tracked by https://github.com/NVIDIA/TensorRT-LLM/issues/240 + num_completion_tokens=num_completion_tokens, + ), + ) + else: + yield CompletionStreamV1Response( + request_id=request_id, + output=None, + ) else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile b/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile new file mode 100644 index 00000000..7bae22fd --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile @@ -0,0 +1,12 @@ +FROM nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 + +COPY requirements.txt /workspace/requirements.txt +WORKDIR /workspace +RUN pip install -r requirements.txt + +# Install s5cmd +RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz +RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz + +COPY launch_triton_server.py /workspace/launch_triton_server.py +COPY triton_model_repo /workspace/model_repo \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py b/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py new file mode 100644 index 00000000..0ce46d2b --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py @@ -0,0 +1,33 @@ +import argparse +import subprocess +from pathlib import Path + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=1, help="world size, only support tensor parallelism now" + ) + parser.add_argument("--tritonserver", type=str, default="/opt/tritonserver/bin/tritonserver") + parser.add_argument( + "--http-port", + type=int, + default=5005, + help="Default HTTP port to 5005. See llm-engine/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml", + ) + path = str(Path(__file__).parent.absolute()) + "/../all_models/gpt" + parser.add_argument("--model_repo", type=str, default=path) + return parser.parse_args() + + +def get_cmd(world_size, tritonserver, model_repo, http_port): + cmd = "mpirun --allow-run-as-root " + for i in range(world_size): + cmd += f" -n 1 {tritonserver} --model-repository={model_repo} --http-address ipv6:[::1] --http-port {http_port} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{i}_ : " + return cmd + + +if __name__ == "__main__": + args = parse_arguments() + cmd = get_cmd(int(args.world_size), args.tritonserver, args.model_repo, args.http_port) + subprocess.call(cmd, shell=True) diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt b/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt new file mode 100644 index 00000000..e2e60684 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt @@ -0,0 +1,2 @@ +sentencepiece==0.1.99 +protobuf==4.24.4 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/1/.tmp b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/1/.tmp new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt new file mode 100755 index 00000000..7a7662d3 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt @@ -0,0 +1,255 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "ensemble" +platform: "ensemble" +max_batch_size: 128 +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "bad_words" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "stop_words" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "length_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1, -1 ] + }, + { + name: "token_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocessing" + model_version: -1 + input_map { + key: "QUERY" + value: "text_input" + } + input_map { + key: "REQUEST_OUTPUT_LEN" + value: "max_tokens" + } + input_map { + key: "BAD_WORDS_DICT" + value: "bad_words" + } + input_map { + key: "STOP_WORDS_DICT" + value: "stop_words" + } + output_map { + key: "REQUEST_INPUT_LEN" + value: "_REQUEST_INPUT_LEN" + } + output_map { + key: "INPUT_ID" + value: "_INPUT_ID" + } + output_map { + key: "REQUEST_OUTPUT_LEN" + value: "_REQUEST_OUTPUT_LEN" + } + }, + { + model_name: "tensorrt_llm" + model_version: -1 + input_map { + key: "input_ids" + value: "_INPUT_ID" + } + input_map { + key: "input_lengths" + value: "_REQUEST_INPUT_LEN" + } + input_map { + key: "request_output_len" + value: "_REQUEST_OUTPUT_LEN" + } + input_map { + key: "end_id" + value: "end_id" + } + input_map { + key: "pad_id" + value: "pad_id" + } + input_map { + key: "runtime_top_k" + value: "top_k" + } + input_map { + key: "runtime_top_p" + value: "top_p" + } + input_map { + key: "temperature" + value: "temperature" + } + input_map { + key: "len_penalty" + value: "length_penalty" + } + input_map { + key: "repetition_penalty" + value: "repetition_penalty" + } + input_map { + key: "min_length" + value: "min_length" + } + input_map { + key: "presence_penalty" + value: "presence_penalty" + } + input_map { + key: "random_seed" + value: "random_seed" + } + input_map { + key: "beam_width" + value: "beam_width" + } + input_map { + key: "streaming" + value: "stream" + } + output_map { + key: "output_ids" + value: "_TOKENS_BATCH" + } + }, + { + model_name: "postprocessing" + model_version: -1 + input_map { + key: "TOKENS_BATCH" + value: "_TOKENS_BATCH" + } + output_map { + key: "OUTPUT" + value: "text_output" + } + output_map { + key: "OUTPUT_TOKEN_IDS" + value: "token_ids" + } + } + ] +} diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py new file mode 100644 index 00000000..1cd809d9 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py @@ -0,0 +1,156 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + tokenizer_dir = model_config["parameters"]["tokenizer_dir"]["string_value"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, padding_side="left") + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Parse model output configs + output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") + output_token_ids_config = pb_utils.get_output_config_by_name( + model_config, "OUTPUT_TOKEN_IDS" + ) + + # Convert Triton types to numpy types + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + self.output_token_ids_dtype = pb_utils.triton_string_to_numpy( + output_token_ids_config["data_type"] + ) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + tokens_batch = pb_utils.get_input_tensor_by_name(request, "TOKENS_BATCH").as_numpy() + + # Reshape Input + # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]]) + # tokens_batch = tokens_batch.T + + # Postprocessing output data. + outputs = self._postprocessing(tokens_batch) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + output_tensor = pb_utils.Tensor("OUTPUT", np.array(outputs).astype(self.output_dtype)) + + output_token_ids = pb_utils.Tensor( + "OUTPUT_TOKEN_IDS", np.array(tokens_batch).astype(self.output_token_ids_dtype) + ) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor, output_token_ids] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") + + def _postprocessing(self, tokens_batch): + outputs = [] + for beam_tokens in tokens_batch: + for tokens in beam_tokens: + output = self.tokenizer.decode(tokens) + outputs.append(output.encode("utf8")) + return outputs diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt new file mode 100755 index 00000000..cc61a24e --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt @@ -0,0 +1,69 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "postprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "TOKENS_BATCH" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_STRING + dims: [ -1, -1 ] + }, + { + name: "OUTPUT_TOKEN_IDS" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "model_tokenizer" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "llama" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py new file mode 100644 index 00000000..b5996f87 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py @@ -0,0 +1,224 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import csv +import json +from typing import List + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + tokenizer_dir = model_config["parameters"]["tokenizer_dir"]["string_value"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, padding_side="left") + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0] + + # Parse model output configs and convert Triton types to numpy types + input_names = ["INPUT_ID", "REQUEST_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS"] + for input_name in input_names: + setattr( + self, + input_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name(model_config, input_name)["data_type"] + ), + ) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, "QUERY").as_numpy() + request_output_len = pb_utils.get_input_tensor_by_name( + request, "REQUEST_OUTPUT_LEN" + ).as_numpy() + + bad_words_dict = pb_utils.get_input_tensor_by_name(request, "BAD_WORDS_DICT").as_numpy() + stop_words_dict = pb_utils.get_input_tensor_by_name( + request, "STOP_WORDS_DICT" + ).as_numpy() + + # Preprocessing input data. + input_id, request_input_len = self._create_request(query) + bad_words = self._to_word_list_format(bad_words_dict) + stop_words = self._to_word_list_format(stop_words_dict) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor( + "INPUT_ID", np.array(input_id).astype(self.input_id_dtype) + ) + request_input_len_tensor = pb_utils.Tensor( + "REQUEST_INPUT_LEN", + np.array(request_input_len).astype(self.request_input_len_dtype), + ) + request_output_len_tensor = pb_utils.Tensor("REQUEST_OUTPUT_LEN", request_output_len) + bad_words_ids_tensor = pb_utils.Tensor("BAD_WORDS_IDS", bad_words) + stop_words_ids_tensor = pb_utils.Tensor("STOP_WORDS_IDS", stop_words) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[ + input_id_tensor, + bad_words_ids_tensor, + stop_words_ids_tensor, + request_input_len_tensor, + request_output_len_tensor, + ] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") + + def _create_request(self, query): + """ + query : batch string (2D numpy array) + """ + start_ids = [torch.IntTensor(self.tokenizer.encode(s[0].decode())) for s in query] + start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) + + start_ids = pad_sequence(start_ids, batch_first=True, padding_value=self.pad_id) + # input_len = min(start_lengths) + # attn_mask = torch.ones((batch_size, input_len, input_len)).tril() + + return start_ids, start_lengths + + def _to_word_list_format(self, word_dict: List[List[str]]): + """ + format of word_dict + len(word_dict) should be same to batch_size + word_dict[i] means the words for batch i + len(word_dict[i]) must be 1, which means it only contains 1 string + This string can contains several sentences and split by ",". + For example, if word_dict[2] = " I am happy, I am sad", then this function will return + the ids for two short sentences " I am happy" and " I am sad". + """ + assert self.tokenizer is not None, "need to set tokenizer" + + flat_ids = [] + offsets = [] + for word_dict_item in word_dict: + item_flat_ids = [] + item_offsets = [] + + if isinstance(word_dict_item[0], bytes): + word_dict_item = [word_dict_item[0].decode()] + + words = list(csv.reader(word_dict_item))[0] + for word in words: + ids = self.tokenizer.encode(word) + + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt new file mode 100644 index 00000000..89d9c91e --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt @@ -0,0 +1,99 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "preprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "BAD_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "STOP_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "BAD_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "STOP_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "model_tokenizer" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "llama" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/1/.gitkeep b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/1/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt new file mode 100644 index 00000000..e24a95b4 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt @@ -0,0 +1,208 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "tensorrtllm" +max_batch_size: 128 + +model_transaction_policy { + decoupled: true +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_UINT32 + dims: [ 1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "1" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "inflight_fused_batching" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "./model_weights" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "${max_tokens_in_paged_kv_cache}" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "${batch_scheduler_policy}" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: "0.9" + } +} +parameters: { + key: "max_num_sequences" + value: { + string_value: "${max_num_sequences}" + } +} +parameters: { + key: "enable_trt_overlap" + value: { + string_value: "${enable_trt_overlap}" + } +} diff --git a/model-engine/mypy.ini b/model-engine/mypy.ini index 9abfbeaa..82c6107a 100644 --- a/model-engine/mypy.ini +++ b/model-engine/mypy.ini @@ -6,7 +6,7 @@ namespace_packages = True explicit_package_bases = True strict_optional = True plugins = pydantic.mypy -exclude = clients +exclude = clients|.*/triton_model_repo/.* [mypy-model_engine_server.cli.*] ignore_errors = True diff --git a/model-engine/requirements.in b/model-engine/requirements.in index e173eeef..ecdf78a1 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -45,7 +45,8 @@ sseclient-py==1.7.2 tenacity>=6.0.0,<=6.2.0 testing-postgresql==1.3.0 tqdm~=4.64 +transformers==4.34.1 twine==3.7.1 uvicorn==0.17.6 uvloop==0.17.0 -yarl~=1.4 +yarl~=1.4 \ No newline at end of file diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index c367da1e..87adb372 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -133,10 +133,16 @@ exceptiongroup==1.1.3 # cattrs fastapi==0.78.0 # via -r model-engine/requirements.in +filelock==3.13.1 + # via + # huggingface-hub + # transformers frozenlist==1.3.3 # via # aiohttp # aiosignal +fsspec==2023.10.0 + # via huggingface-hub gitdb==4.0.10 # via gitpython gitdb2==2.0.6 @@ -160,6 +166,10 @@ hpack==4.0.0 # via h2 httptools==0.5.0 # via -r model-engine/requirements.in +huggingface-hub==0.17.3 + # via + # tokenizers + # transformers hypercorn==0.14.4 # via quart hyperframe==6.0.1 @@ -175,7 +185,7 @@ importlib-metadata==6.8.0 # keyring # quart # twine -importlib-resources==6.0.1 +importlib-resources==6.1.0 # via # alembic # jsonschema @@ -249,6 +259,8 @@ mypy-boto3-sqs==1.26.148 # via boto3-stubs mypy-extensions==1.0.0 # via typing-inspect +numpy==1.24.4 + # via transformers oauthlib==3.2.2 # via requests-oauthlib orjson==3.8.6 @@ -258,7 +270,9 @@ packaging==23.1 # build # ddtrace # deprecation + # huggingface-hub # marshmallow + # transformers pep517==0.13.0 # via build pg8000==1.29.8 @@ -314,9 +328,11 @@ python-multipart==0.0.6 # via -r model-engine/requirements.in pyyaml==6.0 # via + # huggingface-hub # kubeconfig # kubernetes # kubernetes-asyncio + # transformers quart==0.18.3 # via -r model-engine/requirements.in readme-renderer==40.0 @@ -327,15 +343,19 @@ referencing==0.30.2 # via # jsonschema # jsonschema-specifications +regex==2023.10.3 + # via transformers requests==2.31.0 # via # -r model-engine/requirements.in # datadog # docker + # huggingface-hub # kubernetes # requests-auth-aws-sigv4 # requests-oauthlib # requests-toolbelt + # transformers # twine requests-auth-aws-sigv4==0.7 # via -r model-engine/requirements.in @@ -355,6 +375,8 @@ rsa==4.9 # via google-auth s3transfer==0.6.1 # via boto3 +safetensors==0.4.0 + # via transformers scramp==1.4.4 # via pg8000 secretstorage==3.3.3 @@ -403,6 +425,8 @@ testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 # via -r model-engine/requirements.in +tokenizers==0.14.1 + # via transformers tomli==2.0.1 # via # build @@ -411,7 +435,11 @@ tomli==2.0.1 tqdm==4.65.0 # via # -r model-engine/requirements.in + # huggingface-hub + # transformers # twine +transformers==4.34.1 + # via -r model-engine/requirements.in twine==3.7.1 # via -r model-engine/requirements.in types-awscrt==0.16.23 @@ -430,6 +458,7 @@ typing-extensions==4.7.1 # cattrs # datadog-api-client # ddtrace + # huggingface-hub # kombu # mypy-boto3-cloudformation # mypy-boto3-dynamodb diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index 3438f65d..68172acf 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -57,6 +57,7 @@ istio_enabled: true tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" +tensorrt_llm_repository: "tensorrt-llm" user_inference_base_repository: "launch/inference" user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg index c47c17ed..f40a2dd1 100644 --- a/model-engine/setup.cfg +++ b/model-engine/setup.cfg @@ -31,5 +31,7 @@ addopts = --mypy --mypy-ini-file=mypy.ini --ignore=clients +# Need to specify this since pytest override mypy.ini See https://github.com/realpython/pytest-mypy/issues/123 + --ignore-glob=*triton_model_repo* # --pylint # --pylint-rcfile=setup.cfg diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index e9dd1e44..445ab83d 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -3942,3 +3942,76 @@ def llm_model_endpoint_text_generation_inference( image="test_image", ), ) + + +@pytest.fixture +def llm_model_endpoint_trt_llm( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + return ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_3", + name="test_llm_model_endpoint_name_trt_llm", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-2-7b", + "source": "hugging_face", + "inference_framework": "tensorrt_llm", + "inference_framework_image_tag": "23.10", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.STREAMING, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_trt_llm", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + __root__=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index c27aaa52..06310666 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -330,6 +330,61 @@ def create_llm_model_endpoint_text_generation_inference_request_async() -> ( ) +@pytest.fixture +def create_llm_model_endpoint_trt_llm_request_streaming() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_trt_llm_streaming", + model_name="llama-2-7b", + source="hugging_face", + inference_framework="tensorrt_llm", + inference_framework_image_tag="23.10", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage=None, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://test_checkpoint_path", + ) + + +@pytest.fixture +def create_llm_model_endpoint_trt_llm_request_async() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_tgi_async", + model_name="llama-2-7b", + source="hugging_face", + inference_framework="tensorrt_llm", + inference_framework_image_tag="23.10", + num_shards=2, + quantize=Quantization.BITSANDBYTES, + endpoint_type=ModelEndpointType.ASYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage=None, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://test_checkpoint_path", + ) + + @pytest.fixture def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( @@ -386,7 +441,7 @@ def create_llm_model_endpoint_request_invalid_quantization() -> CreateLLMModelEn @pytest.fixture def completion_sync_request() -> CompletionSyncV1Request: return CompletionSyncV1Request( - prompt="test_prompt_1", + prompt="What is machine learning?", max_new_tokens=10, temperature=0.5, return_token_log_probs=True, @@ -396,7 +451,7 @@ def completion_sync_request() -> CompletionSyncV1Request: @pytest.fixture def completion_stream_request() -> CompletionStreamV1Request: return CompletionStreamV1Request( - prompt="test_prompt_1", + prompt="What is machine learning?", max_new_tokens=10, temperature=0.5, ) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index c83e1049..d7ec41f0 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -260,6 +260,63 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( ) +@pytest.mark.asyncio +async def test_create_model_endpoint_trt_llm_use_case_success( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_trt_llm_request_async: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_trt_llm_request_streaming: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + request=create_llm_model_endpoint_trt_llm_request_streaming, + ) + assert response_1.endpoint_creation_task_id + assert isinstance(response_1, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_trt_llm_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_trt_llm_request_streaming.model_name, + "source": create_llm_model_endpoint_trt_llm_request_streaming.source, + "inference_framework": create_llm_model_endpoint_trt_llm_request_streaming.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_trt_llm_request_streaming.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_trt_llm_request_streaming.num_shards, + "quantize": create_llm_model_endpoint_trt_llm_request_streaming.quantize, + } + } + + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute( + user=user, + request=create_llm_model_endpoint_trt_llm_request_async, + ) + + @pytest.mark.asyncio async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception( test_api_key: str, @@ -545,6 +602,39 @@ async def test_completion_sync_text_generation_inference_use_case_success( ) +@pytest.mark.asyncio +async def test_completion_sync_trt_llm_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_trt_llm: ModelEndpoint, + completion_sync_request: CompletionSyncV1Request, +): + completion_sync_request.return_token_log_probs = False # not yet supported + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_trt_llm) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": '{"model_name": "ensemble", "model_version": "1", "sequence_end": false, "sequence_id": 0, "sequence_start": false, "text_output": " What is machine learning? Machine learning is a branch", "token_ids": [1, 1724, 338, 4933, 6509, 29973, 6189, 6509, 338, 263, 5443]}' + }, + traceback=None, + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_trt_llm.record.name, + request=completion_sync_request, + ) + assert response_1.output == CompletionOutput( + text=" Machine learning is a branch", + num_completion_tokens=5, + ) + + @pytest.mark.asyncio async def test_completion_sync_use_case_predict_failed( test_api_key: str, @@ -777,6 +867,61 @@ async def test_completion_stream_text_generation_inference_use_case_success( i += 1 +@pytest.mark.asyncio +async def test_completion_stream_trt_llm_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_trt_llm: ModelEndpoint, + completion_stream_request: CompletionStreamV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_trt_llm) + fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "Machine", "token_ids": 6189}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "learning", "token_ids": 6509}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "is", "token_ids": 338}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "a", "token_ids": 263}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "branch", "token_ids": 5443}}, + traceback=None, + ), + ] + use_case = CompletionStreamV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_trt_llm.record.name, + request=completion_stream_request, + ) + output_texts = ["Machine", "learning", "is", "a", "branch"] + i = 0 + async for message in response_1: + assert message.dict()["request_id"] + assert message.dict()["output"]["text"] == output_texts[i] + assert message.dict()["output"]["num_completion_tokens"] == i + 1 + i += 1 + + @pytest.mark.asyncio async def test_create_llm_fine_tune_model_name_valid(): assert is_model_name_suffix_valid("model-name") From 5e4d6626a3b16a4da80db9959c67ca19db85412b Mon Sep 17 00:00:00 2001 From: tiffzhao5 <142925794+tiffzhao5@users.noreply.github.com> Date: Wed, 15 Nov 2023 10:41:38 -0800 Subject: [PATCH 179/425] Fine-tuning e2e integration test (#372) * make test work * add status checking * fix * test * wget fix * final fixes * move namespace --- .circleci/config.yml | 10 ++- charts/model-engine/values_circleci.yaml | 2 +- integration_tests/test_fine_tunes.py | 73 ++++++++++++++----- .../inference/pytorch_or_tf.base.Dockerfile | 2 +- 4 files changed, 64 insertions(+), 23 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index b10843eb..abeb67f0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -121,15 +121,16 @@ jobs: command: | sudo apt-get update && sudo apt-get install -y expect pushd $HOME/project/.circleci/resources + kubectl create namespace model-engine kubectl apply -f redis-k8s.yaml kubectl apply -f postgres-k8s.yaml kubectl create secret generic model-engine-postgres-credentials --from-literal=database_url=postgresql://postgres:circle_test@postgres.default:5432/circle_test + kubectl create secret generic model-engine-postgres-credentials --from-literal=database_url=postgresql://postgres:circle_test@postgres.default:5432/circle_test -n model-engine export ISTIO_VERSION=1.15.0 popd curl -L https://istio.io/downloadIstio | TARGET_ARCH=x86_64 sh - install istio-${ISTIO_VERSION}/bin/istioctl $HOME/bin $HOME/bin/istioctl install --set profile=demo -y - kubectl create namespace model-engine kubectl create configmap default-config --from-literal=config="$(cat $HOME/project/.circleci/resources/.minikube-config-map | envsubst)" kubectl create configmap default-config --namespace model-engine --from-literal=config="$(cat $HOME/project/.circleci/resources/.minikube-config-map | envsubst)" cat $HOME/project/.circleci/resources/.minikube-registry-creds | envsubst | expect @@ -142,7 +143,7 @@ jobs: name: Pre-load integration test images to minikube command: | docker build -f model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile \ - --build-arg BASE_IMAGE=pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime \ + --build-arg BASE_IMAGE=python:3.8-slim \ --build-arg REQUIREMENTS_FILE="$CIRCLE_SHA1-base-requirements.txt" \ -t temp:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1 . @@ -179,7 +180,10 @@ jobs: command: | pushd $HOME/project kubectl port-forward svc/model-engine 5001:80 & - GIT_TAG=$CIRCLE_SHA1 pytest integration_tests + export AWS_ACCESS_KEY_ID=$CIRCLECI_AWS_ACCESS_KEY + export AWS_SECRET_ACCESS_KEY=$CIRCLECI_AWS_SECRET_KEY + export GIT_TAG=$CIRCLE_SHA1 + pytest integration_tests executors: ubuntu-large: diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index a5e29f87..3c2c94ec 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -140,7 +140,7 @@ config: billing_queue_arn: none cache_redis_url: redis://redis-message-broker-master.default/15 - s3_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET" + s3_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET/fine_tune_repository" dd_trace_enabled: false istio_enabled: true tgi_repository: "text-generation-inference" diff --git a/integration_tests/test_fine_tunes.py b/integration_tests/test_fine_tunes.py index 024540e2..22593eca 100644 --- a/integration_tests/test_fine_tunes.py +++ b/integration_tests/test_fine_tunes.py @@ -1,31 +1,68 @@ -import pytest +import json +import os +import time -from .rest_api_utils import ( # CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, CREATE_FINE_TUNE_REQUEST, USER_ID_0, cancel_fine_tune_by_id, create_docker_image_batch_job_bundle, create_fine_tune, get_fine_tune_by_id, +import boto3 +import smart_open + +from .rest_api_utils import ( + CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, + CREATE_FINE_TUNE_REQUEST, USER_ID_0, + cancel_fine_tune_by_id, + create_docker_image_batch_job_bundle, + create_fine_tune, + get_fine_tune_by_id, list_fine_tunes, ) +MAX_RETRIES = 10 + -@pytest.mark.skip(reason="test doesn't currently work, needs to be implemented correctly") def test_fine_tunes() -> None: - # TODO: get this test to work (move LLM fine tune repository to database rather than in S3) + di_batch_job_id = create_docker_image_batch_job_bundle( + CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, USER_ID_0 + )["docker_image_batch_job_bundle_id"] + data = { + "test_base_model-lora": { + "docker_image_batch_job_bundle_id": di_batch_job_id, + "launch_bundle_config": {}, + "launch_endpoint_config": {}, + "default_hparams": {}, + "required_params": [], + } + } - # di_batch_job_id = create_docker_image_batch_job_bundle( - # CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, USER_ID_0 - # )["docker_image_batch_job_bundle_id"] + if os.getenv("CIRCLECI") == "true": + session = boto3.Session() + aws_s3_bucket = os.getenv("CIRCLECI_AWS_S3_BUCKET") + client = session.client("s3") + with smart_open.open( + f"s3://{aws_s3_bucket}/fine_tune_repository", + "w", + transport_params={"client": client}, + ) as f: + json.dump(data, f) - # create_response = create_fine_tune(CREATE_FINE_TUNE_REQUEST, USER_ID_0) - # fine_tune_id = create_response["id"] + create_response = create_fine_tune(CREATE_FINE_TUNE_REQUEST, USER_ID_0) + fine_tune_id = create_response["id"] - # get_response = get_fine_tune_by_id(fine_tune_id, USER_ID_0) - # assert get_response["id"] == fine_tune_id + get_response = get_fine_tune_by_id(fine_tune_id, USER_ID_0) + num_retries = 0 + while get_response["status"] not in ["SUCCESS", "FAILURE"]: + if num_retries >= MAX_RETRIES: + raise Exception("Fine tune job did not complete in time.") + num_retries += 1 + get_response = get_fine_tune_by_id(fine_tune_id, USER_ID_0) + time.sleep(10) + assert get_response["id"] == fine_tune_id + assert get_response["status"] == "SUCCESS" - # list_response_0_before = list_fine_tunes(USER_ID_0) - # num_jobs = len(list_response_0_before["jobs"]) - # assert num_jobs >= 1 + list_response_0_before = list_fine_tunes(USER_ID_0) + num_jobs = len(list_response_0_before["jobs"]) + assert num_jobs >= 1 - list_response_1 = list_fine_tunes(USER_ID_0) - assert len(list_response_1["jobs"]) == 0 + cancel_fine_tune_by_id(fine_tune_id, USER_ID_0) - # list_response_0_after = list_fine_tunes(USER_ID_0) - # assert len(list_response_0_after["jobs"]) == num_jobs - 1 + list_response_0_after = list_fine_tunes(USER_ID_0) + assert len(list_response_0_after["jobs"]) == num_jobs - 1 diff --git a/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile index 01cbdf0c..8d8d3378 100644 --- a/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile +++ b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile @@ -27,7 +27,7 @@ RUN apt-get update && apt-get install -y \ && rm -rf /var/lib/apt/lists/* # Apparently wget has a vulnerability so we remove it here -RUN apt-get remove wget -y +RUN dpkg -l | grep wget && apt-get remove wget -y || echo "wget not installed, skipping removal" # Create a virtualenv for python so we install our packages in the right place # Not sure how useful the existing contents of the pytorch image are anymore :/ Maybe it's used for cuda/cudnn installs From 5b6aeff6b6636838d31c90d7f3f3f6d915390a6f Mon Sep 17 00:00:00 2001 From: Sam Denton <106690182+sam-scale@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:33:17 -0800 Subject: [PATCH 180/425] Found a bug in the codellama vllm model_len logic. (#380) * Found a bug in the codellama vllm model_len logic. Also, let's just avoid the vLLM error by making sure max_num_batched_tokens >= max_model_len * nevermind I realized that if statement will never happen here. --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 1130e6e2..f4f741e3 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -174,9 +174,9 @@ "mammoth-coder": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, # Based on config here: https://huggingface.co/TIGER-Lab/MAmmoTH-Coder-7B/blob/main/config.json#L12 # Can also see 13B, 34B there too - "code-llama": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, + "codellama": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, # Based on config here: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json#L12 - # Can also see 13B, 34B there too + # Can also see 13B, 34B there too. Note, codellama is one word. "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, } From 043f83a923d93806fbe0c7e5172423785ae4cd43 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 15 Nov 2023 14:12:36 -0800 Subject: [PATCH 181/425] Fix sample.yaml (#381) --- charts/model-engine/values_sample.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index eb9d695b..8b32aee0 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -199,6 +199,7 @@ config: tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" + tensorrt_llm_repository: "tensorrt-llm" user_inference_base_repository: "launch/inference" user_inference_pytorch_repository: "launch/inference/pytorch" user_inference_tensorflow_repository: "launch/inference/tf" From 257ea6c7da5b3dcf6764078f31e2c57bfc7c7f52 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Wed, 15 Nov 2023 14:40:17 -0800 Subject: [PATCH 182/425] count prompt tokens (#366) * count prompt tokens, use tokenizer if needed * docstrings * fix tests and code cov * add download files from s3 fn * use same helpers and add docstring * change to namedtuple * add s3 repo locations * fallback read from s3 * refactor tokenizer laod * edit tests * refactor _SUPPORTED_MODELS_BY_FRAMEWORK * updates for tests * move to utils file * move some fns over * use lru cache * move model info * root to opt * add log and adjust integration test * refocus logs * change empty string to optional * mock count tokens for unit tests * change 1 mock * add unit tests * config change * comments pt 1 * move internal logic to plugins file * replace usage of utils file * rearrange test mock * only return prompt tokens count on last token in stream * fix mock * reorganize imports * inject in external interfaces * make changes to tests * fix tests * adjust test * oops test * add more tests --- charts/model-engine/values_circleci.yaml | 2 +- integration_tests/test_endpoints.py | 15 + .../model_engine_server/api/dependencies.py | 6 + .../model_engine_server/api/llms_v1.py | 2 + .../model_engine_server/common/dtos/llms.py | 4 +- .../domain/gateways/llm_artifact_gateway.py | 19 ++ .../domain/repositories/__init__.py | 2 + .../repositories/tokenizer_repository.py | 18 ++ .../use_cases/llm_model_endpoint_use_cases.py | 278 +++++++++++------- .../infra/gateways/s3_llm_artifact_gateway.py | 47 ++- .../infra/repositories/__init__.py | 2 + .../repositories/live_tokenizer_repository.py | 149 ++++++++++ model-engine/requirements.in | 2 + model-engine/requirements.txt | 2 + .../service_config_circleci.yaml | 2 +- model-engine/setup.cfg | 5 +- model-engine/tests/unit/conftest.py | 26 +- .../tests/unit/domain/test_llm_use_cases.py | 58 +++- .../gateways/test_s3_llm_artifact_gateway.py | 85 ++++++ .../test_live_tokenizer_repository.py | 62 ++++ 20 files changed, 652 insertions(+), 134 deletions(-) create mode 100644 model-engine/model_engine_server/domain/repositories/tokenizer_repository.py create mode 100644 model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py create mode 100644 model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py create mode 100644 model-engine/tests/unit/infra/repositories/test_live_tokenizer_repository.py diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 3c2c94ec..f73ffea6 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -151,7 +151,7 @@ config: user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" docker_image_layer_cache_repository: "kaniko-cache" - hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET" + hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET/model-weights" # Service Account serviceAccount: diff --git a/integration_tests/test_endpoints.py b/integration_tests/test_endpoints.py index ad40a2d9..dbb68ec8 100644 --- a/integration_tests/test_endpoints.py +++ b/integration_tests/test_endpoints.py @@ -2,6 +2,7 @@ import time import pytest +from model_engine_server.common.env_vars import CIRCLECI from tenacity import RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_fixed from .rest_api_utils import ( @@ -234,3 +235,17 @@ def test_sync_streaming_model_endpoint(capsys): ) finally: delete_model_endpoint(create_endpoint_request["name"], user) + + +@pytest.mark.skipif(CIRCLECI, reason="skip on circleci since need to figure out s3 access") +def test_models_tokenizers() -> None: + from model_engine_server.infra.gateways.s3_llm_artifact_gateway import S3LLMArtifactGateway + from model_engine_server.infra.repositories import LiveTokenizerRepository + from model_engine_server.infra.repositories.live_tokenizer_repository import ( + SUPPORTED_MODELS_INFO, + ) + + llm_artifact_gateway = S3LLMArtifactGateway() + tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway) + for model_name in SUPPORTED_MODELS_INFO: + tokenizer_repository.load_tokenizer(model_name) diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 0a055248..b68080b8 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -34,6 +34,7 @@ DockerRepository, LLMFineTuneEventsRepository, ModelBundleRepository, + TokenizerRepository, TriggerRepository, ) from model_engine_server.domain.services import ( @@ -87,6 +88,7 @@ DbTriggerRepository, ECRDockerRepository, FakeDockerRepository, + LiveTokenizerRepository, RedisModelEndpointCacheRepository, S3FileLLMFineTuneEventsRepository, S3FileLLMFineTuneRepository, @@ -134,6 +136,7 @@ class ExternalInterfaces: llm_artifact_gateway: LLMArtifactGateway cron_job_gateway: CronJobGateway monitoring_metrics_gateway: MonitoringMetricsGateway + tokenizer_repository: TokenizerRepository def get_default_monitoring_metrics_gateway() -> MonitoringMetricsGateway: @@ -260,6 +263,8 @@ def _get_external_interfaces( docker_repository = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository() + tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway) + external_interfaces = ExternalInterfaces( docker_repository=docker_repository, model_bundle_repository=model_bundle_repository, @@ -281,6 +286,7 @@ def _get_external_interfaces( trigger_repository=trigger_repository, cron_job_gateway=cron_job_gateway, monitoring_metrics_gateway=monitoring_metrics_gateway, + tokenizer_repository=tokenizer_repository, ) return external_interfaces diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index a15e3f20..153b4c2c 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -247,6 +247,7 @@ async def create_completion_sync_task( use_case = CompletionSyncV1UseCase( model_endpoint_service=external_interfaces.model_endpoint_service, llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, ) return await use_case.execute( user=auth, model_endpoint_name=model_endpoint_name, request=request @@ -290,6 +291,7 @@ async def create_completion_stream_task( use_case = CompletionStreamV1UseCase( model_endpoint_service=external_interfaces.model_endpoint_service, llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, ) response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request) diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index bf0b7519..c0e6b9fc 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -80,7 +80,7 @@ class GetLLMModelEndpointV1Response(BaseModel): """ name: str - model_name: Optional[str] = None + model_name: str source: LLMSource status: ModelEndpointStatus inference_framework: LLMInferenceFramework @@ -143,6 +143,7 @@ class TokenOutput(BaseModel): class CompletionOutput(BaseModel): text: str + num_prompt_tokens: int num_completion_tokens: int tokens: Optional[List[TokenOutput]] = None @@ -198,6 +199,7 @@ class CompletionStreamV1Request(BaseModel): class CompletionStreamOutput(BaseModel): text: str finished: bool + num_prompt_tokens: Optional[int] = None num_completion_tokens: Optional[int] = None token: Optional[TokenOutput] = None diff --git a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py index 21e3c697..017bedea 100644 --- a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py @@ -11,6 +11,21 @@ class LLMArtifactGateway(ABC): def list_files(self, path: str, **kwargs) -> List[str]: """ Gets a list of files from a given path. + + Args: + path (str): path to list files + """ + pass + + @abstractmethod + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + """ + Download files from a given path to a target path. + + Args: + path (str): path to list files + target_path (str): local path to download files + overwrite (bool): whether to overwrite existing local files """ pass @@ -18,5 +33,9 @@ def list_files(self, path: str, **kwargs) -> List[str]: def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: """ Gets a list of URLs for all files associated with a given model. + + Args: + owner (str): owner of the model + model_name (str): name of the model """ pass diff --git a/model-engine/model_engine_server/domain/repositories/__init__.py b/model-engine/model_engine_server/domain/repositories/__init__.py index 00718521..56ec32e7 100644 --- a/model-engine/model_engine_server/domain/repositories/__init__.py +++ b/model-engine/model_engine_server/domain/repositories/__init__.py @@ -4,6 +4,7 @@ from .docker_repository import DockerRepository from .llm_fine_tune_events_repository import LLMFineTuneEventsRepository from .model_bundle_repository import ModelBundleRepository +from .tokenizer_repository import TokenizerRepository from .trigger_repository import TriggerRepository __all__: Sequence[str] = [ @@ -11,5 +12,6 @@ "DockerImageBatchJobBundleRepository", "LLMFineTuneEventsRepository", "ModelBundleRepository", + "TokenizerRepository", "TriggerRepository", ] diff --git a/model-engine/model_engine_server/domain/repositories/tokenizer_repository.py b/model-engine/model_engine_server/domain/repositories/tokenizer_repository.py new file mode 100644 index 00000000..f8ba740a --- /dev/null +++ b/model-engine/model_engine_server/domain/repositories/tokenizer_repository.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + +from transformers import AutoTokenizer + + +class TokenizerRepository(ABC): + @abstractmethod + def load_tokenizer(self, model_name: str) -> AutoTokenizer: + """ + Loads a tokenizer from a model name. + + Args: + model_name: The model name to load the tokenizer for. + + Returns: + A tokenizer. + """ + pass diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index f4f741e3..3668a895 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -57,14 +57,14 @@ UpstreamServiceError, ) from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway -from model_engine_server.domain.repositories import ModelBundleRepository -from model_engine_server.domain.repositories.docker_repository import DockerRepository +from model_engine_server.domain.repositories import ( + DockerRepository, + ModelBundleRepository, + TokenizerRepository, +) from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway - -# Hack for TensorRT-LLM. Remove when it supports returning output tokens only -# See https://github.com/NVIDIA/TensorRT-LLM/issues/227 -from transformers import AutoTokenizer +from model_engine_server.infra.repositories.live_tokenizer_repository import SUPPORTED_MODELS_INFO from ...common.datadog_utils import add_trace_request_id from ..authorization.live_authorization_module import LiveAuthorizationModule @@ -80,83 +80,90 @@ logger = make_logger(logger_name()) -_SUPPORTED_MODEL_NAMES = { - LLMInferenceFramework.DEEPSPEED: { - "mpt-7b": "mosaicml/mpt-7b", - "mpt-7b-instruct": "mosaicml/mpt-7b-instruct", - "gpt-j-6b": "EleutherAI/gpt-j-6b", - "gpt-j-6b-zh-en": "EleutherAI/gpt-j-6b", - "gpt4all-j": "nomic-ai/gpt4all-j", - "dolly-v2-12b": "databricks/dolly-v2-12b", - "stablelm-tuned-7b": "StabilityAI/stablelm-tuned-alpha-7b", - "flan-t5-xxl": "google/flan-t5-xxl", - "llama-7b": "decapoda-research/llama-7b-hf", - "vicuna-13b": "eachadea/vicuna-13b-1.1", - }, - LLMInferenceFramework.TEXT_GENERATION_INFERENCE: { - "mpt-7b": "mosaicml/mpt-7b", - "mpt-7b-instruct": "mosaicml/mpt-7b-instruct", - "flan-t5-xxl": "google/flan-t5-xxl", - "llama-7b": "decapoda-research/llama-7b-hf", - "llama-2-7b": "meta-llama/Llama-2-7b-hf", - "llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", - "llama-2-13b": "meta-llama/Llama-2-13b-hf", - "llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", - "llama-2-70b": "meta-llama/Llama-2-70b-hf", - "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", - "falcon-7b": "tiiuae/falcon-7b", - "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", - "falcon-40b": "tiiuae/falcon-40b", - "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", - "codellama-7b": "codellama/CodeLlama-7b-hf", - "codellama-7b-instruct": "codellama/CodeLlama-7b-Instruct-hf", - "codellama-13b": "codellama/CodeLlama-13b-hf", - "codellama-13b-instruct": "codellama/CodeLlama-13b-Instruct-hf", - "codellama-34b": "codellama/CodeLlama-34b-hf", - "codellama-34b-instruct": "codellama/CodeLlama-34b-Instruct-hf", - "llm-jp-13b-instruct-full": "llm-jp/llm-jp-13b-instruct-full-jaster-v1.0", - "llm-jp-13b-instruct-full-dolly": "llm-jp/llm-jp-13b-instruct-full-dolly-oasst-v1.0", - }, - LLMInferenceFramework.VLLM: { - "mpt-7b": "mosaicml/mpt-7b", - "mpt-7b-instruct": "mosaicml/mpt-7b-instruct", - "llama-7b": "decapoda-research/llama-7b-hf", - "llama-2-7b": "meta-llama/Llama-2-7b-hf", - "llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", - "llama-2-13b": "meta-llama/Llama-2-13b-hf", - "llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", - "llama-2-70b": "meta-llama/Llama-2-70b-hf", - "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", - "falcon-7b": "tiiuae/falcon-7b", - "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", - "falcon-40b": "tiiuae/falcon-40b", - "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", - "mistral-7b": "mistralai/Mistral-7B-v0.1", - "mistral-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", - "falcon-180b": "tiiuae/falcon-180B", - "falcon-180b-chat": "tiiuae/falcon-180B-chat", - "codellama-7b": "codellama/CodeLlama-7b-hf", - "codellama-7b-instruct": "codellama/CodeLlama-7b-Instruct-hf", - "codellama-13b": "codellama/CodeLlama-13b-hf", - "codellama-13b-instruct": "codellama/CodeLlama-13b-Instruct-hf", - "codellama-34b": "codellama/CodeLlama-34b-hf", - "codellama-34b-instruct": "codellama/CodeLlama-34b-Instruct-hf", - "mammoth-coder-llama-2-7b": "TIGER-Lab/MAmmoTH-Coder-7B", - "mammoth-coder-llama-2-13b": "TIGER-Lab/MAmmoTH-Coder-13B", - "mammoth-coder-llama-2-34b": "TIGER-Lab/MAmmoTH-Coder-34B", - }, - LLMInferenceFramework.LIGHTLLM: { - "llama-7b": "decapoda-research/llama-7b-hf", - "llama-2-7b": "meta-llama/Llama-2-7b-hf", - "llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", - "llama-2-13b": "meta-llama/Llama-2-13b-hf", - "llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", - "llama-2-70b": "meta-llama/Llama-2-70b-hf", - "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", - }, - LLMInferenceFramework.TENSORRT_LLM: { - "llama-2-7b": "huggyllama/llama-7b", # Hack to get tokenizer for llama without sign in to huggingface - }, + +_SUPPORTED_MODELS_BY_FRAMEWORK = { + LLMInferenceFramework.DEEPSPEED: set( + [ + "mpt-7b", + "mpt-7b-instruct", + "flan-t5-xxl", + "llama-7b", + "gpt-j-6b", + "gpt-j-6b-zh-en", + "gpt4all-j", + "dolly-v2-12b", + "stablelm-tuned-7b", + "vicuna-13b", + ] + ), + LLMInferenceFramework.TEXT_GENERATION_INFERENCE: set( + [ + "mpt-7b", + "mpt-7b-instruct", + "flan-t5-xxl", + "llama-7b", + "llama-2-7b", + "llama-2-7b-chat", + "llama-2-13b", + "llama-2-13b-chat", + "llama-2-70b", + "llama-2-70b-chat", + "falcon-7b", + "falcon-7b-instruct", + "falcon-40b", + "falcon-40b-instruct", + "codellama-7b", + "codellama-7b-instruct", + "codellama-13b", + "codellama-13b-instruct", + "codellama-34b", + "codellama-34b-instruct", + "llm-jp-13b-instruct-full", + "llm-jp-13b-instruct-full-dolly", + ] + ), + LLMInferenceFramework.VLLM: set( + [ + "mpt-7b", + "mpt-7b-instruct", + "llama-7b", + "llama-2-7b", + "llama-2-7b-chat", + "llama-2-13b", + "llama-2-13b-chat", + "llama-2-70b", + "llama-2-70b-chat", + "falcon-7b", + "falcon-7b-instruct", + "falcon-40b", + "falcon-40b-instruct", + "falcon-180b", + "falcon-180b-chat", + "codellama-7b", + "codellama-7b-instruct", + "codellama-13b", + "codellama-13b-instruct", + "codellama-34b", + "codellama-34b-instruct", + "mistral-7b", + "mistral-7b-instruct", + "mammoth-coder-llama-2-7b", + "mammoth-coder-llama-2-13b", + "mammoth-coder-llama-2-34b", + ] + ), + LLMInferenceFramework.LIGHTLLM: set( + [ + "llama-7b", + "llama-2-7b", + "llama-2-7b-chat", + "llama-2-13b", + "llama-2-13b-chat", + "llama-2-70b", + "llama-2-70b-chat", + ] + ), + LLMInferenceFramework.TENSORRT_LLM: set(["llama-2-7b"]), } _SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = { @@ -186,6 +193,14 @@ DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes +def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int: + """ + Count the number of tokens in the input string. + """ + tokenizer = tokenizer_repository.load_tokenizer(model_name) + return len(tokenizer.encode(input)) + + def _include_safetensors_bin_or_pt(model_files: List[str]) -> Optional[str]: """ This function is used to determine whether to include "*.safetensors", "*.bin", or "*.pt" files @@ -228,7 +243,7 @@ def _model_endpoint_entity_to_get_llm_model_endpoint_response( def validate_model_name(model_name: str, inference_framework: LLMInferenceFramework) -> None: - if model_name not in _SUPPORTED_MODEL_NAMES[inference_framework]: + if model_name not in _SUPPORTED_MODELS_BY_FRAMEWORK[inference_framework]: raise ObjectHasInvalidValueException( f"Model name {model_name} is not supported for inference framework {inference_framework}." ) @@ -403,9 +418,7 @@ async def create_text_generation_inference_bundle( f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) else: - final_weights_folder = _SUPPORTED_MODEL_NAMES[ - LLMInferenceFramework.TEXT_GENERATION_INFERENCE - ][model_name] + final_weights_folder = SUPPORTED_MODELS_INFO[model_name].hf_repo subcommands.append( f"text-generation-launcher --hostname :: --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}" @@ -632,7 +645,7 @@ async def create_vllm_bundle( f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) else: - final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] + final_weights_folder = SUPPORTED_MODELS_INFO[model_name].hf_repo if max_model_len: subcommands.append( @@ -723,7 +736,7 @@ async def create_lightllm_bundle( f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) else: - final_weights_folder = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.VLLM][model_name] + final_weights_folder = SUPPORTED_MODELS_INFO[model_name].hf_repo subcommands.append( f"python -m lightllm.server.api_server --model_dir {final_weights_folder} --port 5005 --tp {num_shards} --max_total_token_num {max_total_token_num} --max_req_input_len {max_req_input_len} --max_req_total_len {max_req_total_len} --tokenizer_mode auto" @@ -1127,9 +1140,6 @@ def validate_and_update_completion_params( return request -tokenizer_cache: Dict[str, AutoTokenizer] = {} - - class CompletionSyncV1UseCase: """ Use case for running a prompt completion on an LLM endpoint. @@ -1139,17 +1149,19 @@ def __init__( self, model_endpoint_service: ModelEndpointService, llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, ): self.model_endpoint_service = model_endpoint_service self.llm_model_endpoint_service = llm_model_endpoint_service self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository def model_output_to_completion_output( self, model_output: Dict[str, Any], model_endpoint: ModelEndpoint, + prompt: str, with_token_probs: Optional[bool], - prompt: Optional[str] = None, ) -> CompletionOutput: model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: @@ -1159,6 +1171,11 @@ def model_output_to_completion_output( tokens = deepspeed_result_to_tokens(model_output) return CompletionOutput( text=model_output["text"], + num_prompt_tokens=count_tokens( + prompt, + model_content.model_name, + self.tokenizer_repository, + ), num_completion_tokens=completion_token_count, tokens=tokens, ) @@ -1172,7 +1189,7 @@ def model_output_to_completion_output( ] return CompletionOutput( text=model_output["generated_text"], - # len(model_output["details"]["prefill"]) does not return the correct value reliably + num_prompt_tokens=len(model_output["details"]["prefill"]), num_completion_tokens=model_output["details"]["generated_tokens"], tokens=tokens, ) @@ -1194,6 +1211,7 @@ def model_output_to_completion_output( ] return CompletionOutput( text=model_output["text"], + num_prompt_tokens=model_output["count_prompt_tokens"], num_completion_tokens=model_output["count_output_tokens"], tokens=tokens, ) @@ -1206,6 +1224,11 @@ def model_output_to_completion_output( ] return CompletionOutput( text=model_output["generated_text"][0], + num_prompt_tokens=count_tokens( + prompt, + model_content.model_name, + self.tokenizer_repository, + ), num_completion_tokens=model_output["count_output_tokens"], tokens=tokens, ) @@ -1216,20 +1239,14 @@ def model_output_to_completion_output( ) if not prompt: raise InvalidRequestException("Prompt must be provided for TensorRT-LLM models.") - if model_content.model_name not in tokenizer_cache: - tokenizer_cache[model_content.model_name] = AutoTokenizer.from_pretrained( - _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.TENSORRT_LLM][ - model_content.model_name - ] - ) - tokenizer = tokenizer_cache[model_content.model_name] - prompt_tokens = tokenizer.encode(prompt) - + num_prompt_tokens = count_tokens( + prompt, model_content.model_name, self.tokenizer_repository + ) return CompletionOutput( - text=model_output["text_output"][ - len(prompt) + 4 : - ], # Output is " prompt output" - num_completion_tokens=len(model_output["token_ids"]) - len(prompt_tokens), + # Output is " prompt output" + text=model_output["text_output"][(len(prompt) + 4) :], + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=len(model_output["token_ids"]) - num_prompt_tokens, ) else: raise EndpointUnsupportedInferenceTypeException( @@ -1334,8 +1351,8 @@ async def execute( output=self.model_output_to_completion_output( predict_result.result["result"][0], model_endpoint, - request.return_token_log_probs, request.prompt, + request.return_token_log_probs, ), ) else: @@ -1383,7 +1400,7 @@ async def execute( return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( - output, model_endpoint, request.return_token_log_probs + output, model_endpoint, request.prompt, request.return_token_log_probs ), ) elif endpoint_content.inference_framework == LLMInferenceFramework.VLLM: @@ -1421,7 +1438,7 @@ async def execute( return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( - output, model_endpoint, request.return_token_log_probs + output, model_endpoint, request.prompt, request.return_token_log_probs ), ) elif endpoint_content.inference_framework == LLMInferenceFramework.LIGHTLLM: @@ -1463,7 +1480,7 @@ async def execute( return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( - output, model_endpoint, request.return_token_log_probs + output, model_endpoint, request.prompt, request.return_token_log_probs ), ) elif endpoint_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: @@ -1499,7 +1516,7 @@ async def execute( return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( - output, model_endpoint, request.return_token_log_probs, request.prompt + output, model_endpoint, request.prompt, request.return_token_log_probs ), ) else: @@ -1517,10 +1534,12 @@ def __init__( self, model_endpoint_service: ModelEndpointService, llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, ): self.model_endpoint_service = model_endpoint_service self.llm_model_endpoint_service = llm_model_endpoint_service self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository async def execute( self, user: User, model_endpoint_name: str, request: CompletionStreamV1Request @@ -1591,6 +1610,7 @@ async def execute( request = validated_request args: Any = None + num_prompt_tokens = None if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: args = { "prompts": [request.prompt], @@ -1605,6 +1625,11 @@ async def execute( if request.stop_sequences is not None: # Deepspeed models only accepts one stop sequence args["stop_sequence"] = request.stop_sequences[0] + num_prompt_tokens = count_tokens( + request.prompt, + model_content.model_name, + self.tokenizer_repository, + ) elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: args = { "inputs": request.prompt, @@ -1621,6 +1646,11 @@ async def execute( args["parameters"]["top_p"] = request.top_p else: args["parameters"]["do_sample"] = False + num_prompt_tokens = count_tokens( + request.prompt, + model_content.model_name, + self.tokenizer_repository, + ) elif model_content.inference_framework == LLMInferenceFramework.VLLM: args = { "prompt": request.prompt, @@ -1656,6 +1686,11 @@ async def execute( args["parameters"]["do_sample"] = False if request.return_token_log_probs: args["parameters"]["return_details"] = True + num_prompt_tokens = count_tokens( + request.prompt, + model_content.model_name, + self.tokenizer_repository, + ) elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: # TODO: Stop sequences is buggy and return token logprobs are not supported # TODO: verify the implementation of presence_penalty and repetition_penalty @@ -1669,6 +1704,12 @@ async def execute( "temperature": request.temperature, "stream": True, } + num_prompt_tokens = count_tokens( + request.prompt, + model_content.model_name, + self.tokenizer_repository, + ) + else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" @@ -1694,6 +1735,7 @@ async def execute( output=CompletionStreamOutput( text=result["result"]["token"], finished=False, + num_prompt_tokens=None, num_completion_tokens=None, ), ) @@ -1706,6 +1748,7 @@ async def execute( output=CompletionStreamOutput( text=result["result"]["response"][0]["text"], finished=True, + num_prompt_tokens=num_prompt_tokens, num_completion_tokens=completion_token_count, ), ) @@ -1737,6 +1780,7 @@ async def execute( output=CompletionStreamOutput( text=result["result"]["token"]["text"], finished=finished, + num_prompt_tokens=num_prompt_tokens if finished else None, num_completion_tokens=num_completion_tokens, token=token, ), @@ -1767,11 +1811,14 @@ async def execute( token=result["result"]["text"], log_prob=list(result["result"]["log_probs"].values())[0], ) + finished = result["result"]["finished"] + num_prompt_tokens = result["result"]["count_prompt_tokens"] yield CompletionStreamV1Response( request_id=request_id, output=CompletionStreamOutput( text=result["result"]["text"], - finished=result["result"]["finished"], + finished=finished, + num_prompt_tokens=num_prompt_tokens if finished else None, num_completion_tokens=result["result"]["count_output_tokens"], token=token, ), @@ -1790,11 +1837,13 @@ async def execute( token=result["result"]["token"]["text"], log_prob=result["result"]["token"]["logprob"], ) + finished = result["result"]["finished"] yield CompletionStreamV1Response( request_id=request_id, output=CompletionStreamOutput( text=result["result"]["token"]["text"], - finished=result["result"]["finished"], + finished=finished, + num_prompt_tokens=num_prompt_tokens if finished else None, num_completion_tokens=num_completion_tokens, token=token, ), @@ -1812,6 +1861,7 @@ async def execute( output=CompletionStreamOutput( text=result["result"]["text_output"], finished=False, # Tracked by https://github.com/NVIDIA/TensorRT-LLM/issues/240 + num_prompt_tokens=num_prompt_tokens, num_completion_tokens=num_completion_tokens, ), ) diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py index 12f03d2a..2582d40d 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -3,9 +3,12 @@ import boto3 from model_engine_server.common.config import get_model_cache_directory_name, hmi_config +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.core.utils.url import parse_attachment_url from model_engine_server.domain.gateways import LLMArtifactGateway +logger = make_logger(logger_name()) + class S3LLMArtifactGateway(LLMArtifactGateway): """ @@ -24,25 +27,45 @@ def list_files(self, path: str, **kwargs) -> List[str]: bucket = parsed_remote.bucket key = parsed_remote.key - # Using resource's bucket object to get its objects with specific prefix s3_bucket = s3.Bucket(bucket) files = [obj.key for obj in s3_bucket.objects.filter(Prefix=key)] return files + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + s3 = self._get_s3_resource(kwargs) + parsed_remote = parse_attachment_url(path) + bucket = parsed_remote.bucket + key = parsed_remote.key + + s3_bucket = s3.Bucket(bucket) + downloaded_files: List[str] = [] + for obj in s3_bucket.objects.filter(Prefix=key): + file_path_suffix = obj.key.replace(key, "").lstrip("/") + local_path = os.path.join(target_path, file_path_suffix).rstrip("/") + + if not overwrite and os.path.exists(local_path): + downloaded_files.append(local_path) + continue + + local_dir = "/".join(local_path.split("/")[:-1]) + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + logger.info(f"Downloading {obj.key} to {local_path}") + s3_bucket.download_file(obj.key, local_path) + downloaded_files.append(local_path) + return downloaded_files + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: s3 = self._get_s3_resource(kwargs) - # parsing prefix to get S3 bucket name - bucket_name = hmi_config.hf_user_fine_tuned_weights_prefix.replace("s3://", "").split("/")[ - 0 - ] - bucket = s3.Bucket(bucket_name) + parsed_remote = parse_attachment_url(hmi_config.hf_user_fine_tuned_weights_prefix) + bucket = parsed_remote.bucket + fine_tuned_weights_prefix = parsed_remote.key + + s3_bucket = s3.Bucket(bucket) model_files: List[str] = [] model_cache_name = get_model_cache_directory_name(model_name) - # parsing prefix to get /hosted-model-inference/fine_tuned_weights - fine_tuned_weights_prefix = "/".join( - hmi_config.hf_user_fine_tuned_weights_prefix.split("/")[-2:] - ) prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}" - for obj in bucket.objects.filter(Prefix=prefix): - model_files.append(f"s3://{bucket_name}/{obj.key}") + for obj in s3_bucket.objects.filter(Prefix=prefix): + model_files.append(f"s3://{bucket}/{obj.key}") return model_files diff --git a/model-engine/model_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py index 93fd708b..42e9988c 100644 --- a/model-engine/model_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -9,6 +9,7 @@ from .ecr_docker_repository import ECRDockerRepository from .fake_docker_repository import FakeDockerRepository from .feature_flag_repository import FeatureFlagRepository +from .live_tokenizer_repository import LiveTokenizerRepository from .llm_fine_tune_repository import LLMFineTuneRepository from .model_endpoint_cache_repository import ModelEndpointCacheRepository from .model_endpoint_record_repository import ModelEndpointRecordRepository @@ -27,6 +28,7 @@ "ECRDockerRepository", "FakeDockerRepository", "FeatureFlagRepository", + "LiveTokenizerRepository", "LLMFineTuneRepository", "ModelEndpointRecordRepository", "ModelEndpointCacheRepository", diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py new file mode 100644 index 00000000..e107e117 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -0,0 +1,149 @@ +import os +from collections import namedtuple +from functools import lru_cache +from typing import Dict, Optional + +from huggingface_hub import list_repo_refs +from huggingface_hub.utils._errors import RepositoryNotFoundError +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ObjectNotFoundException +from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway +from model_engine_server.domain.repositories.tokenizer_repository import TokenizerRepository +from transformers import AutoTokenizer + +logger = make_logger(logger_name()) + + +TOKENIZER_FILES_REQUIRED = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", +] +TOKENIZER_FILES_OPTIONAL = [ + "tokenizer.model", +] +TOKENIZER_TARGET_DIR = "/opt/.cache/model_engine_server/tokenizers" + + +ModelInfo = namedtuple("ModelInfo", ["hf_repo", "s3_repo"]) + + +def get_default_supported_models_info() -> Dict[str, ModelInfo]: + return { + "mpt-7b": ModelInfo("mosaicml/mpt-7b", None), + "mpt-7b-instruct": ModelInfo("mosaicml/mpt-7b-instruct", None), + "flan-t5-xxl": ModelInfo("google/flan-t5-xxl", None), + "llama-7b": ModelInfo("decapoda-research/llama-7b-hf", None), + "llama-2-7b": ModelInfo("huggyllama/llama-7b", None), + "llama-2-7b-chat": ModelInfo("meta-llama/Llama-2-7b-chat-hf", None), + "llama-2-13b": ModelInfo("meta-llama/Llama-2-13b-hf", None), + "llama-2-13b-chat": ModelInfo("meta-llama/Llama-2-13b-chat-hf", None), + "llama-2-70b": ModelInfo("meta-llama/Llama-2-70b-hf", None), + "llama-2-70b-chat": ModelInfo("meta-llama/Llama-2-70b-chat-hf", None), + "falcon-7b": ModelInfo("tiiuae/falcon-7b", None), + "falcon-7b-instruct": ModelInfo("tiiuae/falcon-7b-instruct", None), + "falcon-40b": ModelInfo("tiiuae/falcon-40b", None), + "falcon-40b-instruct": ModelInfo("tiiuae/falcon-40b-instruct", None), + "falcon-180b": ModelInfo("tiiuae/falcon-180B", None), + "falcon-180b-chat": ModelInfo("tiiuae/falcon-180B-chat", None), + "codellama-7b": ModelInfo("codellama/CodeLlama-7b-hf", None), + "codellama-7b-instruct": ModelInfo("codellama/CodeLlama-7b-Instruct-hf", None), + "codellama-13b": ModelInfo("codellama/CodeLlama-13b-hf", None), + "codellama-13b-instruct": ModelInfo("codellama/CodeLlama-13b-Instruct-hf", None), + "codellama-34b": ModelInfo("codellama/CodeLlama-34b-hf", None), + "codellama-34b-instruct": ModelInfo("codellama/CodeLlama-34b-Instruct-hf", None), + "llm-jp-13b-instruct-full": ModelInfo("llm-jp/llm-jp-13b-instruct-full-jaster-v1.0", None), + "llm-jp-13b-instruct-full-dolly": ModelInfo( + "llm-jp/llm-jp-13b-instruct-full-dolly-oasst-v1.0", None + ), + "mistral-7b": ModelInfo("mistralai/Mistral-7B-v0.1", None), + "mistral-7b-instruct": ModelInfo("mistralai/Mistral-7B-Instruct-v0.1", None), + "mammoth-coder-llama-2-7b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-7B", None), + "mammoth-coder-llama-2-13b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-13B", None), + "mammoth-coder-llama-2-34b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-34B", None), + "gpt-j-6b": ModelInfo("EleutherAI/gpt-j-6b", None), + "gpt-j-6b-zh-en": ModelInfo("EleutherAI/gpt-j-6b", None), + "gpt4all-j": ModelInfo("nomic-ai/gpt4all-j", None), + "dolly-v2-12b": ModelInfo("databricks/dolly-v2-12b", None), + "stablelm-tuned-7b": ModelInfo("StabilityAI/stablelm-tuned-alpha-7b", None), + "vicuna-13b": ModelInfo("eachadea/vicuna-13b-1.1", None), + } + + +def get_supported_models_info() -> Dict[str, ModelInfo]: + try: + from plugins.live_tokenizer_repository import ( + get_supported_models_info as get_custom_supported_models_info, + ) + + return get_custom_supported_models_info() + except ModuleNotFoundError: + return get_default_supported_models_info() + + +SUPPORTED_MODELS_INFO = get_supported_models_info() + + +def get_models_s3_uri(*args, **kwargs) -> str: + try: + from plugins.live_tokenizer_repository import get_models_s3_uri as get_custom_models_s3_uri + + return get_custom_models_s3_uri(*args, **kwargs) + except ModuleNotFoundError: + raise NotImplementedError + + +def get_models_local_dir_path(model_name: str) -> str: + """ + Get the local directory path for a given model. + """ + return f"{TOKENIZER_TARGET_DIR}/{model_name}" + + +class LiveTokenizerRepository(TokenizerRepository): + def __init__(self, llm_artifact_gateway: LLMArtifactGateway): + self.llm_artifact_gateway = llm_artifact_gateway + + def _load_tokenizer_from_s3(self, model_name: str, s3_prefix: Optional[str]) -> Optional[str]: + """ + Download tokenizer files from S3 to the local filesystem. + """ + if not s3_prefix: + return None + + model_tokenizer_dir = get_models_local_dir_path(model_name) + + for file in TOKENIZER_FILES_REQUIRED: + s3_path = get_models_s3_uri(s3_prefix, file) + target_path = os.path.join(model_tokenizer_dir, file) + self.llm_artifact_gateway.download_files(s3_path, target_path) + + for file in TOKENIZER_FILES_OPTIONAL: + s3_path = get_models_s3_uri(s3_prefix, file) + target_path = os.path.join(model_tokenizer_dir, file) + try: + self.llm_artifact_gateway.download_files(s3_path, target_path) + except Exception: + pass + + return model_tokenizer_dir + + @lru_cache(maxsize=32) + def load_tokenizer(self, model_name: str) -> AutoTokenizer: + model_info = SUPPORTED_MODELS_INFO[model_name] + + model_location = None + try: + if not model_info.hf_repo: + raise RepositoryNotFoundError("No HF repo specified for model.") + list_repo_refs(model_info.hf_repo) # check if model exists in Hugging Face Hub + model_location = model_info.hf_repo + # AutoTokenizer handles file downloads for HF repos + except RepositoryNotFoundError: + model_location = self._load_tokenizer_from_s3(model_name, model_info.s3_repo) + + if not model_location: + raise ObjectNotFoundException(f"Tokenizer not found for model {model_name}.") + + logger.info(f"Loading tokenizer for model {model_name} from {model_location}.") + return AutoTokenizer.from_pretrained(model_location) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index ecdf78a1..abf0809b 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -37,6 +37,7 @@ quart==0.18.3 requests-auth-aws-sigv4~=0.7 requests~=2.25 rich~=12.6 +sentencepiece==0.1.99 sh~=1.13 smart-open~=5.2 sqlalchemy[asyncio]==2.0.4 @@ -44,6 +45,7 @@ sse-starlette==1.6.1 sseclient-py==1.7.2 tenacity>=6.0.0,<=6.2.0 testing-postgresql==1.3.0 +transformers==4.34.1 tqdm~=4.64 transformers==4.34.1 twine==3.7.1 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 87adb372..e61d22e3 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -381,6 +381,8 @@ scramp==1.4.4 # via pg8000 secretstorage==3.3.3 # via keyring +sentencepiece==0.1.99 + # via -r model-engine/requirements.in sh==1.14.3 # via -r model-engine/requirements.in six==1.16.0 diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index 68172acf..683a3e2b 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -64,4 +64,4 @@ user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-c docker_image_layer_cache_repository: "kaniko-cache" # S3 access -hf_user_fine_tuned_weights_prefix: "s3://test-bucket" +hf_user_fine_tuned_weights_prefix: "s3://test-bucket/model-weights" diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg index f40a2dd1..1566418e 100644 --- a/model-engine/setup.cfg +++ b/model-engine/setup.cfg @@ -3,8 +3,7 @@ test=pytest [coverage:run] omit = - hosted_model_inference/start_server.py, - hosted_model_inference/start_service_builder.py + model_engine_server/entrypoints/* # TODO: Fix pylint errors # [pylint] @@ -26,7 +25,7 @@ addopts = --verbose --durations=0 --cache-clear - --cov=hosted_model_inference + --cov=model_engine_server --cov-report=term-missing --mypy --mypy-ini-file=mypy.ini diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 445ab83d..3528d558 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -106,6 +106,7 @@ DockerRepository, LLMFineTuneEventsRepository, ModelBundleRepository, + TokenizerRepository, TriggerRepository, ) from model_engine_server.domain.services import ( @@ -151,6 +152,7 @@ from model_engine_server.infra.services.live_llm_model_endpoint_service import ( LiveLLMModelEndpointService, ) +from transformers import AutoTokenizer def _translate_fake_model_endpoint_orm_to_model_endpoint_record( @@ -748,7 +750,12 @@ async def initialize_events(self, user_id: str, model_endpoint_name: str): class FakeLLMArtifactGateway(LLMArtifactGateway): def __init__(self): self.existing_models = [] - self.s3_bucket = {"fake-checkpoint": ["fake.bin, fake2.bin", "fake3.safetensors"]} + self.s3_bucket = { + "fake-checkpoint": ["fake.bin, fake2.bin", "fake3.safetensors"], + "llama-7b/tokenizer.json": ["llama-7b/tokenizer.json"], + "llama-7b/tokenizer_config.json": ["llama-7b/tokenizer_config.json"], + "llama-7b/special_tokens_map.json": ["llama-7b/special_tokens_map.json"], + } self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} def _add_model(self, owner: str, model_name: str): @@ -758,6 +765,10 @@ def list_files(self, path: str, **kwargs) -> List[str]: if path in self.s3_bucket: return self.s3_bucket[path] + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + if path in self.s3_bucket: + return self.s3_bucket[path] + def get_model_weights_urls(self, owner: str, model_name: str): if (owner, model_name) in self.existing_models: return self.urls @@ -1803,6 +1814,11 @@ async def delete_model_endpoint(self, model_endpoint_id: str) -> None: del self.db[model_endpoint_id] +class FakeTokenizerRepository(TokenizerRepository): + def load_tokenizer(self, model_name: str) -> AutoTokenizer: + return AutoTokenizer.from_pretrained(model_name) + + class FakeLLMModelEndpointService(LLMModelEndpointService): db: Dict[str, ModelEndpoint] @@ -2071,6 +2087,11 @@ def fake_image_cache_service( ) +@pytest.fixture +def fake_tokenizer_repository() -> TokenizerRepository: + return FakeTokenizerRepository() + + @pytest.fixture def get_repositories_generator_wrapper(): def get_repositories_generator( @@ -2155,6 +2176,8 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: ) fake_llm_fine_tuning_events_repository = FakeLLMFineTuneEventsRepository() fake_file_storage_gateway = FakeFileStorageGateway(fake_file_storage_gateway_contents) + fake_tokenizer_repository = FakeTokenizerRepository() + repositories = ExternalInterfaces( docker_repository=FakeDockerRepository( fake_docker_repository_image_always_exists, False @@ -2178,6 +2201,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: filesystem_gateway=fake_file_system_gateway, llm_artifact_gateway=fake_llm_artifact_gateway, monitoring_metrics_gateway=fake_monitoring_metrics_gateway, + tokenizer_repository=fake_tokenizer_repository, ) try: yield repositories diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index d7ec41f0..589e453b 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple +from typing import Any, List, Tuple from unittest import mock import pytest @@ -410,11 +410,24 @@ async def test_get_llm_model_endpoint_use_case_raises_not_authorized( ) +def mocked_auto_tokenizer_from_pretrained(*args, **kwargs): # noqa + class mocked_encode: + def encode(self, input: str) -> List[Any]: # noqa + return [1] * 7 + + return mocked_encode() + + @pytest.mark.asyncio +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.AutoTokenizer.from_pretrained", + mocked_auto_tokenizer_from_pretrained, +) async def test_completion_sync_use_case_success( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, + fake_tokenizer_repository, llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], completion_sync_request: CompletionSyncV1Request, ): @@ -465,6 +478,7 @@ async def test_completion_sync_use_case_success( use_case = CompletionSyncV1UseCase( model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -474,6 +488,7 @@ async def test_completion_sync_use_case_success( ) assert response_1.output == CompletionOutput( text="I am a newbie to the world of programming.", + num_prompt_tokens=7, num_completion_tokens=11, tokens=[ TokenOutput(token="I", log_prob=-2.3025850929940455), @@ -492,10 +507,15 @@ async def test_completion_sync_use_case_success( @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=5, +) async def test_completion_sync_text_generation_inference_use_case_success( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, + fake_tokenizer_repository, llm_model_endpoint_text_generation_inference: ModelEndpoint, completion_sync_request: CompletionSyncV1Request, ): @@ -578,6 +598,7 @@ async def test_completion_sync_text_generation_inference_use_case_success( use_case = CompletionSyncV1UseCase( model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -587,6 +608,7 @@ async def test_completion_sync_text_generation_inference_use_case_success( ) assert response_1.output == CompletionOutput( text=" Deep Learning is a new type of machine learning", + num_prompt_tokens=5, num_completion_tokens=9, tokens=[ TokenOutput(token=" Deep", log_prob=0.0), @@ -603,10 +625,15 @@ async def test_completion_sync_text_generation_inference_use_case_success( @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=6, +) async def test_completion_sync_trt_llm_use_case_success( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, + fake_tokenizer_repository, llm_model_endpoint_trt_llm: ModelEndpoint, completion_sync_request: CompletionSyncV1Request, ): @@ -622,6 +649,7 @@ async def test_completion_sync_trt_llm_use_case_success( use_case = CompletionSyncV1UseCase( model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -631,6 +659,7 @@ async def test_completion_sync_trt_llm_use_case_success( ) assert response_1.output == CompletionOutput( text=" Machine learning is a branch", + num_prompt_tokens=6, num_completion_tokens=5, ) @@ -640,6 +669,7 @@ async def test_completion_sync_use_case_predict_failed( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, + fake_tokenizer_repository, llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], completion_sync_request: CompletionSyncV1Request, ): @@ -654,6 +684,7 @@ async def test_completion_sync_use_case_predict_failed( use_case = CompletionSyncV1UseCase( model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -669,6 +700,7 @@ async def test_completion_sync_use_case_predict_failed_with_errors( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, + fake_tokenizer_repository, llm_model_endpoint_sync_tgi: Tuple[ModelEndpoint, Any], completion_sync_request: CompletionSyncV1Request, ): @@ -688,6 +720,7 @@ async def test_completion_sync_use_case_predict_failed_with_errors( use_case = CompletionSyncV1UseCase( model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(UpstreamServiceError): @@ -703,6 +736,7 @@ async def test_completion_sync_use_case_not_sync_endpoint_raises( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, + fake_tokenizer_repository, llm_model_endpoint_async: Tuple[ModelEndpoint, Any], completion_sync_request: CompletionSyncV1Request, ): @@ -710,6 +744,7 @@ async def test_completion_sync_use_case_not_sync_endpoint_raises( use_case = CompletionSyncV1UseCase( model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(EndpointUnsupportedInferenceTypeException): @@ -721,10 +756,15 @@ async def test_completion_sync_use_case_not_sync_endpoint_raises( @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=7, +) async def test_completion_stream_use_case_success( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, + fake_tokenizer_repository, llm_model_endpoint_streaming: ModelEndpoint, completion_stream_request: CompletionStreamV1Request, ): @@ -789,6 +829,7 @@ async def test_completion_stream_use_case_success( use_case = CompletionStreamV1UseCase( model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = use_case.execute( @@ -802,15 +843,21 @@ async def test_completion_stream_use_case_success( assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] if i == 6: + assert message.dict()["output"]["num_prompt_tokens"] == 7 assert message.dict()["output"]["num_completion_tokens"] == 6 i += 1 @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=7, +) async def test_completion_stream_text_generation_inference_use_case_success( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, + fake_tokenizer_repository, llm_model_endpoint_text_generation_inference: ModelEndpoint, completion_stream_request: CompletionStreamV1Request, ): @@ -850,6 +897,7 @@ async def test_completion_stream_text_generation_inference_use_case_success( use_case = CompletionStreamV1UseCase( model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = use_case.execute( @@ -863,15 +911,21 @@ async def test_completion_stream_text_generation_inference_use_case_success( assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] if i == 5: + assert message.dict()["output"]["num_prompt_tokens"] == 7 assert message.dict()["output"]["num_completion_tokens"] == 6 i += 1 @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=7, +) async def test_completion_stream_trt_llm_use_case_success( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, + fake_tokenizer_repository, llm_model_endpoint_trt_llm: ModelEndpoint, completion_stream_request: CompletionStreamV1Request, ): @@ -906,6 +960,7 @@ async def test_completion_stream_trt_llm_use_case_success( use_case = CompletionStreamV1UseCase( model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = use_case.execute( @@ -918,6 +973,7 @@ async def test_completion_stream_trt_llm_use_case_success( async for message in response_1: assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] + assert message.dict()["output"]["num_prompt_tokens"] == 7 assert message.dict()["output"]["num_completion_tokens"] == i + 1 i += 1 diff --git a/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py b/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py new file mode 100644 index 00000000..7dcf19a6 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py @@ -0,0 +1,85 @@ +from typing import List +from unittest import mock + +import pytest +from model_engine_server.common.config import hmi_config +from model_engine_server.infra.gateways.s3_llm_artifact_gateway import S3LLMArtifactGateway + + +@pytest.fixture +def llm_artifact_gateway(): + gateway = S3LLMArtifactGateway() + return gateway + + +@pytest.fixture +def fake_files(): + return ["fake-prefix/fake1", "fake-prefix/fake2", "fake-prefix/fake3"] + + +def mock_boto3_session(fake_files: List[str]): + mock_session = mock.Mock() + mock_bucket = mock.Mock() + mock_objects = mock.Mock() + + def filter_files(*args, **kwargs): + prefix = kwargs["Prefix"] + return [mock.Mock(key=file) for file in fake_files if file.startswith(prefix)] + + mock_session.return_value.resource.return_value.Bucket.return_value = mock_bucket + mock_bucket.objects = mock_objects + mock_objects.filter.side_effect = filter_files + + mock_bucket.download_file.return_value = None + return mock_session + + +@mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.os.makedirs", + lambda *args, **kwargs: None, # noqa +) +def test_s3_llm_artifact_gateway_download_folder(llm_artifact_gateway, fake_files): + prefix = "/".join(fake_files[0].split("/")[:-1]) + uri_prefix = f"s3://fake-bucket/{prefix}" + target_dir = "fake-target" + + expected_files = [f"{target_dir}/{file.split('/')[-1]}" for file in fake_files] + with mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.boto3.Session", + mock_boto3_session(fake_files), + ): + assert llm_artifact_gateway.download_files(uri_prefix, target_dir) == expected_files + + +@mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.os.makedirs", + lambda *args, **kwargs: None, # noqa +) +def test_s3_llm_artifact_gateway_download_file(llm_artifact_gateway, fake_files): + file = fake_files[1] + uri = f"s3://fake-bucket/{file}" + target = f"fake-target/{file}" + + with mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.boto3.Session", + mock_boto3_session(fake_files), + ): + assert llm_artifact_gateway.download_files(uri, target) == [target] + + +def test_s3_llm_artifact_gateway_get_model_weights(llm_artifact_gateway): + owner = "fakeuser" + model_name = "fakemodel" + fake_files = [f"{owner}/models--{model_name}/fake1", f"{owner}/models--{model_name}/fake2"] + + s3_prefix = hmi_config.hf_user_fine_tuned_weights_prefix + weights_prefix = "/".join(s3_prefix.replace("s3://", "").split("/")[1:]) + fake_model_weights = [f"{weights_prefix}/{file}" for file in fake_files] + expected_model_files = [f"{s3_prefix}/{file}" for file in fake_files] + with mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.boto3.Session", + mock_boto3_session(fake_model_weights), + ): + assert ( + llm_artifact_gateway.get_model_weights_urls(owner, model_name) == expected_model_files + ) diff --git a/model-engine/tests/unit/infra/repositories/test_live_tokenizer_repository.py b/model-engine/tests/unit/infra/repositories/test_live_tokenizer_repository.py new file mode 100644 index 00000000..b82d78f4 --- /dev/null +++ b/model-engine/tests/unit/infra/repositories/test_live_tokenizer_repository.py @@ -0,0 +1,62 @@ +from typing import Any, List +from unittest import mock + +import pytest +from model_engine_server.infra.repositories.live_tokenizer_repository import ( + LiveTokenizerRepository, + ModelInfo, +) + + +@pytest.fixture +def tokenizer_repository(fake_llm_artifact_gateway): + repository = LiveTokenizerRepository(fake_llm_artifact_gateway) + return repository + + +def mocked_get_models_s3_uri(*args, **kwargs): # noqa + return f"s3://fake-bucket/{args[0]}/{args[1]}" + + +def mocked_auto_tokenizer_from_pretrained(*args, **kwargs): # noqa + class mocked_encode: + def encode(self, input: str) -> List[Any]: + return [1] * len(input) + + return mocked_encode() + + +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.SUPPORTED_MODELS_INFO", + {"llama-7b": ModelInfo("llama-7b", None)}, +) +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.list_repo_refs", + lambda *args, **kwargs: None, # noqa +) +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.AutoTokenizer.from_pretrained", + mocked_auto_tokenizer_from_pretrained, +) +def test_load_tokenizer_from_hf(tokenizer_repository): + tokenizer = tokenizer_repository.load_tokenizer("llama-7b") + + assert tokenizer.encode("fake input") == [1] * len("fake input") + + +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.SUPPORTED_MODELS_INFO", + {"llama-7b": ModelInfo(None, "llama-7b")}, +) +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.get_models_s3_uri", + mocked_get_models_s3_uri, +) +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.AutoTokenizer.from_pretrained", + mocked_auto_tokenizer_from_pretrained, +) +def test_load_tokenizer_from_s3(tokenizer_repository): + tokenizer = tokenizer_repository.load_tokenizer("llama-7b") + + assert tokenizer.encode("fake input") == [1] * len("fake input") From 2221de022e2578616c450fe884840871b3335f60 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 16 Nov 2023 10:12:45 -0800 Subject: [PATCH 183/425] Fix integration test (#383) --- .../model-engine/templates/service_template_config_map.yaml | 4 ++-- .../templates/service_template_config_map_circleci.yaml | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 15410122..b738ebb2 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -572,7 +572,7 @@ data: {{- tuple $env_var | toYaml | nindent 16 }} {{- end }} {{- end }} - imagePullPolicy: Always + imagePullPolicy: IfNotPresent command: - dumb-init - -- @@ -666,7 +666,7 @@ data: {{- tuple $env_var | toYaml | nindent 16 }} {{- end }} {{- end }} - imagePullPolicy: Always + imagePullPolicy: IfNotPresent command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 8e014e18..bfe5c492 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -2843,7 +2843,7 @@ data: value: ${GIT_TAG} - name: GIT_TAG value: ${GIT_TAG} - imagePullPolicy: Always + imagePullPolicy: IfNotPresent command: - dumb-init - -- @@ -2968,7 +2968,7 @@ data: value: ${GIT_TAG} - name: GIT_TAG value: ${GIT_TAG} - imagePullPolicy: Always + imagePullPolicy: IfNotPresent command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. @@ -3114,7 +3114,7 @@ data: value: ${GIT_TAG} - name: GIT_TAG value: ${GIT_TAG} - imagePullPolicy: Always + imagePullPolicy: IfNotPresent command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. From df3738c4bae781901b83af4cb657fb047b5c6dbf Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Thu, 16 Nov 2023 10:52:50 -0800 Subject: [PATCH 184/425] emit metrics on token counts (#382) * emit metrics on token counts * remove print --- .../model_engine_server/api/llms_v1.py | 47 +++++++++++++++---- .../model_engine_server/common/dtos/llms.py | 9 ++++ .../gateways/monitoring_metrics_gateway.py | 9 ++++ .../fake_monitoring_metrics_gateway.py | 6 +++ 4 files changed, 62 insertions(+), 9 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 153b4c2c..6fc2a1d1 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -5,7 +5,7 @@ from typing import Optional import pytz -from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, @@ -32,6 +32,7 @@ ModelDownloadResponse, StreamError, StreamErrorContent, + TokenUsage, ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User @@ -86,17 +87,20 @@ def format_request_route(request: Request) -> str: return f"{request.method}_{url_path}".lower() -async def record_route_call( +async def get_metric_metadata( request: Request, auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -): - route = format_request_route(request) +) -> MetricMetadata: model_name = request.query_params.get("model_endpoint_name", None) + return MetricMetadata(user=auth, model_name=model_name) - external_interfaces.monitoring_metrics_gateway.emit_route_call_metric( - route, MetricMetadata(user=auth, model_name=model_name) - ) + +async def record_route_call( + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + route: str = Depends(format_request_route), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +): + external_interfaces.monitoring_metrics_gateway.emit_route_call_metric(route, metric_metadata) llm_router_v1 = APIRouter(prefix="/v1/llm", dependencies=[Depends(record_route_call)]) @@ -234,8 +238,10 @@ async def get_model_endpoint( async def create_completion_sync_task( model_endpoint_name: str, request: CompletionSyncV1Request, + background_tasks: BackgroundTasks, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), ) -> CompletionSyncV1Response: """ Runs a sync prompt completion on an LLM. @@ -249,9 +255,20 @@ async def create_completion_sync_task( llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, tokenizer_repository=external_interfaces.tokenizer_repository, ) - return await use_case.execute( + response = await use_case.execute( user=auth, model_endpoint_name=model_endpoint_name, request=request ) + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=response.output.num_prompt_tokens if response.output else None, + num_completion_tokens=response.output.num_completion_tokens + if response.output + else None, + ), + metric_metadata, + ) + return response except UpstreamServiceError: request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) logger.exception(f"Upstream service error for request {request_id}") @@ -279,8 +296,10 @@ async def create_completion_sync_task( async def create_completion_stream_task( model_endpoint_name: str, request: CompletionStreamV1Request, + background_tasks: BackgroundTasks, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), ) -> EventSourceResponse: """ Runs a stream prompt completion on an LLM. @@ -299,6 +318,16 @@ async def event_generator(): try: async for message in response: yield {"data": message.json()} + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=message.output.num_prompt_tokens if message.output else None, + num_completion_tokens=message.output.num_completion_tokens + if message.output + else None, + ), + metric_metadata, + ) except (InvalidRequestException, ObjectHasInvalidValueException) as exc: yield handle_streaming_exception(exc, 400, str(exc)) except ( diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index c0e6b9fc..dd2e06a0 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -233,6 +233,15 @@ class CompletionStreamV1Response(BaseModel): """Error of the response (if any).""" +class TokenUsage(BaseModel): + num_prompt_tokens: Optional[int] = 0 + num_completion_tokens: Optional[int] = 0 + + @property + def num_total_tokens(self) -> int: + return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) + + class CreateFineTuneRequest(BaseModel): model: str training_file: str diff --git a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py index 5e7e0382..9bca6a0d 100644 --- a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from typing import Optional +from model_engine_server.common.dtos.llms import TokenUsage from model_engine_server.core.auth.authentication_repository import User from pydantic import BaseModel @@ -81,3 +82,11 @@ def emit_route_call_metric(self, route: str, metadata: MetricMetadata): """ pass + + @abstractmethod + def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMetadata): + """ + Token count metrics + + """ + pass diff --git a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py index 32c2b6f3..9b63a135 100644 --- a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -1,5 +1,6 @@ from collections import defaultdict +from model_engine_server.common.dtos.llms import TokenUsage from model_engine_server.domain.gateways.monitoring_metrics_gateway import ( MetricMetadata, MonitoringMetricsGateway, @@ -19,6 +20,7 @@ def __init__(self): self.database_cache_hit = 0 self.database_cache_miss = 0 self.route_call = defaultdict(int) + self.token_count = 0 def reset(self): self.attempted_build = 0 @@ -32,6 +34,7 @@ def reset(self): self.database_cache_hit = 0 self.database_cache_miss = 0 self.route_call = defaultdict(int) + self.token_count = 0 def emit_attempted_build_metric(self): self.attempted_build += 1 @@ -65,3 +68,6 @@ def emit_database_cache_miss_metric(self): def emit_route_call_metric(self, route: str, _metadata: MetricMetadata): self.route_call[route] += 1 + + def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMetadata): + self.token_count += token_usage.num_total_tokens From df26b0a947140985de8b8b1b3285f3192f274270 Mon Sep 17 00:00:00 2001 From: Sam Denton <106690182+sam-scale@users.noreply.github.com> Date: Thu, 16 Nov 2023 14:45:24 -0800 Subject: [PATCH 185/425] Increase llama-2 max_input_tokens (#384) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 3668a895..9e94859a 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -399,7 +399,7 @@ async def create_text_generation_inference_bundle( max_input_length = 1024 max_total_tokens = 2048 if "llama-2" in model_name: - max_input_length = 2048 + max_input_length = 4095 max_total_tokens = 4096 subcommands = [] From d478ee5e53de685e61238756e0de036de9f676df Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 16 Nov 2023 18:05:20 -0800 Subject: [PATCH 186/425] Revert "Found a bug in the codellama vllm model_len logic. (#380)" (#386) This reverts commit 5b6aeff6b6636838d31c90d7f3f3f6d915390a6f. --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 9e94859a..d379aac2 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -181,9 +181,9 @@ "mammoth-coder": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, # Based on config here: https://huggingface.co/TIGER-Lab/MAmmoTH-Coder-7B/blob/main/config.json#L12 # Can also see 13B, 34B there too - "codellama": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, + "code-llama": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, # Based on config here: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json#L12 - # Can also see 13B, 34B there too. Note, codellama is one word. + # Can also see 13B, 34B there too "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, } From e71326d246503afab0cd08c432d882e43a1bdd14 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 16 Nov 2023 22:06:56 -0800 Subject: [PATCH 187/425] Some updates to integration tests (#385) * Some updates to integration tests * fix * comment * better env var --- integration_tests/test_endpoints.py | 5 +++-- integration_tests/test_fine_tunes.py | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/integration_tests/test_endpoints.py b/integration_tests/test_endpoints.py index dbb68ec8..880acd7c 100644 --- a/integration_tests/test_endpoints.py +++ b/integration_tests/test_endpoints.py @@ -2,7 +2,6 @@ import time import pytest -from model_engine_server.common.env_vars import CIRCLECI from tenacity import RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_fixed from .rest_api_utils import ( @@ -237,7 +236,9 @@ def test_sync_streaming_model_endpoint(capsys): delete_model_endpoint(create_endpoint_request["name"], user) -@pytest.mark.skipif(CIRCLECI, reason="skip on circleci since need to figure out s3 access") +@pytest.mark.skipif( + reason="Need to update the following test to hit remote service to be integration test" +) def test_models_tokenizers() -> None: from model_engine_server.infra.gateways.s3_llm_artifact_gateway import S3LLMArtifactGateway from model_engine_server.infra.repositories import LiveTokenizerRepository diff --git a/integration_tests/test_fine_tunes.py b/integration_tests/test_fine_tunes.py index 22593eca..89d9e447 100644 --- a/integration_tests/test_fine_tunes.py +++ b/integration_tests/test_fine_tunes.py @@ -3,6 +3,7 @@ import time import boto3 +import pytest import smart_open from .rest_api_utils import ( @@ -19,6 +20,10 @@ MAX_RETRIES = 10 +@pytest.mark.skipif( + not os.getenv("FINE_TUNE_TEST_READY"), + reason="Skipping fine tune tests when test templates are not set up.", +) def test_fine_tunes() -> None: di_batch_job_id = create_docker_image_batch_job_bundle( CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, USER_ID_0 From 4888ecf231c05e45566c6d20db89a62106189cc9 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Fri, 17 Nov 2023 13:23:55 -0800 Subject: [PATCH 188/425] Celery autoscaler (#378) --- charts/model-engine/templates/_helpers.tpl | 26 +- .../celery_autoscaler_stateful_set.yaml | 87 +++ charts/model-engine/values_circleci.yaml | 3 + charts/model-engine/values_sample.yaml | 8 +- .../core/celery/__init__.py | 5 +- .../core/celery/celery_autoscaler.py | 631 ++++++++++++++++++ model-engine/requirements.in | 1 + model-engine/requirements.txt | 2 + 8 files changed, 757 insertions(+), 6 deletions(-) create mode 100644 charts/model-engine/templates/celery_autoscaler_stateful_set.yaml create mode 100644 model-engine/model_engine_server/core/celery/celery_autoscaler.py diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index d367f039..b4737392 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -30,6 +30,14 @@ If release name contains chart name it will be used as a full name. {{ .Values.hostDomain.prefix }}{{ include "modelEngine.fullname" . }}.{{ .Release.Namespace }}:{{ .Values.service.port }} {{- end }} +{{- define "modelEngine.celeryautoscalername" -}} +{{- if .Values.serviceIdentifier }} +{{- printf "celery-autoscaler-%s-%s" .Values.celeryBrokerType .Values.serviceIdentifier }} +{{- else }} +{{- printf "celery-autoscaler-%s" .Values.celeryBrokerType }} +{{- end }} +{{- end }} + {{/* Create chart name and version as used by the chart label. */}} @@ -37,17 +45,21 @@ Create chart name and version as used by the chart label. {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} {{- end }} +{{- define "modelEngine.baseLabels" -}} +team: infra +app.kubernetes.io/version: {{ .Values.tag }} +tags.datadoghq.com/version: {{ .Values.tag }} +tags.datadoghq.com/env: {{ .Values.context }} +{{- end }} + {{/* Common labels */}} {{- define "modelEngine.labels" -}} -team: infra +{{- include "modelEngine.baseLabels" . | printf "%s\n" -}} product: model-engine helm.sh/chart: {{ include "modelEngine.chart" . }} app.kubernetes.io/managed-by: {{ .Release.Service }} -app.kubernetes.io/version: {{ .Values.tag }} -tags.datadoghq.com/version: {{ .Values.tag }} -tags.datadoghq.com/env: {{ .Values.context }} {{- end }} {{- define "modelEngine.selectorLabels.builder" -}} @@ -62,6 +74,12 @@ app: {{ include "modelEngine.cachername" . }} app: {{ include "modelEngine.fullname" . -}} {{- end }} +{{- define "modelEngine.selectorLabels.celeryAutoscaler" -}} +app: {{ include "modelEngine.celeryautoscalername" . }} +product: common +tags.datadoghq.com/service: {{ include "modelEngine.celeryautoscalername" . -}} +{{- end }} + {{- define "modelEngine.baseTemplateLabels" -}} user_id: ${OWNER} team: ${TEAM} diff --git a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml new file mode 100644 index 00000000..fb8c393b --- /dev/null +++ b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml @@ -0,0 +1,87 @@ +{{- if .Values.celery_autoscaler.enabled }} +{{- $app := include "modelEngine.celeryautoscalername" . }} +{{- $env := .Values.context }} +{{- $tag := .Values.tag }} +{{- $message_broker := .Values.celeryBrokerType }} +{{- $num_shards := .Values.celery_autoscaler.num_shards }} +{{- $broker_name := ternary "redis-elasticache-message-broker-master" "sqs-message-broker-master" (eq $message_broker "elasticache") }} +apiVersion: apps/v1 +kind: StatefulSet +metadata: + labels: + {{- include "modelEngine.baseLabels" . | nindent 4 }} + {{- include "modelEngine.selectorLabels.celeryAutoscaler" . | nindent 4 }} + name: {{ $app }} +spec: + serviceName: {{ $app }} + replicas: {{ $num_shards }} + selector: + matchLabels: + app: {{ $app }} + template: + metadata: + annotations: + ad.datadoghq.com/main.logs: '[{"service": "{{ $app }}", "source": "python"}]' + sidecar.istio.io/inject: "false" + labels: + {{- include "modelEngine.baseLabels" . | nindent 8 }} + {{- include "modelEngine.selectorLabels.celeryAutoscaler" . | nindent 8 }} + spec: + containers: + - args: + - ddtrace-run + - python + - -m + - model_engine_server.core.celery.celery_autoscaler + env: + - name: AWS_PROFILE + value: {{ .Values.aws.profileName }} + - name: AWS_CONFIG_FILE + value: /opt/.aws/config + - name: DD_TRACE_ENABLED + value: 'false' + - name: DD_SERVICE + value: {{ $app }} + - name: DD_ENV + value: {{ $env }} + - name: DD_VERSION + value: {{ $tag }} + - name: DD_AGENT_HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + - name: BROKER_NAME + value: {{ $broker_name }} + - name: REDIS_BROKER_NAME + value: {{ $broker_name }} + - name: CELERY_ELASTICACHE_ENABLED + value: {{ (eq $message_broker "elasticache") | squote }} + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: NUM_SHARDS + value: '{{ $num_shards }}' + image: "{{ .Values.image.gatewayRepository }}:{{ $tag }}" + imagePullPolicy: Always + name: main + resources: + requests: + cpu: 1000m + volumeMounts: + - mountPath: /opt/.aws/config + name: config-volume + subPath: config + nodeSelector: + node-lifecycle: normal + tolerations: + - key: CriticalAddonsOnly + operator: Equal + value: 'true' + effect: NoSchedule + serviceAccountName: {{ include "modelEngine.fullname" $ }} + volumes: + - configMap: + name: {{ .Values.aws.configMap.name }} + name: config-volume +{{- end }} \ No newline at end of file diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index f73ffea6..fd07f361 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -63,6 +63,9 @@ autoscaling: prewarming: enabled: false +celery_autoscaler: + enabled: false + podDisruptionBudget: enabled: true minAvailable: 1 diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 8b32aee0..2ff37197 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -91,6 +91,12 @@ autoscaling: prewarming: enabled: false +# for async endpoints, Celery autoscaler scales the number of pods based on number of requests +# num_shards is number of instances of the autoscaler +celery_autoscaler: + enabled: true + num_shards: 3 + podDisruptionBudget: enabled: true minAvailable: 1 @@ -239,5 +245,5 @@ imageCache: operator: "Exists" effect: "NoSchedule" -# celeryBrokerType specifies the celery broker type for async endpoints (coming soon) +# celeryBrokerType specifies the celery broker type for async endpoints, either "sqs" or "elasticache" celeryBrokerType: sqs diff --git a/model-engine/model_engine_server/core/celery/__init__.py b/model-engine/model_engine_server/core/celery/__init__.py index cb4eb189..af024891 100644 --- a/model-engine/model_engine_server/core/celery/__init__.py +++ b/model-engine/model_engine_server/core/celery/__init__.py @@ -1,8 +1,11 @@ from typing import Sequence -from .app import TaskVisibility, celery_app +from .app import TaskVisibility, celery_app, get_all_db_indexes, get_redis_host_port, inspect_app __all__: Sequence[str] = ( "celery_app", + "get_all_db_indexes", + "get_redis_host_port", + "inspect_app", "TaskVisibility", ) diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py new file mode 100644 index 00000000..b5b44a78 --- /dev/null +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -0,0 +1,631 @@ +import asyncio as aio +import dataclasses +import hashlib +import logging +import os +import time +from abc import ABC, abstractmethod +from bisect import bisect +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from math import ceil +from typing import Any, DefaultDict, Dict, List, Set, Tuple + +import aioredis +import stringcase +from celery.app.control import Inspect +from datadog import statsd +from kubernetes_asyncio import client +from kubernetes_asyncio import config as kube_config +from kubernetes_asyncio.client.rest import ApiException +from kubernetes_asyncio.config.config_exception import ConfigException +from model_engine_server.core.aws.roles import session +from model_engine_server.core.celery import ( + TaskVisibility, + celery_app, + get_all_db_indexes, + get_redis_host_port, + inspect_app, +) +from model_engine_server.core.loggers import logger_name, make_logger + + +def excluded_namespaces(): + try: + from plugins.celery_autoscaler_dependencies import CELERY_AUTOSCALER_EXCLUDED_NAMESPACES + + return CELERY_AUTOSCALER_EXCLUDED_NAMESPACES + except ModuleNotFoundError: + return [] + + +ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master" +SQS_BROKER = "sqs-message-broker-master" + +UPDATE_DEPLOYMENT_MAX_RETRIES = 10 + +SQS_SAMPLE_COUNT = 10 + +logger = make_logger(logger_name()) + +autoscaler_broker = os.environ.get("BROKER_NAME", SQS_BROKER) +aws_profile = os.environ.get("AWS_PROFILE") + + +@dataclasses.dataclass +class CeleryAutoscalerParams: + queue: str + broker: str = SQS_BROKER + task_visibility: TaskVisibility = TaskVisibility.VISIBILITY_1H + per_worker: int = 1 + min_workers: int = 0 + max_workers: int = 1 + + +def _hash_any_to_int(data: Any): + return int(hashlib.md5(str(data).encode()).hexdigest(), 16) + + +async def list_deployments(core_api, apps_api) -> Dict[Tuple[str, str], CeleryAutoscalerParams]: + namespaces = await core_api.list_namespace() + celery_deployments_params = {} + for namespace in namespaces.items: + namespace_name = namespace.metadata.name + if namespace_name in excluded_namespaces(): + continue + namespace_start_time = time.time() + deployments = await apps_api.list_namespaced_deployment(namespace=namespace_name) + logger.info( + f"list_namespaced_deployment with {namespace_name} took {time.time() - namespace_start_time} seconds" + ) + for deployment in deployments.items: + deployment_name = deployment.metadata.name + annotations = deployment.metadata.annotations + + if not annotations: + continue + + # Parse parameters + params = {} + + if "celery.scaleml.autoscaler/broker" in annotations: + deployment_broker = annotations["celery.scaleml.autoscaler/broker"] + else: + deployment_broker = ELASTICACHE_REDIS_BROKER + + if deployment_broker != autoscaler_broker: + logger.debug( + f"Skipping deployment {deployment_name}; deployment's broker {deployment_broker} is not {autoscaler_broker}" + ) + continue + + for f in dataclasses.fields(CeleryAutoscalerParams): + k = f.name + v = annotations.get(f"celery.scaleml.autoscaler/{stringcase.camelcase(k)}") + if not v: + continue + + try: + if k == "task_visibility": + v = TaskVisibility.from_name(v) + v = f.type(v) + except (ValueError, KeyError): + logger.exception(f"Unable to convert {f.name}: {v} to {f.type}") + + params[k] = v + + try: + celery_autoscaler_params = CeleryAutoscalerParams(**params) + except TypeError: + logger.debug( + f"Missing params, skipping deployment : {deployment_name} in {namespace_name}" + ) + continue + + celery_deployments_params[(deployment_name, namespace_name)] = celery_autoscaler_params + + return celery_deployments_params + + +class InstanceLogger(logging.LoggerAdapter): + def process(self, msg, kwargs): + return "%s %s" % (self.extra["name"], msg), kwargs + + +class Instance: + def __init__(self, api, name, namespace, params: CeleryAutoscalerParams, env): + self.api = api + self.name = name + self.namespace = namespace + self.params = params + self.history: List[Tuple[float, float]] = [] + self.logger = InstanceLogger(logger, {"name": name}) + self.env = env + + async def check_queue_size_and_update_deployment(self, queue_size: int) -> None: + workers_wanted = ceil(queue_size / self.params.per_worker) + + time_now = time.monotonic() + self.history.append((workers_wanted, time_now)) + + # Take last 10 minutes + times = [t for _, t in self.history] + evict = bisect(times, time_now - 600) + self.history = self.history[evict:] + + workers_wanted = max(self.history)[0] # type: ignore + workers_wanted = min(self.params.max_workers, workers_wanted) + workers_wanted = max(self.params.min_workers, workers_wanted) + + await self.update_deployment(workers_wanted) + + async def update_deployment(self, workers_wanted) -> None: + for _ in range(UPDATE_DEPLOYMENT_MAX_RETRIES): + try: + dep = await self.api.read_namespaced_deployment( + name=self.name, namespace=self.namespace + ) + + if dep.spec.replicas == workers_wanted: + self.logger.debug("Deployment not updated.") + break + + dep.spec.replicas = workers_wanted + + await self.api.patch_namespaced_deployment( + name=self.name, + namespace=self.namespace, + body=dep, + ) + + self.logger.info(f"Deployment updated. replicas={dep.spec.replicas}") + emit_health_metric("scaling_succeeded", self.env) + return + except ApiException as exc: + if exc.status == 409: + self.logger.info("409 retry") + continue + elif exc.status == 404: + self.logger.warning("404 not found") + return + emit_health_metric("scaling_failed", self.env) + raise + else: + emit_health_metric("scaling_failed", self.env) + raise Exception("Ran out of retries updating deployment") + + +@dataclasses.dataclass +class QueueSizes: + """Obtained from Inspect.active()""" + + active: int = 0 + + """Obtained from Inspect.active() + """ + reserved: int = 0 + + """Computed by summing Redis queue lengths across all db_indexes. + """ + enqueued: int = 0 + + """The sum of all of other fields. + """ + total: int = 0 + + # Ignoring these other Inspect categories for now, since they have a different structure + # from 'active' and 'reserved'. We can add them later if we want - it'd just require some + # more complexity to parse them out. + # + # scheduled: int = 0 + # revoked: int = 0 + # registered: int = 0 + + +@dataclasses.dataclass +class WorkerMetrics: + """ + Key: db_index + Value: number of workers + """ + + worker_counts: DefaultDict[int, int] + + +@dataclasses.dataclass +class BrokerMetrics: + """ + Key: (queue_name, db_index) + Value: QueueSizes + """ + + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes] + + """" + Represents the number of active redis client connections + """ + connection_count: int + + """ + Represents the max number of redis client connections allowed + """ + max_connections: int + + +@dataclasses.dataclass +class Metrics: + worker_metrics: WorkerMetrics + broker_metrics: BrokerMetrics + + +def emit_metrics( + metrics: Metrics, + env: str, +) -> None: + """ + Emits a given mapping of queue sizes to Datadog. + """ + queue_sizes = metrics.broker_metrics.queue_sizes + for q, queue_size in queue_sizes.items(): + queue_name, _ = q + tags = [ + f"env:{env}", + f"queue:{queue_name}", + ] + + for metric_name, metric_value in queue_size.__dict__.items(): + statsd.gauge(f"celery.queue_size.{metric_name}", metric_value, tags=tags) + + # Redis-specific, can be ignored for sqs (worker_counts should be empty anyways) + for db_index, worker_count in metrics.worker_metrics.worker_counts.items(): + task_visibility = TaskVisibility(db_index).name.lower() + tags = [ + f"env:{env}", + f"task_visibility:{task_visibility}", + ] + statsd.gauge("celery.worker_count", worker_count, tags=tags) + + if metrics.broker_metrics.connection_count is not None: + tags = [ + f"env:{env}", + ] + statsd.gauge( + "celery.connection_count", + metrics.broker_metrics.connection_count, + tags=tags, + ) + + if metrics.broker_metrics.max_connections is not None: + tags = [ + f"env:{env}", + ] + statsd.gauge("celery.max_connections", metrics.broker_metrics.max_connections, tags=tags) + + +def emit_health_metric(metric_name: str, env: str): + tags = [f"env:{env}"] + statsd.increment(f"celery_autoscaler.{metric_name}", tags=tags) + + +class AutoscalerBroker(ABC): + """ + Base class for autoscaler brokers. + """ + + @abstractmethod + async def get_broker_metrics( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ) -> BrokerMetrics: + """ + Calculates broker related metrics. + + Args: + queues: a set of (queue_name, db_index) + queue_sizes: number of active and reserved tasks for each queue + + Returns: broker metrics + """ + + +class RedisBroker(AutoscalerBroker): + def __init__(self, use_elasticache: bool, initialized: bool = False): + self.use_elasticache = use_elasticache + self.initialized = initialized + + async def _init_client(self): + ( + host, + port, + ) = ( + get_redis_host_port() + ) # Switches the redis instance based on CELERY_ELASTICACHE_ENABLED's value + self.redis = { + db_index: aioredis.client.Redis.from_url(f"redis://{host}:{port}/{db_index}") + for db_index in get_all_db_indexes() + } + self.initialized = True + + async def _get_queue_sizes( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ): + if not self.initialized: + await self._init_client() + + for queue_name, db_index in queues: + q = (queue_name, db_index) + enqueued = await self.redis[db_index].llen(queue_name) + queue_sizes[q].enqueued += enqueued + queue_sizes[q].total += enqueued + return queue_sizes + + async def _get_connection_count(self): + redis_client = next(iter(self.redis.values()), None) # get any redis client + + if redis_client is not None: + if ( + self.use_elasticache + ): # We are using elasticache which doesn't allow us to do `CONFIG GET` + info = await redis_client.info() + connection_count = info.get("connected_clients") + max_connections = info.get("maxclients") + else: + (info, config) = await aio.gather( + redis_client.info(), + redis_client.config_get("maxclients"), + ) + max_connections = config.get("maxclients") + connection_count = info.get("connected_clients") + + return connection_count, max_connections + + async def get_broker_metrics( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ) -> BrokerMetrics: + queue_sizes = await self._get_queue_sizes(queues, queue_sizes) + connection_count, max_connections = await self._get_connection_count() + return BrokerMetrics( + queue_sizes=queue_sizes, + connection_count=connection_count, + max_connections=max_connections, + ) + + +class SQSBroker(AutoscalerBroker): + @staticmethod + def _get_sqs_queue_size(queue_name: str): + sqs_client = session(aws_profile).client("sqs", region_name="us-west-2") + try: + total_start_time = time.time() + queue_size_hist = [] + reserved_size_hist = [] + # We intentionally launch several requests to the same queue. + # We have found multiple samples results in more accurate length estimates compared to a single request. + # Performance-wise: The first request takes ~0.5s, subsequent requests take ~0.005s + for _ in range(SQS_SAMPLE_COUNT): + response = sqs_client.get_queue_attributes( + QueueUrl=queue_name, + AttributeNames=[ + "ApproximateNumberOfMessages", + "ApproximateNumberOfMessagesNotVisible", + ], + ) + queue_size_hist.append(int(response["Attributes"]["ApproximateNumberOfMessages"])) + reserved_size_hist.append( + int(response["Attributes"]["ApproximateNumberOfMessagesNotVisible"]) + ) + total_end_time = time.time() + queue_size = max(queue_size_hist) + # SQS's ApproximateNumberOfMessagesNotVisible should correspond to celery's + # number of active + number of reserved tasks + reserved_size = max(reserved_size_hist) + logger.info( + f"SQS {queue_name} total: {total_end_time - total_start_time} seconds, queue size {queue_size}, reserved size {reserved_size}" + ) + + except sqs_client.exceptions.QueueDoesNotExist as e: + logger.info(f"Queue does not exist {queue_name}: {e}") + queue_size = 0 + reserved_size = 0 + except Exception as e: + logger.error(f"Failed to get queue attributes {queue_name}: {e}") + queue_size = 0 + reserved_size = 0 + return queue_size, reserved_size + + def _get_queue_sizes( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ): + queue_names = [queue_name for queue_name, _ in queues] + with ThreadPoolExecutor() as executor: + results = executor.map(SQSBroker._get_sqs_queue_size, queue_names) + + for q, (enqueued, reserved) in zip(queues, results): + queue_sizes[q].enqueued += enqueued + queue_sizes[q].reserved += reserved + queue_sizes[q].total += enqueued + reserved + return queue_sizes + + async def get_broker_metrics( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ) -> BrokerMetrics: + queue_sizes = self._get_queue_sizes(queues, queue_sizes) + return BrokerMetrics( + queue_sizes=queue_sizes, + connection_count=None, + max_connections=None, + ) # connection_count and max_connections are redis-specific metrics + + +def get_worker_metrics( + inspect: Dict[int, Inspect], + queues: Set[Tuple[str, int]], +) -> Tuple[WorkerMetrics, DefaultDict[Tuple[str, int], QueueSizes]]: + """ + Given a set of Celery Inspect results for each db connection, + computes the number of workers for each db connection, and number of active and reserved tasks. + + In the case of SQS this will return no data for queue_sizes/worker counts, as inspect is empty + """ + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes] = defaultdict(QueueSizes) + worker_counts: DefaultDict[int, int] = defaultdict(int) + for db_index, insp in inspect.items(): + insp_categories = { + "active": insp.active(), + "reserved": insp.reserved(), + } + + worker_ping = insp.ping() + if worker_ping: + worker_counts[db_index] = len(worker_ping.values()) + + for insp_key, worker_group in filter(lambda x: x[1], insp_categories.items()): + for task_list in worker_group.values(): + for task in task_list: + queue_name = task["delivery_info"]["routing_key"] + q = (queue_name, db_index) + + if q in queues: + queue_sizes[q].__dict__[insp_key] += 1 + queue_sizes[q].total += 1 + return WorkerMetrics(worker_counts=worker_counts), queue_sizes + + +async def get_metrics( + broker: AutoscalerBroker, + inspect: Dict[int, Inspect], + queues: Set[Tuple[str, int]], +) -> Metrics: + """ + Given a set of Redis db connections and Celery Inspect results for each db connection, + computes worker and broker metrics. + """ + + worker_metrics, active_reserved_queue_sizes = get_worker_metrics(inspect, queues) + broker_metrics = await broker.get_broker_metrics(queues, active_reserved_queue_sizes) + + return Metrics( + worker_metrics=worker_metrics, + broker_metrics=broker_metrics, + ) + + +async def main(): + instances: Dict[Tuple[str, str], Instance] = {} + try: + kube_config.load_incluster_config() + except ConfigException: + logger.info("No incluster kubernetes config, falling back to local") + await kube_config.load_kube_config() + + core_api = client.CoreV1Api() + apps_api = client.AppsV1Api() + + BROKER_NAME_TO_CLASS = { + ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True), + SQS_BROKER: SQSBroker(), + } + + broker = BROKER_NAME_TO_CLASS[autoscaler_broker] + broker_type = "redis" if isinstance(broker, RedisBroker) else "sqs" + + if broker_type == "redis": + inspect = { + db_index: inspect_app( + app=celery_app( + None, broker_type=broker_type, task_visibility=db_index, aws_role=aws_profile + ) + ) + for db_index in get_all_db_indexes() + } + elif broker_type == "sqs": + # for sqs we will get active/reserved counts directly from sqs as opposed to using + # an inspect object + inspect = {} + else: + raise ValueError("broker_type not redis or sqs, how did we get here?") + + env = os.getenv("DD_ENV") + instance_count = int(os.getenv("POD_NAME", "pod-0").split("-")[-1]) + num_shards = int(os.getenv("NUM_SHARDS", 1)) + + env = f"{env}-{autoscaler_broker}" + + while True: + try: + loop_start = time.time() + deployments = await list_deployments(core_api=core_api, apps_api=apps_api) + logger.info(f"list_deployments took {time.time() - loop_start} seconds") + celery_queues = set() + celery_queues_params = [] + for deployment_and_namespace, params in sorted( + deployments.items() + ): # sort for a bit more determinism + # Hash the deployment / namespace to deterministically partition the deployments. + # Skip all deployments not in this partition. + if _hash_any_to_int(deployment_and_namespace) % num_shards != instance_count: + continue + + deployment_name, namespace = deployment_and_namespace + instance = instances.get(deployment_and_namespace) + if instance is None or instance.params != params: + instances[deployment_and_namespace] = Instance( + apps_api, deployment_name, namespace, params, env + ) + + # We're treating a queue as a pair consisting of a (queue_name, db_index). + # This means that two queues that happen to have the same name are treated + # as semantically distinct if they have different db_indexes. + celery_queues.add((params.queue, params.task_visibility.value)) + celery_queues_params.append(params.__dict__) + + # Clean up instances not in set + for deployment_and_namespace in set(instances) - set(deployments): + del instances[deployment_and_namespace] + + # Get queue sizes + # (queue_name, db_index) -> QueueSizes + start_get_metrics = time.time() + metrics = await get_metrics(broker, inspect=inspect, queues=celery_queues) + logger.info(f"get_metrics took {time.time() - start_get_metrics} seconds") + + queue_sizes = metrics.broker_metrics.queue_sizes + for k, v in sorted(queue_sizes.items()): + queue_name, _ = k + logger.info(f"Inflight : {queue_name} : {v.total}") + + emit_metrics(metrics=metrics, env=env) + + # Update scaling + for instance in instances.values(): + queue_size = queue_sizes[ + (instance.params.queue, int(instance.params.task_visibility)) + ] + try: + await instance.check_queue_size_and_update_deployment(queue_size.total) + except Exception as e: + logger.exception(f"Failed to update {instance.name}: {e}") + + # Wait before next iteration + iteration_len = time.time() - loop_start + logger.info(f"Iteration length: {iteration_len} seconds.") + if iteration_len < 3: + await aio.sleep(3 - iteration_len) + + emit_health_metric("heartbeat", env) + except Exception as e: + logger.exception(f"Error in deployment loop: {e}") + continue + + +if __name__ == "__main__": + aio.run(main()) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index abf0809b..8575237b 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -43,6 +43,7 @@ smart-open~=5.2 sqlalchemy[asyncio]==2.0.4 sse-starlette==1.6.1 sseclient-py==1.7.2 +stringcase==1.2.0 tenacity>=6.0.0,<=6.2.0 testing-postgresql==1.3.0 transformers==4.34.1 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index e61d22e3..19e4edf5 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -417,6 +417,8 @@ starlette==0.19.1 # via # fastapi # sse-starlette +stringcase==1.2.0 + # via -r model-engine/requirements.in tblib==2.0.0 # via celery tenacity==6.2.0 From 4d72b23780d160f77da64892910d3bde15bab79f Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 20 Nov 2023 16:30:54 -0800 Subject: [PATCH 189/425] Don't install Celery autoscaler for test deployments (#388) --- .../model-engine/templates/celery_autoscaler_stateful_set.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml index fb8c393b..8dafd7e5 100644 --- a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml +++ b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml @@ -1,4 +1,5 @@ {{- if .Values.celery_autoscaler.enabled }} +{{- if not .Values.serviceIdentifier }} {{- $app := include "modelEngine.celeryautoscalername" . }} {{- $env := .Values.context }} {{- $tag := .Values.tag }} @@ -84,4 +85,5 @@ spec: - configMap: name: {{ .Values.aws.configMap.name }} name: config-volume +{{- end }} {{- end }} \ No newline at end of file From 3c0f1681fe59946c089be5d45d9d0a94d6f96d88 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 27 Nov 2023 10:06:38 -0800 Subject: [PATCH 190/425] LLM update API route (#387) --- .../model_engine_server/api/llms_v1.py | 74 +++++++- .../api/model_endpoints_v1.py | 11 -- .../model_engine_server/common/dtos/llms.py | 48 ++++- .../domain/entities/llm_entity.py | 1 + .../model_engine_server/domain/exceptions.py | 6 - .../use_cases/llm_model_endpoint_use_cases.py | 179 +++++++++++++++++- model-engine/tests/unit/domain/conftest.py | 12 ++ .../tests/unit/domain/test_llm_use_cases.py | 136 +++++++++++-- 8 files changed, 421 insertions(+), 46 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 6fc2a1d1..79961e2d 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -33,6 +33,8 @@ StreamError, StreamErrorContent, TokenUsage, + UpdateLLMModelEndpointV1Request, + UpdateLLMModelEndpointV1Response, ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User @@ -54,7 +56,6 @@ LLMFineTuningQuotaReached, ObjectAlreadyExistsException, ObjectHasInvalidValueException, - ObjectNotApprovedException, ObjectNotAuthorizedException, ObjectNotFoundException, UpstreamServiceError, @@ -70,11 +71,13 @@ from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( CompletionStreamV1UseCase, CompletionSyncV1UseCase, + CreateLLMModelBundleV1UseCase, CreateLLMModelEndpointV1UseCase, DeleteLLMEndpointByNameUseCase, GetLLMModelEndpointByNameV1UseCase, ListLLMModelEndpointsV1UseCase, ModelDownloadV1UseCase, + UpdateLLMModelEndpointV1UseCase, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase from sse_starlette.sse import EventSourceResponse @@ -151,13 +154,16 @@ async def create_model_endpoint( docker_repository=external_interfaces.docker_repository, model_primitive_gateway=external_interfaces.model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=create_model_bundle_use_case, model_bundle_repository=external_interfaces.model_bundle_repository, - model_endpoint_service=external_interfaces.model_endpoint_service, llm_artifact_gateway=external_interfaces.llm_artifact_gateway, docker_repository=external_interfaces.docker_repository, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, + model_endpoint_service=external_interfaces.model_endpoint_service, + ) return await use_case.execute(user=auth, request=request) except ObjectAlreadyExistsException as exc: raise HTTPException( @@ -176,11 +182,6 @@ async def create_model_endpoint( status_code=400, detail=str(exc), ) from exc - except ObjectNotApprovedException as exc: - raise HTTPException( - status_code=403, - detail="The specified model bundle was not approved yet.", - ) from exc except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( status_code=404, @@ -234,6 +235,63 @@ async def get_model_endpoint( ) from exc +@llm_router_v1.put( + "/model-endpoints/{model_endpoint_name}", response_model=UpdateLLMModelEndpointV1Response +) +async def update_model_endpoint( + model_endpoint_name: str, + request: UpdateLLMModelEndpointV1Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> UpdateLLMModelEndpointV1Response: + """ + Updates an LLM endpoint for the current user. + """ + logger.info(f"PUT /llm/model-endpoints/{model_endpoint_name} with {request} for {auth}") + try: + create_model_bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=external_interfaces.model_bundle_repository, + docker_repository=external_interfaces.docker_repository, + model_primitive_gateway=external_interfaces.model_primitive_gateway, + ) + create_llm_model_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=create_model_bundle_use_case, + model_bundle_repository=external_interfaces.model_bundle_repository, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + docker_repository=external_interfaces.docker_repository, + ) + use_case = UpdateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + ) + return await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + except EndpointLabelsException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except EndpointResourceInvalidRequestException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified LLM endpoint could not be found.", + ) from exc + except DockerImageNotFoundException as exc: + raise HTTPException( + status_code=404, + detail="The specified docker image could not be found.", + ) from exc + + @llm_router_v1.post("/completions-sync", response_model=CompletionSyncV1Response) async def create_completion_sync_task( model_endpoint_name: str, diff --git a/model-engine/model_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py index 3b45f071..807393cd 100644 --- a/model-engine/model_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -31,7 +31,6 @@ ExistingEndpointOperationInProgressException, ObjectAlreadyExistsException, ObjectHasInvalidValueException, - ObjectNotApprovedException, ObjectNotAuthorizedException, ObjectNotFoundException, ) @@ -80,11 +79,6 @@ async def create_model_endpoint( status_code=400, detail=str(exc), ) from exc - except ObjectNotApprovedException as exc: - raise HTTPException( - status_code=403, - detail="The specified model bundle was not approved yet.", - ) from exc except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( status_code=404, @@ -154,11 +148,6 @@ async def update_model_endpoint( return await use_case.execute( user=auth, model_endpoint_id=model_endpoint_id, request=request ) - except ObjectNotApprovedException as exc: - raise HTTPException( - status_code=403, - detail="The specified model bundle was not approved yet.", - ) from exc except EndpointLabelsException as exc: raise HTTPException( status_code=400, diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index dd2e06a0..346c9ae2 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -87,6 +87,7 @@ class GetLLMModelEndpointV1Response(BaseModel): inference_framework_image_tag: Optional[str] = None num_shards: Optional[int] = None quantize: Optional[Quantization] = None + checkpoint_path: Optional[str] = None spec: Optional[GetModelEndpointV1Response] = None @@ -94,7 +95,52 @@ class ListLLMModelEndpointsV1Response(BaseModel): model_endpoints: List[GetLLMModelEndpointV1Response] -# Delete and update use the default Launch endpoint APIs. +class UpdateLLMModelEndpointV1Request(BaseModel): + # LLM specific fields + model_name: Optional[str] + source: Optional[LLMSource] + inference_framework_image_tag: Optional[str] + num_shards: Optional[int] + """ + Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models + """ + + quantize: Optional[Quantization] + """ + Whether to quantize the model. Only affect behavior for text-generation-inference models + """ + + checkpoint_path: Optional[str] + """ + Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models + """ + + # General endpoint fields + metadata: Optional[Dict[str, Any]] + post_inference_hooks: Optional[List[str]] + cpus: Optional[CpuSpecificationType] + gpus: Optional[int] + memory: Optional[StorageSpecificationType] + gpu_type: Optional[GpuType] + storage: Optional[StorageSpecificationType] + optimize_costs: Optional[bool] + min_workers: Optional[int] + max_workers: Optional[int] + per_worker: Optional[int] + labels: Optional[Dict[str, str]] + prewarm: Optional[bool] + high_priority: Optional[bool] + billing_tags: Optional[Dict[str, Any]] + default_callback_url: Optional[HttpUrl] + default_callback_auth: Optional[CallbackAuth] + public_inference: Optional[bool] + + +class UpdateLLMModelEndpointV1Response(BaseModel): + endpoint_creation_task_id: str + + +# Delete uses the default Launch endpoint APIs. class CompletionSyncV1Request(BaseModel): diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index 30ec8993..4da8c278 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -28,3 +28,4 @@ class LLMMetadata: inference_framework_image_tag: str num_shards: int quantize: Optional[Quantization] = None + checkpoint_path: Optional[str] = None diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index 934a5e21..b78bb281 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -31,12 +31,6 @@ class ObjectHasInvalidValueException(DomainException, ValueError): """ -class ObjectNotApprovedException(DomainException): - """ - Thrown when a required object is not approved, e.g. for a Bundle in review. - """ - - @dataclass class DockerImageNotFoundException(DomainException): """ diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index d379aac2..af14bf06 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -27,6 +27,8 @@ ModelDownloadRequest, ModelDownloadResponse, TokenOutput, + UpdateLLMModelEndpointV1Request, + UpdateLLMModelEndpointV1Response, ) from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy @@ -48,6 +50,7 @@ ) from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, + EndpointInfraStateNotFound, EndpointLabelsException, EndpointUnsupportedInferenceTypeException, InvalidRequestException, @@ -70,6 +73,7 @@ from ..authorization.live_authorization_module import LiveAuthorizationModule from .model_bundle_use_cases import CreateModelBundleV2UseCase from .model_endpoint_use_cases import ( + CONVERTED_FROM_ARTIFACT_LIKE_KEY, _handle_post_inference_hooks, model_endpoint_entity_to_get_model_endpoint_response, validate_billing_tags, @@ -237,6 +241,7 @@ def _model_endpoint_entity_to_get_llm_model_endpoint_response( inference_framework_image_tag=llm_metadata["inference_framework_image_tag"], num_shards=llm_metadata["num_shards"], quantize=llm_metadata.get("quantize"), + checkpoint_path=llm_metadata.get("checkpoint_path"), spec=model_endpoint_entity_to_get_model_endpoint_response(model_endpoint), ) return response @@ -274,19 +279,17 @@ def validate_quantization( ) -class CreateLLMModelEndpointV1UseCase: +class CreateLLMModelBundleV1UseCase: def __init__( self, create_model_bundle_use_case: CreateModelBundleV2UseCase, model_bundle_repository: ModelBundleRepository, - model_endpoint_service: ModelEndpointService, llm_artifact_gateway: LLMArtifactGateway, docker_repository: DockerRepository, ): self.authz_module = LiveAuthorizationModule() self.create_model_bundle_use_case = create_model_bundle_use_case self.model_bundle_repository = model_bundle_repository - self.model_endpoint_service = model_endpoint_service self.llm_artifact_gateway = llm_artifact_gateway self.docker_repository = docker_repository @@ -302,7 +305,7 @@ def check_docker_image_exists_for_image_tag( tag=framework_image_tag, ) - async def create_model_bundle( + async def execute( self, user: User, endpoint_name: str, @@ -840,6 +843,17 @@ async def create_tensorrt_llm_bundle( ) ).model_bundle_id + +class CreateLLMModelEndpointV1UseCase: + def __init__( + self, + create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, + model_endpoint_service: ModelEndpointService, + ): + self.authz_module = LiveAuthorizationModule() + self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case + self.model_endpoint_service = model_endpoint_service + async def execute( self, user: User, request: CreateLLMModelEndpointV1Request ) -> CreateLLMModelEndpointV1Response: @@ -865,10 +879,10 @@ async def execute( ]: if request.endpoint_type != ModelEndpointType.STREAMING: raise ObjectHasInvalidValueException( - f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM and LightLLM." + f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM, LightLLM, and TensorRT-LLM." ) - bundle = await self.create_model_bundle( + bundle = await self.create_llm_model_bundle_use_case.execute( user, endpoint_name=request.name, model_name=request.model_name, @@ -908,6 +922,7 @@ async def execute( inference_framework_image_tag=request.inference_framework_image_tag, num_shards=request.num_shards, quantize=request.quantize, + checkpoint_path=request.checkpoint_path, ) ) @@ -1025,6 +1040,158 @@ async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndp return _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) +class UpdateLLMModelEndpointV1UseCase: + def __init__( + self, + create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + ): + self.authz_module = LiveAuthorizationModule() + self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + + async def execute( + self, user: User, model_endpoint_name: str, request: UpdateLLMModelEndpointV1Request + ) -> UpdateLLMModelEndpointV1Response: + if request.labels is not None: + validate_labels(request.labels) + validate_billing_tags(request.billing_tags) + validate_post_inference_hooks(user, request.post_inference_hooks) + + model_endpoint = await self.llm_model_endpoint_service.get_llm_model_endpoint( + model_endpoint_name + ) + if not model_endpoint: + raise ObjectNotFoundException + if not self.authz_module.check_access_write_owned_entity(user, model_endpoint.record): + raise ObjectNotAuthorizedException + + endpoint_record = model_endpoint.record + model_endpoint_id = endpoint_record.id + bundle = endpoint_record.current_model_bundle + + # TODO: We may want to consider what happens if an endpoint gets stuck in UPDATE_PENDING + # on first creating it, and we need to find a way to get it unstuck. This would end up + # causing endpoint.infra_state to be None. + if model_endpoint.infra_state is None: + error_msg = f"Endpoint infra state not found for {model_endpoint_name=}" + logger.error(error_msg) + raise EndpointInfraStateNotFound(error_msg) + + infra_state = model_endpoint.infra_state + + if ( + request.model_name + or request.source + or request.inference_framework_image_tag + or request.num_shards + or request.quantize + or request.checkpoint_path + ): + llm_metadata = (model_endpoint.record.metadata or {}).get("_llm", {}) + inference_framework = llm_metadata["inference_framework"] + + model_name = request.model_name or llm_metadata["model_name"] + source = request.source or llm_metadata["source"] + inference_framework_image_tag = ( + request.inference_framework_image_tag + or llm_metadata["inference_framework_image_tag"] + ) + num_shards = request.num_shards or llm_metadata["num_shards"] + quantize = request.quantize or llm_metadata.get("quantize") + checkpoint_path = request.checkpoint_path or llm_metadata.get("checkpoint_path") + + validate_model_name(model_name, inference_framework) + validate_num_shards( + num_shards, inference_framework, request.gpus or infra_state.resource_state.gpus + ) + validate_quantization(quantize, inference_framework) + + bundle = await self.create_llm_model_bundle_use_case.execute( + user, + endpoint_name=model_endpoint_name, + model_name=model_name, + source=source, + framework=inference_framework, + framework_image_tag=inference_framework_image_tag, + endpoint_type=endpoint_record.endpoint_type, + num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, + ) + + metadata = endpoint_record.metadata or {} + metadata["_llm"] = asdict( + LLMMetadata( + model_name=model_name, + source=source, + inference_framework=inference_framework, + inference_framework_image_tag=inference_framework_image_tag, + num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, + ) + ) + request.metadata = metadata + + # For resources that are not specified in the update endpoint request, pass in resource from + # infra_state to make sure that after the update, all resources are valid and in sync. + # E.g. If user only want to update gpus and leave gpu_type as None, we use the existing gpu_type + # from infra_state to avoid passing in None to validate_resource_requests. + validate_resource_requests( + bundle=bundle, + cpus=request.cpus or infra_state.resource_state.cpus, + memory=request.memory or infra_state.resource_state.memory, + storage=request.storage or infra_state.resource_state.storage, + gpus=request.gpus or infra_state.resource_state.gpus, + gpu_type=request.gpu_type or infra_state.resource_state.gpu_type, + ) + + validate_deployment_resources( + min_workers=request.min_workers, + max_workers=request.max_workers, + endpoint_type=endpoint_record.endpoint_type, + ) + + if request.metadata is not None and CONVERTED_FROM_ARTIFACT_LIKE_KEY in request.metadata: + raise ObjectHasInvalidValueException( + f"{CONVERTED_FROM_ARTIFACT_LIKE_KEY} is a reserved metadata key and cannot be used by user." + ) + + updated_endpoint_record = await self.model_endpoint_service.update_model_endpoint( + model_endpoint_id=model_endpoint_id, + model_bundle_id=bundle.id, + metadata=request.metadata, + post_inference_hooks=request.post_inference_hooks, + cpus=request.cpus, + gpus=request.gpus, + memory=request.memory, + gpu_type=request.gpu_type, + storage=request.storage, + optimize_costs=request.optimize_costs, + min_workers=request.min_workers, + max_workers=request.max_workers, + per_worker=request.per_worker, + labels=request.labels, + prewarm=request.prewarm, + high_priority=request.high_priority, + default_callback_url=request.default_callback_url, + default_callback_auth=request.default_callback_auth, + public_inference=request.public_inference, + ) + _handle_post_inference_hooks( + created_by=endpoint_record.created_by, + name=updated_endpoint_record.name, + post_inference_hooks=request.post_inference_hooks, + ) + + return UpdateLLMModelEndpointV1Response( + endpoint_creation_task_id=updated_endpoint_record.creation_task_id # type: ignore + ) + + class DeleteLLMEndpointByNameUseCase: """ Use case for deleting an LLM Model Endpoint of a given user by endpoint name. diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 06310666..f433071c 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -7,6 +7,7 @@ CompletionStreamV1Request, CompletionSyncV1Request, CreateLLMModelEndpointV1Request, + UpdateLLMModelEndpointV1Request, ) from model_engine_server.common.dtos.model_bundles import ( CreateModelBundleV1Request, @@ -218,6 +219,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request labels={"team": "infra", "product": "my_product"}, aws_role="test_aws_role", results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://test_checkpoint_path", ) @@ -247,6 +249,16 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req ) +@pytest.fixture +def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request: + return UpdateLLMModelEndpointV1Request( + checkpoint_path="s3://test_checkpoint_path", + memory="4G", + min_workers=0, + max_workers=1, + ) + + @pytest.fixture def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 589e453b..c4fbb31f 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -11,6 +11,7 @@ CreateLLMModelEndpointV1Response, ModelDownloadRequest, TokenOutput, + UpdateLLMModelEndpointV1Request, ) from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus from model_engine_server.core.auth.authentication_repository import User @@ -38,10 +39,12 @@ from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( CompletionStreamV1UseCase, CompletionSyncV1UseCase, + CreateLLMModelBundleV1UseCase, CreateLLMModelEndpointV1UseCase, DeleteLLMEndpointByNameUseCase, GetLLMModelEndpointByNameV1UseCase, ModelDownloadV1UseCase, + UpdateLLMModelEndpointV1UseCase, _include_safetensors_bin_or_pt, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -66,13 +69,17 @@ async def test_create_model_endpoint_use_case_success( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async) assert response_1.endpoint_creation_task_id @@ -93,6 +100,7 @@ async def test_create_model_endpoint_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_request_async.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_request_async.num_shards, "quantize": None, + "checkpoint_path": create_llm_model_endpoint_request_async.checkpoint_path, } } @@ -115,6 +123,7 @@ async def test_create_model_endpoint_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_request_sync.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_request_sync.num_shards, "quantize": None, + "checkpoint_path": None, } } @@ -139,6 +148,7 @@ async def test_create_model_endpoint_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_request_streaming.num_shards, "quantize": None, + "checkpoint_path": None, } } @@ -182,14 +192,16 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() request.inference_framework = inference_framework @@ -198,7 +210,7 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( if valid: await use_case.execute(user=user, request=request) else: - use_case.docker_repository = fake_docker_repository_image_never_exists + llm_bundle_use_case.docker_repository = fake_docker_repository_image_never_exists with pytest.raises(DockerImageNotFoundException): await use_case.execute(user=user, request=request) @@ -220,13 +232,16 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( user=user, @@ -250,6 +265,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_text_generation_inference_request_streaming.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_text_generation_inference_request_streaming.num_shards, "quantize": create_llm_model_endpoint_text_generation_inference_request_streaming.quantize, + "checkpoint_path": create_llm_model_endpoint_text_generation_inference_request_streaming.checkpoint_path, } } @@ -277,13 +293,16 @@ async def test_create_model_endpoint_trt_llm_use_case_success( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( user=user, @@ -307,6 +326,7 @@ async def test_create_model_endpoint_trt_llm_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_trt_llm_request_streaming.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_trt_llm_request_streaming.num_shards, "quantize": create_llm_model_endpoint_trt_llm_request_streaming.quantize, + "checkpoint_path": create_llm_model_endpoint_trt_llm_request_streaming.checkpoint_path, } } @@ -333,13 +353,16 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): await use_case.execute( @@ -363,13 +386,16 @@ async def test_create_llm_model_endpoint_use_case_quantization_exception( docker_repository=fake_docker_repository_image_always_exists, model_primitive_gateway=fake_model_primitive_gateway, ) - use_case = CreateLLMModelEndpointV1UseCase( + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( create_model_bundle_use_case=bundle_use_case, model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, llm_artifact_gateway=fake_llm_artifact_gateway, docker_repository=fake_docker_repository_image_always_exists, ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): await use_case.execute( @@ -410,6 +436,88 @@ async def test_get_llm_model_endpoint_use_case_raises_not_authorized( ) +@pytest.mark.asyncio +async def test_update_model_endpoint_use_case_success( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + fake_llm_model_endpoint_service, + create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, + update_llm_model_endpoint_request: UpdateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + create_use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + ) + update_use_case = UpdateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + + await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + fake_llm_model_endpoint_service.add_model_endpoint(endpoint) + + update_response = await update_use_case.execute( + user=user, + model_endpoint_name=create_llm_model_endpoint_request_streaming.name, + request=update_llm_model_endpoint_request, + ) + assert update_response.endpoint_creation_task_id + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_streaming.model_name, + "source": create_llm_model_endpoint_request_streaming.source, + "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_request_streaming.num_shards, + "quantize": None, + "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, + } + } + assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory + assert ( + endpoint.infra_state.deployment_state.min_workers + == update_llm_model_endpoint_request.min_workers + ) + assert ( + endpoint.infra_state.deployment_state.max_workers + == update_llm_model_endpoint_request.max_workers + ) + + def mocked_auto_tokenizer_from_pretrained(*args, **kwargs): # noqa class mocked_encode: def encode(self, input: str) -> List[Any]: # noqa From 37814eed4f8a8319d34983e1cb33222e46f662a3 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Mon, 27 Nov 2023 11:08:04 -0800 Subject: [PATCH 191/425] adding zephyr 7b (#389) * adding zephyr 7b * update tokenizer repo --- docs/model_zoo.md | 2 ++ .../domain/use_cases/llm_model_endpoint_use_cases.py | 4 ++++ .../infra/repositories/live_tokenizer_repository.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index ebe9f082..35c93e5d 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -26,6 +26,8 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | | `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | | `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | +| `zephyr-7b-alpha` | ✅ | | text-generation-inference, vllm | +| `zephyr-7b-beta` | ✅ | | text-generation-inference, vllm | ## Usage diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index af14bf06..35f52e49 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -124,6 +124,8 @@ "codellama-34b-instruct", "llm-jp-13b-instruct-full", "llm-jp-13b-instruct-full-dolly", + "zephyr-7b-alpha", + "zephyr-7b-beta", ] ), LLMInferenceFramework.VLLM: set( @@ -154,6 +156,8 @@ "mammoth-coder-llama-2-7b", "mammoth-coder-llama-2-13b", "mammoth-coder-llama-2-34b", + "zephyr-7b-alpha", + "zephyr-7b-beta", ] ), LLMInferenceFramework.LIGHTLLM: set( diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index e107e117..873f2e65 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -67,6 +67,8 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "dolly-v2-12b": ModelInfo("databricks/dolly-v2-12b", None), "stablelm-tuned-7b": ModelInfo("StabilityAI/stablelm-tuned-alpha-7b", None), "vicuna-13b": ModelInfo("eachadea/vicuna-13b-1.1", None), + "zephyr-7b-alpha": ModelInfo("HuggingFaceH4/zephyr-7b-alpha", None), + "zephyr-7b-beta": ModelInfo("HuggingFaceH4/zephyr-7b-beta", None), } From 4483dff763986c7f40579ac148ec2a221a757af0 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Mon, 27 Nov 2023 13:16:47 -0800 Subject: [PATCH 192/425] update tensor-rt llm in enum (#390) * update tensor-rt llm in enum * fix to be the same as in the egp and spellbook-backend --- clients/python/llmengine/data_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index f5a5a0b2..cea75176 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -16,6 +16,7 @@ class LLMInferenceFramework(str, Enum): TEXT_GENERATION_INFERENCE = "text_generation_inference" VLLM = "vllm" LIGHTLLM = "lightllm" + TENSORRT_LLM = "tensorrt-llm" class LLMSource(str, Enum): From de7a493916b0e7ff34538d3eebdebd049c80e8da Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Mon, 27 Nov 2023 15:26:40 -0800 Subject: [PATCH 193/425] pypi version bump (#391) --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 694a969f..cbc5efa2 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b19" +__version__ = "0.0.0b20" import os from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 19225a91..1d45ccc5 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta19" +version = "0.0.0.beta20" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 8a9895b2..3c5f5b5f 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta19", + version="0.0.0.beta20", packages=find_packages(), ) From cccbd3eb057a27853531fd55d0ed9ce7bae712f5 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 28 Nov 2023 16:39:40 -0800 Subject: [PATCH 194/425] Change middleware format (#393) --- model-engine/model_engine_server/api/app.py | 61 +++++++++++---------- model-engine/setup.cfg | 1 + 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index b3a41dfd..a26a8ddc 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -27,10 +27,42 @@ logger_name, make_logger, ) +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware logger = make_logger(logger_name()) -app = FastAPI(title="launch", version="1.0.0", redoc_url="/api") + +class CustomMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + try: + LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) + return await call_next(request) + except Exception as e: + tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": str(e), + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Unhandled exception: %s", structured_log) + return JSONResponse( + { + "status_code": 500, + "content": { + "error": "Internal error occurred. Our team has been notified.", + "timestamp": timestamp, + "request_id": request_id, + }, + } + ) + + +app = FastAPI( + title="launch", version="1.0.0", redoc_url="/api", middleware=[Middleware(CustomMiddleware)] +) app.include_router(batch_job_router_v1) app.include_router(inference_task_router_v1) @@ -44,33 +76,6 @@ app.include_router(trigger_router_v1) -@app.middleware("http") -async def dispatch(request: Request, call_next): - try: - LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) - return await call_next(request) - except Exception as e: - tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) - request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) - timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") - structured_log = { - "error": str(e), - "request_id": str(request_id), - "traceback": "".join(tb_str), - } - logger.error("Unhandled exception: %s", structured_log) - return JSONResponse( - { - "status_code": 500, - "content": { - "error": "Internal error occurred. Our team has been notified.", - "timestamp": timestamp, - "request_id": request_id, - }, - } - ) - - # TODO: Remove this once we have a better way to serve internal docs INTERNAL_DOCS_PATH = str(Path(__file__).parents[3] / "launch_internal/site") if os.path.exists(INTERNAL_DOCS_PATH): diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg index 1566418e..a5f56d8a 100644 --- a/model-engine/setup.cfg +++ b/model-engine/setup.cfg @@ -4,6 +4,7 @@ test=pytest [coverage:run] omit = model_engine_server/entrypoints/* + model_engine_server/api/app.py # TODO: Fix pylint errors # [pylint] From 1ee6fbea9f71db6affac7ba5c666555d0cd561d9 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Wed, 29 Nov 2023 01:07:12 -0800 Subject: [PATCH 195/425] Fix custom framework Dockerfile (#395) --- model-engine/model_engine_server/inference/base.Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model-engine/model_engine_server/inference/base.Dockerfile b/model-engine/model_engine_server/inference/base.Dockerfile index 34f09ab5..88a7f7bf 100644 --- a/model-engine/model_engine_server/inference/base.Dockerfile +++ b/model-engine/model_engine_server/inference/base.Dockerfile @@ -3,6 +3,8 @@ FROM ${BASE_IMAGE} WORKDIR /app +RUN rm -rf /var/lib/apt/lists/* + # Install basic packages. RUN apt-get update && apt-get install -y \ apt-utils \ From 3adbd5956186531a1bf8ba9ca11c52f980a5a252 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Wed, 29 Nov 2023 11:12:28 -0800 Subject: [PATCH 196/425] fixing enum value (#396) --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/data_types.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index cbc5efa2..cc8f28be 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b20" +__version__ = "0.0.0b21" import os from typing import Sequence diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index cea75176..2a37e912 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -16,7 +16,7 @@ class LLMInferenceFramework(str, Enum): TEXT_GENERATION_INFERENCE = "text_generation_inference" VLLM = "vllm" LIGHTLLM = "lightllm" - TENSORRT_LLM = "tensorrt-llm" + TENSORRT_LLM = "tensorrt_llm" class LLMSource(str, Enum): diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 1d45ccc5..f9809cd2 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta20" +version = "0.0.0.beta21" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 3c5f5b5f..907a44b7 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta20", + version="0.0.0.beta21", packages=find_packages(), ) From 8501db049a8f630daf526595e94b00c11f17a4e4 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Thu, 30 Nov 2023 16:16:27 -0800 Subject: [PATCH 197/425] overriding model length for zephyr 7b alpha (#398) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 1 + 1 file changed, 1 insertion(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 35f52e49..08bdc7fd 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -194,6 +194,7 @@ # Can also see 13B, 34B there too "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, + "zephyr": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, } From 994448394272453a6f4aec4087763dd50e5a1c0e Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Fri, 1 Dec 2023 00:18:03 -0800 Subject: [PATCH 198/425] time completions use case (#397) * time use case * name * update fake --- model-engine/model_engine_server/api/llms_v1.py | 15 ++++++++++----- .../model_engine_server/common/dtos/llms.py | 14 ++++++++++++++ .../gateways/fake_monitoring_metrics_gateway.py | 3 +++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 79961e2d..64bbfc5f 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -44,6 +44,7 @@ logger_name, make_logger, ) +from model_engine_server.core.utils.timer import timer from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, EndpointDeleteFailedException, @@ -313,9 +314,10 @@ async def create_completion_sync_task( llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, tokenizer_repository=external_interfaces.tokenizer_repository, ) - response = await use_case.execute( - user=auth, model_endpoint_name=model_endpoint_name, request=request - ) + with timer() as use_case_timer: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) background_tasks.add_task( external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, TokenUsage( @@ -323,6 +325,7 @@ async def create_completion_sync_task( num_completion_tokens=response.output.num_completion_tokens if response.output else None, + total_duration=use_case_timer.duration, ), metric_metadata, ) @@ -374,8 +377,9 @@ async def create_completion_stream_task( async def event_generator(): try: - async for message in response: - yield {"data": message.json()} + with timer() as use_case_timer: + async for message in response: + yield {"data": message.json()} background_tasks.add_task( external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, TokenUsage( @@ -383,6 +387,7 @@ async def event_generator(): num_completion_tokens=message.output.num_completion_tokens if message.output else None, + total_duration=use_case_timer.duration, ), metric_metadata, ) diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 346c9ae2..fc531c1f 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -280,13 +280,27 @@ class CompletionStreamV1Response(BaseModel): class TokenUsage(BaseModel): + """ + Token usage for a prompt completion task. + """ + num_prompt_tokens: Optional[int] = 0 num_completion_tokens: Optional[int] = 0 + total_duration: Optional[float] = None + """Includes time spent waiting for the model to be ready.""" @property def num_total_tokens(self) -> int: return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) + @property + def total_tokens_per_second(self) -> float: + return ( + self.num_total_tokens / self.total_duration + if self.total_duration and self.total_duration > 0 + else 0.0 + ) + class CreateFineTuneRequest(BaseModel): model: str diff --git a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py index 9b63a135..dc419a07 100644 --- a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -21,6 +21,7 @@ def __init__(self): self.database_cache_miss = 0 self.route_call = defaultdict(int) self.token_count = 0 + self.total_tokens_per_second = 0 def reset(self): self.attempted_build = 0 @@ -35,6 +36,7 @@ def reset(self): self.database_cache_miss = 0 self.route_call = defaultdict(int) self.token_count = 0 + self.total_tokens_per_second = 0 def emit_attempted_build_metric(self): self.attempted_build += 1 @@ -71,3 +73,4 @@ def emit_route_call_metric(self, route: str, _metadata: MetricMetadata): def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMetadata): self.token_count += token_usage.num_total_tokens + self.total_tokens_per_second = token_usage.total_tokens_per_second From 8f657c76d0759258e4047e9d7adaee33bdb4a223 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 5 Dec 2023 18:22:26 -0800 Subject: [PATCH 199/425] update docs to show model len / context windows (#401) * update docs to show model len / context windows * make title clearer * make title clearer pt2 --- docs/model_zoo.md | 52 +++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 35c93e5d..50326b6f 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -2,32 +2,32 @@ Scale hosts the following models in the LLM Engine Model Zoo: -| Model Name | Inference APIs Available | Fine-tuning APIs Available | Inference Frameworks Available | -| --------------------- | ------------------------ | -------------------------- | ------------------------------ | -| `llama-7b` | ✅ | ✅ | deepspeed, text-generation-inference | -| `llama-2-7b` | ✅ | ✅ | text-generation-inference, vllm | -| `llama-2-7b-chat` | ✅ | | text-generation-inference, vllm | -| `llama-2-13b` | ✅ | | text-generation-inference, vllm | -| `llama-2-13b-chat` | ✅ | | text-generation-inference, vllm | -| `llama-2-70b` | ✅ | ✅ | text-generation-inference, vllm | -| `llama-2-70b-chat` | ✅ | | text-generation-inference, vllm | -| `falcon-7b` | ✅ | | text-generation-inference, vllm | -| `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | -| `falcon-40b` | ✅ | | text-generation-inference, vllm | -| `falcon-40b-instruct` | ✅ | | text-generation-inference, vllm | -| `mpt-7b` | ✅ | | deepspeed, text-generation-inference, vllm | -| `mpt-7b-instruct` | ✅ | ✅ | deepspeed, text-generation-inference, vllm | -| `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | -| `mistral-7b` | ✅ | ✅ | vllm | -| `mistral-7b-instruct` | ✅ | ✅ | vllm | -| `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | -| `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | -| `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | -| `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | -| `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | -| `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | -| `zephyr-7b-alpha` | ✅ | | text-generation-inference, vllm | -| `zephyr-7b-beta` | ✅ | | text-generation-inference, vllm | +| Model Name | Inference APIs Available | Fine-tuning APIs Available | Inference Frameworks Available | Inference max total tokens (prompt + response) | +| --------------------- | ------------------------ | -------------------------- | ------------------------------ | ------------------------------ | +| `llama-7b` | ✅ | ✅ | deepspeed, text-generation-inference | 2048 | +| `llama-2-7b` | ✅ | ✅ | text-generation-inference, vllm | 4096| +| `llama-2-7b-chat` | ✅ | | text-generation-inference, vllm | 4096| +| `llama-2-13b` | ✅ | | text-generation-inference, vllm | 4096| +| `llama-2-13b-chat` | ✅ | | text-generation-inference, vllm | 4096| +| `llama-2-70b` | ✅ | ✅ | text-generation-inference, vllm | 4096| +| `llama-2-70b-chat` | ✅ | | text-generation-inference, vllm | 4096| +| `falcon-7b` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-40b` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-40b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | +| `mpt-7b` | ✅ | | deepspeed, text-generation-inference, vllm | 2048 | +| `mpt-7b-instruct` | ✅ | ✅ | deepspeed, text-generation-inference, vllm | 2048 | +| `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | 2048 | +| `mistral-7b` | ✅ | ✅ | vllm | 8000 | +| `mistral-7b-instruct` | ✅ | ✅ | vllm | 8000 | +| `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `zephyr-7b-alpha` | ✅ | | text-generation-inference, vllm | 32768 | +| `zephyr-7b-beta` | ✅ | | text-generation-inference, vllm | 32768 | ## Usage From 69e07ff29f8c414e480d91e6ca3ad763ae7866ba Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Wed, 6 Dec 2023 18:35:25 -0800 Subject: [PATCH 200/425] Add MultiprocessingConcurrencyLimiter to gateway (#399) --- .../templates/gateway_deployment.yaml | 7 --- model-engine/model_engine_server/api/app.py | 48 ++++++++++++++----- .../model_engine_server/api/worker.py | 5 +- .../common/concurrency_limiter.py | 36 ++++++++++++++ .../inference/forwarding/http_forwarder.py | 33 +------------ .../sync_inference/fastapi_server.py | 31 +----------- 6 files changed, 77 insertions(+), 83 deletions(-) create mode 100644 model-engine/model_engine_server/common/concurrency_limiter.py diff --git a/charts/model-engine/templates/gateway_deployment.yaml b/charts/model-engine/templates/gateway_deployment.yaml index a58717a3..e5283319 100644 --- a/charts/model-engine/templates/gateway_deployment.yaml +++ b/charts/model-engine/templates/gateway_deployment.yaml @@ -49,13 +49,6 @@ spec: port: 5000 periodSeconds: 2 failureThreshold: 30 - livenessProbe: - httpGet: - path: /healthz - port: 5000 - initialDelaySeconds: 5 - periodSeconds: 2 - failureThreshold: 10 command: - dumb-init - -- diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index a26a8ddc..90f5620c 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -5,7 +5,7 @@ from pathlib import Path import pytz -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from model_engine_server.api.batch_jobs_v1 import batch_job_router_v1 @@ -21,6 +21,7 @@ from model_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1 from model_engine_server.api.tasks_v1 import inference_task_router_v1 from model_engine_server.api.triggers_v1 import trigger_router_v1 +from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, @@ -32,12 +33,34 @@ logger = make_logger(logger_name()) +# Allows us to make the Uvicorn worker concurrency in model_engine_server/api/worker.py very high +MAX_CONCURRENCY = 500 + +concurrency_limiter = MultiprocessingConcurrencyLimiter( + concurrency=MAX_CONCURRENCY, fail_on_concurrency_limit=True +) + +healthcheck_routes = ["/healthcheck", "/healthz", "/readyz"] + class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): try: LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) - return await call_next(request) + # we intentionally exclude healthcheck routes from the concurrency limiter + if request.url.path in healthcheck_routes: + return await call_next(request) + with concurrency_limiter: + return await call_next(request) + except HTTPException as e: + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + return JSONResponse( + status_code=e.status_code, + content={ + "error": e.detail, + "timestamp": timestamp, + }, + ) except Exception as e: tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) @@ -49,14 +72,12 @@ async def dispatch(self, request: Request, call_next): } logger.error("Unhandled exception: %s", structured_log) return JSONResponse( - { - "status_code": 500, - "content": { - "error": "Internal error occurred. Our team has been notified.", - "timestamp": timestamp, - "request_id": request_id, - }, - } + status_code=500, + content={ + "error": "Internal error occurred. Our team has been notified.", + "timestamp": timestamp, + "request_id": request_id, + }, ) @@ -91,9 +112,10 @@ def load_redis(): get_or_create_aioredis_pool() -@app.get("/healthcheck") -@app.get("/healthz") -@app.get("/readyz") def healthcheck() -> Response: """Returns 200 if the app is healthy.""" return Response(status_code=200) + + +for endpoint in healthcheck_routes: + app.get(endpoint)(healthcheck) diff --git a/model-engine/model_engine_server/api/worker.py b/model-engine/model_engine_server/api/worker.py index d08113b5..289640c8 100644 --- a/model-engine/model_engine_server/api/worker.py +++ b/model-engine/model_engine_server/api/worker.py @@ -1,8 +1,9 @@ from uvicorn.workers import UvicornWorker -# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit, before adding rate limiting just increase the concurrency +# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit # We'll autoscale at target concurrency of a much lower number (around 50), and this just makes sure we don't 503 with bursty traffic -CONCURRENCY_LIMIT = 1000 +# We set this very high since model_engine_server/api/app.py sets a lower per-pod concurrency at which we start returning 429s +CONCURRENCY_LIMIT = 10000 class LaunchWorker(UvicornWorker): diff --git a/model-engine/model_engine_server/common/concurrency_limiter.py b/model-engine/model_engine_server/common/concurrency_limiter.py new file mode 100644 index 00000000..b4e10c81 --- /dev/null +++ b/model-engine/model_engine_server/common/concurrency_limiter.py @@ -0,0 +1,36 @@ +from multiprocessing import BoundedSemaphore +from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType +from typing import Optional + +from fastapi import HTTPException +from model_engine_server.core.loggers import logger_name, make_logger + +logger = make_logger(logger_name()) + + +class MultiprocessingConcurrencyLimiter: + def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): + self.concurrency = concurrency + if concurrency is not None: + if concurrency < 1: + raise ValueError("Concurrency should be at least 1") + self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency) + self.blocking = ( + not fail_on_concurrency_limit + ) # we want to block if we want to queue up requests + else: + self.semaphore = None + self.blocking = False # Unused + + def __enter__(self): + logger.debug("Entering concurrency limiter semaphore") + if self.semaphore and not self.semaphore.acquire(block=self.blocking): + logger.warning(f"Too many requests (max {self.concurrency}), returning 429") + raise HTTPException(status_code=429, detail="Too many requests") + # Just raises an HTTPException. + # __exit__ should not run; otherwise the release() doesn't have an acquire() + + def __exit__(self, type, value, traceback): + logger.debug("Exiting concurrency limiter semaphore") + if self.semaphore: + self.semaphore.release() diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index 5943bc50..f121bec2 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -3,11 +3,9 @@ import os import subprocess from functools import lru_cache -from multiprocessing import BoundedSemaphore -from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType -from typing import Optional -from fastapi import Depends, FastAPI, HTTPException +from fastapi import Depends, FastAPI +from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.forwarding.forwarding import ( @@ -21,33 +19,6 @@ app = FastAPI() -class MultiprocessingConcurrencyLimiter: - def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): - if concurrency is not None: - if concurrency < 1: - raise ValueError("Concurrency should be at least 1") - self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency) - self.blocking = ( - not fail_on_concurrency_limit - ) # we want to block if we want to queue up requests - else: - self.semaphore = None - self.blocking = False # Unused - - def __enter__(self): - logger.debug("Entering concurrency limiter semaphore") - if self.semaphore and not self.semaphore.acquire(block=self.blocking): - logger.warning("Too many requests, returning 429") - raise HTTPException(status_code=429, detail="Too many requests") - # Just raises an HTTPException. - # __exit__ should not run; otherwise the release() doesn't have an acquire() - - def __exit__(self, type, value, traceback): - logger.debug("Exiting concurrency limiter semaphore") - if self.semaphore: - self.semaphore.release() - - @app.get("/healthz") @app.get("/readyz") def healthcheck(): diff --git a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py index bec1c50c..02b68eca 100644 --- a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py @@ -1,10 +1,8 @@ import traceback from functools import wraps -from multiprocessing import BoundedSemaphore -from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType -from typing import Optional from fastapi import BackgroundTasks, FastAPI, HTTPException, Response, status +from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.common import ( @@ -25,33 +23,6 @@ logger = make_logger(logger_name()) -class MultiprocessingConcurrencyLimiter: - def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): - if concurrency is not None: - if concurrency < 1: - raise ValueError("Concurrency should be at least 1") - self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency) - self.blocking = ( - not fail_on_concurrency_limit - ) # we want to block if we want to queue up requests - else: - self.semaphore = None - self.blocking = False # Unused - - def __enter__(self): - logger.debug("Entering concurrency limiter semaphore") - if self.semaphore and not self.semaphore.acquire(block=self.blocking): - logger.warning("Too many requests, returning 429") - raise HTTPException(status_code=429, detail="Too many requests") - # Just raises an HTTPException. - # __exit__ should not run; otherwise the release() doesn't have an acquire() - - def __exit__(self, type, value, traceback): - logger.debug("Exiting concurrency limiter semaphore") - if self.semaphore: - self.semaphore.release() - - def with_concurrency_limit(concurrency_limiter: MultiprocessingConcurrencyLimiter): def _inner(flask_func): @wraps(flask_func) From b349a0d8786517e8d708b8b4147a45d51ee96bdb Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Thu, 7 Dec 2023 10:58:21 -0800 Subject: [PATCH 201/425] change code-llama to codellama (#400) * change code-llama to codellama * use both code-llama and codellama temporarily --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 08bdc7fd..359c525b 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -190,6 +190,10 @@ # Based on config here: https://huggingface.co/TIGER-Lab/MAmmoTH-Coder-7B/blob/main/config.json#L12 # Can also see 13B, 34B there too "code-llama": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, + "codellama": { + "max_model_len": 16384, + "max_num_batched_tokens": 16384, + }, # setting both for backwards compatibility, will phase code-llama out in a future pr # Based on config here: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json#L12 # Can also see 13B, 34B there too "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, From cefef80e958adea86e18de4499ffa290e4b38afd Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Fri, 8 Dec 2023 13:16:24 -0800 Subject: [PATCH 202/425] fix completions request id (#402) * fix completions request id --- .../model_engine_server/common/datadog_utils.py | 7 ++++++- model-engine/model_engine_server/common/dtos/llms.py | 4 ++-- .../domain/use_cases/llm_model_endpoint_use_cases.py | 12 ++++++++---- model-engine/tests/unit/domain/test_llm_use_cases.py | 3 --- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/model-engine/model_engine_server/common/datadog_utils.py b/model-engine/model_engine_server/common/datadog_utils.py index 3e3513cb..5707d964 100644 --- a/model-engine/model_engine_server/common/datadog_utils.py +++ b/model-engine/model_engine_server/common/datadog_utils.py @@ -1,10 +1,15 @@ +from typing import Optional + from ddtrace import tracer -def add_trace_request_id(request_id: str): +def add_trace_request_id(request_id: Optional[str]): """Adds a custom tag to a given dd trace corresponding to the request id so that we can filter in Datadog easier """ + if not request_id: + return + current_span = tracer.current_span() if current_span: current_span.set_tag("launch.request_id", request_id) diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index fc531c1f..6e991e45 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -199,7 +199,7 @@ class CompletionSyncV1Response(BaseModel): Response object for a synchronous prompt completion task. """ - request_id: str + request_id: Optional[str] output: Optional[CompletionOutput] = None @@ -273,7 +273,7 @@ class CompletionStreamV1Response(BaseModel): Response object for a stream prompt completion task. """ - request_id: str + request_id: Optional[str] output: Optional[CompletionStreamOutput] = None error: Optional[StreamError] = None """Error of the response (if any).""" diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 359c525b..dcf0d0d2 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -9,7 +9,6 @@ import os from dataclasses import asdict from typing import Any, AsyncIterable, Dict, List, Optional, Union -from uuid import uuid4 from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.llms import ( @@ -35,7 +34,12 @@ from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) from model_engine_server.domain.entities import ( LLMInferenceFramework, LLMMetadata, @@ -1448,7 +1452,7 @@ async def execute( ObjectNotAuthorizedException: If the owner does not own the model endpoint. """ - request_id = str(uuid4()) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( @@ -1736,7 +1740,7 @@ async def execute( ObjectNotAuthorizedException: If the owner does not own the model endpoint. """ - request_id = str(uuid4()) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index c4fbb31f..31a579a6 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -948,7 +948,6 @@ async def test_completion_stream_use_case_success( output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] i = 0 async for message in response_1: - assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] if i == 6: assert message.dict()["output"]["num_prompt_tokens"] == 7 @@ -1016,7 +1015,6 @@ async def test_completion_stream_text_generation_inference_use_case_success( output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] i = 0 async for message in response_1: - assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] if i == 5: assert message.dict()["output"]["num_prompt_tokens"] == 7 @@ -1079,7 +1077,6 @@ async def test_completion_stream_trt_llm_use_case_success( output_texts = ["Machine", "learning", "is", "a", "branch"] i = 0 async for message in response_1: - assert message.dict()["request_id"] assert message.dict()["output"]["text"] == output_texts[i] assert message.dict()["output"]["num_prompt_tokens"] == 7 assert message.dict()["output"]["num_completion_tokens"] == i + 1 From 04a59084b14d62f336a9274d62ba776136fed3db Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Fri, 8 Dec 2023 14:53:37 -0800 Subject: [PATCH 203/425] Allow latest inference framework tag (#403) --- clients/python/llmengine/model.py | 2 +- .../model_engine_server/api/llms_v1.py | 2 + .../model_engine_server/core/docker/ecr.py | 12 +++++ .../domain/repositories/docker_repository.py | 11 +++++ .../use_cases/llm_model_endpoint_use_cases.py | 44 +++++++++++++------ .../repositories/ecr_docker_repository.py | 4 ++ .../repositories/fake_docker_repository.py | 3 ++ model-engine/setup.cfg | 1 + model-engine/tests/unit/conftest.py | 3 ++ model-engine/tests/unit/domain/conftest.py | 3 +- .../tests/unit/domain/test_llm_use_cases.py | 12 ++++- 11 files changed, 79 insertions(+), 18 deletions(-) diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 3bd88944..fa84d1e3 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -67,7 +67,7 @@ def create( Name of the base model inference_framework_image_tag (`str`): - Image tag for the inference framework + Image tag for the inference framework. Use "latest" for the most recent image source (`LLMSource`): Source of the LLM. Currently only HuggingFace is supported diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 64bbfc5f..8de1551d 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -164,6 +164,7 @@ async def create_model_endpoint( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, model_endpoint_service=external_interfaces.model_endpoint_service, + docker_repository=external_interfaces.docker_repository, ) return await use_case.execute(user=auth, request=request) except ObjectAlreadyExistsException as exc: @@ -265,6 +266,7 @@ async def update_model_endpoint( create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, model_endpoint_service=external_interfaces.model_endpoint_service, llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + docker_repository=external_interfaces.docker_repository, ) return await use_case.execute( user=auth, model_endpoint_name=model_endpoint_name, request=request diff --git a/model-engine/model_engine_server/core/docker/ecr.py b/model-engine/model_engine_server/core/docker/ecr.py index aaf9ef6f..fcd324b9 100644 --- a/model-engine/model_engine_server/core/docker/ecr.py +++ b/model-engine/model_engine_server/core/docker/ecr.py @@ -97,3 +97,15 @@ def ecr_exists_for_repo(repo_name: str, image_tag: Optional[str] = None): return True except ecr.exceptions.ImageNotFoundException: return False + + +def get_latest_image_tag(repository_name: str): + ecr = boto3.client("ecr", region_name=infra_config().default_region) + images = ecr.describe_images( + registryId=infra_config().ml_account_id, + repositoryName=repository_name, + filter=DEFAULT_FILTER, + maxResults=1000, + )["imageDetails"] + latest_image = max(images, key=lambda image: image["imagePushedAt"]) + return latest_image["imageTags"][0] diff --git a/model-engine/model_engine_server/domain/repositories/docker_repository.py b/model-engine/model_engine_server/domain/repositories/docker_repository.py index b2d410a1..f8ba774c 100644 --- a/model-engine/model_engine_server/domain/repositories/docker_repository.py +++ b/model-engine/model_engine_server/domain/repositories/docker_repository.py @@ -49,6 +49,17 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: """ pass + @abstractmethod + def get_latest_image_tag(self, repository_name: str) -> str: + """ + Returns the Docker image tag of the most recently pushed image in the given repository + + Args: + repository_name: the name of the repository containing the image. + + Returns: the tag of the latest Docker image. + """ + def is_repo_name(self, repo_name: str): # We assume repository names must start with a letter and can only contain lowercase letters, numbers, hyphens, underscores, and forward slashes. # Based-off ECR naming standards diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index dcf0d0d2..31ebfd35 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -89,6 +89,14 @@ logger = make_logger(logger_name()) +INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = { + LLMInferenceFramework.DEEPSPEED: "instant-llm", + LLMInferenceFramework.TEXT_GENERATION_INFERENCE: hmi_config.tgi_repository, + LLMInferenceFramework.VLLM: hmi_config.vllm_repository, + LLMInferenceFramework.LIGHTLLM: hmi_config.lightllm_repository, + LLMInferenceFramework.TENSORRT_LLM: hmi_config.tensorrt_llm_repository, +} + _SUPPORTED_MODELS_BY_FRAMEWORK = { LLMInferenceFramework.DEEPSPEED: set( [ @@ -332,8 +340,10 @@ async def execute( checkpoint_path: Optional[str], ) -> ModelBundle: if source == LLMSource.HUGGING_FACE: + self.check_docker_image_exists_for_image_tag( + framework_image_tag, INFERENCE_FRAMEWORK_REPOSITORY[framework] + ) if framework == LLMInferenceFramework.DEEPSPEED: - self.check_docker_image_exists_for_image_tag(framework_image_tag, "instant-llm") bundle_id = await self.create_deepspeed_bundle( user, model_name, @@ -342,9 +352,6 @@ async def execute( endpoint_name, ) elif framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: - self.check_docker_image_exists_for_image_tag( - framework_image_tag, hmi_config.tgi_repository - ) bundle_id = await self.create_text_generation_inference_bundle( user, model_name, @@ -355,9 +362,6 @@ async def execute( checkpoint_path, ) elif framework == LLMInferenceFramework.VLLM: - self.check_docker_image_exists_for_image_tag( - framework_image_tag, hmi_config.vllm_repository - ) bundle_id = await self.create_vllm_bundle( user, model_name, @@ -368,9 +372,6 @@ async def execute( checkpoint_path, ) elif framework == LLMInferenceFramework.LIGHTLLM: - self.check_docker_image_exists_for_image_tag( - framework_image_tag, hmi_config.lightllm_repository - ) bundle_id = await self.create_lightllm_bundle( user, model_name, @@ -862,10 +863,12 @@ def __init__( self, create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, model_endpoint_service: ModelEndpointService, + docker_repository: DockerRepository, ): self.authz_module = LiveAuthorizationModule() self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case self.model_endpoint_service = model_endpoint_service + self.docker_repository = docker_repository async def execute( self, user: User, request: CreateLLMModelEndpointV1Request @@ -895,6 +898,11 @@ async def execute( f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM, LightLLM, and TensorRT-LLM." ) + if request.inference_framework_image_tag == "latest": + request.inference_framework_image_tag = self.docker_repository.get_latest_image_tag( + INFERENCE_FRAMEWORK_REPOSITORY[request.inference_framework] + ) + bundle = await self.create_llm_model_bundle_use_case.execute( user, endpoint_name=request.name, @@ -1059,11 +1067,13 @@ def __init__( create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, model_endpoint_service: ModelEndpointService, llm_model_endpoint_service: LLMModelEndpointService, + docker_repository: DockerRepository, ): self.authz_module = LiveAuthorizationModule() self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case self.model_endpoint_service = model_endpoint_service self.llm_model_endpoint_service = llm_model_endpoint_service + self.docker_repository = docker_repository async def execute( self, user: User, model_endpoint_name: str, request: UpdateLLMModelEndpointV1Request @@ -1106,12 +1116,18 @@ async def execute( llm_metadata = (model_endpoint.record.metadata or {}).get("_llm", {}) inference_framework = llm_metadata["inference_framework"] + if request.inference_framework_image_tag == "latest": + inference_framework_image_tag = self.docker_repository.get_latest_image_tag( + INFERENCE_FRAMEWORK_REPOSITORY[inference_framework] + ) + else: + inference_framework_image_tag = ( + request.inference_framework_image_tag + or llm_metadata["inference_framework_image_tag"] + ) + model_name = request.model_name or llm_metadata["model_name"] source = request.source or llm_metadata["source"] - inference_framework_image_tag = ( - request.inference_framework_image_tag - or llm_metadata["inference_framework_image_tag"] - ) num_shards = request.num_shards or llm_metadata["num_shards"] quantize = request.quantize or llm_metadata.get("quantize") checkpoint_path = request.checkpoint_path or llm_metadata.get("checkpoint_path") diff --git a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py index 16c6b742..d283c4c4 100644 --- a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py @@ -3,6 +3,7 @@ from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse from model_engine_server.core.config import infra_config +from model_engine_server.core.docker.ecr import get_latest_image_tag from model_engine_server.core.docker.ecr import image_exists as ecr_image_exists from model_engine_server.core.docker.remote_build import build_remote_block from model_engine_server.core.loggers import logger_name, make_logger @@ -52,3 +53,6 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: return BuildImageResponse( status=build_result.status, logs=build_result.logs, job_name=build_result.job_name ) + + def get_latest_image_tag(self, repository_name: str) -> str: + return get_latest_image_tag(repository_name=repository_name) diff --git a/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py index b7fa39a6..2d12de6e 100644 --- a/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py @@ -19,3 +19,6 @@ def get_image_url(self, image_tag: str, repository_name: str) -> str: def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: raise NotImplementedError("FakeDockerRepository build_image() not implemented") + + def get_latest_image_tag(self, repository_name: str) -> str: + raise NotImplementedError("FakeDockerRepository get_latest_image_tag() not implemented") diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg index a5f56d8a..053cae1e 100644 --- a/model-engine/setup.cfg +++ b/model-engine/setup.cfg @@ -5,6 +5,7 @@ test=pytest omit = model_engine_server/entrypoints/* model_engine_server/api/app.py + model_engine_server/core/docker/ecr.py # TODO: Fix pylint errors # [pylint] diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 3528d558..03ae16b8 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -672,6 +672,9 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: raise Exception("I hope you're handling this!") return BuildImageResponse(status=True, logs="", job_name="test-job-name") + def get_latest_image_tag(self, repository_name: str) -> str: + return "fake_docker_repository_latest_image_tag" + class FakeModelEndpointCacheRepository(ModelEndpointCacheRepository): def __init__(self): diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index f433071c..798af362 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -203,7 +203,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request model_name="mpt-7b", source="hugging_face", inference_framework="deepspeed", - inference_framework_image_tag="test_tag", + inference_framework_image_tag="latest", num_shards=2, endpoint_type=ModelEndpointType.ASYNC, metadata={}, @@ -252,6 +252,7 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req @pytest.fixture def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request: return UpdateLLMModelEndpointV1Request( + inference_framework_image_tag="latest", checkpoint_path="s3://test_checkpoint_path", memory="4G", min_workers=0, diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 31a579a6..99fab709 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -78,6 +78,7 @@ async def test_create_model_endpoint_use_case_success( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) @@ -97,7 +98,7 @@ async def test_create_model_endpoint_use_case_success( "model_name": create_llm_model_endpoint_request_async.model_name, "source": create_llm_model_endpoint_request_async.source, "inference_framework": create_llm_model_endpoint_request_async.inference_framework, - "inference_framework_image_tag": create_llm_model_endpoint_request_async.inference_framework_image_tag, + "inference_framework_image_tag": "fake_docker_repository_latest_image_tag", "num_shards": create_llm_model_endpoint_request_async.num_shards, "quantize": None, "checkpoint_path": create_llm_model_endpoint_request_async.checkpoint_path, @@ -201,6 +202,7 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() @@ -241,6 +243,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -302,6 +305,7 @@ async def test_create_model_endpoint_trt_llm_use_case_success( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -362,6 +366,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): @@ -395,6 +400,7 @@ async def test_create_llm_model_endpoint_use_case_quantization_exception( use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): @@ -463,11 +469,13 @@ async def test_update_model_endpoint_use_case_success( create_use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) update_use_case = UpdateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) @@ -501,7 +509,7 @@ async def test_update_model_endpoint_use_case_success( "model_name": create_llm_model_endpoint_request_streaming.model_name, "source": create_llm_model_endpoint_request_streaming.source, "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, - "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, + "inference_framework_image_tag": "fake_docker_repository_latest_image_tag", "num_shards": create_llm_model_endpoint_request_streaming.num_shards, "quantize": None, "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, From c9ceab9df3054f4865b9edb7109937b539c23727 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Mon, 11 Dec 2023 15:19:49 -0800 Subject: [PATCH 204/425] Bump helm chart version 0.1.0 to 0.1.1 (#406) --- charts/model-engine/Chart.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/charts/model-engine/Chart.yaml b/charts/model-engine/Chart.yaml index 16f2c405..2991cf1e 100644 --- a/charts/model-engine/Chart.yaml +++ b/charts/model-engine/Chart.yaml @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.0 +version: 0.1.1 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to From 5ec6adaecdfc71a0b394f537caa152721f86ec36 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 11 Dec 2023 15:43:33 -0800 Subject: [PATCH 205/425] 4x sqlalchemy pool size (#405) * 4x sqlalchemy pool size * don't update nullpool --- model-engine/model_engine_server/db/base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index 0f882ea3..4469b30b 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -79,25 +79,32 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool echo=False, future=True, pool_pre_ping=True, + pool_size=20, + max_overflow=30, ) pg_engine_read_only = create_engine( get_engine_url(read_only=True, sync=True), echo=False, future=True, pool_pre_ping=True, + pool_size=20, + max_overflow=30, ) pg_engine_async = create_async_engine( get_engine_url(read_only=False, sync=False), echo=False, future=True, pool_pre_ping=True, + pool_size=20, + max_overflow=30, ) pg_engine_read_only_async = create_async_engine( get_engine_url(read_only=True, sync=False), echo=False, future=True, pool_pre_ping=True, - max_overflow=5, + pool_size=20, + max_overflow=30, ) pg_engine_async_null_pool = create_async_engine( get_engine_url(read_only=False, sync=False), From 353e246ed0f132700e0c8022e25e48a21979599b Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Mon, 11 Dec 2023 16:06:01 -0800 Subject: [PATCH 206/425] bump datadog module to 0.47.0 for ipv6 support for dogstatsd (#407) --- .../model_engine_server/inference/requirements_base.txt | 2 +- model-engine/requirements.in | 2 +- model-engine/requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/inference/requirements_base.txt b/model-engine/model_engine_server/inference/requirements_base.txt index cedabe42..14f6577a 100644 --- a/model-engine/model_engine_server/inference/requirements_base.txt +++ b/model-engine/model_engine_server/inference/requirements_base.txt @@ -2,7 +2,7 @@ aioredis~=2.0 boto3>=1.28.38 celery[redis,sqs,tblib]==5.3.1 datadog-api-client==2.11.0 -datadog~=0.46.0 +datadog~=0.47.0 fastapi==0.78.0 # Incompatibility between celery 5 and python 3.7 because of importlib-metadata 5, so we pin it importlib-metadata<5.0;python_version<"3.8" diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 8575237b..eb2d393e 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -14,7 +14,7 @@ cloudpickle==2.1.0 croniter==1.4.1 dataclasses-json>=0.5.7 datadog-api-client==2.11.0 -datadog~=0.46.0 +datadog~=0.47.0 ddtrace==1.8.3 deprecation~=2.1 docker~=5.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 19e4edf5..9a9a062a 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -111,7 +111,7 @@ cryptography==41.0.3 # via secretstorage dataclasses-json==0.5.9 # via -r model-engine/requirements.in -datadog==0.46.0 +datadog==0.47.0 # via -r model-engine/requirements.in datadog-api-client==2.11.0 # via -r model-engine/requirements.in From 74cc915b45d9b76bf8dc80e033cfb397ea247e19 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Mon, 11 Dec 2023 17:08:30 -0800 Subject: [PATCH 207/425] Fix autoscaler node selector (#409) Un-hardcode the existing nodeSelector --- charts/model-engine/Chart.yaml | 2 +- .../templates/celery_autoscaler_stateful_set.yaml | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/charts/model-engine/Chart.yaml b/charts/model-engine/Chart.yaml index 2991cf1e..175dba37 100644 --- a/charts/model-engine/Chart.yaml +++ b/charts/model-engine/Chart.yaml @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.1 +version: 0.1.2 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to diff --git a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml index 8dafd7e5..768fdb8b 100644 --- a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml +++ b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml @@ -73,8 +73,10 @@ spec: - mountPath: /opt/.aws/config name: config-volume subPath: config + {{ with .Values.nodeSelector }} nodeSelector: - node-lifecycle: normal + {{- toYaml . | nindent 8 }} + {{- end }} tolerations: - key: CriticalAddonsOnly operator: Equal From 5b649727390a5c48015330df9765b6eab5ef81c4 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 12 Dec 2023 16:05:30 -0800 Subject: [PATCH 208/425] Log request sizes (#410) --- model-engine/model_engine_server/api/app.py | 1 + model-engine/model_engine_server/core/loggers.py | 1 + 2 files changed, 2 insertions(+) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index 90f5620c..851f0183 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -47,6 +47,7 @@ class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): try: LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) + LoggerTagManager.set(LoggerTagKey.REQUEST_SIZE, request.headers.get("content-length")) # we intentionally exclude healthcheck routes from the concurrency limiter if request.url.path in healthcheck_routes: return await call_next(request) diff --git a/model-engine/model_engine_server/core/loggers.py b/model-engine/model_engine_server/core/loggers.py index 30b0deee..3a28d450 100644 --- a/model-engine/model_engine_server/core/loggers.py +++ b/model-engine/model_engine_server/core/loggers.py @@ -42,6 +42,7 @@ class LoggerTagKey(str, Enum): REQUEST_ID = "request_id" TEAM_ID = "team_id" USER_ID = "user_id" + REQUEST_SIZE = "request_size" class LoggerTagManager: From 474155e33f8b912d311953040b4d799f8c78ce57 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Thu, 14 Dec 2023 14:26:24 -0800 Subject: [PATCH 209/425] add support for mixtral-8x7b and mixtral-8x7b-instruct (#408) * bump datadog module to 0.47.0 for ipv6 support for dogstatsd * add mixtral-8x7b and mixtral-8x7b-instruct * update context window * docker update * install megablocks --- docs/model_zoo.md | 54 ++++++++++--------- .../use_cases/llm_model_endpoint_use_cases.py | 3 ++ .../inference/vllm/Dockerfile | 7 ++- .../inference/vllm/requirements.txt | 4 +- .../repositories/live_tokenizer_repository.py | 2 + 5 files changed, 41 insertions(+), 29 deletions(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 50326b6f..18610abb 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -2,32 +2,34 @@ Scale hosts the following models in the LLM Engine Model Zoo: -| Model Name | Inference APIs Available | Fine-tuning APIs Available | Inference Frameworks Available | Inference max total tokens (prompt + response) | -| --------------------- | ------------------------ | -------------------------- | ------------------------------ | ------------------------------ | -| `llama-7b` | ✅ | ✅ | deepspeed, text-generation-inference | 2048 | -| `llama-2-7b` | ✅ | ✅ | text-generation-inference, vllm | 4096| -| `llama-2-7b-chat` | ✅ | | text-generation-inference, vllm | 4096| -| `llama-2-13b` | ✅ | | text-generation-inference, vllm | 4096| -| `llama-2-13b-chat` | ✅ | | text-generation-inference, vllm | 4096| -| `llama-2-70b` | ✅ | ✅ | text-generation-inference, vllm | 4096| -| `llama-2-70b-chat` | ✅ | | text-generation-inference, vllm | 4096| -| `falcon-7b` | ✅ | | text-generation-inference, vllm | 2048 | -| `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | -| `falcon-40b` | ✅ | | text-generation-inference, vllm | 2048 | -| `falcon-40b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | -| `mpt-7b` | ✅ | | deepspeed, text-generation-inference, vllm | 2048 | -| `mpt-7b-instruct` | ✅ | ✅ | deepspeed, text-generation-inference, vllm | 2048 | -| `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | 2048 | -| `mistral-7b` | ✅ | ✅ | vllm | 8000 | -| `mistral-7b-instruct` | ✅ | ✅ | vllm | 8000 | -| `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `zephyr-7b-alpha` | ✅ | | text-generation-inference, vllm | 32768 | -| `zephyr-7b-beta` | ✅ | | text-generation-inference, vllm | 32768 | +| Model Name | Inference APIs Available | Fine-tuning APIs Available | Inference Frameworks Available | Inference max total tokens (prompt + response) | +| ------------------------ | ------------------------ | -------------------------- | ------------------------------------------ | ---------------------------------------------- | +| `llama-7b` | ✅ | ✅ | deepspeed, text-generation-inference | 2048 | +| `llama-2-7b` | ✅ | ✅ | text-generation-inference, vllm | 4096 | +| `llama-2-7b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-2-13b` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-2-13b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-2-70b` | ✅ | ✅ | text-generation-inference, vllm | 4096 | +| `llama-2-70b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `falcon-7b` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-40b` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-40b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | +| `mpt-7b` | ✅ | | deepspeed, text-generation-inference, vllm | 2048 | +| `mpt-7b-instruct` | ✅ | ✅ | deepspeed, text-generation-inference, vllm | 2048 | +| `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | 2048 | +| `mistral-7b` | ✅ | ✅ | vllm | 8000 | +| `mistral-7b-instruct` | ✅ | ✅ | vllm | 8000 | +| `mixtral-8x7b` | ✅ | | vllm | 32768 | +| `mixtral-8x7b-instruct` | ✅ | | vllm | 32768 | +| `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `zephyr-7b-alpha` | ✅ | | text-generation-inference, vllm | 32768 | +| `zephyr-7b-beta` | ✅ | | text-generation-inference, vllm | 32768 | ## Usage diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 31ebfd35..35689a22 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -165,6 +165,8 @@ "codellama-34b-instruct", "mistral-7b", "mistral-7b-instruct", + "mixtral-8x7b", + "mixtral-8x7b-instruct", "mammoth-coder-llama-2-7b", "mammoth-coder-llama-2-13b", "mammoth-coder-llama-2-34b", @@ -210,6 +212,7 @@ # Can also see 13B, 34B there too "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, + "mixtral": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, "zephyr": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, } diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile b/model-engine/model_engine_server/inference/vllm/Dockerfile index d03a2c03..6f1a00c5 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile @@ -1,8 +1,13 @@ -FROM nvcr.io/nvidia/pytorch:22.12-py3 +FROM nvcr.io/nvidia/pytorch:23.09-py3 RUN pip uninstall torch -y COPY requirements.txt /workspace/requirements.txt RUN pip install -r requirements.txt + +# install special version of megablocks +RUN pip install git+https://github.com/stanford-futuredata/megablocks.git@5897cd6f254b7b3edf7a708a3a3314ecb54b6f78#egg=megablocks + RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz + COPY vllm_server.py /workspace/vllm_server.py diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index b5407ab9..e2c3aa08 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,3 +1,3 @@ ray==2.6.3 -vllm==0.2.0 -pydantic==1.10.12 +vllm==0.2.5 +pydantic==1.10.13 diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 873f2e65..b586bc9c 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -58,6 +58,8 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: ), "mistral-7b": ModelInfo("mistralai/Mistral-7B-v0.1", None), "mistral-7b-instruct": ModelInfo("mistralai/Mistral-7B-Instruct-v0.1", None), + "mixtral-8x7b": ModelInfo("mistralai/Mixtral-8x7B-v0.1", None), + "mixtral-8x7b-instruct": ModelInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", None), "mammoth-coder-llama-2-7b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-7B", None), "mammoth-coder-llama-2-13b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-13B", None), "mammoth-coder-llama-2-34b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-34B", None), From d915f5bb63463efadb01c55b7bf6f758c1402105 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 21 Dec 2023 17:02:33 -0800 Subject: [PATCH 210/425] Make sure metadata is not incorrectly wiped during endpoint update (#413) --- .../infra/services/live_model_endpoint_service.py | 2 +- .../services/test_live_model_endpoint_service.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index ede0c39e..b9ffc260 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -312,7 +312,7 @@ async def update_model_endpoint( if record.current_model_bundle.id != model_bundle_id: if metadata is None: - metadata = {} + metadata = record.metadata if record.metadata is not None else {} # MODEL_BUNDLE_CHANGED_KEY will be checked during _create_deployment in K8SEndpointResourceDelegate metadata[MODEL_BUNDLE_CHANGED_KEY] = True diff --git a/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py index 66969005..52886431 100644 --- a/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py +++ b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py @@ -238,6 +238,17 @@ async def test_create_update_model_endpoint_success( assert model_endpoint.infra_state.deployment_state.max_workers == update_kwargs["max_workers"] assert model_endpoint.infra_state.labels == update_kwargs["labels"] + # Now update min_worker only + update_kwargs: Any = dict( + min_workers=2, + ) + updated_model_endpoint_record = await fake_live_model_endpoint_service.update_model_endpoint( + model_endpoint_id=model_endpoint_record.id, **update_kwargs + ) + + # Make sure metadata is not updated + assert updated_model_endpoint_record.metadata == {"some_new_key": "some_new_values"} + @pytest.mark.skip(reason="Exception is temporarily disabled due to lock flakiness") @pytest.mark.asyncio From 6bbcb6c3ec1f94c19bc81a6fdf96d6614dcd7999 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 21 Dec 2023 17:36:03 -0800 Subject: [PATCH 211/425] Always return output for completions sync response (#412) * Always return output for completions sync response * fix tests and dependencies * more tests * add more tests * revert dependency changes: --- .../model_engine_server/api/llms_v1.py | 8 +- .../use_cases/llm_model_endpoint_use_cases.py | 40 ++- model-engine/requirements.txt | 33 +- model-engine/tests/unit/api/conftest.py | 4 + model-engine/tests/unit/api/test_llms.py | 19 +- model-engine/tests/unit/conftest.py | 290 +++++++++++++++++- .../tests/unit/domain/test_llm_use_cases.py | 136 +++++--- 7 files changed, 442 insertions(+), 88 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 8de1551d..34f6b7f9 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -332,12 +332,14 @@ async def create_completion_sync_task( metric_metadata, ) return response - except UpstreamServiceError: + except UpstreamServiceError as exc: request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) - logger.exception(f"Upstream service error for request {request_id}") + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) raise HTTPException( status_code=500, - detail=f"Upstream service error for request_id {request_id}.", + detail=f"Upstream service error for request_id {request_id}", ) except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 35689a22..0693db9c 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1555,9 +1555,11 @@ async def execute( ), ) else: - return CompletionSyncV1Response( - request_id=request_id, - output=None, + raise UpstreamServiceError( + status_code=500, + content=predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"", ) elif ( endpoint_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE @@ -1589,9 +1591,11 @@ async def execute( ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: - return CompletionSyncV1Response( - request_id=request_id, - output=None, + raise UpstreamServiceError( + status_code=500, + content=predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"", ) output = json.loads(predict_result.result["result"]) @@ -1628,9 +1632,11 @@ async def execute( ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: - return CompletionSyncV1Response( - request_id=request_id, - output=None, + raise UpstreamServiceError( + status_code=500, + content=predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"", ) output = json.loads(predict_result.result["result"]) @@ -1670,9 +1676,11 @@ async def execute( ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: - return CompletionSyncV1Response( - request_id=request_id, - output=None, + raise UpstreamServiceError( + status_code=500, + content=predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"", ) output = json.loads(predict_result.result["result"]) @@ -1706,9 +1714,11 @@ async def execute( ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: - return CompletionSyncV1Response( - request_id=request_id, - output=None, + raise UpstreamServiceError( + status_code=500, + content=predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"", ) output = json.loads(predict_result.result["result"]) diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 9a9a062a..2a7390f8 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -54,7 +54,9 @@ boto3==1.28.1 # celery # kombu boto3-stubs[essential]==1.26.67 - # via -r model-engine/requirements.in + # via + # -r model-engine/requirements.in + # boto3-stubs botocore==1.31.1 # via # -r model-engine/requirements.in @@ -71,15 +73,15 @@ cachetools==5.3.1 cattrs==23.1.2 # via ddtrace celery[redis,sqs,tblib]==5.3.1 - # via -r model-engine/requirements.in + # via + # -r model-engine/requirements.in + # celery certifi==2023.7.22 # via # datadog-api-client # kubernetes # kubernetes-asyncio # requests -cffi==1.15.1 - # via cryptography charset-normalizer==3.2.0 # via # aiohttp @@ -107,8 +109,6 @@ commonmark==0.9.1 # via rich croniter==1.4.1 # via -r model-engine/requirements.in -cryptography==41.0.3 - # via secretstorage dataclasses-json==0.5.9 # via -r model-engine/requirements.in datadog==0.47.0 @@ -127,7 +127,7 @@ docutils==0.20.1 # via readme-renderer envier==0.4.0 # via ddtrace -exceptiongroup==1.1.3 +exceptiongroup==1.2.0 # via # anyio # cattrs @@ -185,7 +185,7 @@ importlib-metadata==6.8.0 # keyring # quart # twine -importlib-resources==6.1.0 +importlib-resources==6.1.1 # via # alembic # jsonschema @@ -195,10 +195,6 @@ itsdangerous==2.1.2 # via quart jaraco-classes==3.3.0 # via keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.0.3 # via # -r model-engine/requirements.in @@ -300,8 +296,6 @@ pyasn1==0.5.0 # rsa pyasn1-modules==0.3.0 # via google-auth -pycparser==2.21 - # via cffi pycurl==7.45.2 # via # -r model-engine/requirements.in @@ -326,7 +320,7 @@ python-dateutil==2.8.2 # pg8000 python-multipart==0.0.6 # via -r model-engine/requirements.in -pyyaml==6.0 +pyyaml==6.0.1 # via # huggingface-hub # kubeconfig @@ -379,8 +373,6 @@ safetensors==0.4.0 # via transformers scramp==1.4.4 # via pg8000 -secretstorage==3.3.3 - # via keyring sentencepiece==0.1.99 # via -r model-engine/requirements.in sh==1.14.3 @@ -409,6 +401,7 @@ sqlalchemy[asyncio]==2.0.4 # via # -r model-engine/requirements.in # alembic + # sqlalchemy sse-starlette==1.6.1 # via -r model-engine/requirements.in sseclient-py==1.7.2 @@ -525,8 +518,4 @@ zipp==3.16.0 # importlib-resources # The following packages are considered to be unsafe in a requirements file: -setuptools==68.0.0 - # via - # gunicorn - # kubernetes - # kubernetes-asyncio +# setuptools diff --git a/model-engine/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py index 9b085915..b77071b7 100644 --- a/model-engine/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -106,6 +106,7 @@ def get_test_client( fake_file_system_gateway_contents=None, fake_trigger_repository_contents=None, fake_cron_job_gateway_contents=None, + fake_sync_inference_content=None, ) -> TestClient: if fake_docker_image_batch_job_gateway_contents is None: fake_docker_image_batch_job_gateway_contents = {} @@ -131,6 +132,8 @@ def get_test_client( fake_trigger_repository_contents = {} if fake_cron_job_gateway_contents is None: fake_cron_job_gateway_contents = {} + if fake_sync_inference_content is None: + fake_sync_inference_content = {} app.dependency_overrides[get_external_interfaces] = get_repositories_generator_wrapper( fake_docker_repository_image_always_exists=fake_docker_repository_image_always_exists, fake_model_bundle_repository_contents=fake_model_bundle_repository_contents, @@ -145,6 +148,7 @@ def get_test_client( fake_file_system_gateway_contents=fake_file_system_gateway_contents, fake_trigger_repository_contents=fake_trigger_repository_contents, fake_cron_job_gateway_contents=fake_cron_job_gateway_contents, + fake_sync_inference_content=fake_sync_inference_content, ) app.dependency_overrides[get_external_interfaces_read_only] = app.dependency_overrides[ get_external_interfaces diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 32178b49..b904ed44 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -4,6 +4,7 @@ import pytest from model_engine_server.common.dtos.llms import GetLLMModelEndpointV1Response +from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus from model_engine_server.domain.entities import ModelEndpoint @@ -102,6 +103,17 @@ def test_completion_sync_success( fake_batch_job_record_repository_contents={}, fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, + fake_sync_inference_content=SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": """{ + "text": "output", + "count_prompt_tokens": 1, + "count_output_tokens": 1 + }""" + }, + traceback=None, + ), ) response_1 = client.post( f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", @@ -109,7 +121,12 @@ def test_completion_sync_success( json=completion_sync_request, ) assert response_1.status_code == 200 - assert response_1.json()["output"] is None + assert response_1.json()["output"] == { + "text": "output", + "num_completion_tokens": 1, + "num_prompt_tokens": 1, + "tokens": None, + } assert response_1.json().keys() == {"output", "request_id"} diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 03ae16b8..684b4fac 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -1434,12 +1434,15 @@ async def streaming_predict( class FakeSyncModelEndpointInferenceGateway(SyncModelEndpointInferenceGateway): - def __init__(self): - self.response = SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result=None, - traceback=None, - ) + def __init__(self, fake_sync_inference_content=None): + if not fake_sync_inference_content: + self.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result=None, + traceback=None, + ) + else: + self.response = fake_sync_inference_content async def predict( self, topic: str, predict_request: EndpointPredictV1Request @@ -2111,6 +2114,7 @@ def get_repositories_generator( fake_file_storage_gateway_contents, fake_trigger_repository_contents, fake_file_system_gateway_contents, + fake_sync_inference_content, ): def get_test_repositories() -> Iterator[ExternalInterfaces]: fake_file_system_gateway = FakeFilesystemGateway() @@ -2131,7 +2135,9 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: streaming_model_endpoint_inference_gateway = ( FakeStreamingModelEndpointInferenceGateway() ) - sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway() + sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway( + fake_sync_inference_content + ) inference_autoscaling_metrics_gateway = FakeInferenceAutoscalingMetricsGateway() model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=FakeFilesystemGateway(), @@ -3584,7 +3590,7 @@ def llm_model_endpoint_sync( "_llm": { "model_name": "llama-7b", "source": "hugging_face", - "inference_framework": "deepspeed", + "inference_framework": "vllm", "inference_framework_image_tag": "123", "num_shards": 4, } @@ -3646,7 +3652,7 @@ def llm_model_endpoint_sync( "model_name": "llama-7b", "source": "hugging_face", "status": "READY", - "inference_framework": "deepspeed", + "inference_framework": "vllm", "inference_framework_image_tag": "123", "num_shards": 4, "spec": { @@ -3659,7 +3665,7 @@ def llm_model_endpoint_sync( "_llm": { "model_name": "llama-7b", "source": "hugging_face", - "inference_framework": "deepspeed", + "inference_framework": "vllm", "inference_framework_image_tag": "123", "num_shards": 4, } @@ -3833,6 +3839,270 @@ def llm_model_endpoint_sync_tgi( return model_endpoint, model_endpoint_json +@pytest.fixture +def llm_model_endpoint_sync_lightllm( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_2", + name="test_llm_model_endpoint_name_1", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "lightllm", + "inference_framework_image_tag": "0.9.4", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.SYNC, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + __root__=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + model_endpoint_json: Dict[str, Any] = { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "model_name": "llama-7b", + "source": "hugging_face", + "status": "READY", + "inference_framework": "lightllm", + "inference_framework_image_tag": "0.9.4", + "num_shards": 4, + "spec": { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "endpoint_type": "sync", + "destination": "test_destination", + "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", + "metadata": { + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "lightllm", + "inference_framework_image_tag": "0.9.4", + "num_shards": 4, + } + }, + "bundle_name": "test_model_bundle_name_1", + "status": "READY", + "post_inference_hooks": ["callback"], + "default_callback_url": "http://www.example.com", + "default_callback_auth": { + "kind": "basic", + "username": "test_username", + "password": "test_password", + }, + "labels": {}, + "aws_role": "test_aws_role", + "results_s3_bucket": "test_s3_bucket", + "created_by": test_api_key, + "created_at": "2022-01-03T00:00:00", + "last_updated_at": "2022-01-03T00:00:00", + "deployment_state": { + "min_workers": 1, + "max_workers": 3, + "per_worker": 2, + "available_workers": 1, + "unavailable_workers": 1, + }, + "resource_state": { + "cpus": "1", + "gpus": 1, + "memory": "1G", + "gpu_type": "nvidia-tesla-t4", + "storage": "10G", + "optimize_costs": True, + }, + "num_queued_items": 1, + "public_inference": True, + }, + } + return model_endpoint, model_endpoint_json + + +@pytest.fixture +def llm_model_endpoint_sync_trt_llm( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_2", + name="test_llm_model_endpoint_name_1", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "tensorrt_llm", + "inference_framework_image_tag": "0.9.4", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.SYNC, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + __root__=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + model_endpoint_json: Dict[str, Any] = { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "model_name": "llama-7b", + "source": "hugging_face", + "status": "READY", + "inference_framework": "tensorrt_llm", + "inference_framework_image_tag": "0.9.4", + "num_shards": 4, + "spec": { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "endpoint_type": "sync", + "destination": "test_destination", + "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", + "metadata": { + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "tensorrt_llm", + "inference_framework_image_tag": "0.9.4", + "num_shards": 4, + } + }, + "bundle_name": "test_model_bundle_name_1", + "status": "READY", + "post_inference_hooks": ["callback"], + "default_callback_url": "http://www.example.com", + "default_callback_auth": { + "kind": "basic", + "username": "test_username", + "password": "test_password", + }, + "labels": {}, + "aws_role": "test_aws_role", + "results_s3_bucket": "test_s3_bucket", + "created_by": test_api_key, + "created_at": "2022-01-03T00:00:00", + "last_updated_at": "2022-01-03T00:00:00", + "deployment_state": { + "min_workers": 1, + "max_workers": 3, + "per_worker": 2, + "available_workers": 1, + "unavailable_workers": 1, + }, + "resource_state": { + "cpus": "1", + "gpus": 1, + "memory": "1G", + "gpu_type": "nvidia-tesla-t4", + "storage": "10G", + "optimize_costs": True, + }, + "num_queued_items": 1, + "public_inference": True, + }, + } + return model_endpoint, model_endpoint_json + + @pytest.fixture def llm_model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> ModelEndpoint: # model_bundle_5 is a runnable bundle diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 99fab709..a9d32975 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1,3 +1,4 @@ +import json from typing import Any, List, Tuple from unittest import mock @@ -552,41 +553,39 @@ async def test_completion_sync_use_case_success( SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, result={ - "result": [ + "result": json.dumps( { - "error": None, "text": "I am a newbie to the world of programming.", - "token_probs": { - "tokens": [ - "I", - " am", - " a", - " new", - "bie", - " to", - " the", - " world", - " of", - " programming", - ".", - ], - "token_probs": [ - 0.1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - ], - }, - "tokens_consumed": 25, + "tokens": [ + "I", + " am", + " a", + " new", + "bie", + " to", + " the", + " world", + " of", + " programming", + ".", + ], + "log_probs": [ + {1: -2.3025850929940455}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + ], + "count_prompt_tokens": 7, + "count_output_tokens": 11, } - ] + ) }, traceback=None, ) @@ -803,12 +802,75 @@ async def test_completion_sync_use_case_predict_failed( tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = await use_case.execute( - user=user, - model_endpoint_name=llm_model_endpoint_sync[0].record.name, - request=completion_sync_request, + with pytest.raises(UpstreamServiceError): + await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_sync[0].record.name, + request=completion_sync_request, + ) + + +@pytest.mark.asyncio +async def test_completion_sync_use_case_predict_failed_lightllm( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_sync_lightllm: Tuple[ModelEndpoint, Any], + completion_sync_request: CompletionSyncV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_lightllm[0]) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( + SyncEndpointPredictV1Response( + status=TaskStatus.FAILURE, + result=None, + traceback="failed to predict", + ) ) - assert response_1.output is None + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(UpstreamServiceError): + await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_sync_lightllm[0].record.name, + request=completion_sync_request, + ) + + +@pytest.mark.asyncio +async def test_completion_sync_use_case_predict_failed_trt_llm( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_sync_trt_llm: Tuple[ModelEndpoint, Any], + completion_sync_request: CompletionSyncV1Request, +): + completion_sync_request.return_token_log_probs = False # not yet supported + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_trt_llm[0]) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( + SyncEndpointPredictV1Response( + status=TaskStatus.FAILURE, + result=None, + traceback="failed to predict", + ) + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(UpstreamServiceError): + await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_sync_trt_llm[0].record.name, + request=completion_sync_request, + ) @pytest.mark.asyncio From d0061f24bc64c708f6716234dc88f206d29a7f60 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Wed, 27 Dec 2023 23:45:31 -0800 Subject: [PATCH 212/425] handle update endpoint errors (#414) * handle invalid values errors * update test cases * fix when gpus = 0 * fix when gpu_type is None * throw explicit 500 for EndpointInfraStateNotFound --- .../api/model_endpoints_v1.py | 25 +- .../use_cases/model_endpoint_use_cases.py | 19 +- .../domain/test_model_endpoint_use_cases.py | 298 ++++++++++++++++++ 3 files changed, 328 insertions(+), 14 deletions(-) diff --git a/model-engine/model_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py index 807393cd..eece8be3 100644 --- a/model-engine/model_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -26,6 +26,7 @@ from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, + EndpointInfraStateNotFound, EndpointLabelsException, EndpointResourceInvalidRequestException, ExistingEndpointOperationInProgressException, @@ -67,14 +68,11 @@ async def create_model_endpoint( status_code=400, detail="The specified model endpoint already exists.", ) from exc - except EndpointLabelsException as exc: - raise HTTPException( - status_code=400, - detail=str(exc), - ) from exc - except ObjectHasInvalidValueException as exc: - raise HTTPException(status_code=400, detail=str(exc)) - except EndpointResourceInvalidRequestException as exc: + except ( + EndpointLabelsException, + ObjectHasInvalidValueException, + EndpointResourceInvalidRequestException, + ) as exc: raise HTTPException( status_code=400, detail=str(exc), @@ -148,7 +146,11 @@ async def update_model_endpoint( return await use_case.execute( user=auth, model_endpoint_id=model_endpoint_id, request=request ) - except EndpointLabelsException as exc: + except ( + EndpointLabelsException, + ObjectHasInvalidValueException, + EndpointResourceInvalidRequestException, + ) as exc: raise HTTPException( status_code=400, detail=str(exc), @@ -163,6 +165,11 @@ async def update_model_endpoint( status_code=409, detail="Existing operation on endpoint in progress, try again later.", ) from exc + except EndpointInfraStateNotFound as exc: + raise HTTPException( + status_code=500, + detail="Endpoint infra state not found, try again later.", + ) from exc @model_endpoint_router_v1.delete( diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index 8c58cdcc..1f128e65 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -362,13 +362,22 @@ async def execute( # infra_state to make sure that after the update, all resources are valid and in sync. # E.g. If user only want to update gpus and leave gpu_type as None, we use the existing gpu_type # from infra_state to avoid passing in None to validate_resource_requests. + raw_request = request.dict(exclude_unset=True) validate_resource_requests( bundle=bundle, - cpus=request.cpus or infra_state.resource_state.cpus, - memory=request.memory or infra_state.resource_state.memory, - storage=request.storage or infra_state.resource_state.storage, - gpus=request.gpus or infra_state.resource_state.gpus, - gpu_type=request.gpu_type or infra_state.resource_state.gpu_type, + cpus=(request.cpus if "cpus" in raw_request else infra_state.resource_state.cpus), + memory=( + request.memory if "memory" in raw_request else infra_state.resource_state.memory + ), + storage=( + request.storage if "storage" in raw_request else infra_state.resource_state.storage + ), + gpus=(request.gpus if "gpus" in raw_request else infra_state.resource_state.gpus), + gpu_type=( + request.gpu_type + if "gpu_type" in raw_request + else infra_state.resource_state.gpu_type + ), ) validate_deployment_resources( diff --git a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py index 49e017fa..d0b27514 100644 --- a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py @@ -27,6 +27,7 @@ ObjectNotFoundException, ) from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( + CONVERTED_FROM_ARTIFACT_LIKE_KEY, CreateModelEndpointV1UseCase, DeleteModelEndpointByIdV1UseCase, GetModelEndpointByIdV1UseCase, @@ -855,6 +856,303 @@ async def test_update_model_endpoint_team_success( assert isinstance(response, UpdateModelEndpointV1Response) +@pytest.mark.asyncio +async def test_update_model_endpoint_use_case_raises_invalid_value_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_2: ModelBundle, + model_endpoint_1: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = update_model_endpoint_request.copy() + request.metadata = {CONVERTED_FROM_ARTIFACT_LIKE_KEY: False} + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + +@pytest.mark.asyncio +async def test_update_model_endpoint_use_case_raises_resource_request_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + model_bundle_2: ModelBundle, + model_bundle_4: ModelBundle, + model_bundle_6: ModelBundle, + model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage: ModelBundle, + model_endpoint_1: ModelEndpoint, + model_endpoint_2: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_bundle_repository.add_model_bundle(model_bundle_4) + fake_model_bundle_repository.add_model_bundle(model_bundle_6) + fake_model_bundle_repository.add_model_bundle( + model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage + ) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_2) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = update_model_endpoint_request.copy() + request.cpus = -1 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.cpus = float("inf") + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.memory = "invalid_memory_amount" + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.memory = float("inf") + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.storage = "invalid_storage_amount" + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.storage = float("inf") + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + # specific to sync endpoint + request = update_model_endpoint_request.copy() + request.min_workers = 0 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_2.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.max_workers = 2**63 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.gpus = 0 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.gpu_type = None + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.gpu_type = "invalid_gpu_type" + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + instance_limits = REQUESTS_BY_GPU_TYPE[model_endpoint_1.infra_state.resource_state.gpu_type] + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_1.id + # Test that request.cpus + FORWARDER_CPU_USAGE > instance_limits["cpus"] should fail + request.cpus = instance_limits["cpus"] + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_1.id + # Test that request.memory + FORWARDER_MEMORY_USAGE > instance_limits["memory"] should fail + request.memory = instance_limits["memory"] + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_1.id + # Test that request.storage + FORWARDER_STORAGE_USAGE > STORAGE_LIMIT should fail + request.storage = STORAGE_LIMIT + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_4.id + # Test that request.cpus + FORWARDER_CPU_USAGE > instance_limits["cpus"] should fail + request.cpus = instance_limits["cpus"] + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_4.id + # Test that request.memory + FORWARDER_MEMORY_USAGE > instance_limits["memory"] should fail + request.memory = instance_limits["memory"] + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_4.id + # Test that request.storage + FORWARDER_STORAGE_USAGE > STORAGE_LIMIT should fail + request.storage = STORAGE_LIMIT + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + # Test TritonEnhancedRunnableImageFlavor specific validation logic + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # TritonEnhancedRunnableImageFlavor requires gpu >= 1 + request.gpus = 0.9 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # TritonEnhancedRunnableImageFlavor requires gpu_type be specified + request.gpu_type = None + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # Test that request.cpus + FORWARDER_CPU_USAGE + triton_num_cpu > instance_limits["cpu"] should fail + request.cpus = instance_limits["cpus"] - FORWARDER_CPU_USAGE + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # Test that request.memory + FORWARDER_MEMORY_USAGE + triton_memory > instance_limits["memory"] should fail + request.memory = parse_mem_request(instance_limits["memory"]) - parse_mem_request( + FORWARDER_MEMORY_USAGE + ) + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # Test that request.storage + FORWARDER_STORAGE_USAGE + triton_storage > STORAGE_LIMIT should fail + request.storage = parse_mem_request(STORAGE_LIMIT) - parse_mem_request(FORWARDER_STORAGE_USAGE) + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + # Test triton_num_cpu >= 1 + request.model_bundle_id = ( + model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage.id + ) + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + @pytest.mark.asyncio async def test_update_model_endpoint_raises_not_found( fake_model_bundle_repository, From 1f9c4619a0787cdc1f78b6fdb4ffcdd5ec1d36ed Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 2 Jan 2024 16:24:43 -0800 Subject: [PATCH 213/425] [bug-fix] LLM Artifact Gateway .list_files() (#416) * fix test to catch use case * fix parsing for prefixes --- model-engine/model_engine_server/core/utils/url.py | 4 ++-- .../infra/gateways/s3_llm_artifact_gateway.py | 8 +++++--- .../unit/infra/gateways/test_s3_llm_artifact_gateway.py | 8 +++++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/model-engine/model_engine_server/core/utils/url.py b/model-engine/model_engine_server/core/utils/url.py index ec80747c..358f316c 100644 --- a/model-engine/model_engine_server/core/utils/url.py +++ b/model-engine/model_engine_server/core/utils/url.py @@ -32,7 +32,7 @@ class InvalidAttachmentUrl(ValueError): pass -def parse_attachment_url(url: str) -> ParsedURL: +def parse_attachment_url(url: str, clean_key: bool = True) -> ParsedURL: """Extracts protocol, bucket, region, and key from the :param:`url`. :raises: InvalidAttachmentUrl Iff the input `url` is not a valid AWS S3 or GCS url. @@ -102,5 +102,5 @@ def clean(v): protocol=clean(protocol), bucket=clean(bucket), region=clean(region), - key=clean(key), + key=clean(key) if clean_key else key, ) diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py index 2582d40d..ebc6b2fd 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -23,7 +23,7 @@ def _get_s3_resource(self, kwargs): def list_files(self, path: str, **kwargs) -> List[str]: s3 = self._get_s3_resource(kwargs) - parsed_remote = parse_attachment_url(path) + parsed_remote = parse_attachment_url(path, clean_key=False) bucket = parsed_remote.bucket key = parsed_remote.key @@ -33,7 +33,7 @@ def list_files(self, path: str, **kwargs) -> List[str]: def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: s3 = self._get_s3_resource(kwargs) - parsed_remote = parse_attachment_url(path) + parsed_remote = parse_attachment_url(path, clean_key=False) bucket = parsed_remote.bucket key = parsed_remote.key @@ -58,7 +58,9 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: s3 = self._get_s3_resource(kwargs) - parsed_remote = parse_attachment_url(hmi_config.hf_user_fine_tuned_weights_prefix) + parsed_remote = parse_attachment_url( + hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False + ) bucket = parsed_remote.bucket fine_tuned_weights_prefix = parsed_remote.key diff --git a/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py b/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py index 7dcf19a6..9e989959 100644 --- a/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py @@ -14,7 +14,7 @@ def llm_artifact_gateway(): @pytest.fixture def fake_files(): - return ["fake-prefix/fake1", "fake-prefix/fake2", "fake-prefix/fake3"] + return ["fake-prefix/fake1", "fake-prefix/fake2", "fake-prefix/fake3", "fake-prefix-ext/fake1"] def mock_boto3_session(fake_files: List[str]): @@ -39,11 +39,13 @@ def filter_files(*args, **kwargs): lambda *args, **kwargs: None, # noqa ) def test_s3_llm_artifact_gateway_download_folder(llm_artifact_gateway, fake_files): - prefix = "/".join(fake_files[0].split("/")[:-1]) + prefix = "/".join(fake_files[0].split("/")[:-1]) + "/" uri_prefix = f"s3://fake-bucket/{prefix}" target_dir = "fake-target" - expected_files = [f"{target_dir}/{file.split('/')[-1]}" for file in fake_files] + expected_files = [ + f"{target_dir}/{file.split('/')[-1]}" for file in fake_files if file.startswith(prefix) + ] with mock.patch( "model_engine_server.infra.gateways.s3_llm_artifact_gateway.boto3.Session", mock_boto3_session(fake_files), From 756682f159ac2bca31eddcc559c5fb049f3460ac Mon Sep 17 00:00:00 2001 From: William Song Date: Wed, 3 Jan 2024 16:06:39 -0800 Subject: [PATCH 214/425] enable sensitive log mode (#415) Enable sensitive log mode --- charts/model-engine/values_circleci.yaml | 1 + charts/model-engine/values_sample.yaml | 1 + .../model_engine_server/api/llms_v1.py | 15 ++++---- .../model_engine_server/common/config.py | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 3 ++ .../service_config_circleci.yaml | 1 + model-engine/tests/unit/api/test_llms.py | 35 ++++++++++++------- 7 files changed, 39 insertions(+), 18 deletions(-) diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index fd07f361..758bf472 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -146,6 +146,7 @@ config: s3_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET/fine_tune_repository" dd_trace_enabled: false istio_enabled: true + sensitive_log_mode: false tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 2ff37197..65e41928 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -163,6 +163,7 @@ config: # dd_trace_enabled specifies whether to enable datadog tracing, datadog must be installed in the cluster dd_trace_enabled: false istio_enabled: true + sensitive_log_mode: false # Asynchronous endpoints configs (coming soon) sqs_profile: default diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 34f6b7f9..352adeb1 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -12,6 +12,7 @@ get_external_interfaces_read_only, verify_authentication, ) +from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.llms import ( CancelFineTuneResponse, CompletionStreamV1Request, @@ -307,9 +308,10 @@ async def create_completion_sync_task( """ Runs a sync prompt completion on an LLM. """ - logger.info( - f"POST /completion_sync with {request} to endpoint {model_endpoint_name} for {auth}" - ) + if not hmi_config.sensitive_log_mode: + logger.info( + f"POST /completion_sync with {request} to endpoint {model_endpoint_name} for {auth}" + ) try: use_case = CompletionSyncV1UseCase( model_endpoint_service=external_interfaces.model_endpoint_service, @@ -369,9 +371,10 @@ async def create_completion_stream_task( """ Runs a stream prompt completion on an LLM. """ - logger.info( - f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" - ) + if not hmi_config.sensitive_log_mode: # pragma: no cover + logger.info( + f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" + ) use_case = CompletionStreamV1UseCase( model_endpoint_service=external_interfaces.model_endpoint_service, llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index f2b33eea..d2736576 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -62,6 +62,7 @@ class HostedModelInferenceServiceConfig: user_inference_pytorch_repository: str user_inference_tensorflow_repository: str docker_image_layer_cache_repository: str + sensitive_log_mode: bool @classmethod def from_yaml(cls, yaml_path): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 0693db9c..67f26fe1 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -682,6 +682,9 @@ async def create_vllm_bundle( else: raise InvalidRequestException(f"Quantization {quantize} is not supported by vLLM.") + if hmi_config.sensitive_log_mode: # pragma: no cover + subcommands[-1] = subcommands[-1] + " --disable-log-requests" + command = [ "/bin/bash", "-c", diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index 683a3e2b..d37172ec 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -54,6 +54,7 @@ s3_file_llm_fine_tune_repository: "s3://model-engine-integration-tests/fine_tune dd_trace_enabled: false istio_enabled: true +sensitive_log_mode: false tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index b904ed44..7bf64660 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -1,6 +1,6 @@ import json -import re from typing import Any, Dict, Tuple +from unittest import mock import pytest from model_engine_server.common.dtos.llms import GetLLMModelEndpointV1Response @@ -156,12 +156,14 @@ def test_completion_sync_endpoint_not_found_returns_404( assert response_1.status_code == 404 +# When enabling this test, other tests fail with "RunTumeError got Future attached to a different loop" +# https://github.com/encode/starlette/issues/1315#issuecomment-980784457 @pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") def test_completion_stream_success( llm_model_endpoint_streaming: ModelEndpoint, completion_stream_request: Dict[str, Any], get_test_client_wrapper, -): +): # pragma: no cover client = get_test_client_wrapper( fake_docker_repository_image_always_exists=True, fake_model_bundle_repository_contents={}, @@ -175,19 +177,28 @@ def test_completion_stream_success( fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, ) - response_1 = client.post( - f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", - auth=("no_user", ""), - json=completion_stream_request, - stream=True, - ) + with mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=5, + ): + response_1 = client.post( + f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + auth=("no_user", ""), + json=completion_stream_request, + stream=True, + ) assert response_1.status_code == 200 count = 0 for message in response_1: - assert re.fullmatch( - 'data: {"request_id"}: ".*", "output": null}\r\n\r\n', - message.decode("utf-8"), - ) + decoded_message = message.decode("utf-8") + assert decoded_message.startswith("data: "), "SSE does not start with 'data: '" + + # strip 'data: ' prefix from Server-sent events format + json_str = decoded_message[len("data: ") :] + parsed_data = json.loads(json_str.strip()) + assert parsed_data["request_id"] is not None + assert parsed_data["output"] is None + assert parsed_data["error"] is None count += 1 assert count == 1 From 13fa6eb1ef572ae7305d47b3344c060ca3100496 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 8 Jan 2024 13:45:40 -0800 Subject: [PATCH 215/425] Throughput benchmark script (#411) * Move throughput benchmark script * support local * update * output csv * append * add concurrency * fix yaml --------- Co-authored-by: Ubuntu --- scripts/requirements.txt | 5 + scripts/throughput_benchmarks.py | 389 +++++++++++++++++++++++++++++++ 2 files changed, 394 insertions(+) create mode 100644 scripts/requirements.txt create mode 100644 scripts/throughput_benchmarks.py diff --git a/scripts/requirements.txt b/scripts/requirements.txt new file mode 100644 index 00000000..18993b10 --- /dev/null +++ b/scripts/requirements.txt @@ -0,0 +1,5 @@ +numpy==1.24.4 +typer==0.9.0 +lorem-text==2.1 +transformers==4.36.0 +chardet==5.2.0 \ No newline at end of file diff --git a/scripts/throughput_benchmarks.py b/scripts/throughput_benchmarks.py new file mode 100644 index 00000000..a60267e5 --- /dev/null +++ b/scripts/throughput_benchmarks.py @@ -0,0 +1,389 @@ +import csv +import json +import os +import queue +import random +import threading +import time +import traceback +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +import numpy as np +import requests +import typer +from lorem_text import lorem +from transformers import AutoTokenizer + +AUTH_USER_ID = os.getenv("AUTH_USER_ID") +GATEWAY_URL = os.getenv("GATEWAY_URL") +app = typer.Typer(name="throughput-benchmarks", add_completion=False) + +MAX_CONTEXT_WINDOW = 4096 + + +@dataclass +class BenchmarkConfig: + def __init__(self, input_token_count, output_token_count_mean): + self.input_token_count = input_token_count + self.output_token_count_mean = output_token_count_mean + # Here we assume 3x standard deviation is enough to cover the range of output token counts. + # Also assume 3x stddev is rougly half of the mean. + self.output_token_count_std = output_token_count_mean / 6.0 + + def __repr__(self) -> str: + return f"BenchmarkConfig(input_token_count={self.input_token_count}, output_token_count_mean={self.output_token_count_mean}, output_token_count_std={self.output_token_count_std})" + + +HF_MODEL_MAPPING = { + "llama-2-7b": "meta-llama/Llama-2-7b-hf", + "llama-2-13b": "meta-llama/Llama-2-13b-hf", +} + + +class InferenceFramework(Enum): + TEXT_GENERATION_INFERENCE = "tgi" + VLLM = "vllm" + LIGHTLLM = "lightllm" + TENSORRT_LLM = "tensorrt-llm" + + @classmethod + def from_value(cls, value): + for member in cls: + if member.value == value: + return member + raise ValueError(f"No member with value {value} in {cls.__name__}") + + +def send_request(url, request, user=None): + start = time.time() + response = requests.post( + url, + json=request, + auth=(user, ""), + stream=True, + ) + first_line = True + for byte_payload in response.iter_lines(): + if first_line: + time_to_first_token = time.time() - start + first_line = False + + # Skip line + if byte_payload == b"\n": + continue + + payload = byte_payload.decode("utf-8") + + # Event data + if payload.startswith("data:"): + payload_data = payload.lstrip("data:").rstrip("/n") + payload_json = json.loads(payload_data) + + return { + "payload": payload_json, + "time_to_first_token": time_to_first_token, + "total_time": time.time() - start, + } + + +def pull_and_send_request_from_queue( + model: str, + request_queue: queue.Queue, + result_queue: queue.Queue, + use_localhost: bool, + framework: InferenceFramework, + local_port: int = 5005, +): + while not request_queue.empty(): + request = request_queue.get() + if use_localhost: + if framework == InferenceFramework.VLLM: + response = send_request(f"http://localhost:{local_port}/stream", request) + response["num_completion_tokens"] = response["payload"]["count_output_tokens"] + else: + raise NotImplementedError() + else: + response = send_request( + f"{GATEWAY_URL}/v1/llm/completions-stream?model_endpoint_name={model}", + request, + AUTH_USER_ID, + ) + response["num_completion_tokens"] = response["payload"]["output"][ + "num_completion_tokens" + ] + + result_queue.put(response) + + +def generate_request( + framework: InferenceFramework, prompt: str, output_token_count: int, localhost: bool +): + if not localhost: + return {"prompt": prompt, "max_new_tokens": output_token_count, "temperature": 0.0} + + if framework == InferenceFramework.TEXT_GENERATION_INFERENCE: + return { + "parameters": { + "do_sample": False, + "max_new_tokens": output_token_count, + "details": False, + }, + "inputs": prompt, + } + elif framework == InferenceFramework.VLLM: + return { + "prompt": prompt, + "max_tokens": output_token_count, + "temperature": 0, + "stream": True, + } + elif framework == InferenceFramework.LIGHTLLM: + return { + "parameters": { + "do_sample": False, + "max_new_tokens": output_token_count, + }, + "inputs": prompt, + } + elif framework == InferenceFramework.TENSORRT_LLM: + return { + "max_tokens": output_token_count, + "text_input": prompt, + "bad_words": "", + "stop_words": "", + } + else: + raise NotImplementedError() + + +def send_requests( + model: str, + prompt: str, + output_token_counts: List[int], + use_localhost: bool, + concurrency: int, + framework: InferenceFramework, + local_port: int = 5005, +): + thread_results: queue.Queue = queue.Queue() + requests_queue: queue.Queue = queue.Queue() + for output_token_count in output_token_counts: + request = generate_request(framework, prompt, output_token_count, use_localhost) + requests_queue.put(request) + threads = [] + for i in range(concurrency): + thread = threading.Thread( + target=pull_and_send_request_from_queue, + args=( + model, + requests_queue, + thread_results, + use_localhost, + framework, + local_port, + ), + ) + thread.start() + threads.append(thread) + + for thread in threads: + thread.join() + + results = [] + while not thread_results.empty(): + results.append(thread_results.get()) + + return results + + +def generate_prompt(num, hf_model): + random.seed(1) + text = lorem.words(num // 2) # Roughly 2 tokens per lorem word + tokenizer = AutoTokenizer.from_pretrained(hf_model) + return tokenizer.decode(tokenizer.encode(text)[: num - 2]) + + +def generate_output_token_counts(mean, std, num, input_token_count): + output = np.random.normal(mean, std, num).astype(int).tolist() + + for i in range(len(output)): + output[i] = min(output[i], MAX_CONTEXT_WINDOW - input_token_count) + return output + + +def run_benchmark( + model: str, + framework: InferenceFramework, + hf_model: str, + config: BenchmarkConfig, + num_trials: int, + use_localhost: bool, + concurrency: int, + verbose: bool, + local_port: int, +): + prompt = generate_prompt(config.input_token_count, hf_model) + + prompt_num_tokens = config.input_token_count + + output_token_counts = generate_output_token_counts( + config.output_token_count_mean, + config.output_token_count_std, + num_trials, + config.input_token_count, + ) + + start = time.time() + results = send_requests( + model, + prompt, + output_token_counts, + use_localhost, + concurrency, + framework, + local_port=local_port, + ) + end = time.time() + elapsed = end - start + results = [result for result in results if result is not None] + + num_sampled_tokens = sum([result["num_completion_tokens"] for result in results]) + num_prompt_tokens = prompt_num_tokens * len(results) + n = len(results) + time_to_process_prompt = [] + time_per_completion = [] + time_to_first_token = [] + inter_token_latency = [] + for result in results: + avg_time_per_token = (result["total_time"] - result["time_to_first_token"]) / ( + result["num_completion_tokens"] - 1 + ) + time_to_first_token.append(result["time_to_first_token"]) + time_to_process_prompt.append(result["time_to_first_token"] - avg_time_per_token) + time_per_completion.append(result["total_time"] - time_to_process_prompt[-1]) + inter_token_latency.append(avg_time_per_token) + + total_num_tokens = num_sampled_tokens + num_prompt_tokens + avg_prefill_time = sum(time_to_process_prompt) / n + avg_completion_time = sum(time_per_completion) / n + + statistics = { + "concurrency": concurrency, + "avg_prompt_throughput": num_prompt_tokens + / (elapsed * avg_prefill_time / (avg_prefill_time + avg_completion_time)), + "avg_time_to_first_token": sum(time_to_first_token) / n, + "avg_sampling_throughput": num_sampled_tokens + / (elapsed * avg_completion_time / (avg_prefill_time + avg_completion_time)), + "avg_total_throughput": total_num_tokens / elapsed, + "avg_per_session_sampling_throughput": num_sampled_tokens + / (elapsed * avg_completion_time / (avg_prefill_time + avg_completion_time)) + / concurrency, + "avg_inter_token_latency": sum(inter_token_latency) / n, + "num_prompt_tokens": prompt_num_tokens, + "avg_num_sampled_tokens": num_sampled_tokens / n, + "elapsed_time": elapsed, + "avg_prefill_time": avg_prefill_time, + "avg_completion_time": avg_completion_time, + "num_requests": num_trials, + "num_successful_requests": n, + "total_num_tokens": total_num_tokens, + "total_num_sampled_tokens": num_sampled_tokens, + } + if verbose: + print(f"Statistics: {statistics}") + + # Sleep for 1 seconds between each benchmark. + time.sleep(1) + + return statistics + + +@app.command() +def run_benchmarks( + model: str, + framework: str, + input_token_count: int, + output_token_count_mean: int, + num_trials: int = 50, + output_file: Optional[str] = None, + use_localhost: bool = False, + concurrency: int = 1, + verbose: bool = False, + hf_model: Optional[str] = None, + local_port: int = 5005, +): + """Run benchmarks.""" + all_statistics = [] + config = BenchmarkConfig(input_token_count, output_token_count_mean) + try: + if verbose: + print(f"Running benchmark for config {config}") + if hf_model is None: + if model not in HF_MODEL_MAPPING: + raise ValueError( + f"--hf-model must be specified for model {model} since it's not in default mapping." + ) + hf_model = HF_MODEL_MAPPING[model] + statistics = run_benchmark( + model, + InferenceFramework.from_value(framework), + hf_model, + config, + num_trials, + use_localhost, + concurrency, + verbose, + local_port, + ) + all_statistics.append(statistics) + except Exception: + traceback.print_exc() + + if output_file is not None: + header = all_statistics[0].keys() + + with open(output_file, "a") as csvfile: + csv_writer = csv.DictWriter(csvfile, fieldnames=header) + csv_writer.writeheader() + csv_writer.writerows(all_statistics) + + +@app.command() +def run_benchmarks_concurrency_range( + model: str, + framework: str, + input_token_count: int, + output_token_count_mean: int, + num_trials_per_concurrency: int = 5, + output_file: Optional[str] = None, + use_localhost: bool = False, + concurrency_min: int = 1, + concurrency_max: int = 1, + verbose: bool = False, + hf_model: Optional[str] = None, + local_port: int = 5005, +): + if output_file is not None: + # Create empty file + with open(output_file, "w"): + pass + for concurrency in range(concurrency_min, concurrency_max + 1): + run_benchmarks( + model, + framework, + input_token_count, + output_token_count_mean, + num_trials_per_concurrency * concurrency, + output_file, + use_localhost, + concurrency, + verbose, + hf_model, + local_port, + ) + + +if __name__ == "__main__": + app() From e8bb27c22efef63943aefd1c6d98fec98c60b511 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 10 Jan 2024 13:15:26 -0800 Subject: [PATCH 216/425] Upgrade vllm to 0.2.7 (#417) * Upgrade vllm to 0.2.7 * remove megablocks * fix ipv6 --- model-engine/model_engine_server/inference/vllm/Dockerfile | 3 --- .../model_engine_server/inference/vllm/requirements.txt | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile b/model-engine/model_engine_server/inference/vllm/Dockerfile index 6f1a00c5..907e795d 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile @@ -4,9 +4,6 @@ RUN pip uninstall torch -y COPY requirements.txt /workspace/requirements.txt RUN pip install -r requirements.txt -# install special version of megablocks -RUN pip install git+https://github.com/stanford-futuredata/megablocks.git@5897cd6f254b7b3edf7a708a3a3314ecb54b6f78#egg=megablocks - RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index e2c3aa08..4cc6239a 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,3 +1,3 @@ ray==2.6.3 -vllm==0.2.5 +git+https://github.com/vllm-project/vllm.git@4b61c6b669e368c6850531815940d9a542b9f223#egg=vllm pydantic==1.10.13 From a5bfdb7dd2b0266974fe71726f8028bde1dedcc4 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 17 Jan 2024 12:45:19 -0800 Subject: [PATCH 217/425] LLM batch completions API (#418) * wip * wip * wip * batch run files * wip * fix * nonindex * fixes * log * fix vllm version * fix config * aws profile * add dumb-init and ddtrace * use batch to start * fix path * delete temp s3 files * fix unit test * Add unit tests * mypy * fix tests * comments --- .../service_template_config_map.yaml | 3 + charts/model-engine/values_circleci.yaml | 1 + charts/model-engine/values_sample.yaml | 1 + .../model_engine_server/api/llms_v1.py | 26 ++ .../model_engine_server/common/config.py | 1 + .../model_engine_server/common/dtos/llms.py | 119 +++++++- .../domain/entities/batch_job_entity.py | 1 + .../docker_image_batch_job_gateway.py | 3 + .../use_cases/llm_model_endpoint_use_cases.py | 168 ++++++++++- .../inference/batch_inference/Dockerfile_vllm | 18 ++ .../inference/batch_inference/__init__.py | 0 .../batch_inference/build_and_upload_image.sh | 21 ++ .../batch_inference/requirements.txt | 6 + .../batch_inference/sample_config.json | 11 + .../batch_inference/sample_data.json | 9 + .../inference/batch_inference/vllm_batch.py | 209 +++++++++++++ .../inference/vllm/vllm_server.py | 2 - .../live_docker_image_batch_job_gateway.py | 8 + .../gateways/resources/k8s_resource_types.py | 1 + .../service_config_circleci.yaml | 1 + model-engine/tests/unit/api/conftest.py | 21 ++ model-engine/tests/unit/api/test_llms.py | 22 ++ model-engine/tests/unit/conftest.py | 2 + model-engine/tests/unit/domain/conftest.py | 23 ++ .../tests/unit/domain/test_llm_use_cases.py | 77 +++++ model-engine/tests/unit/inference/conftest.py | 93 ++++++ .../tests/unit/inference/test_vllm_batch.py | 274 ++++++++++++++++++ .../unit/infra/gateways/k8s_fake_objects.py | 7 + 28 files changed, 1114 insertions(+), 14 deletions(-) create mode 100644 model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm create mode 100644 model-engine/model_engine_server/inference/batch_inference/__init__.py create mode 100755 model-engine/model_engine_server/inference/batch_inference/build_and_upload_image.sh create mode 100644 model-engine/model_engine_server/inference/batch_inference/requirements.txt create mode 100644 model-engine/model_engine_server/inference/batch_inference/sample_config.json create mode 100644 model-engine/model_engine_server/inference/batch_inference/sample_data.json create mode 100644 model-engine/model_engine_server/inference/batch_inference/vllm_batch.py create mode 100644 model-engine/tests/unit/inference/conftest.py create mode 100644 model-engine/tests/unit/inference/test_vllm_batch.py diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index b738ebb2..199a5b1d 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -615,6 +615,9 @@ data: backoffLimit: 0 activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED} + completions: ${BATCH_JOB_NUM_WORKERS} + parallelism: ${BATCH_JOB_NUM_WORKERS} + completionMode: "Indexed" template: metadata: labels: diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 758bf472..8d841c86 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -151,6 +151,7 @@ config: vllm_repository: "vllm" lightllm_repository: "lightllm" tensorrt_llm_repository: "tensorrt-llm" + batch_inference_vllm_repository: "llm-engine/batch-infer-vllm" user_inference_base_repository: "launch/inference" user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 65e41928..f8c3c66a 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -207,6 +207,7 @@ config: vllm_repository: "vllm" lightllm_repository: "lightllm" tensorrt_llm_repository: "tensorrt-llm" + batch_inference_vllm_repository: "llm-engine/batch-infer-vllm" user_inference_base_repository: "launch/inference" user_inference_pytorch_repository: "launch/inference/pytorch" user_inference_tensorflow_repository: "launch/inference/tf" diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 352adeb1..c076e2b1 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -19,6 +19,8 @@ CompletionStreamV1Response, CompletionSyncV1Request, CompletionSyncV1Response, + CreateBatchCompletionsRequest, + CreateBatchCompletionsResponse, CreateFineTuneRequest, CreateFineTuneResponse, CreateLLMModelEndpointV1Request, @@ -73,6 +75,7 @@ from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( CompletionStreamV1UseCase, CompletionSyncV1UseCase, + CreateBatchCompletionsUseCase, CreateLLMModelBundleV1UseCase, CreateLLMModelEndpointV1UseCase, DeleteLLMEndpointByNameUseCase, @@ -568,3 +571,26 @@ async def delete_llm_model_endpoint( status_code=500, detail="deletion of endpoint failed.", ) from exc + + +@llm_router_v1.post("/batch-completions", response_model=CreateBatchCompletionsResponse) +async def create_batch_completions( + request: CreateBatchCompletionsRequest, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> CreateBatchCompletionsResponse: + logger.info(f"POST /batch-completions with {request} for {auth}") + try: + use_case = CreateBatchCompletionsUseCase( + docker_image_batch_job_gateway=external_interfaces.docker_image_batch_job_gateway, + docker_repository=external_interfaces.docker_repository, + docker_image_batch_job_bundle_repo=external_interfaces.docker_image_batch_job_bundle_repository, + ) + return await use_case.execute(user=auth, request=request) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified endpoint could not be found.", + ) from exc + except (InvalidRequestException, ObjectHasInvalidValueException) as exc: + raise HTTPException(status_code=400, detail=str(exc)) diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index d2736576..86625450 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -58,6 +58,7 @@ class HostedModelInferenceServiceConfig: vllm_repository: str lightllm_repository: str tensorrt_llm_repository: str + batch_inference_vllm_repository: str user_inference_base_repository: str user_inference_pytorch_repository: str user_inference_tensorflow_repository: str diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 6e991e45..11c21e24 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -30,21 +30,21 @@ class CreateLLMModelEndpointV1Request(BaseModel): # LLM specific fields model_name: str source: LLMSource = LLMSource.HUGGING_FACE - inference_framework: LLMInferenceFramework = LLMInferenceFramework.DEEPSPEED - inference_framework_image_tag: str + inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM + inference_framework_image_tag: str = "latest" num_shards: int = 1 """ - Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models + Number of shards to distribute the model onto GPUs. """ quantize: Optional[Quantization] = None """ - Whether to quantize the model. Only affect behavior for text-generation-inference models + Whether to quantize the model. """ checkpoint_path: Optional[str] = None """ - Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models + Path to the checkpoint to load the model from. """ # General endpoint fields @@ -102,17 +102,17 @@ class UpdateLLMModelEndpointV1Request(BaseModel): inference_framework_image_tag: Optional[str] num_shards: Optional[int] """ - Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models + Number of shards to distribute the model onto GPUs. """ quantize: Optional[Quantization] """ - Whether to quantize the model. Only affect behavior for text-generation-inference models + Whether to quantize the model. """ checkpoint_path: Optional[str] """ - Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models + Path to the checkpoint to load the model from. """ # General endpoint fields @@ -220,7 +220,7 @@ class CompletionStreamV1Request(BaseModel): """ return_token_log_probs: Optional[bool] = False """ - Whether to return the log probabilities of the tokens. Only affects behavior for text-generation-inference models + Whether to return the log probabilities of the tokens. """ presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) """ @@ -359,3 +359,104 @@ class ModelDownloadResponse(BaseModel): class DeleteLLMEndpointResponse(BaseModel): deleted: bool + + +class CreateBatchCompletionsRequestContent(BaseModel): + prompts: List[str] + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + + +class CreateBatchCompletionsModelConfig(BaseModel): + model: str + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + labels: Dict[str, str] + """ + Labels to attach to the batch inference job. + """ + num_shards: Optional[int] = 1 + """ + Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. + System may decide to use a different number than the given value. + """ + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + seed: Optional[int] = None + """ + Random seed for the model. + """ + + +class CreateBatchCompletionsRequest(BaseModel): + """ + Request object for batch completions. + """ + + input_data_path: Optional[str] + output_data_path: str + """ + Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. + """ + content: Optional[CreateBatchCompletionsRequestContent] = None + """ + Either `input_data_path` or `content` needs to be provided. + When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. + """ + model_config: CreateBatchCompletionsModelConfig + """ + Model configuration for the batch inference. Hardware configurations are inferred. + """ + data_parallelism: Optional[int] = Field(default=1, ge=1, le=64) + """ + Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference. + """ + max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600) + """ + Maximum runtime of the batch inference in seconds. Default to one day. + """ + + +class CreateBatchCompletionsResponse(BaseModel): + job_id: str + + +class GetBatchCompletionsResponse(BaseModel): + progress: float + """ + Progress of the batch inference in percentage from 0 to 100. + """ + finished: bool diff --git a/model-engine/model_engine_server/domain/entities/batch_job_entity.py b/model-engine/model_engine_server/domain/entities/batch_job_entity.py index e80f9fd4..6bf51b0d 100644 --- a/model-engine/model_engine_server/domain/entities/batch_job_entity.py +++ b/model-engine/model_engine_server/domain/entities/batch_job_entity.py @@ -61,3 +61,4 @@ class DockerImageBatchJob(BaseModel): status: BatchJobStatus # the status map relatively nicely onto BatchJobStatus annotations: Optional[Dict[str, str]] = None override_job_max_runtime_s: Optional[int] = None + num_workers: Optional[int] = 1 diff --git a/model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py b/model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py index 43af4e04..66c23368 100644 --- a/model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py @@ -26,6 +26,7 @@ async def create_docker_image_batch_job( mount_location: Optional[str], annotations: Optional[Dict[str, str]] = None, override_job_max_runtime_s: Optional[int] = None, + num_workers: Optional[int] = 1, ) -> str: """ Create a docker image batch job @@ -42,6 +43,8 @@ async def create_docker_image_batch_job( annotations: K8s annotations resource_requests: The resource requests for the batch job. mount_location: Location on filesystem where runtime-provided file contents get mounted + override_job_max_runtime_s: Optional override for the maximum runtime of the job + num_workers: num of pods to run in this job. Coordination needs to happen between the workers. Returns: diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 67f26fe1..dee24e5c 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -4,13 +4,16 @@ Read model endpoint creation logs: GET model-endpoints//creation-logs """ +import datetime import json import math import os +import re from dataclasses import asdict from typing import Any, AsyncIterable, Dict, List, Optional, Union from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests from model_engine_server.common.dtos.llms import ( CompletionOutput, CompletionStreamOutput, @@ -18,6 +21,8 @@ CompletionStreamV1Response, CompletionSyncV1Request, CompletionSyncV1Response, + CreateBatchCompletionsRequest, + CreateBatchCompletionsResponse, CreateLLMModelEndpointV1Request, CreateLLMModelEndpointV1Response, DeleteLLMEndpointResponse, @@ -41,6 +46,7 @@ make_logger, ) from model_engine_server.domain.entities import ( + GpuType, LLMInferenceFramework, LLMMetadata, LLMSource, @@ -52,6 +58,9 @@ RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor, ) +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( + DockerImageBatchJobBundle, +) from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, EndpointInfraStateNotFound, @@ -63,8 +72,10 @@ ObjectNotFoundException, UpstreamServiceError, ) +from model_engine_server.domain.gateways import DockerImageBatchJobGateway from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway from model_engine_server.domain.repositories import ( + DockerImageBatchJobBundleRepository, DockerRepository, ModelBundleRepository, TokenizerRepository, @@ -486,8 +497,6 @@ def load_model_weights_sub_commands( subcommands = [] s5cmd = "s5cmd" - base_path = checkpoint_path.split("/")[-1] - # This is a hack for now to skip installing s5cmd for text-generation-inference:0.9.3-launch_s3, # which has s5cmd binary already baked in. Otherwise, install s5cmd if it's not already available if ( @@ -498,6 +507,15 @@ def load_model_weights_sub_commands( else: s5cmd = "./s5cmd" + subcommands.extend( + self.get_s5cmd_copy_command(checkpoint_path, final_weights_folder, subcommands, s5cmd) + ) + + return subcommands + + def get_s5cmd_copy_command(self, checkpoint_path, final_weights_folder, s5cmd): + subcommands = [] + base_path = checkpoint_path.split("/")[-1] if base_path.endswith(".tar"): # If the checkpoint file is a tar file, extract it into final_weights_folder subcommands.extend( @@ -517,7 +535,6 @@ def load_model_weights_sub_commands( subcommands.append( f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" ) - return subcommands def load_model_files_sub_commands_trt_llm( @@ -2119,3 +2136,148 @@ async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownl public_file_name = model_file.rsplit("/", 1)[-1] urls[public_file_name] = self.filesystem_gateway.generate_signed_url(model_file) return ModelDownloadResponse(urls=urls) + + +def infer_hardware_from_model_name(model_name: str) -> CreateDockerImageBatchJobResourceRequests: + if "mixtral-8x7b" in model_name: + cpus = "20" + gpus = 2 + memory = "160Gi" + storage = "160Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A100E + else: + numbers = re.findall(r"\d+", model_name) + if len(numbers) == 0: + raise ObjectHasInvalidValueException( + f"Model {model_name} is not supported for batch completions." + ) + + b_params = int(numbers[-1]) + if b_params <= 7: + cpus = "10" + gpus = 1 + memory = "24Gi" + storage = "80Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A10 + elif b_params <= 13: + cpus = "20" + gpus = 2 + memory = "48Gi" + storage = "80Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A10 + elif b_params <= 34: + cpus = "40" + gpus = 4 + memory = "96Gi" + storage = "96Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A10 + elif b_params <= 70: + cpus = "20" + gpus = 2 + memory = "160Gi" + storage = "160Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A100E + else: + raise ObjectHasInvalidValueException( + f"Model {model_name} is not supported for batch completions." + ) + + return CreateDockerImageBatchJobResourceRequests( + cpus=cpus, gpus=gpus, memory=memory, storage=storage, gpu_type=gpu_type + ) + + +class CreateBatchCompletionsUseCase: + def __init__( + self, + docker_image_batch_job_gateway: DockerImageBatchJobGateway, + docker_repository: DockerRepository, + docker_image_batch_job_bundle_repo: DockerImageBatchJobBundleRepository, + ): + self.docker_image_batch_job_gateway = docker_image_batch_job_gateway + self.docker_repository = docker_repository + self.docker_image_batch_job_bundle_repo = docker_image_batch_job_bundle_repo + + async def create_batch_job_bundle( + self, + user: User, + request: CreateBatchCompletionsRequest, + hardware: CreateDockerImageBatchJobResourceRequests, + ) -> DockerImageBatchJobBundle: + bundle_name = ( + f"{request.model_config.model}_{datetime.datetime.utcnow().strftime('%y%m%d-%H%M%S')}" + ) + + image_tag = self.docker_repository.get_latest_image_tag( + hmi_config.batch_inference_vllm_repository + ) + + config_file_path = "/opt/config.json" + + assert hardware.gpu_type is not None + + batch_bundle = ( + await self.docker_image_batch_job_bundle_repo.create_docker_image_batch_job_bundle( + name=bundle_name, + created_by=user.user_id, + owner=user.team_id, + image_repository=hmi_config.batch_inference_vllm_repository, + image_tag=image_tag, + command=[ + "dumb-init", + "--", + "/bin/bash", + "-c", + "ddtrace-run python vllm_batch.py", + ], + env={"CONFIG_FILE": config_file_path}, + mount_location=config_file_path, + cpus=str(hardware.cpus), + memory=str(hardware.memory), + storage=str(hardware.storage), + gpus=hardware.gpus, + gpu_type=hardware.gpu_type, + public=False, + ) + ) + return batch_bundle + + async def execute( + self, user: User, request: CreateBatchCompletionsRequest + ) -> CreateBatchCompletionsResponse: + hardware = infer_hardware_from_model_name(request.model_config.model) + # Reconcile gpus count with num_shards from request + assert hardware.gpus is not None + if request.model_config.num_shards: + hardware.gpus = max(hardware.gpus, request.model_config.num_shards) + request.model_config.num_shards = hardware.gpus + + batch_bundle = await self.create_batch_job_bundle(user, request, hardware) + + validate_resource_requests( + bundle=batch_bundle, + cpus=hardware.cpus, + memory=hardware.memory, + storage=hardware.storage, + gpus=hardware.gpus, + gpu_type=hardware.gpu_type, + ) + + if request.max_runtime_sec is None or request.max_runtime_sec < 1: + raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") + + job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( + created_by=user.user_id, + owner=user.team_id, + job_config=request.dict(), + env=batch_bundle.env, + command=batch_bundle.command, + repo=batch_bundle.image_repository, + tag=batch_bundle.image_tag, + resource_requests=hardware, + labels=request.model_config.labels, + mount_location=batch_bundle.mount_location, + override_job_max_runtime_s=request.max_runtime_sec, + num_workers=request.data_parallelism, + ) + return CreateBatchCompletionsResponse(job_id=job_id) diff --git a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm new file mode 100644 index 00000000..92820714 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm @@ -0,0 +1,18 @@ +FROM nvcr.io/nvidia/pytorch:23.09-py3 + +RUN apt-get update && \ + apt-get install -y dumb-init && \ + apt-get autoremove -y && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean + +RUN pip uninstall torch -y +COPY model-engine/model_engine_server/inference/batch_inference/requirements.txt /workspace/requirements.txt +RUN pip install -r requirements.txt + +RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz +RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz + +COPY model-engine /workspace/model-engine +RUN pip install -e /workspace/model-engine +COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py diff --git a/model-engine/model_engine_server/inference/batch_inference/__init__.py b/model-engine/model_engine_server/inference/batch_inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/model_engine_server/inference/batch_inference/build_and_upload_image.sh b/model-engine/model_engine_server/inference/batch_inference/build_and_upload_image.sh new file mode 100755 index 00000000..2bd519ed --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/build_and_upload_image.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Build and push batch inference vLLM docker image to AWS ECR. + +set -eo pipefail + +if [ -z "$1" ]; then + echo "Must supply AWS account ID" + exit 1; +fi + +if [ -z "$2" ]; then + echo "Must supply the image tag" + exit 1; +fi + +IMAGE_TAG=$2 +ACCOUNT=$1 +aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com +DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:$IMAGE_TAG -f Dockerfile_vllm ../../../../ +docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:$IMAGE_TAG diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt new file mode 100644 index 00000000..5b7cf76a --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -0,0 +1,6 @@ +ray==2.6.3 +git+https://github.com/vllm-project/vllm.git@4b61c6b669e368c6850531815940d9a542b9f223#egg=vllm +pydantic==1.10.13 +boto3==1.34.15 +smart-open==6.4.0 +ddtrace==2.4.0 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_config.json b/model-engine/model_engine_server/inference/batch_inference/sample_config.json new file mode 100644 index 00000000..366d9785 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/sample_config.json @@ -0,0 +1,11 @@ +{ + "input_data_path":"./sample_data.json", + "output_data_path":"./sample_output.json", + "model_config":{ + "model":"llama-2-7b", + "checkpoint_path":"my_path", + "num_shards": 1, + "labels": {"team": "my_team"} + }, + "data_parallelism":2 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_data.json b/model-engine/model_engine_server/inference/batch_inference/sample_data.json new file mode 100644 index 00000000..87eb3169 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/sample_data.json @@ -0,0 +1,9 @@ +{ + "prompts":[ + "deep learning is", + "san francisco is" + ], + "max_new_tokens": 100, + "temperature": 0.0, + "return_token_log_probs": true +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py new file mode 100644 index 00000000..6c0c76db --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -0,0 +1,209 @@ +import asyncio +import json +import os +import subprocess +import time +from urllib.parse import urlparse + +import boto3 +import smart_open +from model_engine_server.common.dtos.llms import ( + CompletionOutput, + CreateBatchCompletionsRequest, + CreateBatchCompletionsRequestContent, + TokenOutput, +) + +CONFIG_FILE = os.getenv("CONFIG_FILE") +AWS_REGION = os.getenv("AWS_REGION", "us-west-2") + +os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") + + +def get_s3_client(): + session = boto3.Session(profile_name=os.getenv("S3_WRITE_AWS_PROFILE")) + return session.client("s3", region_name=AWS_REGION) + + +def download_model(checkpoint_path, final_weights_folder): + s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + process = subprocess.Popen( + s5cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + for line in process.stdout: + print(line) + + process.wait() + + if process.returncode != 0: + stderr_lines = [] + for line in iter(process.stderr.readline, ""): + stderr_lines.append(line.strip()) + + raise IOError(f"Error downloading model weights: {stderr_lines}") + + +def file_exists(path): + try: + with smart_open.open(path, "r"): + return True + except FileNotFoundError: + return False + + +def parse_s3_url(s3_url): + parsed_url = urlparse(s3_url) + + if parsed_url.scheme != "s3": + raise ValueError(f'The URL scheme is not "s3": {s3_url}') + + bucket = parsed_url.netloc + key = parsed_url.path.lstrip("/") + + return bucket, key + + +def wait_for_all_chunks(request): + # Max wait time is controlled by the batch job timeout + while True: + print("Waiting for all chunks to be written...") + all_chunks_exist = True + for i in range(request.data_parallelism): + chunk_file = f"{request.output_data_path}.{i}" + if not file_exists(chunk_file): + print(f"Chunk {chunk_file} does not exist yet") + all_chunks_exist = False + break + if all_chunks_exist: + break + time.sleep(5) + print("All chunks written") + + +def combine_all_chunks(request): + print("Combining chunks...") + with smart_open.open(request.output_data_path, "w") as f: + f.write("[") + for i in range(request.data_parallelism): + if i > 0: + f.write(",") + chunk_file = f"{request.output_data_path}.{i}" + with smart_open.open(chunk_file, "r") as chunk_f: + chunk_data = chunk_f.read() + f.write(chunk_data[1:-1]) # Remove leading and trailing brackets + f.write("]") + print("Chunks combined") + + +def delete_s3_chunks(request): + print("Deleting S3 chunks...") + for i in range(request.data_parallelism): + chunk_file = f"{request.output_data_path}.{i}" + bucket, key = parse_s3_url(chunk_file) + get_s3_client().delete_object(Bucket=bucket, Key=key) + print("Chunks deleted") + + +async def batch_inference(): + job_index = int(os.getenv("JOB_COMPLETION_INDEX", 0)) + + request = CreateBatchCompletionsRequest.parse_file(CONFIG_FILE) + + if request.model_config.checkpoint_path is not None: + download_model(request.model_config.checkpoint_path, "./model_weights") + + content = request.content + if content is None: + with smart_open.open(request.input_data_path, "r") as f: + content = CreateBatchCompletionsRequestContent.parse_raw(f.read()) + + model = ( + "./model_weights" if request.model_config.checkpoint_path else request.model_config.model + ) + + results_generators = await generate_with_vllm(request, content, model, job_index) + + outputs = [] + for generator in results_generators: + last_output_text = "" + tokens = [] + async for request_output in generator: + token_text = request_output.outputs[-1].text[len(last_output_text) :] + log_probs = ( + request_output.outputs[0].logprobs[-1] if content.return_token_log_probs else None + ) + last_output_text = request_output.outputs[-1].text + + if content.return_token_log_probs: + tokens.append( + TokenOutput( + token=token_text, + log_prob=log_probs[request_output.outputs[0].token_ids[-1]], + ) + ) + + num_prompt_tokens = len(request_output.prompt_token_ids) + num_completion_tokens = len(request_output.outputs[0].token_ids) + + output = CompletionOutput( + text=request_output.outputs[0].text, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) + if content.return_token_log_probs: + output.tokens = tokens + + outputs.append(output.dict()) + + if request.data_parallelism == 1: + with smart_open.open(request.output_data_path, "w") as f: + f.write(json.dumps(outputs)) + else: + chunk_file = f"{request.output_data_path}.{job_index}" + with smart_open.open(chunk_file, "w") as f: + f.write(json.dumps(outputs)) + if job_index == 0: + wait_for_all_chunks(request) + combine_all_chunks(request) + if request.output_data_path.startswith("s3://"): + delete_s3_chunks(request) + + +async def generate_with_vllm(request, content, model, job_index): + from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams + from vllm.utils import random_uuid + + engine_args = AsyncEngineArgs( + model=model, + quantization=request.model_config.quantize, + tensor_parallel_size=request.model_config.num_shards, + seed=request.model_config.seed or 0, + ) + + llm = AsyncLLMEngine.from_engine_args(engine_args) + + # Add the requests to the engine. + sampling_params = SamplingParams( + max_tokens=content.max_new_tokens, + temperature=content.temperature, + stop=content.stop_sequences, + logprobs=1 if content.return_token_log_probs else None, + presence_penalty=content.presence_penalty or 0.0, + frequency_penalty=content.frequency_penalty or 0.0, + top_k=content.top_k or -1, + top_p=content.top_p or 1.0, + ) + + results_generators = [] + prompts_per_pod = len(content.prompts) // request.data_parallelism + for prompt in content.prompts[prompts_per_pod * job_index : prompts_per_pod * (job_index + 1)]: + request_id = random_uuid() + results_generator = await llm.add_request( + request_id, prompt, sampling_params, None, time.monotonic() + ) + results_generators.append(results_generator) + return results_generators + + +if __name__ == "__main__": + asyncio.run(batch_inference()) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 954c143a..9c66ae7a 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -39,8 +39,6 @@ async def generate(request: Request) -> Response: results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case - # TODO: vLLM spends a long time decoding text repeatedly, that for every new token `text` is regenerated, - # (see detokenize_incrementally) which we should definitely optimize away. async def stream_results() -> AsyncGenerator[str, None]: last_output_text = "" async for request_output in results_generator: diff --git a/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py index 8ad2c09a..c3a86326 100644 --- a/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py @@ -145,6 +145,7 @@ async def create_docker_image_batch_job( mount_location: Optional[str], annotations: Optional[Dict[str, str]] = None, override_job_max_runtime_s: Optional[int] = None, + num_workers: Optional[int] = 1, ) -> str: await maybe_load_kube_config() @@ -161,7 +162,9 @@ async def create_docker_image_batch_job( labels=labels, annotations=annotations, override_job_max_runtime_s=override_job_max_runtime_s, + num_workers=num_workers, ) + logger.info(resource_spec) batch_client = get_kubernetes_batch_client() @@ -191,6 +194,7 @@ def _generate_job_spec( labels: Dict[str, str], annotations: Optional[Dict[str, str]] = None, override_job_max_runtime_s: Optional[int] = None, + num_workers: Optional[int] = 1, ) -> Tuple[str, Dict[str, Any]]: job_id = _get_job_id() job_name = _k8s_job_name_from_id(job_id) # why do we even have job_name and id @@ -237,6 +241,7 @@ def _generate_job_spec( GPU_TYPE=resource_requests.gpu_type.value, GPUS=resource_requests.gpus or 1, REQUEST_ID=LoggerTagManager.get(LoggerTagKey.REQUEST_ID) or "", + BATCH_JOB_NUM_WORKERS=num_workers or 1, ) else: resource_key = "docker-image-batch-job-cpu.yaml" @@ -266,6 +271,7 @@ def _generate_job_spec( FILE_CONTENTS_B64ENCODED=job_config_b64encoded, AWS_ROLE=infra_config().profile_ml_inference_worker, REQUEST_ID=LoggerTagManager.get(LoggerTagKey.REQUEST_ID) or "", + BATCH_JOB_NUM_WORKERS=num_workers or 1, ) resource_spec = load_k8s_yaml(resource_key, substitution_kwargs) @@ -336,6 +342,7 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker completed_at=job.status.completion_time, status=status, annotations=annotations, + num_workers=job.spec.completions, ) async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatchJob]: @@ -374,6 +381,7 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc status=_parse_job_status_from_k8s_obj( job, pods_per_job[job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)] ), + num_workers=job.spec.completions, ) for job in jobs.items ] diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index ed957fcc..e9d4a657 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -187,6 +187,7 @@ class _DockerImageBatchJobArguments(_JobArguments): LOCAL_FILE_NAME: str FILE_CONTENTS_B64ENCODED: str COMMAND: List[str] + BATCH_JOB_NUM_WORKERS: int class _GpuArguments(TypedDict): diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index d37172ec..a42fdc1b 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -59,6 +59,7 @@ tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" tensorrt_llm_repository: "tensorrt-llm" +batch_inference_vllm_repository: "llm-engine/batch-infer-vllm" user_inference_base_repository: "launch/inference" user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" diff --git a/model-engine/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py index b77071b7..703c7c12 100644 --- a/model-engine/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -1268,3 +1268,24 @@ def trigger_2(test_api_key) -> Tuple[Trigger, Any]: "default_job_metadata": {"team": "infra", "product": "my_product_two"}, } return trigger, trigger_json + + +@pytest.fixture +def create_batch_completions_request() -> Dict[str, Any]: + return { + "input_data_path": "test_input_data_path", + "output_data_path": "test_output_data_path", + "content": { + "prompts": ["what is 1+1?"], + "max_new_tokens": 10, + "temperature": 0.1, + }, + "model_config": { + "model": "mpt-7b", + "checkpoint_path": "test_checkpoint_path", + "labels": [], + "num_shards": 2, + }, + "data_parallelism": 1, + "max_runtime_sec": 86400, + } diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 7bf64660..7ab55908 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -231,3 +231,25 @@ def test_completion_stream_endpoint_not_found_returns_404( for message in response_1: assert "404" in message.decode("utf-8") + + +def test_create_batch_completions_success( + create_batch_completions_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + "/v1/llm/batch-completions", + auth=(test_api_key, ""), + json=create_batch_completions_request, + ) + assert response_1.status_code == 200 diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 684b4fac..eeef573c 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -1249,6 +1249,7 @@ async def create_docker_image_batch_job( mount_location: Optional[str], annotations: Optional[Dict[str, str]] = None, override_job_max_runtime_s: Optional[int] = None, + num_workers: Optional[int] = 1, ) -> str: job_id = f"ft-{self.id}" self.id += 1 @@ -1262,6 +1263,7 @@ async def create_docker_image_batch_job( status=BatchJobStatus.RUNNING, annotations=annotations, override_job_max_runtime_s=override_job_max_runtime_s, + num_workers=num_workers, ) return job_id diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 798af362..bbf25058 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -6,6 +6,9 @@ from model_engine_server.common.dtos.llms import ( CompletionStreamV1Request, CompletionSyncV1Request, + CreateBatchCompletionsModelConfig, + CreateBatchCompletionsRequest, + CreateBatchCompletionsRequestContent, CreateLLMModelEndpointV1Request, UpdateLLMModelEndpointV1Request, ) @@ -468,3 +471,23 @@ def completion_stream_request() -> CompletionStreamV1Request: max_new_tokens=10, temperature=0.5, ) + + +@pytest.fixture +def create_batch_completions_request() -> CreateBatchCompletionsRequest: + return CreateBatchCompletionsRequest( + input_data_path="test_input_data_path", + output_data_path="test_output_data_path", + content=CreateBatchCompletionsRequestContent( + prompts=["What is machine learning?"], + max_new_tokens=10, + temperature=0.5, + ), + model_config=CreateBatchCompletionsModelConfig( + model="mpt-7b", + checkpoint_path="test_checkpoint_path", + labels=[], + num_shards=2, + ), + data_parallelism=2, + ) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index a9d32975..c5176e03 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -7,6 +7,7 @@ CompletionOutput, CompletionStreamV1Request, CompletionSyncV1Request, + CreateBatchCompletionsRequest, CreateFineTuneRequest, CreateLLMModelEndpointV1Request, CreateLLMModelEndpointV1Response, @@ -40,13 +41,16 @@ from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( CompletionStreamV1UseCase, CompletionSyncV1UseCase, + CreateBatchCompletionsUseCase, CreateLLMModelBundleV1UseCase, CreateLLMModelEndpointV1UseCase, DeleteLLMEndpointByNameUseCase, GetLLMModelEndpointByNameV1UseCase, + GpuType, ModelDownloadV1UseCase, UpdateLLMModelEndpointV1UseCase, _include_safetensors_bin_or_pt, + infer_hardware_from_model_name, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -1492,3 +1496,76 @@ async def test_include_safetensors_bin_or_pt_majority_pt(): "fake3.pt", ] assert _include_safetensors_bin_or_pt(fake_model_files) == "*.pt" + + +def test_infer_hardware_from_model_name(): + hardware = infer_hardware_from_model_name("mixtral-8x7b") + assert hardware.cpus == "20" + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + + hardware = infer_hardware_from_model_name("llama-2-7b") + assert hardware.cpus == "10" + assert hardware.gpus == 1 + assert hardware.memory == "24Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + + hardware = infer_hardware_from_model_name("llama-2-13b") + assert hardware.cpus == "20" + assert hardware.gpus == 2 + assert hardware.memory == "48Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + + hardware = infer_hardware_from_model_name("codellama-34b") + assert hardware.cpus == "40" + assert hardware.gpus == 4 + assert hardware.memory == "96Gi" + assert hardware.storage == "96Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + + hardware = infer_hardware_from_model_name("llama-2-70b") + assert hardware.cpus == "20" + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + + with pytest.raises(ObjectHasInvalidValueException): + infer_hardware_from_model_name("unsupported_model") + + with pytest.raises(ObjectHasInvalidValueException): + infer_hardware_from_model_name("falcon-180b") + + +@pytest.mark.asyncio +async def test_create_batch_completions( + fake_docker_image_batch_job_gateway, + fake_docker_repository_image_always_exists, + fake_docker_image_batch_job_bundle_repository, + test_api_key: str, + create_batch_completions_request: CreateBatchCompletionsRequest, +): + use_case = CreateBatchCompletionsUseCase( + docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway, + docker_repository=fake_docker_repository_image_always_exists, + docker_image_batch_job_bundle_repo=fake_docker_image_batch_job_bundle_repository, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + result = await use_case.execute(user, create_batch_completions_request) + + job = await fake_docker_image_batch_job_gateway.get_docker_image_batch_job(result.job_id) + assert job.num_workers == create_batch_completions_request.data_parallelism + + bundle = list(fake_docker_image_batch_job_bundle_repository.db.values())[0] + assert bundle.command == [ + "dumb-init", + "--", + "/bin/bash", + "-c", + "ddtrace-run python vllm_batch.py", + ] diff --git a/model-engine/tests/unit/inference/conftest.py b/model-engine/tests/unit/inference/conftest.py new file mode 100644 index 00000000..4d0ec72c --- /dev/null +++ b/model-engine/tests/unit/inference/conftest.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock + +import pytest +from model_engine_server.common.dtos.llms import ( + CompletionOutput, + CreateBatchCompletionsModelConfig, + CreateBatchCompletionsRequest, + CreateBatchCompletionsRequestContent, + TokenOutput, +) + + +@pytest.fixture +def create_batch_completions_request(): + return CreateBatchCompletionsRequest( + model_config=CreateBatchCompletionsModelConfig( + checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={} + ), + data_parallelism=1, + input_data_path="input_data_path", + output_data_path="output_data_path", + ) + + +@pytest.fixture +def create_batch_completions_request_content(): + return CreateBatchCompletionsRequestContent( + prompts=["prompt1", "prompt2"], + max_new_tokens=100, + temperature=0.8, + return_token_log_probs=True, + ) + + +@pytest.fixture +def create_vllm_request_outputs(): + mock_vllm_request_output1 = MagicMock() + mock_vllm_request_output1.outputs = [ + MagicMock(text="text1"), + ] + mock_vllm_request_output1.prompt_token_ids = [1, 2, 3] + mock_vllm_request_output1.outputs[0].token_ids = [4] + mock_vllm_request_output1.outputs[0].logprobs = [{4: 0.1}] + + mock_vllm_request_output2 = MagicMock() + mock_vllm_request_output2.outputs = [ + MagicMock(text="text1 text2"), + ] + mock_vllm_request_output2.prompt_token_ids = [1, 2, 3] + mock_vllm_request_output2.outputs[0].token_ids = [4, 5] + mock_vllm_request_output2.outputs[0].logprobs = [{4: 0.1, 5: 0.2}] + + mock_vllm_request_output3 = MagicMock() + mock_vllm_request_output3.outputs = [ + MagicMock(text="text1 text2 text3"), + ] + mock_vllm_request_output3.prompt_token_ids = [1, 2, 3] + mock_vllm_request_output3.outputs[0].token_ids = [4, 5, 6] + mock_vllm_request_output3.outputs[0].logprobs = [{4: 0.1, 5: 0.2, 6: 0.3}] + return [mock_vllm_request_output1, mock_vllm_request_output2, mock_vllm_request_output3] + + +@pytest.fixture +def mock_s3_client(): + mock_s3_client = MagicMock() + mock_s3_client.delete_object.return_value = None + return mock_s3_client + + +@pytest.fixture +def mock_process(): + mock_process = MagicMock() + mock_process.stdout = [] + mock_process.stderr.readline.side_effect = [ + "error", + ] + mock_process.returncode = 0 + mock_process.wait.return_value = None + return mock_process + + +@pytest.fixture +def mock_completion_output(): + return CompletionOutput( + text="text1 text2 text3", + num_prompt_tokens=3, + num_completion_tokens=3, + tokens=[ + TokenOutput(token="text1", log_prob=0.1), + TokenOutput(token=" text2", log_prob=0.2), + TokenOutput(token=" text3", log_prob=0.3), + ], + ) diff --git a/model-engine/tests/unit/inference/test_vllm_batch.py b/model-engine/tests/unit/inference/test_vllm_batch.py new file mode 100644 index 00000000..dd717c04 --- /dev/null +++ b/model-engine/tests/unit/inference/test_vllm_batch.py @@ -0,0 +1,274 @@ +import json +from unittest.mock import MagicMock, call, mock_open, patch + +import pytest +from model_engine_server.inference.batch_inference.vllm_batch import batch_inference + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("subprocess.Popen") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +async def test_batch_inference( + mock_open_func, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_request, + create_batch_completions_request, + create_batch_completions_request_content, + create_vllm_request_outputs, + mock_s3_client, + mock_process, + mock_completion_output, +): + # Mock the necessary objects and data + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request + mock_create_batch_completions_request_content.parse_raw.return_value = ( + create_batch_completions_request_content + ) + + mock_results_generator = MagicMock() + mock_results_generator.__aiter__.return_value = create_vllm_request_outputs + + # Mock the generate_with_vllm function + mock_generate_with_vllm.return_value = [mock_results_generator] + + # Call the function + await batch_inference() + + # Assertions + mock_create_batch_completions_request.parse_file.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + ], + any_order=True, + ) + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("subprocess.Popen") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +async def test_batch_inference_failed_to_download_model( + mock_open_func, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_request, + create_batch_completions_request, + create_batch_completions_request_content, + create_vllm_request_outputs, + mock_s3_client, + mock_process, +): + # Mock the necessary objects and data + mock_process.returncode = 1 + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request + mock_create_batch_completions_request_content.parse_raw.return_value = ( + create_batch_completions_request_content + ) + + # Call the function + with pytest.raises(IOError): + await batch_inference() + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("subprocess.Popen") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.os.getenv") +async def test_batch_inference_two_workers( + mock_getenv, + mock_open_func, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_request, + create_batch_completions_request, + create_batch_completions_request_content, + create_vllm_request_outputs, + mock_s3_client, + mock_process, + mock_completion_output, +): + # Mock the necessary objects and data + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + create_batch_completions_request.data_parallelism = 2 + mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request + mock_create_batch_completions_request_content.parse_raw.return_value = ( + create_batch_completions_request_content + ) + + mock_results_generator = MagicMock() + mock_results_generator.__aiter__.return_value = create_vllm_request_outputs + + # Mock the generate_with_vllm function + mock_generate_with_vllm.return_value = [mock_results_generator] + + indexes = [1, 0] + + def side_effect(key, default): + if key == "JOB_COMPLETION_INDEX": + return indexes.pop(0) + return default + + mock_getenv.side_effect = side_effect + # Batch completion worker 1 + await batch_inference() + + # Assertions + mock_create_batch_completions_request.parse_file.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path.1", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + ], + any_order=True, + ) + + # Batch completion worker 0 + await batch_inference() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path.1", "r"), + call("output_data_path.0", "w"), + call("output_data_path.0", "r"), + call("output_data_path", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + call().write("["), + call().write(","), + call().write("]"), + ], + any_order=True, + ) + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("subprocess.Popen") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.os.getenv") +async def test_batch_inference_delete_chunks( + mock_getenv, + mock_open_func, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_request, + create_batch_completions_request, + create_batch_completions_request_content, + create_vllm_request_outputs, + mock_s3_client, + mock_process, + mock_completion_output, +): + # Mock the necessary objects and data + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + create_batch_completions_request.data_parallelism = 2 + create_batch_completions_request.output_data_path = "s3://bucket/key" + mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request + mock_create_batch_completions_request_content.parse_raw.return_value = ( + create_batch_completions_request_content + ) + + mock_results_generator = MagicMock() + mock_results_generator.__aiter__.return_value = create_vllm_request_outputs + + # Mock the generate_with_vllm function + mock_generate_with_vllm.return_value = [mock_results_generator] + + indexes = [1, 0] + + def side_effect(key, default): + if key == "JOB_COMPLETION_INDEX": + return indexes.pop(0) + return default + + mock_getenv.side_effect = side_effect + # Batch completion worker 1 + await batch_inference() + + # Assertions + mock_create_batch_completions_request.parse_file.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("s3://bucket/key.1", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + ], + any_order=True, + ) + + # Batch completion worker 0 + await batch_inference() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("s3://bucket/key.1", "r"), + call("s3://bucket/key.0", "w"), + call("s3://bucket/key.0", "r"), + call("s3://bucket/key", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + call().write("["), + call().write(","), + call().write("]"), + ], + any_order=True, + ) + + mock_s3_client.delete_object.assert_has_calls( + [call(Bucket="bucket", Key="key.0"), call(Bucket="bucket", Key="key.1")] + ) diff --git a/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py b/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py index 55039109..257b0db5 100644 --- a/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py +++ b/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py @@ -32,10 +32,17 @@ class FakeK8sV1JobStatus: completion_time: Optional[datetime] = None +@dataclass +class FakeK8sV1JobSpec: + completions: int = 1 + parallelism: int = 1 + + @dataclass class FakeK8sV1Job: metadata: FakeK8sV1ObjectMeta = FakeK8sV1ObjectMeta() status: FakeK8sV1JobStatus = FakeK8sV1JobStatus() + spec: FakeK8sV1JobSpec = FakeK8sV1JobSpec() # TODO: spec, api_version, kind From db11cd779a71f41c8d386e8cab2a4a1d1b76bc4c Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 17 Jan 2024 14:31:44 -0800 Subject: [PATCH 218/425] Small update to vllm batch (#419) --- .../inference/batch_inference/requirements.txt | 3 ++- .../inference/batch_inference/vllm_batch.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt index 5b7cf76a..9e8d1188 100644 --- a/model-engine/model_engine_server/inference/batch_inference/requirements.txt +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -1,5 +1,6 @@ ray==2.6.3 -git+https://github.com/vllm-project/vllm.git@4b61c6b669e368c6850531815940d9a542b9f223#egg=vllm +#git+https://github.com/vllm-project/vllm.git@4b61c6b669e368c6850531815940d9a542b9f223#egg=vllm +vllm==0.2.5 pydantic==1.10.13 boto3==1.34.15 smart-open==6.4.0 diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 6c0c76db..20e0459d 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -13,6 +13,7 @@ CreateBatchCompletionsRequestContent, TokenOutput, ) +from tqdm import tqdm CONFIG_FILE = os.getenv("CONFIG_FILE") AWS_REGION = os.getenv("AWS_REGION", "us-west-2") @@ -123,11 +124,16 @@ async def batch_inference(): results_generators = await generate_with_vllm(request, content, model, job_index) + bar = tqdm(total=len(content.prompts), desc="Processed prompts") + outputs = [] for generator in results_generators: last_output_text = "" tokens = [] async for request_output in generator: + if request_output.finished: + bar.update(1) + token_text = request_output.outputs[-1].text[len(last_output_text) :] log_probs = ( request_output.outputs[0].logprobs[-1] if content.return_token_log_probs else None @@ -155,6 +161,8 @@ async def batch_inference(): outputs.append(output.dict()) + bar.close() + if request.data_parallelism == 1: with smart_open.open(request.output_data_path, "w") as f: f.write(json.dumps(outputs)) @@ -178,6 +186,7 @@ async def generate_with_vllm(request, content, model, job_index): quantization=request.model_config.quantize, tensor_parallel_size=request.model_config.num_shards, seed=request.model_config.seed or 0, + disable_log_requests=True, ) llm = AsyncLLMEngine.from_engine_args(engine_args) From 53a1918ef3568b674b59a4e4e772501a7e1a1d69 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 18 Jan 2024 20:55:12 -0800 Subject: [PATCH 219/425] sensitive content flag (#421) --- charts/model-engine/values_sample.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index f8c3c66a..7ed16da6 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -202,7 +202,7 @@ config: billing_queue_arn: "unused" model_primitive_host: "unused" hf_user_fine_tuned_weights_prefix: "s3://llm-engine/fine_tuned_weights" - + sensitive_log_mode: false tgi_repository: "text-generation-inference" vllm_repository: "vllm" lightllm_repository: "lightllm" From fc9a50300c0f513cd3d11a1d74924af97a534d16 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 22 Jan 2024 11:06:08 -0800 Subject: [PATCH 220/425] Revert a broken refactoring (#423) * Revert a broken refactoring * fix --- .../use_cases/llm_model_endpoint_use_cases.py | 9 +--- .../tests/unit/domain/test_llm_use_cases.py | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index dee24e5c..8a18273e 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -507,14 +507,6 @@ def load_model_weights_sub_commands( else: s5cmd = "./s5cmd" - subcommands.extend( - self.get_s5cmd_copy_command(checkpoint_path, final_weights_folder, subcommands, s5cmd) - ) - - return subcommands - - def get_s5cmd_copy_command(self, checkpoint_path, final_weights_folder, s5cmd): - subcommands = [] base_path = checkpoint_path.split("/")[-1] if base_path.endswith(".tar"): # If the checkpoint file is a tar file, extract it into final_weights_folder @@ -535,6 +527,7 @@ def get_s5cmd_copy_command(self, checkpoint_path, final_weights_folder, s5cmd): subcommands.append( f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" ) + return subcommands def load_model_files_sub_commands_trt_llm( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index c5176e03..10b37c7d 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -284,6 +284,56 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( ) +def test_load_model_weights_sub_commands( + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + + framework = LLMInferenceFramework.VLLM + framework_image_tag = "0.2.7" + checkpoint_path = "fake-checkpoint" + final_weights_folder = "test_folder" + + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder + ) + + expected_result = [ + "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder", + ] + assert expected_result == subcommands + + framework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE + framework_image_tag = "1.0.0" + checkpoint_path = "fake-checkpoint" + final_weights_folder = "test_folder" + + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder + ) + + expected_result = [ + "s5cmd > /dev/null || conda install -c conda-forge -y s5cmd", + "s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder", + ] + assert expected_result == subcommands + + @pytest.mark.asyncio async def test_create_model_endpoint_trt_llm_use_case_success( test_api_key: str, From 112513cd46a50f4351e93e6e6648ad51c58d9b02 Mon Sep 17 00:00:00 2001 From: tiffzhao5 <142925794+tiffzhao5@users.noreply.github.com> Date: Tue, 23 Jan 2024 15:38:54 -0800 Subject: [PATCH 221/425] [Logging I/O] Post inference hooks as background tasks (#422) * changes for forwarder to run locally * forwarder hooks as background tasks and testing code * hooks for celery forwarder * revert local changes for testing * revert unncessary things * remove space * remove print statement + fix unit test * move logic to after_return * load json response in handler * add temp unit test for post inference hooks handler * add another temp unit test for json handling * not cover handle line for now --- .../inference/forwarding/celery_forwarder.py | 81 ++++++++-------- .../inference/forwarding/forwarding.py | 4 - .../inference/forwarding/http_forwarder.py | 13 ++- .../inference/post_inference_hooks.py | 12 ++- .../unit/inference/test_http_forwarder.py | 94 +++++++++++++++++-- 5 files changed, 149 insertions(+), 55 deletions(-) diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 6206f711..9ed5e4dd 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -5,6 +5,7 @@ from celery import Celery, Task, states from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.celery import TaskVisibility, celery_app from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger @@ -25,45 +26,6 @@ class ErrorResponse(TypedDict): error_metadata: str -class ErrorHandlingTask(Task): - """Sets a 'custom' field with error in the Task response for FAILURE. - - Used when services are ran via the Celery backend. - """ - - def after_return( - self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo - ) -> None: - """Handler that ensures custom error response information is available whenever a Task fails. - - Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value - :param:`retval` is an `Exception`, this handler extracts information from the `Exception` - and constructs a custom error response JSON value (see :func:`error_response` for details). - - This handler then re-propagates the Celery-required exception information (`"exc_type"` and - `"exc_message"`) while adding this new error response information under the `"custom"` key. - """ - if status == states.FAILURE and isinstance(retval, Exception): - logger.warning(f"Setting custom error response for failed task {task_id}") - - info: dict = raw_celery_response(self.backend, task_id) - result: dict = info["result"] - err: Exception = retval - - error_payload = error_response("Internal failure", err) - - # Inspired by pattern from: - # https://www.distributedpython.com/2018/09/28/celery-task-states/ - self.update_state( - state=states.FAILURE, - meta={ - "exc_type": result["exc_type"], - "exc_message": result["exc_message"], - "custom": json.dumps(error_payload, indent=False), - }, - ) - - def raw_celery_response(backend, task_id: str) -> Dict[str, Any]: key_info: str = backend.get_key_for_task(task_id) info_as_str: str = backend.get(key_info) @@ -103,6 +65,47 @@ def create_celery_service( else None, ) + class ErrorHandlingTask(Task): + """Sets a 'custom' field with error in the Task response for FAILURE. + + Used when services are ran via the Celery backend. + """ + + def after_return( + self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo + ) -> None: + """Handler that ensures custom error response information is available whenever a Task fails. + + Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value + :param:`retval` is an `Exception`, this handler extracts information from the `Exception` + and constructs a custom error response JSON value (see :func:`error_response` for details). + + This handler then re-propagates the Celery-required exception information (`"exc_type"` and + `"exc_message"`) while adding this new error response information under the `"custom"` key. + """ + if status == states.FAILURE and isinstance(retval, Exception): + logger.warning(f"Setting custom error response for failed task {task_id}") + + info: dict = raw_celery_response(self.backend, task_id) + result: dict = info["result"] + err: Exception = retval + + error_payload = error_response("Internal failure", err) + + # Inspired by pattern from: + # https://www.distributedpython.com/2018/09/28/celery-task-states/ + self.update_state( + state=states.FAILURE, + meta={ + "exc_type": result["exc_type"], + "exc_message": result["exc_message"], + "custom": json.dumps(error_payload, indent=False), + }, + ) + request_params = args[0] + request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params) + forwarder.post_inference_hooks_handler.handle(request_params_pydantic, retval, task_id) # type: ignore + # See documentation for options: # https://docs.celeryproject.org/en/stable/userguide/tasks.html#list-of-options @app.task(base=ErrorHandlingTask, name=LIRA_CELERY_TASK_NAME, track_started=True) diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 099fe7d4..4bbe885d 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -9,7 +9,6 @@ import sseclient import yaml from fastapi.responses import JSONResponse -from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.common import get_endpoint_config from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( @@ -126,7 +125,6 @@ class Forwarder(ModelEngineSerializationMixin): forward_http_status: bool def __call__(self, json_payload: Any) -> Any: - request_obj = EndpointPredictV1Request.parse_obj(json_payload) json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload @@ -163,8 +161,6 @@ def __call__(self, json_payload: Any) -> Any: if self.wrap_response: response = self.get_response_payload(using_serialize_results_as_string, response) - # TODO: we actually want to do this after we've returned the response. - self.post_inference_hooks_handler.handle(request_obj, response) if self.forward_http_status: return JSONResponse(content=response, status_code=response_raw.status_code) else: diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index f121bec2..1fdb030b 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -4,7 +4,7 @@ import subprocess from functools import lru_cache -from fastapi import Depends, FastAPI +from fastapi import BackgroundTasks, Depends, FastAPI from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger @@ -70,11 +70,20 @@ def load_streaming_forwarder(): @app.post("/predict") def predict( request: EndpointPredictV1Request, + background_tasks: BackgroundTasks, forwarder=Depends(load_forwarder), limiter=Depends(get_concurrency_limiter), ): with limiter: - return forwarder(request.dict()) + try: + response = forwarder(request.dict()) + background_tasks.add_task( + forwarder.post_inference_hooks_handler.handle, request, response + ) + return response + except Exception: + logger.error(f"Failed to decode payload from: {request}") + raise @app.post("/stream") diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py index 05dba306..142e998e 100644 --- a/model-engine/model_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -1,7 +1,9 @@ +import json from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import requests +from fastapi.responses import JSONResponse from model_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger @@ -108,13 +110,17 @@ def __init__( def handle( self, request_payload: EndpointPredictV1Request, - response: Dict[str, Any], + response: Union[Dict[str, Any], JSONResponse], task_id: Optional[str] = None, ): + if isinstance(response, JSONResponse): + loaded_response = json.loads(response.body) + else: + loaded_response = response for hook_name, hook in self._hooks.items(): self._monitoring_metrics_gateway.emit_attempted_post_inference_hook(hook_name) try: - hook.handle(request_payload, response, task_id) + hook.handle(request_payload, loaded_response, task_id) # pragma: no cover self._monitoring_metrics_gateway.emit_successful_post_inference_hook(hook_name) except Exception: logger.exception(f"Hook {hook_name} failed.") diff --git a/model-engine/tests/unit/inference/test_http_forwarder.py b/model-engine/tests/unit/inference/test_http_forwarder.py index 43fbdfbd..1ded1624 100644 --- a/model-engine/tests/unit/inference/test_http_forwarder.py +++ b/model-engine/tests/unit/inference/test_http_forwarder.py @@ -1,12 +1,23 @@ import threading -import time +from dataclasses import dataclass +from typing import Mapping +from unittest import mock import pytest +from fastapi import BackgroundTasks +from fastapi.responses import JSONResponse from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.inference.forwarding.forwarding import Forwarder from model_engine_server.inference.forwarding.http_forwarder import ( MultiprocessingConcurrencyLimiter, predict, ) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler + +PAYLOAD: Mapping[str, str] = {"hello": "world"} class ExceptionCapturedThread(threading.Thread): @@ -26,21 +37,90 @@ def join(self): raise self.ex -def mock_forwarder(dict): - time.sleep(1) - return dict +def mocked_get(*args, **kwargs): # noqa + @dataclass + class mocked_static_status_code: + status_code: int = 200 + + return mocked_static_status_code() + + +def mocked_post(*args, **kwargs): # noqa + @dataclass + class mocked_static_json: + status_code: int = 200 + + def json(self) -> dict: + return PAYLOAD # type: ignore + + return mocked_static_json() + + +@pytest.fixture +def post_inference_hooks_handler(): + handler = PostInferenceHooksHandler( + endpoint_name="test_endpoint_name", + bundle_name="test_bundle_name", + post_inference_hooks=[], + user_id="test_user_id", + billing_queue="billing_queue", + billing_tags=[], + default_callback_url=None, + default_callback_auth=None, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + ) + return handler + +@pytest.fixture +def mock_request(): + return EndpointPredictV1Request( + url="test_url", + return_pickled=False, + args={"x": 1}, + ) -def test_http_service_429(): + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +def test_http_service_429(mock_request, post_inference_hooks_handler): + mock_forwarder = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=False, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=True, + forward_http_status=True, + ) limiter = MultiprocessingConcurrencyLimiter(1, True) t1 = ExceptionCapturedThread( - target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter) + target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter) ) t2 = ExceptionCapturedThread( - target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter) + target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter) ) t1.start() t2.start() t1.join() with pytest.raises(Exception): # 429 thrown t2.join() + + +def test_handler_response(post_inference_hooks_handler): + try: + post_inference_hooks_handler.handle( + request_payload=mock_request, response=PAYLOAD, task_id="test_task_id" + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_handler_json_response(post_inference_hooks_handler): + try: + post_inference_hooks_handler.handle( + request_payload=mock_request, + response=JSONResponse(content=PAYLOAD), + task_id="test_task_id", + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") From d1306608ff592ee65537fccb7910b80ce7faa686 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 25 Jan 2024 16:15:49 -0800 Subject: [PATCH 222/425] Batch inference client / doc (#424) * batch inference client / doc * fix * fixes --- clients/python/llmengine/__init__.py | 8 +++ clients/python/llmengine/completion.py | 98 ++++++++++++++++++++++++++ clients/python/llmengine/data_types.py | 96 +++++++++++++++++++++++++ docs/api/data_types.md | 38 ++++++++++ docs/api/python_client.md | 1 + docs/contributing.md | 2 +- 6 files changed, 242 insertions(+), 1 deletion(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index cc8f28be..ac05d1bb 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -25,6 +25,10 @@ CompletionStreamOutput, CompletionStreamResponse, CompletionSyncResponse, + CreateBatchCompletionsModelConfig, + CreateBatchCompletionsRequest, + CreateBatchCompletionsRequestContent, + CreateBatchCompletionsResponse, CreateFineTuneRequest, CreateFineTuneResponse, DeleteFileResponse, @@ -51,6 +55,10 @@ "CompletionStreamOutput", "CompletionStreamResponse", "CompletionSyncResponse", + "CreateBatchCompletionsModelConfig", + "CreateBatchCompletionsRequest", + "CreateBatchCompletionsRequestContent", + "CreateBatchCompletionsResponse", "CreateFineTuneRequest", "CreateFineTuneResponse", "DeleteFileResponse", diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 507754d8..3a02f04e 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -6,9 +6,14 @@ CompletionStreamV1Request, CompletionSyncResponse, CompletionSyncV1Request, + CreateBatchCompletionsModelConfig, + CreateBatchCompletionsRequest, + CreateBatchCompletionsRequestContent, + CreateBatchCompletionsResponse, ) COMPLETION_TIMEOUT = 300 +HTTP_TIMEOUT = 60 class Completion(APIEngine): @@ -397,3 +402,96 @@ def _create_stream(**kwargs): timeout=timeout, ) return CompletionSyncResponse.parse_obj(response) + + @classmethod + def batch_create( + cls, + output_data_path: str, + model_config: CreateBatchCompletionsModelConfig, + content: Optional[CreateBatchCompletionsRequestContent] = None, + input_data_path: Optional[str] = None, + data_parallelism: int = 1, + max_runtime_sec: int = 24 * 3600, + ) -> CreateBatchCompletionsResponse: + """ + Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint. + + Prompts can be passed in from an input file, or as a part of the request. + + Args: + output_data_path (str): + The path to the output file. The output file will be a JSON file containing the completions. + + model_config (CreateBatchCompletionsModelConfig): + The model configuration to use for the batch completion. + + content (Optional[CreateBatchCompletionsRequestContent]): + The content to use for the batch completion. Either one of `content` or `input_data_path` must be provided. + + input_data_path (Optional[str]): + The path to the input file. The input file should be a JSON file with data of type `BatchCompletionsRequestContent`. Either one of `content` or `input_data_path` must be provided. + + data_parallelism (int): + The number of parallel jobs to run. Data will be evenly distributed to the jobs. Defaults to 1. + + max_runtime_sec (int): + The maximum runtime of the batch completion in seconds. Defaults to 24 hours. + + Returns: + response (CreateBatchCompletionsResponse): The response containing the job id. + + === "Batch completions with prompts in the request" + ```python + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + + response = Completion.batch_create( + output_data_path="s3://my-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + content=CreateBatchCompletionsRequestContent( + prompts=["What is deep learning", "What is a neural network"], + max_new_tokens=10, + temperature=0.0 + ) + ) + print(response.json()) + ``` + + === "Batch completions with prompts in a file and with 2 parallel jobs" + ```python + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + + # Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" + + response = Completion.batch_create( + input_data_path="s3://my-input-path", + output_data_path="s3://my-output-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + data_parallelism=2 + ) + print(response.json()) + ``` + """ + data = CreateBatchCompletionsRequest( + model_config=model_config, + content=content, + input_data_path=input_data_path, + output_data_path=output_data_path, + data_parallelism=data_parallelism, + max_runtime_sec=max_runtime_sec, + ).dict() + response = cls.post_sync( + resource_name="v1/llm/batch-completions", + data=data, + timeout=HTTP_TIMEOUT, + ) + return CreateBatchCompletionsResponse.parse_obj(response) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 2a37e912..64ec45a0 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -591,3 +591,99 @@ class GetFileContentResponse(BaseModel): content: str = Field(..., description="File content.") """File content.""" + + +class CreateBatchCompletionsRequestContent(BaseModel): + prompts: List[str] + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + + +class CreateBatchCompletionsModelConfig(BaseModel): + model: str + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + labels: Dict[str, str] + """ + Labels to attach to the batch inference job. + """ + num_shards: Optional[int] = 1 + """ + Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. + System may decide to use a different number than the given value. + """ + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + seed: Optional[int] = None + """ + Random seed for the model. + """ + + +class CreateBatchCompletionsRequest(BaseModel): + """ + Request object for batch completions. + """ + + input_data_path: Optional[str] + output_data_path: str + """ + Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. + """ + content: Optional[CreateBatchCompletionsRequestContent] = None + """ + Either `input_data_path` or `content` needs to be provided. + When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. + """ + model_config: CreateBatchCompletionsModelConfig + """ + Model configuration for the batch inference. Hardware configurations are inferred. + """ + data_parallelism: Optional[int] = Field(default=1, ge=1, le=64) + """ + Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference. + """ + max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600) + """ + Maximum runtime of the batch inference in seconds. Default to one day. + """ + + +class CreateBatchCompletionsResponse(BaseModel): + job_id: str + """ + The ID of the batch completions job. + """ diff --git a/docs/api/data_types.md b/docs/api/data_types.md index 44dd3d8f..206c93e6 100644 --- a/docs/api/data_types.md +++ b/docs/api/data_types.md @@ -110,3 +110,41 @@ options: members: - deleted + +::: llmengine.CreateBatchCompletionsRequestContent + options: + members: + - prompts + - max_new_tokens + - temperature + - stop_sequences + - return_token_log_probs + - presence_penalty + - frequency_penalty + - top_k + - top_p + +::: llmengine.CreateBatchCompletionsModelConfig + options: + members: + - model + - checkpoint_path + - labels + - num_shards + - quantize + - seed + +::: llmengine.CreateBatchCompletionsRequest + options: + members: + - input_data_path + - output_data_path + - content + - model_config + - data_parallelism + - max_runtime_sec + +::: llmengine.CreateBatchCompletionsResponse + options: + members: + - job_id diff --git a/docs/api/python_client.md b/docs/api/python_client.md index d77d28bc..c9e22723 100644 --- a/docs/api/python_client.md +++ b/docs/api/python_client.md @@ -5,6 +5,7 @@ members: - create - acreate + - batch_create ::: llmengine.FineTune options: diff --git a/docs/contributing.md b/docs/contributing.md index 37a6793a..8423c202 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -21,7 +21,7 @@ pip install -r requirements-docs.txt Our Python client API reference is autogenerated from our client. You can install the client in editable mode with ``` -pip install -r clients/python +pip install -e clients/python ``` ### Step 4: Run Locally From a9843a1897f8bdf2f7a0e7185ae8c8777bd4166f Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:53:03 -0800 Subject: [PATCH 223/425] Minor fixes for batch inference (#426) * Fix file not found * progress fix * add tests * bump * typing --- clients/python/llmengine/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- .../inference/batch_inference/vllm_batch.py | 5 ++-- .../tests/unit/inference/test_vllm_batch.py | 27 ++++++++++++++++++- 5 files changed, 32 insertions(+), 6 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index ac05d1bb..9a52ed74 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b21" +__version__ = "0.0.0b22" import os from typing import Sequence diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index f9809cd2..4645f34f 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta21" +version = "0.0.0.beta22" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 907a44b7..0e06917f 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta21", + version="0.0.0.beta22", packages=find_packages(), ) diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 20e0459d..1a1b3156 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -48,7 +48,8 @@ def file_exists(path): try: with smart_open.open(path, "r"): return True - except FileNotFoundError: + except Exception as exc: + print(f"Error checking if file exists: {exc}") return False @@ -124,7 +125,7 @@ async def batch_inference(): results_generators = await generate_with_vllm(request, content, model, job_index) - bar = tqdm(total=len(content.prompts), desc="Processed prompts") + bar = tqdm(total=len(results_generators), desc="Processed prompts") outputs = [] for generator in results_generators: diff --git a/model-engine/tests/unit/inference/test_vllm_batch.py b/model-engine/tests/unit/inference/test_vllm_batch.py index dd717c04..ac586d4a 100644 --- a/model-engine/tests/unit/inference/test_vllm_batch.py +++ b/model-engine/tests/unit/inference/test_vllm_batch.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, call, mock_open, patch import pytest -from model_engine_server.inference.batch_inference.vllm_batch import batch_inference +from model_engine_server.inference.batch_inference.vllm_batch import batch_inference, file_exists @pytest.mark.asyncio @@ -272,3 +272,28 @@ def side_effect(key, default): mock_s3_client.delete_object.assert_has_calls( [call(Bucket="bucket", Key="key.0"), call(Bucket="bucket", Key="key.1")] ) + + +def test_file_exists(): + mock_open_func = mock_open() + path = "test_path" + + with patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", mock_open_func + ): + result = file_exists(path) + + mock_open_func.assert_called_once_with(path, "r") + assert result is True + + +def test_file_exists_no_such_key(): + path = "test_path" + + with patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + side_effect=IOError("No such key"), + ): + result = file_exists(path) + + assert result is False From 1213b4cad8e4824b568c9e1d25cfdf329e8d55d6 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 31 Jan 2024 14:52:53 -0800 Subject: [PATCH 224/425] LLM benchmark script improvements (#427) * step concurrency * add completion time percentiles * fix up percentiles to be total request time * rename some things * wip percentiles for inter token latency * actually record the numbers * oops * add percentiles for time to first token --- scripts/throughput_benchmarks.py | 46 +++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/scripts/throughput_benchmarks.py b/scripts/throughput_benchmarks.py index a60267e5..c689d8cc 100644 --- a/scripts/throughput_benchmarks.py +++ b/scripts/throughput_benchmarks.py @@ -65,10 +65,17 @@ def send_request(url, request, user=None): stream=True, ) first_line = True + inter_token_latencies = [] + last_token_time = None for byte_payload in response.iter_lines(): + token_time = time.time() if first_line: - time_to_first_token = time.time() - start + time_to_first_token = token_time - start + last_token_time = token_time first_line = False + else: + inter_token_latencies.append(token_time - last_token_time) + last_token_time = token_time # Skip line if byte_payload == b"\n": @@ -85,6 +92,7 @@ def send_request(url, request, user=None): "payload": payload_json, "time_to_first_token": time_to_first_token, "total_time": time.time() - start, + "inter_token_latencies": inter_token_latencies, } @@ -255,7 +263,9 @@ def run_benchmark( time_to_process_prompt = [] time_per_completion = [] time_to_first_token = [] - inter_token_latency = [] + inter_token_latency = [] # one value per request, average inter-token latency in the request + total_request_time = [] + all_inter_token_latencies = [] # one value per token (except the first generated token) for result in results: avg_time_per_token = (result["total_time"] - result["time_to_first_token"]) / ( result["num_completion_tokens"] - 1 @@ -264,28 +274,57 @@ def run_benchmark( time_to_process_prompt.append(result["time_to_first_token"] - avg_time_per_token) time_per_completion.append(result["total_time"] - time_to_process_prompt[-1]) inter_token_latency.append(avg_time_per_token) + total_request_time.append(result["total_time"]) + all_inter_token_latencies.extend(result["inter_token_latencies"]) total_num_tokens = num_sampled_tokens + num_prompt_tokens avg_prefill_time = sum(time_to_process_prompt) / n avg_completion_time = sum(time_per_completion) / n + p50_request_time = np.percentile(total_request_time, 50) + p90_request_time = np.percentile(total_request_time, 90) + p95_request_time = np.percentile(total_request_time, 95) + p99_request_time = np.percentile(total_request_time, 99) + p50_inter_token_latency = np.percentile(all_inter_token_latencies, 50) + p90_inter_token_latency = np.percentile(all_inter_token_latencies, 90) + p95_inter_token_latency = np.percentile(all_inter_token_latencies, 95) + p99_inter_token_latency = np.percentile(all_inter_token_latencies, 99) + p999_inter_token_latency = np.percentile(all_inter_token_latencies, 99.9) + p50_time_to_first_token = np.percentile(time_to_first_token, 50) + p90_time_to_first_token = np.percentile(time_to_first_token, 90) + p95_time_to_first_token = np.percentile(time_to_first_token, 95) + p99_time_to_first_token = np.percentile(time_to_first_token, 99) statistics = { "concurrency": concurrency, "avg_prompt_throughput": num_prompt_tokens / (elapsed * avg_prefill_time / (avg_prefill_time + avg_completion_time)), "avg_time_to_first_token": sum(time_to_first_token) / n, + "p50_time_to_first_token": p50_time_to_first_token, + "p90_time_to_first_token": p90_time_to_first_token, + "p95_time_to_first_token": p95_time_to_first_token, + "p99_time_to_first_token": p99_time_to_first_token, "avg_sampling_throughput": num_sampled_tokens / (elapsed * avg_completion_time / (avg_prefill_time + avg_completion_time)), "avg_total_throughput": total_num_tokens / elapsed, "avg_per_session_sampling_throughput": num_sampled_tokens / (elapsed * avg_completion_time / (avg_prefill_time + avg_completion_time)) / concurrency, + "avg_request_throughput": n / elapsed, "avg_inter_token_latency": sum(inter_token_latency) / n, + "p50_inter_token_latency": p50_inter_token_latency, + "p90_inter_token_latency": p90_inter_token_latency, + "p95_inter_token_latency": p95_inter_token_latency, + "p99_inter_token_latency": p99_inter_token_latency, + "p99.9_inter_token_latency": p999_inter_token_latency, "num_prompt_tokens": prompt_num_tokens, "avg_num_sampled_tokens": num_sampled_tokens / n, "elapsed_time": elapsed, "avg_prefill_time": avg_prefill_time, "avg_completion_time": avg_completion_time, + "p50_request_time": p50_request_time, + "p90_request_time": p90_request_time, + "p95_request_time": p95_request_time, + "p99_request_time": p99_request_time, "num_requests": num_trials, "num_successful_requests": n, "total_num_tokens": total_num_tokens, @@ -361,6 +400,7 @@ def run_benchmarks_concurrency_range( use_localhost: bool = False, concurrency_min: int = 1, concurrency_max: int = 1, + concurrency_step: int = 1, verbose: bool = False, hf_model: Optional[str] = None, local_port: int = 5005, @@ -369,7 +409,7 @@ def run_benchmarks_concurrency_range( # Create empty file with open(output_file, "w"): pass - for concurrency in range(concurrency_min, concurrency_max + 1): + for concurrency in range(concurrency_min, concurrency_max + 1, concurrency_step): run_benchmarks( model, framework, From 8d8774ca5acd12877e6fec2da08bd489d1ec989f Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Fri, 2 Feb 2024 17:45:50 -0800 Subject: [PATCH 225/425] Allow using pydantic v2 (#429) * Allow using pydantic v2 * bump version --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/data_types.py | 7 ++++++- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 9a52ed74..2ea56f44 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b22" +__version__ = "0.0.0b23" import os from typing import Sequence diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 64ec45a0..b32ff7a6 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -5,7 +5,12 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel, Field, HttpUrl +import pydantic + +if int(pydantic.__version__.split(".")[0]) > 1: + from pydantic.v1 import BaseModel, Field, HttpUrl +else: + from pydantic import BaseModel, Field, HttpUrl CpuSpecificationType = Union[str, int, float] StorageSpecificationType = Union[str, int, float] # TODO(phil): we can make this more specific. diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 4645f34f..2c3868fb 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta22" +version = "0.0.0.beta23" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 0e06917f..5a641e2d 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta22", + version="0.0.0.beta23", packages=find_packages(), ) From a2a6563cc3d610bb0173abcffe4f1e44924e7f76 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 6 Feb 2024 15:40:10 -0500 Subject: [PATCH 226/425] Fix helm chart nodeSelector for GPU endpoints (#430) --- .../model-engine/templates/service_template_config_map.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 199a5b1d..5b8cf5ae 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -92,6 +92,9 @@ data: {{- toYaml . | nindent 12 }} {{- end }} {{- if eq $device "gpu" }} + {{- if empty $node_selector }} + nodeSelector: + {{- end }} k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -634,6 +637,9 @@ data: {{- toYaml . | nindent 12 }} {{- end }} {{- if eq $device "gpu" }} + {{- if empty $node_selector }} + nodeSelector: + {{- end }} k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" From ea38f1ee11c5a081ee5417957265ff945e0ad205 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 6 Feb 2024 17:18:46 -0800 Subject: [PATCH 227/425] Allow pydantic 2 in python client requested requirements (#433) * allow pydantic 2 in requirements * bump version * ignore mypy for line * fix mypy * can't just cast to HttpUrl like that --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/data_types.py | 12 ++++++------ clients/python/llmengine/model.py | 3 ++- clients/python/pyproject.toml | 4 ++-- clients/python/setup.py | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 2ea56f44..f579cf83 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b23" +__version__ = "0.0.0b24" import os from typing import Sequence diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index b32ff7a6..209084aa 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -10,7 +10,7 @@ if int(pydantic.__version__.split(".")[0]) > 1: from pydantic.v1 import BaseModel, Field, HttpUrl else: - from pydantic import BaseModel, Field, HttpUrl + from pydantic import BaseModel, Field, HttpUrl # type: ignore CpuSpecificationType = Union[str, int, float] StorageSpecificationType = Union[str, int, float] # TODO(phil): we can make this more specific. @@ -163,17 +163,17 @@ class CreateLLMEndpointRequest(BaseModel): cpus: CpuSpecificationType gpus: int memory: StorageSpecificationType - gpu_type: GpuType + gpu_type: Optional[GpuType] storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] + optimize_costs: Optional[bool] = None min_workers: int max_workers: int per_worker: int labels: Dict[str, str] - prewarm: Optional[bool] + prewarm: Optional[bool] = None high_priority: Optional[bool] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] + default_callback_url: Optional[HttpUrl] = None + default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = True """ Whether the endpoint can be used for inference for all users. LLM endpoints are public by default. diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index fa84d1e3..35e26631 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -277,7 +277,8 @@ def create( per_worker=per_worker, high_priority=high_priority, post_inference_hooks=post_inference_hooks_strs, - default_callback_url=default_callback_url, + # Pydantic automatically validates the url + default_callback_url=default_callback_url, # type: ignore storage=storage, public_inference=public_inference, ) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 2c3868fb..2f90d572 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta23" +version = "0.0.0.beta24" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] @@ -13,7 +13,7 @@ packages = [{include = "llmengine"}] [tool.poetry.dependencies] python = "^3.7" -pydantic = "^1.10" +pydantic = ">=1.10" aiohttp = "^3.8" requests = "^2.31.0" diff --git a/clients/python/setup.py b/clients/python/setup.py index 5a641e2d..a7d2b8d8 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta23", + version="0.0.0.beta24", packages=find_packages(), ) From 70285752e41dd5d97910139183f67f92a6a12d3a Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 6 Feb 2024 23:04:33 -0800 Subject: [PATCH 228/425] Fix permissions (#431) * Fix s5cmd env vars * more fixes for s5cmd * dont error * add back aws_profile * flush * fix test --- .../inference/batch_inference/vllm_batch.py | 13 +++++++--- .../tests/unit/inference/test_vllm_batch.py | 25 ++++++++++++++++--- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 1a1b3156..f976b87b 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -27,12 +27,17 @@ def get_s3_client(): def download_model(checkpoint_path, final_weights_folder): - s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + s5cmd = f"./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + env = os.environ.copy() + env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") + # Need to override these env vars so s5cmd uses AWS_PROFILE + env["AWS_ROLE_ARN"] = "" + env["AWS_WEB_IDENTITY_TOKEN_FILE"] = "" process = subprocess.Popen( - s5cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + s5cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env ) for line in process.stdout: - print(line) + print(line, flush=True) process.wait() @@ -41,7 +46,7 @@ def download_model(checkpoint_path, final_weights_folder): for line in iter(process.stderr.readline, ""): stderr_lines.append(line.strip()) - raise IOError(f"Error downloading model weights: {stderr_lines}") + print(f"Error downloading model weights: {stderr_lines}", flush=True) def file_exists(path): diff --git a/model-engine/tests/unit/inference/test_vllm_batch.py b/model-engine/tests/unit/inference/test_vllm_batch.py index ac586d4a..e9ab0937 100644 --- a/model-engine/tests/unit/inference/test_vllm_batch.py +++ b/model-engine/tests/unit/inference/test_vllm_batch.py @@ -74,7 +74,7 @@ async def test_batch_inference( new_callable=mock_open, read_data="Mocked content", ) -async def test_batch_inference_failed_to_download_model( +async def test_batch_inference_failed_to_download_model_but_proceed( mock_open_func, mock_popen, mock_get_s3_client, @@ -86,9 +86,10 @@ async def test_batch_inference_failed_to_download_model( create_vllm_request_outputs, mock_s3_client, mock_process, + mock_completion_output, ): # Mock the necessary objects and data - mock_process.returncode = 1 + mock_process.returncode = 1 # Failed to download model mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request @@ -96,9 +97,25 @@ async def test_batch_inference_failed_to_download_model( create_batch_completions_request_content ) + mock_results_generator = MagicMock() + mock_results_generator.__aiter__.return_value = create_vllm_request_outputs + + # Mock the generate_with_vllm function + mock_generate_with_vllm.return_value = [mock_results_generator] + # Call the function - with pytest.raises(IOError): - await batch_inference() + await batch_inference() + + # Assertions + mock_create_batch_completions_request.parse_file.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + ], + any_order=True, + ) @pytest.mark.asyncio From e07fc7a5b7ec903327edc331c2b39339ae9410a8 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 7 Feb 2024 11:11:18 -0800 Subject: [PATCH 229/425] [Client] Add Auth headers to the python async routes (#434) * add Auth headers to the async routes * bump version --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/api_engine.py | 10 +++++++--- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index f579cf83..cc19aefd 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b24" +__version__ = "0.0.0b25" import os from typing import Sequence diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index 1431d6cb..3abf86d6 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -8,7 +8,7 @@ from urllib.parse import urljoin import requests -from aiohttp import ClientSession, ClientTimeout +from aiohttp import BasicAuth, ClientSession, ClientTimeout from llmengine.errors import parse_error SPELLBOOK_API_URL = "https://api.spellbook.scale.com/llm-engine/" @@ -163,7 +163,9 @@ async def apost_sync( ) -> Dict[str, Any]: api_key = get_api_key() async with ClientSession( - timeout=ClientTimeout(timeout), headers={"x-api-key": api_key} + timeout=ClientTimeout(timeout), + headers={"x-api-key": api_key}, + auth=BasicAuth(api_key, ""), ) as session: async with session.post( urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data @@ -179,7 +181,9 @@ async def apost_stream( ) -> AsyncIterable[Dict[str, Any]]: api_key = get_api_key() async with ClientSession( - timeout=ClientTimeout(timeout), headers={"x-api-key": api_key} + timeout=ClientTimeout(timeout), + headers={"x-api-key": api_key}, + auth=BasicAuth(api_key, ""), ) as session: async with session.post( urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 2f90d572..a0afe290 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta24" +version = "0.0.0.beta25" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index a7d2b8d8..961459dc 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta24", + version="0.0.0.beta25", packages=find_packages(), ) From 847317e74405562b5641ca70722d611a076ff44e Mon Sep 17 00:00:00 2001 From: Edward Gan Date: Wed, 7 Feb 2024 16:18:26 -0800 Subject: [PATCH 230/425] pin boto3 and urllib3 version (#432) --- .../model_engine_server/inference/requirements_base.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/requirements_base.txt b/model-engine/model_engine_server/inference/requirements_base.txt index 14f6577a..4561bd06 100644 --- a/model-engine/model_engine_server/inference/requirements_base.txt +++ b/model-engine/model_engine_server/inference/requirements_base.txt @@ -1,5 +1,6 @@ aioredis~=2.0 -boto3>=1.28.38 +urllib3~=1.26.13 +boto3~=1.34.33 celery[redis,sqs,tblib]==5.3.1 datadog-api-client==2.11.0 datadog~=0.47.0 From 5bff3451bbbcade72c04fb104ba9f8a84fca0a4d Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 13 Feb 2024 14:15:37 -0800 Subject: [PATCH 231/425] include stop string in output (#435) --- .../model_engine_server/common/dtos/llms.py | 8 ++++++++ .../use_cases/llm_model_endpoint_use_cases.py | 13 +++++++++++++ 2 files changed, 21 insertions(+) diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 11c21e24..35e1c744 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -180,6 +180,10 @@ class CompletionSyncV1Request(BaseModel): """ Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ class TokenOutput(BaseModel): @@ -240,6 +244,10 @@ class CompletionStreamV1Request(BaseModel): """ Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ class CompletionStreamOutput(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 8a18273e..eb71bf82 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1349,6 +1349,15 @@ def validate_and_update_completion_params( "return_token_log_probs is only supported in deepspeed, text-generation-inference, vllm, lightllm." ) + # include_stop_str_in_output + if inference_framework == LLMInferenceFramework.VLLM: + pass + else: + if request.include_stop_str_in_output is not None: + raise ObjectHasInvalidValueException( + "include_stop_str_in_output is only supported in vllm." + ) + return request @@ -1634,6 +1643,8 @@ async def execute( vllm_args["top_p"] = request.top_p if request.return_token_log_probs: vllm_args["logprobs"] = 1 + if request.include_stop_str_in_output is not None: + vllm_args["include_stop_str_in_output"] = request.include_stop_str_in_output inference_request = SyncEndpointPredictV1Request( args=vllm_args, @@ -1888,6 +1899,8 @@ async def execute( args["top_p"] = request.top_p if request.return_token_log_probs: args["logprobs"] = 1 + if request.include_stop_str_in_output is not None: + args["include_stop_str_in_output"] = request.include_stop_str_in_output args["stream"] = True elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: args = { From c427d0b5288d3c6b60c7d6f0c572c806921591a3 Mon Sep 17 00:00:00 2001 From: tiffzhao5 <142925794+tiffzhao5@users.noreply.github.com> Date: Wed, 14 Feb 2024 22:21:18 -0800 Subject: [PATCH 232/425] Logging post inference hook implementation (#428) * add logging hook * plumb other columns through endpoint config * add args to unit tests * change to labels * add config and assume role * new storage gateway * add stream name config * fix test * undo conftest * handle error * fix test * fake streaming storage gateway * move client to fn * change to fake gateway * PR comments * catch err * update test * remove error in response * try small test * add more tests * fix test --------- Co-authored-by: Sai Atmakuri --- .../model_engine_server/api/dependencies.py | 10 ++ .../model_engine_server/common/constants.py | 1 + .../model_engine_server/core/config.py | 2 + .../domain/entities/model_endpoint_entity.py | 4 + .../model_engine_server/domain/exceptions.py | 6 ++ .../inference/async_inference/tasks.py | 8 ++ .../gateways/streaming_storage_gateway.py | 19 ++++ .../inference/forwarding/forwarding.py | 13 +++ .../firehose_streaming_storage_gateway.py | 63 +++++++++++++ .../inference/post_inference_hooks.py | 82 +++++++++++++++- .../sync_inference/fastapi_server.py | 8 ++ .../gateways/resources/k8s_resource_types.py | 4 + model-engine/tests/unit/conftest.py | 16 ++++ .../tests/unit/inference/test_forwarding.py | 6 ++ .../unit/inference/test_http_forwarder.py | 38 ++++++++ ...test_firehose_streaming_storage_gateway.py | 93 +++++++++++++++++++ 16 files changed, 372 insertions(+), 1 deletion(-) create mode 100644 model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py create mode 100644 model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py create mode 100644 model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index b68080b8..140c5b11 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -43,6 +43,12 @@ LLMModelEndpointService, ModelEndpointService, ) +from model_engine_server.inference.domain.gateways.streaming_storage_gateway import ( + StreamingStorageGateway, +) +from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( + FirehoseStreamingStorageGateway, +) from model_engine_server.infra.gateways import ( CeleryTaskQueueGateway, FakeMonitoringMetricsGateway, @@ -137,6 +143,7 @@ class ExternalInterfaces: cron_job_gateway: CronJobGateway monitoring_metrics_gateway: MonitoringMetricsGateway tokenizer_repository: TokenizerRepository + streaming_storage_gateway: StreamingStorageGateway def get_default_monitoring_metrics_gateway() -> MonitoringMetricsGateway: @@ -265,6 +272,8 @@ def _get_external_interfaces( tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway) + streaming_storage_gateway = FirehoseStreamingStorageGateway() + external_interfaces = ExternalInterfaces( docker_repository=docker_repository, model_bundle_repository=model_bundle_repository, @@ -287,6 +296,7 @@ def _get_external_interfaces( cron_job_gateway=cron_job_gateway, monitoring_metrics_gateway=monitoring_metrics_gateway, tokenizer_repository=tokenizer_repository, + streaming_storage_gateway=streaming_storage_gateway, ) return external_interfaces diff --git a/model-engine/model_engine_server/common/constants.py b/model-engine/model_engine_server/common/constants.py index 567df502..53795c41 100644 --- a/model-engine/model_engine_server/common/constants.py +++ b/model-engine/model_engine_server/common/constants.py @@ -2,6 +2,7 @@ BILLING_POST_INFERENCE_HOOK: str = "billing" CALLBACK_POST_INFERENCE_HOOK: str = "callback" +LOGGING_POST_INFERENCE_HOOK: str = "logging" READYZ_FPATH: str = "/tmp/readyz" DEFAULT_CELERY_TASK_NAME: str = "hosted_model_inference.inference.async_inference.tasks.predict" LIRA_CELERY_TASK_NAME: str = "ml_serve.celery_service.exec_func" diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index ef08839c..6403d64f 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -42,6 +42,8 @@ class InfraConfig: profile_ml_worker: str = "default" profile_ml_inference_worker: str = "default" identity_service_url: Optional[str] = None + firehose_role_arn: Optional[str] = None + firehose_stream_name: Optional[str] = None @classmethod def from_yaml(cls, yaml_path) -> "InfraConfig": diff --git a/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py index 809035ba..cb6277f6 100644 --- a/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py @@ -88,6 +88,10 @@ class ModelEndpointConfig(BaseModel): billing_tags: Optional[Dict[str, Any]] = None default_callback_url: Optional[str] = None default_callback_auth: Optional[CallbackAuth] + endpoint_id: Optional[str] = None + endpoint_type: Optional[ModelEndpointType] + bundle_id: Optional[str] = None + labels: Optional[Dict[str, str]] = None def serialize(self) -> str: return python_json_to_b64(dict_not_none(**self.dict())) diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index b78bb281..7b5ff902 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -164,3 +164,9 @@ class TriggerNameAlreadyExistsException(DomainException): """ Thrown if the requested name already exists in the trigger repository """ + + +class StreamPutException(DomainException): + """ + Thrown if the streaming storage gateway fails to put a record. + """ diff --git a/model-engine/model_engine_server/inference/async_inference/tasks.py b/model-engine/model_engine_server/inference/async_inference/tasks.py index 6fce0588..62bde09c 100644 --- a/model-engine/model_engine_server/inference/async_inference/tasks.py +++ b/model-engine/model_engine_server/inference/async_inference/tasks.py @@ -18,6 +18,9 @@ from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) +from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( + FirehoseStreamingStorageGateway, +) from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler logger = make_logger(logger_name()) @@ -46,6 +49,11 @@ def init_worker_global(): default_callback_url=endpoint_config.default_callback_url, default_callback_auth=endpoint_config.default_callback_auth, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id=endpoint_config.endpoint_id, + endpoint_type=endpoint_config.endpoint_type, + bundle_id=endpoint_config.bundle_id, + labels=endpoint_config.labels, + streaming_storage_gateway=FirehoseStreamingStorageGateway(), ) # k8s health check with open(READYZ_FPATH, "w") as f: diff --git a/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py b/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py new file mode 100644 index 00000000..ae4216dd --- /dev/null +++ b/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class StreamingStorageGateway(ABC): + """ + Base class for a gateway that stores data through a streaming mechanism. + """ + + @abstractmethod + def put_record(self, stream_name: str, record: Dict[str, Any]) -> None: + """ + Put a record into a streaming storage mechanism. + + Args: + stream_name: The name of the stream. + record: The record to put into the stream. + """ + pass diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 4bbe885d..38b4e8cc 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -14,6 +14,9 @@ from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) +from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( + FirehoseStreamingStorageGateway, +) from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler __all__: Sequence[str] = ( @@ -279,6 +282,11 @@ def endpoint(route: str) -> str: default_callback_url=endpoint_config.default_callback_url, default_callback_auth=endpoint_config.default_callback_auth, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id=endpoint_config.endpoint_id, + endpoint_type=endpoint_config.endpoint_type, + bundle_id=endpoint_config.bundle_id, + labels=endpoint_config.labels, + streaming_storage_gateway=FirehoseStreamingStorageGateway(), ) return Forwarder( @@ -451,6 +459,11 @@ def endpoint(route: str) -> str: default_callback_url=endpoint_config.default_callback_url, default_callback_auth=endpoint_config.default_callback_auth, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id=endpoint_config.endpoint_id, + endpoint_type=endpoint_config.endpoint_type, + bundle_id=endpoint_config.bundle_id, + labels=endpoint_config.labels, + streaming_storage_gateway=FirehoseStreamingStorageGateway(), ) return StreamingForwarder( diff --git a/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py new file mode 100644 index 00000000..ab718737 --- /dev/null +++ b/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py @@ -0,0 +1,63 @@ +import json +from typing import Any, Dict + +import boto3 +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import StreamPutException +from model_engine_server.inference.domain.gateways.streaming_storage_gateway import ( + StreamingStorageGateway, +) + +logger = make_logger(logger_name()) + + +class FirehoseStreamingStorageGateway(StreamingStorageGateway): + """ + A gateway that stores data through the AWS Kinesis Firehose streaming mechanism. + """ + + def __init__(self): + pass + + """ + Creates a new firehose client. + + Streams with Snowflake as a destination and the AWS profile live in different + accounts. Firehose doesn't support resource-based policies, so we need to assume + a new role to write to the stream. + """ + + def _get_firehose_client(self): + sts_client = boto3.client("sts", region_name=infra_config().default_region) + assumed_role_object = sts_client.assume_role( + RoleArn=infra_config().firehose_role_arn, + RoleSessionName="AssumeMlLoggingRoleSession", + ) + credentials = assumed_role_object["Credentials"] + session = boto3.Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + firehose_client = session.client("firehose", region_name=infra_config().default_region) + return firehose_client + + def put_record(self, stream_name: str, record: Dict[str, Any]) -> None: + """ + Put a record into a Firehose stream. + + Args: + stream_name: The name of the stream. + record: The record to put into the stream. + """ + firehose_response = self._get_firehose_client().put_record( + DeliveryStreamName=stream_name, Record={"Data": json.dumps(record).encode("utf-8")} + ) + if firehose_response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise StreamPutException( + f"Failed to put record into firehose stream {stream_name}. Record content: {record}" + ) + logger.info( + f"Logged to firehose stream {stream_name}. Record content: {record}, Record ID: {firehose_response['RecordId']}" + ) diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py index 142e998e..6f388acb 100644 --- a/model-engine/model_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -4,13 +4,22 @@ import requests from fastapi.responses import JSONResponse -from model_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK +from model_engine_server.common.constants import ( + CALLBACK_POST_INFERENCE_HOOK, + LOGGING_POST_INFERENCE_HOOK, +) from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth +from model_engine_server.domain.entities.model_endpoint_entity import ModelEndpointType +from model_engine_server.domain.exceptions import StreamPutException from model_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( InferenceMonitoringMetricsGateway, ) +from model_engine_server.inference.domain.gateways.streaming_storage_gateway import ( + StreamingStorageGateway, +) from tenacity import Retrying, stop_after_attempt, wait_exponential logger = make_logger(logger_name()) @@ -76,6 +85,61 @@ def handle( assert 200 <= res.status_code < 300 +class LoggingHook(PostInferenceHook): + def __init__( + self, + endpoint_name: str, + bundle_name: str, + user_id: str, + endpoint_id: Optional[str], + endpoint_type: Optional[ModelEndpointType], + bundle_id: Optional[str], + labels: Optional[Dict[str, str]], + streaming_storage_gateway: StreamingStorageGateway, + ): + super().__init__(endpoint_name, bundle_name, user_id) + self._endpoint_id = endpoint_id + self._endpoint_type = endpoint_type + self._bundle_id = bundle_id + self._labels = labels + self._streaming_storage_gateway = streaming_storage_gateway + + def handle( + self, + request_payload: EndpointPredictV1Request, + response: Dict[str, Any], + task_id: Optional[str], + ): + if ( + not self._endpoint_id + or not self._endpoint_type + or not self._bundle_id + or not self._labels + ): + logger.warning( + "No endpoint_id, endpoint_type, bundle_id, or labels specified for request." + ) + return + response["task_id"] = task_id + data_record = { + "REQUEST_BODY": request_payload.json(), + "RESPONSE_BODY": response, + "ENDPOINT_ID": self._endpoint_id, + "ENDPOINT_NAME": self._endpoint_name, + "ENDPOINT_TYPE": self._endpoint_type.value, + "BUNDLE_ID": self._bundle_id, + "LABELS": self._labels, + } + stream_name = infra_config().firehose_stream_name + if stream_name is None: + logger.warning("No firehose stream name specified. Logging hook will not be executed.") + return + try: + self._streaming_storage_gateway.put_record(stream_name=stream_name, record=data_record) + except StreamPutException as e: + logger.error(f"Error in logging hook {e}") + + class PostInferenceHooksHandler: def __init__( self, @@ -88,6 +152,11 @@ def __init__( default_callback_auth: Optional[CallbackAuth], post_inference_hooks: Optional[List[str]], monitoring_metrics_gateway: InferenceMonitoringMetricsGateway, + endpoint_id: Optional[str], + endpoint_type: Optional[ModelEndpointType], + bundle_id: Optional[str], + labels: Optional[Dict[str, str]], + streaming_storage_gateway: StreamingStorageGateway, ): self._monitoring_metrics_gateway = monitoring_metrics_gateway self._hooks: Dict[str, PostInferenceHook] = {} @@ -104,6 +173,17 @@ def __init__( default_callback_url, default_callback_auth, ) + elif hook_lower == LOGGING_POST_INFERENCE_HOOK: + self._hooks[hook_lower] = LoggingHook( + endpoint_name, + bundle_name, + user_id, + endpoint_id, + endpoint_type, + bundle_id, + labels, + streaming_storage_gateway, + ) else: raise ValueError(f"Hook {hook_lower} is currently not supported.") diff --git a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py index 02b68eca..3d30bf0c 100644 --- a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py @@ -13,6 +13,9 @@ from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) +from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( + FirehoseStreamingStorageGateway, +) from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler from model_engine_server.inference.sync_inference.constants import ( CONCURRENCY, @@ -52,6 +55,11 @@ def _inner_2(*args, **kwargs): default_callback_url=endpoint_config.default_callback_url, default_callback_auth=endpoint_config.default_callback_auth, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id=endpoint_config.endpoint_id, + endpoint_type=endpoint_config.endpoint_type, + bundle_id=endpoint_config.bundle_id, + labels=endpoint_config.labels, + streaming_storage_gateway=FirehoseStreamingStorageGateway(), ) diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index e9d4a657..483c5e5b 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -1082,6 +1082,10 @@ def get_endpoint_resource_arguments_from_request( billing_tags=build_endpoint_request.billing_tags, default_callback_url=build_endpoint_request.default_callback_url, default_callback_auth=build_endpoint_request.default_callback_auth, + endpoint_id=model_endpoint_record.id, + endpoint_type=model_endpoint_record.endpoint_type, + bundle_id=model_bundle.id, + labels=build_endpoint_request.labels, ).serialize() return EndpointConfigArguments( # Base resource arguments diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index eeef573c..e7ad32cc 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -114,6 +114,9 @@ LLMModelEndpointService, ModelEndpointService, ) +from model_engine_server.inference.domain.gateways.streaming_storage_gateway import ( + StreamingStorageGateway, +) from model_engine_server.infra.gateways import ( BatchJobOrchestrationGateway, LiveBatchJobProgressGateway, @@ -1555,6 +1558,11 @@ async def emit_prewarm_metric(self, endpoint_id: str): pass +class FakeStreamingStorageGateway(StreamingStorageGateway): + def put_record(self, stream_name: str, record: Dict[str, Any]): + pass + + class FakeModelEndpointService(ModelEndpointService): db: Dict[str, ModelEndpoint] @@ -2100,6 +2108,12 @@ def fake_tokenizer_repository() -> TokenizerRepository: return FakeTokenizerRepository() +@pytest.fixture +def fake_streaming_storage_gateway() -> StreamingStorageGateway: + gateway = FakeStreamingStorageGateway() + return gateway + + @pytest.fixture def get_repositories_generator_wrapper(): def get_repositories_generator( @@ -2188,6 +2202,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: fake_llm_fine_tuning_events_repository = FakeLLMFineTuneEventsRepository() fake_file_storage_gateway = FakeFileStorageGateway(fake_file_storage_gateway_contents) fake_tokenizer_repository = FakeTokenizerRepository() + fake_streaming_storage_gateway = FakeStreamingStorageGateway() repositories = ExternalInterfaces( docker_repository=FakeDockerRepository( @@ -2213,6 +2228,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: llm_artifact_gateway=fake_llm_artifact_gateway, monitoring_metrics_gateway=fake_monitoring_metrics_gateway, tokenizer_repository=fake_tokenizer_repository, + streaming_storage_gateway=fake_streaming_storage_gateway, ) try: yield repositories diff --git a/model-engine/tests/unit/inference/test_forwarding.py b/model-engine/tests/unit/inference/test_forwarding.py index 07117967..68c9ab32 100644 --- a/model-engine/tests/unit/inference/test_forwarding.py +++ b/model-engine/tests/unit/inference/test_forwarding.py @@ -19,6 +19,7 @@ DatadogInferenceMonitoringMetricsGateway, ) from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from tests.unit.conftest import FakeStreamingStorageGateway PAYLOAD: Mapping[str, str] = {"hello": "world"} @@ -86,6 +87,11 @@ def post_inference_hooks_handler(): default_callback_url=None, default_callback_auth=None, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id="test_endpoint_id", + endpoint_type="sync", + bundle_id="test_bundle_id", + labels={}, + streaming_storage_gateway=FakeStreamingStorageGateway(), ) return handler diff --git a/model-engine/tests/unit/inference/test_http_forwarder.py b/model-engine/tests/unit/inference/test_http_forwarder.py index 1ded1624..bad6e6b4 100644 --- a/model-engine/tests/unit/inference/test_http_forwarder.py +++ b/model-engine/tests/unit/inference/test_http_forwarder.py @@ -16,6 +16,7 @@ DatadogInferenceMonitoringMetricsGateway, ) from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from tests.unit.conftest import FakeStreamingStorageGateway PAYLOAD: Mapping[str, str] = {"hello": "world"} @@ -68,6 +69,32 @@ def post_inference_hooks_handler(): default_callback_url=None, default_callback_auth=None, monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id="test_endpoint_id", + endpoint_type="sync", + bundle_id="test_bundle_id", + labels={}, + streaming_storage_gateway=FakeStreamingStorageGateway(), + ) + return handler + + +@pytest.fixture +def post_inference_hooks_handler_with_logging(): + handler = PostInferenceHooksHandler( + endpoint_name="test_endpoint_name", + bundle_name="test_bundle_name", + post_inference_hooks=["logging"], + user_id="test_user_id", + billing_queue="billing_queue", + billing_tags=[], + default_callback_url=None, + default_callback_auth=None, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id="test_endpoint_id", + endpoint_type="sync", + bundle_id="test_bundle_id", + labels={}, + streaming_storage_gateway=FakeStreamingStorageGateway(), ) return handler @@ -124,3 +151,14 @@ def test_handler_json_response(post_inference_hooks_handler): ) except Exception as e: pytest.fail(f"Unexpected exception: {e}") + + +def test_handler_with_logging(post_inference_hooks_handler_with_logging): + try: + post_inference_hooks_handler_with_logging.handle( + request_payload=mock_request, + response=JSONResponse(content=PAYLOAD), + task_id="test_task_id", + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") diff --git a/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py b/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py new file mode 100644 index 00000000..1cedaef6 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py @@ -0,0 +1,93 @@ +from unittest import mock + +import pytest +from model_engine_server.domain.exceptions import StreamPutException +from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( + FirehoseStreamingStorageGateway, +) + +stream_name = "fake-stream" + + +@pytest.fixture +def streaming_storage_gateway(): + gateway = FirehoseStreamingStorageGateway() + return gateway + + +@pytest.fixture +def fake_record(): + return {"Data": "fake-data"} + + +def mock_sts_client(*args, **kwargs): + mock_client = mock.Mock() + mock_client.assume_role.return_value = { + "Credentials": { + "AccessKeyId": "fake-access-key-id", + "SecretAccessKey": "fake-secret-access-key", + "SessionToken": "fake-session-token", + } + } + return mock_client + + +def mock_firehose_client(*args, **kwargs): + mock_client = mock.Mock() + mock_client.put_record.return_value = { + "RecordId": "fake-record-id", + "Encrypted": False, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + return mock_client + + +def mock_session(*args, **kwargs): + mock_session_obj = mock.Mock() + mock_firehose = mock_firehose_client() + mock_session_obj.client.return_value = mock_firehose + return mock_session_obj + + +def mock_firehose_client_with_exception(*args, **kwargs): + mock_client = mock.Mock() + mock_client.put_record.return_value = { + "RecordId": "fake-record-id", + "Encrypted": False, + "ResponseMetadata": {"HTTPStatusCode": 500}, + } + return mock_client + + +def mock_session_with_exception(*args, **kwargs): + mock_session_obj = mock.Mock() + mock_firehose = mock_firehose_client_with_exception() + + mock_session_obj.client.return_value = mock_firehose + + return mock_session_obj + + +def test_firehose_streaming_storage_gateway_put_record(streaming_storage_gateway, fake_record): + with mock.patch( + "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.client", + mock_sts_client, + ), mock.patch( + "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.Session", + mock_session, + ): + assert streaming_storage_gateway.put_record(stream_name, fake_record) is None + + +def test_firehose_streaming_storage_gateway_put_record_with_exception( + streaming_storage_gateway, fake_record +): + with mock.patch( + "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.client", + mock_sts_client, + ), mock.patch( + "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.Session", + mock_session_with_exception, + ): + with pytest.raises(StreamPutException): + streaming_storage_gateway.put_record(stream_name, fake_record) From 0541e4962bb866b7606715bfa9e4547451bd2529 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:04:32 -0800 Subject: [PATCH 233/425] add codellama-70b models (#436) --- .../use_cases/llm_model_endpoint_use_cases.py | 42 ++++++++++++------- .../repositories/live_tokenizer_repository.py | 2 + 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index eb71bf82..be56c7c5 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -174,6 +174,8 @@ "codellama-13b-instruct", "codellama-34b", "codellama-34b-instruct", + "codellama-70b", + "codellama-70b-instruct", "mistral-7b", "mistral-7b-instruct", "mixtral-8x7b", @@ -1579,9 +1581,11 @@ async def execute( else: raise UpstreamServiceError( status_code=500, - content=predict_result.traceback.encode("utf-8") - if predict_result.traceback is not None - else b"", + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), ) elif ( endpoint_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE @@ -1615,9 +1619,11 @@ async def execute( if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, - content=predict_result.traceback.encode("utf-8") - if predict_result.traceback is not None - else b"", + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), ) output = json.loads(predict_result.result["result"]) @@ -1658,9 +1664,11 @@ async def execute( if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, - content=predict_result.traceback.encode("utf-8") - if predict_result.traceback is not None - else b"", + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), ) output = json.loads(predict_result.result["result"]) @@ -1702,9 +1710,11 @@ async def execute( if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, - content=predict_result.traceback.encode("utf-8") - if predict_result.traceback is not None - else b"", + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), ) output = json.loads(predict_result.result["result"]) @@ -1740,9 +1750,11 @@ async def execute( if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, - content=predict_result.traceback.encode("utf-8") - if predict_result.traceback is not None - else b"", + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), ) output = json.loads(predict_result.result["result"]) diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index b586bc9c..1140686f 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -52,6 +52,8 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "codellama-13b-instruct": ModelInfo("codellama/CodeLlama-13b-Instruct-hf", None), "codellama-34b": ModelInfo("codellama/CodeLlama-34b-hf", None), "codellama-34b-instruct": ModelInfo("codellama/CodeLlama-34b-Instruct-hf", None), + "codellama-70b": ModelInfo("codellama/CodeLlama-70b-hf", None), + "codellama-70b-instruct": ModelInfo("codellama/CodeLlama-70b-Instruct-hf", None), "llm-jp-13b-instruct-full": ModelInfo("llm-jp/llm-jp-13b-instruct-full-jaster-v1.0", None), "llm-jp-13b-instruct-full-dolly": ModelInfo( "llm-jp/llm-jp-13b-instruct-full-dolly-oasst-v1.0", None From da86a9deb69267ee8c6a564d360c6c225f4b9c7c Mon Sep 17 00:00:00 2001 From: tiffzhao5 <142925794+tiffzhao5@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:37:35 -0800 Subject: [PATCH 234/425] Add hook validation and support logging for python client (#437) * add hook validation and support logging * fix test --- .../model_engine_server/api/model_endpoints_v1.py | 3 +++ .../model_engine_server/common/constants.py | 5 +++++ .../model_engine_server/domain/exceptions.py | 6 ++++++ .../domain/use_cases/model_endpoint_use_cases.py | 15 ++++++--------- .../unit/domain/test_model_endpoint_use_cases.py | 3 ++- 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/model-engine/model_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py index eece8be3..662e5ef8 100644 --- a/model-engine/model_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -3,6 +3,7 @@ List model endpoint history: GET model-endpoints//history Read model endpoint creation logs: GET model-endpoints//creation-logs """ + from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query @@ -34,6 +35,7 @@ ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, + PostInferenceHooksException, ) from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CreateModelEndpointV1UseCase, @@ -150,6 +152,7 @@ async def update_model_endpoint( EndpointLabelsException, ObjectHasInvalidValueException, EndpointResourceInvalidRequestException, + PostInferenceHooksException, ) as exc: raise HTTPException( status_code=400, diff --git a/model-engine/model_engine_server/common/constants.py b/model-engine/model_engine_server/common/constants.py index 53795c41..00d00d6c 100644 --- a/model-engine/model_engine_server/common/constants.py +++ b/model-engine/model_engine_server/common/constants.py @@ -3,6 +3,11 @@ BILLING_POST_INFERENCE_HOOK: str = "billing" CALLBACK_POST_INFERENCE_HOOK: str = "callback" LOGGING_POST_INFERENCE_HOOK: str = "logging" +SUPPORTED_POST_INFERENCE_HOOKS: list = [ + BILLING_POST_INFERENCE_HOOK, + CALLBACK_POST_INFERENCE_HOOK, + LOGGING_POST_INFERENCE_HOOK, +] READYZ_FPATH: str = "/tmp/readyz" DEFAULT_CELERY_TASK_NAME: str = "hosted_model_inference.inference.async_inference.tasks.predict" LIRA_CELERY_TASK_NAME: str = "ml_serve.celery_service.exec_func" diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index 7b5ff902..32a16bd8 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -170,3 +170,9 @@ class StreamPutException(DomainException): """ Thrown if the streaming storage gateway fails to put a record. """ + + +class PostInferenceHooksException(DomainException): + """ + Thrown if the post inference hooks are invalid. + """ diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index 1f128e65..bfd51b17 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -7,10 +7,7 @@ import re from typing import Any, Dict, List, Optional -from model_engine_server.common.constants import ( - BILLING_POST_INFERENCE_HOOK, - CALLBACK_POST_INFERENCE_HOOK, -) +from model_engine_server.common.constants import SUPPORTED_POST_INFERENCE_HOOKS from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, @@ -41,6 +38,7 @@ ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, + PostInferenceHooksException, ) from model_engine_server.domain.repositories import ModelBundleRepository from model_engine_server.domain.services import ModelEndpointService @@ -184,11 +182,10 @@ def validate_post_inference_hooks(user: User, post_inference_hooks: Optional[Lis return for hook in post_inference_hooks: - if hook not in [ - BILLING_POST_INFERENCE_HOOK, - CALLBACK_POST_INFERENCE_HOOK, - ]: - raise ValueError(f"Unsupported post-inference hook {hook}") + if hook not in SUPPORTED_POST_INFERENCE_HOOKS: + raise PostInferenceHooksException( + f"Unsupported post-inference hook {hook}. The supported hooks are: {SUPPORTED_POST_INFERENCE_HOOKS}" + ) class CreateModelEndpointV1UseCase: diff --git a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py index d0b27514..e9958b11 100644 --- a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py @@ -25,6 +25,7 @@ ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, + PostInferenceHooksException, ) from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CONVERTED_FROM_ARTIFACT_LIKE_KEY, @@ -463,7 +464,7 @@ async def test_create_model_endpoint_use_case_validates_post_inference_hooks( request = create_model_endpoint_request_async.copy() request.post_inference_hooks = ["invalid_hook"] - with pytest.raises(ValueError): + with pytest.raises(PostInferenceHooksException): await use_case.execute(user=user, request=request) From 4d0cd26087c3436b6bde2c722018d5b486e78fa9 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 20 Feb 2024 09:27:33 -0800 Subject: [PATCH 235/425] Azure refactor for async endpoints (#425) --- charts/model-engine/templates/_helpers.tpl | 44 +++++++- .../templates/cacher_deployment.yaml | 2 + .../celery_autoscaler_stateful_set.yaml | 23 +++- .../endpoint_builder_deployment.yaml | 2 + .../templates/gateway_deployment.yaml | 2 + .../service_template_config_map.yaml | 34 +++++- charts/model-engine/values_circleci.yaml | 6 +- charts/model-engine/values_sample.yaml | 17 ++- docs/guides/self_hosting.md | 3 +- .../model_engine_server/api/dependencies.py | 100 +++++++++++++----- .../model_engine_server/common/config.py | 19 +++- .../common/dtos/model_endpoints.py | 2 + model-engine/model_engine_server/common/io.py | 19 +++- .../model_engine_server/core/celery/abs.py | 23 ++++ .../model_engine_server/core/celery/app.py | 18 +++- .../core/celery/celery_autoscaler.py | 67 +++++++++++- .../model_engine_server/core/config.py | 1 + .../core/configs/default.yaml | 1 + .../model_engine_server/core/utils/url.py | 21 ++++ model-engine/model_engine_server/db/base.py | 43 ++++++-- .../entrypoints/k8s_cache.py | 40 ++++--- .../start_batch_job_orchestration.py | 44 +++++--- .../inference/forwarding/celery_forwarder.py | 21 +++- .../infra/gateways/__init__.py | 6 ++ .../gateways/abs_file_storage_gateway.py | 34 ++++++ .../infra/gateways/abs_filesystem_gateway.py | 48 +++++++++ .../gateways/abs_llm_artifact_gateway.py | 75 +++++++++++++ .../gateways/celery_task_queue_gateway.py | 27 ++++- .../asb_queue_endpoint_resource_delegate.py | 68 ++++++++++++ .../resources/endpoint_resource_gateway.py | 8 +- ... fake_queue_endpoint_resource_delegate.py} | 19 ++-- .../gateways/resources/k8s_resource_types.py | 13 ++- .../live_endpoint_resource_gateway.py | 49 ++++----- .../queue_endpoint_resource_delegate.py | 46 ++++++++ .../sqs_endpoint_resource_delegate.py | 48 --------- ...> sqs_queue_endpoint_resource_delegate.py} | 25 +++-- .../infra/repositories/__init__.py | 6 ++ ...bs_file_llm_fine_tune_events_repository.py | 19 ++++ .../abs_file_llm_fine_tune_repository.py | 19 ++++ .../repositories/acr_docker_repository.py | 42 ++++++++ .../service_builder/celery.py | 5 + .../service_builder/tasks_v1.py | 42 +++++--- model-engine/requirements.in | 9 +- model-engine/requirements.txt | 96 +++++++++++++---- .../service_config_circleci.yaml | 2 +- model-engine/setup.cfg | 14 +++ model-engine/tests/unit/conftest.py | 2 +- ...t_sqs_queue_endpoint_resource_delegate.py} | 26 ++--- 48 files changed, 1053 insertions(+), 247 deletions(-) create mode 100644 model-engine/model_engine_server/core/celery/abs.py create mode 100644 model-engine/model_engine_server/infra/gateways/abs_file_storage_gateway.py create mode 100644 model-engine/model_engine_server/infra/gateways/abs_filesystem_gateway.py create mode 100644 model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py create mode 100644 model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py rename model-engine/model_engine_server/infra/gateways/resources/{fake_sqs_endpoint_resource_delegate.py => fake_queue_endpoint_resource_delegate.py} (58%) create mode 100644 model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py delete mode 100644 model-engine/model_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py rename model-engine/model_engine_server/infra/gateways/resources/{live_sqs_endpoint_resource_delegate.py => sqs_queue_endpoint_resource_delegate.py} (86%) create mode 100644 model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py create mode 100644 model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py create mode 100644 model-engine/model_engine_server/infra/repositories/acr_docker_repository.py rename model-engine/tests/unit/infra/gateways/resources/{test_live_sqs_endpoint_resource_delegate.py => test_sqs_queue_endpoint_resource_delegate.py} (95%) diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index b4737392..3df2ea81 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -167,8 +167,10 @@ env: value: "${LOAD_PREDICT_FN_MODULE_PATH}" - name: LOAD_MODEL_FN_MODULE_PATH value: "${LOAD_MODEL_FN_MODULE_PATH}" + {{- if .Values.aws }} - name: AWS_PROFILE value: "${AWS_ROLE}" + {{- end }} - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: CHILD_FN_INFO @@ -219,10 +221,12 @@ env: valueFrom: fieldRef: fieldPath: status.hostIP + {{- if .Values.aws }} - name: AWS_PROFILE value: "${AWS_ROLE}" - name: AWS_CONFIG_FILE value: /opt/.aws/config + {{- end }} - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -233,6 +237,16 @@ env: {{- else }} value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" {{- end }} + {{- if .Values.azure}} + - name: AZURE_IDENTITY_NAME + value: {{ .Values.azure.identity_name }} + - name: AZURE_CLIENT_ID + value: {{ .Values.azure.client_id }} + - name: AZURE_OBJECT_ID + value: {{ .Values.azure.object_id }} + - name: ABS_ACCOUNT_NAME + value: {{ .Values.azure.abs_account_name }} + {{- end }} {{- end }} {{- define "modelEngine.syncForwarderTemplateEnv" -}} @@ -251,6 +265,14 @@ env: value: "VISIBILITY_24H" - name: S3_BUCKET value: "${CELERY_S3_BUCKET}" + {{- if .Values.azure}} + - name: ABS_ACCOUNT_NAME + value: {{ .Values.azure.abs_account_name }} + - name: SERVICEBUS_NAMESPACE + value: {{ .Values.azure.servicebus_namespace }} + - name: SERVICEBUS_SAS_KEY + value: {{ .Values.azure.servicebus_sas_key }} + {{- end }} {{- end }} {{- define "modelEngine.serviceEnvBase" }} @@ -290,9 +312,9 @@ env: secretKeyRef: name: {{ .kubernetesDatabaseSecretName }} key: database_url - {{- else if .awsDatabaseSecretName }} + {{- else if .cloudDatabaseSecretName }} - name: DB_SECRET_NAME - value: {{ .awsDatabaseSecretName }} + value: {{ .cloudDatabaseSecretName }} {{- end }} {{- end }} {{- if .Values.config.file }} @@ -314,6 +336,22 @@ env: - name: REDIS_AUTH_TOKEN value: {{ .Values.redis.auth }} {{- end }} + {{- if .Values.azure}} + - name: AZURE_IDENTITY_NAME + value: {{ .Values.azure.identity_name }} + - name: AZURE_CLIENT_ID + value: {{ .Values.azure.client_id }} + - name: AZURE_OBJECT_ID + value: {{ .Values.azure.object_id }} + - name: AZURE_KEYVAULT_IDENTITY_CLIENT_ID + value: {{ .Values.azure.keyvault_identity_client_id }} + - name: KEYVAULT_NAME + value: {{ .Values.azure.keyvault_name }} + - name: ABS_ACCOUNT_NAME + value: {{ .Values.azure.abs_account_name }} + - name: SERVICEBUS_NAMESPACE + value: {{ .Values.azure.servicebus_namespace }} + {{- end }} {{- if eq .Values.context "circleci" }} - name: CIRCLECI value: "true" @@ -405,9 +443,11 @@ volumeMounts: {{- define "modelEngine.forwarderVolumeMounts" }} volumeMounts: + {{- if .Values.aws }} - name: config-volume mountPath: /opt/.aws/config subPath: config + {{- end }} - name: user-config mountPath: /workspace/user_config subPath: raw_data diff --git a/charts/model-engine/templates/cacher_deployment.yaml b/charts/model-engine/templates/cacher_deployment.yaml index 4cb2a9c2..09297aba 100644 --- a/charts/model-engine/templates/cacher_deployment.yaml +++ b/charts/model-engine/templates/cacher_deployment.yaml @@ -45,7 +45,9 @@ spec: command: - dumb-init - -- + {{- if .Values.datadog.enabled }} - ddtrace-run + {{- end }} args: - python - -m diff --git a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml index 768fdb8b..810e7e1f 100644 --- a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml +++ b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml @@ -5,7 +5,12 @@ {{- $tag := .Values.tag }} {{- $message_broker := .Values.celeryBrokerType }} {{- $num_shards := .Values.celery_autoscaler.num_shards }} -{{- $broker_name := ternary "redis-elasticache-message-broker-master" "sqs-message-broker-master" (eq $message_broker "elasticache") }} +{{- $broker_name := "redis-elasticache-message-broker-master" }} +{{- if eq $message_broker "sqs" }} +{{ $broker_name = "sqs-message-broker-master" }} +{{- else if eq $message_broker "servicebus" }} +{{ $broker_name = "servicebus-message-broker-master" }} +{{- end }} apiVersion: apps/v1 kind: StatefulSet metadata: @@ -30,15 +35,19 @@ spec: spec: containers: - args: + {{- if .Values.datadog.enabled }} - ddtrace-run + {{- end }} - python - -m - model_engine_server.core.celery.celery_autoscaler env: + {{- if .Values.aws }} - name: AWS_PROFILE value: {{ .Values.aws.profileName }} - name: AWS_CONFIG_FILE value: /opt/.aws/config + {{- end }} - name: DD_TRACE_ENABLED value: 'false' - name: DD_SERVICE @@ -63,16 +72,26 @@ spec: fieldPath: metadata.name - name: NUM_SHARDS value: '{{ $num_shards }}' + {{- if .Values.azure }} + - name: AZURE_CLIENT_ID + value: {{ .Values.azure.client_id }} + - name: AZURE_OBJECT_ID + value: {{ .Values.azure.object_id }} + - name: SERVICEBUS_NAMESPACE + value: {{ .Values.azure.servicebus_namespace }} + {{- end }} image: "{{ .Values.image.gatewayRepository }}:{{ $tag }}" imagePullPolicy: Always name: main resources: requests: cpu: 1000m + {{- if .Values.aws }} volumeMounts: - mountPath: /opt/.aws/config name: config-volume subPath: config + {{- end }} {{ with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} @@ -83,9 +102,11 @@ spec: value: 'true' effect: NoSchedule serviceAccountName: {{ include "modelEngine.fullname" $ }} + {{- if .Values.aws }} volumes: - configMap: name: {{ .Values.aws.configMap.name }} name: config-volume + {{- end}} {{- end }} {{- end }} \ No newline at end of file diff --git a/charts/model-engine/templates/endpoint_builder_deployment.yaml b/charts/model-engine/templates/endpoint_builder_deployment.yaml index 2f62a11a..273543f5 100644 --- a/charts/model-engine/templates/endpoint_builder_deployment.yaml +++ b/charts/model-engine/templates/endpoint_builder_deployment.yaml @@ -46,7 +46,9 @@ spec: command: - dumb-init - -- + {{- if .Values.datadog.enabled }} - ddtrace-run + {{- end }} args: - celery - --app=model_engine_server.service_builder diff --git a/charts/model-engine/templates/gateway_deployment.yaml b/charts/model-engine/templates/gateway_deployment.yaml index e5283319..ed1d6cae 100644 --- a/charts/model-engine/templates/gateway_deployment.yaml +++ b/charts/model-engine/templates/gateway_deployment.yaml @@ -52,7 +52,9 @@ spec: command: - dumb-init - -- + {{- if .Values.datadog.enabled }} - ddtrace-run + {{- end }} args: - python - -m diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 5b8cf5ae..9300637a 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -14,13 +14,15 @@ {{- $forwarder_volume_mounts := include "modelEngine.forwarderVolumeMounts" . }} {{- $gateway_repository := .Values.image.gatewayRepository -}} {{- $tag := .Values.tag -}} -{{- $aws_config_map_name := .Values.aws.configMap.name }} +{{- $aws_config_map_name := (.Values.aws).configMap.name }} {{- $security_context := .Values.serviceTemplate.securityContext }} {{- $mount_infra_config := .Values.serviceTemplate.mountInfraConfig }} {{- $service_template_service_account_name := .Values.serviceTemplate.serviceAccountName }} {{- $service_template_aws_config_map_name := .Values.serviceTemplate.awsConfigMapName }} {{- $celery_broker_type := .Values.celeryBrokerType }} {{- $node_selector := .Values.nodeSelector }} +{{- $require_aws_config := not (empty .Values.aws) }} +{{- $enable_datadog := .Values.datadog.enabled }} {{- if .Values.message }} {{- .Values.message }} @@ -111,7 +113,9 @@ data: command: - /usr/bin/dumb-init - -- + {{- if $enable_datadog }} - ddtrace-run + {{- end }} - python - -m - model_engine_server.inference.forwarding.http_forwarder @@ -155,7 +159,9 @@ data: command: - /usr/bin/dumb-init - -- + {{- if $enable_datadog }} - ddtrace-run + {{- end }} - python - -m - model_engine_server.inference.forwarding.http_forwarder @@ -201,7 +207,9 @@ data: command: - /usr/bin/dumb-init - -- + {{- if $enable_datadog }} - ddtrace-run + {{- end }} - python - -m - model_engine_server.inference.forwarding.celery_forwarder @@ -221,6 +229,12 @@ data: {{- end }} - --num-workers - "${PER_WORKER}" + - --broker-type + - {{ $celery_broker_type }} + {{- if eq $celery_broker_type "servicebus" }} + - --backend-protocol + - abs + {{- end }} {{- $async_forwarder_template_env | nindent 14 }} resources: requests: @@ -275,9 +289,11 @@ data: ${TRITON_MEMORY_DICT} ${TRITON_STORAGE_DICT} volumeMounts: + {{- if $require_aws_config }} - name: config-volume mountPath: /opt/.aws/config subPath: config + {{- end }} - mountPath: /dev/shm name: dshm {{- end }} @@ -313,9 +329,11 @@ data: memory: ${MEMORY} ${STORAGE_DICT} volumeMounts: + {{- if $require_aws_config }} - name: config-volume mountPath: /opt/.aws/config subPath: config + {{- end }} - mountPath: /dev/shm name: dshm {{- if $mount_infra_config }} @@ -336,6 +354,7 @@ data: securityContext: fsGroup: 65534 volumes: + {{- if $require_aws_config }} - name: config-volume configMap: {{- if $service_template_aws_config_map_name }} @@ -343,6 +362,7 @@ data: {{- else }} name: {{ $aws_config_map_name }} {{- end }} + {{- end }} - name: user-config configMap: name: ${RESOURCE_NAME} @@ -556,10 +576,12 @@ data: {{- toYaml . | nindent 12 }} {{- end }} serviceAccountName: {{ $launch_name }} + {{- if $require_aws_config }} volumes: - name: config-volume configMap: name: {{ $aws_config_map_name }} + {{- end }} containers: - name: main image: {{ $gateway_repository }}:${GIT_TAG} @@ -579,7 +601,9 @@ data: command: - dumb-init - -- + {{- if $enable_datadog }} - ddtrace-run + {{- end }} args: - python - -m @@ -602,10 +626,12 @@ data: limits: cpu: 4 memory: 32Gi + {{- if $require_aws_config }} volumeMounts: - name: config-volume mountPath: /opt/.aws/config subPath: config + {{- end }} {{- range $device := tuple "cpu" "gpu" }} docker-image-batch-job-{{- $device }}.yaml: |- apiVersion: batch/v1 @@ -652,9 +678,11 @@ data: serviceAccountName: {{ $launch_name }} {{- end }} volumes: + {{- if $require_aws_config }} - name: config-volume configMap: name: {{ $aws_config_map_name }} + {{- end }} - name: workdir emptyDir: {} - name: dshm @@ -694,9 +722,11 @@ data: memory: ${MEMORY} ${STORAGE_DICT} volumeMounts: + {{- if $require_aws_config }} - name: config-volume mountPath: /opt/.aws/config subPath: config + {{- end }} - name: workdir mountPath: ${MOUNT_PATH} - mountPath: /dev/shm @@ -726,9 +756,11 @@ data: cpu: 1 memory: 1Gi volumeMounts: + {{- if $require_aws_config }} - name: config-volume mountPath: /opt/.aws/config subPath: config + {{- end }} - name: workdir mountPath: ${MOUNT_PATH} {{- end }} diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 8d841c86..0f9d9337 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -85,6 +85,7 @@ affinity: { } config: values: infra: + cloud_provider: aws k8s_cluster_name: minikube dns_host_domain: localhost default_region: us-west-2 @@ -142,7 +143,7 @@ config: } billing_queue_arn: none - cache_redis_url: redis://redis-message-broker-master.default/15 + cache_redis_aws_url: redis://redis-message-broker-master.default/15 s3_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET/fine_tune_repository" dd_trace_enabled: false istio_enabled: true @@ -221,3 +222,6 @@ imageCache: effect: "NoSchedule" celeryBrokerType: redis + +datadog: + enabled: false diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 7ed16da6..2d002c00 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -17,10 +17,13 @@ image: pullPolicy: Always secrets: - # kubernetesDatabaseSecretName or awsDatabaseSecretName [required] + # kubernetesDatabaseSecretName or cloudDatabaseSecretName [required] # is the name of the secret that contains the database credentials kubernetesDatabaseSecretName: llm-engine-postgres-credentials +# Azure Key Vault name to pull secrets from +keyvaultName: llm-engine-keyvault + db: runDbInitScript: false @@ -139,6 +142,8 @@ serviceTemplate: config: values: infra: + # cloud_provider [required]; either "aws" or "azure" + cloud_provider: aws # k8s_cluster_name [required] is the name of the k8s cluster k8s_cluster_name: main_cluster # dns_host_domain [required] is the domain name of the k8s cluster @@ -156,8 +161,11 @@ config: launch: # endpoint_namespace [required] is K8s namespace the endpoints will be created in endpoint_namespace: llm-engine - # cache_redis_url [required] is the full url for the redis cluster you wish to connect - cache_redis_url: redis://llm-engine-prod-cache.use1.cache.amazonaws.com:6379/15 + # cache_redis_aws_url is the full url for the redis cluster you wish to connect, + # cache_redis_azure_host is the redis cluster host when using cloud_provider azure + # one of cache_redis_aws_url and cache_redis_azure_host must be provided + cache_redis_aws_url: redis://llm-engine-prod-cache.use1.cache.amazonaws.com:6379/15 + cache_redis_azure_host: llm-engine-cache.redis.cache.windows.net:6380 # s3_file_llm_fine_tuning_job_repository [required] is the S3 URI for the S3 bucket/key that you wish to save fine-tuned assests s3_file_llm_fine_tuning_job_repository: "s3://llm-engine/llm-ft-job-repository" # dd_trace_enabled specifies whether to enable datadog tracing, datadog must be installed in the cluster @@ -249,3 +257,6 @@ imageCache: # celeryBrokerType specifies the celery broker type for async endpoints, either "sqs" or "elasticache" celeryBrokerType: sqs + +datadog: + enabled: false diff --git a/docs/guides/self_hosting.md b/docs/guides/self_hosting.md index 84aaa376..348e94be 100644 --- a/docs/guides/self_hosting.md +++ b/docs/guides/self_hosting.md @@ -112,7 +112,8 @@ Below are the configurations to specify in the `values_sample.yaml` file. | config.values.infra.redis_host | The hostname of the redis cluster you wish to connect | Yes | | config.values.infra.s3_bucket | The S3 bucket you wish to connect | Yes | | config.values.llm_engine.endpoint_namespace | K8s namespace the endpoints will be created in | Yes | -| config.values.llm_engine.cache_redis_url | The full url for the redis cluster you wish to connect | Yes | +| config.values.llm_engine.cache_redis_aws_url | The full url for the redis cluster you wish to connect | No | +| config.values.llm_engine.cache_redis_azure_host | The redis cluster host when using cloud_provider azure | No | | config.values.llm_engine.s3_file_llm_fine_tuning_job_repository | The S3 URI for the S3 bucket/key that you wish to save fine-tuned assets | Yes | | config.values.dd_trace_enabled | Whether to enable datadog tracing, datadog must be installed in the cluster | No | diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 140c5b11..b65f1189 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -13,6 +13,7 @@ from model_engine_server.core.auth.fake_authentication_repository import ( FakeAuthenticationRepository, ) +from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, @@ -50,6 +51,9 @@ FirehoseStreamingStorageGateway, ) from model_engine_server.infra.gateways import ( + ABSFileStorageGateway, + ABSFilesystemGateway, + ABSLLMArtifactGateway, CeleryTaskQueueGateway, FakeMonitoringMetricsGateway, LiveAsyncModelEndpointInferenceGateway, @@ -70,23 +74,29 @@ FakeModelPrimitiveGateway, ) from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( + ASBQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( - FakeSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( + FakeQueueEndpointResourceDelegate, ) from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, ) -from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, ) from model_engine_server.infra.gateways.s3_file_storage_gateway import S3FileStorageGateway from model_engine_server.infra.repositories import ( + ABSFileLLMFineTuneEventsRepository, + ABSFileLLMFineTuneRepository, + ACRDockerRepository, DbBatchJobRecordRepository, DbDockerImageBatchJobBundleRepository, DbModelBundleRepository, @@ -95,6 +105,7 @@ ECRDockerRepository, FakeDockerRepository, LiveTokenizerRepository, + LLMFineTuneRepository, RedisModelEndpointCacheRepository, S3FileLLMFineTuneEventsRepository, S3FileLLMFineTuneRepository, @@ -174,6 +185,7 @@ def _get_external_interfaces( redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS) redis_24h_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS_24H) sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) + servicebus_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SERVICEBUS) monitoring_metrics_gateway = get_monitoring_metrics_gateway() model_endpoint_record_repo = DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, @@ -181,25 +193,35 @@ def _get_external_interfaces( read_only=read_only, ) - sqs_delegate: SQSEndpointResourceDelegate + queue_delegate: QueueEndpointResourceDelegate if CIRCLECI: - sqs_delegate = FakeSQSEndpointResourceDelegate() + queue_delegate = FakeQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "azure": + queue_delegate = ASBQueueEndpointResourceDelegate() else: - sqs_delegate = LiveSQSEndpointResourceDelegate( + queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) - inference_task_queue_gateway = ( - sqs_task_queue_gateway if not CIRCLECI else redis_24h_task_queue_gateway - ) - resource_gateway = LiveEndpointResourceGateway(sqs_delegate=sqs_delegate) + inference_task_queue_gateway: TaskQueueGateway + infra_task_queue_gateway: TaskQueueGateway + if CIRCLECI: + inference_task_queue_gateway = redis_24h_task_queue_gateway + infra_task_queue_gateway = redis_task_queue_gateway + elif infra_config().cloud_provider == "azure": + inference_task_queue_gateway = servicebus_task_queue_gateway + infra_task_queue_gateway = servicebus_task_queue_gateway + else: + inference_task_queue_gateway = sqs_task_queue_gateway + infra_task_queue_gateway = sqs_task_queue_gateway + resource_gateway = LiveEndpointResourceGateway(queue_delegate=queue_delegate) redis_client = aioredis.Redis(connection_pool=get_or_create_aioredis_pool()) model_endpoint_cache_repo = RedisModelEndpointCacheRepository( redis_client=redis_client, ) model_endpoint_infra_gateway = LiveModelEndpointInfraGateway( resource_gateway=resource_gateway, - task_queue_gateway=redis_task_queue_gateway, + task_queue_gateway=infra_task_queue_gateway, ) async_model_endpoint_inference_gateway = LiveAsyncModelEndpointInferenceGateway( task_queue_gateway=inference_task_queue_gateway @@ -211,8 +233,16 @@ def _get_external_interfaces( streaming_model_endpoint_inference_gateway = LiveStreamingModelEndpointInferenceGateway( use_asyncio=(not CIRCLECI), ) - filesystem_gateway = S3FilesystemGateway() - llm_artifact_gateway = S3LLMArtifactGateway() + filesystem_gateway = ( + ABSFilesystemGateway() + if infra_config().cloud_provider == "azure" + else S3FilesystemGateway() + ) + llm_artifact_gateway = ( + ABSLLMArtifactGateway() + if infra_config().cloud_provider == "azure" + else S3LLMArtifactGateway() + ) model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) @@ -253,22 +283,40 @@ def _get_external_interfaces( docker_image_batch_job_gateway = LiveDockerImageBatchJobGateway() cron_job_gateway = LiveCronJobGateway() - llm_fine_tune_repository = S3FileLLMFineTuneRepository( - file_path=os.getenv( - "S3_FILE_LLM_FINE_TUNE_REPOSITORY", - hmi_config.s3_file_llm_fine_tune_repository, - ), + llm_fine_tune_repository: LLMFineTuneRepository + if infra_config().cloud_provider == "azure": + llm_fine_tune_repository = ABSFileLLMFineTuneRepository("not supported yet") + else: + llm_fine_tune_repository = S3FileLLMFineTuneRepository( + file_path=os.getenv( + "S3_FILE_LLM_FINE_TUNE_REPOSITORY", + hmi_config.s3_file_llm_fine_tune_repository, + ), + ) + llm_fine_tune_events_repository = ( + ABSFileLLMFineTuneEventsRepository() + if infra_config().cloud_provider == "azure" + else S3FileLLMFineTuneEventsRepository() ) - llm_fine_tune_events_repository = S3FileLLMFineTuneEventsRepository() llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService( docker_image_batch_job_gateway=docker_image_batch_job_gateway, docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository, llm_fine_tune_repository=llm_fine_tune_repository, ) - file_storage_gateway = S3FileStorageGateway() + file_storage_gateway = ( + ABSFileStorageGateway() + if infra_config().cloud_provider == "azure" + else S3FileStorageGateway() + ) - docker_repository = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository() + docker_repository: DockerRepository + if CIRCLECI: + docker_repository = FakeDockerRepository() + elif infra_config().cloud_provider == "azure": + docker_repository = ACRDockerRepository() + else: + docker_repository = ECRDockerRepository() tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway) @@ -281,8 +329,8 @@ def _get_external_interfaces( llm_model_endpoint_service=llm_model_endpoint_service, batch_job_service=batch_job_service, resource_gateway=resource_gateway, - endpoint_creation_task_queue_gateway=redis_task_queue_gateway, - inference_task_queue_gateway=sqs_task_queue_gateway, + endpoint_creation_task_queue_gateway=infra_task_queue_gateway, + inference_task_queue_gateway=inference_task_queue_gateway, model_endpoint_infra_gateway=model_endpoint_infra_gateway, model_primitive_gateway=model_primitive_gateway, docker_image_batch_job_bundle_repository=docker_image_batch_job_bundle_repository, diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 86625450..6c7088fc 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -4,9 +4,11 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Sequence +from typing import Optional, Sequence import yaml +from azure.identity import DefaultAzureCredential +from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger logger = make_logger(logger_name()) @@ -45,7 +47,6 @@ def get_model_cache_directory_name(model_name: str): class HostedModelInferenceServiceConfig: endpoint_namespace: str billing_queue_arn: str - cache_redis_url: str # also using this to store sync autoscaling metrics sqs_profile: str sqs_queue_policy_template: str sqs_queue_tag_template: str @@ -64,6 +65,8 @@ class HostedModelInferenceServiceConfig: user_inference_tensorflow_repository: str docker_image_layer_cache_repository: str sensitive_log_mode: bool + cache_redis_aws_url: Optional[str] = None # also using this to store sync autoscaling metrics + cache_redis_azure_host: Optional[str] = None @classmethod def from_yaml(cls, yaml_path): @@ -71,10 +74,22 @@ def from_yaml(cls, yaml_path): raw_data = yaml.safe_load(f) return HostedModelInferenceServiceConfig(**raw_data) + @property + def cache_redis_url(self) -> str: + if self.cache_redis_aws_url: + return self.cache_redis_aws_url + + assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure" + username = os.getenv("AZURE_OBJECT_ID") + password = DefaultAzureCredential().get_token("https://redis.azure.com/.default").token + return f"rediss://{username}:{password}@{self.cache_redis_azure_host}" + @property def cache_redis_host_port(self) -> str: # redis://redis.url:6379/ # -> redis.url:6379 + if "rediss://" in self.cache_redis_url: + return self.cache_redis_url.split("rediss://")[1].split("/")[0] return self.cache_redis_url.split("redis://")[1].split("/")[0] @property diff --git a/model-engine/model_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py index 301a2d45..06073ada 100644 --- a/model-engine/model_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -32,6 +32,7 @@ class BrokerType(str, Enum): REDIS = "redis" REDIS_24H = "redis_24h" SQS = "sqs" + SERVICEBUS = "servicebus" class BrokerName(str, Enum): @@ -42,6 +43,7 @@ class BrokerName(str, Enum): REDIS = "redis-message-broker-master" SQS = "sqs-message-broker-master" + SERVICEBUS = "servicebus-message-broker-master" class CreateModelEndpointV1Request(BaseModel): diff --git a/model-engine/model_engine_server/common/io.py b/model-engine/model_engine_server/common/io.py index 2247b6f4..93e2328a 100644 --- a/model-engine/model_engine_server/common/io.py +++ b/model-engine/model_engine_server/common/io.py @@ -1,14 +1,27 @@ """Launch Input/Output utils.""" import os +from typing import Any import boto3 import smart_open +from model_engine_server.core.config import infra_config def open_wrapper(uri: str, mode: str = "rt", **kwargs): + client: Any # This follows the 5.1.0 smart_open API - profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) - session = boto3.Session(profile_name=profile_name) - client = session.client("s3") + if infra_config().cloud_provider == "azure": + from azure.identity import DefaultAzureCredential + from azure.storage.blob import BlobServiceClient + + client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), + ) + else: + profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) + session = boto3.Session(profile_name=profile_name) + client = session.client("s3") + transport_params = {"client": client} return smart_open.open(uri, mode, transport_params=transport_params) diff --git a/model-engine/model_engine_server/core/celery/abs.py b/model-engine/model_engine_server/core/celery/abs.py new file mode 100644 index 00000000..ea303947 --- /dev/null +++ b/model-engine/model_engine_server/core/celery/abs.py @@ -0,0 +1,23 @@ +from azure.core.exceptions import ResourceExistsError +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from celery.backends.azureblockblob import AzureBlockBlobBackend as DefaultAzureBlockBlobBackend +from kombu.utils import cached_property + + +class AzureBlockBlobBackend(DefaultAzureBlockBlobBackend): + @cached_property + def _blob_service_client(self): + client = BlobServiceClient( + f"https://{self._connection_string}.blob.core.windows.net", + credential=DefaultAzureCredential(), + connection_timeout=self._connection_timeout, + read_timeout=self._read_timeout, + ) + + try: + client.create_container(name=self._container_name) + except ResourceExistsError: + pass + + return client diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index 7e87d2f0..a045d0aa 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -26,6 +26,9 @@ # override the backend with a class instead of a URL, despite the fact # that the `backend` constructor arg type is a Union[str, Type[celery.backends.base.Backend]] backends.BACKEND_ALIASES["s3"] = "model_engine_server.core.celery.s3:S3Backend" +backends.BACKEND_ALIASES[ + "azureblockblob" +] = "model_engine_server.core.celery.abs:AzureBlockBlobBackend" @unique @@ -347,14 +350,14 @@ def celery_app( :param s3_base_path: [optional] Base path for task results when using S3 as backend. The results uri will be "s3:////...". - :param backend_protocol: [optional] Backend protocol to use, currently supports "s3" and "redis". + :param backend_protocol: [optional] Backend protocol to use, currently supports "s3", "redis", and "abs". Defaults to "s3". Redis might be faster than S3 but is not persistent, so using "redis" is discouraged. If you do end up using this, make sure you set up `result_expires` (https://docs.celeryproject.org/en/stable/userguide/configuration.html#result-expires) to something reasonable (1 day by default) and run `celery beat` periodically to clear expired results from Redis. Visit https://docs.celeryproject.org/en/stable/userguide/periodic-tasks.html to learn more about celery beat - :param broker_type: [defaults to "redis"] The broker type. We currently support "redis" and "sqs". + :param broker_type: [defaults to "redis"] The broker type. We currently support "redis", "sqs", and "servicebus". :param aws_role: [optional] AWS role to use. @@ -481,9 +484,14 @@ def _get_broker_endpoint_and_transport_options( # Plain "sqs://" signifies to use instance metadata. return "sqs://", out_broker_transport_options + if broker_type == "servicebus": + return ( + f"azureservicebus://DefaultAzureCredential@{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net", + out_broker_transport_options, + ) raise ValueError( - f"Only 'redis' and 'sqs' are supported values for broker_type, got value {broker_type}" + f"Only 'redis', 'sqs', and 'servicebus' are supported values for broker_type, got value {broker_type}" ) @@ -514,9 +522,11 @@ def _get_backend_url_and_conf( "s3_base_path": s3_base_path, } ) + elif backend_protocol == "abs": + backend_url = f"azureblockblob://{os.getenv('ABS_ACCOUNT_NAME')}" else: raise ValueError( - f'Unknown backend protocol "{backend_protocol}". Should be one of ["s3", "redis"].' + f'Unknown backend protocol "{backend_protocol}". Should be one of ["s3", "redis", "abs].' ) return backend_url, out_conf_changes diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py index b5b44a78..d8782e35 100644 --- a/model-engine/model_engine_server/core/celery/celery_autoscaler.py +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -13,6 +13,9 @@ import aioredis import stringcase +from azure.core.exceptions import ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.servicebus.management import ServiceBusAdministrationClient from celery.app.control import Inspect from datadog import statsd from kubernetes_asyncio import client @@ -41,6 +44,7 @@ def excluded_namespaces(): ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master" SQS_BROKER = "sqs-message-broker-master" +SERVICEBUS_BROKER = "servicebus-message-broker-master" UPDATE_DEPLOYMENT_MAX_RETRIES = 10 @@ -466,6 +470,54 @@ async def get_broker_metrics( ) # connection_count and max_connections are redis-specific metrics +class ASBBroker(AutoscalerBroker): + @staticmethod + def _get_asb_queue_size(queue_name: str): + with ServiceBusAdministrationClient( + f"{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net", + credential=DefaultAzureCredential(), + ) as client: + try: + queue_attributes = client.get_queue_runtime_properties(queue_name=queue_name) + active_queue_size = queue_attributes.active_message_count + + logger.info(f"ASB {queue_name} total: active queue size {active_queue_size}") + except ResourceNotFoundError as e: + logger.info(f"Queue does not exist {queue_name}: {e}") + active_queue_size = 0 + except Exception as e: + logger.error(f"Failed to get queue attributes {queue_name}: {e}") + active_queue_size = 0 + + return active_queue_size + + def _get_queue_sizes( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ): + queue_names = [queue_name for queue_name, _ in queues] + with ThreadPoolExecutor() as executor: + results = executor.map(ASBBroker._get_asb_queue_size, queue_names) + + for q, active_queue_size in zip(queues, results): + queue_sizes[q].enqueued += active_queue_size + queue_sizes[q].total += active_queue_size + return queue_sizes + + async def get_broker_metrics( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ) -> BrokerMetrics: + queue_sizes = self._get_queue_sizes(queues, queue_sizes) + return BrokerMetrics( + queue_sizes=queue_sizes, + connection_count=None, + max_connections=None, + ) # connection_count and max_connections are redis-specific metrics + + def get_worker_metrics( inspect: Dict[int, Inspect], queues: Set[Tuple[str, int]], @@ -533,10 +585,17 @@ async def main(): BROKER_NAME_TO_CLASS = { ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True), SQS_BROKER: SQSBroker(), + SERVICEBUS_BROKER: ASBBroker(), } broker = BROKER_NAME_TO_CLASS[autoscaler_broker] - broker_type = "redis" if isinstance(broker, RedisBroker) else "sqs" + broker_type = ( + "redis" + if isinstance(broker, RedisBroker) + else "sqs" + if isinstance(broker, SQSBroker) + else "servicebus" + ) if broker_type == "redis": inspect = { @@ -551,8 +610,12 @@ async def main(): # for sqs we will get active/reserved counts directly from sqs as opposed to using # an inspect object inspect = {} + elif broker_type == "servicebus": + inspect = { + 0: inspect_app(app=celery_app(None, broker_type=broker_type, backend_protocol="abs")) + } else: - raise ValueError("broker_type not redis or sqs, how did we get here?") + raise ValueError("broker_type not redis, sqs, or servicebus; how did we get here?") env = os.getenv("DD_ENV") instance_count = int(os.getenv("POD_NAME", "pod-0").split("-")[-1]) diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index 6403d64f..5bbe58bb 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -31,6 +31,7 @@ @dataclass class InfraConfig: + cloud_provider: str env: str k8s_cluster_name: str dns_host_domain: str diff --git a/model-engine/model_engine_server/core/configs/default.yaml b/model-engine/model_engine_server/core/configs/default.yaml index 745d3fc2..3529c814 100644 --- a/model-engine/model_engine_server/core/configs/default.yaml +++ b/model-engine/model_engine_server/core/configs/default.yaml @@ -1,3 +1,4 @@ +cloud_provider: "aws" env: "circleci" k8s_cluster_name: "minikube" dns_host_domain: "localhost" diff --git a/model-engine/model_engine_server/core/utils/url.py b/model-engine/model_engine_server/core/utils/url.py index 358f316c..81a48ffc 100644 --- a/model-engine/model_engine_server/core/utils/url.py +++ b/model-engine/model_engine_server/core/utils/url.py @@ -8,6 +8,7 @@ class ParsedURL(NamedTuple): bucket: str key: str region: Optional[str] + account: Optional[str] = None def canonical_url(self) -> str: """Packs the parsed URL information into a standard form of @@ -23,6 +24,10 @@ def s3(bucket: str, key: str, region: Optional[str] = None) -> "ParsedURL": def gs(bucket: str, key: str, region: Optional[str] = None) -> "ParsedURL": return ParsedURL(protocol="gs", bucket=bucket, key=key, region=region) + @staticmethod + def azure(bucket: str, key: str, account: Optional[str] = None) -> "ParsedURL": + return ParsedURL(protocol="azure", bucket=bucket, key=key, account=account) + @staticmethod def cds(bucket: str, key: str, region: Optional[str] = None) -> "ParsedURL": return ParsedURL(protocol="scale-cds", bucket=bucket, key=key, region=region) @@ -42,6 +47,7 @@ def parse_attachment_url(url: str, clean_key: bool = True) -> ParsedURL: bucket = None region = None key = None + account = None # s3://bucket/key1/key2 match = re.search("^s3://([^/]+)/(.*?)$", url) @@ -54,6 +60,13 @@ def parse_attachment_url(url: str, clean_key: bool = True) -> ParsedURL: protocol = "gs" bucket, key = match.group(1), match.group(2) + # azure://bucket/key1/key2 + # for Azure Blob Storage, bucket refers to an ABS container + match = re.search("^azure://([^/]+)/(.*?)$", url) + if match: + protocol = "azure" + bucket, key = match.group(1), match.group(2) + # http://bucket.s3.amazonaws.com/key1/key2 match = re.search("^https?://(.+).s3.amazonaws.com(.*?)$", url) if match: @@ -85,6 +98,13 @@ def parse_attachment_url(url: str, clean_key: bool = True) -> ParsedURL: if match: bucket, key = match.group(1), match.group(2) + # https://account.blob.core.windows.net/bucket/key1/key2 + # for Azure Blob Storage, bucket refers to an ABS container + match = re.search("^https?://([^/]+).blob.core.windows.net/([^/]+)(.*?)$", url) + if match: + protocol = "azure" + account, bucket, key = match.group(1), match.group(2), match.group(3) + match = re.search("scale-cds://(\\w+)/([\\-\\w\\/]+)", url) if match: bucket, key = match.group(1), match.group(2) @@ -103,4 +123,5 @@ def clean(v): bucket=clean(bucket), region=clean(region), key=clean(key) if clean_key else key, + account=clean(account), ) diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index 4469b30b..b6949617 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -4,6 +4,8 @@ from typing import Iterator, Optional import sqlalchemy +from azure.identity import DefaultAzureCredential, ManagedIdentityCredential +from azure.keyvault.secrets import SecretClient from model_engine_server.core.aws.secrets import get_key_file from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger @@ -17,6 +19,8 @@ def get_key_file_name(environment: str) -> str: + if infra_config().cloud_provider == "azure": + return f"{environment}-ml-infra-pg".replace("training", "prod").replace("-new", "") return f"{environment}/ml_infra_pg".replace("training", "prod").replace("-new", "") @@ -42,17 +46,36 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool if key_file is None: key_file = get_key_file_name(env) # type: ignore logger.info(f"Using key file {key_file}") - db_secret_aws_profile = os.environ.get("DB_SECRET_AWS_PROFILE") - creds = get_key_file(key_file, db_secret_aws_profile) - user = creds.get("username") - password = creds.get("password") - host = creds.get("clusterHostRo") if read_only else creds.get("clusterHost") - port = str(creds.get("port")) - dbname = creds.get("dbname") - logger.info(f"Connecting to db {host}:{port}, name {dbname}") - - engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}" + if infra_config().cloud_provider == "azure": + client = SecretClient( + vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net", + credential=ManagedIdentityCredential( + client_id=os.getenv("AZURE_KEYVAULT_IDENTITY_CLIENT_ID") + ), # uses a different managed identity than the default + ) + db = client.get_secret(key_file).value + user = os.environ.get("AZURE_IDENTITY_NAME") + password = ( + DefaultAzureCredential() + .get_token("https://ossrdbms-aad.database.windows.net") + .token + ) + logger.info(f"Connecting to db {db} as user {user}") + + engine_url = f"postgresql://{user}:{password}@{db}?sslmode=require" + else: + db_secret_aws_profile = os.environ.get("DB_SECRET_AWS_PROFILE") + creds = get_key_file(key_file, db_secret_aws_profile) + + user = creds.get("username") + password = creds.get("password") + host = creds.get("clusterHostRo") if read_only else creds.get("clusterHost") + port = str(creds.get("port")) + dbname = creds.get("dbname") + logger.info(f"Connecting to db {host}:{port}, name {dbname}") + + engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}" assert engine_url diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index 12a6e82f..df1e9df2 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -14,26 +14,34 @@ from model_engine_server.common.config import hmi_config from model_engine_server.common.constants import READYZ_FPATH from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.db.base import SessionAsyncNullPool from model_engine_server.domain.repositories import DockerRepository +from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( + ASBQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( - FakeSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( + FakeQueueEndpointResourceDelegate, ) from model_engine_server.infra.gateways.resources.image_cache_gateway import ImageCacheGateway from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, ) -from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, +from model_engine_server.infra.repositories import ( + ACRDockerRepository, + ECRDockerRepository, + FakeDockerRepository, ) -from model_engine_server.infra.repositories import ECRDockerRepository, FakeDockerRepository from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( DbModelEndpointRecordRepository, ) @@ -98,19 +106,27 @@ async def main(args: Any): read_only=True, ) - sqs_delegate: SQSEndpointResourceDelegate + queue_delegate: QueueEndpointResourceDelegate if CIRCLECI: - sqs_delegate = FakeSQSEndpointResourceDelegate() + queue_delegate = FakeQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "azure": + queue_delegate = ASBQueueEndpointResourceDelegate() else: - sqs_delegate = LiveSQSEndpointResourceDelegate( + queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) k8s_resource_manager = LiveEndpointResourceGateway( - sqs_delegate=sqs_delegate, + queue_delegate=queue_delegate, ) image_cache_gateway = ImageCacheGateway() - docker_repo = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository() + docker_repo: DockerRepository + if CIRCLECI: + docker_repo = FakeDockerRepository() + elif infra_config().cloud_provider == "azure": + docker_repo = ACRDockerRepository() + else: + docker_repo = ECRDockerRepository() while True: loop_start = time.time() await loop_iteration( diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index c52442ab..6cd8f5af 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -8,9 +8,11 @@ from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.core.config import infra_config from model_engine_server.db.base import SessionAsyncNullPool from model_engine_server.domain.entities import BatchJobSerializationFormat from model_engine_server.infra.gateways import ( + ABSFilesystemGateway, CeleryTaskQueueGateway, LiveAsyncModelEndpointInferenceGateway, LiveBatchJobProgressGateway, @@ -21,17 +23,20 @@ RedisInferenceAutoscalingMetricsGateway, S3FilesystemGateway, ) -from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( - FakeSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( + ASBQueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( + FakeQueueEndpointResourceDelegate, ) from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, ) -from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, ) from model_engine_server.infra.repositories import ( DbBatchJobRecordRepository, @@ -54,32 +59,39 @@ async def run_batch_job( session = SessionAsyncNullPool pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) redis = aioredis.Redis(connection_pool=pool) - redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS) sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) + servicebus_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SERVICEBUS) monitoring_metrics_gateway = get_monitoring_metrics_gateway() model_endpoint_record_repo = DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False ) - sqs_delegate: SQSEndpointResourceDelegate + queue_delegate: QueueEndpointResourceDelegate if CIRCLECI: - sqs_delegate = FakeSQSEndpointResourceDelegate() + queue_delegate = FakeQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "azure": + queue_delegate = ASBQueueEndpointResourceDelegate() else: - sqs_delegate = LiveSQSEndpointResourceDelegate( + queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) - resource_gateway = LiveEndpointResourceGateway(sqs_delegate=sqs_delegate) + resource_gateway = LiveEndpointResourceGateway(queue_delegate=queue_delegate) model_endpoint_cache_repo = RedisModelEndpointCacheRepository( redis_client=redis, ) + inference_task_queue_gateway = ( + servicebus_task_queue_gateway + if infra_config().cloud_provider == "azure" + else sqs_task_queue_gateway + ) model_endpoint_infra_gateway = LiveModelEndpointInfraGateway( resource_gateway=resource_gateway, - task_queue_gateway=redis_task_queue_gateway, + task_queue_gateway=inference_task_queue_gateway, ) async_model_endpoint_inference_gateway = LiveAsyncModelEndpointInferenceGateway( - task_queue_gateway=sqs_task_queue_gateway + task_queue_gateway=inference_task_queue_gateway ) streaming_model_endpoint_inference_gateway = LiveStreamingModelEndpointInferenceGateway( use_asyncio=(not CIRCLECI), @@ -87,7 +99,11 @@ async def run_batch_job( sync_model_endpoint_inference_gateway = LiveSyncModelEndpointInferenceGateway( use_asyncio=(not CIRCLECI), ) - filesystem_gateway = S3FilesystemGateway() + filesystem_gateway = ( + ABSFilesystemGateway() + if infra_config().cloud_provider == "azure" + else S3FilesystemGateway() + ) model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 9ed5e4dd..264f6af5 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -44,6 +44,8 @@ def error_response(msg: str, e_unhandled: Exception) -> ErrorResponse: def create_celery_service( forwarder: Forwarder, task_visibility: TaskVisibility, + broker_type: str, + backend_protocol: str, queue_name: Optional[str] = None, sqs_url: Optional[str] = None, ) -> Celery: @@ -59,10 +61,11 @@ def create_celery_service( s3_bucket=infra_config().s3_bucket, aws_role=infra_config().profile_ml_inference_worker, task_visibility=task_visibility, - broker_type=str(BrokerType.SQS.value if sqs_url else BrokerType.REDIS.value), + broker_type=broker_type, broker_transport_options={"predefined_queues": {queue_name: {"url": sqs_url}}} - if sqs_url + if broker_type == str(BrokerType.SQS.value) else None, + backend_protocol=backend_protocol, ) class ErrorHandlingTask(Task): @@ -157,16 +160,28 @@ def entrypoint(): parser.add_argument("--set", type=str, action="append") parser.add_argument("--task-visibility", type=str, required=True) parser.add_argument("--num-workers", type=int, required=True) + parser.add_argument("--broker-type", type=str, default=None) + parser.add_argument("--backend-protocol", type=str, default="s3") parser.add_argument("--queue", type=str, required=True) parser.add_argument("--sqs-url", type=str, default=None) args = parser.parse_args() + if args.broker_type is None: + args.broker_type = str(BrokerType.SQS.value if args.sqs_url else BrokerType.REDIS.value) + forwarder_config = load_named_config(args.config, args.set) forwarder_loader = LoadForwarder(**forwarder_config["async"]) forwader = forwarder_loader.load(None, None) - app = create_celery_service(forwader, TaskVisibility.VISIBILITY_24H, args.queue, args.sqs_url) + app = create_celery_service( + forwader, + TaskVisibility.VISIBILITY_24H, + args.broker_type, + args.backend_protocol, + args.queue, + args.sqs_url, + ) start_celery_service(app, args.queue, args.num_workers) diff --git a/model-engine/model_engine_server/infra/gateways/__init__.py b/model-engine/model_engine_server/infra/gateways/__init__.py index de4eb6b7..b36fb641 100644 --- a/model-engine/model_engine_server/infra/gateways/__init__.py +++ b/model-engine/model_engine_server/infra/gateways/__init__.py @@ -1,5 +1,8 @@ from typing import Sequence +from .abs_file_storage_gateway import ABSFileStorageGateway +from .abs_filesystem_gateway import ABSFilesystemGateway +from .abs_llm_artifact_gateway import ABSLLMArtifactGateway from .batch_job_orchestration_gateway import BatchJobOrchestrationGateway from .batch_job_progress_gateway import BatchJobProgressGateway from .celery_task_queue_gateway import CeleryTaskQueueGateway @@ -22,6 +25,9 @@ from .s3_llm_artifact_gateway import S3LLMArtifactGateway __all__: Sequence[str] = [ + "ABSFileStorageGateway", + "ABSFilesystemGateway", + "ABSLLMArtifactGateway", "BatchJobOrchestrationGateway", "BatchJobProgressGateway", "CeleryTaskQueueGateway", diff --git a/model-engine/model_engine_server/infra/gateways/abs_file_storage_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_file_storage_gateway.py new file mode 100644 index 00000000..a12a0cb7 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/abs_file_storage_gateway.py @@ -0,0 +1,34 @@ +from typing import List, Optional + +from model_engine_server.domain.gateways.file_storage_gateway import ( + FileMetadata, + FileStorageGateway, +) +from model_engine_server.infra.gateways.abs_filesystem_gateway import ABSFilesystemGateway + + +class ABSFileStorageGateway(FileStorageGateway): + """ + Concrete implementation of a file storage gateway backed by ABS. + """ + + def __init__(self): + self.filesystem_gateway = ABSFilesystemGateway() + + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + raise NotImplementedError("ABS not supported yet") + + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + raise NotImplementedError("ABS not supported yet") + + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + raise NotImplementedError("ABS not supported yet") + + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + raise NotImplementedError("ABS not supported yet") + + async def delete_file(self, owner: str, file_id: str) -> bool: + raise NotImplementedError("ABS not supported yet") + + async def list_files(self, owner: str) -> List[FileMetadata]: + raise NotImplementedError("ABS not supported yet") diff --git a/model-engine/model_engine_server/infra/gateways/abs_filesystem_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_filesystem_gateway.py new file mode 100644 index 00000000..abf6f99e --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/abs_filesystem_gateway.py @@ -0,0 +1,48 @@ +import os +import re +from datetime import datetime, timedelta +from typing import IO + +import smart_open +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobSasPermissions, BlobServiceClient, generate_blob_sas +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway + + +class ABSFilesystemGateway(FilesystemGateway): + """ + Concrete implementation for interacting with a filesystem backed by Azure Blob Storage. + """ + + # uri should start with azure:// (as opposed to https://) unless the container is publicly accessible + def open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), + ) + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str: + match = re.search("^https://([^/]+)\.blob\.core\.windows\.net/([^/]+)/(.*?)$", uri) + assert match + + account_name, container_name, blob_name = match.group(1), match.group(2), match.group(3) + + blob_service_client = BlobServiceClient( + f"https://{account_name}.blob.core.windows.net", DefaultAzureCredential() + ) + user_delegation_key = blob_service_client.get_user_delegation_key( + datetime.utcnow(), datetime.utcnow() + timedelta(seconds=expiration) + ) + + sas_blob = generate_blob_sas( + account_name=account_name, + container_name=container_name, + blob_name=blob_name, + user_delegation_key=user_delegation_key, + permission=BlobSasPermissions(read=True, write=False, create=False), + expiry=datetime.utcnow() + timedelta(seconds=expiration), + **kwargs, + ) + return uri + "?" + sas_blob diff --git a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py new file mode 100644 index 00000000..8ebbeda3 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py @@ -0,0 +1,75 @@ +import os +from typing import List + +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient, ContainerClient +from model_engine_server.common.config import get_model_cache_directory_name, hmi_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.url import parse_attachment_url +from model_engine_server.domain.gateways import LLMArtifactGateway + +logger = make_logger(logger_name()) + + +def _get_abs_container_client(bucket: str) -> ContainerClient: + blob_service_client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", DefaultAzureCredential() + ) + return blob_service_client.get_container_client(container=bucket) + + +class ABSLLMArtifactGateway(LLMArtifactGateway): + """ + Concrete implemention using Azure Blob Storage. + """ + + def list_files(self, path: str, **kwargs) -> List[str]: + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = parsed_remote.key + + container_client = _get_abs_container_client(bucket) + return list(container_client.list_blob_names(name_starts_with=key)) + + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = parsed_remote.key + + container_client = _get_abs_container_client(bucket) + + downloaded_files: List[str] = [] + for blob in container_client.list_blobs(name_starts_with=key): + file_path_suffix = blob.name.replace(key, "").lstrip("/") + local_path = os.path.join(target_path, file_path_suffix).rstrip("/") + + if not overwrite and os.path.exists(local_path): + downloaded_files.append(local_path) + continue + + local_dir = "/".join(local_path.split("/")[:-1]) + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + logger.info(f"Downloading {blob.name} to {local_path}") + with open(file=local_path, mode="wb") as f: + f.write(container_client.download_blob(blob.name).readall()) + downloaded_files.append(local_path) + return downloaded_files + + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + parsed_remote = parse_attachment_url( + hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False + ) + account = parsed_remote.account + bucket = parsed_remote.bucket + fine_tuned_weights_prefix = parsed_remote.key + + container_client = _get_abs_container_client(bucket) + + model_files: List[str] = [] + model_cache_name = get_model_cache_directory_name(model_name) + prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}" + for blob_name in container_client.list_blob_names(name_starts_with=prefix): + model_files.append(f"https://{account}.blob.core.windows.net/{bucket}/{blob_name}") + return model_files diff --git a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index 8d487029..e1f2f11c 100644 --- a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -10,32 +10,51 @@ from model_engine_server.core.config import infra_config from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway +backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3" + celery_redis = celery_app( - None, s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.REDIS.value) + None, + s3_bucket=infra_config().s3_bucket, + broker_type=str(BrokerType.REDIS.value), + backend_protocol=backend_protocol, ) celery_redis_24h = celery_app( None, s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.REDIS.value), task_visibility=TaskVisibility.VISIBILITY_24H, + backend_protocol=backend_protocol, ) celery_sqs = celery_app( - None, s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.SQS.value) + None, + s3_bucket=infra_config().s3_bucket, + broker_type=str(BrokerType.SQS.value), + backend_protocol=backend_protocol, +) +celery_servicebus = celery_app( + None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol ) class CeleryTaskQueueGateway(TaskQueueGateway): def __init__(self, broker_type: BrokerType): self.broker_type = broker_type - assert self.broker_type in [BrokerType.SQS, BrokerType.REDIS, BrokerType.REDIS_24H] + assert self.broker_type in [ + BrokerType.SQS, + BrokerType.REDIS, + BrokerType.REDIS_24H, + BrokerType.SERVICEBUS, + ] def _get_celery_dest(self): if self.broker_type == BrokerType.SQS: return celery_sqs elif self.broker_type == BrokerType.REDIS_24H: return celery_redis_24h - else: # self.broker_type == BrokerType.REDIS + elif self.broker_type == BrokerType.REDIS: return celery_redis + else: + return celery_servicebus def send_task( self, diff --git a/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py new file mode 100644 index 00000000..3799ed65 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py @@ -0,0 +1,68 @@ +import os +from typing import Any, Dict + +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.servicebus.management import ServiceBusAdministrationClient +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, +) + +logger = make_logger(logger_name()) + + +def _get_servicebus_administration_client() -> ServiceBusAdministrationClient: + return ServiceBusAdministrationClient( + f"{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net", + credential=DefaultAzureCredential(), + ) + + +class ASBQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): + """ + Using Azure Service Bus. + """ + + async def create_queue_if_not_exists( + self, + endpoint_id: str, + endpoint_name: str, + endpoint_created_by: str, + endpoint_labels: Dict[str, Any], + ) -> QueueInfo: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + client.create_queue(queue_name=queue_name) + except ResourceExistsError: + pass + + return QueueInfo(queue_name, None) + + async def delete_queue(self, endpoint_id: str) -> None: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + client.delete_queue(queue_name=queue_name) + except ResourceNotFoundError: + logger.info(f"Could not find ASB queue {queue_name} for endpoint {endpoint_id}") + + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + queue_attributes = client.get_queue_runtime_properties(queue_name=queue_name) + except ResourceNotFoundError as e: + raise EndpointResourceInfraException( + f"Could not find ASB queue {queue_name} for endpoint {endpoint_id}" + ) from e + + # queue_attributes does have other fields, but we don't need them right now + return { + "name": queue_attributes.name, + "total_message_count": queue_attributes.total_message_count, + "active_message_count": queue_attributes.active_message_count, + } diff --git a/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py index 145f675b..1e6a3f6d 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py @@ -1,18 +1,17 @@ from abc import ABC, abstractmethod from typing import Dict, Generic, Sequence, Tuple, TypeVar -from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest from model_engine_server.domain.entities import ( ModelEndpointInfraState, ModelEndpointRecord, ModelEndpointType, ) +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import QueueInfo from pydantic import BaseModel __all__: Sequence[str] = ( "EndpointResourceGateway", - "QueueInfo", "EndpointResourceGatewayCreateOrUpdateResourcesResponse", ) @@ -21,11 +20,6 @@ class EndpointResourceGatewayCreateOrUpdateResourcesResponse(BaseModel): destination: str -class QueueInfo(BaseModel): - queue_name: str - broker: BrokerType - - Q = TypeVar("Q", bound=QueueInfo) """Either a QueueInfo or some specialization of it. """ diff --git a/model-engine/model_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py similarity index 58% rename from model-engine/model_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py rename to model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py index e8cfa497..9ded2d6e 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py @@ -1,32 +1,31 @@ from typing import Any, Dict, Sequence -from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, - SQSQueueInfo, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, ) -from mypy_boto3_sqs.type_defs import GetQueueAttributesResultTypeDef -__all__: Sequence[str] = ("FakeSQSEndpointResourceDelegate",) +__all__: Sequence[str] = ("FakeQueueEndpointResourceDelegate",) -class FakeSQSEndpointResourceDelegate(SQSEndpointResourceDelegate): +class FakeQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): async def create_queue_if_not_exists( self, endpoint_id: str, endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], - ) -> SQSQueueInfo: - queue_name = SQSEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + ) -> QueueInfo: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) queue_url = f"http://foobar.com/{queue_name}" - return SQSQueueInfo(queue_name, queue_url) + return QueueInfo(queue_name, queue_url) async def delete_queue(self, endpoint_id: str) -> None: # Don't need to do anything, since the contract says that deleting is a no-op, # and we don't need to simulate real exceptions. pass - async def get_queue_attributes(self, endpoint_id: str) -> GetQueueAttributesResultTypeDef: + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: return { "Attributes": { "ApproximateNumberOfMessages": "100", diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 483c5e5b..6b958920 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -509,10 +509,17 @@ def get_endpoint_resource_arguments_from_request( image_hash = compute_image_hash(request.image) # In Circle CI, we use Redis on localhost instead of SQS - broker_name = BrokerName.SQS.value if not CIRCLECI else BrokerName.REDIS.value - broker_type = BrokerType.SQS.value if not CIRCLECI else BrokerType.REDIS.value + if CIRCLECI: + broker_name = BrokerName.REDIS.value + broker_type = BrokerType.REDIS.value + elif infra_config().cloud_provider == "azure": + broker_name = BrokerName.SERVICEBUS.value + broker_type = BrokerType.SERVICEBUS.value + else: + broker_name = BrokerName.SQS.value + broker_type = BrokerType.SQS.value dd_trace_enabled = hmi_config.dd_trace_enabled - if broker_type == BrokerType.REDIS.value: + if broker_type != BrokerType.SQS.value: sqs_queue_url = "" main_env = [] diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index 2d6b7410..516470ba 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -1,6 +1,5 @@ from typing import Dict, Optional, Tuple -from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ( @@ -8,53 +7,40 @@ ModelEndpointRecord, ModelEndpointType, ) -from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException +from model_engine_server.domain.exceptions import EndpointResourceInfraException from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, EndpointResourceGatewayCreateOrUpdateResourcesResponse, - QueueInfo, ) from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( K8SEndpointResourceDelegate, ) -from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, ) logger = make_logger(logger_name()) -class SqsQueueInfo(QueueInfo): - """Live endpoints create and use SQS queues. These come with an additional per-queue URL. - - NOTE: broker for this class **MUST** always be SQS. - """ - - queue_url: str - - @staticmethod - def new(queue_name: str, queue_url: str) -> "SqsQueueInfo": - return SqsQueueInfo(queue_name=queue_name, broker=BrokerType.SQS, queue_url=queue_url) - - -class LiveEndpointResourceGateway(EndpointResourceGateway[SqsQueueInfo]): - def __init__(self, sqs_delegate: SQSEndpointResourceDelegate): +class LiveEndpointResourceGateway(EndpointResourceGateway[QueueInfo]): + def __init__(self, queue_delegate: QueueEndpointResourceDelegate): self.k8s_delegate = K8SEndpointResourceDelegate() - self.sqs_delegate = sqs_delegate + self.queue_delegate = queue_delegate async def create_queue( self, endpoint_record: ModelEndpointRecord, labels: Dict[str, str], - ) -> SqsQueueInfo: - """Creates a new SQS queue, returning its unique name and queue URL.""" - queue_name, queue_url = await self.sqs_delegate.create_queue_if_not_exists( + ) -> QueueInfo: + """Creates a new queue, returning its unique name and queue URL.""" + queue_name, queue_url = await self.queue_delegate.create_queue_if_not_exists( endpoint_id=endpoint_record.id, endpoint_name=endpoint_record.name, endpoint_created_by=endpoint_record.created_by, endpoint_labels=labels, ) - return SqsQueueInfo.new(queue_name, queue_url) + return QueueInfo(queue_name, queue_url) async def create_or_update_resources( self, request: CreateOrUpdateResourcesRequest @@ -90,11 +76,16 @@ async def get_resources( ) if endpoint_type == ModelEndpointType.ASYNC: - sqs_attributes = await self.sqs_delegate.get_queue_attributes(endpoint_id=endpoint_id) - if "ApproximateNumberOfMessages" in sqs_attributes["Attributes"]: + sqs_attributes = await self.queue_delegate.get_queue_attributes(endpoint_id=endpoint_id) + if ( + "Attributes" in sqs_attributes + and "ApproximateNumberOfMessages" in sqs_attributes["Attributes"] + ): resources.num_queued_items = int( sqs_attributes["Attributes"]["ApproximateNumberOfMessages"] ) + elif "active_message_count" in sqs_attributes: # from ASBQueueEndpointResourceDelegate + resources.num_queued_items = int(sqs_attributes["active_message_count"]) return resources @@ -113,8 +104,8 @@ async def delete_resources( ) sqs_result = True try: - await self.sqs_delegate.delete_queue(endpoint_id=endpoint_id) - except EndpointResourceInvalidRequestException as e: + await self.queue_delegate.delete_queue(endpoint_id=endpoint_id) + except EndpointResourceInfraException as e: logger.warning("Could not delete SQS resources", exc_info=e) sqs_result = False diff --git a/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py new file mode 100644 index 00000000..76c77e64 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, NamedTuple, Optional, Sequence + +__all__: Sequence[str] = ( + "QueueInfo", + "QueueEndpointResourceDelegate", +) + + +class QueueInfo(NamedTuple): + queue_name: str + queue_url: Optional[str] + + +class QueueEndpointResourceDelegate(ABC): + """ + Base class for an interactor with SQS or ASB. This is used by the LiveEndpointResourceGateway. + """ + + @abstractmethod + async def create_queue_if_not_exists( + self, + endpoint_id: str, + endpoint_name: str, + endpoint_created_by: str, + endpoint_labels: Dict[str, Any], + ) -> QueueInfo: + """ + Creates a queue associated with the given endpoint_id. Other fields are set as tags on the queue. + """ + + @abstractmethod + async def delete_queue(self, endpoint_id: str) -> None: + """ + Deletes a queue associated with the given endpoint_id. This is a no-op if the queue does not exist. + """ + + @abstractmethod + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + """ + Get attributes of a queue. + """ + + @staticmethod + def endpoint_id_to_queue_name(endpoint_id: str) -> str: + return f"launch-endpoint-id-{endpoint_id}" diff --git a/model-engine/model_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py deleted file mode 100644 index de3a59e3..00000000 --- a/model-engine/model_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py +++ /dev/null @@ -1,48 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, NamedTuple, Sequence - -from mypy_boto3_sqs.type_defs import GetQueueAttributesResultTypeDef - -__all__: Sequence[str] = ( - "SQSQueueInfo", - "SQSEndpointResourceDelegate", -) - - -class SQSQueueInfo(NamedTuple): - queue_name: str - queue_url: str - - -class SQSEndpointResourceDelegate(ABC): - """ - Base class for an interactor with SQS. This is used by the LiveEndpointResourceGateway. - """ - - @abstractmethod - async def create_queue_if_not_exists( - self, - endpoint_id: str, - endpoint_name: str, - endpoint_created_by: str, - endpoint_labels: Dict[str, Any], - ) -> SQSQueueInfo: - """ - Creates an SQS queue associated with the given endpoint_id. Other fields are set as tags on the queue. - """ - - @abstractmethod - async def delete_queue(self, endpoint_id: str) -> None: - """ - Deletes an SQS queue associated with the given endpoint_id. This is a no-op if the queue does not exist. - """ - - @abstractmethod - async def get_queue_attributes(self, endpoint_id: str) -> GetQueueAttributesResultTypeDef: - """ - Get all attributes of an SQS queue. - """ - - @staticmethod - def endpoint_id_to_queue_name(endpoint_id: str) -> str: - return f"launch-endpoint-id-{endpoint_id}" diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py similarity index 86% rename from model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py rename to model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py index f04d6b65..748c3f69 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py @@ -10,15 +10,14 @@ from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import EndpointResourceInfraException -from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, - SQSQueueInfo, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, ) -from mypy_boto3_sqs.type_defs import GetQueueAttributesResultTypeDef logger = make_logger(logger_name()) -__all__: Sequence[str] = ("LiveSQSEndpointResourceDelegate",) +__all__: Sequence[str] = ("SQSQueueEndpointResourceDelegate",) def _create_async_sqs_client(sqs_profile: Optional[str]) -> AioBaseClient: @@ -46,7 +45,7 @@ def _get_queue_tags( ) -class LiveSQSEndpointResourceDelegate(SQSEndpointResourceDelegate): +class SQSQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): def __init__(self, sqs_profile: Optional[str]): self.sqs_profile = sqs_profile @@ -56,13 +55,13 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], - ) -> SQSQueueInfo: + ) -> QueueInfo: async with _create_async_sqs_client(sqs_profile=self.sqs_profile) as sqs_client: - queue_name = SQSEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) try: get_queue_url_response = await sqs_client.get_queue_url(QueueName=queue_name) - return SQSQueueInfo( + return QueueInfo( queue_name=queue_name, queue_url=get_queue_url_response["QueueUrl"], ) @@ -94,10 +93,10 @@ async def create_queue_if_not_exists( f"Creating SQS queue got non-200 response: {create_response}" ) - return SQSQueueInfo(queue_name, create_response["QueueUrl"]) + return QueueInfo(queue_name, create_response["QueueUrl"]) async def delete_queue(self, endpoint_id: str) -> None: - queue_name = SQSEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) async with _create_async_sqs_client(self.sqs_profile) as sqs_client: try: queue_url = (await sqs_client.get_queue_url(QueueName=queue_name))["QueueUrl"] @@ -122,8 +121,8 @@ async def delete_queue(self, endpoint_id: str) -> None: f"Deleting SQS queue got non-200 response: {delete_response}" ) - async def get_queue_attributes(self, endpoint_id: str) -> GetQueueAttributesResultTypeDef: - queue_name = SQSEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) async with _create_async_sqs_client(self.sqs_profile) as sqs_client: try: queue_url = (await sqs_client.get_queue_url(QueueName=queue_name))["QueueUrl"] diff --git a/model-engine/model_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py index 42e9988c..f14cf69f 100644 --- a/model-engine/model_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -1,5 +1,8 @@ from typing import Sequence +from .abs_file_llm_fine_tune_events_repository import ABSFileLLMFineTuneEventsRepository +from .abs_file_llm_fine_tune_repository import ABSFileLLMFineTuneRepository +from .acr_docker_repository import ACRDockerRepository from .batch_job_record_repository import BatchJobRecordRepository from .db_batch_job_record_repository import DbBatchJobRecordRepository from .db_docker_image_batch_job_bundle_repository import DbDockerImageBatchJobBundleRepository @@ -19,6 +22,9 @@ from .s3_file_llm_fine_tune_repository import S3FileLLMFineTuneRepository __all__: Sequence[str] = [ + "ABSFileLLMFineTuneEventsRepository", + "ABSFileLLMFineTuneRepository", + "ACRDockerRepository", "BatchJobRecordRepository", "DbBatchJobRecordRepository", "DbDockerImageBatchJobBundleRepository", diff --git a/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py new file mode 100644 index 00000000..9d33585b --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py @@ -0,0 +1,19 @@ +from typing import List + +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent +from model_engine_server.domain.repositories.llm_fine_tune_events_repository import ( + LLMFineTuneEventsRepository, +) + + +class ABSFileLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): + def __init__(self): + pass + + async def get_fine_tune_events( + self, user_id: str, model_endpoint_name: str + ) -> List[LLMFineTuneEvent]: + raise NotImplementedError("ABS not supported yet") + + async def initialize_events(self, user_id: str, model_endpoint_name: str) -> None: + raise NotImplementedError("ABS not supported yet") diff --git a/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py new file mode 100644 index 00000000..a205fd83 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py @@ -0,0 +1,19 @@ +from typing import Optional + +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository + + +class ABSFileLLMFineTuneRepository(LLMFineTuneRepository): + def __init__(self, file_path: str): + self.file_path = file_path + + async def get_job_template_for_model( + self, model_name: str, fine_tuning_method: str + ) -> Optional[LLMFineTuneTemplate]: + raise NotImplementedError("ABS not supported yet") + + async def write_job_template_for_model( + self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate + ): + raise NotImplementedError("ABS not supported yet") diff --git a/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py new file mode 100644 index 00000000..2d6e1cc3 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py @@ -0,0 +1,42 @@ +from typing import Optional + +from azure.containerregistry import ContainerRegistryClient +from azure.core.exceptions import ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.repositories import DockerRepository + +logger = make_logger(logger_name()) + + +class ACRDockerRepository(DockerRepository): + def image_exists( + self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None + ) -> bool: + endpoint = f"https://{infra_config().docker_repo_prefix}" + credential = DefaultAzureCredential() + client = ContainerRegistryClient(endpoint, credential) + + try: + client.get_manifest_properties(repository_name, image_tag) + except ResourceNotFoundError: + return False + return True + + def get_image_url(self, image_tag: str, repository_name: str) -> str: + return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" + + def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: + raise NotImplementedError("ACR image build not supported yet") + + def get_latest_image_tag(self, repository_name: str) -> str: + endpoint = f"https://{infra_config().docker_repo_prefix}" + credential = DefaultAzureCredential() + client = ContainerRegistryClient(endpoint, credential) + + image = client.list_manifest_properties( + repository_name, order_by="time_desc", results_per_page=1 + ).next() + return image.tags[0] diff --git a/model-engine/model_engine_server/service_builder/celery.py b/model-engine/model_engine_server/service_builder/celery.py index a5ac93e6..67cb94b0 100644 --- a/model-engine/model_engine_server/service_builder/celery.py +++ b/model-engine/model_engine_server/service_builder/celery.py @@ -1,3 +1,4 @@ +from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.core.celery import celery_app from model_engine_server.core.config import infra_config @@ -7,6 +8,10 @@ "model_engine_server.service_builder.tasks_v1", ], s3_bucket=infra_config().s3_bucket, + broker_type=str(BrokerType.SERVICEBUS.value) + if infra_config().cloud_provider == "azure" + else str(BrokerType.REDIS.value), + backend_protocol="abs" if infra_config().cloud_provider == "azure" else "s3", ) if __name__ == "__main__": diff --git a/model-engine/model_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py index b8b38a28..7615e123 100644 --- a/model-engine/model_engine_server/service_builder/tasks_v1.py +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -12,11 +12,16 @@ BuildEndpointResponse, ) from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.core.config import infra_config from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway from model_engine_server.db.base import SessionAsyncNullPool -from model_engine_server.infra.gateways import S3FilesystemGateway -from model_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( - FakeSQSEndpointResourceDelegate, +from model_engine_server.domain.repositories import DockerRepository +from model_engine_server.infra.gateways import ABSFilesystemGateway, S3FilesystemGateway +from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( + ASBQueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( + FakeQueueEndpointResourceDelegate, ) from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( set_lazy_load_kubernetes_clients, @@ -24,13 +29,14 @@ from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, ) -from model_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, ) from model_engine_server.infra.repositories import ( + ACRDockerRepository, DbModelEndpointRecordRepository, ECRDockerRepository, FakeDockerRepository, @@ -49,27 +55,37 @@ def get_live_endpoint_builder_service( session: Any, redis: aioredis.Redis, ): - sqs_delegate: SQSEndpointResourceDelegate + queue_delegate: QueueEndpointResourceDelegate if CIRCLECI: - sqs_delegate = FakeSQSEndpointResourceDelegate() + queue_delegate = FakeQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "azure": + queue_delegate = ASBQueueEndpointResourceDelegate() else: - sqs_delegate = LiveSQSEndpointResourceDelegate( + queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) notification_gateway = FakeNotificationGateway() monitoring_metrics_gateway = get_monitoring_metrics_gateway() - docker_repository = ECRDockerRepository() if not CIRCLECI else FakeDockerRepository() + docker_repository: DockerRepository + if CIRCLECI: + docker_repository = FakeDockerRepository() + elif infra_config().cloud_provider == "azure": + docker_repository = ACRDockerRepository() + else: + docker_repository = ECRDockerRepository() service = LiveEndpointBuilderService( docker_repository=docker_repository, resource_gateway=LiveEndpointResourceGateway( - sqs_delegate=sqs_delegate, + queue_delegate=queue_delegate, ), monitoring_metrics_gateway=monitoring_metrics_gateway, model_endpoint_record_repository=DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False ), model_endpoint_cache_repository=RedisModelEndpointCacheRepository(redis_client=redis), - filesystem_gateway=S3FilesystemGateway(), + filesystem_gateway=ABSFilesystemGateway() + if infra_config().cloud_provider == "azure" + else S3FilesystemGateway(), notification_gateway=notification_gateway, feature_flag_repo=RedisFeatureFlagRepository(redis_client=redis), ) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index eb2d393e..eaa46f55 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -4,11 +4,16 @@ aiohttp~=3.8 aioredis~=2.0 alembic==1.8.1 asyncpg==0.27.0 +azure-containerregistry~=1.2.0 +azure-identity~=1.15.0 +azure-keyvault-secrets~=4.7.0 +azure-servicebus~=7.11.4 +azure-storage-blob~=12.19.0 boto3-stubs[essential]==1.26.67 boto3~=1.21 botocore~=1.24 build==0.8.0 -celery[redis,sqs,tblib]~=5.2 +celery[redis,sqs,tblib]~=5.3.6 click~=8.1 cloudpickle==2.1.0 croniter==1.4.1 @@ -24,7 +29,7 @@ gunicorn~=20.0 httptools==0.5.0 json-log-formatter~=0.3 kubeconfig~=1.1 -kubernetes-asyncio==24.2.2 +kubernetes-asyncio==25.11.0 kubernetes~=25.3.0 orjson==3.8.6 protobuf~=3.20 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 2a7390f8..56acbaf9 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -6,7 +6,7 @@ # aiofiles==23.1.0 # via quart -aiohttp==3.8.5 +aiohttp==3.9.1 # via # -r model-engine/requirements.in # kubernetes-asyncio @@ -19,7 +19,9 @@ alembic==1.8.1 amqp==5.1.1 # via kombu anyio==3.7.1 - # via starlette + # via + # azure-core + # starlette asgiref==3.7.2 # via uvicorn asn1crypto==1.5.1 @@ -38,11 +40,30 @@ attrs==23.1.0 # ddtrace # jsonschema # referencing +azure-common==1.1.28 + # via azure-keyvault-secrets +azure-containerregistry==1.2.0 + # via -r model-engine/requirements.in +azure-core==1.29.6 + # via + # azure-containerregistry + # azure-identity + # azure-keyvault-secrets + # azure-servicebus + # azure-storage-blob +azure-identity==1.15.0 + # via -r model-engine/requirements.in +azure-keyvault-secrets==4.7.0 + # via -r model-engine/requirements.in +azure-servicebus==7.11.4 + # via -r model-engine/requirements.in +azure-storage-blob==12.19.0 + # via -r model-engine/requirements.in backports-zoneinfo[tzdata]==0.2.1 # via # celery # kombu -billiard==4.1.0 +billiard==4.2.0 # via celery bleach==6.0.0 # via readme-renderer @@ -54,9 +75,7 @@ boto3==1.28.1 # celery # kombu boto3-stubs[essential]==1.26.67 - # via - # -r model-engine/requirements.in - # boto3-stubs + # via -r model-engine/requirements.in botocore==1.31.1 # via # -r model-engine/requirements.in @@ -72,20 +91,18 @@ cachetools==5.3.1 # via google-auth cattrs==23.1.2 # via ddtrace -celery[redis,sqs,tblib]==5.3.1 - # via - # -r model-engine/requirements.in - # celery +celery[redis,sqs,tblib]==5.3.6 + # via -r model-engine/requirements.in certifi==2023.7.22 # via # datadog-api-client # kubernetes # kubernetes-asyncio # requests +cffi==1.16.0 + # via cryptography charset-normalizer==3.2.0 - # via - # aiohttp - # requests + # via requests click==8.1.4 # via # -r model-engine/requirements.in @@ -109,6 +126,13 @@ commonmark==0.9.1 # via rich croniter==1.4.1 # via -r model-engine/requirements.in +cryptography==41.0.7 + # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # secretstorage dataclasses-json==0.5.9 # via -r model-engine/requirements.in datadog==0.47.0 @@ -191,10 +215,20 @@ importlib-resources==6.1.1 # jsonschema # jsonschema-specifications # keyring +isodate==0.6.1 + # via + # azure-containerregistry + # azure-keyvault-secrets + # azure-servicebus + # azure-storage-blob itsdangerous==2.1.2 # via quart jaraco-classes==3.3.0 # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.0.3 # via # -r model-engine/requirements.in @@ -211,13 +245,13 @@ jsonschema-specifications==2023.7.1 # via jsonschema keyring==24.2.0 # via twine -kombu[sqs]==5.3.1 +kombu[sqs]==5.3.5 # via celery kubeconfig==1.1.1 # via -r model-engine/requirements.in kubernetes==25.3.0 # via -r model-engine/requirements.in -kubernetes-asyncio==24.2.2 +kubernetes-asyncio==25.11.0 # via -r model-engine/requirements.in mako==1.2.4 # via alembic @@ -235,6 +269,12 @@ marshmallow-enum==1.5.1 # via dataclasses-json more-itertools==9.1.0 # via jaraco-classes +msal==1.26.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.1.0 + # via azure-identity multidict==6.0.4 # via # aiohttp @@ -268,6 +308,7 @@ packaging==23.1 # deprecation # huggingface-hub # marshmallow + # msal-extensions # transformers pep517==0.13.0 # via build @@ -277,6 +318,8 @@ pkginfo==1.9.6 # via twine pkgutil-resolve-name==1.3.10 # via jsonschema +portalocker==2.8.2 + # via msal-extensions priority==2.0.0 # via hypercorn prompt-toolkit==3.0.39 @@ -296,6 +339,8 @@ pyasn1==0.5.0 # rsa pyasn1-modules==0.3.0 # via google-auth +pycparser==2.21 + # via cffi pycurl==7.45.2 # via # -r model-engine/requirements.in @@ -309,6 +354,8 @@ pygments==2.15.1 # via # readme-renderer # rich +pyjwt[crypto]==2.8.0 + # via msal python-dateutil==2.8.2 # via # botocore @@ -342,10 +389,12 @@ regex==2023.10.3 requests==2.31.0 # via # -r model-engine/requirements.in + # azure-core # datadog # docker # huggingface-hub # kubernetes + # msal # requests-auth-aws-sigv4 # requests-oauthlib # requests-toolbelt @@ -373,16 +422,20 @@ safetensors==0.4.0 # via transformers scramp==1.4.4 # via pg8000 +secretstorage==3.3.3 + # via keyring sentencepiece==0.1.99 # via -r model-engine/requirements.in sh==1.14.3 # via -r model-engine/requirements.in six==1.16.0 # via + # azure-core # bleach # ddsketch # ddtrace # google-auth + # isodate # kubernetes # kubernetes-asyncio # python-dateutil @@ -401,7 +454,6 @@ sqlalchemy[asyncio]==2.0.4 # via # -r model-engine/requirements.in # alembic - # sqlalchemy sse-starlette==1.6.1 # via -r model-engine/requirements.in sseclient-py==1.7.2 @@ -449,6 +501,10 @@ typing-extensions==4.7.1 # via # aioredis # asgiref + # azure-core + # azure-keyvault-secrets + # azure-servicebus + # azure-storage-blob # boto3-stubs # botocore-stubs # bytecode @@ -489,7 +545,7 @@ uvicorn==0.17.6 # via -r model-engine/requirements.in uvloop==0.17.0 # via -r model-engine/requirements.in -vine==5.0.0 +vine==5.1.0 # via # amqp # celery @@ -518,4 +574,8 @@ zipp==3.16.0 # importlib-resources # The following packages are considered to be unsafe in a requirements file: -# setuptools +setuptools==69.0.3 + # via + # gunicorn + # kubernetes + # kubernetes-asyncio diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index a42fdc1b..001a54b7 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -48,7 +48,7 @@ sqs_queue_tag_template: > # Billing billing_queue_arn: none # There's a separate piece of infra that caches k8s state onto redis, so we need a url to it -cache_redis_url: redis://127.0.0.1:6379/15 +cache_redis_aws_url: redis://127.0.0.1:6379/15 s3_file_llm_fine_tune_repository: "s3://model-engine-integration-tests/fine_tune_repository/circleci" diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg index 053cae1e..76fa54d1 100644 --- a/model-engine/setup.cfg +++ b/model-engine/setup.cfg @@ -5,7 +5,21 @@ test=pytest omit = model_engine_server/entrypoints/* model_engine_server/api/app.py + model_engine_server/api/dependencies.py + model_engine_server/common/config.py + model_engine_server/common/io.py + model_engine_server/core/celery/app.py model_engine_server/core/docker/ecr.py + model_engine_server/db/base.py + model_engine_server/infra/gateways/abs_file_storage_gateway.py + model_engine_server/infra/gateways/abs_filesystem_gateway.py + model_engine_server/infra/gateways/abs_llm_artifact_gateway.py + model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py + model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py + model_engine_server/infra/gateways/resources/k8s_resource_types.py + model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py + model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py + model_engine_server/infra/repositories/acr_docker_repository.py # TODO: Fix pylint errors # [pylint] diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index e7ad32cc..61473b37 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -1163,7 +1163,7 @@ async def create_queue( """Creates a new, unique queue name. Used by this endpoint resource gateway to create new resources. """ - return QueueInfo(queue_name="foobar", broker=BrokerType.REDIS) + return QueueInfo(queue_name="foobar", queue_url=None) async def create_or_update_resources( self, request: CreateOrUpdateResourcesRequest diff --git a/model-engine/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py similarity index 95% rename from model-engine/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py rename to model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py index 1ab2143a..ae00ac43 100644 --- a/model-engine/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py @@ -7,11 +7,11 @@ from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest from model_engine_server.domain.entities import ModelEndpointRecord from model_engine_server.domain.exceptions import EndpointResourceInfraException -from model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, ) -MODULE_PATH = "model_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate" +MODULE_PATH = "model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate" EXPECTED_QUEUE_POLICY = """ { @@ -340,7 +340,7 @@ async def test_sqs_create_or_update_resources_endpoint_exists( build_endpoint_request_async_custom: BuildEndpointRequest, mock_create_async_sqs_client_get_queue_url, ): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record queue_name, queue_url = await delegate.create_queue_if_not_exists( endpoint_id=endpoint_record.id, @@ -368,7 +368,7 @@ async def test_sqs_create_or_update_resources( build_endpoint_request_async_custom: BuildEndpointRequest, mock_create_async_sqs_client_create_queue, ): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record queue_name, queue_url = await delegate.create_queue_if_not_exists( endpoint_id=endpoint_record.id, @@ -408,7 +408,7 @@ async def test_sqs_create_or_update_resources_throws_exception( build_endpoint_request_async_custom: BuildEndpointRequest, mock_create_async_sqs_client_create_queue_throws_exception, ): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record with pytest.raises(EndpointResourceInfraException): await delegate.create_queue_if_not_exists( @@ -424,7 +424,7 @@ async def test_sqs_create_or_update_resources_non_200( build_endpoint_request_async_custom: BuildEndpointRequest, mock_create_async_sqs_client_create_queue_returns_non_200, ): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record with pytest.raises(EndpointResourceInfraException): await delegate.create_queue_if_not_exists( @@ -437,7 +437,7 @@ async def test_sqs_create_or_update_resources_non_200( @pytest.mark.asyncio async def test_sqs_delete_resources(mock_create_async_sqs_client_delete_queue): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.delete_queue(endpoint_id="model_endpoint_id_1") mock_create_async_sqs_client_delete_queue.__aenter__.assert_called_once() @@ -456,7 +456,7 @@ async def test_sqs_delete_resources_throws_exception( mock_create_async_sqs_client_delete_queue_throws_exception, ): with pytest.raises(EndpointResourceInfraException): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.delete_queue(endpoint_id="model_endpoint_id_1") @@ -465,13 +465,13 @@ async def test_sqs_delete_resources_non_200( mock_create_async_sqs_client_delete_queue_returns_non_200, ): with pytest.raises(EndpointResourceInfraException): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.delete_queue(endpoint_id="model_endpoint_id_1") @pytest.mark.asyncio async def test_sqs_get_queue_attributes(mock_create_async_sqs_client_get_queue_attributes): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") response = await delegate.get_queue_attributes(endpoint_id="model_endpoint_id_1") mock_create_async_sqs_client_get_queue_attributes.__aenter__.assert_called_once() @@ -494,7 +494,7 @@ async def test_sqs_get_queue_attributes_queue_not_found( mock_create_async_sqs_client_get_queue_attributes_queue_not_found, ): with pytest.raises(EndpointResourceInfraException): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.get_queue_attributes(endpoint_id="model_endpoint_id_1") @@ -503,5 +503,5 @@ async def test_sqs_get_queue_attributes_queue_throws_exception( mock_create_async_sqs_client_get_queue_attributes_queue_throws_exception, ): with pytest.raises(EndpointResourceInfraException): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.get_queue_attributes(endpoint_id="model_endpoint_id_1") From d88511b5ae60bca02a2722fcab51a78e46f47203 Mon Sep 17 00:00:00 2001 From: Tiffany Zhao <142925794+tiffzhao5@users.noreply.github.com> Date: Tue, 20 Feb 2024 13:24:03 -0800 Subject: [PATCH 236/425] remove handling (#438) --- .../inference/async_inference/tasks.py | 34 +----------------- .../sync_inference/fastapi_server.py | 35 ++----------------- 2 files changed, 4 insertions(+), 65 deletions(-) diff --git a/model-engine/model_engine_server/inference/async_inference/tasks.py b/model-engine/model_engine_server/inference/async_inference/tasks.py index 62bde09c..69f9c9d0 100644 --- a/model-engine/model_engine_server/inference/async_inference/tasks.py +++ b/model-engine/model_engine_server/inference/async_inference/tasks.py @@ -10,17 +10,7 @@ from model_engine_server.core.utils.timer import timer from model_engine_server.domain.entities import ModelEndpointConfig from model_engine_server.inference.async_inference.celery import async_inference_service -from model_engine_server.inference.common import ( - get_endpoint_config, - load_predict_fn_or_cls, - run_predict, -) -from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( - DatadogInferenceMonitoringMetricsGateway, -) -from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( - FirehoseStreamingStorageGateway, -) +from model_engine_server.inference.common import load_predict_fn_or_cls, run_predict from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler logger = make_logger(logger_name()) @@ -38,23 +28,6 @@ def init_worker_global(): with timer(logger=logger, name="load_predict_fn_or_cls"): predict_fn_or_cls = load_predict_fn_or_cls() - endpoint_config = get_endpoint_config() - hooks = PostInferenceHooksHandler( - endpoint_name=endpoint_config.endpoint_name, - bundle_name=endpoint_config.bundle_name, - post_inference_hooks=endpoint_config.post_inference_hooks, - user_id=endpoint_config.user_id, - billing_queue=endpoint_config.billing_queue, - billing_tags=endpoint_config.billing_tags, - default_callback_url=endpoint_config.default_callback_url, - default_callback_auth=endpoint_config.default_callback_auth, - monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), - endpoint_id=endpoint_config.endpoint_id, - endpoint_type=endpoint_config.endpoint_type, - bundle_id=endpoint_config.bundle_id, - labels=endpoint_config.labels, - streaming_storage_gateway=FirehoseStreamingStorageGateway(), - ) # k8s health check with open(READYZ_FPATH, "w") as f: f.write("READY") @@ -96,11 +69,6 @@ def predict(self, request_params, return_pickled): request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params) return run_predict(predict_fn_or_cls, request_params_pydantic) # type: ignore - def on_success(self, retval, task_id, args, kwargs): - request_params = args[0] - request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params) - hooks.handle(request_params_pydantic, retval, task_id) # type: ignore - @async_inference_service.task( base=InferenceTask, diff --git a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py index 3d30bf0c..aba74bbe 100644 --- a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py @@ -1,22 +1,11 @@ import traceback from functools import wraps -from fastapi import BackgroundTasks, FastAPI, HTTPException, Response, status +from fastapi import FastAPI, HTTPException, Response, status from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger -from model_engine_server.inference.common import ( - get_endpoint_config, - load_predict_fn_or_cls, - run_predict, -) -from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( - DatadogInferenceMonitoringMetricsGateway, -) -from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( - FirehoseStreamingStorageGateway, -) -from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from model_engine_server.inference.common import load_predict_fn_or_cls, run_predict from model_engine_server.inference.sync_inference.constants import ( CONCURRENCY, FAIL_ON_CONCURRENCY_LIMIT, @@ -44,23 +33,6 @@ def _inner_2(*args, **kwargs): # How does this interact with threads? # Analogous to init_worker() inside async_inference predict_fn = load_predict_fn_or_cls() -endpoint_config = get_endpoint_config() -hooks = PostInferenceHooksHandler( - endpoint_name=endpoint_config.endpoint_name, - bundle_name=endpoint_config.bundle_name, - post_inference_hooks=endpoint_config.post_inference_hooks, - user_id=endpoint_config.user_id, - billing_queue=endpoint_config.billing_queue, - billing_tags=endpoint_config.billing_tags, - default_callback_url=endpoint_config.default_callback_url, - default_callback_auth=endpoint_config.default_callback_auth, - monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), - endpoint_id=endpoint_config.endpoint_id, - endpoint_type=endpoint_config.endpoint_type, - bundle_id=endpoint_config.bundle_id, - labels=endpoint_config.labels, - streaming_storage_gateway=FirehoseStreamingStorageGateway(), -) @app.get("/healthcheck") @@ -72,14 +44,13 @@ def healthcheck(): @app.post("/predict") @with_concurrency_limit(concurrency_limiter) -def predict(payload: EndpointPredictV1Request, background_tasks: BackgroundTasks): +def predict(payload: EndpointPredictV1Request): """ Assumption: payload is a JSON with format {"url": , "args": , "returned_pickled": boolean} Returns: Results of running the predict function on the request url. See `run_predict`. """ try: result = run_predict(predict_fn, payload) - background_tasks.add_task(hooks.handle, payload, result) return result except Exception: raise HTTPException(status_code=500, detail=dict(traceback=str(traceback.format_exc()))) From b4e7a5c1da642322b47733c602a64038caf09a54 Mon Sep 17 00:00:00 2001 From: Tiffany Zhao <142925794+tiffzhao5@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:23:12 -0800 Subject: [PATCH 237/425] Clean up logs for logging hook (#439) * clean up logs * fix test * fix typing * fix test * fixes * pragma no cover * pragma no cover * pragma no cover --- .../domain/gateways/streaming_storage_gateway.py | 2 +- .../firehose_streaming_storage_gateway.py | 7 ++++--- .../inference/post_inference_hooks.py | 13 ++++++++++--- .../test_firehose_streaming_storage_gateway.py | 16 +++++++++------- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py b/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py index ae4216dd..add325bc 100644 --- a/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py +++ b/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py @@ -8,7 +8,7 @@ class StreamingStorageGateway(ABC): """ @abstractmethod - def put_record(self, stream_name: str, record: Dict[str, Any]) -> None: + def put_record(self, stream_name: str, record: Dict[str, Any]) -> Dict[str, Any]: """ Put a record into a streaming storage mechanism. diff --git a/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py index ab718737..801178af 100644 --- a/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py +++ b/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py @@ -43,7 +43,7 @@ def _get_firehose_client(self): firehose_client = session.client("firehose", region_name=infra_config().default_region) return firehose_client - def put_record(self, stream_name: str, record: Dict[str, Any]) -> None: + def put_record(self, stream_name: str, record: Dict[str, Any]) -> Dict[str, Any]: """ Put a record into a Firehose stream. @@ -56,8 +56,9 @@ def put_record(self, stream_name: str, record: Dict[str, Any]) -> None: ) if firehose_response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise StreamPutException( - f"Failed to put record into firehose stream {stream_name}. Record content: {record}" + f"Failed to put record into firehose stream {stream_name}. Response metadata {firehose_response['ResponseMetadata']}." ) logger.info( - f"Logged to firehose stream {stream_name}. Record content: {record}, Record ID: {firehose_response['RecordId']}" + f"Logged to firehose stream {stream_name}. Record ID: {firehose_response['RecordId']}. Task ID: {record['RESPONSE_BODY']['task_id']}" ) + return firehose_response diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py index 6f388acb..39f5fcd7 100644 --- a/model-engine/model_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -134,10 +134,17 @@ def handle( if stream_name is None: logger.warning("No firehose stream name specified. Logging hook will not be executed.") return + streaming_storage_response = {} # pragma: no cover try: - self._streaming_storage_gateway.put_record(stream_name=stream_name, record=data_record) - except StreamPutException as e: - logger.error(f"Error in logging hook {e}") + streaming_storage_response = ( + self._streaming_storage_gateway.put_record( # pragma: no cover + stream_name=stream_name, record=data_record + ) + ) + except StreamPutException: # pragma: no cover + logger.error( # pragma: no cover + f"Failed to put record into firehose stream {stream_name}. Response metadata {streaming_storage_response.get('ResponseMetadata')}." + ) class PostInferenceHooksHandler: diff --git a/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py b/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py index 1cedaef6..3ae72a6e 100644 --- a/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py @@ -8,6 +8,12 @@ stream_name = "fake-stream" +return_value = { + "RecordId": "fake-record-id", + "Encrypted": False, + "ResponseMetadata": {"HTTPStatusCode": 200}, +} + @pytest.fixture def streaming_storage_gateway(): @@ -17,7 +23,7 @@ def streaming_storage_gateway(): @pytest.fixture def fake_record(): - return {"Data": "fake-data"} + return {"RESPONSE_BODY": {"task_id": "fake-task-id"}} def mock_sts_client(*args, **kwargs): @@ -34,11 +40,7 @@ def mock_sts_client(*args, **kwargs): def mock_firehose_client(*args, **kwargs): mock_client = mock.Mock() - mock_client.put_record.return_value = { - "RecordId": "fake-record-id", - "Encrypted": False, - "ResponseMetadata": {"HTTPStatusCode": 200}, - } + mock_client.put_record.return_value = return_value return mock_client @@ -76,7 +78,7 @@ def test_firehose_streaming_storage_gateway_put_record(streaming_storage_gateway "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.Session", mock_session, ): - assert streaming_storage_gateway.put_record(stream_name, fake_record) is None + assert streaming_storage_gateway.put_record(stream_name, fake_record) is return_value def test_firehose_streaming_storage_gateway_put_record_with_exception( From 9a892cfa369c54a8a5ef38f974e0ef74a344f057 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Wed, 21 Feb 2024 22:25:15 -0800 Subject: [PATCH 238/425] Fix Infra Task Gateway (#443) * revert use to redis * move service-builder to sqs * service builder uses redis for circleci --- .../endpoint_builder_deployment.yaml | 4 ++-- .../model_engine_server/common/settings.py | 4 ++-- .../start_batch_job_orchestration.py | 24 ++++++++++++------- .../service_builder/celery.py | 13 +++++++--- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/charts/model-engine/templates/endpoint_builder_deployment.yaml b/charts/model-engine/templates/endpoint_builder_deployment.yaml index 273543f5..2868e87b 100644 --- a/charts/model-engine/templates/endpoint_builder_deployment.yaml +++ b/charts/model-engine/templates/endpoint_builder_deployment.yaml @@ -56,9 +56,9 @@ spec: - --loglevel=INFO - --concurrency=2 {{- if .Values.serviceIdentifier }} - - --queues=model-engine-{{ .Values.serviceIdentifier }}.service-builder + - --queues=model-engine-{{ .Values.serviceIdentifier }}-service-builder {{- else }} - - --queues=model-engine.service-builder + - --queues=model-engine-service-builder {{- end }} resources: {{- toYaml .Values.resources | nindent 12 }} diff --git a/model-engine/model_engine_server/common/settings.py b/model-engine/model_engine_server/common/settings.py index 7dc6c6bb..7438844a 100644 --- a/model-engine/model_engine_server/common/settings.py +++ b/model-engine/model_engine_server/common/settings.py @@ -86,9 +86,9 @@ def get_sync_endpoint_elb_url(deployment_name: str) -> str: def get_service_builder_queue(service_identifier=None): return ( - f"{SERVICE_BUILDER_QUEUE_PREFIX}-{service_identifier}.{SERVICE_BUILDER_QUEUE_SUFFIX}" + f"{SERVICE_BUILDER_QUEUE_PREFIX}-{service_identifier}-{SERVICE_BUILDER_QUEUE_SUFFIX}" if service_identifier - else f"{SERVICE_BUILDER_QUEUE_PREFIX}.{SERVICE_BUILDER_QUEUE_SUFFIX}" + else f"{SERVICE_BUILDER_QUEUE_PREFIX}-{SERVICE_BUILDER_QUEUE_SUFFIX}" ) diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index 6cd8f5af..c9abea51 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -11,6 +11,7 @@ from model_engine_server.core.config import infra_config from model_engine_server.db.base import SessionAsyncNullPool from model_engine_server.domain.entities import BatchJobSerializationFormat +from model_engine_server.domain.gateways import TaskQueueGateway from model_engine_server.infra.gateways import ( ABSFilesystemGateway, CeleryTaskQueueGateway, @@ -78,17 +79,22 @@ async def run_batch_job( ) resource_gateway = LiveEndpointResourceGateway(queue_delegate=queue_delegate) - model_endpoint_cache_repo = RedisModelEndpointCacheRepository( - redis_client=redis, - ) - inference_task_queue_gateway = ( - servicebus_task_queue_gateway - if infra_config().cloud_provider == "azure" - else sqs_task_queue_gateway - ) + + inference_task_queue_gateway: TaskQueueGateway + infra_task_queue_gateway: TaskQueueGateway + if infra_config().cloud_provider == "azure": + inference_task_queue_gateway = servicebus_task_queue_gateway + infra_task_queue_gateway = servicebus_task_queue_gateway + else: + inference_task_queue_gateway = sqs_task_queue_gateway + infra_task_queue_gateway = sqs_task_queue_gateway + model_endpoint_infra_gateway = LiveModelEndpointInfraGateway( resource_gateway=resource_gateway, - task_queue_gateway=inference_task_queue_gateway, + task_queue_gateway=infra_task_queue_gateway, + ) + model_endpoint_cache_repo = RedisModelEndpointCacheRepository( + redis_client=redis, ) async_model_endpoint_inference_gateway = LiveAsyncModelEndpointInferenceGateway( task_queue_gateway=inference_task_queue_gateway diff --git a/model-engine/model_engine_server/service_builder/celery.py b/model-engine/model_engine_server/service_builder/celery.py index 67cb94b0..06384c9e 100644 --- a/model-engine/model_engine_server/service_builder/celery.py +++ b/model-engine/model_engine_server/service_builder/celery.py @@ -1,16 +1,23 @@ from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.env_vars import CIRCLECI from model_engine_server.core.celery import celery_app from model_engine_server.core.config import infra_config +service_builder_broker_type: str +if CIRCLECI: + service_builder_broker_type = str(BrokerType.REDIS.value) +elif infra_config().cloud_provider == "azure": + service_builder_broker_type = str(BrokerType.SERVICEBUS.value) +else: + service_builder_broker_type = str(BrokerType.SQS.value) + service_builder_service = celery_app( name="model_engine_server.service_builder", modules=[ "model_engine_server.service_builder.tasks_v1", ], s3_bucket=infra_config().s3_bucket, - broker_type=str(BrokerType.SERVICEBUS.value) - if infra_config().cloud_provider == "azure" - else str(BrokerType.REDIS.value), + broker_type=service_builder_broker_type, backend_protocol="abs" if infra_config().cloud_provider == "azure" else "s3", ) From a63642111b2a930e6870fbf741479338fcc0f814 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Thu, 22 Feb 2024 12:32:01 -0800 Subject: [PATCH 239/425] support gemma models (#444) * upgrade vllm to 0.3.2 * bump transformers * tokenizer changes * rename -it to -instruct --- docs/model_zoo.md | 6 ++++++ .../domain/use_cases/llm_model_endpoint_use_cases.py | 5 +++++ .../inference/vllm/requirements.txt | 6 +++--- .../infra/repositories/live_tokenizer_repository.py | 4 ++++ model-engine/requirements.in | 4 ++-- model-engine/requirements.txt | 12 +++++++----- 6 files changed, 27 insertions(+), 10 deletions(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 18610abb..8805418c 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -28,8 +28,14 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | | `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | | `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-70b` | ✅ | | vllm | 16384 | +| `codellama-70b-instruct` | ✅ | | vllm | 4096 | | `zephyr-7b-alpha` | ✅ | | text-generation-inference, vllm | 32768 | | `zephyr-7b-beta` | ✅ | | text-generation-inference, vllm | 32768 | +| `gemma-2b` | ✅ | | vllm | 8192 | +| `gemma-2b-instruct` | ✅ | | vllm | 8192 | +| `gemma-7b` | ✅ | | vllm | 8192 | +| `gemma-7b-instruct` | ✅ | | vllm | 8192 | ## Usage diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index be56c7c5..ce2bf265 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -185,6 +185,10 @@ "mammoth-coder-llama-2-34b", "zephyr-7b-alpha", "zephyr-7b-beta", + "gemma-2b", + "gemma-2b-instruct", + "gemma-7b", + "gemma-7b-instruct", ] ), LLMInferenceFramework.LIGHTLLM: set( @@ -223,6 +227,7 @@ }, # setting both for backwards compatibility, will phase code-llama out in a future pr # Based on config here: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json#L12 # Can also see 13B, 34B there too + "gemma": {"max_model_len": 8192, "max_num_batched_tokens": 8192}, "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, "mixtral": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 4cc6239a..78e033bb 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,3 +1,3 @@ -ray==2.6.3 -git+https://github.com/vllm-project/vllm.git@4b61c6b669e368c6850531815940d9a542b9f223#egg=vllm -pydantic==1.10.13 +ray>=2.9 +vllm==0.3.2 +pydantic>=2.0 diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 1140686f..41356aef 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -73,6 +73,10 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "vicuna-13b": ModelInfo("eachadea/vicuna-13b-1.1", None), "zephyr-7b-alpha": ModelInfo("HuggingFaceH4/zephyr-7b-alpha", None), "zephyr-7b-beta": ModelInfo("HuggingFaceH4/zephyr-7b-beta", None), + "gemma-2b": ModelInfo("google/gemma-2b", None), + "gemma-2b-instruct": ModelInfo("google/gemma-2b-it", None), + "gemma-7b": ModelInfo("google/gemma-7b", None), + "gemma-7b-instruct": ModelInfo("google/gemma-7b-it", None), } diff --git a/model-engine/requirements.in b/model-engine/requirements.in index eaa46f55..380f7ec9 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -51,9 +51,9 @@ sseclient-py==1.7.2 stringcase==1.2.0 tenacity>=6.0.0,<=6.2.0 testing-postgresql==1.3.0 -transformers==4.34.1 +tokenizers~=0.15.2 tqdm~=4.64 -transformers==4.34.1 +transformers==4.38.0 twine==3.7.1 uvicorn==0.17.6 uvloop==0.17.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 56acbaf9..d4e6cd11 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -190,7 +190,7 @@ hpack==4.0.0 # via h2 httptools==0.5.0 # via -r model-engine/requirements.in -huggingface-hub==0.17.3 +huggingface-hub==0.20.3 # via # tokenizers # transformers @@ -418,7 +418,7 @@ rsa==4.9 # via google-auth s3transfer==0.6.1 # via boto3 -safetensors==0.4.0 +safetensors==0.4.2 # via transformers scramp==1.4.4 # via pg8000 @@ -474,8 +474,10 @@ testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 # via -r model-engine/requirements.in -tokenizers==0.14.1 - # via transformers +tokenizers==0.15.2 + # via + # -r model-engine/requirements.in + # transformers tomli==2.0.1 # via # build @@ -487,7 +489,7 @@ tqdm==4.65.0 # huggingface-hub # transformers # twine -transformers==4.34.1 +transformers==4.38.0 # via -r model-engine/requirements.in twine==3.7.1 # via -r model-engine/requirements.in From 31c7c5a392a1265c52b7f70e8e08d20ca293f697 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Thu, 22 Feb 2024 15:26:06 -0800 Subject: [PATCH 240/425] Fix infra config dependency (#449) --- model-engine/model_engine_server/common/io.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/common/io.py b/model-engine/model_engine_server/common/io.py index 93e2328a..ae53e7b9 100644 --- a/model-engine/model_engine_server/common/io.py +++ b/model-engine/model_engine_server/common/io.py @@ -9,8 +9,13 @@ def open_wrapper(uri: str, mode: str = "rt", **kwargs): client: Any + cloud_provider: str # This follows the 5.1.0 smart_open API - if infra_config().cloud_provider == "azure": + try: + cloud_provider = infra_config().cloud_provider + except Exception: + cloud_provider = "aws" + if cloud_provider == "azure": from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient From b3a0036228131422b1142f55fef5bd48a69f7605 Mon Sep 17 00:00:00 2001 From: Tiffany Zhao <142925794+tiffzhao5@users.noreply.github.com> Date: Thu, 22 Feb 2024 22:43:54 -0800 Subject: [PATCH 241/425] Add emitted timestamp for logging (#450) * add timestamp * change to utc --- .../model_engine_server/inference/post_inference_hooks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py index 39f5fcd7..3295c3b4 100644 --- a/model-engine/model_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -1,7 +1,9 @@ import json from abc import ABC, abstractmethod +from datetime import datetime from typing import Any, Dict, List, Optional, Union +import pytz import requests from fastapi.responses import JSONResponse from model_engine_server.common.constants import ( @@ -122,6 +124,7 @@ def handle( return response["task_id"] = task_id data_record = { + "EMITTED_AT": datetime.now(pytz.timezone("UTC")).strftime("%Y-%m-%dT%H:%M:%S"), "REQUEST_BODY": request_payload.json(), "RESPONSE_BODY": response, "ENDPOINT_ID": self._endpoint_id, From c4db5e4cfafb104d1cbff9c4428109d1f519f130 Mon Sep 17 00:00:00 2001 From: Tiffany Zhao <142925794+tiffzhao5@users.noreply.github.com> Date: Fri, 23 Feb 2024 12:54:55 -0800 Subject: [PATCH 242/425] change cache update time (#451) --- integration_tests/test_endpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/test_endpoints.py b/integration_tests/test_endpoints.py index 880acd7c..26f2dbe2 100644 --- a/integration_tests/test_endpoints.py +++ b/integration_tests/test_endpoints.py @@ -94,7 +94,7 @@ def test_async_model_endpoint( user, ) # Let the cache update - time.sleep(30) + time.sleep(60) # Endpoint builds should be cached now. ensure_n_ready_endpoints_short(1, user) From b4aef83ebf5c1fca42a893995d30f1bb91243558 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 23 Feb 2024 13:46:23 -0800 Subject: [PATCH 243/425] Bump aiohttp from 3.9.1 to 3.9.2 in /model-engine (#446) Bumps [aiohttp](https://github.com/aio-libs/aiohttp) from 3.9.1 to 3.9.2. - [Release notes](https://github.com/aio-libs/aiohttp/releases) - [Changelog](https://github.com/aio-libs/aiohttp/blob/master/CHANGES.rst) - [Commits](https://github.com/aio-libs/aiohttp/compare/v3.9.1...v3.9.2) --- updated-dependencies: - dependency-name: aiohttp dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> --- model-engine/requirements.in | 2 +- model-engine/requirements.txt | 131 ++++++++++++++++++---------------- 2 files changed, 70 insertions(+), 63 deletions(-) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 380f7ec9..15e32deb 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -1,6 +1,6 @@ GitPython~=3.0 Jinja2==3.0.3 # version 3.1.0 had a bug -aiohttp~=3.8 +aiohttp~=3.9 aioredis~=2.0 alembic==1.8.1 asyncpg==0.27.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index d4e6cd11..21c0c918 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -6,16 +6,16 @@ # aiofiles==23.1.0 # via quart -aiohttp==3.9.1 +aiohttp==3.9.2 # via - # -r model-engine/requirements.in + # -r requirements.in # kubernetes-asyncio aioredis==2.0.1 - # via -r model-engine/requirements.in + # via -r requirements.in aiosignal==1.3.1 # via aiohttp alembic==1.8.1 - # via -r model-engine/requirements.in + # via -r requirements.in amqp==5.1.1 # via kombu anyio==3.7.1 @@ -32,7 +32,7 @@ async-timeout==4.0.2 # aioredis # redis asyncpg==0.27.0 - # via -r model-engine/requirements.in + # via -r requirements.in attrs==23.1.0 # via # aiohttp @@ -43,7 +43,7 @@ attrs==23.1.0 azure-common==1.1.28 # via azure-keyvault-secrets azure-containerregistry==1.2.0 - # via -r model-engine/requirements.in + # via -r requirements.in azure-core==1.29.6 # via # azure-containerregistry @@ -52,13 +52,13 @@ azure-core==1.29.6 # azure-servicebus # azure-storage-blob azure-identity==1.15.0 - # via -r model-engine/requirements.in + # via -r requirements.in azure-keyvault-secrets==4.7.0 - # via -r model-engine/requirements.in + # via -r requirements.in azure-servicebus==7.11.4 - # via -r model-engine/requirements.in + # via -r requirements.in azure-storage-blob==12.19.0 - # via -r model-engine/requirements.in + # via -r requirements.in backports-zoneinfo[tzdata]==0.2.1 # via # celery @@ -71,20 +71,22 @@ blinker==1.6.2 # via quart boto3==1.28.1 # via - # -r model-engine/requirements.in + # -r requirements.in # celery # kombu boto3-stubs[essential]==1.26.67 - # via -r model-engine/requirements.in + # via + # -r requirements.in + # boto3-stubs botocore==1.31.1 # via - # -r model-engine/requirements.in + # -r requirements.in # boto3 # s3transfer botocore-stubs==1.29.165 # via boto3-stubs build==0.8.0 - # via -r model-engine/requirements.in + # via -r requirements.in bytecode==0.14.2 # via ddtrace cachetools==5.3.1 @@ -92,7 +94,9 @@ cachetools==5.3.1 cattrs==23.1.2 # via ddtrace celery[redis,sqs,tblib]==5.3.6 - # via -r model-engine/requirements.in + # via + # -r requirements.in + # celery certifi==2023.7.22 # via # datadog-api-client @@ -105,7 +109,7 @@ charset-normalizer==3.2.0 # via requests click==8.1.4 # via - # -r model-engine/requirements.in + # -r requirements.in # celery # click-didyoumean # click-plugins @@ -119,13 +123,13 @@ click-plugins==1.1.1 click-repl==0.3.0 # via celery cloudpickle==2.1.0 - # via -r model-engine/requirements.in + # via -r requirements.in colorama==0.4.6 # via twine commonmark==0.9.1 # via rich croniter==1.4.1 - # via -r model-engine/requirements.in + # via -r requirements.in cryptography==41.0.7 # via # azure-identity @@ -134,19 +138,19 @@ cryptography==41.0.7 # pyjwt # secretstorage dataclasses-json==0.5.9 - # via -r model-engine/requirements.in + # via -r requirements.in datadog==0.47.0 - # via -r model-engine/requirements.in + # via -r requirements.in datadog-api-client==2.11.0 - # via -r model-engine/requirements.in + # via -r requirements.in ddsketch==2.0.4 # via ddtrace ddtrace==1.8.3 - # via -r model-engine/requirements.in + # via -r requirements.in deprecation==2.1.0 - # via -r model-engine/requirements.in + # via -r requirements.in docker==5.0.3 - # via -r model-engine/requirements.in + # via -r requirements.in docutils==0.20.1 # via readme-renderer envier==0.4.0 @@ -156,7 +160,7 @@ exceptiongroup==1.2.0 # anyio # cattrs fastapi==0.78.0 - # via -r model-engine/requirements.in + # via -r requirements.in filelock==3.13.1 # via # huggingface-hub @@ -170,15 +174,15 @@ fsspec==2023.10.0 gitdb==4.0.10 # via gitpython gitdb2==2.0.6 - # via -r model-engine/requirements.in + # via -r requirements.in gitpython==3.1.32 - # via -r model-engine/requirements.in + # via -r requirements.in google-auth==2.21.0 # via kubernetes greenlet==2.0.2 # via sqlalchemy gunicorn==20.1.0 - # via -r model-engine/requirements.in + # via -r requirements.in h11==0.14.0 # via # hypercorn @@ -189,7 +193,7 @@ h2==4.1.0 hpack==4.0.0 # via h2 httptools==0.5.0 - # via -r model-engine/requirements.in + # via -r requirements.in huggingface-hub==0.20.3 # via # tokenizers @@ -231,14 +235,14 @@ jeepney==0.8.0 # secretstorage jinja2==3.0.3 # via - # -r model-engine/requirements.in + # -r requirements.in # quart jmespath==1.0.1 # via # boto3 # botocore json-log-formatter==0.5.2 - # via -r model-engine/requirements.in + # via -r requirements.in jsonschema==4.19.0 # via ddtrace jsonschema-specifications==2023.7.1 @@ -248,11 +252,11 @@ keyring==24.2.0 kombu[sqs]==5.3.5 # via celery kubeconfig==1.1.1 - # via -r model-engine/requirements.in + # via -r requirements.in kubernetes==25.3.0 - # via -r model-engine/requirements.in + # via -r requirements.in kubernetes-asyncio==25.11.0 - # via -r model-engine/requirements.in + # via -r requirements.in mako==1.2.4 # via alembic markupsafe==2.1.3 @@ -300,7 +304,7 @@ numpy==1.24.4 oauthlib==3.2.2 # via requests-oauthlib orjson==3.8.6 - # via -r model-engine/requirements.in + # via -r requirements.in packaging==23.1 # via # build @@ -326,13 +330,13 @@ prompt-toolkit==3.0.39 # via click-repl protobuf==3.20.3 # via - # -r model-engine/requirements.in + # -r requirements.in # ddsketch # ddtrace psycopg2-binary==2.9.3 - # via -r model-engine/requirements.in + # via -r requirements.in py-xid==0.3.0 - # via -r model-engine/requirements.in + # via -r requirements.in pyasn1==0.5.0 # via # pyasn1-modules @@ -343,19 +347,21 @@ pycparser==2.21 # via cffi pycurl==7.45.2 # via - # -r model-engine/requirements.in + # -r requirements.in # celery # kombu pydantic==1.10.11 # via - # -r model-engine/requirements.in + # -r requirements.in # fastapi pygments==2.15.1 # via # readme-renderer # rich pyjwt[crypto]==2.8.0 - # via msal + # via + # msal + # pyjwt python-dateutil==2.8.2 # via # botocore @@ -366,7 +372,7 @@ python-dateutil==2.8.2 # kubernetes-asyncio # pg8000 python-multipart==0.0.6 - # via -r model-engine/requirements.in + # via -r requirements.in pyyaml==6.0.1 # via # huggingface-hub @@ -375,7 +381,7 @@ pyyaml==6.0.1 # kubernetes-asyncio # transformers quart==0.18.3 - # via -r model-engine/requirements.in + # via -r requirements.in readme-renderer==40.0 # via twine redis==4.6.0 @@ -388,7 +394,7 @@ regex==2023.10.3 # via transformers requests==2.31.0 # via - # -r model-engine/requirements.in + # -r requirements.in # azure-core # datadog # docker @@ -401,7 +407,7 @@ requests==2.31.0 # transformers # twine requests-auth-aws-sigv4==0.7 - # via -r model-engine/requirements.in + # via -r requirements.in requests-oauthlib==1.3.1 # via kubernetes requests-toolbelt==1.0.0 @@ -409,7 +415,7 @@ requests-toolbelt==1.0.0 rfc3986==2.0.0 # via twine rich==12.6.0 - # via -r model-engine/requirements.in + # via -r requirements.in rpds-py==0.10.0 # via # jsonschema @@ -425,9 +431,9 @@ scramp==1.4.4 secretstorage==3.3.3 # via keyring sentencepiece==0.1.99 - # via -r model-engine/requirements.in + # via -r requirements.in sh==1.14.3 - # via -r model-engine/requirements.in + # via -r requirements.in six==1.16.0 # via # azure-core @@ -441,7 +447,7 @@ six==1.16.0 # python-dateutil # tenacity smart-open==5.2.1 - # via -r model-engine/requirements.in + # via -r requirements.in smmap==5.0.0 # via # gitdb @@ -452,31 +458,32 @@ sniffio==1.3.0 # via anyio sqlalchemy[asyncio]==2.0.4 # via - # -r model-engine/requirements.in + # -r requirements.in # alembic + # sqlalchemy sse-starlette==1.6.1 - # via -r model-engine/requirements.in + # via -r requirements.in sseclient-py==1.7.2 - # via -r model-engine/requirements.in + # via -r requirements.in starlette==0.19.1 # via # fastapi # sse-starlette stringcase==1.2.0 - # via -r model-engine/requirements.in + # via -r requirements.in tblib==2.0.0 # via celery tenacity==6.2.0 # via - # -r model-engine/requirements.in + # -r requirements.in # ddtrace testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 - # via -r model-engine/requirements.in + # via -r requirements.in tokenizers==0.15.2 # via - # -r model-engine/requirements.in + # -r requirements.in # transformers tomli==2.0.1 # via @@ -485,14 +492,14 @@ tomli==2.0.1 # pep517 tqdm==4.65.0 # via - # -r model-engine/requirements.in + # -r requirements.in # huggingface-hub # transformers # twine transformers==4.38.0 - # via -r model-engine/requirements.in + # via -r requirements.in twine==3.7.1 - # via -r model-engine/requirements.in + # via -r requirements.in types-awscrt==0.16.23 # via # botocore-stubs @@ -544,9 +551,9 @@ urllib3==1.26.16 # kubernetes-asyncio # requests uvicorn==0.17.6 - # via -r model-engine/requirements.in + # via -r requirements.in uvloop==0.17.0 - # via -r model-engine/requirements.in + # via -r requirements.in vine==5.1.0 # via # amqp @@ -568,7 +575,7 @@ xmltodict==0.13.0 # via ddtrace yarl==1.9.2 # via - # -r model-engine/requirements.in + # -r requirements.in # aiohttp zipp==3.16.0 # via From dc03fd4efa90663e6d6fde9ccf1c13c57ed008b9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 23 Feb 2024 14:16:04 -0800 Subject: [PATCH 244/425] Bump python-multipart from 0.0.6 to 0.0.7 in /model-engine (#447) Bumps [python-multipart](https://github.com/andrew-d/python-multipart) from 0.0.6 to 0.0.7. - [Release notes](https://github.com/andrew-d/python-multipart/releases) - [Changelog](https://github.com/Kludex/python-multipart/blob/master/CHANGELOG.md) - [Commits](https://github.com/andrew-d/python-multipart/compare/0.0.6...0.0.7) --- updated-dependencies: - dependency-name: python-multipart dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- model-engine/requirements.in | 2 +- model-engine/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 15e32deb..e03739c3 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -37,7 +37,7 @@ psycopg2-binary==2.9.3 py-xid==0.3.0 pycurl~=7.44 # For celery[sqs] pydantic~=1.10.11 -python-multipart~=0.0.6 +python-multipart~=0.0.7 quart==0.18.3 requests-auth-aws-sigv4~=0.7 requests~=2.25 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 21c0c918..14230e8f 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -371,7 +371,7 @@ python-dateutil==2.8.2 # kubernetes # kubernetes-asyncio # pg8000 -python-multipart==0.0.6 +python-multipart==0.0.7 # via -r requirements.in pyyaml==6.0.1 # via From be330c2397a491036d5427de78e0ae60a2395f49 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 23 Feb 2024 14:50:05 -0800 Subject: [PATCH 245/425] Bump gitpython from 3.1.32 to 3.1.41 in /model-engine (#453) Bumps [gitpython](https://github.com/gitpython-developers/GitPython) from 3.1.32 to 3.1.41. - [Release notes](https://github.com/gitpython-developers/GitPython/releases) - [Changelog](https://github.com/gitpython-developers/GitPython/blob/main/CHANGES) - [Commits](https://github.com/gitpython-developers/GitPython/compare/3.1.32...3.1.41) --- updated-dependencies: - dependency-name: gitpython dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- model-engine/requirements.in | 2 +- model-engine/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index e03739c3..e93e9e5b 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -1,4 +1,4 @@ -GitPython~=3.0 +GitPython~=3.1 Jinja2==3.0.3 # version 3.1.0 had a bug aiohttp~=3.9 aioredis~=2.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 14230e8f..9a4fe51e 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -175,7 +175,7 @@ gitdb==4.0.10 # via gitpython gitdb2==2.0.6 # via -r requirements.in -gitpython==3.1.32 +gitpython==3.1.41 # via -r requirements.in google-auth==2.21.0 # via kubernetes From 37d38d4c4a4da6b653f4ebdacb1717f07004f56f Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 26 Feb 2024 12:37:10 -0800 Subject: [PATCH 246/425] Log endpoint in sensitive_log_mode (#455) --- model-engine/model_engine_server/api/llms_v1.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index c076e2b1..9660f0d0 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -311,7 +311,9 @@ async def create_completion_sync_task( """ Runs a sync prompt completion on an LLM. """ - if not hmi_config.sensitive_log_mode: + if hmi_config.sensitive_log_mode: # pragma: no cover + logger.info(f"POST /completion_sync to endpoint {model_endpoint_name} for {auth}") + else: logger.info( f"POST /completion_sync with {request} to endpoint {model_endpoint_name} for {auth}" ) @@ -374,7 +376,9 @@ async def create_completion_stream_task( """ Runs a stream prompt completion on an LLM. """ - if not hmi_config.sensitive_log_mode: # pragma: no cover + if hmi_config.sensitive_log_mode: # pragma: no cover + logger.info(f"POST /completion_stream to endpoint {model_endpoint_name} for {auth}") + else: logger.info( f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" ) From 06bc25e9a7c104f2926bc1ffb4759668e90aecd6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:39:11 -0800 Subject: [PATCH 247/425] Bump orjson from 3.8.6 to 3.9.15 in /model-engine (#456) --- model-engine/requirements.in | 2 +- model-engine/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index e93e9e5b..49984a54 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -31,7 +31,7 @@ json-log-formatter~=0.3 kubeconfig~=1.1 kubernetes-asyncio==25.11.0 kubernetes~=25.3.0 -orjson==3.8.6 +orjson==3.9.15 protobuf~=3.20 psycopg2-binary==2.9.3 py-xid==0.3.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 9a4fe51e..47d2fcef 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -303,7 +303,7 @@ numpy==1.24.4 # via transformers oauthlib==3.2.2 # via requests-oauthlib -orjson==3.8.6 +orjson==3.9.15 # via -r requirements.in packaging==23.1 # via From 9a4e2e53c464cc271f1ce7c9e8abfa6581b24a51 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:16:31 -0800 Subject: [PATCH 248/425] Allow the load test script to use a csv of inputs (#440) * prepare allowing a csv input * randomly select input * pass some args through * log output token count percentiles * debug + ignore first line in file * lazy try except lmao * oops * oops x2 * oops x3 * . * oops prompt sample is reused * revert the changes to main, I'm gonna just have it take in a distribution of output token counts * output token count distribution * renane var to be more clear --- scripts/throughput_benchmarks.py | 57 ++++++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/scripts/throughput_benchmarks.py b/scripts/throughput_benchmarks.py index c689d8cc..d67614a5 100644 --- a/scripts/throughput_benchmarks.py +++ b/scripts/throughput_benchmarks.py @@ -221,6 +221,34 @@ def generate_output_token_counts(mean, std, num, input_token_count): return output +def generate_output_token_counts_from_existing( + distribution: List[int], num: int, input_token_count: int +): + assert len(distribution) > 0, "Can't have a distribution with 0 tokens" + output = [] + # Sample without replacement so that we don't have as much variance + for _ in range(num // len(distribution)): + random.shuffle(distribution) + output.extend(distribution) + random.shuffle(distribution) + output.extend(distribution[: num % len(distribution)]) + assert len(output) == num + + for i in range(len(output)): + output[i] = min(output[i], MAX_CONTEXT_WINDOW - input_token_count) + return output + + +def read_distribution_from_file(fpath: str): + # Assumes the distribution is some json-formatted string that represents a list + try: + with open(fpath, "r") as fin: + return json.load(fin) + except FileNotFoundError: + print("File not found. Exiting.") + raise + + def run_benchmark( model: str, framework: InferenceFramework, @@ -231,17 +259,23 @@ def run_benchmark( concurrency: int, verbose: bool, local_port: int, + response_token_count_distribution: Optional[List] = None, ): prompt = generate_prompt(config.input_token_count, hf_model) prompt_num_tokens = config.input_token_count - output_token_counts = generate_output_token_counts( - config.output_token_count_mean, - config.output_token_count_std, - num_trials, - config.input_token_count, - ) + if response_token_count_distribution is not None: + output_token_counts = generate_output_token_counts_from_existing( + response_token_count_distribution, num_trials, config.input_token_count + ) + else: + output_token_counts = generate_output_token_counts( + config.output_token_count_mean, + config.output_token_count_std, + num_trials, + config.input_token_count, + ) start = time.time() results = send_requests( @@ -352,10 +386,18 @@ def run_benchmarks( verbose: bool = False, hf_model: Optional[str] = None, local_port: int = 5005, + response_token_count_distribution_file: Optional[str] = None, ): """Run benchmarks.""" all_statistics = [] config = BenchmarkConfig(input_token_count, output_token_count_mean) + + response_token_count_distribution = None + if response_token_count_distribution_file is not None: + response_token_count_distribution = read_distribution_from_file( + response_token_count_distribution_file + ) + try: if verbose: print(f"Running benchmark for config {config}") @@ -375,6 +417,7 @@ def run_benchmarks( concurrency, verbose, local_port, + response_token_count_distribution, ) all_statistics.append(statistics) except Exception: @@ -404,6 +447,7 @@ def run_benchmarks_concurrency_range( verbose: bool = False, hf_model: Optional[str] = None, local_port: int = 5005, + response_token_count_distribution_file: Optional[str] = None, ): if output_file is not None: # Create empty file @@ -422,6 +466,7 @@ def run_benchmarks_concurrency_range( verbose, hf_model, local_port, + response_token_count_distribution_file, ) From 38c59e22d1445694c610bb8ca161be7f36c49175 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:01:09 -0800 Subject: [PATCH 249/425] add some debugging to vllm docker (#454) * add some debugging to vllm docker * update * check processes using GPU * lint --- .../inference/vllm/Dockerfile | 7 +++ .../inference/vllm/vllm_server.py | 52 +++++++++++++++++-- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile b/model-engine/model_engine_server/inference/vllm/Dockerfile index 907e795d..75b9e1f5 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile @@ -1,5 +1,12 @@ FROM nvcr.io/nvidia/pytorch:23.09-py3 +RUN apt-get update \ + && apt-get install -y \ + gdb \ + psmisc \ + && apt-get autoremove -y \ + && rm -rf /var/lib/apt/lists/* + RUN pip uninstall torch -y COPY requirements.txt /workspace/requirements.txt RUN pip install -r requirements.txt diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 9c66ae7a..e402db82 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -1,5 +1,9 @@ import argparse +import code import json +import signal +import subprocess +import traceback from typing import AsyncGenerator import uvicorn @@ -46,9 +50,9 @@ async def stream_results() -> AsyncGenerator[str, None]: "text": request_output.outputs[-1].text[len(last_output_text) :], "count_prompt_tokens": len(request_output.prompt_token_ids), "count_output_tokens": len(request_output.outputs[0].token_ids), - "log_probs": request_output.outputs[0].logprobs[-1] - if sampling_params.logprobs - else None, + "log_probs": ( + request_output.outputs[0].logprobs[-1] if sampling_params.logprobs else None + ), "finished": request_output.finished, } last_output_text = request_output.outputs[-1].text @@ -88,7 +92,47 @@ async def abort_request() -> None: return Response(content=json.dumps(ret)) +def get_gpu_free_memory(): + """Get GPU free memory using nvidia-smi.""" + try: + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"] + ).decode("utf-8") + gpu_memory = [int(x) for x in output.strip().split("\n")] + return gpu_memory + except subprocess.CalledProcessError: + return None + + +def check_unknown_startup_memory_usage(): + """Check for unknown memory usage at startup.""" + gpu_free_memory = get_gpu_free_memory() + if gpu_free_memory is not None: + min_mem = min(gpu_free_memory) + max_mem = max(gpu_free_memory) + if max_mem - min_mem > 10: + print( + f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." + ) + output = subprocess.check_output(["fuser -v /dev/nvidia*"], shell=True).decode("utf-8") + print(f"Processes using GPU: {output}") + + +def debug(sig, frame): + """Interrupt running process, and provide a python prompt for + interactive debugging.""" + d = {"_frame": frame} # Allow access to frame object. + d.update(frame.f_globals) # Unless shadowed by global + d.update(frame.f_locals) + + i = code.InteractiveConsole(d) + message = "Signal received : entering python shell.\nTraceback:\n" + message += "".join(traceback.format_stack(frame)) + i.interact(message) + + if __name__ == "__main__": + check_unknown_startup_memory_usage() parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) # None == IPv4 / IPv6 dualstack parser.add_argument("--port", type=int, default=5005) @@ -98,6 +142,8 @@ async def abort_request() -> None: engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) + signal.signal(signal.SIGUSR1, debug) + uvicorn.run( app, host=args.host, From 468bcbe43bf10974b460c537bcee752fb8eb702d Mon Sep 17 00:00:00 2001 From: Edward Gan Date: Tue, 27 Feb 2024 14:31:33 -0500 Subject: [PATCH 250/425] Add product label validation (#442) Adds support for a new shared plugin that validates the product and team labels --- .../use_cases/model_endpoint_use_cases.py | 27 +++++++++++++++++++ .../tests/unit/api/test_model_endpoints.py | 11 ++++---- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index bfd51b17..9d355307 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -5,6 +5,7 @@ """ import re +from dataclasses import dataclass from typing import Any, Dict, List, Optional from model_engine_server.common.constants import SUPPORTED_POST_INFERENCE_HOOKS @@ -46,6 +47,8 @@ CONVERTED_FROM_ARTIFACT_LIKE_KEY = "_CONVERTED_FROM_ARTIFACT_LIKE" MODEL_BUNDLE_CHANGED_KEY = "_MODEL_BUNDLE_CHANGED" +DEFAULT_DISALLOWED_TEAMS = ["_INVALID_TEAM"] + logger = make_logger(logger_name()) @@ -118,6 +121,20 @@ def validate_deployment_resources( ) +@dataclass +class ValidationResult: + passed: bool + message: str + + +# Placeholder team and product label validator that only checks for a single invalid team +def simple_team_product_validator(team: str, product: str) -> ValidationResult: + if team in DEFAULT_DISALLOWED_TEAMS: + return ValidationResult(False, "Invalid team") + else: + return ValidationResult(True, "Valid team") + + def validate_labels(labels: Dict[str, str]) -> None: for required_label in REQUIRED_ENDPOINT_LABELS: if required_label not in labels: @@ -129,6 +146,7 @@ def validate_labels(labels: Dict[str, str]) -> None: if restricted_label in labels: raise EndpointLabelsException(f"Cannot specify '{restricted_label}' in labels") + # TODO: remove after we fully migrate to the new team + product validator try: from plugins.known_users import ALLOWED_TEAMS @@ -138,6 +156,15 @@ def validate_labels(labels: Dict[str, str]) -> None: except ModuleNotFoundError: pass + try: + from shared_plugins.team_product_label_validation import validate_team_product_label + except ModuleNotFoundError: + validate_team_product_label = simple_team_product_validator + + validation_result = validate_team_product_label(labels["team"], labels["product"]) + if not validation_result.passed: + raise EndpointLabelsException(validation_result.message) + # Check k8s will accept the label values regex_pattern = "(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?" # k8s label regex for label_value in labels.values(): diff --git a/model-engine/tests/unit/api/test_model_endpoints.py b/model-engine/tests/unit/api/test_model_endpoints.py index 614e5907..d3d8b9a6 100644 --- a/model-engine/tests/unit/api/test_model_endpoints.py +++ b/model-engine/tests/unit/api/test_model_endpoints.py @@ -5,6 +5,7 @@ from fastapi.testclient import TestClient from model_engine_server.common.dtos.model_endpoints import GetModelEndpointV1Response from model_engine_server.domain.entities import ModelBundle, ModelEndpoint, ModelEndpointStatus +from model_engine_server.domain.use_cases.model_endpoint_use_cases import DEFAULT_DISALLOWED_TEAMS def test_create_model_endpoint_success( @@ -40,7 +41,6 @@ def test_create_model_endpoint_success( assert response_2.status_code == 200 -@pytest.mark.skip(reason="TODO: team validation is currently disabled") def test_create_model_endpoint_invalid_team_returns_400( model_bundle_1_v1: Tuple[ModelBundle, Any], create_model_endpoint_request_sync: Dict[str, Any], @@ -59,7 +59,8 @@ def test_create_model_endpoint_invalid_team_returns_400( fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, ) - create_model_endpoint_request_sync["labels"]["team"] = "some_invalid_team" + invalid_team_name = DEFAULT_DISALLOWED_TEAMS[0] + create_model_endpoint_request_sync["labels"]["team"] = invalid_team_name response_1 = client.post( "/v1/model-endpoints", auth=(test_api_key, ""), @@ -67,7 +68,7 @@ def test_create_model_endpoint_invalid_team_returns_400( ) assert response_1.status_code == 400 - create_model_endpoint_request_async["labels"]["team"] = "some_invalid_team" + create_model_endpoint_request_async["labels"]["team"] = invalid_team_name response_2 = client.post( "/v1/model-endpoints", auth=(test_api_key, ""), @@ -394,7 +395,6 @@ def test_update_model_endpoint_by_id_success( assert response.json()["endpoint_creation_task_id"] -@pytest.mark.skip(reason="TODO: team validation is currently disabled") def test_update_model_endpoint_by_id_invalid_team_returns_400( model_bundle_1_v1: Tuple[ModelBundle, Any], model_endpoint_1: Tuple[ModelEndpoint, Any], @@ -418,8 +418,9 @@ def test_update_model_endpoint_by_id_invalid_team_returns_400( fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, ) + invalid_team_name = DEFAULT_DISALLOWED_TEAMS[0] update_model_endpoint_request["labels"] = { - "team": "some_invalid_team", + "team": invalid_team_name, "product": "my_product", } response = client.put( From f9a3ff5f2d96a3348eb5c4283b97e1ed9ac32ab4 Mon Sep 17 00:00:00 2001 From: Tiffany Zhao <142925794+tiffzhao5@users.noreply.github.com> Date: Wed, 28 Feb 2024 10:49:38 -0800 Subject: [PATCH 251/425] Add log statement for gateway sending async task (#459) * log * fix --- .../infra/gateways/celery_task_queue_gateway.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index e1f2f11c..7a8f6911 100644 --- a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -8,8 +8,10 @@ ) from model_engine_server.core.celery import TaskVisibility, celery_app from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway +logger = make_logger(logger_name()) backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3" celery_redis = celery_app( @@ -72,6 +74,7 @@ def send_task( kwargs=kwargs, queue=queue_name, ) + logger.info(f"Task {res.id} sent to queue {queue_name} from gateway") # pragma: no cover return CreateAsyncTaskV1Response(task_id=res.id) def get_task(self, task_id: str) -> GetAsyncTaskV1Response: From 39ef7c4006f694ceea4685cb646d5a8f50a7af7f Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 1 Mar 2024 18:17:21 -0800 Subject: [PATCH 252/425] Some batch inference improvements (#460) * Some batch inference improvements * fix unit test * coverage * integration test * fix --- docs/guides/completions.md | 49 +++++++++++++++++++ .../inference/batch_inference/Dockerfile_vllm | 2 +- .../inference/batch_inference/vllm_batch.py | 30 ++++++++++++ .../inference/vllm/vllm_server.py | 1 + .../infra/services/image_cache_service.py | 33 ++++++++++--- .../services/test_image_cache_service.py | 25 ++++++++-- 6 files changed, 128 insertions(+), 12 deletions(-) diff --git a/docs/guides/completions.md b/docs/guides/completions.md index dee51f61..f48f05c4 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -120,6 +120,55 @@ async def main(): asyncio.run(main()) ``` +## Batch completions + +The Python client also supports batch completions. Batch completions supports distributing data to multiple workers to accelerate inference. It also tries to maximize throughput so the completions should finish quite a bit faster than hitting models through HTTP. Use [Completion.batch_complete](../../api/python_client/#llmengine.completion.Completion.batch_complete) to utilize batch completions. + +Some examples of batch completions: + +=== "Batch completions with prompts in the request" +```python +from llmengine import Completion +from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + +content = CreateBatchCompletionsRequestContent( + prompts=["What is deep learning", "What is a neural network"], + max_new_tokens=10, + temperature=0.0 +) + +response = Completion.batch_create( + output_data_path="s3://my-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + content=content +) +print(response.job_id) +``` + +=== "Batch completions with prompts in a file and with 2 parallel jobs" +```python +from llmengine import Completion +from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + +# Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" + +response = Completion.batch_create( + input_data_path="s3://my-input-path", + output_data_path="s3://my-output-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + data_parallelism=2 +) +print(response.job_id) +``` + ## Which model should I use? See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions. diff --git a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm index 92820714..c79c51a0 100644 --- a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm +++ b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm @@ -1,7 +1,7 @@ FROM nvcr.io/nvidia/pytorch:23.09-py3 RUN apt-get update && \ - apt-get install -y dumb-init && \ + apt-get install -y dumb-init psmisc && \ apt-get autoremove -y && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index f976b87b..f7bfab11 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -33,6 +33,7 @@ def download_model(checkpoint_path, final_weights_folder): # Need to override these env vars so s5cmd uses AWS_PROFILE env["AWS_ROLE_ARN"] = "" env["AWS_WEB_IDENTITY_TOKEN_FILE"] = "" + # nosemgrep process = subprocess.Popen( s5cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env ) @@ -193,6 +194,7 @@ async def generate_with_vllm(request, content, model, job_index): tensor_parallel_size=request.model_config.num_shards, seed=request.model_config.seed or 0, disable_log_requests=True, + gpu_memory_utilization=0.8, # To avoid OOM errors when there's host machine GPU usage ) llm = AsyncLLMEngine.from_engine_args(engine_args) @@ -220,5 +222,33 @@ async def generate_with_vllm(request, content, model, job_index): return results_generators +def get_gpu_free_memory(): # pragma: no cover + """Get GPU free memory using nvidia-smi.""" + try: + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"] + ).decode("utf-8") + gpu_memory = [int(x) for x in output.strip().split("\n")] + return gpu_memory + except subprocess.CalledProcessError: + return None + + +def check_unknown_startup_memory_usage(): # pragma: no cover + """Check for unknown memory usage at startup.""" + gpu_free_memory = get_gpu_free_memory() + if gpu_free_memory is not None: + min_mem = min(gpu_free_memory) + max_mem = max(gpu_free_memory) + if max_mem - min_mem > 10: + print( + f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." + ) + # nosemgrep + output = subprocess.check_output(["fuser -v /dev/nvidia*"], shell=True).decode("utf-8") + print(f"Processes using GPU: {output}") + + if __name__ == "__main__": + check_unknown_startup_memory_usage() asyncio.run(batch_inference()) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index e402db82..5bd3f6e4 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -114,6 +114,7 @@ def check_unknown_startup_memory_usage(): print( f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." ) + # nosemgrep output = subprocess.check_output(["fuser -v /dev/nvidia*"], shell=True).decode("utf-8") print(f"Processes using GPU: {output}") diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index d79f4c49..a14c2b45 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -3,7 +3,7 @@ import pytz from model_engine_server.common.config import hmi_config -from model_engine_server.common.env_vars import GIT_TAG +from model_engine_server.common.env_vars import CIRCLECI, GIT_TAG from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import GpuType, ModelEndpointInfraState @@ -69,17 +69,38 @@ def _cache_finetune_llm_images( ) istio_image = DockerImage("gcr.io/istio-release/proxyv2", "1.15.0") - tgi_image = DockerImage( - f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.3-launch_s3" + tgi_image_110 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "1.1.0" ) - tgi_image_2 = DockerImage( - f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.4" + vllm_image_027 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.2.7" + ) + vllm_image_032 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2" + ) + latest_tag = ( + self.docker_repository.get_latest_image_tag( + f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}" + ) + if not CIRCLECI + else "fake_docker_repository_latest_image_tag" + ) + vllm_batch_image_latest = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}", + latest_tag, ) forwarder_image = DockerImage( f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG ) - for llm_image in [istio_image, tgi_image, tgi_image_2, forwarder_image]: + for llm_image in [ + istio_image, + tgi_image_110, + vllm_image_027, + vllm_image_032, + vllm_batch_image_latest, + forwarder_image, + ]: if self.docker_repository.is_repo_name( llm_image.repo ) and not self.docker_repository.image_exists(llm_image.tag, llm_image.repo): diff --git a/model-engine/tests/unit/infra/services/test_image_cache_service.py b/model-engine/tests/unit/infra/services/test_image_cache_service.py index aa1821fa..bf578c6d 100644 --- a/model-engine/tests/unit/infra/services/test_image_cache_service.py +++ b/model-engine/tests/unit/infra/services/test_image_cache_service.py @@ -52,14 +52,29 @@ async def test_caching_finetune_llm_images( gateway: Any = fake_image_cache_service.image_cache_gateway istio_image = DockerImage("gcr.io/istio-release/proxyv2", "1.15.0") - tgi_image = DockerImage( - f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.3-launch_s3" + tgi_image_110 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "1.1.0" ) - tgi_image_2 = DockerImage( - f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "0.9.4" + vllm_image_027 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.2.7" + ) + vllm_image_032 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2" + ) + latest_tag = "fake_docker_repository_latest_image_tag" + vllm_batch_image_latest = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}", + latest_tag, ) forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG) for key in ["a10", "a100"]: - for llm_image in [istio_image, tgi_image, tgi_image_2, forwarder_image]: + for llm_image in [ + istio_image, + tgi_image_110, + vllm_image_027, + vllm_image_032, + vllm_batch_image_latest, + forwarder_image, + ]: assert f"{llm_image.repo}:{llm_image.tag}" in gateway.cached_images[key] From 036b1a9ccb8cffee8039f304ab06384f0904dbb1 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 5 Mar 2024 17:25:40 -0800 Subject: [PATCH 253/425] Fix cacher (#462) * Fix cacher * format --- .../model_engine_server/infra/services/image_cache_service.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index a14c2b45..5d5c9d13 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -79,9 +79,7 @@ def _cache_finetune_llm_images( f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2" ) latest_tag = ( - self.docker_repository.get_latest_image_tag( - f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}" - ) + self.docker_repository.get_latest_image_tag(hmi_config.batch_inference_vllm_repository) if not CIRCLECI else "fake_docker_repository_latest_image_tag" ) From 575eaa621232a358c31f562ac1f7e7a3046baa6b Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 6 Mar 2024 20:06:02 -0800 Subject: [PATCH 254/425] Fix vllm batch docker image (#463) * Fix vllm batch docker image * try again with 0.2.5 --- .../inference/batch_inference/Dockerfile_vllm | 7 ++++++- .../inference/batch_inference/requirements.txt | 2 -- .../inference/batch_inference/sample_config.json | 2 +- .../inference/batch_inference/vllm_batch.py | 5 +++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm index c79c51a0..d0a3b36b 100644 --- a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm +++ b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm @@ -6,10 +6,15 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean -RUN pip uninstall torch -y COPY model-engine/model_engine_server/inference/batch_inference/requirements.txt /workspace/requirements.txt RUN pip install -r requirements.txt +RUN pip uninstall torch -y +RUN pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cu121 + +RUN pip uninstall xformers -y +RUN pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu121 + RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt index 9e8d1188..ab055543 100644 --- a/model-engine/model_engine_server/inference/batch_inference/requirements.txt +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -1,5 +1,3 @@ -ray==2.6.3 -#git+https://github.com/vllm-project/vllm.git@4b61c6b669e368c6850531815940d9a542b9f223#egg=vllm vllm==0.2.5 pydantic==1.10.13 boto3==1.34.15 diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_config.json b/model-engine/model_engine_server/inference/batch_inference/sample_config.json index 366d9785..d047d7f8 100644 --- a/model-engine/model_engine_server/inference/batch_inference/sample_config.json +++ b/model-engine/model_engine_server/inference/batch_inference/sample_config.json @@ -8,4 +8,4 @@ "labels": {"team": "my_team"} }, "data_parallelism":2 -} \ No newline at end of file +} diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index f7bfab11..ffbdac3a 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -17,6 +17,7 @@ CONFIG_FILE = os.getenv("CONFIG_FILE") AWS_REGION = os.getenv("AWS_REGION", "us-west-2") +MODEL_WEIGHTS_FOLDER = os.getenv("MODEL_WEIGHTS_FOLDER", "./model_weights") os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") @@ -118,7 +119,7 @@ async def batch_inference(): request = CreateBatchCompletionsRequest.parse_file(CONFIG_FILE) if request.model_config.checkpoint_path is not None: - download_model(request.model_config.checkpoint_path, "./model_weights") + download_model(request.model_config.checkpoint_path, MODEL_WEIGHTS_FOLDER) content = request.content if content is None: @@ -126,7 +127,7 @@ async def batch_inference(): content = CreateBatchCompletionsRequestContent.parse_raw(f.read()) model = ( - "./model_weights" if request.model_config.checkpoint_path else request.model_config.model + MODEL_WEIGHTS_FOLDER if request.model_config.checkpoint_path else request.model_config.model ) results_generators = await generate_with_vllm(request, content, model, job_index) From 0528b52776da51191b0b45a92b993a54f4c89baa Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 7 Mar 2024 14:40:33 -0800 Subject: [PATCH 255/425] Add tool completion to batch inference (#461) * test * Impl * remove logging * lints * fix tests * fix * fix * try fix unit test * fix * no cover * stop sequences --- .../model_engine_server/common/dtos/llms.py | 29 ++ .../use_cases/llm_model_endpoint_use_cases.py | 5 + .../generate_tool_sample_data.py | 79 +++++ .../batch_inference/requirements.txt | 4 +- .../batch_inference/sample_config_tool.json | 14 + .../batch_inference/sample_data_tool.json | 15 + .../inference/batch_inference/vllm_batch.py | 317 ++++++++++++++---- .../inference/tool_completion/__init__.py | 0 .../inference/tool_completion/base.py | 17 + .../inference/tool_completion/tools.py | 249 ++++++++++++++ .../inference/tool_completion/utils.py | 107 ++++++ model-engine/requirements-test.txt | 1 + model-engine/tests/unit/inference/conftest.py | 88 +++++ .../tests/unit/inference/test_vllm_batch.py | 144 ++++++-- 14 files changed, 985 insertions(+), 84 deletions(-) create mode 100644 model-engine/model_engine_server/inference/batch_inference/generate_tool_sample_data.py create mode 100644 model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json create mode 100644 model-engine/model_engine_server/inference/batch_inference/sample_data_tool.json create mode 100644 model-engine/model_engine_server/inference/tool_completion/__init__.py create mode 100644 model-engine/model_engine_server/inference/tool_completion/base.py create mode 100644 model-engine/model_engine_server/inference/tool_completion/tools.py create mode 100644 model-engine/model_engine_server/inference/tool_completion/utils.py diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 35e1c744..8d335d8d 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -429,6 +429,30 @@ class CreateBatchCompletionsModelConfig(BaseModel): """ +class ToolConfig(BaseModel): + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + name: str + """ + Name of the tool to use for the batch inference. + """ + max_iterations: Optional[int] = 10 + """ + Maximum number of iterations to run the tool. + """ + execution_timeout_seconds: Optional[int] = 60 + """ + Maximum runtime of the tool in seconds. + """ + should_retry_on_error: Optional[bool] = True + """ + Whether to retry the tool on error. + """ + + class CreateBatchCompletionsRequest(BaseModel): """ Request object for batch completions. @@ -456,6 +480,11 @@ class CreateBatchCompletionsRequest(BaseModel): """ Maximum runtime of the batch inference in seconds. Default to one day. """ + tool_config: Optional[ToolConfig] = None + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ class CreateBatchCompletionsResponse(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index ce2bf265..b458343c 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -2275,6 +2275,11 @@ async def execute( hardware.gpus = max(hardware.gpus, request.model_config.num_shards) request.model_config.num_shards = hardware.gpus + if request.tool_config and request.tool_config.name != "code_evaluator": + raise ObjectHasInvalidValueException( + "Only code_evaluator tool is supported for batch completions." + ) + batch_bundle = await self.create_batch_job_bundle(user, request, hardware) validate_resource_requests( diff --git a/model-engine/model_engine_server/inference/batch_inference/generate_tool_sample_data.py b/model-engine/model_engine_server/inference/batch_inference/generate_tool_sample_data.py new file mode 100644 index 00000000..d60a76d4 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/generate_tool_sample_data.py @@ -0,0 +1,79 @@ +import json + +COMPLETION_PROMPT1 = """\ +FYI: you can write code like this: +```python +import math +print(math.sqrt(2)) +``` +1.41... +>>> + +For reference, the third digit of 4.32 is 2. Also, use "Final Answer: X" to indicate your final answer. + +### Problem: + +What is the 4th digit of pi? + +### Answer: +```python +import math +print(math.pi) +``` +3.141592653589793 +>>> + +Final Answer: 1 + +### Problem: + +What is the 4th digit of the square root of 2? + +### Answer: +""" + +COMPLETION_PROMPT2 = """\ +FYI: you can write code like this: +```python +import math +print(math.sqrt(2)) +``` +1.41... +>>> + +For reference, the third digit of 4.32 is 2. Also, use "Final Answer: X" to indicate your final answer. + +### Problem: + +What is the 4th digit of pi? + +### Answer: +```python +import math +print(math.pi) +``` +3.141592653589793 +>>> + +Final Answer: 1 + +### Problem: + +What is the 5th digit of the square root of 2? + +### Answer: +""" + +data = { + "prompts": [ + COMPLETION_PROMPT1, + COMPLETION_PROMPT2, + "what is deep learning", + ], + "max_new_tokens": 100, + "temperature": 0.0, + "return_token_log_probs": True, + "stop_sequences": ["", "\n### Problem:\n", ">>>\n"], +} + +json.dump(data, open("sample_data_tool.json", "w")) diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt index ab055543..bbc99b04 100644 --- a/model-engine/model_engine_server/inference/batch_inference/requirements.txt +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -2,4 +2,6 @@ vllm==0.2.5 pydantic==1.10.13 boto3==1.34.15 smart-open==6.4.0 -ddtrace==2.4.0 \ No newline at end of file +ddtrace==2.4.0 +docker==7.0.0 +func-timeout==4.3.5 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json b/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json new file mode 100644 index 00000000..d9a3af4a --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json @@ -0,0 +1,14 @@ +{ + "input_data_path":"./sample_data_tool.json", + "output_data_path":"./sample_output_tool.json", + "model_config":{ + "model":"mistral-7b", + "checkpoint_path":"s3://scale-ml/models/mistral-7b", + "num_shards": 1, + "labels": {"team": "my_team"} + }, + "data_parallelism":2, + "tool_config": { + "name": "code_evaluator" + } +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_data_tool.json b/model-engine/model_engine_server/inference/batch_inference/sample_data_tool.json new file mode 100644 index 00000000..f529eca4 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/sample_data_tool.json @@ -0,0 +1,15 @@ +{ + "prompts": [ + "FYI: you can write code like this: \n```python\nimport math\nprint(math.sqrt(2))\n```\n1.41...\n>>>\n\nFor reference, the third digit of 4.32 is 2. Also, use \"Final Answer: X\" to indicate your final answer.\n\n### Problem:\n\nWhat is the 4th digit of pi?\n\n### Answer:\n```python\nimport math\nprint(math.pi)\n```\n3.141592653589793\n>>>\n\nFinal Answer: 1\n\n### Problem:\n\nWhat is the 4th digit of the square root of 2?\n\n### Answer: \n", + "FYI: you can write code like this: \n```python\nimport math\nprint(math.sqrt(2))\n```\n1.41...\n>>>\n\nFor reference, the third digit of 4.32 is 2. Also, use \"Final Answer: X\" to indicate your final answer.\n\n### Problem:\n\nWhat is the 4th digit of pi?\n\n### Answer:\n```python\nimport math\nprint(math.pi)\n```\n3.141592653589793\n>>>\n\nFinal Answer: 1\n\n### Problem:\n\nWhat is the 5th digit of the square root of 2?\n\n### Answer: \n", + "what is deep learning" + ], + "max_new_tokens": 100, + "temperature": 0.0, + "return_token_log_probs": true, + "stop_sequences": [ + "", + "\n### Problem:\n", + ">>>\n" + ] +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index ffbdac3a..d783b3f4 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -2,17 +2,23 @@ import json import os import subprocess +import sys import time +import uuid +from typing import List, Optional, Type from urllib.parse import urlparse import boto3 import smart_open +from func_timeout import FunctionTimedOut, func_set_timeout from model_engine_server.common.dtos.llms import ( CompletionOutput, CreateBatchCompletionsRequest, CreateBatchCompletionsRequestContent, TokenOutput, + ToolConfig, ) +from model_engine_server.inference.tool_completion.tools import TOOL_MAP, BaseTool, Tools, tokenizer from tqdm import tqdm CONFIG_FILE = os.getenv("CONFIG_FILE") @@ -28,7 +34,7 @@ def get_s3_client(): def download_model(checkpoint_path, final_weights_folder): - s5cmd = f"./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" env = os.environ.copy() env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") # Need to override these env vars so s5cmd uses AWS_PROFILE @@ -113,6 +119,164 @@ def delete_s3_chunks(request): print("Chunks deleted") +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +def get_vllm_engine(model, request): + from vllm import AsyncEngineArgs, AsyncLLMEngine + + engine_args = AsyncEngineArgs( + model=model, + quantization=request.model_config.quantize, + tensor_parallel_size=request.model_config.num_shards, + seed=request.model_config.seed or 0, + disable_log_requests=True, + gpu_memory_utilization=0.8, # To avoid OOM errors when there's host machine GPU usage + ) + + llm = AsyncLLMEngine.from_engine_args(engine_args) + return llm + + +async def generate_with_tool( + llm, + tool_config: ToolConfig, + content: CreateBatchCompletionsRequestContent, + prompts, + tool: Type[BaseTool], +): + class IterativeGeneration: + def __init__(self, prompt, max_new_tokens): + self.generated_text = "" + self.num_prompt_tokens = 0 + self.remaining_tokens = max_new_tokens + self.token_logits = [] + self.tool_exception = None + self.prompt = prompt + self.completed = False + + def __repr__(self) -> str: + return f"generated_text: {self.generated_text}, num_prompt_tokens: {self.num_prompt_tokens}, remaining_tokens: {self.remaining_tokens}, tool_exception: {self.tool_exception}, prompt: {self.prompt}, completed: {self.completed}" + + num_iters = 0 + generations = [IterativeGeneration(prompt, content.max_new_tokens) for prompt in prompts] + max_iterations = tool_config.max_iterations or 10 + stop_sequences = content.stop_sequences or [] + stop_sequences.append(tool.tool_context_end) + + while num_iters < max_iterations: + num_iters += 1 + + iter_prompts = [ + (gen.prompt + gen.generated_text, idx) + for idx, gen in enumerate(generations) + if not gen.completed + ] + + if not iter_prompts: + break + + bar = tqdm( + total=len(iter_prompts), + desc=f"Generating outputs, iteration {num_iters}", + file=sys.stdout, + ) + + outputs = await generate_with_vllm( + llm, + content.max_new_tokens, + content.temperature, + content.stop_sequences, + content.return_token_log_probs, + content.presence_penalty, + content.frequency_penalty, + content.top_k, + content.top_p, + [iter[0] for iter in iter_prompts], + bar, + ) + + bar = tqdm( + total=len(iter_prompts), + desc=f"Running tools, iteration {num_iters}", + file=sys.stdout, + ) + for i in range(len(iter_prompts)): + bar.update(1) + response = outputs[i] + gen_item = generations[iter_prompts[i][1]] + new_text = response.text + + if content.return_token_log_probs: + gen_item.token_logits += response.tokens + + if not gen_item.num_prompt_tokens: + gen_item.num_prompt_tokens = response.num_prompt_tokens + + # break the loop if generation is complete even if remaining_tokens>0 + if len(new_text) == 0: + gen_item.completed = True + continue + + # To-do write tools to receive response object itself rather than the text + try: + # We need to pass the tool/text to a function that times out if the python code can't execute + @func_set_timeout(tool_config.execution_timeout_seconds) + def tool_func(text: str, past_context: Optional[str]): + return tool()(text, past_context) + + past_context = ( + gen_item.generated_text if tool_config.should_retry_on_error else None + ) + new_text, num_tool_output_tokens = tool_func(new_text, past_context) + + except (Exception, FunctionTimedOut) as e: + # If the tool failed, we should add the error message to the generated text and keep going. It should be added right after the + # tool call token and concluded with the tool_context_end_token. + new_text_split = new_text.rsplit(tool.tool_call_token, 1) + + # We can guarantee this because the tool is not called if it doesn't have the tool call token + # We still want to replace what the LLM thinks the output should be.. + added_text = str(e) + tool.tool_context_end + subtracted_text = new_text_split[1] + + new_text = f"{new_text_split[0]}{tool.tool_call_token}{e}{tool.tool_context_end}" + + # Now let's add the additional tokens + num_tool_output_tokens = min( + len(tokenizer(added_text).input_ids) + - len(tokenizer(subtracted_text).input_ids), + 0, + ) + + # Also, define the tool exception here so we can raise it later + gen_item.tool_exception = e + + num_completion_tokens = response.num_completion_tokens + + gen_item.remaining_tokens -= num_completion_tokens + gen_item.remaining_tokens -= num_tool_output_tokens + gen_item.generated_text += new_text + + # If we didn't just execute a tool, we're done + if not gen_item.generated_text.endswith(tool.tool_context_end): + gen_item.completed = True + continue + + results = [ + CompletionOutput( + text=gen_item.generated_text, + num_prompt_tokens=gen_item.num_prompt_tokens, + num_completion_tokens=content.max_new_tokens - gen_item.remaining_tokens, + tokens=gen_item.token_logits if content.return_token_log_probs else None, + ) + for gen_item in generations + ] + + return results + + async def batch_inference(): job_index = int(os.getenv("JOB_COMPLETION_INDEX", 0)) @@ -130,9 +294,92 @@ async def batch_inference(): MODEL_WEIGHTS_FOLDER if request.model_config.checkpoint_path else request.model_config.model ) - results_generators = await generate_with_vllm(request, content, model, job_index) + llm = get_vllm_engine(model, request) + + prompts = [] + prompts_per_pod = len(content.prompts) // request.data_parallelism + if job_index == request.data_parallelism - 1: + for prompt in content.prompts[prompts_per_pod * job_index :]: + prompts.append(prompt) + else: + for prompt in content.prompts[ + prompts_per_pod * job_index : prompts_per_pod * (job_index + 1) + ]: + prompts.append(prompt) + + if request.tool_config is not None: + tool_enum = Tools(request.tool_config.name) + tool = TOOL_MAP[tool_enum] + outputs = await generate_with_tool(llm, request.tool_config, content, prompts, tool) + else: + bar = tqdm(total=len(prompts), desc="Processed prompts") + + outputs = await generate_with_vllm( + llm, + content.max_new_tokens, + content.temperature, + content.stop_sequences, + content.return_token_log_probs, + content.presence_penalty, + content.frequency_penalty, + content.top_k, + content.top_p, + prompts, + bar, + ) + + bar.close() + + output_dicts = [output.dict() for output in outputs] + + if request.data_parallelism == 1: + with smart_open.open(request.output_data_path, "w") as f: + f.write(json.dumps(output_dicts)) + else: + chunk_file = f"{request.output_data_path}.{job_index}" + with smart_open.open(chunk_file, "w") as f: + f.write(json.dumps(output_dicts)) + if job_index == 0: + wait_for_all_chunks(request) + combine_all_chunks(request) + if request.output_data_path.startswith("s3://"): + delete_s3_chunks(request) + + +async def generate_with_vllm( + engine, + max_new_tokens, + temperature, + stop_sequences, + return_token_log_probs, + presence_penalty, + frequency_penalty, + top_k, + top_p, + prompts, + bar, +) -> List[CompletionOutput]: + from vllm import SamplingParams - bar = tqdm(total=len(results_generators), desc="Processed prompts") + # Add the requests to the engine. + sampling_params = SamplingParams( + max_tokens=max_new_tokens, + temperature=temperature, + stop=stop_sequences, + logprobs=1 if return_token_log_probs else None, + presence_penalty=presence_penalty or 0.0, + frequency_penalty=frequency_penalty or 0.0, + top_k=top_k or -1, + top_p=top_p or 1.0, + ) + + results_generators = [] + for prompt in prompts: + request_id = random_uuid() + results_generator = await engine.add_request( + request_id, prompt, sampling_params, None, time.monotonic() + ) + results_generators.append(results_generator) outputs = [] for generator in results_generators: @@ -143,12 +390,10 @@ async def batch_inference(): bar.update(1) token_text = request_output.outputs[-1].text[len(last_output_text) :] - log_probs = ( - request_output.outputs[0].logprobs[-1] if content.return_token_log_probs else None - ) + log_probs = request_output.outputs[0].logprobs[-1] if return_token_log_probs else None last_output_text = request_output.outputs[-1].text - if content.return_token_log_probs: + if return_token_log_probs: tokens.append( TokenOutput( token=token_text, @@ -164,63 +409,11 @@ async def batch_inference(): num_prompt_tokens=num_prompt_tokens, num_completion_tokens=num_completion_tokens, ) - if content.return_token_log_probs: + if return_token_log_probs: output.tokens = tokens - outputs.append(output.dict()) - - bar.close() - - if request.data_parallelism == 1: - with smart_open.open(request.output_data_path, "w") as f: - f.write(json.dumps(outputs)) - else: - chunk_file = f"{request.output_data_path}.{job_index}" - with smart_open.open(chunk_file, "w") as f: - f.write(json.dumps(outputs)) - if job_index == 0: - wait_for_all_chunks(request) - combine_all_chunks(request) - if request.output_data_path.startswith("s3://"): - delete_s3_chunks(request) - - -async def generate_with_vllm(request, content, model, job_index): - from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams - from vllm.utils import random_uuid - - engine_args = AsyncEngineArgs( - model=model, - quantization=request.model_config.quantize, - tensor_parallel_size=request.model_config.num_shards, - seed=request.model_config.seed or 0, - disable_log_requests=True, - gpu_memory_utilization=0.8, # To avoid OOM errors when there's host machine GPU usage - ) - - llm = AsyncLLMEngine.from_engine_args(engine_args) - - # Add the requests to the engine. - sampling_params = SamplingParams( - max_tokens=content.max_new_tokens, - temperature=content.temperature, - stop=content.stop_sequences, - logprobs=1 if content.return_token_log_probs else None, - presence_penalty=content.presence_penalty or 0.0, - frequency_penalty=content.frequency_penalty or 0.0, - top_k=content.top_k or -1, - top_p=content.top_p or 1.0, - ) - - results_generators = [] - prompts_per_pod = len(content.prompts) // request.data_parallelism - for prompt in content.prompts[prompts_per_pod * job_index : prompts_per_pod * (job_index + 1)]: - request_id = random_uuid() - results_generator = await llm.add_request( - request_id, prompt, sampling_params, None, time.monotonic() - ) - results_generators.append(results_generator) - return results_generators + outputs.append(output) + return outputs def get_gpu_free_memory(): # pragma: no cover diff --git a/model-engine/model_engine_server/inference/tool_completion/__init__.py b/model-engine/model_engine_server/inference/tool_completion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/model_engine_server/inference/tool_completion/base.py b/model-engine/model_engine_server/inference/tool_completion/base.py new file mode 100644 index 00000000..6166987a --- /dev/null +++ b/model-engine/model_engine_server/inference/tool_completion/base.py @@ -0,0 +1,17 @@ +from typing import Optional, Tuple + + +class BaseTool: + """ + Base class for third-party tools. + """ + + tool_context_start = "" + tool_call_token = "" + tool_context_end = "" + + def __call__(self, expression: str, past_context: Optional[str]) -> Tuple[str, int]: + """ + Call method to be overridden by child classes. + """ + raise NotImplementedError("The evaluate method must be implemented by child classes.") diff --git a/model-engine/model_engine_server/inference/tool_completion/tools.py b/model-engine/model_engine_server/inference/tool_completion/tools.py new file mode 100644 index 00000000..8e84bdff --- /dev/null +++ b/model-engine/model_engine_server/inference/tool_completion/tools.py @@ -0,0 +1,249 @@ +import re +import subprocess +from enum import Enum +from typing import Optional, Tuple + +import docker +from model_engine_server.inference.tool_completion.base import BaseTool +from model_engine_server.inference.tool_completion.utils import ( + FIX_ERRORS_MAPPING, + NAME_ERROR_PATTERN, + PRINT_PATTERN, +) +from transformers import LlamaTokenizer + +tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_7b", legacy=False) +MAX_CODEBLOCK_RETRIES = 3 + + +class CodeBlockEvaluator(BaseTool): # pragma: no cover + """ + A evaluator to "pseudo-safely" execute python code blocks. + Executes code from a model generated response using a safe python interpreter. + the code should have the following format: + + ```python + {code} + ``` + {output} + >>> + + The output will be replaced with the output from executing the code. + """ + + tool_context_start = "```python\n" + tool_call_token = "\n```\n" + tool_context_end = "\n>>>\n" + + @staticmethod + def _cleanup_code_error(error_code: str) -> str: + """This function will clean up an error code from code execution + + Args: + error_code (str): The full error code (e.g. like below): + + Command '['python', '-c', 'import math\nx = 2\nmath.sqrt(y)']' in image 'continuumio/anaconda3' + returned non-zero exit status 1: b'Traceback (most recent call last): + File "", line 3, in \nNameError: name \'y\' is not defined\n' + + Returns: + str: like the following: + + Traceback (most recent call last): File "", line 3, in + NameError: name \'y\' is not defined + + """ + if "Traceback" not in error_code: + return error_code + + # Let's find the byte string: (e.g. b') + stacktrace = error_code.split("b'")[-1] + + # Now read it as a bytestring + stacktrace = "\n" + stacktrace.encode("utf-8").decode("unicode_escape") + + return stacktrace.strip("'") + + def __init__(self): + # Condition to check if we can use docker + try: + self.client = docker.from_env() + self.evaluate = self.evaluate_code_in_docker + except docker.errors.DockerException: + # If docker is not available, use the python interpreter + self.evaluate = self.evaluate_code_using_exec + + def __call__( + self, + expression: str, + past_context: Optional[str] = None, + ) -> Tuple[str, int]: + """ + Given an expression, extract the code block and execute it using a safe python interpreter. Additionally, + approximate the number of tokens added to the expression from the tool output along with handling retries + due to simple tool errors (e.g. import errors, missing variables) + + Args: + expression (str): text with natural language and code blocks + past_context (Optional[str]): previously generated code blocks for retrying simple code errors + + Returns: + str: Formatted output from the code execution tool + int: Number of tokens added + + Raises: + RuntimeError: If any errors occur during the code execution or retries for simple code errors. + """ + tool_output = "" + expression_ = expression + num_tokens = 0 + if (CodeBlockEvaluator.tool_context_start in expression) and ( + CodeBlockEvaluator.tool_call_token in expression + ): + # Extract the expression between the start token and the special token for the tool to evaluate + code_expression = expression.split(CodeBlockEvaluator.tool_context_start)[-1].split( + CodeBlockEvaluator.tool_call_token + )[0] + + # Note: Can increase max retries if needed (e.g. > 1 import errors + variable not defined in code_expression) + for retry_count in range(MAX_CODEBLOCK_RETRIES): + try: + tool_output = self.evaluate(code_expression) + break + except Exception as e: + name_error = re.search(NAME_ERROR_PATTERN, str(e)) + if ( + past_context is None + or name_error is None + or retry_count == MAX_CODEBLOCK_RETRIES - 1 + ): + error_code = self._cleanup_code_error(str(e)) + raise RuntimeError(f"failed with error: {error_code}") + + if retry_count == 0 and past_context != "": + # Grab all the prior code blocks in "```python\n{code}\n```\n" format + code_expression = ( + self._extract_code_blocks(past_context) + "\n" + code_expression + ) + else: + current_error = name_error.group(1).replace("\\", "") + # Make sure error is one of the fixable/common import errors seen in the past + if current_error not in FIX_ERRORS_MAPPING.keys(): + error_code = self._cleanup_code_error(str(e)) + raise RuntimeError( + f"failed on retry: {retry_count}, NameError variable: {current_error}, and error: {error_code}" + ) + + code_expression = FIX_ERRORS_MAPPING[current_error] + "\n" + code_expression + + tool_output = ( + CodeBlockEvaluator.tool_call_token + + tool_output + + CodeBlockEvaluator.tool_context_end + ) + + expression_ = expression.split(CodeBlockEvaluator.tool_call_token)[0] + tool_output + num_tokens = max( + 0, len(tokenizer(expression_).input_ids) - len(tokenizer(expression).input_ids) + ) + return expression_, num_tokens + + def _extract_code_blocks(self, context: str): + """ + Given some text (e.g. previous completion), extract all the code blocks in the format + along with removing any old print statements. + + Args: + context (str): text with natural language and code blocks + + Returns: + str: Parsed code blocks with print statements removed + """ + code_block_pattern = re.compile( + rf"{CodeBlockEvaluator.tool_context_start}(.*?){CodeBlockEvaluator.tool_call_token}", + re.DOTALL, + ) + code_block_matches = code_block_pattern.findall(context) + # Remove lines with print statements bc already included in model response + cleaned_code_blocks = [] + for code_block in code_block_matches: + no_print_code_blocks = [] + for line in code_block.split("\n"): + # Ignore lines with print format + if re.search(PRINT_PATTERN, line) is None: + no_print_code_blocks.append(line) + cleaned_code_blocks.append("\n".join(no_print_code_blocks)) + return "\n".join(cleaned_code_blocks) + + def evaluate_code_in_docker(self, code: str) -> str: + """ + Executes a block of code using a safe python interpreter and returns the output as a string. + + This function uses a docker container to safely execute a given block of code. + The function returns the output of the last executed line, if any. + + Args: + code (str): A string containing the Python code to be executed. + + Returns: + str: The output of the executed code, converted to string. If there's no explicit output, + the function returns the result of the last line of code. + + Raises: + RuntimeError: If any errors occur during the code execution. + """ + + try: + output = self.client.containers.run( + "continuumio/anaconda3", command=["python", "-c", code] + ).decode() + output = output.strip() + except docker.errors.ContainerError as e: + raise RuntimeError(e) + + return output + + def evaluate_code_using_exec(self, code: str) -> str: + """ + Executes a block of code using the python "exec" function. Returns the output as a string. + Unfortunately it doesn't have the same safety guarantees as the docker version. + + However, it will only ever be enabled when we are in a scale environment as we check the llmengine + path. + + Args: + code (str): A string containing the Python code to be executed. + + Returns: + str: The output of the executed code, converted to string. If there's no explicit output, + the function returns the result of the last line of code. + """ + try: + p = subprocess.run(["python", "-c", code], capture_output=True, text=True) + p.check_returncode() # Raises CalledProcessError if the exit code is non-zero + output_str = p.stdout + + # If output is empty and the last line didn't have a print statement, edit the code to add one + if output_str == "" and "print" not in code.split("\n")[-1]: + new_code = "\n".join(code.split("\n")[:-1]) + last_line = code.split("\n")[-1] + new_code = new_code + f"\nprint({last_line})" + + # Re-run it + p = subprocess.run(["python", "-c", new_code], capture_output=True, text=True) + p.check_returncode() + output_str = p.stdout + + except subprocess.CalledProcessError as e: + raise RuntimeError(p.stderr) from e + + return output_str + + +class Tools(str, Enum): + CODE_EVALUATOR = "code_evaluator" + + +TOOL_MAP = { + Tools.CODE_EVALUATOR: CodeBlockEvaluator, +} diff --git a/model-engine/model_engine_server/inference/tool_completion/utils.py b/model-engine/model_engine_server/inference/tool_completion/utils.py new file mode 100644 index 00000000..bb30d116 --- /dev/null +++ b/model-engine/model_engine_server/inference/tool_completion/utils.py @@ -0,0 +1,107 @@ +from queue import Queue +from typing import Tuple + +from model_engine_server.inference.tool_completion.base import BaseTool + +NAME_ERROR_PATTERN = r"NameError: name \\?'([^']+)\\?' is not defined" + +PRINT_PATTERN = r"print\(.+?\)" + +# Most common imports used during code execution +FIX_ERRORS_MAPPING = { + "math": "import math", + "np": "import numpy as np", + "cmath": "import cmath", + "norm": "from scipy.stats import norm", + "plt": "import matplotlib.pyplot as plt", + "sp": "import sympy as sp", + "sympy": "import sympy", + "sqrt": "from cmath import sqrt", + "erfinv": "from scipy.special import erfinv", + "t": "from scipy.stats import t", + "comb": "from scipy.special import comb", + "Fraction": "from fractions import Fraction", + "st": "import steam_table as st", + "pd": "import pandas as pd", + "stats": "import scipy.stats as stats", + "opt": "import scipy.optimize as opt", + "Counter": "from collections import Counter", + "datetime": "import datetime", + "gcd": "from fractions import gcd", + "pi": "from math import pi", + "quad": "from scipy.integrate import quad", + "fsolve": "from scipy.optimize import fsolve", + "factorial": "from math import factorial", + "tan": "from math import tan", + "log": "from math import log", + "symbols": "from sympy import symbols, sin, cos", + "integrate": "from sympy import symbols, integrate", + "diff": "from sympy import symbols, sin, cos, diff", + "sin": "from sympy import symbols, sin, cos", + "cos": "from sympy import symbols, sin, cos", + "time": "import time", + "Symbol": "from sympy import Symbol", +} + + +# Check if a model response indicates it could be starting a tool +def check_streaming_tool_start(stream_queue: Queue, tool: BaseTool) -> bool: # pragma: no cover + # If the queue is empty, we can't start the tool + if stream_queue.qsize() == 0: + return False + + # Create the full string from the queue + queue_str = "" + for response in list(stream_queue.queue): + queue_str += response.output.text + + # Check if the start token is in the queue + if tool.tool_context_start in queue_str: + return True + + return False + + +def check_either_substr(str1: str, str2: str) -> bool: + return str1 in str2 or str2 in str1 + + +# Check if some responses from the queue should be returned +def get_responses_to_yield( + stream_queue: Queue, tool: BaseTool, tool_started: bool +) -> Tuple[Queue, Queue]: # pragma: no cover + """We return a tuple, (responses_to_yield, stream_queue) based on what should be returned""" + # If we've started the tool, we shouldn't yield anything + if tool_started: + return Queue(), stream_queue + + # Otherwise, we should yield everything in the queue that *can't* be part of the start of a tool + concatenated_queue_str = "" + responses_to_yield: Queue = Queue() # These are values we're sure we want to return right now + undecided_queue: Queue = ( + Queue() + ) # These are values that could be part of start token but we aren't sure yet + + # Iterate through the queue and add to the concatenated queue string + while stream_queue.qsize() > 0: + response = stream_queue.get() + + # First check if the adding the current response could be part of the start token + if check_either_substr( + concatenated_queue_str + response.output.text, tool.tool_context_start + ): + # If so, add it to the undecided queue + undecided_queue.put(response) + concatenated_queue_str += response.output.text + + # Otherwise, we are confident that everything in the undecided *can't* be part of the start token + # in addition to the concatenated queue string + else: + while not undecided_queue.empty(): + responses_to_yield.put(undecided_queue.get()) + + responses_to_yield.put(response) + concatenated_queue_str = "" + + # Finally, return the responses to yield and the new stream queue + return responses_to_yield, undecided_queue diff --git a/model-engine/requirements-test.txt b/model-engine/requirements-test.txt index 158e0743..bbe191e4 100644 --- a/model-engine/requirements-test.txt +++ b/model-engine/requirements-test.txt @@ -24,3 +24,4 @@ types-ujson==5.5.0 types-urllib3==1.26.14 types-waitress==2.1.4 frozendict==2.3.4 +func-timeout==4.3.5 diff --git a/model-engine/tests/unit/inference/conftest.py b/model-engine/tests/unit/inference/conftest.py index 4d0ec72c..26a3a0a3 100644 --- a/model-engine/tests/unit/inference/conftest.py +++ b/model-engine/tests/unit/inference/conftest.py @@ -7,6 +7,7 @@ CreateBatchCompletionsRequest, CreateBatchCompletionsRequestContent, TokenOutput, + ToolConfig, ) @@ -22,6 +23,29 @@ def create_batch_completions_request(): ) +@pytest.fixture +def create_batch_completions_tool_completion_request(): + return CreateBatchCompletionsRequest( + model_config=CreateBatchCompletionsModelConfig( + checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={} + ), + data_parallelism=1, + input_data_path="input_data_path", + output_data_path="output_data_path", + tool_config=ToolConfig(name="code_evaluator"), + ) + + +@pytest.fixture +def create_batch_completions_tool_completion_request_content(): + return CreateBatchCompletionsRequestContent( + prompts=["prompt1"], + max_new_tokens=100, + temperature=0.8, + return_token_log_probs=True, + ) + + @pytest.fixture def create_batch_completions_request_content(): return CreateBatchCompletionsRequestContent( @@ -91,3 +115,67 @@ def mock_completion_output(): TokenOutput(token=" text3", log_prob=0.3), ], ) + + +@pytest.fixture +def mock_tool_completion_output(): + return CompletionOutput( + text="```python\nimport math\nprint(math.sqrt(2))\n```\n1.414...\n", + num_prompt_tokens=10, + num_completion_tokens=28, + tokens=[ + TokenOutput(token="``", log_prob=-0.1980377733707428), + TokenOutput(token="`", log_prob=-0.0037908137310296297), + TokenOutput(token="python", log_prob=-0.015637163072824478), + TokenOutput(token="\n", log_prob=-0.0010788579238578677), + TokenOutput(token="import", log_prob=-0.04351021721959114), + TokenOutput(token=" math", log_prob=-0.0021214615553617477), + TokenOutput(token="\n", log_prob=-0.002169043058529496), + TokenOutput(token="print", log_prob=-0.06555093079805374), + TokenOutput(token="(", log_prob=-0.005272886715829372), + TokenOutput(token="math", log_prob=-0.009995171800255775), + TokenOutput(token=".", log_prob=-0.0002040654799202457), + TokenOutput(token="sqrt", log_prob=-0.00886327400803566), + TokenOutput(token="(", log_prob=-0.0015410225605592132), + TokenOutput(token="2", log_prob=-0.008573509752750397), + TokenOutput(token="))", log_prob=-0.010970987379550934), + TokenOutput(token="\n", log_prob=-0.002175347413867712), + TokenOutput(token="``", log_prob=-0.01911235973238945), + TokenOutput(token="`", log_prob=-0.0005327236140146852), + TokenOutput(token="\n", log_prob=-0.002304519060999155), + TokenOutput(token="1", log_prob=-0.10852570831775665), + TokenOutput(token=".", log_prob=-0.007146273739635944), + TokenOutput(token="4", log_prob=-0.003810290014371276), + TokenOutput(token="1", log_prob=-0.002774677239358425), + TokenOutput(token="4", log_prob=-0.16946221888065338), + TokenOutput(token=".", log_prob=-0.007678280584514141), + TokenOutput(token=".", log_prob=-0.021146666258573532), + TokenOutput(token=".", log_prob=-0.3870151937007904), + TokenOutput(token="\n", log_prob=-0.027081478387117386), + ], + ) + + +@pytest.fixture +def mock_tool_completion_output2(): + return CompletionOutput( + text="Final Answer: 4\n", + num_prompt_tokens=38, + num_completion_tokens=6, + tokens=[ + TokenOutput(token="Final", log_prob=-0.1980377733707428), + TokenOutput(token=" Answer", log_prob=-0.0037908137310296297), + TokenOutput(token=":", log_prob=-0.015637163072824478), + TokenOutput(token=" ", log_prob=-0.0010788579238578677), + TokenOutput(token="4", log_prob=-0.04351021721959114), + TokenOutput(token="\n", log_prob=-0.0021214615553617477), + ], + ) + + +@pytest.fixture +def mock_run_output(): + value = MagicMock() + value.stdout = "1.4142135623730951" + value.check_returncode = MagicMock() + return value diff --git a/model-engine/tests/unit/inference/test_vllm_batch.py b/model-engine/tests/unit/inference/test_vllm_batch.py index e9ab0937..7dbaad42 100644 --- a/model-engine/tests/unit/inference/test_vllm_batch.py +++ b/model-engine/tests/unit/inference/test_vllm_batch.py @@ -1,11 +1,12 @@ import json -from unittest.mock import MagicMock, call, mock_open, patch +from unittest.mock import call, mock_open, patch import pytest from model_engine_server.inference.batch_inference.vllm_batch import batch_inference, file_exists @pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") @patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") @patch( "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" @@ -25,9 +26,9 @@ async def test_batch_inference( mock_generate_with_vllm, mock_create_batch_completions_request_content, mock_create_batch_completions_request, + mock_vllm, create_batch_completions_request, create_batch_completions_request_content, - create_vllm_request_outputs, mock_s3_client, mock_process, mock_completion_output, @@ -40,11 +41,8 @@ async def test_batch_inference( create_batch_completions_request_content ) - mock_results_generator = MagicMock() - mock_results_generator.__aiter__.return_value = create_vllm_request_outputs - # Mock the generate_with_vllm function - mock_generate_with_vllm.return_value = [mock_results_generator] + mock_generate_with_vllm.return_value = [mock_completion_output] # Call the function await batch_inference() @@ -62,6 +60,7 @@ async def test_batch_inference( @pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") @patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") @patch( "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" @@ -81,9 +80,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed( mock_generate_with_vllm, mock_create_batch_completions_request_content, mock_create_batch_completions_request, + mock_vllm, create_batch_completions_request, create_batch_completions_request_content, - create_vllm_request_outputs, mock_s3_client, mock_process, mock_completion_output, @@ -97,11 +96,8 @@ async def test_batch_inference_failed_to_download_model_but_proceed( create_batch_completions_request_content ) - mock_results_generator = MagicMock() - mock_results_generator.__aiter__.return_value = create_vllm_request_outputs - # Mock the generate_with_vllm function - mock_generate_with_vllm.return_value = [mock_results_generator] + mock_generate_with_vllm.return_value = [mock_completion_output] # Call the function await batch_inference() @@ -119,6 +115,7 @@ async def test_batch_inference_failed_to_download_model_but_proceed( @pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") @patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") @patch( "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" @@ -140,9 +137,9 @@ async def test_batch_inference_two_workers( mock_generate_with_vllm, mock_create_batch_completions_request_content, mock_create_batch_completions_request, + mock_vllm, create_batch_completions_request, create_batch_completions_request_content, - create_vllm_request_outputs, mock_s3_client, mock_process, mock_completion_output, @@ -156,11 +153,8 @@ async def test_batch_inference_two_workers( create_batch_completions_request_content ) - mock_results_generator = MagicMock() - mock_results_generator.__aiter__.return_value = create_vllm_request_outputs - # Mock the generate_with_vllm function - mock_generate_with_vllm.return_value = [mock_results_generator] + mock_generate_with_vllm.return_value = [mock_completion_output] indexes = [1, 0] @@ -203,6 +197,7 @@ def side_effect(key, default): @pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") @patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") @patch( "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" @@ -224,9 +219,9 @@ async def test_batch_inference_delete_chunks( mock_generate_with_vllm, mock_create_batch_completions_request_content, mock_create_batch_completions_request, + mock_vllm, create_batch_completions_request, create_batch_completions_request_content, - create_vllm_request_outputs, mock_s3_client, mock_process, mock_completion_output, @@ -241,11 +236,8 @@ async def test_batch_inference_delete_chunks( create_batch_completions_request_content ) - mock_results_generator = MagicMock() - mock_results_generator.__aiter__.return_value = create_vllm_request_outputs - # Mock the generate_with_vllm function - mock_generate_with_vllm.return_value = [mock_results_generator] + mock_generate_with_vllm.return_value = [mock_completion_output] indexes = [1, 0] @@ -314,3 +306,113 @@ def test_file_exists_no_such_key(): result = file_exists(path) assert result is False + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") +@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("model_engine_server.inference.batch_inference.vllm_batch.subprocess.Popen") +@patch("subprocess.run") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +async def test_batch_inference_tool_completion( + mock_open_func, + mock_run, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_request, + mock_vllm, + create_batch_completions_tool_completion_request, + create_batch_completions_tool_completion_request_content, + mock_s3_client, + mock_process, + mock_tool_completion_output, + mock_tool_completion_output2, + mock_run_output, +): + # Mock the necessary objects and data + mock_run.return_value = mock_run_output + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + mock_create_batch_completions_request.parse_file.return_value = ( + create_batch_completions_tool_completion_request + ) + mock_create_batch_completions_request_content.parse_raw.return_value = ( + create_batch_completions_tool_completion_request_content + ) + + # Mock the generate_with_vllm function + mock_generate_with_vllm.side_effect = [ + [mock_tool_completion_output], + [mock_tool_completion_output2], + ] + + # Call the function + await batch_inference() + + # Assertions + mock_create_batch_completions_request.parse_file.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path", "w"), + call().write( + json.dumps( + [ + { + "text": "```python\nimport math\nprint(math.sqrt(2))\n```\n1.4142135623730951\n>>>\nFinal Answer: 4\n", + "num_prompt_tokens": 10, + "num_completion_tokens": 49, + "tokens": [ + {"token": "``", "log_prob": -0.1980377733707428}, + {"token": "`", "log_prob": -0.0037908137310296297}, + {"token": "python", "log_prob": -0.015637163072824478}, + {"token": "\n", "log_prob": -0.0010788579238578677}, + {"token": "import", "log_prob": -0.04351021721959114}, + {"token": " math", "log_prob": -0.0021214615553617477}, + {"token": "\n", "log_prob": -0.002169043058529496}, + {"token": "print", "log_prob": -0.06555093079805374}, + {"token": "(", "log_prob": -0.005272886715829372}, + {"token": "math", "log_prob": -0.009995171800255775}, + {"token": ".", "log_prob": -0.0002040654799202457}, + {"token": "sqrt", "log_prob": -0.00886327400803566}, + {"token": "(", "log_prob": -0.0015410225605592132}, + {"token": "2", "log_prob": -0.008573509752750397}, + {"token": "))", "log_prob": -0.010970987379550934}, + {"token": "\n", "log_prob": -0.002175347413867712}, + {"token": "``", "log_prob": -0.01911235973238945}, + {"token": "`", "log_prob": -0.0005327236140146852}, + {"token": "\n", "log_prob": -0.002304519060999155}, + {"token": "1", "log_prob": -0.10852570831775665}, + {"token": ".", "log_prob": -0.007146273739635944}, + {"token": "4", "log_prob": -0.003810290014371276}, + {"token": "1", "log_prob": -0.002774677239358425}, + {"token": "4", "log_prob": -0.16946221888065338}, + {"token": ".", "log_prob": -0.007678280584514141}, + {"token": ".", "log_prob": -0.021146666258573532}, + {"token": ".", "log_prob": -0.3870151937007904}, + {"token": "\n", "log_prob": -0.027081478387117386}, + {"token": "Final", "log_prob": -0.1980377733707428}, + {"token": " Answer", "log_prob": -0.0037908137310296297}, + {"token": ":", "log_prob": -0.015637163072824478}, + {"token": " ", "log_prob": -0.0010788579238578677}, + {"token": "4", "log_prob": -0.04351021721959114}, + {"token": "\n", "log_prob": -0.0021214615553617477}, + ], + } + ] + ) + ), + ], + any_order=True, + ) From 659d08ddf850667d53eb9e35827070f434cecef1 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:02:54 -0800 Subject: [PATCH 256/425] fix llm-engine finetune.create failures (#464) --- .../services/docker_image_batch_job_llm_fine_tuning_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py index f4622a16..fd60966f 100644 --- a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py +++ b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py @@ -77,7 +77,7 @@ async def create_fine_tune( raise LLMFineTuningMethodNotImplementedException("Fine-tuning method not accessible") # TODO: Pass user-defined labels - labels = dict(team="egp", product="llm-fine-tune") + labels = dict(team="egp", product="training.llm_engine_fine_tune") logger.info( f"Using bundle {di_batch_job_bundle.id} for fine-tune job: {di_batch_job_bundle.image_repository=}, {di_batch_job_bundle.image_tag=}" From bfcfbbab62b4d6795611eb8cdb9a1e662388403b Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 8 Mar 2024 13:05:50 -0800 Subject: [PATCH 257/425] Change back batch infer GPU util and add tool completion client changes (#465) * Change back batch infer gpu util * Add client changes * fixes * bump --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/completion.py | 33 +++++++++++++++++++ clients/python/llmengine/data_types.py | 30 +++++++++++++++++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- docs/api/data_types.md | 1 + docs/guides/completions.md | 26 ++++++++++++++- .../inference/batch_inference/vllm_batch.py | 3 +- 8 files changed, 94 insertions(+), 5 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index cc19aefd..998388ac 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b25" +__version__ = "0.0.0b26" import os from typing import Sequence diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 3a02f04e..43d0813c 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -10,6 +10,7 @@ CreateBatchCompletionsRequest, CreateBatchCompletionsRequestContent, CreateBatchCompletionsResponse, + ToolConfig, ) COMPLETION_TIMEOUT = 300 @@ -412,6 +413,7 @@ def batch_create( input_data_path: Optional[str] = None, data_parallelism: int = 1, max_runtime_sec: int = 24 * 3600, + tool_config: Optional[ToolConfig] = None, ) -> CreateBatchCompletionsResponse: """ Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint. @@ -437,6 +439,13 @@ def batch_create( max_runtime_sec (int): The maximum runtime of the batch completion in seconds. Defaults to 24 hours. + tool_config (Optional[ToolConfig]): + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + Currently only Python code evaluator is supported. + Python code context starts with "\`\`\`python\\n" and ends with "\\n>>>\\n", data before "\\n\`\`\`\\n" and content end will be replaced by the Python execution results. + Please format prompts accordingly and provide examples so LLMs could properly generate Python code. + Returns: response (CreateBatchCompletionsResponse): The response containing the job id. @@ -480,6 +489,29 @@ def batch_create( ) print(response.json()) ``` + + === "Batch completions with prompts and use tool" + ```python + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent, ToolConfig + + # Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" + + response = Completion.batch_create( + input_data_path="s3://my-input-path", + output_data_path="s3://my-output-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + data_parallelism=2, + tool_config=ToolConfig( + name="code_evaluator", + ) + ) + print(response.json()) + ``` """ data = CreateBatchCompletionsRequest( model_config=model_config, @@ -488,6 +520,7 @@ def batch_create( output_data_path=output_data_path, data_parallelism=data_parallelism, max_runtime_sec=max_runtime_sec, + tool_config=tool_config, ).dict() response = cls.post_sync( resource_name="v1/llm/batch-completions", diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 209084aa..06c0b805 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -1,6 +1,7 @@ """ DTOs for LLM APIs. """ + import datetime from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union @@ -658,6 +659,30 @@ class CreateBatchCompletionsModelConfig(BaseModel): """ +class ToolConfig(BaseModel): + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + name: str + """ + Name of the tool to use for the batch inference. + """ + max_iterations: Optional[int] = 10 + """ + Maximum number of iterations to run the tool. + """ + execution_timeout_seconds: Optional[int] = 60 + """ + Maximum runtime of the tool in seconds. + """ + should_retry_on_error: Optional[bool] = True + """ + Whether to retry the tool on error. + """ + + class CreateBatchCompletionsRequest(BaseModel): """ Request object for batch completions. @@ -685,6 +710,11 @@ class CreateBatchCompletionsRequest(BaseModel): """ Maximum runtime of the batch inference in seconds. Default to one day. """ + tool_config: Optional[ToolConfig] = None + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ class CreateBatchCompletionsResponse(BaseModel): diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index a0afe290..a2fdc9ce 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta25" +version = "0.0.0.beta26" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 961459dc..9afe6136 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta25", + version="0.0.0.beta26", packages=find_packages(), ) diff --git a/docs/api/data_types.md b/docs/api/data_types.md index 206c93e6..0576329c 100644 --- a/docs/api/data_types.md +++ b/docs/api/data_types.md @@ -143,6 +143,7 @@ - model_config - data_parallelism - max_runtime_sec + - tool_config ::: llmengine.CreateBatchCompletionsResponse options: diff --git a/docs/guides/completions.md b/docs/guides/completions.md index f48f05c4..69dfe1bd 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -122,7 +122,7 @@ asyncio.run(main()) ## Batch completions -The Python client also supports batch completions. Batch completions supports distributing data to multiple workers to accelerate inference. It also tries to maximize throughput so the completions should finish quite a bit faster than hitting models through HTTP. Use [Completion.batch_complete](../../api/python_client/#llmengine.completion.Completion.batch_complete) to utilize batch completions. +The Python client also supports batch completions. Batch completions supports distributing data to multiple workers to accelerate inference. It also tries to maximize throughput so the completions should finish quite a bit faster than hitting models through HTTP. Use [Completion.batch_create](../../api/python_client/#llmengine.Completion.batch_create) to utilize batch completions. Some examples of batch completions: @@ -169,6 +169,30 @@ response = Completion.batch_create( print(response.job_id) ``` +=== "Batch completions with prompts and use tool" +For how to properly use the tool please see [Completion.batch_create](../../api/python_client/#llmengine.Completion.batch_create) tool_config doc. +```python +from llmengine import Completion +from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent, ToolConfig + +# Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" + +response = Completion.batch_create( + input_data_path="s3://my-input-path", + output_data_path="s3://my-output-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + data_parallelism=2, + tool_config=ToolConfig( + name="code_evaluator", + ) +) +print(response.json()) +``` + ## Which model should I use? See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions. diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index d783b3f4..dced4e84 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -132,7 +132,7 @@ def get_vllm_engine(model, request): tensor_parallel_size=request.model_config.num_shards, seed=request.model_config.seed or 0, disable_log_requests=True, - gpu_memory_utilization=0.8, # To avoid OOM errors when there's host machine GPU usage + gpu_memory_utilization=0.9, ) llm = AsyncLLMEngine.from_engine_args(engine_args) @@ -432,6 +432,7 @@ def check_unknown_startup_memory_usage(): # pragma: no cover """Check for unknown memory usage at startup.""" gpu_free_memory = get_gpu_free_memory() if gpu_free_memory is not None: + print(f"GPU free memory at startup in MB: {gpu_free_memory}") min_mem = min(gpu_free_memory) max_mem = max(gpu_free_memory) if max_mem - min_mem > 10: From 4b012f0155cc7850cf804b81a24b0a24427aaf0a Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 11 Mar 2024 14:21:56 -0700 Subject: [PATCH 258/425] Try to fix async requests getting stuck (#466) --- model-engine/model_engine_server/core/celery/app.py | 5 +++++ .../model_engine_server/inference/forwarding/echo_server.py | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index a045d0aa..0b37966f 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -476,6 +476,11 @@ def _get_broker_endpoint_and_transport_options( # Going to try this with defaults first. out_broker_transport_options["region"] = os.environ.get("AWS_REGION", "us-west-2") + # changing wait_time_seconds from the default of 10 based on https://github.com/celery/celery/discussions/7283 + # goal is to prevent async requests from being stuck in pending when workers die; the hypothesis is that this is caused by SQS long polling + out_broker_transport_options["wait_time_seconds"] = 0 + out_broker_transport_options["polling_interval"] = 5 + # NOTE: The endpoints should ideally use predefined queues. However, the sender probably needs the flexibility # of not requiring predefined queues. # assert ( diff --git a/model-engine/model_engine_server/inference/forwarding/echo_server.py b/model-engine/model_engine_server/inference/forwarding/echo_server.py index db6c0b3c..3581f678 100644 --- a/model-engine/model_engine_server/inference/forwarding/echo_server.py +++ b/model-engine/model_engine_server/inference/forwarding/echo_server.py @@ -3,6 +3,7 @@ """ import argparse import subprocess +import time from fastapi import FastAPI, Request from fastapi.responses import JSONResponse @@ -19,7 +20,10 @@ def healthcheck(): @app.post("/predict") async def predict(request: Request): - return await request.json() + dictionary = await request.json() + if "delay" in dictionary: + time.sleep(dictionary["delay"]) + return dictionary @app.post("/predict500") From b09c106414fa5a08418b2172236e088caede5a15 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:12:38 -0700 Subject: [PATCH 259/425] [Client] Add num_prompt_tokens to the client's CompletionOutputs (#467) * add prompt token, untested * comment * remove stop_str stuff, it doesn't do anything with the public api, and it breaks on certain frameworks when hosted locally --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/data_types.py | 9 +++++++++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 998388ac..17dacfa9 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b26" +__version__ = "0.0.0b27" import os from typing import Sequence diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 06c0b805..70abd6cb 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -305,6 +305,11 @@ class CompletionOutput(BaseModel): text: str """The text of the completion.""" + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. + # If we send request to api.spellbook.scale.com, we don't get this back. + num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + num_completion_tokens: int """Number of tokens in the completion.""" @@ -353,6 +358,10 @@ class CompletionStreamOutput(BaseModel): finished: bool """Whether the completion is finished.""" + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. + num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + num_completion_tokens: Optional[int] = None """Number of tokens in the completion.""" diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index a2fdc9ce..2563b814 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta26" +version = "0.0.0.beta27" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 9afe6136..257516fc 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta26", + version="0.0.0.beta27", packages=find_packages(), ) From 80a2d3efe9a7611a2e3fe7a1002cc1c02ade7f26 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 12 Mar 2024 17:20:48 -0700 Subject: [PATCH 260/425] Tool completion respect num new tokens (#469) * Tool completion respect num new tokens * more fix * remove unused import * format * empty * no cover --- .../inference/batch_inference/vllm_batch.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index dced4e84..718d5e24 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -185,7 +185,7 @@ def __repr__(self) -> str: outputs = await generate_with_vllm( llm, - content.max_new_tokens, + [generations[iter[1]].remaining_tokens for iter in iter_prompts], content.temperature, content.stop_sequences, content.return_token_log_probs, @@ -260,7 +260,10 @@ def tool_func(text: str, past_context: Optional[str]): gen_item.generated_text += new_text # If we didn't just execute a tool, we're done - if not gen_item.generated_text.endswith(tool.tool_context_end): + if ( + not gen_item.generated_text.endswith(tool.tool_context_end) + or gen_item.remaining_tokens <= 0 + ): gen_item.completed = True continue @@ -316,7 +319,7 @@ async def batch_inference(): outputs = await generate_with_vllm( llm, - content.max_new_tokens, + [content.max_new_tokens] * len(prompts), content.temperature, content.stop_sequences, content.return_token_log_probs, @@ -358,24 +361,23 @@ async def generate_with_vllm( top_p, prompts, bar, -) -> List[CompletionOutput]: +) -> List[CompletionOutput]: # pragma: no cover from vllm import SamplingParams # Add the requests to the engine. - sampling_params = SamplingParams( - max_tokens=max_new_tokens, - temperature=temperature, - stop=stop_sequences, - logprobs=1 if return_token_log_probs else None, - presence_penalty=presence_penalty or 0.0, - frequency_penalty=frequency_penalty or 0.0, - top_k=top_k or -1, - top_p=top_p or 1.0, - ) - results_generators = [] - for prompt in prompts: + for idx, prompt in enumerate(prompts): request_id = random_uuid() + sampling_params = SamplingParams( + max_tokens=max_new_tokens[idx], + temperature=temperature, + stop=stop_sequences, + logprobs=1 if return_token_log_probs else None, + presence_penalty=presence_penalty or 0.0, + frequency_penalty=frequency_penalty or 0.0, + top_k=top_k or -1, + top_p=top_p or 1.0, + ) results_generator = await engine.add_request( request_id, prompt, sampling_params, None, time.monotonic() ) From 24314f5618ee1fe298ac7dacda47581197ee3ef4 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Fri, 15 Mar 2024 11:32:38 -0700 Subject: [PATCH 261/425] Azure fixes + additional asks (#468) --- charts/model-engine/templates/_helpers.tpl | 2 - model-engine/Dockerfile | 2 +- .../model_engine_server/api/dependencies.py | 4 +- .../model_engine_server/api/files_v1.py | 2 +- .../model_engine_server/common/config.py | 12 +- model-engine/model_engine_server/db/base.py | 6 +- .../model_engine_server/domain/exceptions.py | 6 + .../live_model_endpoints_schema_gateway.py | 10 +- .../repositories/acr_docker_repository.py | 13 +- .../infra/services/image_cache_service.py | 14 +- model-engine/requirements.in | 10 +- model-engine/requirements.txt | 166 ++++++++++-------- model-engine/tests/unit/api/test_tasks.py | 19 +- 13 files changed, 158 insertions(+), 108 deletions(-) diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index 3df2ea81..d13af358 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -270,8 +270,6 @@ env: value: {{ .Values.azure.abs_account_name }} - name: SERVICEBUS_NAMESPACE value: {{ .Values.azure.servicebus_namespace }} - - name: SERVICEBUS_SAS_KEY - value: {{ .Values.azure.servicebus_sas_key }} {{- end }} {{- end }} diff --git a/model-engine/Dockerfile b/model-engine/Dockerfile index 80939559..23eacd9c 100644 --- a/model-engine/Dockerfile +++ b/model-engine/Dockerfile @@ -1,6 +1,6 @@ # syntax = docker/dockerfile:experimental -FROM python:3.8.8-slim as model-engine +FROM python:3.8.18-slim as model-engine WORKDIR /workspace RUN apt-get update && apt-get install -y \ diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index b65f1189..713938d1 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -1,5 +1,6 @@ import asyncio import os +import time from dataclasses import dataclass from typing import Callable, Optional @@ -442,6 +443,7 @@ async def verify_authentication( def get_or_create_aioredis_pool() -> aioredis.ConnectionPool: global _pool - if _pool is None: + expiration_timestamp = hmi_config.cache_redis_url_expiration_timestamp + if _pool is None or (expiration_timestamp is not None and time.time() > expiration_timestamp): _pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) return _pool diff --git a/model-engine/model_engine_server/api/files_v1.py b/model-engine/model_engine_server/api/files_v1.py index 556566d5..d3c093f0 100644 --- a/model-engine/model_engine_server/api/files_v1.py +++ b/model-engine/model_engine_server/api/files_v1.py @@ -44,7 +44,7 @@ async def upload_file( ) return await use_case.execute( user=auth, - filename=file.filename, + filename=file.filename or "", content=file.file.read(), ) diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 6c7088fc..ac92cf43 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -30,6 +30,8 @@ SERVICE_CONFIG_PATH = os.environ.get("DEPLOY_SERVICE_CONFIG_PATH", DEFAULT_SERVICE_CONFIG_PATH) +redis_cache_expiration_timestamp = None + # duplicated from llm/ia3_finetune def get_model_cache_directory_name(model_name: str): @@ -81,9 +83,17 @@ def cache_redis_url(self) -> str: assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure" username = os.getenv("AZURE_OBJECT_ID") - password = DefaultAzureCredential().get_token("https://redis.azure.com/.default").token + token = DefaultAzureCredential().get_token("https://redis.azure.com/.default") + password = token.token + global redis_cache_expiration_timestamp + redis_cache_expiration_timestamp = token.expires_on return f"rediss://{username}:{password}@{self.cache_redis_azure_host}" + @property + def cache_redis_url_expiration_timestamp(self) -> Optional[int]: + global redis_cache_expiration_timestamp + return redis_cache_expiration_timestamp + @property def cache_redis_host_port(self) -> str: # redis://redis.url:6379/ diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index b6949617..9acf95c0 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -58,7 +58,7 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool user = os.environ.get("AZURE_IDENTITY_NAME") password = ( DefaultAzureCredential() - .get_token("https://ossrdbms-aad.database.windows.net") + .get_token("https://ossrdbms-aad.database.windows.net/.default") .token ) logger.info(f"Connecting to db {db} as user {user}") @@ -81,7 +81,9 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool # For async postgres, we need to use an async dialect. if not sync: - engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://") + engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://").replace( + "sslmode", "ssl" + ) return engine_url diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index 32a16bd8..c64e3beb 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -41,6 +41,12 @@ class DockerImageNotFoundException(DomainException): tag: str +class DockerRepositoryNotFoundException(DomainException): + """ + Thrown when a Docker repository that is trying to be accessed doesn't exist. + """ + + class DockerBuildFailedException(DomainException): """ Thrown if the server failed to build a docker image. diff --git a/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py index 1f6dd7b0..5fac2841 100644 --- a/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py @@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, Sequence, Set, Type, Union from fastapi import routing +from fastapi._compat import GenerateJsonSchema, get_model_definitions +from fastapi.openapi.constants import REF_TEMPLATE from fastapi.openapi.utils import get_openapi_path -from fastapi.utils import get_model_definitions from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, GetAsyncTaskV1Response, @@ -119,8 +120,13 @@ def get_openapi( if isinstance(route, routing.APIRoute): prefix = model_endpoint_name model_name_map = LiveModelEndpointsSchemaGateway.get_model_name_map(prefix) + schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) result = get_openapi_path( - route=route, model_name_map=model_name_map, operation_ids=operation_ids + route=route, + model_name_map=model_name_map, + operation_ids=operation_ids, + schema_generator=schema_generator, + field_mapping={}, ) if result: path, security_schemes, path_definitions = result diff --git a/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py index 2d6e1cc3..7f9137fe 100644 --- a/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py @@ -6,6 +6,7 @@ from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException from model_engine_server.domain.repositories import DockerRepository logger = make_logger(logger_name()) @@ -36,7 +37,11 @@ def get_latest_image_tag(self, repository_name: str) -> str: credential = DefaultAzureCredential() client = ContainerRegistryClient(endpoint, credential) - image = client.list_manifest_properties( - repository_name, order_by="time_desc", results_per_page=1 - ).next() - return image.tags[0] + try: + image = client.list_manifest_properties( + repository_name, order_by="time_desc", results_per_page=1 + ).next() + # Azure automatically deletes empty ACR repositories, so repos will always have at least one image + return image.tags[0] + except ResourceNotFoundError: + raise DockerRepositoryNotFoundException diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index 5d5c9d13..beab3ec8 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -7,6 +7,7 @@ from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import GpuType, ModelEndpointInfraState +from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException from model_engine_server.domain.repositories import DockerRepository from model_engine_server.infra.gateways.resources.image_cache_gateway import ( CachedImages, @@ -78,11 +79,14 @@ def _cache_finetune_llm_images( vllm_image_032 = DockerImage( f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2" ) - latest_tag = ( - self.docker_repository.get_latest_image_tag(hmi_config.batch_inference_vllm_repository) - if not CIRCLECI - else "fake_docker_repository_latest_image_tag" - ) + latest_tag = "fake_docker_repository_latest_image_tag" + if not CIRCLECI: + try: # pragma: no cover + latest_tag = self.docker_repository.get_latest_image_tag( + hmi_config.batch_inference_vllm_repository + ) + except DockerRepositoryNotFoundException: + pass vllm_batch_image_latest = DockerImage( f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}", latest_tag, diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 49984a54..756df6c3 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -9,7 +9,7 @@ azure-identity~=1.15.0 azure-keyvault-secrets~=4.7.0 azure-servicebus~=7.11.4 azure-storage-blob~=12.19.0 -boto3-stubs[essential]==1.26.67 +boto3-stubs[essential]~=1.26.67 boto3~=1.21 botocore~=1.24 build==0.8.0 @@ -17,13 +17,14 @@ celery[redis,sqs,tblib]~=5.3.6 click~=8.1 cloudpickle==2.1.0 croniter==1.4.1 +cryptography>=42.0.4 # not used directly, but needs to be pinned for Microsoft security scan dataclasses-json>=0.5.7 datadog-api-client==2.11.0 datadog~=0.47.0 ddtrace==1.8.3 deprecation~=2.1 docker~=5.0 -fastapi==0.78.0 +fastapi~=0.110.0 gitdb2~=2.0 gunicorn~=20.0 httptools==0.5.0 @@ -45,9 +46,10 @@ rich~=12.6 sentencepiece==0.1.99 sh~=1.13 smart-open~=5.2 -sqlalchemy[asyncio]==2.0.4 -sse-starlette==1.6.1 +sqlalchemy[asyncio]~=2.0.4 +sse-starlette==2.0.0 sseclient-py==1.7.2 +starlette[full]>=0.35.0 # not used directly, but needs to be pinned for Microsoft security scan stringcase==1.2.0 tenacity>=6.0.0,<=6.2.0 testing-postgresql==1.3.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 47d2fcef..bc0052c1 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -8,19 +8,21 @@ aiofiles==23.1.0 # via quart aiohttp==3.9.2 # via - # -r requirements.in + # -r model-engine/requirements.in # kubernetes-asyncio aioredis==2.0.1 - # via -r requirements.in + # via -r model-engine/requirements.in aiosignal==1.3.1 # via aiohttp alembic==1.8.1 - # via -r requirements.in + # via -r model-engine/requirements.in amqp==5.1.1 # via kombu anyio==3.7.1 # via # azure-core + # httpx + # sse-starlette # starlette asgiref==3.7.2 # via uvicorn @@ -32,7 +34,7 @@ async-timeout==4.0.2 # aioredis # redis asyncpg==0.27.0 - # via -r requirements.in + # via -r model-engine/requirements.in attrs==23.1.0 # via # aiohttp @@ -43,7 +45,7 @@ attrs==23.1.0 azure-common==1.1.28 # via azure-keyvault-secrets azure-containerregistry==1.2.0 - # via -r requirements.in + # via -r model-engine/requirements.in azure-core==1.29.6 # via # azure-containerregistry @@ -52,13 +54,13 @@ azure-core==1.29.6 # azure-servicebus # azure-storage-blob azure-identity==1.15.0 - # via -r requirements.in + # via -r model-engine/requirements.in azure-keyvault-secrets==4.7.0 - # via -r requirements.in + # via -r model-engine/requirements.in azure-servicebus==7.11.4 - # via -r requirements.in + # via -r model-engine/requirements.in azure-storage-blob==12.19.0 - # via -r requirements.in + # via -r model-engine/requirements.in backports-zoneinfo[tzdata]==0.2.1 # via # celery @@ -71,22 +73,20 @@ blinker==1.6.2 # via quart boto3==1.28.1 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # kombu boto3-stubs[essential]==1.26.67 - # via - # -r requirements.in - # boto3-stubs + # via -r model-engine/requirements.in botocore==1.31.1 # via - # -r requirements.in + # -r model-engine/requirements.in # boto3 # s3transfer botocore-stubs==1.29.165 # via boto3-stubs build==0.8.0 - # via -r requirements.in + # via -r model-engine/requirements.in bytecode==0.14.2 # via ddtrace cachetools==5.3.1 @@ -94,12 +94,12 @@ cachetools==5.3.1 cattrs==23.1.2 # via ddtrace celery[redis,sqs,tblib]==5.3.6 - # via - # -r requirements.in - # celery + # via -r model-engine/requirements.in certifi==2023.7.22 # via # datadog-api-client + # httpcore + # httpx # kubernetes # kubernetes-asyncio # requests @@ -109,7 +109,7 @@ charset-normalizer==3.2.0 # via requests click==8.1.4 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # click-didyoumean # click-plugins @@ -123,34 +123,35 @@ click-plugins==1.1.1 click-repl==0.3.0 # via celery cloudpickle==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in colorama==0.4.6 # via twine commonmark==0.9.1 # via rich croniter==1.4.1 - # via -r requirements.in -cryptography==41.0.7 + # via -r model-engine/requirements.in +cryptography==42.0.5 # via + # -r model-engine/requirements.in # azure-identity # azure-storage-blob # msal # pyjwt # secretstorage dataclasses-json==0.5.9 - # via -r requirements.in + # via -r model-engine/requirements.in datadog==0.47.0 - # via -r requirements.in + # via -r model-engine/requirements.in datadog-api-client==2.11.0 - # via -r requirements.in + # via -r model-engine/requirements.in ddsketch==2.0.4 # via ddtrace ddtrace==1.8.3 - # via -r requirements.in + # via -r model-engine/requirements.in deprecation==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in docker==5.0.3 - # via -r requirements.in + # via -r model-engine/requirements.in docutils==0.20.1 # via readme-renderer envier==0.4.0 @@ -159,8 +160,8 @@ exceptiongroup==1.2.0 # via # anyio # cattrs -fastapi==0.78.0 - # via -r requirements.in +fastapi==0.110.0 + # via -r model-engine/requirements.in filelock==3.13.1 # via # huggingface-hub @@ -174,17 +175,18 @@ fsspec==2023.10.0 gitdb==4.0.10 # via gitpython gitdb2==2.0.6 - # via -r requirements.in + # via -r model-engine/requirements.in gitpython==3.1.41 - # via -r requirements.in + # via -r model-engine/requirements.in google-auth==2.21.0 # via kubernetes greenlet==2.0.2 # via sqlalchemy gunicorn==20.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in h11==0.14.0 # via + # httpcore # hypercorn # uvicorn # wsproto @@ -192,8 +194,12 @@ h2==4.1.0 # via hypercorn hpack==4.0.0 # via h2 +httpcore==1.0.4 + # via httpx httptools==0.5.0 - # via -r requirements.in + # via -r model-engine/requirements.in +httpx==0.27.0 + # via starlette huggingface-hub==0.20.3 # via # tokenizers @@ -205,6 +211,7 @@ hyperframe==6.0.1 idna==3.4 # via # anyio + # httpx # requests # yarl importlib-metadata==6.8.0 @@ -226,7 +233,9 @@ isodate==0.6.1 # azure-servicebus # azure-storage-blob itsdangerous==2.1.2 - # via quart + # via + # quart + # starlette jaraco-classes==3.3.0 # via keyring jeepney==0.8.0 @@ -235,14 +244,15 @@ jeepney==0.8.0 # secretstorage jinja2==3.0.3 # via - # -r requirements.in + # -r model-engine/requirements.in # quart + # starlette jmespath==1.0.1 # via # boto3 # botocore json-log-formatter==0.5.2 - # via -r requirements.in + # via -r model-engine/requirements.in jsonschema==4.19.0 # via ddtrace jsonschema-specifications==2023.7.1 @@ -252,11 +262,11 @@ keyring==24.2.0 kombu[sqs]==5.3.5 # via celery kubeconfig==1.1.1 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes==25.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes-asyncio==25.11.0 - # via -r requirements.in + # via -r model-engine/requirements.in mako==1.2.4 # via alembic markupsafe==2.1.3 @@ -304,7 +314,7 @@ numpy==1.24.4 oauthlib==3.2.2 # via requests-oauthlib orjson==3.9.15 - # via -r requirements.in + # via -r model-engine/requirements.in packaging==23.1 # via # build @@ -330,13 +340,13 @@ prompt-toolkit==3.0.39 # via click-repl protobuf==3.20.3 # via - # -r requirements.in + # -r model-engine/requirements.in # ddsketch # ddtrace psycopg2-binary==2.9.3 - # via -r requirements.in + # via -r model-engine/requirements.in py-xid==0.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in pyasn1==0.5.0 # via # pyasn1-modules @@ -347,21 +357,19 @@ pycparser==2.21 # via cffi pycurl==7.45.2 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # kombu pydantic==1.10.11 # via - # -r requirements.in + # -r model-engine/requirements.in # fastapi pygments==2.15.1 # via # readme-renderer # rich pyjwt[crypto]==2.8.0 - # via - # msal - # pyjwt + # via msal python-dateutil==2.8.2 # via # botocore @@ -372,16 +380,19 @@ python-dateutil==2.8.2 # kubernetes-asyncio # pg8000 python-multipart==0.0.7 - # via -r requirements.in + # via + # -r model-engine/requirements.in + # starlette pyyaml==6.0.1 # via # huggingface-hub # kubeconfig # kubernetes # kubernetes-asyncio + # starlette # transformers quart==0.18.3 - # via -r requirements.in + # via -r model-engine/requirements.in readme-renderer==40.0 # via twine redis==4.6.0 @@ -394,7 +405,7 @@ regex==2023.10.3 # via transformers requests==2.31.0 # via - # -r requirements.in + # -r model-engine/requirements.in # azure-core # datadog # docker @@ -407,7 +418,7 @@ requests==2.31.0 # transformers # twine requests-auth-aws-sigv4==0.7 - # via -r requirements.in + # via -r model-engine/requirements.in requests-oauthlib==1.3.1 # via kubernetes requests-toolbelt==1.0.0 @@ -415,7 +426,7 @@ requests-toolbelt==1.0.0 rfc3986==2.0.0 # via twine rich==12.6.0 - # via -r requirements.in + # via -r model-engine/requirements.in rpds-py==0.10.0 # via # jsonschema @@ -431,9 +442,9 @@ scramp==1.4.4 secretstorage==3.3.3 # via keyring sentencepiece==0.1.99 - # via -r requirements.in + # via -r model-engine/requirements.in sh==1.14.3 - # via -r requirements.in + # via -r model-engine/requirements.in six==1.16.0 # via # azure-core @@ -447,7 +458,7 @@ six==1.16.0 # python-dateutil # tenacity smart-open==5.2.1 - # via -r requirements.in + # via -r model-engine/requirements.in smmap==5.0.0 # via # gitdb @@ -455,35 +466,37 @@ smmap==5.0.0 smmap2==3.0.1 # via gitdb2 sniffio==1.3.0 - # via anyio + # via + # anyio + # httpx sqlalchemy[asyncio]==2.0.4 # via - # -r requirements.in + # -r model-engine/requirements.in # alembic - # sqlalchemy -sse-starlette==1.6.1 - # via -r requirements.in +sse-starlette==2.0.0 + # via -r model-engine/requirements.in sseclient-py==1.7.2 - # via -r requirements.in -starlette==0.19.1 + # via -r model-engine/requirements.in +starlette[full]==0.36.3 # via + # -r model-engine/requirements.in # fastapi # sse-starlette stringcase==1.2.0 - # via -r requirements.in + # via -r model-engine/requirements.in tblib==2.0.0 # via celery tenacity==6.2.0 # via - # -r requirements.in + # -r model-engine/requirements.in # ddtrace testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in tokenizers==0.15.2 # via - # -r requirements.in + # -r model-engine/requirements.in # transformers tomli==2.0.1 # via @@ -492,21 +505,21 @@ tomli==2.0.1 # pep517 tqdm==4.65.0 # via - # -r requirements.in + # -r model-engine/requirements.in # huggingface-hub # transformers # twine transformers==4.38.0 - # via -r requirements.in + # via -r model-engine/requirements.in twine==3.7.1 - # via -r requirements.in + # via -r model-engine/requirements.in types-awscrt==0.16.23 # via # botocore-stubs # types-s3transfer types-s3transfer==0.6.1 # via boto3-stubs -typing-extensions==4.7.1 +typing-extensions==4.10.0 # via # aioredis # asgiref @@ -520,6 +533,7 @@ typing-extensions==4.7.1 # cattrs # datadog-api-client # ddtrace + # fastapi # huggingface-hub # kombu # mypy-boto3-cloudformation @@ -551,9 +565,11 @@ urllib3==1.26.16 # kubernetes-asyncio # requests uvicorn==0.17.6 - # via -r requirements.in + # via + # -r model-engine/requirements.in + # sse-starlette uvloop==0.17.0 - # via -r requirements.in + # via -r model-engine/requirements.in vine==5.1.0 # via # amqp @@ -575,7 +591,7 @@ xmltodict==0.13.0 # via ddtrace yarl==1.9.2 # via - # -r requirements.in + # -r model-engine/requirements.in # aiohttp zipp==3.16.0 # via diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index 611195bd..f9a0f062 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -360,15 +360,14 @@ def test_create_streaming_task_success( fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, ) - response = client.post( - f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", + with client.stream( + method="POST", + url=f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", auth=(test_api_key, ""), json=endpoint_predict_request_1[1], - stream=True, - ) - assert response.status_code == 200 - count = 0 - for message in response: - assert message == b'data: {"status": "SUCCESS", "result": null, "traceback": null}\r\n\r\n' - count += 1 - assert count == 1 + ) as response: + assert response.status_code == 200 + assert ( + response.read() + == b'data: {"status": "SUCCESS", "result": null, "traceback": null}\r\n\r\n' + ) From 1d33b2764732ee9680c4c5d0741f739b8ef2c01b Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Fri, 15 Mar 2024 14:47:05 -0700 Subject: [PATCH 262/425] Metrics for stuck async requests (#471) --- .../core/celery/__init__.py | 10 +++++- .../model_engine_server/core/celery/app.py | 3 ++ .../gateways/monitoring_metrics_gateway.py | 9 ----- .../inference_monitoring_metrics_gateway.py | 20 +++++++++++ .../inference/forwarding/celery_forwarder.py | 25 +++++++++++--- ...og_inference_monitoring_metrics_gateway.py | 8 +++++ ..._async_model_endpoint_inference_gateway.py | 3 +- ...test_live_async_model_inference_gateway.py | 33 ++++++++++--------- 8 files changed, 80 insertions(+), 31 deletions(-) diff --git a/model-engine/model_engine_server/core/celery/__init__.py b/model-engine/model_engine_server/core/celery/__init__.py index af024891..3368bc69 100644 --- a/model-engine/model_engine_server/core/celery/__init__.py +++ b/model-engine/model_engine_server/core/celery/__init__.py @@ -1,6 +1,13 @@ from typing import Sequence -from .app import TaskVisibility, celery_app, get_all_db_indexes, get_redis_host_port, inspect_app +from .app import ( + DEFAULT_TASK_VISIBILITY_SECONDS, + TaskVisibility, + celery_app, + get_all_db_indexes, + get_redis_host_port, + inspect_app, +) __all__: Sequence[str] = ( "celery_app", @@ -8,4 +15,5 @@ "get_redis_host_port", "inspect_app", "TaskVisibility", + "DEFAULT_TASK_VISIBILITY_SECONDS", ) diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index 0b37966f..167c01ba 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -31,6 +31,9 @@ ] = "model_engine_server.core.celery.abs:AzureBlockBlobBackend" +DEFAULT_TASK_VISIBILITY_SECONDS = 86400 + + @unique class TaskVisibility(IntEnum): """ diff --git a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py index 9bca6a0d..38861ade 100644 --- a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -24,14 +24,12 @@ class MonitoringMetricsGateway(ABC): def emit_attempted_build_metric(self): """ Service builder attempted metric - """ @abstractmethod def emit_successful_build_metric(self): """ Service builder succeeded metric - """ @abstractmethod @@ -44,42 +42,36 @@ def emit_build_time_metric(self, duration_seconds: float): def emit_image_build_cache_hit_metric(self, image_type: str): """ Service builder image build cache hit metric - """ @abstractmethod def emit_image_build_cache_miss_metric(self, image_type: str): """ Service builder image build cache miss metric - """ @abstractmethod def emit_docker_failed_build_metric(self): """ Service builder docker build failed metric - """ @abstractmethod def emit_database_cache_hit_metric(self): """ Successful database cache metric - """ @abstractmethod def emit_database_cache_miss_metric(self): """ Missed database cache metric - """ @abstractmethod def emit_route_call_metric(self, route: str, metadata: MetricMetadata): """ Route call metric - """ pass @@ -87,6 +79,5 @@ def emit_route_call_metric(self, route: str, metadata: MetricMetadata): def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMetadata): """ Token count metrics - """ pass diff --git a/model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py b/model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py index 13586992..15602563 100644 --- a/model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py @@ -30,3 +30,23 @@ def emit_successful_post_inference_hook(self, hook: str): Args: hook: The name of the hook """ + + @abstractmethod + def emit_async_task_received_metric(self, queue_name: str): + """ + Async task received metric + + Args: + queue_name: The name of the Celery queue + """ + pass + + @abstractmethod + def emit_async_task_stuck_metric(self, queue_name: str): + """ + Async task stuck metric + + Args: + queue_name: The name of the Celery queue + """ + pass diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 264f6af5..d9c841f2 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -1,12 +1,17 @@ import argparse import json +from datetime import datetime, timedelta from typing import Any, Dict, Optional, TypedDict, Union from celery import Celery, Task, states from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.dtos.tasks import EndpointPredictV1Request -from model_engine_server.core.celery import TaskVisibility, celery_app +from model_engine_server.core.celery import ( + DEFAULT_TASK_VISIBILITY_SECONDS, + TaskVisibility, + celery_app, +) from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.core.utils.format import format_stacktrace @@ -15,6 +20,9 @@ LoadForwarder, load_named_config, ) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) logger = make_logger(logger_name()) @@ -68,6 +76,8 @@ def create_celery_service( backend_protocol=backend_protocol, ) + monitoring_metrics_gateway = DatadogInferenceMonitoringMetricsGateway() + class ErrorHandlingTask(Task): """Sets a 'custom' field with error in the Task response for FAILURE. @@ -112,13 +122,18 @@ def after_return( # See documentation for options: # https://docs.celeryproject.org/en/stable/userguide/tasks.html#list-of-options @app.task(base=ErrorHandlingTask, name=LIRA_CELERY_TASK_NAME, track_started=True) - def exec_func(payload, *ignored_args, **ignored_kwargs): + def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs): if len(ignored_args) > 0: logger.warning(f"Ignoring {len(ignored_args)} positional arguments: {ignored_args=}") if len(ignored_kwargs) > 0: logger.warning(f"Ignoring {len(ignored_kwargs)} keyword arguments: {ignored_kwargs=}") try: - return forwarder(payload) + monitoring_metrics_gateway.emit_async_task_received_metric(queue_name) + result = forwarder(payload) + request_duration = datetime.now() - arrival_timestamp + if request_duration > timedelta(seconds=DEFAULT_TASK_VISIBILITY_SECONDS): + monitoring_metrics_gateway.emit_async_task_stuck_metric(queue_name) + return result except Exception: logger.exception("Celery service failed to respond to request.") raise @@ -131,8 +146,8 @@ def exec_func(payload, *ignored_args, **ignored_kwargs): name=DEFAULT_CELERY_TASK_NAME, track_started=True, ) - def exec_func_pre_lira(payload, *ignored_args, **ignored_kwargs): - return exec_func(payload, *ignored_args, **ignored_kwargs) + def exec_func_pre_lira(payload, arrival_timestamp, *ignored_args, **ignored_kwargs): + return exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs) return app diff --git a/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py index a8999723..802066cb 100644 --- a/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py @@ -10,3 +10,11 @@ def emit_attempted_post_inference_hook(self, hook: str): def emit_successful_post_inference_hook(self, hook: str): statsd.increment(f"scale_launch.post_inference_hook.{hook}.success") + + def emit_async_task_received_metric(self, queue_name: str): + statsd.increment( + "scale_launch.async_task.received.count", tags=[f"queue_name:{queue_name}"] + ) # pragma: no cover + + def emit_async_task_stuck_metric(self, queue_name: str): + statsd.increment("scale_launch.async_task.stuck.count", tags=[f"queue_name:{queue_name}"]) diff --git a/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py index 7976e24f..3c0408c8 100644 --- a/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py @@ -1,4 +1,5 @@ import json +from datetime import datetime from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME from model_engine_server.common.dtos.tasks import ( @@ -37,7 +38,7 @@ def create_task( send_task_response = self.task_queue_gateway.send_task( task_name=task_name, queue_name=topic, - args=[predict_args, predict_request.return_pickled], + args=[predict_args, datetime.now(), predict_request.return_pickled], expires=task_timeout_seconds, ) return CreateAsyncTaskV1Response(task_id=send_task_response.task_id) diff --git a/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py index 2b4939df..e86f0f1f 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py @@ -1,4 +1,5 @@ import json +from datetime import datetime, timedelta from typing import Any import pytest @@ -22,10 +23,11 @@ def test_task_create_get_url( task_id = create_response.task_id task_queue_gateway: Any = fake_live_async_model_inference_gateway.task_queue_gateway assert len(task_queue_gateway.queue) == 1 - assert task_queue_gateway.queue[task_id]["args"] == [ - endpoint_predict_request_1[0].dict(), - endpoint_predict_request_1[0].return_pickled, - ] + assert task_queue_gateway.queue[task_id]["args"][0] == endpoint_predict_request_1[0].dict() + assert (datetime.now() - task_queue_gateway.queue[task_id]["args"][1]) < timedelta(seconds=1) + assert ( + task_queue_gateway.queue[task_id]["args"][2] == endpoint_predict_request_1[0].return_pickled + ) get_response_1 = fake_live_async_model_inference_gateway.get_task(task_id) assert get_response_1 == GetAsyncTaskV1Response(task_id=task_id, status=TaskStatus.PENDING) @@ -49,17 +51,18 @@ def test_task_create_get_args_callback( task_id = create_response.task_id task_queue_gateway: Any = fake_live_async_model_inference_gateway.task_queue_gateway assert len(task_queue_gateway.queue) == 1 - assert task_queue_gateway.queue[task_id]["args"] == [ - { - "args": endpoint_predict_request_2[0].args.__root__, - "url": None, - "cloudpickle": None, - "callback_auth": json.loads(endpoint_predict_request_2[0].callback_auth.json()), - "callback_url": endpoint_predict_request_2[0].callback_url, - "return_pickled": endpoint_predict_request_2[0].return_pickled, - }, - endpoint_predict_request_2[0].return_pickled, - ] + assert task_queue_gateway.queue[task_id]["args"][0] == { + "args": endpoint_predict_request_2[0].args.__root__, + "url": None, + "cloudpickle": None, + "callback_auth": json.loads(endpoint_predict_request_2[0].callback_auth.json()), + "callback_url": endpoint_predict_request_2[0].callback_url, + "return_pickled": endpoint_predict_request_2[0].return_pickled, + } + assert (datetime.now() - task_queue_gateway.queue[task_id]["args"][1]) < timedelta(seconds=1) + assert ( + task_queue_gateway.queue[task_id]["args"][2] == endpoint_predict_request_2[0].return_pickled + ) get_response_1 = fake_live_async_model_inference_gateway.get_task(task_id) assert get_response_1 == GetAsyncTaskV1Response(task_id=task_id, status=TaskStatus.PENDING) From 98e1f43faa7e25742383e135a9c5be204e966f8d Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 15 Mar 2024 16:41:12 -0700 Subject: [PATCH 263/425] Fix cacher (#472) * Fix cacher * repo name * unit test --- .../model_engine_server/infra/services/image_cache_service.py | 4 +--- .../tests/unit/infra/services/test_image_cache_service.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index beab3ec8..4966e9f4 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -91,9 +91,7 @@ def _cache_finetune_llm_images( f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}", latest_tag, ) - forwarder_image = DockerImage( - f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG - ) + forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/model-engine", GIT_TAG) for llm_image in [ istio_image, diff --git a/model-engine/tests/unit/infra/services/test_image_cache_service.py b/model-engine/tests/unit/infra/services/test_image_cache_service.py index bf578c6d..3dd1913d 100644 --- a/model-engine/tests/unit/infra/services/test_image_cache_service.py +++ b/model-engine/tests/unit/infra/services/test_image_cache_service.py @@ -66,7 +66,7 @@ async def test_caching_finetune_llm_images( f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}", latest_tag, ) - forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/launch/gateway", GIT_TAG) + forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/model-engine", GIT_TAG) for key in ["a10", "a100"]: for llm_image in [ From 6db2d488e96b17060c6af96716010d73b475cf4a Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 19 Mar 2024 12:55:54 -0700 Subject: [PATCH 264/425] Add retries to deflake integration tests (#473) --- integration_tests/rest_api_utils.py | 1 + integration_tests/test_endpoints.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index abec0cf4..286e6ef0 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -516,6 +516,7 @@ def get_model_endpoint(name: str, user_id: str) -> Dict[str, Any]: return response.json()["model_endpoints"][0] +@retry(stop=stop_after_attempt(3), wait=wait_fixed(20)) def update_model_endpoint( endpoint_name: str, update_model_endpoint_request: Dict[str, Any], user_id: str ) -> Dict[str, Any]: diff --git a/integration_tests/test_endpoints.py b/integration_tests/test_endpoints.py index 26f2dbe2..5b0a6404 100644 --- a/integration_tests/test_endpoints.py +++ b/integration_tests/test_endpoints.py @@ -59,6 +59,14 @@ def ensure_async_inference_works(user, create_endpoint_request, inference_payloa ensure_all_async_tasks_success(task_ids, user, return_pickled) +@retry(stop=stop_after_attempt(3), wait=wait_fixed(20)) +def ensure_endpoint_updated(create_endpoint_request, update_endpoint_request, user): + endpoint = get_model_endpoint(create_endpoint_request["name"], user) + assert endpoint["resource_state"]["cpus"] == update_endpoint_request["cpus"] + assert endpoint["resource_state"]["memory"] == update_endpoint_request["memory"] + assert endpoint["deployment_state"]["max_workers"] == update_endpoint_request["max_workers"] + + @pytest.mark.parametrize( "create_endpoint_request,update_endpoint_request,inference_requests", [ @@ -99,13 +107,7 @@ def test_async_model_endpoint( ensure_n_ready_endpoints_short(1, user) print("Checking endpoint state...") - endpoint = get_model_endpoint(create_endpoint_request["name"], user) - assert endpoint["resource_state"]["cpus"] == update_endpoint_request["cpus"] - assert endpoint["resource_state"]["memory"] == update_endpoint_request["memory"] - assert ( - endpoint["deployment_state"]["max_workers"] - == update_endpoint_request["max_workers"] - ) + ensure_endpoint_updated(create_endpoint_request, update_endpoint_request, user) time.sleep(20) From 990409142dbcdcf5862fce2ec1f71ec8d382842f Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 19 Mar 2024 15:58:15 -0700 Subject: [PATCH 265/425] add suffix to integration tests (#474) * add suffix to integration tests * dont suffix bundle --- integration_tests/rest_api_utils.py | 45 ++++++++++++++++++++--------- integration_tests/test_docs.py | 33 ++++++++++----------- model-engine/requirements-test.txt | 15 +++++----- 3 files changed, 54 insertions(+), 39 deletions(-) diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index 286e6ef0..1e780c37 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -15,6 +15,14 @@ print(f"Integration tests using gateway {BASE_PATH=}") DEFAULT_NETWORK_TIMEOUT_SEC = 10 +# add suffix to avoid name collisions +SERVICE_IDENTIFIER = os.environ.get("SERVICE_IDENTIFIER", "") + + +def format_name(name: str) -> str: + return f"{name}-{SERVICE_IDENTIFIER}" if SERVICE_IDENTIFIER else name + + # Use the scale-launch-integration-tests id USER_ID_0 = os.getenv("TEST_USER_ID", "fakeuser") @@ -97,7 +105,7 @@ def my_model(**keyword_args): CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE = { "bundle_name": "model_bundle_simple", - "name": "model-endpoint-simple-async", + "name": format_name("model-endpoint-simple-async"), "endpoint_type": "async", "cpus": "0.5", "memory": "500Mi", @@ -110,12 +118,12 @@ def my_model(**keyword_args): } CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE = CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE.copy() -CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE["name"] = "model-endpoint-simple-sync" +CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE["name"] = format_name("model-endpoint-simple-sync") CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE["endpoint_type"] = "sync" CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE = { "bundle_name": "model_bundle_runnable_image", - "name": "model-endpoint-runnable-image-async", + "name": format_name("model-endpoint-runnable-async"), "post_inference_hooks": [], "endpoint_type": "async", "cpus": "1", @@ -132,9 +140,9 @@ def my_model(**keyword_args): CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE = ( CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE.copy() ) -CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE[ - "name" -] = "model-endpoint-runnable-image-sync-streaming" +CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE["name"] = format_name( + "model-endpoint-runnable-sync-streaming" +) CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE["endpoint_type"] = "streaming" UPDATE_MODEL_ENDPOINT_REQUEST_SIMPLE = { @@ -175,7 +183,7 @@ def my_model(**keyword_args): } CREATE_DOCKER_IMAGE_BATCH_JOB_BUNDLE_REQUEST: Dict[str, Any] = { - "name": "di_batch_job_bundle_1", + "name": format_name("di_batch_job_bundle_1"), "image_repository": "model-engine", "image_tag": "2c1951dfff7159d7d29dd13b4f888e8355f8d51e", "command": ["jq", ".", "/launch_mount_location/file"], @@ -188,14 +196,14 @@ def my_model(**keyword_args): } CREATE_DOCKER_IMAGE_BATCH_JOB_REQUEST: Dict[str, Any] = { - "docker_image_batch_job_bundle_name": "di_batch_job_bundle_1", + "docker_image_batch_job_bundle_name": format_name("di_batch_job_bundle_1"), "job_config": {"data": {"to": "mount"}}, "labels": {"team": "infra", "product": "testing"}, "resource_requests": {"cpus": 0.15, "memory": "15Mi"}, } CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST: Dict[str, Any] = { - "name": "fine_tune_di_batch_job_bundle_1", + "name": format_name("fine_tune_di_batch_job_bundle_1"), "image_repository": "model-engine", "image_tag": "2c1951dfff7159d7d29dd13b4f888e8355f8d51e", "command": ["cat", "/launch_mount_location/file"], @@ -700,9 +708,16 @@ def ensure_n_ready_endpoints_short(n: int, user_id: str): assert len(ready_endpoints) >= n -def delete_all_endpoints(user_id): +def delete_all_endpoints(user_id: str, delete_suffix_only: bool): endpoints = list_model_endpoints(user_id) for i, endpoint in enumerate(endpoints): + if ( + delete_suffix_only + and SERVICE_IDENTIFIER + and not endpoint["name"].endswith(SERVICE_IDENTIFIER) + ): + continue + response = delete_model_endpoint(endpoint["name"], user_id) assert response["deleted"] print(f"[{i + 1}/{len(endpoints)}] Deleted {endpoint=}") @@ -745,7 +760,9 @@ def ensure_all_async_tasks_success(task_ids: List[str], user_id: str, return_pic ensure_inference_task_response_is_correct(response, return_pickled) -def delete_existing_endpoints(users: Sequence[str] = DEFAULT_USERS) -> None: +def delete_existing_endpoints( + users: Sequence[str] = DEFAULT_USERS, delete_suffix_only: bool = True +) -> None: if len(users) == 0: raise ValueError("Must supply at least one user!") @@ -778,8 +795,9 @@ def delete_existing_endpoints(users: Sequence[str] = DEFAULT_USERS) -> None: print(f"[{len({users})}] Deleting all user endpoints...") try: for i, u in enumerate(users): - print(f"[{i + 1}/{len(users)}] Deleting all endpoints for user with ID {u}") - delete_all_endpoints(u) + suffix_msg = f" with suffix {SERVICE_IDENTIFIER}" if delete_suffix_only else "" + print(f"[{i + 1}/{len(users)}] Deleting all endpoints{suffix_msg} for user with ID {u}") + delete_all_endpoints(u, delete_suffix_only) except Exception: # noqa try: j: str = json.dumps(all_endpoint_info, indent=2) @@ -788,5 +806,4 @@ def delete_existing_endpoints(users: Sequence[str] = DEFAULT_USERS) -> None: barrier: str = "-" * 80 print(f"ERROR! Deletion failed. All endpoint information:\n{barrier}\n{j}\n{barrier}") raise - time.sleep(15) diff --git a/integration_tests/test_docs.py b/integration_tests/test_docs.py index 2185154e..d5ee2435 100644 --- a/integration_tests/test_docs.py +++ b/integration_tests/test_docs.py @@ -2,7 +2,6 @@ # flake8: noqa: W605 import importlib.util import os -import random import re from pathlib import Path from textwrap import dedent @@ -13,6 +12,7 @@ ROOT_DIR = Path(__file__).parent.parent TEST_SKIP_MAGIC_STRING = "# test='skip'" +SERVICE_IDENTIFIER = os.environ.get("SERVICE_IDENTIFIER", "") @pytest.fixture @@ -50,18 +50,12 @@ def env(): setenv.clear() -@pytest.fixture() -def seed() -> int: - """Returns a random seed between 0 and 999, inclusive.""" - return random.randint(0, 999) - - @pytest.fixture() def integration_test_user_id() -> str: - return "62bc820451dbea002b1c5421" + return os.getenv("TEST_USER_ID", "fakeuser") -def modify_source(source: str, seed: int) -> str: +def modify_source(source: str) -> str: # Adds some custom logic to update code from docs to comply with some requirements. source = re.sub(r"('team'|\"team\"): ('\w+'|\"\w+\")", r"'team': 'infra'", source) source = re.sub( @@ -72,15 +66,21 @@ def modify_source(source: str, seed: int) -> str: # Add suffix to avoid name collisions source = re.sub( - r"('endpoint_name'|\"endpoint_name\"): ('(\w+)'|\"(\w+)\")", - f"'endpoint_name': '\g<3>\g<4>-{seed}'", + r"('endpoint_name'|\"endpoint_name\"): ('([\w-]+)'|\"([\w-]+)\")", + rf"'endpoint_name': '\g<3>\g<4>-{SERVICE_IDENTIFIER}'", source, ) source = re.sub( - r"endpoint_name=('(\w+)'|\"(\w+)\")", - f"endpoint_name='\g<2>\g<3>-{seed}'", + r"endpoint_name=('([\w-]+)'|\"([\w-]+)\")", + rf"endpoint_name='\g<2>\g<3>-{SERVICE_IDENTIFIER}'", source, ) + source = re.sub( + r"get_model_endpoint\(\"([\w-]+)\"\)", + rf'get_model_endpoint("\g<1>-{SERVICE_IDENTIFIER}")', + source, + ) + source = re.sub(r'"repository": "..."', '"repository": "launch_rearch"', source) source = re.sub( r'"tag": "..."', '"tag": "11d9d42047cc9a0c6435b19e5e91bc7e0ad31efc-cpu"', source @@ -126,7 +126,7 @@ def modify_source(source: str, seed: int) -> str: @pytest.fixture def import_execute(request, tmp_work_path: Path): - def _import_execute(module_name: str, source: str, seed: int, rewrite_assertions: bool = False): + def _import_execute(module_name: str, source: str, rewrite_assertions: bool = False): if rewrite_assertions: loader = AssertionRewritingHook(config=request.config) loader.mark_rewrite(module_name) @@ -134,7 +134,7 @@ def _import_execute(module_name: str, source: str, seed: int, rewrite_assertions loader = None module_path = tmp_work_path / f"{module_name}.py" - modified_source = modify_source(source, seed) + modified_source = modify_source(source) module_path.write_text(modified_source) spec = importlib.util.spec_from_file_location("__main__", str(module_path), loader=loader) module = importlib.util.module_from_spec(spec) @@ -196,7 +196,6 @@ def test_docs_examples( source_code, import_execute, env, - seed, integration_test_user_id, ): if source_code == "__skip__": @@ -205,6 +204,6 @@ def test_docs_examples( env("LAUNCH_API_KEY", os.getenv("LAUNCH_TEST_API_KEY", integration_test_user_id)) try: - import_execute(module_name, source_code, seed, True) + import_execute(module_name, source_code, True) except Exception: raise diff --git a/model-engine/requirements-test.txt b/model-engine/requirements-test.txt index bbe191e4..55a4b9f2 100644 --- a/model-engine/requirements-test.txt +++ b/model-engine/requirements-test.txt @@ -1,16 +1,17 @@ +coverage==5.5 +diff-cover==7.7.0 +frozendict==2.3.4 +func-timeout==4.3.5 multiprocess==0.70.14 +moto==3.1.12 +mypy==1.3.0 +pylint<3.0.0 pytest==7.2.0 pytest-asyncio==0.20.1 pytest-cov==2.10.0 -diff-cover==7.7.0 -moto==3.1.12 -coverage==5.5 -mypy==1.3.0 pytest-mypy==0.9.1 pytest-mypy-plugins==1.10.1 -pytest-asyncio==0.20.1 pytest-pylint==0.18.0 -pylint<3.0.0 types-cachetools==5.3.0.5 types-croniter==1.4.0.0 types-PyYAML==6.0.7 @@ -23,5 +24,3 @@ types-toml==0.10.8 types-ujson==5.5.0 types-urllib3==1.26.14 types-waitress==2.1.4 -frozendict==2.3.4 -func-timeout==4.3.5 From 2e5eec22052e4223702c3e31b0f399c69cca82e2 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 19 Mar 2024 18:38:45 -0700 Subject: [PATCH 266/425] fix docs tests gateway endpoint (#475) * fix docs tests gateway endpoint * update comments * delete docs endpoints after test --- integration_tests/test_docs.py | 35 +++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/integration_tests/test_docs.py b/integration_tests/test_docs.py index d5ee2435..1b5ff941 100644 --- a/integration_tests/test_docs.py +++ b/integration_tests/test_docs.py @@ -9,10 +9,16 @@ import pytest from _pytest.assertion.rewrite import AssertionRewritingHook +from .rest_api_utils import ( + BASE_PATH, + SERVICE_IDENTIFIER, + delete_existing_endpoints, + ensure_gateway_ready, +) + ROOT_DIR = Path(__file__).parent.parent TEST_SKIP_MAGIC_STRING = "# test='skip'" -SERVICE_IDENTIFIER = os.environ.get("SERVICE_IDENTIFIER", "") @pytest.fixture @@ -56,11 +62,17 @@ def integration_test_user_id() -> str: def modify_source(source: str) -> str: - # Adds some custom logic to update code from docs to comply with some requirements. - source = re.sub(r"('team'|\"team\"): ('\w+'|\"\w+\")", r"'team': 'infra'", source) + """Modify the source code from docs to be compatible with the integration tests.""" + + # Ensure the correct base path is used source = re.sub( - r"('product'|\"product\"): ('\w+'|\"\w+\")", - r"'product': 'launch-integration-test'", + r"get_launch_client\((.*)\)\n", + rf'get_launch_client(\g<1>, gateway_endpoint="{BASE_PATH}")\n', + source, + ) + source = re.sub( + r"LaunchClient\((.*)\)\n", + rf'LaunchClient(\g<1>, endpoint="{BASE_PATH}")\n', source, ) @@ -81,6 +93,15 @@ def modify_source(source: str) -> str: source, ) + # Set particular tag values for cost tracking + source = re.sub(r"('team'|\"team\"): ('\w+'|\"\w+\")", r"'team': 'infra'", source) + source = re.sub( + r"('product'|\"product\"): ('\w+'|\"\w+\")", + r"'product': 'launch-integration-test'", + source, + ) + + # Fill in empty values in docs source = re.sub(r'"repository": "..."', '"repository": "launch_rearch"', source) source = re.sub( r'"tag": "..."', '"tag": "11d9d42047cc9a0c6435b19e5e91bc7e0ad31efc-cpu"', source @@ -203,7 +224,11 @@ def test_docs_examples( env("LAUNCH_API_KEY", os.getenv("LAUNCH_TEST_API_KEY", integration_test_user_id)) + ensure_gateway_ready() + try: import_execute(module_name, source_code, True) except Exception: raise + finally: + delete_existing_endpoints() From 5f6cd3238435e774b0e7870490d5c7b3879c3424 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 20 Mar 2024 21:47:02 -0700 Subject: [PATCH 267/425] Guided decoding (#476) * Guided decoding * endpoints * fix * update client * unit tests * fix test * coverage * coverage * fix * try to bump coverage * more tests! * lint --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/completion.py | 42 ++++- clients/python/llmengine/data_types.py | 8 + clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- docs/guides/completions.md | 53 +++++++ .../model_engine_server/common/dtos/llms.py | 24 +++ .../use_cases/llm_model_endpoint_use_cases.py | 32 ++++ .../inference/vllm/requirements.txt | 3 +- .../inference/vllm/vllm_server.py | 32 +++- model-engine/tests/unit/conftest.py | 132 ++++++++++++++++ .../tests/unit/domain/test_llm_use_cases.py | 149 ++++++++++++++++++ 12 files changed, 474 insertions(+), 7 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 17dacfa9..dfae78cf 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b27" +__version__ = "0.0.0b28" import os from typing import Sequence diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 43d0813c..0181b733 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -1,4 +1,4 @@ -from typing import AsyncIterable, Iterator, List, Optional, Union +from typing import Any, AsyncIterable, Dict, Iterator, List, Optional, Union from llmengine.api_engine import APIEngine from llmengine.data_types import ( @@ -43,6 +43,10 @@ async def acreate( frequency_penalty: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + include_stop_str_in_output: Optional[bool] = False, + guided_json: Optional[Dict[str, Any]] = None, + guided_regex: Optional[str] = None, + guided_choice: Optional[List[str]] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]: @@ -102,6 +106,18 @@ async def acreate( Float that controls the cumulative probability of the top tokens to consider. Range: (0.0, 1.0]. 1.0 means consider all tokens. + include_stop_str_in_output (Optional[bool]): + Whether to include the stop sequence in the output. Default to False. + + guided_json (Optional[Dict[str, Any]]): + If specified, the output will follow the JSON schema. For examples see https://json-schema.org/learn/miscellaneous-examples. + + guided_regex (Optional[str]): + If specified, the output will follow the regex pattern. + + guided_choice (Optional[List[str]]): + If specified, the output will be exactly one of the choices. + timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -198,6 +214,10 @@ async def _acreate_stream( frequency_penalty=frequency_penalty, top_k=top_k, top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, timeout=timeout, ) @@ -237,6 +257,10 @@ def create( frequency_penalty: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + include_stop_str_in_output: Optional[bool] = False, + guided_json: Optional[Dict[str, Any]] = None, + guided_regex: Optional[str] = None, + guided_choice: Optional[List[str]] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]: @@ -297,6 +321,18 @@ def create( Float that controls the cumulative probability of the top tokens to consider. Range: (0.0, 1.0]. 1.0 means consider all tokens. + include_stop_str_in_output (Optional[bool]): + Whether to include the stop sequence in the output. Default to False. + + guided_json (Optional[Dict[str, Any]]): + If specified, the output will follow the JSON schema. + + guided_regex (Optional[str]): + If specified, the output will follow the regex pattern. + + guided_choice (Optional[List[str]]): + If specified, the output will be exactly one of the choices. + timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -396,6 +432,10 @@ def _create_stream(**kwargs): frequency_penalty=frequency_penalty, top_k=top_k, top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, ).dict() response = cls.post_sync( resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}", diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 70abd6cb..a3ed3209 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -279,6 +279,10 @@ class CompletionSyncV1Request(BaseModel): frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_k: Optional[int] = Field(default=None, ge=-1) top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + include_stop_str_in_output: Optional[bool] = Field(default=False) + guided_json: Optional[Dict[str, Any]] = Field(default=None) + guided_regex: Optional[str] = Field(default=None) + guided_choice: Optional[List[str]] = Field(default=None) class TokenOutput(BaseModel): @@ -349,6 +353,10 @@ class CompletionStreamV1Request(BaseModel): frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_k: Optional[int] = Field(default=None, ge=-1) top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + include_stop_str_in_output: Optional[bool] = Field(default=False) + guided_json: Optional[Dict[str, Any]] = Field(default=None) + guided_regex: Optional[str] = Field(default=None) + guided_choice: Optional[List[str]] = Field(default=None) class CompletionStreamOutput(BaseModel): diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 2563b814..8ddec08f 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta27" +version = "0.0.0.beta28" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 257516fc..a33e6a03 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta27", + version="0.0.0.beta28", packages=find_packages(), ) diff --git a/docs/guides/completions.md b/docs/guides/completions.md index 69dfe1bd..86bb9f0b 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -193,6 +193,59 @@ response = Completion.batch_create( print(response.json()) ``` +## Guided decoding + +Guided decoding is supported by vLLM and backed by [Outlines](https://github.com/outlines-dev/outlines). +It enforces certain token generation patterns by tinkering with the sampling logits. + +=== "Guided decoding with regex" +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_regex="Sean.*", +) + +print(response.json()) +# {"request_id":"c19f0fae-317e-4f69-8e06-c04189299b9c","output":{"text":"Sean. I'm a 2","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}} +``` + +=== "Guided decoding with choice" +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_choice=["Sean", "Brian", "Tim"], +) + +print(response.json()) +# {"request_id":"641e2af3-a3e3-4493-98b9-d38115ba0d22","output":{"text":"Sean","num_prompt_tokens":6,"num_completion_tokens":4,"tokens":null}} +``` + +=== "Guided decoding with JSON schema" +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_json={"properties":{"myString":{"type":"string"}},"required":["myString"]}, +) + +print(response.json()) +# {"request_id":"5b184654-96b6-4932-9eb6-382a51fdb3d5","output":{"text":"{\"myString\" : \"John Doe","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}} +``` + ## Which model should I use? See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions. diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 8d335d8d..9fb8ed1d 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -184,6 +184,18 @@ class CompletionSyncV1Request(BaseModel): """ Whether to include the stop strings in output text. """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. + """ class TokenOutput(BaseModel): @@ -248,6 +260,18 @@ class CompletionStreamV1Request(BaseModel): """ Whether to include the stop strings in output text. """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ class CompletionStreamOutput(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index b458343c..65973ced 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1365,6 +1365,26 @@ def validate_and_update_completion_params( "include_stop_str_in_output is only supported in vllm." ) + guided_count = 0 + if request.guided_choice is not None: + guided_count += 1 + if request.guided_json is not None: + guided_count += 1 + if request.guided_regex is not None: + guided_count += 1 + + if guided_count > 1: + raise ObjectHasInvalidValueException( + "Only one of guided_json, guided_choice, guided_regex can be enabled." + ) + + if ( + request.guided_choice is not None + or request.guided_regex is not None + or request.guided_json is not None + ) and not inference_framework == LLMInferenceFramework.VLLM: + raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.") + return request @@ -1656,6 +1676,12 @@ async def execute( vllm_args["logprobs"] = 1 if request.include_stop_str_in_output is not None: vllm_args["include_stop_str_in_output"] = request.include_stop_str_in_output + if request.guided_choice is not None: + vllm_args["guided_choice"] = request.guided_choice + if request.guided_regex is not None: + vllm_args["guided_regex"] = request.guided_regex + if request.guided_json is not None: + vllm_args["guided_json"] = request.guided_json inference_request = SyncEndpointPredictV1Request( args=vllm_args, @@ -1918,6 +1944,12 @@ async def execute( args["logprobs"] = 1 if request.include_stop_str_in_output is not None: args["include_stop_str_in_output"] = request.include_stop_str_in_output + if request.guided_choice is not None: + args["guided_choice"] = request.guided_choice + if request.guided_regex is not None: + args["guided_regex"] = request.guided_regex + if request.guided_json is not None: + args["guided_json"] = request.guided_json args["stream"] = True elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: args = { diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 78e033bb..3c1cf851 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,3 +1,2 @@ -ray>=2.9 -vllm==0.3.2 +vllm==0.3.3 pydantic>=2.0 diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 5bd3f6e4..c4dd0eed 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -7,10 +7,12 @@ from typing import AsyncGenerator import uvicorn -from fastapi import BackgroundTasks, FastAPI, Request +from fastapi import BackgroundTasks, FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest +from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -38,7 +40,35 @@ async def generate(request: Request) -> Response: request_dict = await request.json() prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) + guided_json = request_dict.pop("guided_json", None) + guided_regex = request_dict.pop("guided_regex", None) + guided_choice = request_dict.pop("guided_choice", None) sampling_params = SamplingParams(**request_dict) + + # Dummy request to get guided decode logit processor + try: + partial_openai_request = OpenAICompletionRequest.model_validate( + { + "model": "", + "prompt": "", + "guided_json": guided_json, + "guided_regex": guided_regex, + "guided_choice": guided_choice, + } + ) + except Exception: + raise HTTPException( + status_code=400, detail="Bad request: failed to parse guided decoding parameters." + ) + + guided_decode_logit_processor = await get_guided_decoding_logits_processor( + partial_openai_request, engine.get_tokenizer() + ) + if guided_decode_logit_processor is not None: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(guided_decode_logit_processor) + request_id = random_uuid() results_generator = engine.generate(prompt, sampling_params, request_id) diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 61473b37..4b57afa1 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -3725,6 +3725,138 @@ def llm_model_endpoint_sync( return model_endpoint, model_endpoint_json +@pytest.fixture +def llm_model_endpoint_stream( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_2", + name="test_llm_model_endpoint_name_1", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.STREAMING, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + __root__=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + model_endpoint_json: Dict[str, Any] = { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "model_name": "llama-7b", + "source": "hugging_face", + "status": "READY", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + "spec": { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "endpoint_type": "streaming", + "destination": "test_destination", + "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", + "metadata": { + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + "bundle_name": "test_model_bundle_name_1", + "status": "READY", + "post_inference_hooks": ["callback"], + "default_callback_url": "http://www.example.com", + "default_callback_auth": { + "kind": "basic", + "username": "test_username", + "password": "test_password", + }, + "labels": {}, + "aws_role": "test_aws_role", + "results_s3_bucket": "test_s3_bucket", + "created_by": test_api_key, + "created_at": "2022-01-03T00:00:00", + "last_updated_at": "2022-01-03T00:00:00", + "deployment_state": { + "min_workers": 1, + "max_workers": 3, + "per_worker": 2, + "available_workers": 1, + "unavailable_workers": 1, + }, + "resource_state": { + "cpus": "1", + "gpus": 1, + "memory": "1G", + "gpu_type": "nvidia-tesla-t4", + "storage": "10G", + "optimize_costs": True, + }, + "num_queued_items": 1, + "public_inference": True, + }, + } + return model_endpoint, model_endpoint_json + + @pytest.fixture def llm_model_endpoint_sync_tgi( test_api_key: str, model_bundle_1: ModelBundle diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 10b37c7d..b2496ff9 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -51,6 +51,7 @@ UpdateLLMModelEndpointV1UseCase, _include_safetensors_bin_or_pt, infer_hardware_from_model_name, + validate_and_update_completion_params, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -602,6 +603,8 @@ async def test_completion_sync_use_case_success( llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], completion_sync_request: CompletionSyncV1Request, ): + completion_sync_request.include_stop_str_in_output = True + completion_sync_request.guided_json = {} fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( SyncEndpointPredictV1Response( @@ -987,6 +990,42 @@ async def test_completion_sync_use_case_not_sync_endpoint_raises( ) +@pytest.mark.asyncio +async def test_validate_and_update_completion_params(): + completion_sync_request = CompletionSyncV1Request( + prompt="What is machine learning?", + max_new_tokens=10, + temperature=0.5, + return_token_log_probs=True, + ) + + validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) + + validate_and_update_completion_params( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request + ) + + completion_sync_request.include_stop_str_in_output = True + with pytest.raises(ObjectHasInvalidValueException): + validate_and_update_completion_params( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request + ) + completion_sync_request.include_stop_str_in_output = None + + completion_sync_request.guided_regex = "" + completion_sync_request.guided_json = {} + completion_sync_request.guided_choice = [""] + with pytest.raises(ObjectHasInvalidValueException): + validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) + + completion_sync_request.guided_regex = None + completion_sync_request.guided_choice = None + with pytest.raises(ObjectHasInvalidValueException): + validate_and_update_completion_params( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request + ) + + @pytest.mark.asyncio @mock.patch( "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", @@ -1079,6 +1118,116 @@ async def test_completion_stream_use_case_success( i += 1 +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=7, +) +async def test_completion_stream_vllm_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_stream: Tuple[ModelEndpoint, Any], + completion_stream_request: CompletionStreamV1Request, +): + completion_stream_request.guided_json = {} + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_stream[0]) + fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": "I", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 1, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": " am", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 2, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": " a", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 3, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": " new", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 4, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": "bie", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 5, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": ".", + "finished": True, + "count_prompt_tokens": 7, + "count_output_tokens": 6, + } + }, + traceback=None, + ), + ] + use_case = CompletionStreamV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_stream[0].record.name, + request=completion_stream_request, + ) + output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] + i = 0 + async for message in response_1: + assert message.dict()["output"]["text"] == output_texts[i] + if i == 5: + assert message.dict()["output"]["num_prompt_tokens"] == 7 + assert message.dict()["output"]["num_completion_tokens"] == 6 + i += 1 + + @pytest.mark.asyncio @mock.patch( "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", From b785d2552cb69e532c1e960604d7b9734393c471 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:15:22 -0700 Subject: [PATCH 268/425] Add emitting token count metrics to datadog statsd (#458) We want to be able to view token count metrics. --- .../model_engine_server/api/dependencies.py | 7 +- .../model_engine_server/api/llms_v1.py | 4 + .../model_engine_server/common/dtos/llms.py | 16 +++ .../model_engine_server/core/utils/timer.py | 15 ++- .../infra/gateways/__init__.py | 2 + .../datadog_monitoring_metrics_gateway.py | 87 ++++++++++++++ .../tests/unit/core/utils/test_timer.py | 15 +++ ...test_datadog_monitoring_metrics_gateway.py | 106 ++++++++++++++++++ 8 files changed, 250 insertions(+), 2 deletions(-) create mode 100644 model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py create mode 100644 model-engine/tests/unit/core/utils/test_timer.py create mode 100644 model-engine/tests/unit/infra/gateways/test_datadog_monitoring_metrics_gateway.py diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 713938d1..eb2ee227 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -56,6 +56,7 @@ ABSFilesystemGateway, ABSLLMArtifactGateway, CeleryTaskQueueGateway, + DatadogMonitoringMetricsGateway, FakeMonitoringMetricsGateway, LiveAsyncModelEndpointInferenceGateway, LiveBatchJobOrchestrationGateway, @@ -159,7 +160,11 @@ class ExternalInterfaces: def get_default_monitoring_metrics_gateway() -> MonitoringMetricsGateway: - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + # dd_trace_enabled is a good enough proxy for determining if we should use Datadog + if hmi_config.dd_trace_enabled: + monitoring_metrics_gateway: MonitoringMetricsGateway = DatadogMonitoringMetricsGateway() + else: + monitoring_metrics_gateway = FakeMonitoringMetricsGateway() return monitoring_metrics_gateway diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 9660f0d0..614cc6bb 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -391,8 +391,11 @@ async def create_completion_stream_task( async def event_generator(): try: + time_to_first_token = None with timer() as use_case_timer: async for message in response: + if time_to_first_token is None and message.output is not None: + time_to_first_token = use_case_timer.lap() yield {"data": message.json()} background_tasks.add_task( external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, @@ -402,6 +405,7 @@ async def event_generator(): if message.output else None, total_duration=use_case_timer.duration, + time_to_first_token=time_to_first_token, ), metric_metadata, ) diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 9fb8ed1d..6f63e712 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -321,6 +321,8 @@ class TokenUsage(BaseModel): total_duration: Optional[float] = None """Includes time spent waiting for the model to be ready.""" + time_to_first_token: Optional[float] = None # Only for streaming requests + @property def num_total_tokens(self) -> int: return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) @@ -333,6 +335,20 @@ def total_tokens_per_second(self) -> float: else 0.0 ) + @property + def inter_token_latency(self) -> Optional[float]: # Only for streaming requests + # Note: we calculate a single inter-token latency for the entire request. + # Calculating latency between each token seems a bit heavyweight, although we can do this if we wanted + if ( + self.time_to_first_token is None + or self.num_completion_tokens is None + or self.total_duration is None + ): + return None + if self.num_completion_tokens < 2: + return None + return (self.total_duration - self.time_to_first_token) / (self.num_completion_tokens - 1) + class CreateFineTuneRequest(BaseModel): model: str diff --git a/model-engine/model_engine_server/core/utils/timer.py b/model-engine/model_engine_server/core/utils/timer.py index 53a6f8fe..5a2bd1be 100644 --- a/model-engine/model_engine_server/core/utils/timer.py +++ b/model-engine/model_engine_server/core/utils/timer.py @@ -33,7 +33,7 @@ class timer: # pylint: disable=invalid-name >>> f() """ - __slots__ = ("logger", "name", "_duration", "start") + __slots__ = ("logger", "name", "_duration", "start", "start_lap") def __init__(self, logger: Optional[Logger] = None, name: str = "") -> None: self.logger = logger @@ -42,6 +42,7 @@ def __init__(self, logger: Optional[Logger] = None, name: str = "") -> None: # for start, -1 is the uninitialized value # it is set at the context-block entering method: __enter__ self.start: float = -1.0 + self.start_lap: float = -1.0 def __enter__(self) -> "timer": """Records start time: context-block entering function.""" @@ -62,6 +63,18 @@ def __exit__(self, *args) -> None: ) self._maybe_log_end_time() + def lap(self) -> float: + # Records a "lap time". Specifically if start is called at t_0, and lap is + # called at t_1 and t_2, then the returned values are t_1 - t_0 and t_2 - t_1. + # This does introduce extra overhead, however. + current_time = time.monotonic() + if self.start_lap == -1: + duration = current_time - self.start + else: + duration = current_time - self.start_lap + self.start_lap = current_time + return duration + def _maybe_log_end_time(self) -> None: if self.logger is not None: caller_namespace = "" diff --git a/model-engine/model_engine_server/infra/gateways/__init__.py b/model-engine/model_engine_server/infra/gateways/__init__.py index b36fb641..5a0d7a90 100644 --- a/model-engine/model_engine_server/infra/gateways/__init__.py +++ b/model-engine/model_engine_server/infra/gateways/__init__.py @@ -6,6 +6,7 @@ from .batch_job_orchestration_gateway import BatchJobOrchestrationGateway from .batch_job_progress_gateway import BatchJobProgressGateway from .celery_task_queue_gateway import CeleryTaskQueueGateway +from .datadog_monitoring_metrics_gateway import DatadogMonitoringMetricsGateway from .fake_model_primitive_gateway import FakeModelPrimitiveGateway from .fake_monitoring_metrics_gateway import FakeMonitoringMetricsGateway from .live_async_model_endpoint_inference_gateway import LiveAsyncModelEndpointInferenceGateway @@ -31,6 +32,7 @@ "BatchJobOrchestrationGateway", "BatchJobProgressGateway", "CeleryTaskQueueGateway", + "DatadogMonitoringMetricsGateway", "FakeModelPrimitiveGateway", "FakeMonitoringMetricsGateway", "LiveAsyncModelEndpointInferenceGateway", diff --git a/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py new file mode 100644 index 00000000..8732615d --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py @@ -0,0 +1,87 @@ +from typing import List, Optional + +from datadog import statsd +from model_engine_server.common.dtos.llms import TokenUsage +from model_engine_server.core.config import infra_config +from model_engine_server.domain.gateways.monitoring_metrics_gateway import ( + MetricMetadata, + MonitoringMetricsGateway, +) + + +def get_model_tags(model_name: Optional[str]) -> List[str]: + """ + Returns a tag for the model name and whether it is a finetuned model + """ + tags = [] + if model_name: + parts = model_name.split(".") + tags.extend([f"model_name:{parts[0]}"]) + return tags + + +class DatadogMonitoringMetricsGateway(MonitoringMetricsGateway): + def __init__(self, prefix: str = "model_engine"): + self.prefix = prefix + self.tags = [f"env:{infra_config().env}"] + + def emit_attempted_build_metric(self): + statsd.increment("scale_launch.service_builder.attempt", tags=self.tags) + + def emit_successful_build_metric(self): + statsd.increment("scale_launch.service_builder.success", tags=self.tags) + + def emit_build_time_metric(self, duration_seconds: float): + statsd.distribution( + "scale_launch.service_builder.endpoint_build_time", duration_seconds, tags=self.tags + ) + + def emit_image_build_cache_hit_metric(self, image_type: str): + statsd.increment( + f"scale_launch.service_builder.{image_type}_image_cache_hit", tags=self.tags + ) + + def emit_image_build_cache_miss_metric(self, image_type: str): + statsd.increment( + f"scale_launch.service_builder.{image_type}_image_cache_miss", tags=self.tags + ) + + def emit_docker_failed_build_metric(self): + statsd.increment("scale_launch.service_builder.docker_failed", tags=self.tags) + + def emit_database_cache_hit_metric(self): + statsd.increment("scale_launch.database_cache.hit", tags=self.tags) + + def emit_database_cache_miss_metric(self): + statsd.increment("scale_launch.database_cache.miss", tags=self.tags) + + def _format_call_tags(self, metadata: MetricMetadata) -> List[str]: + tags = self.tags + tags.extend(get_model_tags(metadata.model_name)) + return tags + + def emit_route_call_metric(self, route: str, metadata: MetricMetadata): + statsd.increment(f"{self.prefix}.{route}.call", tags=self._format_call_tags(metadata)) + + def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMetadata): + tags = self._format_call_tags(metadata) + + token_count_metric = f"{self.prefix}.token_count" + statsd.increment( + f"{token_count_metric}.prompt", (token_usage.num_prompt_tokens or 0), tags=tags + ) + statsd.increment( + f"{token_count_metric}.completion", (token_usage.num_completion_tokens or 0), tags=tags + ) + statsd.increment(f"{token_count_metric}.total", token_usage.num_total_tokens, tags=tags) + + total_tokens_per_second = f"{self.prefix}.total_tokens_per_second" + statsd.histogram(total_tokens_per_second, token_usage.total_tokens_per_second, tags=tags) + + time_to_first_token = f"{self.prefix}.time_to_first_token" + if token_usage.time_to_first_token is not None: + statsd.distribution(time_to_first_token, token_usage.time_to_first_token, tags=tags) + + inter_token_latency = f"{self.prefix}.inter_token_latency" + if token_usage.inter_token_latency is not None: + statsd.distribution(inter_token_latency, token_usage.inter_token_latency, tags=tags) diff --git a/model-engine/tests/unit/core/utils/test_timer.py b/model-engine/tests/unit/core/utils/test_timer.py new file mode 100644 index 00000000..f5d3b2d1 --- /dev/null +++ b/model-engine/tests/unit/core/utils/test_timer.py @@ -0,0 +1,15 @@ +import time + +from model_engine_server.core.utils.timer import timer + + +def test_timer(): + with timer() as t: + time.sleep(0.1) + lap_time = t.lap() + time.sleep(0.01) + new_lap_time = t.lap() + + assert new_lap_time >= 0.009 + assert lap_time >= 0.09 + assert t.duration >= 0.1 diff --git a/model-engine/tests/unit/infra/gateways/test_datadog_monitoring_metrics_gateway.py b/model-engine/tests/unit/infra/gateways/test_datadog_monitoring_metrics_gateway.py new file mode 100644 index 00000000..e3e295a6 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_datadog_monitoring_metrics_gateway.py @@ -0,0 +1,106 @@ +from unittest.mock import Mock + +import pytest +from datadog import statsd +from model_engine_server.common.dtos.llms import TokenUsage +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata +from model_engine_server.infra.gateways import DatadogMonitoringMetricsGateway + + +@pytest.fixture(autouse=True) +def mock_statsd(): + # https://github.com/DataDog/datadogpy/issues/183 for how dd mocks statsd + statsd.socket = Mock() + # also mock the methods we use or may use, there might be more + statsd.gauge = Mock() + statsd.increment = Mock() + statsd.decrement = Mock() + statsd.histogram = Mock() + statsd.distribution = Mock() + + +@pytest.fixture +def sync_token_count(): + return TokenUsage( + num_prompt_tokens=100, + num_completion_tokens=200, + total_duration=30, + ) + + +@pytest.fixture +def streaming_token_count(): + return TokenUsage( + num_prompt_tokens=100, + num_completion_tokens=200, + total_duration=30, + time_to_first_token=5, + ) + + +@pytest.fixture +def datadog_monitoring_metrics_gateway(): + gateway = DatadogMonitoringMetricsGateway(prefix="model_engine_unit_test") + return gateway + + +def test_datadog_monitoring_metrics_gateway_build_metrics(datadog_monitoring_metrics_gateway): + datadog_monitoring_metrics_gateway.emit_attempted_build_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_successful_build_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_build_time_metric(300) + statsd.distribution.assert_called_once() + statsd.distribution.reset_mock() + datadog_monitoring_metrics_gateway.emit_image_build_cache_hit_metric("test_image") + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_image_build_cache_miss_metric("test_image_2") + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_docker_failed_build_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + + +def test_datadog_monitoring_metrics_gateway_db_metrics(datadog_monitoring_metrics_gateway): + datadog_monitoring_metrics_gateway.emit_database_cache_hit_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_database_cache_miss_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + + +def test_datadog_monitoring_metrics_gateway_route_call_metrics(datadog_monitoring_metrics_gateway): + metadata = MetricMetadata( + user=User(user_id="test_user", team_id="test_team", email="test_email"), + model_name="test_model", + ) + datadog_monitoring_metrics_gateway.emit_route_call_metric("test_route", metadata) + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + + +def test_datadog_monitoring_metrics_gateway_token_count_metrics( + datadog_monitoring_metrics_gateway, sync_token_count, streaming_token_count +): + metadata = MetricMetadata( + user=User(user_id="test_user", team_id="test_team", email="test_email"), + model_name="test_model", + ) + datadog_monitoring_metrics_gateway.emit_token_count_metrics(sync_token_count, metadata) + statsd.increment.assert_called() + statsd.increment.reset_mock() + statsd.histogram.assert_called() + statsd.histogram.reset_mock() + datadog_monitoring_metrics_gateway.emit_token_count_metrics(streaming_token_count, metadata) + statsd.increment.assert_called() + statsd.increment.reset_mock() + statsd.histogram.assert_called() + statsd.histogram.reset_mock() + statsd.distribution.assert_called() + statsd.distribution.reset_mock() From bdf4a25098c1dc05a35422f61b6760fb3a61205d Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Thu, 28 Mar 2024 13:26:26 -0700 Subject: [PATCH 269/425] Downgrade sse-starlette version (#478) --- model-engine/requirements.in | 4 ++-- model-engine/requirements.txt | 7 ++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 756df6c3..5cf95a51 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -47,9 +47,9 @@ sentencepiece==0.1.99 sh~=1.13 smart-open~=5.2 sqlalchemy[asyncio]~=2.0.4 -sse-starlette==2.0.0 +sse-starlette==1.6.1 sseclient-py==1.7.2 -starlette[full]>=0.35.0 # not used directly, but needs to be pinned for Microsoft security scan +starlette[full]>=0.36.2 # not used directly, but needs to be pinned for Microsoft security scan stringcase==1.2.0 tenacity>=6.0.0,<=6.2.0 testing-postgresql==1.3.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index bc0052c1..c261e668 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -22,7 +22,6 @@ anyio==3.7.1 # via # azure-core # httpx - # sse-starlette # starlette asgiref==3.7.2 # via uvicorn @@ -473,7 +472,7 @@ sqlalchemy[asyncio]==2.0.4 # via # -r model-engine/requirements.in # alembic -sse-starlette==2.0.0 +sse-starlette==1.6.1 # via -r model-engine/requirements.in sseclient-py==1.7.2 # via -r model-engine/requirements.in @@ -565,9 +564,7 @@ urllib3==1.26.16 # kubernetes-asyncio # requests uvicorn==0.17.6 - # via - # -r model-engine/requirements.in - # sse-starlette + # via -r model-engine/requirements.in uvloop==0.17.0 # via -r model-engine/requirements.in vine==5.1.0 From 5524f80344040d9e5fbf1c114a6a4d63b565f1f3 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 1 Apr 2024 09:57:08 -0700 Subject: [PATCH 270/425] Return 400 for botocore client errors (#479) * Return 400 for invalid SQS message * add tests --- .../model_engine_server/api/tasks_v1.py | 6 +++ .../gateways/celery_task_queue_gateway.py | 18 ++++++--- .../inference/test_async_inference.py | 24 ++++++++++++ model-engine/tests/unit/api/test_tasks.py | 38 +++++++++++++++++++ 4 files changed, 80 insertions(+), 6 deletions(-) diff --git a/model-engine/model_engine_server/api/tasks_v1.py b/model-engine/model_engine_server/api/tasks_v1.py index 524f2f46..663b3e0c 100644 --- a/model-engine/model_engine_server/api/tasks_v1.py +++ b/model-engine/model_engine_server/api/tasks_v1.py @@ -18,6 +18,7 @@ from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( EndpointUnsupportedInferenceTypeException, + InvalidRequestException, ObjectNotAuthorizedException, ObjectNotFoundException, UpstreamServiceError, @@ -66,6 +67,11 @@ async def create_async_inference_task( status_code=400, detail=f"Unsupported inference type: {str(exc)}", ) from exc + except InvalidRequestException as exc: + raise HTTPException( + status_code=400, + detail=f"Invalid request: {str(exc)}", + ) from exc @inference_task_router_v1.get("/async-tasks/{task_id}", response_model=GetAsyncTaskV1Response) diff --git a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index 7a8f6911..676a1274 100644 --- a/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional +import botocore from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, @@ -9,6 +10,7 @@ from model_engine_server.core.celery import TaskVisibility, celery_app from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import InvalidRequestException from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway logger = make_logger(logger_name()) @@ -68,12 +70,16 @@ def send_task( ) -> CreateAsyncTaskV1Response: celery_dest = self._get_celery_dest() - res = celery_dest.send_task( - name=task_name, - args=args, - kwargs=kwargs, - queue=queue_name, - ) + try: + res = celery_dest.send_task( + name=task_name, + args=args, + kwargs=kwargs, + queue=queue_name, + ) + except botocore.exceptions.ClientError as e: + logger.exception(f"Error sending task to queue {queue_name}: {e}") + raise InvalidRequestException(f"Error sending celery task: {e}") logger.info(f"Task {res.id} sent to queue {queue_name} from gateway") # pragma: no cover return CreateAsyncTaskV1Response(task_id=res.id) diff --git a/model-engine/tests/integration/inference/test_async_inference.py b/model-engine/tests/integration/inference/test_async_inference.py index e96164d7..db9bc9a7 100644 --- a/model-engine/tests/integration/inference/test_async_inference.py +++ b/model-engine/tests/integration/inference/test_async_inference.py @@ -4,7 +4,9 @@ import subprocess from functools import lru_cache from typing import Any, List, Optional, Tuple +from unittest.mock import MagicMock +import botocore import pytest import redis import requests @@ -17,6 +19,7 @@ TaskStatus, ) from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.domain.exceptions import InvalidRequestException from model_engine_server.infra.gateways import ( CeleryTaskQueueGateway, LiveAsyncModelEndpointInferenceGateway, @@ -157,3 +160,24 @@ def test_async_callbacks( assert actual_payload == expected_callback_payload assert callback_stats["last_auth"][callback_version] == expected_credentials + + +def test_async_callbacks_botocore_exception( + queue: str, +): + gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) + + mock_dest = MagicMock() + mock_dest.send_task = MagicMock( + side_effect=botocore.exceptions.ClientError(error_response={}, operation_name="") + ) + mock_get = MagicMock() + mock_get.return_value = mock_dest + gateway._get_celery_dest = mock_get + + with pytest.raises(InvalidRequestException): + gateway.send_task( + task_name="test_task", + queue_name=queue, + args=[1, 2], + ) diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index f9a0f062..3d019016 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -4,6 +4,7 @@ from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.domain.entities import ModelBundle, ModelEndpoint from model_engine_server.domain.exceptions import ( + InvalidRequestException, ObjectNotAuthorizedException, ObjectNotFoundException, UpstreamServiceError, @@ -104,6 +105,43 @@ def test_create_async_task_raises_404_not_found( assert response.status_code == 404 +def test_create_async_task_raises_400_invalid_requests( + model_bundle_1_v1: Tuple[ModelBundle, Any], + model_endpoint_1: Tuple[ModelEndpoint, Any], + endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]], + test_api_key: str, + get_test_client_wrapper, +): + assert model_endpoint_1[0].infra_state is not None + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={ + model_endpoint_1[0].record.id: model_endpoint_1[0].record, + }, + fake_model_endpoint_infra_gateway_contents={ + model_endpoint_1[0].infra_state.deployment_name: model_endpoint_1[0].infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + mock_use_case = MagicMock() + mock_use_case.return_value.execute = MagicMock(side_effect=InvalidRequestException) + with patch( + "model_engine_server.api.tasks_v1.CreateAsyncInferenceTaskV1UseCase", + mock_use_case, + ): + response = client.post( + "/v1/async-tasks?model_endpoint_id=invalid_model_endpoint_id", + auth=(test_api_key, ""), + json=endpoint_predict_request_1[1], + ) + assert response.status_code == 400 + + def test_get_async_task_success( model_bundle_1_v1: Tuple[ModelBundle, Any], model_endpoint_1: Tuple[ModelEndpoint, Any], From f187c00fa83286778472d2e2aa9d0b5c01ef56ca Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 2 Apr 2024 15:07:40 -0700 Subject: [PATCH 271/425] Increase Kaniko Memory (#481) * Increase Kaniko Memory * oop --- .../model_engine_server/core/docker/kaniko_template.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/core/docker/kaniko_template.yaml b/model-engine/model_engine_server/core/docker/kaniko_template.yaml index a5bb5384..4f842367 100644 --- a/model-engine/model_engine_server/core/docker/kaniko_template.yaml +++ b/model-engine/model_engine_server/core/docker/kaniko_template.yaml @@ -39,6 +39,7 @@ spec: - "--use-new-run" - "--image-fs-extract-retry=5" - "--log-format=json" + - "--push-retry=2" # The --use-new-run flag should fix docker builds eating up a lot of memory and consequently oom/failing env: - name: AWS_REGION @@ -50,11 +51,11 @@ spec: resources: requests: cpu: 3.5 - memory: 30Gi + memory: 90Gi ephemeral-storage: 80G limits: cpu: 3.5 - memory: 30Gi + memory: 90Gi ephemeral-storage: 80G volumes: - name: pipconf From 3d9ea7583af5289fc6201222b2521dedd27d55c3 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 2 Apr 2024 16:08:36 -0700 Subject: [PATCH 272/425] Batch job metrics (#480) * Batch job metrics * fix * fix * use DatadogInferenceMonitoringMetricsGateway * formmating * fix lint * fix * move tests * fix --- .../inference/batch_inference/Dockerfile_vllm | 6 +-- .../batch_inference/requirements.txt | 3 +- .../inference/batch_inference/vllm_batch.py | 23 ++++++++++- ...og_inference_monitoring_metrics_gateway.py | 33 ++++++++++++++++ ...og_inference_monitoring_metrics_gateway.py | 39 +++++++++++++++++++ 5 files changed, 99 insertions(+), 5 deletions(-) create mode 100644 model-engine/tests/unit/infra/gateways/test_datadog_inference_monitoring_metrics_gateway.py diff --git a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm index d0a3b36b..3b08756c 100644 --- a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm +++ b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm @@ -6,9 +6,6 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean -COPY model-engine/model_engine_server/inference/batch_inference/requirements.txt /workspace/requirements.txt -RUN pip install -r requirements.txt - RUN pip uninstall torch -y RUN pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cu121 @@ -18,6 +15,9 @@ RUN pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz +COPY model-engine/model_engine_server/inference/batch_inference/requirements.txt /workspace/requirements.txt +RUN pip install -r requirements.txt + COPY model-engine /workspace/model-engine RUN pip install -e /workspace/model-engine COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt index bbc99b04..e83b4ccd 100644 --- a/model-engine/model_engine_server/inference/batch_inference/requirements.txt +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -4,4 +4,5 @@ boto3==1.34.15 smart-open==6.4.0 ddtrace==2.4.0 docker==7.0.0 -func-timeout==4.3.5 \ No newline at end of file +func-timeout==4.3.5 +datadog==0.49.1 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 718d5e24..37b99b82 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -18,6 +18,9 @@ TokenOutput, ToolConfig, ) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) from model_engine_server.inference.tool_completion.tools import TOOL_MAP, BaseTool, Tools, tokenizer from tqdm import tqdm @@ -145,6 +148,7 @@ async def generate_with_tool( content: CreateBatchCompletionsRequestContent, prompts, tool: Type[BaseTool], + is_finetuned: bool, ): class IterativeGeneration: def __init__(self, prompt, max_new_tokens): @@ -195,6 +199,8 @@ def __repr__(self) -> str: content.top_p, [iter[0] for iter in iter_prompts], bar, + use_tool=True, + is_finetuned=is_finetuned, ) bar = tqdm( @@ -296,6 +302,7 @@ async def batch_inference(): model = ( MODEL_WEIGHTS_FOLDER if request.model_config.checkpoint_path else request.model_config.model ) + is_finetuned = request.model_config.checkpoint_path is not None llm = get_vllm_engine(model, request) @@ -313,7 +320,9 @@ async def batch_inference(): if request.tool_config is not None: tool_enum = Tools(request.tool_config.name) tool = TOOL_MAP[tool_enum] - outputs = await generate_with_tool(llm, request.tool_config, content, prompts, tool) + outputs = await generate_with_tool( + llm, request.tool_config, content, prompts, tool, is_finetuned + ) else: bar = tqdm(total=len(prompts), desc="Processed prompts") @@ -329,6 +338,8 @@ async def batch_inference(): content.top_p, prompts, bar, + use_tool=False, + is_finetuned=is_finetuned, ) bar.close() @@ -361,9 +372,15 @@ async def generate_with_vllm( top_p, prompts, bar, + use_tool, + is_finetuned, ) -> List[CompletionOutput]: # pragma: no cover from vllm import SamplingParams + model = (await engine.get_model_config()).model + + metrics_gateway = DatadogInferenceMonitoringMetricsGateway() + # Add the requests to the engine. results_generators = [] for idx, prompt in enumerate(prompts): @@ -414,6 +431,10 @@ async def generate_with_vllm( if return_token_log_probs: output.tokens = tokens + metrics_gateway.emit_batch_completions_metric( + model, use_tool, num_prompt_tokens, num_completion_tokens, is_finetuned + ) + outputs.append(output) return outputs diff --git a/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py index 802066cb..30aca62b 100644 --- a/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py @@ -18,3 +18,36 @@ def emit_async_task_received_metric(self, queue_name: str): def emit_async_task_stuck_metric(self, queue_name: str): statsd.increment("scale_launch.async_task.stuck.count", tags=[f"queue_name:{queue_name}"]) + + def emit_batch_completions_metric( + self, + model: str, + use_tool: bool, + num_prompt_tokens: int, + num_completion_tokens: int, + is_finetuned: bool, + ): + tags = [ + f"model:{model}", + f"use_tool:{use_tool}", + f"is_finetuned:{is_finetuned}", + ] + statsd.increment( + "model_engine.batch_inference.vllm.generation_count", + tags=tags, + ) + statsd.increment( + "model_engine.batch_inference.vllm.token_count.total", + num_prompt_tokens + num_completion_tokens, + tags=tags, + ) + statsd.increment( + "model_engine.batch_inference.vllm.token_count.completion", + num_completion_tokens, + tags=tags, + ) + statsd.increment( + "model_engine.batch_inference.vllm.token_count.prompt", + num_prompt_tokens, + tags=tags, + ) diff --git a/model-engine/tests/unit/infra/gateways/test_datadog_inference_monitoring_metrics_gateway.py b/model-engine/tests/unit/infra/gateways/test_datadog_inference_monitoring_metrics_gateway.py new file mode 100644 index 00000000..cb99d9b7 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_datadog_inference_monitoring_metrics_gateway.py @@ -0,0 +1,39 @@ +from unittest.mock import Mock + +import pytest +from datadog import statsd +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) + + +@pytest.fixture(autouse=True) +def mock_statsd(): + # https://github.com/DataDog/datadogpy/issues/183 for how dd mocks statsd + statsd.socket = Mock() + # also mock the methods we use or may use, there might be more + statsd.gauge = Mock() + statsd.increment = Mock() + statsd.decrement = Mock() + statsd.histogram = Mock() + statsd.distribution = Mock() + + +@pytest.fixture +def datadog_inference_monitoring_metrics_gateway(): + return DatadogInferenceMonitoringMetricsGateway() + + +def test_datadog_inference_monitoring_metrics_gateway_batch_completion_metrics( + datadog_inference_monitoring_metrics_gateway, +): + model = "test_model" + use_tool = True + num_prompt_tokens = 100 + num_completion_tokens = 200 + is_finetuned = True + datadog_inference_monitoring_metrics_gateway.emit_batch_completions_metric( + model, use_tool, num_prompt_tokens, num_completion_tokens, is_finetuned + ) + statsd.increment.assert_called() + statsd.increment.reset_mock() From e924ffa1afbb63195eeb2e617b00714c93ce971a Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 3 Apr 2024 11:52:53 -0700 Subject: [PATCH 273/425] Use base model name as metric tag (#483) --- .../inference/batch_inference/vllm_batch.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 37b99b82..f0dbb041 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -149,6 +149,7 @@ async def generate_with_tool( prompts, tool: Type[BaseTool], is_finetuned: bool, + model: str, ): class IterativeGeneration: def __init__(self, prompt, max_new_tokens): @@ -201,6 +202,7 @@ def __repr__(self) -> str: bar, use_tool=True, is_finetuned=is_finetuned, + model=model, ) bar = tqdm( @@ -321,7 +323,13 @@ async def batch_inference(): tool_enum = Tools(request.tool_config.name) tool = TOOL_MAP[tool_enum] outputs = await generate_with_tool( - llm, request.tool_config, content, prompts, tool, is_finetuned + llm, + request.tool_config, + content, + prompts, + tool, + is_finetuned, + request.model_config.model, ) else: bar = tqdm(total=len(prompts), desc="Processed prompts") @@ -340,6 +348,7 @@ async def batch_inference(): bar, use_tool=False, is_finetuned=is_finetuned, + model=request.model_config.model, ) bar.close() @@ -374,11 +383,10 @@ async def generate_with_vllm( bar, use_tool, is_finetuned, + model, ) -> List[CompletionOutput]: # pragma: no cover from vllm import SamplingParams - model = (await engine.get_model_config()).model - metrics_gateway = DatadogInferenceMonitoringMetricsGateway() # Add the requests to the engine. From 2b4466baedef454f017537ca35dcfe9bce537e03 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Wed, 3 Apr 2024 18:14:26 -0700 Subject: [PATCH 274/425] Change LLM Engine base path from global var (#482) --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/api_engine.py | 45 +++++++++++++++++--------- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 33 insertions(+), 18 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index dfae78cf..09dd8526 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b28" +__version__ = "0.0.0b29" import os from typing import Sequence diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index 3abf86d6..a1b955be 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -12,12 +12,23 @@ from llmengine.errors import parse_error SPELLBOOK_API_URL = "https://api.spellbook.scale.com/llm-engine/" -LLM_ENGINE_BASE_PATH = os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) DEFAULT_TIMEOUT: int = 10 +base_path = None api_key = None +def set_base_path(path): + global base_path + base_path = path + + +def get_base_path() -> str: + if base_path is not None: + return base_path + return os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) + + def set_api_key(key): global api_key api_key = key @@ -33,7 +44,7 @@ def get_api_key() -> str: def assert_self_hosted(func): @wraps(func) def inner(*args, **kwargs): - if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH: + if SPELLBOOK_API_URL == get_base_path(): raise ValueError("This feature is only available for self-hosted users.") return func(*args, **kwargs) @@ -43,16 +54,17 @@ def inner(*args, **kwargs): class APIEngine: @classmethod def validate_api_key(cls): - if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH and not get_api_key(): + if SPELLBOOK_API_URL == get_base_path() and not get_api_key(): raise ValueError( "You must set SCALE_API_KEY in your environment to to use the LLM Engine API." ) @classmethod def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.get( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), timeout=timeout, headers={"x-api-key": api_key}, auth=(api_key, ""), @@ -66,9 +78,10 @@ def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: def put( cls, resource_name: str, data: Optional[Dict[str, Any]], timeout: int ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.put( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), json=data, timeout=timeout, headers={"x-api-key": api_key}, @@ -81,9 +94,10 @@ def put( @classmethod def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.delete( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), timeout=timeout, headers={"x-api-key": api_key}, auth=(api_key, ""), @@ -95,9 +109,10 @@ def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: @classmethod def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), json=data, timeout=timeout, headers={"x-api-key": api_key}, @@ -112,9 +127,10 @@ def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Di def post_stream( cls, resource_name: str, data: Dict[str, Any], timeout: int ) -> Iterator[Dict[str, Any]]: + base_path = get_base_path() api_key = get_api_key() response = requests.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), json=data, timeout=timeout, headers={"x-api-key": api_key}, @@ -144,9 +160,10 @@ def post_stream( def post_file( cls, resource_name: str, files: Dict[str, BufferedReader], timeout: int ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), files=files, timeout=timeout, headers={"x-api-key": api_key}, @@ -161,15 +178,14 @@ def post_file( async def apost_sync( cls, resource_name: str, data: Dict[str, Any], timeout: int ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() async with ClientSession( timeout=ClientTimeout(timeout), headers={"x-api-key": api_key}, auth=BasicAuth(api_key, ""), ) as session: - async with session.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data - ) as resp: + async with session.post(urljoin(base_path, resource_name), json=data) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.read()) payload = await resp.json() @@ -179,15 +195,14 @@ async def apost_sync( async def apost_stream( cls, resource_name: str, data: Dict[str, Any], timeout: int ) -> AsyncIterable[Dict[str, Any]]: + base_path = get_base_path() api_key = get_api_key() async with ClientSession( timeout=ClientTimeout(timeout), headers={"x-api-key": api_key}, auth=BasicAuth(api_key, ""), ) as session: - async with session.post( - urljoin(LLM_ENGINE_BASE_PATH, resource_name), json=data - ) as resp: + async with session.post(urljoin(base_path, resource_name), json=data) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.read()) async for byte_payload in resp.content: diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 8ddec08f..97719609 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta28" +version = "0.0.0.beta29" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index a33e6a03..5da0008a 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta28", + version="0.0.0.beta29", packages=find_packages(), ) From 077c5a5206d12bf624420d55baae3d89423bbbfa Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Thu, 4 Apr 2024 15:20:24 -0700 Subject: [PATCH 275/425] Remove fine-tune limit for internal users (#484) --- .../use_cases/llm_fine_tuning_use_cases.py | 16 ++++++---------- .../tests/unit/domain/test_llm_use_cases.py | 8 ++++---- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py index 268569be..0aaead3b 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -28,7 +28,6 @@ REQUIRED_COLUMNS = ["prompt", "response"] MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER = 5 -MAX_LLM_ENDPOINTS_PER_INTERNAL_USER = 15 MAX_SUFFIX_LENGTH = 28 # k8s labels need to be <= 62 characters, timestamp takes 13 characters, 2 characters for periods, @@ -115,17 +114,14 @@ async def execute(self, user: User, request: CreateFineTuneRequest) -> CreateFin current_jobs_and_endpoints = len(in_progress_jobs) + len(model_endpoints) - max_llm_endpoints_per_user = ( - MAX_LLM_ENDPOINTS_PER_INTERNAL_USER - if user.is_privileged_user - else MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER - ) - - if current_jobs_and_endpoints >= max_llm_endpoints_per_user: + if ( + not user.is_privileged_user + and current_jobs_and_endpoints >= MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER + ): raise LLMFineTuningQuotaReached( - f"Limit {max_llm_endpoints_per_user} fine-tunes/fine-tuned endpoints per user. " + f"Limit {MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER} fine-tunes/fine-tuned endpoints per user. " f"Cancel/delete a total of " - f"{current_jobs_and_endpoints - max_llm_endpoints_per_user + 1} pending or " + f"{current_jobs_and_endpoints - MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER + 1} pending or " f"running fine-tune(s) or fine-tuned endpoints to run another fine-tune." ) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index b2496ff9..1c3fb086 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -33,7 +33,7 @@ UpstreamServiceError, ) from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( - MAX_LLM_ENDPOINTS_PER_INTERNAL_USER, + MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER, CreateFineTuneV1UseCase, GetFineTuneEventsV1UseCase, is_model_name_suffix_valid, @@ -1416,7 +1416,7 @@ async def test_create_fine_tune_limit( fake_llm_fine_tuning_events_repository, fake_file_storage_gateway, ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=False) request = CreateFineTuneRequest( model="base_model", training_file="file1", @@ -1425,8 +1425,8 @@ async def test_create_fine_tune_limit( hyperparameters={}, suffix=None, ) - for i in range(MAX_LLM_ENDPOINTS_PER_INTERNAL_USER): - if i == MAX_LLM_ENDPOINTS_PER_INTERNAL_USER: + for i in range(MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER): + if i == MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER: with pytest.raises(LLMFineTuningQuotaReached): await use_case.execute(user=user, request=request) else: From c46162a4af68822b3af8fdd2d1cc55f138cae274 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 5 Apr 2024 16:27:07 -0700 Subject: [PATCH 276/425] Parallel Python execution for tool completion (#470) * test * Update * Don't fail * dont' fail * Accurate CPU count * format * comments --- .../batch_inference/sample_config_tool.json | 2 +- .../inference/batch_inference/vllm_batch.py | 40 +++++++++++++++---- .../inference/vllm/vllm_server.py | 12 ++++-- 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json b/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json index d9a3af4a..3f21befe 100644 --- a/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json +++ b/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json @@ -3,7 +3,7 @@ "output_data_path":"./sample_output_tool.json", "model_config":{ "model":"mistral-7b", - "checkpoint_path":"s3://scale-ml/models/mistral-7b", + "checkpoint_path":"my_path", "num_shards": 1, "labels": {"team": "my_team"} }, diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index f0dbb041..5bb30a4b 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -1,10 +1,12 @@ import asyncio import json +import multiprocessing import os import subprocess import sys import time import uuid +from multiprocessing.pool import ThreadPool from typing import List, Optional, Type from urllib.parse import urlparse @@ -31,6 +33,23 @@ os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") +def get_cpu_cores_in_container(): + cpu_count = multiprocessing.cpu_count() + try: + with open("/sys/fs/cgroup/cpu/cpu.cfs_quota_us") as fp: + cfs_quota_us = int(fp.read()) + with open("/sys/fs/cgroup/cpu/cpu.cfs_period_us") as fp: + cfs_period_us = int(fp.read()) + if cfs_quota_us != -1: + cpu_count = cfs_quota_us // cfs_period_us + except FileNotFoundError: + pass + return cpu_count + + +CPU_COUNT = get_cpu_cores_in_container() + + def get_s3_client(): session = boto3.Session(profile_name=os.getenv("S3_WRITE_AWS_PROFILE")) return session.client("s3", region_name=AWS_REGION) @@ -210,7 +229,8 @@ def __repr__(self) -> str: desc=f"Running tools, iteration {num_iters}", file=sys.stdout, ) - for i in range(len(iter_prompts)): + + def tool_func(i): bar.update(1) response = outputs[i] gen_item = generations[iter_prompts[i][1]] @@ -225,7 +245,7 @@ def __repr__(self) -> str: # break the loop if generation is complete even if remaining_tokens>0 if len(new_text) == 0: gen_item.completed = True - continue + return # To-do write tools to receive response object itself rather than the text try: @@ -273,7 +293,9 @@ def tool_func(text: str, past_context: Optional[str]): or gen_item.remaining_tokens <= 0 ): gen_item.completed = True - continue + + pool = ThreadPool(CPU_COUNT) + pool.map(tool_func, range(len(iter_prompts))) results = [ CompletionOutput( @@ -450,9 +472,11 @@ async def generate_with_vllm( def get_gpu_free_memory(): # pragma: no cover """Get GPU free memory using nvidia-smi.""" try: - output = subprocess.check_output( - ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"] - ).decode("utf-8") + output = subprocess.run( + ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"], + capture_output=True, + text=True, + ).stdout gpu_memory = [int(x) for x in output.strip().split("\n")] return gpu_memory except subprocess.CalledProcessError: @@ -471,7 +495,9 @@ def check_unknown_startup_memory_usage(): # pragma: no cover f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." ) # nosemgrep - output = subprocess.check_output(["fuser -v /dev/nvidia*"], shell=True).decode("utf-8") + output = subprocess.run( + ["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True + ).stdout print(f"Processes using GPU: {output}") diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index c4dd0eed..d9b502ef 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -125,9 +125,11 @@ async def abort_request() -> None: def get_gpu_free_memory(): """Get GPU free memory using nvidia-smi.""" try: - output = subprocess.check_output( - ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"] - ).decode("utf-8") + output = subprocess.run( + ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"], + capture_output=True, + text=True, + ).stdout gpu_memory = [int(x) for x in output.strip().split("\n")] return gpu_memory except subprocess.CalledProcessError: @@ -145,7 +147,9 @@ def check_unknown_startup_memory_usage(): f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." ) # nosemgrep - output = subprocess.check_output(["fuser -v /dev/nvidia*"], shell=True).decode("utf-8") + output = subprocess.run( + ["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True + ).stdout print(f"Processes using GPU: {output}") From 85231413fdf0a85dd96cc678f3543f8ae5c35e6b Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 9 Apr 2024 10:15:47 -0700 Subject: [PATCH 277/425] Allow JSONL for fine-tuning datasets --- .../use_cases/llm_fine_tuning_use_cases.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py index 0aaead3b..70da8a9e 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -1,5 +1,6 @@ import csv import datetime +import json import re from typing import Optional @@ -25,7 +26,7 @@ from model_engine_server.domain.services import LLMFineTuningService, ModelEndpointService DEFAULT_FINE_TUNING_METHOD = "lora" -REQUIRED_COLUMNS = ["prompt", "response"] +REQUIRED_COLUMNS = [["prompt", "response"], ["input", "output"]] MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER = 5 @@ -52,6 +53,7 @@ def ensure_model_name_is_valid_k8s_label(model_name: str): def read_csv_headers(file_location: str): """ Read the headers of a csv file. + This will also parse for a JSONL file and will return the first row of the file split by comma. """ with smart_open.open(file_location, transport_params=dict(buffer_size=1024)) as file: csv_reader = csv.DictReader(file) @@ -63,18 +65,30 @@ def are_dataset_headers_valid(file_location: str): Ensure the dataset headers are valid with required columns 'prompt' and 'response'. """ current_headers = read_csv_headers(file_location) - return all(required_header in current_headers for required_header in REQUIRED_COLUMNS) + first_line = ",".join(current_headers) + try: + object = json.loads(first_line) # JSONL file format + current_headers = object.keys() + except json.decoder.JSONDecodeError: # CSV file format + pass + return any( + [ + all(header in current_headers for header in header_group) + for header_group in REQUIRED_COLUMNS + ] + ) def check_file_is_valid(file_name: Optional[str], file_type: str): """ Ensure the file is valid with required columns 'prompt' and 'response', isn't malformatted, and exists. + Accepts CSV and JSONL formats. file_type: 'training' or 'validation' """ try: if file_name is not None and not are_dataset_headers_valid(file_name): raise InvalidRequestException( - f"Required column headers {','.join(REQUIRED_COLUMNS)} not found in {file_type} dataset" + f"Required column headers (one subset of {REQUIRED_COLUMNS}) not found in {file_type} dataset" ) except FileNotFoundError: raise InvalidRequestException( From 38d94de9041555e0195753308747080c3de77e9a Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 9 Apr 2024 17:34:01 -0700 Subject: [PATCH 278/425] Fix throughput_benchmarks ITL calculation, add option to use a json file of prompts (#485) * Fix throughput_benchmarks ITL calculation * fix a div/0 * add prompts file override option * undo vllm version changes * fix token skipping in vllm localhost case * rerun unit test --------- Co-authored-by: Yunfeng Bai --- scripts/throughput_benchmarks.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/scripts/throughput_benchmarks.py b/scripts/throughput_benchmarks.py index d67614a5..7d2cb32f 100644 --- a/scripts/throughput_benchmarks.py +++ b/scripts/throughput_benchmarks.py @@ -68,6 +68,10 @@ def send_request(url, request, user=None): inter_token_latencies = [] last_token_time = None for byte_payload in response.iter_lines(): + # Skip line + if byte_payload == b"\n" or byte_payload == b"": + continue + token_time = time.time() if first_line: time_to_first_token = token_time - start @@ -77,10 +81,6 @@ def send_request(url, request, user=None): inter_token_latencies.append(token_time - last_token_time) last_token_time = token_time - # Skip line - if byte_payload == b"\n": - continue - payload = byte_payload.decode("utf-8") # Event data @@ -174,11 +174,16 @@ def send_requests( concurrency: int, framework: InferenceFramework, local_port: int = 5005, + prompts_list_override: Optional[List] = None, ): thread_results: queue.Queue = queue.Queue() requests_queue: queue.Queue = queue.Queue() - for output_token_count in output_token_counts: - request = generate_request(framework, prompt, output_token_count, use_localhost) + for i, output_token_count in enumerate(output_token_counts): + if prompts_list_override is not None: + new_prompt = prompts_list_override[i % len(prompts_list_override)] + else: + new_prompt = prompt + request = generate_request(framework, new_prompt, output_token_count, use_localhost) requests_queue.put(request) threads = [] for i in range(concurrency): @@ -239,7 +244,7 @@ def generate_output_token_counts_from_existing( return output -def read_distribution_from_file(fpath: str): +def read_data_from_json_file(fpath: str): # Assumes the distribution is some json-formatted string that represents a list try: with open(fpath, "r") as fin: @@ -260,6 +265,7 @@ def run_benchmark( verbose: bool, local_port: int, response_token_count_distribution: Optional[List] = None, + prompts_list_override: Optional[List] = None, ): prompt = generate_prompt(config.input_token_count, hf_model) @@ -286,6 +292,7 @@ def run_benchmark( concurrency, framework, local_port=local_port, + prompts_list_override=prompts_list_override, ) end = time.time() elapsed = end - start @@ -302,7 +309,7 @@ def run_benchmark( all_inter_token_latencies = [] # one value per token (except the first generated token) for result in results: avg_time_per_token = (result["total_time"] - result["time_to_first_token"]) / ( - result["num_completion_tokens"] - 1 + max(1, result["num_completion_tokens"] - 1) ) time_to_first_token.append(result["time_to_first_token"]) time_to_process_prompt.append(result["time_to_first_token"] - avg_time_per_token) @@ -387,6 +394,7 @@ def run_benchmarks( hf_model: Optional[str] = None, local_port: int = 5005, response_token_count_distribution_file: Optional[str] = None, + prompts_list_override_file: Optional[str] = None, ): """Run benchmarks.""" all_statistics = [] @@ -394,9 +402,12 @@ def run_benchmarks( response_token_count_distribution = None if response_token_count_distribution_file is not None: - response_token_count_distribution = read_distribution_from_file( + response_token_count_distribution = read_data_from_json_file( response_token_count_distribution_file ) + prompts_list_override = None + if prompts_list_override_file is not None: + prompts_list_override = read_data_from_json_file(prompts_list_override_file) try: if verbose: @@ -418,6 +429,7 @@ def run_benchmarks( verbose, local_port, response_token_count_distribution, + prompts_list_override, ) all_statistics.append(statistics) except Exception: @@ -448,6 +460,7 @@ def run_benchmarks_concurrency_range( hf_model: Optional[str] = None, local_port: int = 5005, response_token_count_distribution_file: Optional[str] = None, + prompts_list_override_file: Optional[str] = None, ): if output_file is not None: # Create empty file @@ -467,6 +480,7 @@ def run_benchmarks_concurrency_range( hf_model, local_port, response_token_count_distribution_file, + prompts_list_override_file, ) From 3c7d40bd4996831492cf74d47d91e69d080ec06c Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Thu, 11 Apr 2024 11:40:21 -0700 Subject: [PATCH 279/425] Add Model.update() to Python client (#490) --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/data_types.py | 45 +++++++ clients/python/llmengine/model.py | 177 +++++++++++++++++++++++++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- docs/api/python_client.md | 1 + 6 files changed, 226 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 09dd8526..df8ffbd0 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b29" +__version__ = "0.0.0b30" import os from typing import Sequence diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index a3ed3209..07d2622b 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -254,6 +254,51 @@ class ListLLMEndpointsResponse(BaseModel): """ +class UpdateLLMEndpointRequest(BaseModel): + # LLM specific fields + model_name: Optional[str] + source: Optional[LLMSource] + inference_framework_image_tag: Optional[str] + num_shards: Optional[int] + """ + Number of shards to distribute the model onto GPUs. + """ + + quantize: Optional[Quantization] + """ + Whether to quantize the model. + """ + + checkpoint_path: Optional[str] + """ + Path to the checkpoint to load the model from. + """ + + # General endpoint fields + metadata: Optional[Dict[str, Any]] + post_inference_hooks: Optional[List[str]] + cpus: Optional[CpuSpecificationType] + gpus: Optional[int] + memory: Optional[StorageSpecificationType] + gpu_type: Optional[GpuType] + storage: Optional[StorageSpecificationType] + optimize_costs: Optional[bool] + min_workers: Optional[int] + max_workers: Optional[int] + per_worker: Optional[int] + labels: Optional[Dict[str, str]] + prewarm: Optional[bool] + high_priority: Optional[bool] + billing_tags: Optional[Dict[str, Any]] + default_callback_url: Optional[HttpUrl] + default_callback_auth: Optional[CallbackAuth] + public_inference: Optional[bool] + + +class UpdateLLMEndpointResponse(BaseModel): + endpoint_creation_task_id: str + + class DeleteLLMEndpointResponse(BaseModel): """ Response object for deleting a Model. diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 35e26631..b242abea 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -15,6 +15,8 @@ ModelEndpointType, PostInferenceHooks, Quantization, + UpdateLLMEndpointRequest, + UpdateLLMEndpointResponse, ) @@ -416,6 +418,181 @@ def list(cls) -> ListLLMEndpointsResponse: response = cls._get("v1/llm/model-endpoints", timeout=DEFAULT_TIMEOUT) return ListLLMEndpointsResponse.parse_obj(response) + @classmethod + @assert_self_hosted + def update( + cls, + name: str, + # LLM specific fields + model: Optional[str] = None, + inference_framework_image_tag: Optional[str] = None, + source: Optional[LLMSource] = None, + num_shards: Optional[int] = None, + quantize: Optional[Quantization] = None, + checkpoint_path: Optional[str] = None, + # General endpoint fields + cpus: Optional[int] = None, + memory: Optional[str] = None, + storage: Optional[str] = None, + gpus: Optional[int] = None, + min_workers: Optional[int] = None, + max_workers: Optional[int] = None, + per_worker: Optional[int] = None, + endpoint_type: Optional[ModelEndpointType] = None, + gpu_type: Optional[str] = None, + high_priority: Optional[bool] = None, + post_inference_hooks: Optional[List[PostInferenceHooks]] = None, + default_callback_url: Optional[str] = None, + public_inference: Optional[bool] = None, + labels: Optional[Dict[str, str]] = None, + ) -> UpdateLLMEndpointResponse: + """ + Update an LLM model. Note: This API is only available for self-hosted users. + + Args: + name (`str`): + Name of the endpoint + + model (`Optional[str]`): + Name of the base model + + inference_framework_image_tag (`Optional[str]`): + Image tag for the inference framework. Use "latest" for the most recent image + + source (`Optional[LLMSource]`): + Source of the LLM. Currently only HuggingFace is supported + + num_shards (`Optional[int]`): + Number of shards for the LLM. When bigger than 1, LLM will be sharded + to multiple GPUs. Number of GPUs must be equal or larger than num_shards. + + quantize (`Optional[Quantization]`): + Quantization method for the LLM. `text_generation_inference` supports `bitsandbytes` and `vllm` supports `awq`. + + checkpoint_path (`Optional[str]`): + Remote path to the checkpoint for the LLM. LLM engine must have permission to access the given path. + Can be either a folder or a tar file. Folder is preferred since we don't need to untar and model loads faster. + For model weights, safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be longer). + + cpus (`Optional[int]`): + Number of cpus each worker should get, e.g. 1, 2, etc. This must be greater + than or equal to 1. Recommendation is set it to 8 * GPU count. + + memory (`Optional[str]`): + Amount of memory each worker should get, e.g. "4Gi", "512Mi", etc. This must + be a positive amount of memory. Recommendation is set it to 24Gi * GPU count. + + storage (`Optional[str]`): + Amount of local ephemeral storage each worker should get, e.g. "4Gi", + "512Mi", etc. This must be a positive amount of storage. + Recommendataion is 40Gi for 7B models, 80Gi for 13B models and 200Gi for 70B models. + + gpus (`Optional[int]`): + Number of gpus each worker should get, e.g. 0, 1, etc. + + min_workers (`Optional[int]`): + The minimum number of workers. Must be greater than or equal to 0. This + should be determined by computing the minimum throughput of your workload and + dividing it by the throughput of a single worker. When this number is 0, + max_workers must be 1, and the endpoint will autoscale between + 0 and 1 pods. When this number is greater than 0, max_workers can be any number + greater or equal to min_workers. + + max_workers (`Optional[int]`): + The maximum number of workers. Must be greater than or equal to 0, + and as well as greater than or equal to ``min_workers``. This should be determined by + computing the maximum throughput of your workload and dividing it by the throughput + of a single worker + + per_worker (`Optional[int]`): + The maximum number of concurrent requests that an individual worker can + service. LLM engine automatically scales the number of workers for the endpoint so that + each worker is processing ``per_worker`` requests, subject to the limits defined by + ``min_workers`` and ``max_workers`` + - If the average number of concurrent requests per worker is lower than + ``per_worker``, then the number of workers will be reduced. - Otherwise, + if the average number of concurrent requests per worker is higher than + ``per_worker``, then the number of workers will be increased to meet the elevated + traffic. + Here is our recommendation for computing ``per_worker``: + 1. Compute ``min_workers`` and ``max_workers`` per your minimum and maximum + throughput requirements. 2. Determine a value for the maximum number of + concurrent requests in the workload. Divide this number by ``max_workers``. Doing + this ensures that the number of workers will "climb" to ``max_workers``. + + endpoint_type (`Optional[ModelEndpointType]`): + Currently only ``"streaming"`` endpoints are supported. + + gpu_type (`Optional[str]`): + If specifying a non-zero number of gpus, this controls the type of gpu + requested. Here are the supported values: + + - ``nvidia-tesla-t4`` + - ``nvidia-ampere-a10`` + - ``nvidia-ampere-a100`` + - ``nvidia-ampere-a100e`` + + high_priority (`Optional[bool]`): + Either ``True`` or ``False``. Enabling this will allow the created + endpoint to leverage the shared pool of prewarmed nodes for faster spinup time + + post_inference_hooks (`Optional[List[PostInferenceHooks]]`): + List of hooks to trigger after inference tasks are served + + default_callback_url (`Optional[str]`): + The default callback url to use for sync completion requests. + This can be overridden in the task parameters for each individual task. + post_inference_hooks must contain "callback" for the callback to be triggered + + public_inference (`Optional[bool]`): + If ``True``, this endpoint will be available to all user IDs for + inference + + labels (`Optional[Dict[str, str]]`): + An optional dictionary of key/value pairs to associate with this endpoint + Returns: + UpdateLLMEndpointResponse: creation task ID of the updated Model. Currently not used. + """ + post_inference_hooks_strs = None + if post_inference_hooks is not None: + post_inference_hooks_strs = [] + for hook in post_inference_hooks: + if isinstance(hook, PostInferenceHooks): + post_inference_hooks_strs.append(hook.value) + else: + post_inference_hooks_strs.append(hook) + + request = UpdateLLMEndpointRequest( + model_name=model, + source=source, + inference_framework_image_tag=inference_framework_image_tag, + num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, + cpus=cpus, + endpoint_type=ModelEndpointType(endpoint_type) if endpoint_type is not None else None, + gpus=gpus, + gpu_type=GpuType(gpu_type) if gpu_type is not None else None, + labels=labels, + max_workers=max_workers, + memory=memory, + metadata={}, + min_workers=min_workers, + per_worker=per_worker, + high_priority=high_priority, + post_inference_hooks=post_inference_hooks_strs, + # Pydantic automatically validates the url + default_callback_url=default_callback_url, # type: ignore + storage=storage, + public_inference=public_inference, + ) + response = cls.put( + resource_name=f"v1/llm/model-endpoints/{name}", + data=request.dict(), + timeout=DEFAULT_TIMEOUT, + ) + return UpdateLLMEndpointResponse.parse_obj(response) + @classmethod def delete(cls, model_endpoint_name: str) -> DeleteLLMEndpointResponse: """ diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 97719609..24c5a3b5 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta29" +version = "0.0.0.beta30" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 5da0008a..eaeb7507 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta29", + version="0.0.0.beta30", packages=find_packages(), ) diff --git a/docs/api/python_client.md b/docs/api/python_client.md index c9e22723..3e338388 100644 --- a/docs/api/python_client.md +++ b/docs/api/python_client.md @@ -22,6 +22,7 @@ - create - get - list + - update - delete - download From 740c12addb0161671603470e2a5070eac24da1c5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 12:34:47 -0700 Subject: [PATCH 280/425] Bump idna from 3.4 to 3.7 in /clients/python (#491) Bumps [idna](https://github.com/kjd/idna) from 3.4 to 3.7. - [Release notes](https://github.com/kjd/idna/releases) - [Changelog](https://github.com/kjd/idna/blob/master/HISTORY.rst) - [Commits](https://github.com/kjd/idna/compare/v3.4...v3.7) --- updated-dependencies: - dependency-name: idna dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- clients/python/poetry.lock | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock index f2d221f7..8d98a933 100644 --- a/clients/python/poetry.lock +++ b/clients/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -483,13 +483,13 @@ files = [ [[package]] name = "idna" -version = "3.4" +version = "3.7" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.5" files = [ - {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, - {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, + {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, + {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] [[package]] @@ -1234,4 +1234,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "deae6349cd55b6da7e03a9a858e7bbfb678e97982b34324cef3af0be5dfa3a4a" +content-hash = "e172656b142f767ce252f458226edc093bec9cee800a0a608340742d11bfa911" From 795d6241e3092a43321d89a4cb25cf6f14a8e948 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:17:02 -0700 Subject: [PATCH 281/425] Bump idna from 3.4 to 3.7 in /model-engine (#492) Bumps [idna](https://github.com/kjd/idna) from 3.4 to 3.7. - [Release notes](https://github.com/kjd/idna/releases) - [Changelog](https://github.com/kjd/idna/blob/master/HISTORY.rst) - [Commits](https://github.com/kjd/idna/compare/v3.4...v3.7) --- updated-dependencies: - dependency-name: idna dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> --- model-engine/requirements.txt | 130 +++++++++++++++++----------------- 1 file changed, 66 insertions(+), 64 deletions(-) diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index c261e668..623b12af 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -8,14 +8,14 @@ aiofiles==23.1.0 # via quart aiohttp==3.9.2 # via - # -r model-engine/requirements.in + # -r requirements.in # kubernetes-asyncio aioredis==2.0.1 - # via -r model-engine/requirements.in + # via -r requirements.in aiosignal==1.3.1 # via aiohttp alembic==1.8.1 - # via -r model-engine/requirements.in + # via -r requirements.in amqp==5.1.1 # via kombu anyio==3.7.1 @@ -33,7 +33,7 @@ async-timeout==4.0.2 # aioredis # redis asyncpg==0.27.0 - # via -r model-engine/requirements.in + # via -r requirements.in attrs==23.1.0 # via # aiohttp @@ -44,7 +44,7 @@ attrs==23.1.0 azure-common==1.1.28 # via azure-keyvault-secrets azure-containerregistry==1.2.0 - # via -r model-engine/requirements.in + # via -r requirements.in azure-core==1.29.6 # via # azure-containerregistry @@ -53,13 +53,13 @@ azure-core==1.29.6 # azure-servicebus # azure-storage-blob azure-identity==1.15.0 - # via -r model-engine/requirements.in + # via -r requirements.in azure-keyvault-secrets==4.7.0 - # via -r model-engine/requirements.in + # via -r requirements.in azure-servicebus==7.11.4 - # via -r model-engine/requirements.in + # via -r requirements.in azure-storage-blob==12.19.0 - # via -r model-engine/requirements.in + # via -r requirements.in backports-zoneinfo[tzdata]==0.2.1 # via # celery @@ -72,20 +72,20 @@ blinker==1.6.2 # via quart boto3==1.28.1 # via - # -r model-engine/requirements.in + # -r requirements.in # celery # kombu boto3-stubs[essential]==1.26.67 - # via -r model-engine/requirements.in + # via -r requirements.in botocore==1.31.1 # via - # -r model-engine/requirements.in + # -r requirements.in # boto3 # s3transfer botocore-stubs==1.29.165 # via boto3-stubs build==0.8.0 - # via -r model-engine/requirements.in + # via -r requirements.in bytecode==0.14.2 # via ddtrace cachetools==5.3.1 @@ -93,7 +93,7 @@ cachetools==5.3.1 cattrs==23.1.2 # via ddtrace celery[redis,sqs,tblib]==5.3.6 - # via -r model-engine/requirements.in + # via -r requirements.in certifi==2023.7.22 # via # datadog-api-client @@ -108,7 +108,7 @@ charset-normalizer==3.2.0 # via requests click==8.1.4 # via - # -r model-engine/requirements.in + # -r requirements.in # celery # click-didyoumean # click-plugins @@ -122,35 +122,35 @@ click-plugins==1.1.1 click-repl==0.3.0 # via celery cloudpickle==2.1.0 - # via -r model-engine/requirements.in + # via -r requirements.in colorama==0.4.6 # via twine commonmark==0.9.1 # via rich croniter==1.4.1 - # via -r model-engine/requirements.in + # via -r requirements.in cryptography==42.0.5 # via - # -r model-engine/requirements.in + # -r requirements.in # azure-identity # azure-storage-blob # msal # pyjwt # secretstorage dataclasses-json==0.5.9 - # via -r model-engine/requirements.in + # via -r requirements.in datadog==0.47.0 - # via -r model-engine/requirements.in + # via -r requirements.in datadog-api-client==2.11.0 - # via -r model-engine/requirements.in + # via -r requirements.in ddsketch==2.0.4 # via ddtrace ddtrace==1.8.3 - # via -r model-engine/requirements.in + # via -r requirements.in deprecation==2.1.0 - # via -r model-engine/requirements.in + # via -r requirements.in docker==5.0.3 - # via -r model-engine/requirements.in + # via -r requirements.in docutils==0.20.1 # via readme-renderer envier==0.4.0 @@ -160,7 +160,7 @@ exceptiongroup==1.2.0 # anyio # cattrs fastapi==0.110.0 - # via -r model-engine/requirements.in + # via -r requirements.in filelock==3.13.1 # via # huggingface-hub @@ -174,15 +174,15 @@ fsspec==2023.10.0 gitdb==4.0.10 # via gitpython gitdb2==2.0.6 - # via -r model-engine/requirements.in + # via -r requirements.in gitpython==3.1.41 - # via -r model-engine/requirements.in + # via -r requirements.in google-auth==2.21.0 # via kubernetes greenlet==2.0.2 # via sqlalchemy gunicorn==20.1.0 - # via -r model-engine/requirements.in + # via -r requirements.in h11==0.14.0 # via # httpcore @@ -196,7 +196,7 @@ hpack==4.0.0 httpcore==1.0.4 # via httpx httptools==0.5.0 - # via -r model-engine/requirements.in + # via -r requirements.in httpx==0.27.0 # via starlette huggingface-hub==0.20.3 @@ -207,7 +207,7 @@ hypercorn==0.14.4 # via quart hyperframe==6.0.1 # via h2 -idna==3.4 +idna==3.7 # via # anyio # httpx @@ -243,7 +243,7 @@ jeepney==0.8.0 # secretstorage jinja2==3.0.3 # via - # -r model-engine/requirements.in + # -r requirements.in # quart # starlette jmespath==1.0.1 @@ -251,7 +251,7 @@ jmespath==1.0.1 # boto3 # botocore json-log-formatter==0.5.2 - # via -r model-engine/requirements.in + # via -r requirements.in jsonschema==4.19.0 # via ddtrace jsonschema-specifications==2023.7.1 @@ -261,11 +261,11 @@ keyring==24.2.0 kombu[sqs]==5.3.5 # via celery kubeconfig==1.1.1 - # via -r model-engine/requirements.in + # via -r requirements.in kubernetes==25.3.0 - # via -r model-engine/requirements.in + # via -r requirements.in kubernetes-asyncio==25.11.0 - # via -r model-engine/requirements.in + # via -r requirements.in mako==1.2.4 # via alembic markupsafe==2.1.3 @@ -313,7 +313,7 @@ numpy==1.24.4 oauthlib==3.2.2 # via requests-oauthlib orjson==3.9.15 - # via -r model-engine/requirements.in + # via -r requirements.in packaging==23.1 # via # build @@ -339,13 +339,13 @@ prompt-toolkit==3.0.39 # via click-repl protobuf==3.20.3 # via - # -r model-engine/requirements.in + # -r requirements.in # ddsketch # ddtrace psycopg2-binary==2.9.3 - # via -r model-engine/requirements.in + # via -r requirements.in py-xid==0.3.0 - # via -r model-engine/requirements.in + # via -r requirements.in pyasn1==0.5.0 # via # pyasn1-modules @@ -356,19 +356,21 @@ pycparser==2.21 # via cffi pycurl==7.45.2 # via - # -r model-engine/requirements.in + # -r requirements.in # celery # kombu pydantic==1.10.11 # via - # -r model-engine/requirements.in + # -r requirements.in # fastapi pygments==2.15.1 # via # readme-renderer # rich pyjwt[crypto]==2.8.0 - # via msal + # via + # msal + # pyjwt python-dateutil==2.8.2 # via # botocore @@ -380,7 +382,7 @@ python-dateutil==2.8.2 # pg8000 python-multipart==0.0.7 # via - # -r model-engine/requirements.in + # -r requirements.in # starlette pyyaml==6.0.1 # via @@ -391,7 +393,7 @@ pyyaml==6.0.1 # starlette # transformers quart==0.18.3 - # via -r model-engine/requirements.in + # via -r requirements.in readme-renderer==40.0 # via twine redis==4.6.0 @@ -404,7 +406,7 @@ regex==2023.10.3 # via transformers requests==2.31.0 # via - # -r model-engine/requirements.in + # -r requirements.in # azure-core # datadog # docker @@ -417,7 +419,7 @@ requests==2.31.0 # transformers # twine requests-auth-aws-sigv4==0.7 - # via -r model-engine/requirements.in + # via -r requirements.in requests-oauthlib==1.3.1 # via kubernetes requests-toolbelt==1.0.0 @@ -425,7 +427,7 @@ requests-toolbelt==1.0.0 rfc3986==2.0.0 # via twine rich==12.6.0 - # via -r model-engine/requirements.in + # via -r requirements.in rpds-py==0.10.0 # via # jsonschema @@ -441,9 +443,9 @@ scramp==1.4.4 secretstorage==3.3.3 # via keyring sentencepiece==0.1.99 - # via -r model-engine/requirements.in + # via -r requirements.in sh==1.14.3 - # via -r model-engine/requirements.in + # via -r requirements.in six==1.16.0 # via # azure-core @@ -457,7 +459,7 @@ six==1.16.0 # python-dateutil # tenacity smart-open==5.2.1 - # via -r model-engine/requirements.in + # via -r requirements.in smmap==5.0.0 # via # gitdb @@ -470,32 +472,32 @@ sniffio==1.3.0 # httpx sqlalchemy[asyncio]==2.0.4 # via - # -r model-engine/requirements.in + # -r requirements.in # alembic sse-starlette==1.6.1 - # via -r model-engine/requirements.in + # via -r requirements.in sseclient-py==1.7.2 - # via -r model-engine/requirements.in + # via -r requirements.in starlette[full]==0.36.3 # via - # -r model-engine/requirements.in + # -r requirements.in # fastapi # sse-starlette stringcase==1.2.0 - # via -r model-engine/requirements.in + # via -r requirements.in tblib==2.0.0 # via celery tenacity==6.2.0 # via - # -r model-engine/requirements.in + # -r requirements.in # ddtrace testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 - # via -r model-engine/requirements.in + # via -r requirements.in tokenizers==0.15.2 # via - # -r model-engine/requirements.in + # -r requirements.in # transformers tomli==2.0.1 # via @@ -504,14 +506,14 @@ tomli==2.0.1 # pep517 tqdm==4.65.0 # via - # -r model-engine/requirements.in + # -r requirements.in # huggingface-hub # transformers # twine transformers==4.38.0 - # via -r model-engine/requirements.in + # via -r requirements.in twine==3.7.1 - # via -r model-engine/requirements.in + # via -r requirements.in types-awscrt==0.16.23 # via # botocore-stubs @@ -564,9 +566,9 @@ urllib3==1.26.16 # kubernetes-asyncio # requests uvicorn==0.17.6 - # via -r model-engine/requirements.in + # via -r requirements.in uvloop==0.17.0 - # via -r model-engine/requirements.in + # via -r requirements.in vine==5.1.0 # via # amqp @@ -588,7 +590,7 @@ xmltodict==0.13.0 # via ddtrace yarl==1.9.2 # via - # -r model-engine/requirements.in + # -r requirements.in # aiohttp zipp==3.16.0 # via From ee3a367e46040af1693f708ff47de518086740c8 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:50:09 -0700 Subject: [PATCH 282/425] Properly add mixtral 8x22b (#493) * Properly add mixtral 8x22b * unit test --- docs/model_zoo.md | 1 + .../domain/use_cases/llm_model_endpoint_use_cases.py | 10 +++++++++- .../infra/repositories/live_tokenizer_repository.py | 1 + model-engine/tests/unit/domain/test_llm_use_cases.py | 7 +++++++ 4 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 8805418c..67b9d727 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -22,6 +22,7 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `mistral-7b-instruct` | ✅ | ✅ | vllm | 8000 | | `mixtral-8x7b` | ✅ | | vllm | 32768 | | `mixtral-8x7b-instruct` | ✅ | | vllm | 32768 | +| `mixtral-8x22b` | ✅ | | vllm | 65536 | | `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | | `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | | `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 65973ced..e24efe10 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -180,6 +180,7 @@ "mistral-7b-instruct", "mixtral-8x7b", "mixtral-8x7b-instruct", + "mixtral-8x22b", "mammoth-coder-llama-2-7b", "mammoth-coder-llama-2-13b", "mammoth-coder-llama-2-34b", @@ -230,7 +231,8 @@ "gemma": {"max_model_len": 8192, "max_num_batched_tokens": 8192}, "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, - "mixtral": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, + "mixtral-8x7b": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, + "mixtral-8x22b": {"max_model_len": 65536, "max_num_batched_tokens": 65536}, "zephyr": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, } @@ -2200,6 +2202,12 @@ def infer_hardware_from_model_name(model_name: str) -> CreateDockerImageBatchJob memory = "160Gi" storage = "160Gi" gpu_type = GpuType.NVIDIA_AMPERE_A100E + elif "mixtral-8x22b" in model_name: + cpus = "80" + gpus = 8 + memory = "800Gi" + storage = "460Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A100E else: numbers = re.findall(r"\d+", model_name) if len(numbers) == 0: diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 41356aef..4670a965 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -62,6 +62,7 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "mistral-7b-instruct": ModelInfo("mistralai/Mistral-7B-Instruct-v0.1", None), "mixtral-8x7b": ModelInfo("mistralai/Mixtral-8x7B-v0.1", None), "mixtral-8x7b-instruct": ModelInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", None), + "mixtral-8x22b": ModelInfo("mistral-community/Mixtral-8x22B-v0.1", None), "mammoth-coder-llama-2-7b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-7B", None), "mammoth-coder-llama-2-13b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-13B", None), "mammoth-coder-llama-2-34b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-34B", None), diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 1c3fb086..6796135e 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1705,6 +1705,13 @@ def test_infer_hardware_from_model_name(): assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + hardware = infer_hardware_from_model_name("mixtral-8x22b") + assert hardware.cpus == "80" + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "460Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + hardware = infer_hardware_from_model_name("llama-2-7b") assert hardware.cpus == "10" assert hardware.gpus == 1 From 040622ab37218497f1a6b786f5db8105f76ac8d0 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Wed, 17 Apr 2024 11:07:43 -0700 Subject: [PATCH 283/425] support mixtral 8x22b instruct (#495) --- docs/model_zoo.md | 1 + .../domain/use_cases/llm_model_endpoint_use_cases.py | 1 + .../infra/repositories/live_tokenizer_repository.py | 3 ++- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 67b9d727..dd834219 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -23,6 +23,7 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `mixtral-8x7b` | ✅ | | vllm | 32768 | | `mixtral-8x7b-instruct` | ✅ | | vllm | 32768 | | `mixtral-8x22b` | ✅ | | vllm | 65536 | +| `mixtral-8x22b-instruct` | ✅ | | vllm | 65536 | | `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | | `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | | `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index e24efe10..924d3391 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -181,6 +181,7 @@ "mixtral-8x7b", "mixtral-8x7b-instruct", "mixtral-8x22b", + "mixtral-8x22b-instruct", "mammoth-coder-llama-2-7b", "mammoth-coder-llama-2-13b", "mammoth-coder-llama-2-34b", diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 4670a965..180bea31 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -62,7 +62,8 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "mistral-7b-instruct": ModelInfo("mistralai/Mistral-7B-Instruct-v0.1", None), "mixtral-8x7b": ModelInfo("mistralai/Mixtral-8x7B-v0.1", None), "mixtral-8x7b-instruct": ModelInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", None), - "mixtral-8x22b": ModelInfo("mistral-community/Mixtral-8x22B-v0.1", None), + "mixtral-8x22b": ModelInfo("mistralai/Mixtral-8x22B-v0.1", None), + "mixtral-8x22b-instruct": ModelInfo("mistralai/Mixtral-8x22B-Instruct-v0.1", None), "mammoth-coder-llama-2-7b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-7B", None), "mammoth-coder-llama-2-13b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-13B", None), "mammoth-coder-llama-2-34b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-34B", None), From 10d84ca8c81c9562fac955528c61e800aa6ec151 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Mon, 22 Apr 2024 22:05:19 -0700 Subject: [PATCH 284/425] fix return_token_log_probs on vLLM > 0.3.3 endpoints (#498) * fix return_token_log_probs * fix fr * undo extra change --- .../inference/vllm/requirements.txt | 2 +- .../inference/vllm/vllm_server.py | 23 +++++++++++++++---- model-engine/tests/unit/inference/conftest.py | 14 ++++++++--- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 3c1cf851..d0e331f4 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,2 +1,2 @@ -vllm==0.3.3 +vllm==0.4.0.post1 pydantic>=2.0 diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index d9b502ef..c7ef4b43 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -4,7 +4,7 @@ import signal import subprocess import traceback -from typing import AsyncGenerator +from typing import AsyncGenerator, Dict, List, Optional import uvicorn from fastapi import BackgroundTasks, FastAPI, HTTPException, Request @@ -13,7 +13,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor +from vllm.outputs import CompletionOutput from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob from vllm.utils import random_uuid TIMEOUT_KEEP_ALIVE = 5 # seconds. @@ -76,13 +78,12 @@ async def generate(request: Request) -> Response: async def stream_results() -> AsyncGenerator[str, None]: last_output_text = "" async for request_output in results_generator: + log_probs = format_logprobs(request_output) ret = { "text": request_output.outputs[-1].text[len(last_output_text) :], "count_prompt_tokens": len(request_output.prompt_token_ids), "count_output_tokens": len(request_output.outputs[0].token_ids), - "log_probs": ( - request_output.outputs[0].logprobs[-1] if sampling_params.logprobs else None - ), + "log_probs": log_probs[-1] if log_probs and sampling_params.logprobs else None, "finished": request_output.finished, } last_output_text = request_output.outputs[-1].text @@ -116,7 +117,7 @@ async def abort_request() -> None: "text": final_output.outputs[0].text, "count_prompt_tokens": len(final_output.prompt_token_ids), "count_output_tokens": len(final_output.outputs[0].token_ids), - "log_probs": final_output.outputs[0].logprobs, + "log_probs": format_logprobs(final_output), "tokens": tokens, } return Response(content=json.dumps(ret)) @@ -166,6 +167,18 @@ def debug(sig, frame): i.interact(message) +def format_logprobs(request_output: CompletionOutput) -> Optional[List[Dict[int, float]]]: + """Given a request output, format the logprobs if they exist.""" + output_logprobs = request_output.outputs[0].logprobs + if output_logprobs is None: + return None + + def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: + return {k: v.logprob for k, v in logprobs.items()} + + return [extract_logprobs(logprobs) for logprobs in output_logprobs] + + if __name__ == "__main__": check_unknown_startup_memory_usage() parser = argparse.ArgumentParser() diff --git a/model-engine/tests/unit/inference/conftest.py b/model-engine/tests/unit/inference/conftest.py index 26a3a0a3..20c4aae8 100644 --- a/model-engine/tests/unit/inference/conftest.py +++ b/model-engine/tests/unit/inference/conftest.py @@ -58,13 +58,19 @@ def create_batch_completions_request_content(): @pytest.fixture def create_vllm_request_outputs(): + class Logprob: + """mock, from https://github.com/vllm-project/vllm/blob/v0.4.1/vllm/sequence.py#L18""" + + def __init__(self, logprob: float): + self.logprob = logprob + mock_vllm_request_output1 = MagicMock() mock_vllm_request_output1.outputs = [ MagicMock(text="text1"), ] mock_vllm_request_output1.prompt_token_ids = [1, 2, 3] mock_vllm_request_output1.outputs[0].token_ids = [4] - mock_vllm_request_output1.outputs[0].logprobs = [{4: 0.1}] + mock_vllm_request_output1.outputs[0].logprobs = [{4: Logprob(0.1)}] mock_vllm_request_output2 = MagicMock() mock_vllm_request_output2.outputs = [ @@ -72,7 +78,7 @@ def create_vllm_request_outputs(): ] mock_vllm_request_output2.prompt_token_ids = [1, 2, 3] mock_vllm_request_output2.outputs[0].token_ids = [4, 5] - mock_vllm_request_output2.outputs[0].logprobs = [{4: 0.1, 5: 0.2}] + mock_vllm_request_output2.outputs[0].logprobs = [{4: Logprob(0.1), 5: Logprob(0.2)}] mock_vllm_request_output3 = MagicMock() mock_vllm_request_output3.outputs = [ @@ -80,7 +86,9 @@ def create_vllm_request_outputs(): ] mock_vllm_request_output3.prompt_token_ids = [1, 2, 3] mock_vllm_request_output3.outputs[0].token_ids = [4, 5, 6] - mock_vllm_request_output3.outputs[0].logprobs = [{4: 0.1, 5: 0.2, 6: 0.3}] + mock_vllm_request_output3.outputs[0].logprobs = [ + {4: Logprob(0.1), 5: Logprob(0.2), 6: Logprob(0.3)} + ] return [mock_vllm_request_output1, mock_vllm_request_output2, mock_vllm_request_output3] From 9673b3f590b51853094bcd83919f479d40dba4d8 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 23 Apr 2024 22:32:21 -0700 Subject: [PATCH 285/425] Package update + more docs on dev setup (#500) * Add more docs to be able to run tests properly * Pull out os.getenv from method call * Add highpri test case + fixtures/mocks for coverage * Remove README update --- .../services/live_endpoint_builder_service.py | 16 +- model-engine/requirements.in | 4 +- model-engine/requirements.txt | 141 +++++++++--------- model-engine/tests/unit/conftest.py | 42 ++++++ .../test_live_endpoint_builder_service.py | 5 + requirements-dev.txt | 4 +- 6 files changed, 131 insertions(+), 81 deletions(-) diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index bef91df0..aecef2c7 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -78,6 +78,7 @@ ECR_AWS_PROFILE: str = os.getenv("ECR_READ_AWS_PROFILE", "default") # type: ignore GIT_TAG: str = os.getenv("GIT_TAG") # type: ignore ENV: str = os.getenv("DD_ENV") # type: ignore +WORKSPACE_PATH = os.getenv("WORKSPACE", ".") INITIAL_K8S_CACHE_TTL_SECONDS: int = 60 MAX_IMAGE_TAG_LEN = 128 @@ -494,7 +495,6 @@ def get_base_image_params( # The context should be whatever WORKDIR is in the container running the build app itself. inference_folder = "model-engine/model_engine_server/inference" - base_path: str = os.getenv("WORKSPACE") # type: ignore logger_adapter.info(f"inference_folder: {inference_folder}") logger_adapter.info(f"dockerfile: {inference_folder}/{dockerfile}") @@ -502,7 +502,7 @@ def get_base_image_params( repo=hmi_config.user_inference_base_repository, image_tag=resulting_image_tag[:MAX_IMAGE_TAG_LEN], aws_profile=ECR_AWS_PROFILE, # type: ignore - base_path=base_path, + base_path=WORKSPACE_PATH, dockerfile=f"{inference_folder}/{dockerfile}", base_image=base_image, requirements_folder=None, @@ -557,9 +557,7 @@ def _get_user_image_params( # The context should be whatever WORKDIR is in the container running the build app itself. inference_folder = "model-engine/model_engine_server/inference" - base_path: str = os.getenv("WORKSPACE") # type: ignore - - requirements_folder = os.path.join(base_path, f"requirements_{requirements_hash}") + requirements_folder = os.path.join(WORKSPACE_PATH, f"requirements_{requirements_hash}") try: os.mkdir(requirements_folder) except FileExistsError: @@ -577,7 +575,7 @@ def _get_user_image_params( repo=ecr_repo, image_tag=service_image_tag[:MAX_IMAGE_TAG_LEN], aws_profile=ECR_AWS_PROFILE, - base_path=base_path, + base_path=WORKSPACE_PATH, dockerfile=f"{inference_folder}/{dockerfile}", base_image=base_image, requirements_folder=requirements_folder, @@ -609,9 +607,7 @@ def _get_inject_bundle_image_params( # The context should be whatever WORKDIR is in the container running the build app itself. dockerfile = "inject_bundle.Dockerfile" inference_folder = "model-engine/model_engine_server/inference" - base_path: str = os.getenv("WORKSPACE") # type: ignore - - bundle_folder = os.path.join(base_path, f"bundle_{service_image_hash}") + bundle_folder = os.path.join(WORKSPACE_PATH, f"bundle_{service_image_hash}") try: os.mkdir(bundle_folder) except FileExistsError: @@ -635,7 +631,7 @@ def _get_inject_bundle_image_params( repo=ecr_repo, image_tag=service_image_tag[:MAX_IMAGE_TAG_LEN], aws_profile=ECR_AWS_PROFILE, - base_path=base_path, + base_path=WORKSPACE_PATH, dockerfile=f"{inference_folder}/{dockerfile}", base_image=base_image, requirements_folder=bundle_folder, diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 5cf95a51..2ef63150 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -12,7 +12,7 @@ azure-storage-blob~=12.19.0 boto3-stubs[essential]~=1.26.67 boto3~=1.21 botocore~=1.24 -build==0.8.0 +build~=1.0.3 celery[redis,sqs,tblib]~=5.3.6 click~=8.1 cloudpickle==2.1.0 @@ -37,7 +37,7 @@ protobuf~=3.20 psycopg2-binary==2.9.3 py-xid==0.3.0 pycurl~=7.44 # For celery[sqs] -pydantic~=1.10.11 +pydantic==1.10.14 python-multipart~=0.0.7 quart==0.18.3 requests-auth-aws-sigv4~=0.7 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 623b12af..71e7440d 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -8,14 +8,14 @@ aiofiles==23.1.0 # via quart aiohttp==3.9.2 # via - # -r requirements.in + # -r model-engine/requirements.in # kubernetes-asyncio aioredis==2.0.1 - # via -r requirements.in + # via -r model-engine/requirements.in aiosignal==1.3.1 # via aiohttp alembic==1.8.1 - # via -r requirements.in + # via -r model-engine/requirements.in amqp==5.1.1 # via kombu anyio==3.7.1 @@ -33,7 +33,7 @@ async-timeout==4.0.2 # aioredis # redis asyncpg==0.27.0 - # via -r requirements.in + # via -r model-engine/requirements.in attrs==23.1.0 # via # aiohttp @@ -44,7 +44,7 @@ attrs==23.1.0 azure-common==1.1.28 # via azure-keyvault-secrets azure-containerregistry==1.2.0 - # via -r requirements.in + # via -r model-engine/requirements.in azure-core==1.29.6 # via # azure-containerregistry @@ -53,13 +53,13 @@ azure-core==1.29.6 # azure-servicebus # azure-storage-blob azure-identity==1.15.0 - # via -r requirements.in + # via -r model-engine/requirements.in azure-keyvault-secrets==4.7.0 - # via -r requirements.in + # via -r model-engine/requirements.in azure-servicebus==7.11.4 - # via -r requirements.in + # via -r model-engine/requirements.in azure-storage-blob==12.19.0 - # via -r requirements.in + # via -r model-engine/requirements.in backports-zoneinfo[tzdata]==0.2.1 # via # celery @@ -72,20 +72,22 @@ blinker==1.6.2 # via quart boto3==1.28.1 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # kombu boto3-stubs[essential]==1.26.67 - # via -r requirements.in + # via + # -r model-engine/requirements.in + # boto3-stubs botocore==1.31.1 # via - # -r requirements.in + # -r model-engine/requirements.in # boto3 # s3transfer botocore-stubs==1.29.165 # via boto3-stubs -build==0.8.0 - # via -r requirements.in +build==1.0.3 + # via -r model-engine/requirements.in bytecode==0.14.2 # via ddtrace cachetools==5.3.1 @@ -93,7 +95,9 @@ cachetools==5.3.1 cattrs==23.1.2 # via ddtrace celery[redis,sqs,tblib]==5.3.6 - # via -r requirements.in + # via + # -r model-engine/requirements.in + # celery certifi==2023.7.22 # via # datadog-api-client @@ -108,7 +112,7 @@ charset-normalizer==3.2.0 # via requests click==8.1.4 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # click-didyoumean # click-plugins @@ -122,35 +126,35 @@ click-plugins==1.1.1 click-repl==0.3.0 # via celery cloudpickle==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in colorama==0.4.6 # via twine commonmark==0.9.1 # via rich croniter==1.4.1 - # via -r requirements.in + # via -r model-engine/requirements.in cryptography==42.0.5 # via - # -r requirements.in + # -r model-engine/requirements.in # azure-identity # azure-storage-blob # msal # pyjwt # secretstorage dataclasses-json==0.5.9 - # via -r requirements.in + # via -r model-engine/requirements.in datadog==0.47.0 - # via -r requirements.in + # via -r model-engine/requirements.in datadog-api-client==2.11.0 - # via -r requirements.in + # via -r model-engine/requirements.in ddsketch==2.0.4 # via ddtrace ddtrace==1.8.3 - # via -r requirements.in + # via -r model-engine/requirements.in deprecation==2.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in docker==5.0.3 - # via -r requirements.in + # via -r model-engine/requirements.in docutils==0.20.1 # via readme-renderer envier==0.4.0 @@ -160,7 +164,7 @@ exceptiongroup==1.2.0 # anyio # cattrs fastapi==0.110.0 - # via -r requirements.in + # via -r model-engine/requirements.in filelock==3.13.1 # via # huggingface-hub @@ -174,15 +178,15 @@ fsspec==2023.10.0 gitdb==4.0.10 # via gitpython gitdb2==2.0.6 - # via -r requirements.in + # via -r model-engine/requirements.in gitpython==3.1.41 - # via -r requirements.in + # via -r model-engine/requirements.in google-auth==2.21.0 # via kubernetes greenlet==2.0.2 # via sqlalchemy gunicorn==20.1.0 - # via -r requirements.in + # via -r model-engine/requirements.in h11==0.14.0 # via # httpcore @@ -196,7 +200,7 @@ hpack==4.0.0 httpcore==1.0.4 # via httpx httptools==0.5.0 - # via -r requirements.in + # via -r model-engine/requirements.in httpx==0.27.0 # via starlette huggingface-hub==0.20.3 @@ -216,6 +220,7 @@ idna==3.7 importlib-metadata==6.8.0 # via # alembic + # build # keyring # quart # twine @@ -243,7 +248,7 @@ jeepney==0.8.0 # secretstorage jinja2==3.0.3 # via - # -r requirements.in + # -r model-engine/requirements.in # quart # starlette jmespath==1.0.1 @@ -251,7 +256,7 @@ jmespath==1.0.1 # boto3 # botocore json-log-formatter==0.5.2 - # via -r requirements.in + # via -r model-engine/requirements.in jsonschema==4.19.0 # via ddtrace jsonschema-specifications==2023.7.1 @@ -261,11 +266,11 @@ keyring==24.2.0 kombu[sqs]==5.3.5 # via celery kubeconfig==1.1.1 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes==25.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in kubernetes-asyncio==25.11.0 - # via -r requirements.in + # via -r model-engine/requirements.in mako==1.2.4 # via alembic markupsafe==2.1.3 @@ -313,7 +318,7 @@ numpy==1.24.4 oauthlib==3.2.2 # via requests-oauthlib orjson==3.9.15 - # via -r requirements.in + # via -r model-engine/requirements.in packaging==23.1 # via # build @@ -323,8 +328,6 @@ packaging==23.1 # marshmallow # msal-extensions # transformers -pep517==0.13.0 - # via build pg8000==1.29.8 # via testing-postgresql pkginfo==1.9.6 @@ -339,13 +342,13 @@ prompt-toolkit==3.0.39 # via click-repl protobuf==3.20.3 # via - # -r requirements.in + # -r model-engine/requirements.in # ddsketch # ddtrace psycopg2-binary==2.9.3 - # via -r requirements.in + # via -r model-engine/requirements.in py-xid==0.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in pyasn1==0.5.0 # via # pyasn1-modules @@ -356,12 +359,12 @@ pycparser==2.21 # via cffi pycurl==7.45.2 # via - # -r requirements.in + # -r model-engine/requirements.in # celery # kombu -pydantic==1.10.11 +pydantic==1.10.14 # via - # -r requirements.in + # -r model-engine/requirements.in # fastapi pygments==2.15.1 # via @@ -371,6 +374,8 @@ pyjwt[crypto]==2.8.0 # via # msal # pyjwt +pyproject-hooks==1.0.0 + # via build python-dateutil==2.8.2 # via # botocore @@ -382,7 +387,7 @@ python-dateutil==2.8.2 # pg8000 python-multipart==0.0.7 # via - # -r requirements.in + # -r model-engine/requirements.in # starlette pyyaml==6.0.1 # via @@ -393,7 +398,7 @@ pyyaml==6.0.1 # starlette # transformers quart==0.18.3 - # via -r requirements.in + # via -r model-engine/requirements.in readme-renderer==40.0 # via twine redis==4.6.0 @@ -406,7 +411,7 @@ regex==2023.10.3 # via transformers requests==2.31.0 # via - # -r requirements.in + # -r model-engine/requirements.in # azure-core # datadog # docker @@ -419,7 +424,7 @@ requests==2.31.0 # transformers # twine requests-auth-aws-sigv4==0.7 - # via -r requirements.in + # via -r model-engine/requirements.in requests-oauthlib==1.3.1 # via kubernetes requests-toolbelt==1.0.0 @@ -427,7 +432,7 @@ requests-toolbelt==1.0.0 rfc3986==2.0.0 # via twine rich==12.6.0 - # via -r requirements.in + # via -r model-engine/requirements.in rpds-py==0.10.0 # via # jsonschema @@ -443,9 +448,9 @@ scramp==1.4.4 secretstorage==3.3.3 # via keyring sentencepiece==0.1.99 - # via -r requirements.in + # via -r model-engine/requirements.in sh==1.14.3 - # via -r requirements.in + # via -r model-engine/requirements.in six==1.16.0 # via # azure-core @@ -459,7 +464,7 @@ six==1.16.0 # python-dateutil # tenacity smart-open==5.2.1 - # via -r requirements.in + # via -r model-engine/requirements.in smmap==5.0.0 # via # gitdb @@ -472,48 +477,50 @@ sniffio==1.3.0 # httpx sqlalchemy[asyncio]==2.0.4 # via - # -r requirements.in + # -r model-engine/requirements.in # alembic + # sqlalchemy sse-starlette==1.6.1 - # via -r requirements.in + # via -r model-engine/requirements.in sseclient-py==1.7.2 - # via -r requirements.in + # via -r model-engine/requirements.in starlette[full]==0.36.3 # via - # -r requirements.in + # -r model-engine/requirements.in # fastapi # sse-starlette + # starlette stringcase==1.2.0 - # via -r requirements.in + # via -r model-engine/requirements.in tblib==2.0.0 # via celery tenacity==6.2.0 # via - # -r requirements.in + # -r model-engine/requirements.in # ddtrace testing-common-database==2.0.3 # via testing-postgresql testing-postgresql==1.3.0 - # via -r requirements.in + # via -r model-engine/requirements.in tokenizers==0.15.2 # via - # -r requirements.in + # -r model-engine/requirements.in # transformers tomli==2.0.1 # via # build # hypercorn - # pep517 + # pyproject-hooks tqdm==4.65.0 # via - # -r requirements.in + # -r model-engine/requirements.in # huggingface-hub # transformers # twine transformers==4.38.0 - # via -r requirements.in + # via -r model-engine/requirements.in twine==3.7.1 - # via -r requirements.in + # via -r model-engine/requirements.in types-awscrt==0.16.23 # via # botocore-stubs @@ -566,9 +573,9 @@ urllib3==1.26.16 # kubernetes-asyncio # requests uvicorn==0.17.6 - # via -r requirements.in + # via -r model-engine/requirements.in uvloop==0.17.0 - # via -r requirements.in + # via -r model-engine/requirements.in vine==5.1.0 # via # amqp @@ -590,7 +597,7 @@ xmltodict==0.13.0 # via ddtrace yarl==1.9.2 # via - # -r requirements.in + # -r model-engine/requirements.in # aiohttp zipp==3.16.0 # via diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 4b57afa1..fd6b9f0e 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -3383,6 +3383,48 @@ def build_endpoint_request_async_custom( return build_endpoint_request +@pytest.fixture +def build_endpoint_request_async_zipartifact_highpri( + test_api_key: str, model_bundle_3: ModelBundle +) -> BuildEndpointRequest: + build_endpoint_request = BuildEndpointRequest( + model_endpoint_record=ModelEndpointRecord( + id="test_model_endpoint_id_3", + name="test_model_endpoint_name_3", + created_by=test_api_key, + created_at=datetime(2022, 1, 4), + last_updated_at=datetime(2022, 1, 4), + metadata={}, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.ASYNC, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_3, + owner=test_api_key, + ), + high_priority=True, + deployment_name=f"{test_api_key}-test_model_endpoint_name_3", + aws_role="default", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + post_inference_hooks=None, + labels=dict(team="test_team", product="test_product"), + min_workers=1, + max_workers=3, + per_worker=2, + cpus=1, + gpus=0, + memory="1G", + gpu_type=None, + storage=None, + optimize_costs=True, + broker_type=BrokerType.SQS, + default_callback_url=None, + default_callback_auth=None, + ) + return build_endpoint_request + + @pytest.fixture def build_endpoint_request_sync_custom( test_api_key: str, model_bundle_3: ModelBundle diff --git a/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py index bf568c9a..a0e876eb 100644 --- a/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py +++ b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py @@ -102,8 +102,11 @@ def set_env_vars(): live_endpoint_builder_service.ECR_AWS_PROFILE = "default" live_endpoint_builder_service.GIT_TAG = "test_tag" live_endpoint_builder_service.ENV = "test_env" + live_endpoint_builder_service.WORKSPACE_PATH = ".." live_endpoint_builder_service.open = mock_open() live_endpoint_builder_service.os.mkdir = Mock() + live_endpoint_builder_service.open_wrapper = mock_open() + live_endpoint_builder_service.tempfile.mkstemp = Mock(return_value=["", ""]) @pytest.mark.asyncio @@ -114,6 +117,7 @@ async def test_build_endpoint( build_endpoint_request_async_runnable_image: BuildEndpointRequest, build_endpoint_request_sync_runnable_image: BuildEndpointRequest, build_endpoint_request_streaming_runnable_image: BuildEndpointRequest, + build_endpoint_request_async_zipartifact_highpri: BuildEndpointRequest, endpoint_builder_service_empty_docker_built: LiveEndpointBuilderService, endpoint_builder_service_empty_docker_not_built: LiveEndpointBuilderService, fake_model_endpoint_cache_repository: ModelEndpointCacheRepository, @@ -131,6 +135,7 @@ async def test_build_endpoint( build_endpoint_request_async_runnable_image, build_endpoint_request_sync_runnable_image, build_endpoint_request_streaming_runnable_image, + build_endpoint_request_async_zipartifact_highpri, ]: fake_monitoring_metrics_gateway.reset() repo.add_model_endpoint_record(request.model_endpoint_record) diff --git a/requirements-dev.txt b/requirements-dev.txt index f6e4d22c..f785f8de 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,5 +5,5 @@ ipython==8.12.0 # 8.12.0 is the last version to support Python 3.8 isort==5.12.0 mypy==1.3.0 pip-tools==7.0.0 -poetry==1.5.1 -pre-commit==3.3.3 \ No newline at end of file +poetry==1.8.2 +pre-commit==3.3.3 From edecf56b6deb00d6c914e107fd6a96760a3f27a0 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 24 Apr 2024 10:52:43 -0700 Subject: [PATCH 286/425] Add Llama 3 models (#501) * Add Llama 3 models * fix --- docs/model_zoo.md | 4 ++++ .../domain/use_cases/llm_model_endpoint_use_cases.py | 5 +++++ .../infra/repositories/live_tokenizer_repository.py | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index dd834219..a8f4ae63 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -11,6 +11,10 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `llama-2-13b-chat` | ✅ | | text-generation-inference, vllm | 4096 | | `llama-2-70b` | ✅ | ✅ | text-generation-inference, vllm | 4096 | | `llama-2-70b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-3-8b` | ✅ | | vllm | 8192 | +| `llama-3-8b-instruct` | ✅ | | vllm | 8192 | +| `llama-3-70b` | ✅ | | vllm | 8192 | +| `llama-3-70b-instruct` | ✅ | | vllm | 8192 | | `falcon-7b` | ✅ | | text-generation-inference, vllm | 2048 | | `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | | `falcon-40b` | ✅ | | text-generation-inference, vllm | 2048 | diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 924d3391..40e47716 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -162,6 +162,10 @@ "llama-2-13b-chat", "llama-2-70b", "llama-2-70b-chat", + "llama-3-8b", + "llama-3-8b-instruct", + "llama-3-70b", + "llama-3-70b-instruct", "falcon-7b", "falcon-7b-instruct", "falcon-40b", @@ -231,6 +235,7 @@ # Can also see 13B, 34B there too "gemma": {"max_model_len": 8192, "max_num_batched_tokens": 8192}, "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, + "llama-3": {"max_model_len": None, "max_num_batched_tokens": 8192}, "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, "mixtral-8x7b": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, "mixtral-8x22b": {"max_model_len": 65536, "max_num_batched_tokens": 65536}, diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 180bea31..ea7b93d9 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -40,6 +40,10 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "llama-2-13b-chat": ModelInfo("meta-llama/Llama-2-13b-chat-hf", None), "llama-2-70b": ModelInfo("meta-llama/Llama-2-70b-hf", None), "llama-2-70b-chat": ModelInfo("meta-llama/Llama-2-70b-chat-hf", None), + "llama-3-8b": ModelInfo("meta-llama/Meta-Llama-3-8B", None), + "llama-3-8b-instruct": ModelInfo("meta-llama/Meta-Llama-3-8B-Instruct", None), + "llama-3-70b": ModelInfo("meta-llama/Meta-Llama-3-70B", None), + "llama-3-70b-instruct": ModelInfo("meta-llama/Meta-Llama-3-70B-Instruct", None), "falcon-7b": ModelInfo("tiiuae/falcon-7b", None), "falcon-7b-instruct": ModelInfo("tiiuae/falcon-7b-instruct", None), "falcon-40b": ModelInfo("tiiuae/falcon-40b", None), From 0079f7eb3c9c9827beb5ad4f35e0670ec2318b23 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 26 Apr 2024 12:36:53 -0700 Subject: [PATCH 287/425] Enforce model checkpoints existing for endpoint/bundle creation (#503) * Enforce model checkpoints existing for endpoint/bundle creation * Add test mock for good models info * Clean up checkpoint validation * Rename validate to get for semantics --- .../use_cases/llm_model_endpoint_use_cases.py | 117 +++++++++--------- .../repositories/live_tokenizer_repository.py | 7 +- .../tests/unit/domain/test_llm_use_cases.py | 95 +++++++++++++- 3 files changed, 157 insertions(+), 62 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 40e47716..189a4c24 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -329,6 +329,27 @@ def validate_quantization( ) +def validate_checkpoint_path_uri(checkpoint_path: str) -> None: + if not checkpoint_path.startswith("s3://"): + raise ObjectHasInvalidValueException( + f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." + ) + + +def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str: + checkpoint_path = ( + SUPPORTED_MODELS_INFO[model_name].s3_repo + if not checkpoint_path_override + else checkpoint_path_override + ) + + if not checkpoint_path: + raise InvalidRequestException(f"No checkpoint path found for model {model_name}") + + validate_checkpoint_path_uri(checkpoint_path) + return checkpoint_path + + class CreateLLMModelBundleV1UseCase: def __init__( self, @@ -449,22 +470,16 @@ async def create_text_generation_inference_bundle( max_total_tokens = 4096 subcommands = [] - if checkpoint_path is not None: - if checkpoint_path.startswith("s3://"): - final_weights_folder = "model_files" - subcommands += self.load_model_weights_sub_commands( - LLMInferenceFramework.TEXT_GENERATION_INFERENCE, - framework_image_tag, - checkpoint_path, - final_weights_folder, - ) - else: - raise ObjectHasInvalidValueException( - f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." - ) - else: - final_weights_folder = SUPPORTED_MODELS_INFO[model_name].hf_repo + checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) + final_weights_folder = "model_files" + + subcommands += self.load_model_weights_sub_commands( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + framework_image_tag, + checkpoint_path, + final_weights_folder, + ) subcommands.append( f"text-generation-launcher --hostname :: --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}" @@ -672,25 +687,19 @@ async def create_vllm_bundle( break subcommands = [] - if checkpoint_path is not None: - if checkpoint_path.startswith("s3://"): - # added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder - if "mistral" in model_name: - final_weights_folder = "mistral_files" - else: - final_weights_folder = "model_files" - subcommands += self.load_model_weights_sub_commands( - LLMInferenceFramework.VLLM, - framework_image_tag, - checkpoint_path, - final_weights_folder, - ) - else: - raise ObjectHasInvalidValueException( - f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." - ) + + checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) + # added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder + if "mistral" in model_name: + final_weights_folder = "mistral_files" else: - final_weights_folder = SUPPORTED_MODELS_INFO[model_name].hf_repo + final_weights_folder = "model_files" + subcommands += self.load_model_weights_sub_commands( + LLMInferenceFramework.VLLM, + framework_image_tag, + checkpoint_path, + final_weights_folder, + ) if max_model_len: subcommands.append( @@ -770,21 +779,15 @@ async def create_lightllm_bundle( max_req_total_len = 4096 subcommands = [] - if checkpoint_path is not None: - if checkpoint_path.startswith("s3://"): - final_weights_folder = "model_files" - subcommands += self.load_model_weights_sub_commands( - LLMInferenceFramework.LIGHTLLM, - framework_image_tag, - checkpoint_path, - final_weights_folder, - ) - else: - raise ObjectHasInvalidValueException( - f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." - ) - else: - final_weights_folder = SUPPORTED_MODELS_INFO[model_name].hf_repo + + checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) + final_weights_folder = "model_files" + subcommands += self.load_model_weights_sub_commands( + LLMInferenceFramework.LIGHTLLM, + framework_image_tag, + checkpoint_path, + final_weights_folder, + ) subcommands.append( f"python -m lightllm.server.api_server --model_dir {final_weights_folder} --port 5005 --tp {num_shards} --max_total_token_num {max_total_token_num} --max_req_input_len {max_req_input_len} --max_req_total_len {max_req_total_len} --tokenizer_mode auto" @@ -835,20 +838,18 @@ async def create_tensorrt_llm_bundle( command = [] subcommands = [] - if checkpoint_path is not None: - if checkpoint_path.startswith("s3://"): - subcommands += self.load_model_files_sub_commands_trt_llm( - checkpoint_path, - ) - else: - raise ObjectHasInvalidValueException( - f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." - ) - else: + + if not checkpoint_path: raise ObjectHasInvalidValueException( "Checkpoint must be provided for TensorRT-LLM models." ) + validate_checkpoint_path_uri(checkpoint_path) + + subcommands += self.load_model_files_sub_commands_trt_llm( + checkpoint_path, + ) + subcommands.append( f"python3 launch_triton_server.py --world_size={num_shards} --model_repo=./model_repo/" ) diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index ea7b93d9..779f08a7 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -1,7 +1,6 @@ import os -from collections import namedtuple from functools import lru_cache -from typing import Dict, Optional +from typing import Dict, NamedTuple, Optional from huggingface_hub import list_repo_refs from huggingface_hub.utils._errors import RepositoryNotFoundError @@ -25,7 +24,9 @@ TOKENIZER_TARGET_DIR = "/opt/.cache/model_engine_server/tokenizers" -ModelInfo = namedtuple("ModelInfo", ["hf_repo", "s3_repo"]) +class ModelInfo(NamedTuple): + hf_repo: str + s3_repo: Optional[str] def get_default_supported_models_info() -> Dict[str, ModelInfo]: diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 6796135e..4dcfaba5 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1,5 +1,5 @@ import json -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple from unittest import mock import pytest @@ -54,9 +54,22 @@ validate_and_update_completion_params, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase +from model_engine_server.infra.repositories import live_tokenizer_repository +from model_engine_server.infra.repositories.live_tokenizer_repository import ModelInfo + + +def good_models_info() -> Dict[str, ModelInfo]: + return { + k: ModelInfo(v.hf_repo, "s3://test-s3.tar") + for k, v in live_tokenizer_repository.SUPPORTED_MODELS_INFO.items() + } @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO", + good_models_info(), +) async def test_create_model_endpoint_use_case_success( test_api_key: str, fake_model_bundle_repository, @@ -170,6 +183,82 @@ async def test_create_model_endpoint_use_case_success( assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1] +def bad_models_info() -> Dict[str, ModelInfo]: + info = { + k: ModelInfo(v.hf_repo, v.s3_repo) + for k, v in live_tokenizer_repository.SUPPORTED_MODELS_INFO.items() + } + info.update( + { + "mpt-7b": ModelInfo("mosaicml/mpt-7b", None), + "mpt-7b-instruct": ModelInfo("mosaicml/mpt-7b-instruct", "gibberish"), + } + ) + return info + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "inference_framework, model_name, expected_error", + [ + (LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "mpt-7b", InvalidRequestException), + ( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + "mpt-7b-instruct", + ObjectHasInvalidValueException, + ), + (LLMInferenceFramework.LIGHTLLM, "mpt-7b", InvalidRequestException), + (LLMInferenceFramework.LIGHTLLM, "mpt-7b-instruct", ObjectHasInvalidValueException), + (LLMInferenceFramework.VLLM, "mpt-7b", InvalidRequestException), + (LLMInferenceFramework.VLLM, "mpt-7b-instruct", ObjectHasInvalidValueException), + ], +) +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO", + bad_models_info(), +) +async def test_create_model_bundle_fails_if_no_checkpoint( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request, + inference_framework, + model_name, + expected_error, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() + + with pytest.raises(expected_error): + await use_case.execute( + user=user, + endpoint_name=request.name, + model_name=model_name, + source=request.source, + framework=inference_framework, + framework_image_tag="0.0.0", + endpoint_type=request.endpoint_type, + num_shards=request.num_shards, + quantize=request.quantize, + checkpoint_path=None, + ) + + @pytest.mark.asyncio @pytest.mark.parametrize( "valid, inference_framework, inference_framework_image_tag", @@ -180,6 +269,10 @@ async def test_create_model_endpoint_use_case_success( (True, LLMInferenceFramework.VLLM, "0.1.3.6"), ], ) +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO", + good_models_info(), +) async def test_create_model_bundle_inference_framework_image_tag_validation( test_api_key: str, fake_model_bundle_repository, From 866bcd19a0e5f54f34a96c28fd770b3c6ce4fcb8 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:56:22 -0700 Subject: [PATCH 288/425] guided decoding with grammar (#488) * support guided decoding with grammar * 0.4.1 fixes --- docs/guides/completions.md | 16 ++++++++++++++++ .../model_engine_server/common/dtos/llms.py | 14 +++++++++++--- .../use_cases/llm_model_endpoint_use_cases.py | 9 ++++++++- .../inference/vllm/requirements.txt | 2 +- .../inference/vllm/vllm_server.py | 5 ++++- .../tests/unit/domain/test_llm_use_cases.py | 2 ++ 6 files changed, 42 insertions(+), 6 deletions(-) diff --git a/docs/guides/completions.md b/docs/guides/completions.md index 86bb9f0b..8bfad184 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -246,6 +246,22 @@ print(response.json()) # {"request_id":"5b184654-96b6-4932-9eb6-382a51fdb3d5","output":{"text":"{\"myString\" : \"John Doe","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}} ``` +=== "Guided decoding with Context-Free Grammar" + +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_grammar="start: \"John\"" +) + +print(response.json()) +# {"request_id": "34621b44-c655-402c-a459-f108b3e49b12", "output": {"text": "John", "num_prompt_tokens": 6, "num_completion_tokens": 4, "tokens": None}} + ## Which model should I use? See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions. diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 6f63e712..c84c6931 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -186,15 +186,19 @@ class CompletionSyncV1Request(BaseModel): """ guided_json: Optional[Dict[str, Any]] = None """ - JSON schema for guided decoding. + JSON schema for guided decoding. Only supported in vllm. """ guided_regex: Optional[str] = None """ - Regex for guided decoding. + Regex for guided decoding. Only supported in vllm. """ guided_choice: Optional[List[str]] = None """ - Choices for guided decoding. + Choices for guided decoding. Only supported in vllm. + """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. """ @@ -272,6 +276,10 @@ class CompletionStreamV1Request(BaseModel): """ Choices for guided decoding. Only supported in vllm. """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. + """ class CompletionStreamOutput(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 189a4c24..5ac3084e 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1381,16 +1381,19 @@ def validate_and_update_completion_params( guided_count += 1 if request.guided_regex is not None: guided_count += 1 + if request.guided_grammar is not None: + guided_count += 1 if guided_count > 1: raise ObjectHasInvalidValueException( - "Only one of guided_json, guided_choice, guided_regex can be enabled." + "Only one of guided_json, guided_choice, guided_regex, guided_grammar can be enabled." ) if ( request.guided_choice is not None or request.guided_regex is not None or request.guided_json is not None + or request.guided_grammar is not None ) and not inference_framework == LLMInferenceFramework.VLLM: raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.") @@ -1691,6 +1694,8 @@ async def execute( vllm_args["guided_regex"] = request.guided_regex if request.guided_json is not None: vllm_args["guided_json"] = request.guided_json + if request.guided_grammar is not None: + vllm_args["guided_grammar"] = request.guided_grammar inference_request = SyncEndpointPredictV1Request( args=vllm_args, @@ -1959,6 +1964,8 @@ async def execute( args["guided_regex"] = request.guided_regex if request.guided_json is not None: args["guided_json"] = request.guided_json + if request.guided_grammar is not None: + args["guided_grammar"] = request.guided_grammar args["stream"] = True elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: args = { diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index d0e331f4..c4d967d7 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,2 +1,2 @@ -vllm==0.4.0.post1 +vllm==0.4.1 pydantic>=2.0 diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index c7ef4b43..94e49fcf 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -45,6 +45,7 @@ async def generate(request: Request) -> Response: guided_json = request_dict.pop("guided_json", None) guided_regex = request_dict.pop("guided_regex", None) guided_choice = request_dict.pop("guided_choice", None) + guided_grammar = request_dict.pop("guided_grammar", None) sampling_params = SamplingParams(**request_dict) # Dummy request to get guided decode logit processor @@ -56,6 +57,7 @@ async def generate(request: Request) -> Response: "guided_json": guided_json, "guided_regex": guided_regex, "guided_choice": guided_choice, + "guided_grammar": guided_grammar, } ) except Exception: @@ -63,8 +65,9 @@ async def generate(request: Request) -> Response: status_code=400, detail="Bad request: failed to parse guided decoding parameters." ) + guided_decoding_backend = engine.engine.decoding_config.guided_decoding_backend guided_decode_logit_processor = await get_guided_decoding_logits_processor( - partial_openai_request, engine.get_tokenizer() + guided_decoding_backend, partial_openai_request, await engine.get_tokenizer() ) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 4dcfaba5..232e0626 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1108,11 +1108,13 @@ async def test_validate_and_update_completion_params(): completion_sync_request.guided_regex = "" completion_sync_request.guided_json = {} completion_sync_request.guided_choice = [""] + completion_sync_request.guided_grammar = "" with pytest.raises(ObjectHasInvalidValueException): validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) completion_sync_request.guided_regex = None completion_sync_request.guided_choice = None + completion_sync_request.guided_grammar = None with pytest.raises(ObjectHasInvalidValueException): validate_and_update_completion_params( LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request From 9d0e4334c7a72c6742884e5957de474eebdd9749 Mon Sep 17 00:00:00 2001 From: Ian Macleod <139901935+ian-scale@users.noreply.github.com> Date: Tue, 30 Apr 2024 14:52:07 -0700 Subject: [PATCH 289/425] adding asyncenginedead error catch (#504) * adding asyncenginedead error catch * catch error in generation --- .../model_engine_server/inference/vllm/vllm_server.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 94e49fcf..3a966f15 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -10,7 +10,7 @@ from fastapi import BackgroundTasks, FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.async_llm_engine import AsyncEngineDeadError, AsyncLLMEngine from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.outputs import CompletionOutput @@ -75,7 +75,11 @@ async def generate(request: Request) -> Response: sampling_params.logits_processors.append(guided_decode_logit_processor) request_id = random_uuid() - results_generator = engine.generate(prompt, sampling_params, request_id) + try: + results_generator = engine.generate(prompt, sampling_params, request_id) + except AsyncEngineDeadError as e: + print(f"The vllm engine is dead, exiting the pod: {e}") + exit(1) # Streaming case async def stream_results() -> AsyncGenerator[str, None]: @@ -192,6 +196,7 @@ def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) + engine.check_health() signal.signal(signal.SIGUSR1, debug) From 6f8870c04cfe94893a82385976c9503180d155b2 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Thu, 2 May 2024 15:08:29 -0700 Subject: [PATCH 290/425] Default include_stop_str_in_output to None (#506) --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/completion.py | 4 ++-- clients/python/llmengine/data_types.py | 4 ++-- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index df8ffbd0..7a3d4c96 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b30" +__version__ = "0.0.0b31" import os from typing import Sequence diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 0181b733..01aa86a9 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -43,7 +43,7 @@ async def acreate( frequency_penalty: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - include_stop_str_in_output: Optional[bool] = False, + include_stop_str_in_output: Optional[bool] = None, guided_json: Optional[Dict[str, Any]] = None, guided_regex: Optional[str] = None, guided_choice: Optional[List[str]] = None, @@ -257,7 +257,7 @@ def create( frequency_penalty: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - include_stop_str_in_output: Optional[bool] = False, + include_stop_str_in_output: Optional[bool] = None, guided_json: Optional[Dict[str, Any]] = None, guided_regex: Optional[str] = None, guided_choice: Optional[List[str]] = None, diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 07d2622b..a9d65d1a 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -324,7 +324,7 @@ class CompletionSyncV1Request(BaseModel): frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_k: Optional[int] = Field(default=None, ge=-1) top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) - include_stop_str_in_output: Optional[bool] = Field(default=False) + include_stop_str_in_output: Optional[bool] = Field(default=None) guided_json: Optional[Dict[str, Any]] = Field(default=None) guided_regex: Optional[str] = Field(default=None) guided_choice: Optional[List[str]] = Field(default=None) @@ -398,7 +398,7 @@ class CompletionStreamV1Request(BaseModel): frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_k: Optional[int] = Field(default=None, ge=-1) top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) - include_stop_str_in_output: Optional[bool] = Field(default=False) + include_stop_str_in_output: Optional[bool] = Field(default=None) guided_json: Optional[Dict[str, Any]] = Field(default=None) guided_regex: Optional[str] = Field(default=None) guided_choice: Optional[List[str]] = Field(default=None) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 24c5a3b5..81c1f98a 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta30" +version = "0.0.0.beta31" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index eaeb7507..a4ba34de 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta30", + version="0.0.0.beta31", packages=find_packages(), ) From a2bf698e38777eb49ab184c9a38cbe3cc6066c29 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Fri, 3 May 2024 14:11:20 -0700 Subject: [PATCH 291/425] get latest inference framework tag from configmap (#505) * get latest inference framework tag from configmap * comments * fix for test * make namespace a config * fix s3 prefix bug * fix checkpoint path fn + tests * values change * quotes --- .../templates/inference_framework_config.yaml | 16 +++++ .../templates/service_config_map.yaml | 2 + .../model_engine_server/common/config.py | 1 + .../model_engine_server/core/configmap.py | 35 +++++++++ .../model_engine_server/domain/exceptions.py | 6 ++ .../use_cases/llm_model_endpoint_use_cases.py | 36 +++++++--- .../service_config_circleci.yaml | 3 + model-engine/tests/unit/domain/conftest.py | 3 +- .../tests/unit/domain/test_llm_use_cases.py | 71 +++++++++---------- 9 files changed, 122 insertions(+), 51 deletions(-) create mode 100644 charts/model-engine/templates/inference_framework_config.yaml create mode 100644 model-engine/model_engine_server/core/configmap.py diff --git a/charts/model-engine/templates/inference_framework_config.yaml b/charts/model-engine/templates/inference_framework_config.yaml new file mode 100644 index 00000000..d81d5be2 --- /dev/null +++ b/charts/model-engine/templates/inference_framework_config.yaml @@ -0,0 +1,16 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: model-engine-inference-framework-latest-config + labels: + product: common + team: infra + annotations: + "helm.sh/hook": pre-install + "helm.sh/hook-weight": "-2" +data: + deepspeed: "latest" + text_generation_inference: "latest" + vllm: "latest" + lightllm: "latest" + tensorrt_llm: "latest" diff --git a/charts/model-engine/templates/service_config_map.yaml b/charts/model-engine/templates/service_config_map.yaml index 70a12755..403bb552 100644 --- a/charts/model-engine/templates/service_config_map.yaml +++ b/charts/model-engine/templates/service_config_map.yaml @@ -11,6 +11,7 @@ metadata: data: launch_service_config: |- dd_trace_enabled: {{ .Values.dd_trace_enabled | default false | quote }} + gateway_namespace: {{ .Release.Namespace | quote }} {{- with .Values.config.values.launch }} {{- range $key, $value := . }} {{ $key }}: {{ $value | quote }} @@ -39,6 +40,7 @@ metadata: data: launch_service_config: |- dd_trace_enabled: {{ .Values.dd_trace_enabled | default false | quote }} + gateway_namespace: {{ .Release.Namespace | quote }} {{- with .Values.config.values.launch }} {{- range $key, $value := . }} {{ $key }}: {{ $value | quote }} diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index ac92cf43..dd18a1c5 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -47,6 +47,7 @@ def get_model_cache_directory_name(model_name: str): @dataclass class HostedModelInferenceServiceConfig: + gateway_namespace: str endpoint_namespace: str billing_queue_arn: str sqs_profile: str diff --git a/model-engine/model_engine_server/core/configmap.py b/model-engine/model_engine_server/core/configmap.py new file mode 100644 index 00000000..d3edb669 --- /dev/null +++ b/model-engine/model_engine_server/core/configmap.py @@ -0,0 +1,35 @@ +"""Read configmap from k8s.""" + +from typing import Dict + +from kubernetes_asyncio import client +from kubernetes_asyncio import config as kube_config +from kubernetes_asyncio.client.rest import ApiException +from kubernetes_asyncio.config.config_exception import ConfigException +from model_engine_server.common.config import hmi_config +from model_engine_server.core.loggers import logger_name, make_logger + +DEFAULT_NAMESPACE = "default" + +logger = make_logger(logger_name()) + + +async def read_config_map( + config_map_name: str, namespace: str = hmi_config.gateway_namespace +) -> Dict[str, str]: + try: + kube_config.load_incluster_config() + except ConfigException: + logger.info("No incluster kubernetes config, falling back to local") + await kube_config.load_kube_config() + + core_api = client.CoreV1Api() + + try: + config_map = await core_api.read_namespaced_config_map( + name=config_map_name, namespace=namespace + ) + return config_map.data + except ApiException as e: + logger.exception(f"Error reading configmap {config_map_name}") + raise e diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index c64e3beb..e9ded985 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -182,3 +182,9 @@ class PostInferenceHooksException(DomainException): """ Thrown if the post inference hooks are invalid. """ + + +class LatestImageTagNotFoundException(DomainException): + """ + Thrown if the latest image tag cannot be found. + """ diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 5ac3084e..5f68a5bd 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -39,6 +39,7 @@ from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.configmap import read_config_map from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, @@ -67,6 +68,7 @@ EndpointLabelsException, EndpointUnsupportedInferenceTypeException, InvalidRequestException, + LatestImageTagNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, @@ -82,7 +84,10 @@ ) from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway -from model_engine_server.infra.repositories.live_tokenizer_repository import SUPPORTED_MODELS_INFO +from model_engine_server.infra.repositories.live_tokenizer_repository import ( + SUPPORTED_MODELS_INFO, + get_models_s3_uri, +) from ...common.datadog_utils import add_trace_request_id from ..authorization.live_authorization_module import LiveAuthorizationModule @@ -246,6 +251,8 @@ NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes +LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = "model-engine-inference-framework-latest-config" + def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int: """ @@ -255,6 +262,15 @@ def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRep return len(tokenizer.encode(input)) +async def _get_latest_tag(inference_framework: LLMInferenceFramework) -> str: + config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME) + if inference_framework not in config_map: + raise LatestImageTagNotFoundException( + f"Could not find latest tag for inference framework {inference_framework}." + ) + return config_map[inference_framework] + + def _include_safetensors_bin_or_pt(model_files: List[str]) -> Optional[str]: """ This function is used to determine whether to include "*.safetensors", "*.bin", or "*.pt" files @@ -337,11 +353,11 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None: def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str: - checkpoint_path = ( - SUPPORTED_MODELS_INFO[model_name].s3_repo - if not checkpoint_path_override - else checkpoint_path_override - ) + checkpoint_path = None + if SUPPORTED_MODELS_INFO[model_name].s3_repo: + checkpoint_path = get_models_s3_uri(SUPPORTED_MODELS_INFO[model_name].s3_repo, "") + if checkpoint_path_override: + checkpoint_path = checkpoint_path_override if not checkpoint_path: raise InvalidRequestException(f"No checkpoint path found for model {model_name}") @@ -931,8 +947,8 @@ async def execute( ) if request.inference_framework_image_tag == "latest": - request.inference_framework_image_tag = self.docker_repository.get_latest_image_tag( - INFERENCE_FRAMEWORK_REPOSITORY[request.inference_framework] + request.inference_framework_image_tag = await _get_latest_tag( + request.inference_framework ) bundle = await self.create_llm_model_bundle_use_case.execute( @@ -1149,9 +1165,7 @@ async def execute( inference_framework = llm_metadata["inference_framework"] if request.inference_framework_image_tag == "latest": - inference_framework_image_tag = self.docker_repository.get_latest_image_tag( - INFERENCE_FRAMEWORK_REPOSITORY[inference_framework] - ) + inference_framework_image_tag = await _get_latest_tag(inference_framework) else: inference_framework_image_tag = ( request.inference_framework_image_tag diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index 001a54b7..de998c27 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -1,3 +1,6 @@ +# Config to know where model-engine is running +gateway_namespace: default + # Config for Model Engine running in CircleCI model_primitive_host: "none" diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index bbf25058..18aa4470 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -222,7 +222,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request labels={"team": "infra", "product": "my_product"}, aws_role="test_aws_role", results_s3_bucket="test_s3_bucket", - checkpoint_path="s3://test_checkpoint_path", + checkpoint_path="s3://test-s3.tar", ) @@ -286,6 +286,7 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque labels={"team": "infra", "product": "my_product"}, aws_role="test_aws_role", results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://test-s3.tar", ) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 232e0626..166c2149 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple from unittest import mock import pytest @@ -54,21 +54,19 @@ validate_and_update_completion_params, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase -from model_engine_server.infra.repositories import live_tokenizer_repository -from model_engine_server.infra.repositories.live_tokenizer_repository import ModelInfo -def good_models_info() -> Dict[str, ModelInfo]: - return { - k: ModelInfo(v.hf_repo, "s3://test-s3.tar") - for k, v in live_tokenizer_repository.SUPPORTED_MODELS_INFO.items() - } +def mocked__get_latest_tag(): + async def async_mock(*args, **kwargs): # noqa + return "fake_docker_repository_latest_image_tag" + + return mock.AsyncMock(side_effect=async_mock) @pytest.mark.asyncio @mock.patch( - "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO", - good_models_info(), + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag", + mocked__get_latest_tag(), ) async def test_create_model_endpoint_use_case_success( test_api_key: str, @@ -183,40 +181,33 @@ async def test_create_model_endpoint_use_case_success( assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1] -def bad_models_info() -> Dict[str, ModelInfo]: - info = { - k: ModelInfo(v.hf_repo, v.s3_repo) - for k, v in live_tokenizer_repository.SUPPORTED_MODELS_INFO.items() - } - info.update( - { - "mpt-7b": ModelInfo("mosaicml/mpt-7b", None), - "mpt-7b-instruct": ModelInfo("mosaicml/mpt-7b-instruct", "gibberish"), - } - ) - return info - - @pytest.mark.asyncio @pytest.mark.parametrize( - "inference_framework, model_name, expected_error", + "inference_framework, model_name, checkpoint_path, expected_error", [ - (LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "mpt-7b", InvalidRequestException), + (LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "mpt-7b", None, InvalidRequestException), ( LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "mpt-7b-instruct", + "gibberish", + ObjectHasInvalidValueException, + ), + (LLMInferenceFramework.LIGHTLLM, "mpt-7b", None, InvalidRequestException), + ( + LLMInferenceFramework.LIGHTLLM, + "mpt-7b-instruct", + "gibberish", + ObjectHasInvalidValueException, + ), + (LLMInferenceFramework.VLLM, "mpt-7b", None, InvalidRequestException), + ( + LLMInferenceFramework.VLLM, + "mpt-7b-instruct", + "gibberish", ObjectHasInvalidValueException, ), - (LLMInferenceFramework.LIGHTLLM, "mpt-7b", InvalidRequestException), - (LLMInferenceFramework.LIGHTLLM, "mpt-7b-instruct", ObjectHasInvalidValueException), - (LLMInferenceFramework.VLLM, "mpt-7b", InvalidRequestException), - (LLMInferenceFramework.VLLM, "mpt-7b-instruct", ObjectHasInvalidValueException), ], ) -@mock.patch( - "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO", - bad_models_info(), -) async def test_create_model_bundle_fails_if_no_checkpoint( test_api_key: str, fake_model_bundle_repository, @@ -227,6 +218,7 @@ async def test_create_model_bundle_fails_if_no_checkpoint( create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request, inference_framework, model_name, + checkpoint_path, expected_error, ): fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository @@ -255,7 +247,7 @@ async def test_create_model_bundle_fails_if_no_checkpoint( endpoint_type=request.endpoint_type, num_shards=request.num_shards, quantize=request.quantize, - checkpoint_path=None, + checkpoint_path=checkpoint_path, ) @@ -269,10 +261,6 @@ async def test_create_model_bundle_fails_if_no_checkpoint( (True, LLMInferenceFramework.VLLM, "0.1.3.6"), ], ) -@mock.patch( - "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.SUPPORTED_MODELS_INFO", - good_models_info(), -) async def test_create_model_bundle_inference_framework_image_tag_validation( test_api_key: str, fake_model_bundle_repository, @@ -307,6 +295,7 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() request.inference_framework = inference_framework request.inference_framework_image_tag = inference_framework_image_tag + request.checkpoint_path = "s3://test-s3.tar" user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) if valid: await use_case.execute(user=user, request=request) @@ -592,6 +581,10 @@ async def test_get_llm_model_endpoint_use_case_raises_not_authorized( @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag", + mocked__get_latest_tag(), +) async def test_update_model_endpoint_use_case_success( test_api_key: str, fake_model_bundle_repository, From 70d0e771869c67ef0874bd433d11d722d747bfd5 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Fri, 3 May 2024 14:12:41 -0700 Subject: [PATCH 292/425] integration tests for completions (#507) * get latest inference framework tag from configmap * comments * fix for test * make namespace a config * fix s3 prefix bug * fix checkpoint path fn + tests * integration tests for completions * values change * quotes --- integration_tests/rest_api_utils.py | 242 +++++++++++++++++++++++++- integration_tests/test_completions.py | 96 ++++++++++ 2 files changed, 337 insertions(+), 1 deletion(-) create mode 100644 integration_tests/test_completions.py diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index 1e780c37..285e2c1d 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -2,8 +2,9 @@ import inspect import json import os +import re import time -from typing import Any, Dict, List, Sequence +from typing import Any, Dict, List, Optional, Sequence import aiohttp import requests @@ -14,6 +15,7 @@ BASE_PATH = os.environ.get("BASE_PATH", _DEFAULT_BASE_PATH) print(f"Integration tests using gateway {BASE_PATH=}") DEFAULT_NETWORK_TIMEOUT_SEC = 10 +LONG_NETWORK_TIMEOUT_SEC = 30 # add suffix to avoid name collisions SERVICE_IDENTIFIER = os.environ.get("SERVICE_IDENTIFIER", "") @@ -164,12 +166,87 @@ def my_model(**keyword_args): "url": None, } +CREATE_LLM_MODEL_ENDPOINT_REQUEST: Dict[str, Any] = { + "name": format_name("llama-2-7b-test"), + "model_name": "llama-2-7b", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "endpoint_type": "streaming", + "cpus": 20, + "gpus": 1, + "memory": "20Gi", + "gpu_type": "nvidia-ampere-a10", + "storage": "40Gi", + "optimize_costs": False, + "min_workers": 1, + "max_workers": 1, + "per_worker": 1, + "labels": {"team": "infra", "product": "launch"}, + "metadata": {"key": "value"}, + "public_inference": False, +} + + INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE: Dict[str, Any] = INFERENCE_PAYLOAD.copy() INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE["return_pickled"] = False INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE: Dict[str, Any] = INFERENCE_PAYLOAD.copy() INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE["return_pickled"] = True +LLM_PAYLOAD: Dict[str, Any] = { + "prompt": "Hello, my name is", + "max_new_tokens": 10, + "temperature": 0.2, +} + +LLM_PAYLOAD_WITH_STOP_SEQUENCE: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_STOP_SEQUENCE["stop_sequences"] = ["\n"] + +LLM_PAYLOAD_WITH_PRESENCE_PENALTY: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_PRESENCE_PENALTY["presence_penalty"] = 0.5 + +LLM_PAYLOAD_WITH_FREQUENCY_PENALTY: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_FREQUENCY_PENALTY["frequency_penalty"] = 0.5 + +LLM_PAYLOAD_WITH_TOP_K: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_TOP_K["top_k"] = 10 + +LLM_PAYLOAD_WITH_TOP_P: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_TOP_P["top_p"] = 0.5 + +LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT["include_stop_str_in_output"] = True + +LLM_PAYLOAD_WITH_GUIDED_JSON: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_GUIDED_JSON["guided_json"] = { + "properties": {"myString": {"type": "string"}}, + "required": ["myString"], +} + +LLM_PAYLOAD_WITH_GUIDED_REGEX: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_GUIDED_REGEX["guided_regex"] = "Sean.*" + +LLM_PAYLOAD_WITH_GUIDED_CHOICE: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_GUIDED_CHOICE["guided_choice"] = ["dog", "cat"] + +LLM_PAYLOAD_WITH_GUIDED_GRAMMAR: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_GUIDED_GRAMMAR["guided_grammar"] = 'start: "John"' + +LLM_PAYLOADS_WITH_EXPECTED_RESPONSES = [ + (LLM_PAYLOAD, None, None), + (LLM_PAYLOAD_WITH_STOP_SEQUENCE, None, None), + (LLM_PAYLOAD_WITH_PRESENCE_PENALTY, None, None), + (LLM_PAYLOAD_WITH_FREQUENCY_PENALTY, None, None), + (LLM_PAYLOAD_WITH_TOP_K, None, None), + (LLM_PAYLOAD_WITH_TOP_P, None, None), + (LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT, ["tokens"], None), + (LLM_PAYLOAD_WITH_GUIDED_JSON, None, None), + (LLM_PAYLOAD_WITH_GUIDED_REGEX, None, "Sean.*"), + (LLM_PAYLOAD_WITH_GUIDED_CHOICE, None, "dog|cat"), + (LLM_PAYLOAD_WITH_GUIDED_GRAMMAR, None, "John"), +] + CREATE_BATCH_JOB_REQUEST: Dict[str, Any] = { "bundle_name": "model_bundle_simple", "input_path": "TBA", @@ -524,6 +601,18 @@ def get_model_endpoint(name: str, user_id: str) -> Dict[str, Any]: return response.json()["model_endpoints"][0] +@retry(stop=stop_after_attempt(6), wait=wait_fixed(1)) +def get_llm_model_endpoint(name: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/llm/model-endpoints/{name}", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(20)) def update_model_endpoint( endpoint_name: str, update_model_endpoint_request: Dict[str, Any], user_id: str @@ -556,6 +645,18 @@ def delete_model_endpoint(endpoint_name: str, user_id: str) -> Dict[str, Any]: return response.json() +def delete_llm_model_endpoint(endpoint_name: str, user_id: str) -> Dict[str, Any]: + response = requests.delete( + f"{BASE_PATH}/v1/llm/model-endpoints/{endpoint_name}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) def list_model_endpoints(user_id: str) -> List[Dict[str, Any]]: response = requests.get( @@ -568,6 +669,44 @@ def list_model_endpoints(user_id: str) -> List[Dict[str, Any]]: return response.json()["model_endpoints"] +@retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) +def list_llm_model_endpoints(user_id: str) -> List[Dict[str, Any]]: + response = requests.get( + f"{BASE_PATH}/v1/llm/model-endpoints", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json()["model_endpoints"] + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) +def create_llm_model_endpoint( + create_llm_model_endpoint_request: Dict[str, Any], + user_id: str, + inference_framework: Optional[str], + inference_framework_image_tag: Optional[str], +) -> Dict[str, Any]: + create_model_endpoint_request = create_llm_model_endpoint_request.copy() + if inference_framework: + create_model_endpoint_request["inference_framework"] = inference_framework + if inference_framework_image_tag: + create_model_endpoint_request[ + "inference_framework_image_tag" + ] = inference_framework_image_tag + response = requests.post( + f"{BASE_PATH}/v1/llm/model-endpoints", + json=create_model_endpoint_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + async def create_async_task( model_endpoint_id: str, create_async_task_request: Dict[str, Any], @@ -615,6 +754,23 @@ async def create_sync_task( return await response.json() +async def create_llm_sync_task( + model_endpoint_name: str, + create_sync_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/llm/completions-sync?model_endpoint_name={model_endpoint_name}", + json=create_sync_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=LONG_NETWORK_TIMEOUT_SEC, + ) as response: + assert response.status == 200, (await response.read()).decode() + return await response.json() + + async def create_streaming_task( model_endpoint_id: str, create_streaming_task_request: Dict[str, Any], @@ -632,6 +788,23 @@ async def create_streaming_task( return (await response.read()).decode() +async def create_llm_streaming_task( + model_endpoint_name: str, + create_streaming_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/llm/completions-stream?model_endpoint_name={model_endpoint_name}", + json=create_streaming_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=LONG_NETWORK_TIMEOUT_SEC, + ) as response: + assert response.status == 200, (await response.read()).decode() + return await response.json() + + async def create_sync_tasks( endpoint_name: str, create_sync_task_requests: List[Dict[str, Any]], user_id: str ) -> List[Any]: @@ -646,6 +819,19 @@ async def create_sync_tasks( return result # type: ignore +async def create_llm_sync_tasks( + endpoint_name: str, create_sync_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + async with aiohttp.ClientSession() as session: + tasks = [] + for create_sync_task_request in create_sync_task_requests: + task = create_llm_sync_task(endpoint_name, create_sync_task_request, user_id, session) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + async def create_streaming_tasks( endpoint_name: str, create_streaming_task_requests: List[Dict[str, Any]], user_id: str ) -> List[Any]: @@ -662,6 +848,21 @@ async def create_streaming_tasks( return result # type: ignore +async def create_llm_streaming_tasks( + endpoint_name: str, create_streaming_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + async with aiohttp.ClientSession() as session: + tasks = [] + for create_streaming_task_request in create_streaming_task_requests: + task = create_llm_streaming_task( + endpoint_name, create_streaming_task_request, user_id, session + ) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + async def get_async_task( task_id: str, user_id: str, session: aiohttp.ClientSession ) -> Dict[str, Any]: @@ -708,6 +909,22 @@ def ensure_n_ready_endpoints_short(n: int, user_id: str): assert len(ready_endpoints) >= n +# Wait 2 minutes (120 seconds) for endpoints to build. +@retry(stop=stop_after_attempt(12), wait=wait_fixed(10)) +def ensure_n_ready_private_llm_endpoints_short(n: int, user_id: str): + endpoints = list_llm_model_endpoints(user_id) + private_endpoints = [ + endpoint for endpoint in endpoints if not endpoint["spec"]["public_inference"] + ] + ready_endpoints = [endpoint for endpoint in private_endpoints if endpoint["status"] == "READY"] + print( + f"User {user_id} Current num endpoints: {len(private_endpoints)}, num ready endpoints: {len(ready_endpoints)}" + ) + assert ( + len(ready_endpoints) >= n + ), f"Expected {n} ready endpoints, got {len(ready_endpoints)}. Look through endpoint builder for errors." + + def delete_all_endpoints(user_id: str, delete_suffix_only: bool): endpoints = list_model_endpoints(user_id) for i, endpoint in enumerate(endpoints): @@ -737,6 +954,13 @@ def ensure_nonzero_available_workers(endpoint_name: str, user_id: str): assert simple_endpoint.get("deployment_state", {}).get("available_workers", 0) +# Wait up to 20 minutes (1200 seconds) for the pods to spin up. +@retry(stop=stop_after_attempt(120), wait=wait_fixed(10)) +def ensure_nonzero_available_llm_workers(endpoint_name: str, user_id: str): + simple_endpoint = get_llm_model_endpoint(endpoint_name, user_id) + assert simple_endpoint["spec"].get("deployment_state", {}).get("available_workers", 0) + + def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_pickled: bool): print(response) assert response["status"] == "SUCCESS" @@ -747,6 +971,22 @@ def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_p assert response["result"] == {"result": '{"y": 1}'} +def ensure_llm_task_response_is_correct( + response: Dict[str, Any], + required_output_fields: Optional[List[str]], + response_text_regex: Optional[str], +): + print(response) + assert response["output"] is not None + + if required_output_fields is not None: + for field in required_output_fields: + assert field in response["output"] + + if response_text_regex is not None: + assert re.search(response_text_regex, response["output"]["text"]) + + # Wait up to 30 seconds for the tasks to be returned. @retry( stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError) diff --git a/integration_tests/test_completions.py b/integration_tests/test_completions.py new file mode 100644 index 00000000..01dcdc2d --- /dev/null +++ b/integration_tests/test_completions.py @@ -0,0 +1,96 @@ +import asyncio +import os + +import pytest + +from .rest_api_utils import ( + CREATE_LLM_MODEL_ENDPOINT_REQUEST, + LLM_PAYLOADS_WITH_EXPECTED_RESPONSES, + USER_ID_0, + create_llm_model_endpoint, + create_llm_streaming_tasks, + create_llm_sync_tasks, + delete_llm_model_endpoint, + ensure_llm_task_response_is_correct, + ensure_n_ready_private_llm_endpoints_short, + ensure_nonzero_available_llm_workers, +) + +TEST_INFERENCE_FRAMEWORK = os.environ.get("TEST_INFERENCE_FRAMEWORK", None) +TEST_INFERENCE_FRAMEWORK_IMAGE_TAG = os.environ.get("TEST_INFERENCE_FRAMEWORK_IMAGE_TAG", None) +print(f"TEST_INFERENCE_FRAMEWORK={TEST_INFERENCE_FRAMEWORK}") + + +@pytest.mark.skipif( + (not TEST_INFERENCE_FRAMEWORK) or (not TEST_INFERENCE_FRAMEWORK_IMAGE_TAG), + reason="Skip unless running inference framework tests", +) +def test_completions(capsys): + with capsys.disabled(): + try: + user = USER_ID_0 + create_endpoint_request = CREATE_LLM_MODEL_ENDPOINT_REQUEST + + print(f"Creating {create_endpoint_request['name']} model endpoint...") + create_llm_model_endpoint( + create_endpoint_request, + user, + TEST_INFERENCE_FRAMEWORK, + TEST_INFERENCE_FRAMEWORK_IMAGE_TAG, + ) + ensure_n_ready_private_llm_endpoints_short(1, user) + ensure_nonzero_available_llm_workers(create_endpoint_request["name"], user) + + for ( + completions_payload, + required_output_fields, + response_text_regex, + ) in LLM_PAYLOADS_WITH_EXPECTED_RESPONSES: + print( + f"Sending sync tasks to {create_endpoint_request['name']} for user {user}, {completions_payload=}..." + ) + try: + task_responses = asyncio.run( + create_llm_sync_tasks( + create_endpoint_request["name"], + [completions_payload], + user, + ) + ) + for response in task_responses: + ensure_llm_task_response_is_correct( + response, required_output_fields, response_text_regex + ) + except Exception as e: + if hasattr(e, "response") and e.response.status_code // 100 == 4: + print(f"Got 4xx status code for {completions_payload=}, which is expected") + else: + raise e + + for ( + completions_payload, + required_output_fields, + response_text_regex, + ) in LLM_PAYLOADS_WITH_EXPECTED_RESPONSES: + print( + f"Sending streaming tasks to {create_endpoint_request['name']} for user {user}, {completions_payload=}..." + ) + try: + task_responses = asyncio.run( + create_llm_streaming_tasks( + create_endpoint_request["name"], + [completions_payload], + user, + ) + ) + for response in task_responses: + ensure_llm_task_response_is_correct( + response, required_output_fields, response_text_regex + ) + except Exception as e: + if hasattr(e, "response") and e.response.status_code // 100 == 4: + print(f"Got 4xx status code for {completions_payload=}, which is expected") + else: + raise e + finally: + delete_llm_model_endpoint(create_endpoint_request["name"], user) From 13da4c136c11f6ceb9a7c6e17e698df893eddfe0 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Sat, 4 May 2024 14:17:22 -0700 Subject: [PATCH 293/425] patch service config identifier (#509) --- .../model-engine/templates/inference_framework_config.yaml | 2 +- .../domain/use_cases/llm_model_endpoint_use_cases.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/charts/model-engine/templates/inference_framework_config.yaml b/charts/model-engine/templates/inference_framework_config.yaml index d81d5be2..d97d1920 100644 --- a/charts/model-engine/templates/inference_framework_config.yaml +++ b/charts/model-engine/templates/inference_framework_config.yaml @@ -1,7 +1,7 @@ apiVersion: v1 kind: ConfigMap metadata: - name: model-engine-inference-framework-latest-config + name: {{ include "modelEngine.fullname" . }}-inference-framework-latest-config labels: product: common team: infra diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 5f68a5bd..4f0e2fbf 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -251,7 +251,11 @@ NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes -LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = "model-engine-inference-framework-latest-config" +SERVICE_NAME = "model-engine" +SERVICE_IDENTIFIER = os.getenv("SERVICE_IDENTIFIER") +if SERVICE_IDENTIFIER: + SERVICE_NAME += f"-{SERVICE_IDENTIFIER}" +LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = f"{SERVICE_NAME}-inference-framework-latest-config" def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int: From a87e5aaba4eef9896a76443bca88683f4240ea40 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Mon, 6 May 2024 13:26:57 -0400 Subject: [PATCH 294/425] require safetensors (#510) --- .../use_cases/llm_model_endpoint_use_cases.py | 73 ++++++------------- model-engine/tests/unit/conftest.py | 6 +- model-engine/tests/unit/domain/conftest.py | 9 ++- .../tests/unit/domain/test_llm_use_cases.py | 39 +++------- 4 files changed, 43 insertions(+), 84 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 4f0e2fbf..ceef5d2c 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -275,24 +275,6 @@ async def _get_latest_tag(inference_framework: LLMInferenceFramework) -> str: return config_map[inference_framework] -def _include_safetensors_bin_or_pt(model_files: List[str]) -> Optional[str]: - """ - This function is used to determine whether to include "*.safetensors", "*.bin", or "*.pt" files - based on which file type is present most often in the checkpoint folder. The most - frequently present file type is included. - In case of ties, priority is given to "*.safetensors", then "*.bin", then "*.pt". - """ - num_safetensors = len([f for f in model_files if f.endswith(".safetensors")]) - num_bin = len([f for f in model_files if f.endswith(".bin")]) - num_pt = len([f for f in model_files if f.endswith(".pt")]) - maximum = max(num_safetensors, num_bin, num_pt) - if num_safetensors == maximum: - return "*.safetensors" - if num_bin == maximum: - return "*.bin" - return "*.pt" - - def _model_endpoint_entity_to_get_llm_model_endpoint_response( model_endpoint: ModelEndpoint, ) -> GetLLMModelEndpointV1Response: @@ -354,6 +336,10 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None: raise ObjectHasInvalidValueException( f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." ) + if checkpoint_path.endswith(".tar"): + raise ObjectHasInvalidValueException( + f"Tar files are not supported. Given checkpoint path: {checkpoint_path}." + ) def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str: @@ -370,6 +356,14 @@ def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str] return checkpoint_path +def validate_checkpoint_files(checkpoint_files: List[str]) -> None: + """Require safetensors in the checkpoint path.""" + model_files = [f for f in checkpoint_files if "model" in f] + num_safetensors = len([f for f in model_files if f.endswith(".safetensors")]) + if num_safetensors == 0: + raise ObjectHasInvalidValueException("No safetensors found in the checkpoint path.") + + class CreateLLMModelBundleV1UseCase: def __init__( self, @@ -557,27 +551,14 @@ def load_model_weights_sub_commands( else: s5cmd = "./s5cmd" - base_path = checkpoint_path.split("/")[-1] - if base_path.endswith(".tar"): - # If the checkpoint file is a tar file, extract it into final_weights_folder - subcommands.extend( - [ - f"{s5cmd} cp {checkpoint_path} .", - f"mkdir -p {final_weights_folder}", - f"tar --no-same-owner -xf {base_path} -C {final_weights_folder}", - ] - ) - else: - # Let's check whether to exclude "*.safetensors" or "*.bin" files - checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) - model_files = [f for f in checkpoint_files if "model" in f] - - include_str = _include_safetensors_bin_or_pt(model_files) - file_selection_str = f"--include '*.model' --include '*.json' --include '{include_str}' --exclude 'optimizer*'" - subcommands.append( - f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" - ) + checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) + validate_checkpoint_files(checkpoint_files) + # filter to configs ('*.model' and '*.json') and weights ('*.safetensors') + file_selection_str = "--include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*'" + subcommands.append( + f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) return subcommands def load_model_files_sub_commands_trt_llm( @@ -591,19 +572,9 @@ def load_model_files_sub_commands_trt_llm( See llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt """ - subcommands = [] - - base_path = checkpoint_path.split("/")[-1] - - if base_path.endswith(".tar"): - raise ObjectHasInvalidValueException( - "Checkpoint for TensorRT-LLM models must be a folder, not a tar file." - ) - else: - subcommands.append( - f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./" - ) - + subcommands = [ + f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./" + ] return subcommands async def create_deepspeed_bundle( diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index fd6b9f0e..50aa4018 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -757,10 +757,12 @@ class FakeLLMArtifactGateway(LLMArtifactGateway): def __init__(self): self.existing_models = [] self.s3_bucket = { - "fake-checkpoint": ["fake.bin, fake2.bin", "fake3.safetensors"], + "fake-checkpoint": ["model-fake.bin, model-fake2.bin", "model-fake.safetensors"], "llama-7b/tokenizer.json": ["llama-7b/tokenizer.json"], "llama-7b/tokenizer_config.json": ["llama-7b/tokenizer_config.json"], "llama-7b/special_tokens_map.json": ["llama-7b/special_tokens_map.json"], + "llama-2-7b": ["model-fake.safetensors"], + "mpt-7b": ["model-fake.safetensors"], } self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} @@ -768,10 +770,12 @@ def _add_model(self, owner: str, model_name: str): self.existing_models.append((owner, model_name)) def list_files(self, path: str, **kwargs) -> List[str]: + path = path.lstrip("s3://") if path in self.s3_bucket: return self.s3_bucket[path] def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + path = path.lstrip("s3://") if path in self.s3_bucket: return self.s3_bucket[path] diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 18aa4470..a3419890 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -196,6 +196,7 @@ def create_llm_model_endpoint_request_sync() -> CreateLLMModelEndpointV1Request: labels={"team": "infra", "product": "my_product"}, aws_role="test_aws_role", results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://mpt-7b", ) @@ -222,7 +223,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request labels={"team": "infra", "product": "my_product"}, aws_role="test_aws_role", results_s3_bucket="test_s3_bucket", - checkpoint_path="s3://test-s3.tar", + checkpoint_path="s3://llama-2-7b", ) @@ -249,6 +250,7 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req labels={"team": "infra", "product": "my_product"}, aws_role="test_aws_role", results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://mpt-7b", ) @@ -256,7 +258,7 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request: return UpdateLLMModelEndpointV1Request( inference_framework_image_tag="latest", - checkpoint_path="s3://test_checkpoint_path", + checkpoint_path="s3://mpt-7b", memory="4G", min_workers=0, max_workers=1, @@ -286,7 +288,7 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque labels={"team": "infra", "product": "my_product"}, aws_role="test_aws_role", results_s3_bucket="test_s3_bucket", - checkpoint_path="s3://test-s3.tar", + checkpoint_path="s3://llama-2-7b", ) @@ -315,6 +317,7 @@ def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( labels={"team": "infra", "product": "my_product"}, aws_role="test_aws_role", results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://mpt-7b", ) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 166c2149..9ee2fa57 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -49,9 +49,9 @@ GpuType, ModelDownloadV1UseCase, UpdateLLMModelEndpointV1UseCase, - _include_safetensors_bin_or_pt, infer_hardware_from_model_name, validate_and_update_completion_params, + validate_checkpoint_files, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -141,7 +141,7 @@ async def test_create_model_endpoint_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_request_sync.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_request_sync.num_shards, "quantize": None, - "checkpoint_path": None, + "checkpoint_path": create_llm_model_endpoint_request_sync.checkpoint_path, } } @@ -166,7 +166,7 @@ async def test_create_model_endpoint_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_request_streaming.num_shards, "quantize": None, - "checkpoint_path": None, + "checkpoint_path": create_llm_model_endpoint_request_sync.checkpoint_path, } } @@ -295,7 +295,6 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() request.inference_framework = inference_framework request.inference_framework_image_tag = inference_framework_image_tag - request.checkpoint_path = "s3://test-s3.tar" user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) if valid: await use_case.execute(user=user, request=request) @@ -1755,34 +1754,16 @@ async def test_delete_public_inference_model_raises_not_authorized( @pytest.mark.asyncio -async def test_include_safetensors_bin_or_pt_majority_safetensors(): - fake_model_files = ["fake.bin", "fake2.safetensors", "model.json", "optimizer.pt"] - assert _include_safetensors_bin_or_pt(fake_model_files) == "*.safetensors" - - -@pytest.mark.asyncio -async def test_include_safetensors_bin_or_pt_majority_bin(): - fake_model_files = [ - "fake.bin", - "fake2.bin", - "fake3.safetensors", - "model.json", - "optimizer.pt", - "fake4.pt", - ] - assert _include_safetensors_bin_or_pt(fake_model_files) == "*.bin" +async def test_validate_checkpoint_files_no_safetensors(): + fake_model_files = ["model-fake.bin", "model.json", "optimizer.pt"] + with pytest.raises(ObjectHasInvalidValueException): + validate_checkpoint_files(fake_model_files) @pytest.mark.asyncio -async def test_include_safetensors_bin_or_pt_majority_pt(): - fake_model_files = [ - "fake.bin", - "fake2.safetensors", - "model.json", - "optimizer.pt", - "fake3.pt", - ] - assert _include_safetensors_bin_or_pt(fake_model_files) == "*.pt" +async def test_validate_checkpoint_files_safetensors_with_other_files(): + fake_model_files = ["model-fake.bin", "model-fake2.safetensors", "model.json", "optimizer.pt"] + validate_checkpoint_files(fake_model_files) # No exception should be raised def test_infer_hardware_from_model_name(): From e1da2431f455ea71c1569b87dfbde8be149c42f8 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 7 May 2024 10:52:21 -0700 Subject: [PATCH 295/425] Add py.typed for proper typechecking support on clients (#513) * Add py.typed for proper typechecking support on clients * update client version --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/py.typed | 0 clients/python/pyproject.toml | 2 +- clients/python/setup.py | 3 ++- 4 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 clients/python/llmengine/py.typed diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 7a3d4c96..110ed4cc 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b31" +__version__ = "0.0.0b32" import os from typing import Sequence diff --git a/clients/python/llmengine/py.typed b/clients/python/llmengine/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 81c1f98a..b7459272 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta31" +version = "0.0.0.beta32" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index a4ba34de..6bcdf689 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta31", + version="0.0.0.beta32", packages=find_packages(), + package_data={"scale_llm_engine": ["py.typed"]}, ) From 110643566ba28ecd6b10a337e4aef0add1d7f3e4 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 7 May 2024 11:28:39 -0700 Subject: [PATCH 296/425] Fix package name mapping (#514) --- clients/python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/setup.py b/clients/python/setup.py index 6bcdf689..489e428a 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -5,5 +5,5 @@ python_requires=">=3.7", version="0.0.0.beta32", packages=find_packages(), - package_data={"scale_llm_engine": ["py.typed"]}, + package_data={"llmengine": ["py.typed"]}, ) From c019a6a7bf852d2a1bdb344a4e441bd2cef961b4 Mon Sep 17 00:00:00 2001 From: Sam Denton <106690182+sam-scale@users.noreply.github.com> Date: Tue, 14 May 2024 14:09:40 -0400 Subject: [PATCH 297/425] Necessary Changes for long context llama-3-8b (#516) * all necessary changes * tests --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 8 ++++++++ .../infra/repositories/live_tokenizer_repository.py | 1 + model-engine/tests/unit/domain/test_llm_use_cases.py | 7 +++++++ 3 files changed, 16 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index ceef5d2c..375206c7 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -169,6 +169,7 @@ "llama-2-70b-chat", "llama-3-8b", "llama-3-8b-instruct", + "llama-3-8b-instruct-262k", "llama-3-70b", "llama-3-70b-instruct", "falcon-7b", @@ -240,6 +241,7 @@ # Can also see 13B, 34B there too "gemma": {"max_model_len": 8192, "max_num_batched_tokens": 8192}, "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, + "llama-3-8b-instruct-262k": {"max_model_len": None, "max_num_batched_tokens": 262144}, "llama-3": {"max_model_len": None, "max_num_batched_tokens": 8192}, "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, "mixtral-8x7b": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, @@ -2211,6 +2213,12 @@ def infer_hardware_from_model_name(model_name: str) -> CreateDockerImageBatchJob memory = "800Gi" storage = "460Gi" gpu_type = GpuType.NVIDIA_AMPERE_A100E + elif "llama-3-8b-instruct-262k" in model_name: + cpus = "20" + gpus = 2 + memory = "40Gi" + storage = "40Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A100E else: numbers = re.findall(r"\d+", model_name) if len(numbers) == 0: diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 779f08a7..8dff922b 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -43,6 +43,7 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "llama-2-70b-chat": ModelInfo("meta-llama/Llama-2-70b-chat-hf", None), "llama-3-8b": ModelInfo("meta-llama/Meta-Llama-3-8B", None), "llama-3-8b-instruct": ModelInfo("meta-llama/Meta-Llama-3-8B-Instruct", None), + "llama-3-8b-instruct-262k": ModelInfo("gradientai/Llama-3-8B-Instruct-262k", None), "llama-3-70b": ModelInfo("meta-llama/Meta-Llama-3-70B", None), "llama-3-70b-instruct": ModelInfo("meta-llama/Meta-Llama-3-70B-Instruct", None), "falcon-7b": ModelInfo("tiiuae/falcon-7b", None), diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 9ee2fa57..4cbde450 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1809,6 +1809,13 @@ def test_infer_hardware_from_model_name(): assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + hardware = infer_hardware_from_model_name("llama-3-8b-instruct-262k") + assert hardware.cpus == "20" + assert hardware.gpus == 2 + assert hardware.memory == "40Gi" + assert hardware.storage == "40Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + with pytest.raises(ObjectHasInvalidValueException): infer_hardware_from_model_name("unsupported_model") From fbe7417e2e65f81c7ee16a9b7a3ba7eb0a6d1feb Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 14 May 2024 19:46:35 -0700 Subject: [PATCH 298/425] Increase max gpu utilization for 70b models (#517) * Increase max gpu utilization for 70b models * Separate Gateway DTO and engine DTO * Update test fixtures --- .../model_engine_server/common/dtos/llms.py | 24 +++++++ .../use_cases/llm_model_endpoint_use_cases.py | 51 ++++++++++--- .../inference/batch_inference/vllm_batch.py | 8 +-- model-engine/tests/unit/inference/conftest.py | 17 +++-- .../tests/unit/inference/test_vllm_batch.py | 72 ++++++++++++------- 5 files changed, 127 insertions(+), 45 deletions(-) diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index c84c6931..b9744e5f 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -535,6 +535,30 @@ class CreateBatchCompletionsRequest(BaseModel): """ +class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest): + """ + Internal model for representing request to the llm engine. This contains additional fields that we want + hidden from the DTO exposed to the client. + """ + + max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0) + """ + Maximum GPU memory utilization for the batch inference. Default to 90%. + """ + + @staticmethod + def from_api(request: CreateBatchCompletionsRequest) -> "CreateBatchCompletionsEngineRequest": + return CreateBatchCompletionsEngineRequest( + input_data_path=request.input_data_path, + output_data_path=request.output_data_path, + content=request.content, + model_config=request.model_config, + data_parallelism=request.data_parallelism, + max_runtime_sec=request.max_runtime_sec, + tool_config=request.tool_config, + ) + + class CreateBatchCompletionsResponse(BaseModel): job_id: str diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 375206c7..9ce99364 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -9,7 +9,7 @@ import math import os import re -from dataclasses import asdict +from dataclasses import asdict, dataclass from typing import Any, AsyncIterable, Dict, List, Optional, Union from model_engine_server.common.config import hmi_config @@ -21,6 +21,7 @@ CompletionStreamV1Response, CompletionSyncV1Request, CompletionSyncV1Response, + CreateBatchCompletionsEngineRequest, CreateBatchCompletionsRequest, CreateBatchCompletionsResponse, CreateLLMModelEndpointV1Request, @@ -2200,6 +2201,27 @@ async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownl return ModelDownloadResponse(urls=urls) +@dataclass +class VLLMEngineArgs: + gpu_memory_utilization: Optional[float] = None + + +def infer_addition_engine_args_from_model_name(model_name: str) -> VLLMEngineArgs: + numbers = re.findall(r"\d+", model_name) + if len(numbers) == 0: + raise ObjectHasInvalidValueException( + f"Model {model_name} is not supported for batch completions." + ) + + b_params = int(numbers[-1]) + if b_params >= 70: + gpu_memory_utilization = 0.95 + else: + gpu_memory_utilization = 0.9 + + return VLLMEngineArgs(gpu_memory_utilization=gpu_memory_utilization) + + def infer_hardware_from_model_name(model_name: str) -> CreateDockerImageBatchJobResourceRequests: if "mixtral-8x7b" in model_name: cpus = "20" @@ -2324,14 +2346,25 @@ async def execute( assert hardware.gpus is not None if request.model_config.num_shards: hardware.gpus = max(hardware.gpus, request.model_config.num_shards) - request.model_config.num_shards = hardware.gpus - if request.tool_config and request.tool_config.name != "code_evaluator": + engine_request = CreateBatchCompletionsEngineRequest.from_api(request) + engine_request.model_config.num_shards = hardware.gpus + + if engine_request.tool_config and engine_request.tool_config.name != "code_evaluator": raise ObjectHasInvalidValueException( "Only code_evaluator tool is supported for batch completions." ) - batch_bundle = await self.create_batch_job_bundle(user, request, hardware) + additional_engine_args = infer_addition_engine_args_from_model_name( + engine_request.model_config.model + ) + + if additional_engine_args.gpu_memory_utilization is not None: + engine_request.max_gpu_memory_utilization = ( + additional_engine_args.gpu_memory_utilization + ) + + batch_bundle = await self.create_batch_job_bundle(user, engine_request, hardware) validate_resource_requests( bundle=batch_bundle, @@ -2342,21 +2375,21 @@ async def execute( gpu_type=hardware.gpu_type, ) - if request.max_runtime_sec is None or request.max_runtime_sec < 1: + if engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1: raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( created_by=user.user_id, owner=user.team_id, - job_config=request.dict(), + job_config=engine_request.dict(), env=batch_bundle.env, command=batch_bundle.command, repo=batch_bundle.image_repository, tag=batch_bundle.image_tag, resource_requests=hardware, - labels=request.model_config.labels, + labels=engine_request.model_config.labels, mount_location=batch_bundle.mount_location, - override_job_max_runtime_s=request.max_runtime_sec, - num_workers=request.data_parallelism, + override_job_max_runtime_s=engine_request.max_runtime_sec, + num_workers=engine_request.data_parallelism, ) return CreateBatchCompletionsResponse(job_id=job_id) diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 5bb30a4b..e8b9fabe 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -15,7 +15,7 @@ from func_timeout import FunctionTimedOut, func_set_timeout from model_engine_server.common.dtos.llms import ( CompletionOutput, - CreateBatchCompletionsRequest, + CreateBatchCompletionsEngineRequest, CreateBatchCompletionsRequestContent, TokenOutput, ToolConfig, @@ -145,7 +145,7 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) -def get_vllm_engine(model, request): +def get_vllm_engine(model: str, request: CreateBatchCompletionsEngineRequest): from vllm import AsyncEngineArgs, AsyncLLMEngine engine_args = AsyncEngineArgs( @@ -154,7 +154,7 @@ def get_vllm_engine(model, request): tensor_parallel_size=request.model_config.num_shards, seed=request.model_config.seed or 0, disable_log_requests=True, - gpu_memory_utilization=0.9, + gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9, ) llm = AsyncLLMEngine.from_engine_args(engine_args) @@ -313,7 +313,7 @@ def tool_func(text: str, past_context: Optional[str]): async def batch_inference(): job_index = int(os.getenv("JOB_COMPLETION_INDEX", 0)) - request = CreateBatchCompletionsRequest.parse_file(CONFIG_FILE) + request = CreateBatchCompletionsEngineRequest.parse_file(CONFIG_FILE) if request.model_config.checkpoint_path is not None: download_model(request.model_config.checkpoint_path, MODEL_WEIGHTS_FOLDER) diff --git a/model-engine/tests/unit/inference/conftest.py b/model-engine/tests/unit/inference/conftest.py index 20c4aae8..e8bdca29 100644 --- a/model-engine/tests/unit/inference/conftest.py +++ b/model-engine/tests/unit/inference/conftest.py @@ -3,6 +3,7 @@ import pytest from model_engine_server.common.dtos.llms import ( CompletionOutput, + CreateBatchCompletionsEngineRequest, CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequest, CreateBatchCompletionsRequestContent, @@ -12,14 +13,20 @@ @pytest.fixture -def create_batch_completions_request(): - return CreateBatchCompletionsRequest( +def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineRequest: + return CreateBatchCompletionsEngineRequest( + input_data_path="input_data_path", + output_data_path="output_data_path", model_config=CreateBatchCompletionsModelConfig( - checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={} + model="model", + checkpoint_path="checkpoint_path", + labels={}, + seed=123, + num_shards=4, ), data_parallelism=1, - input_data_path="input_data_path", - output_data_path="output_data_path", + max_runtime_sec=86400, + max_gpu_memory_utilization=0.95, ) diff --git a/model-engine/tests/unit/inference/test_vllm_batch.py b/model-engine/tests/unit/inference/test_vllm_batch.py index 7dbaad42..45223b20 100644 --- a/model-engine/tests/unit/inference/test_vllm_batch.py +++ b/model-engine/tests/unit/inference/test_vllm_batch.py @@ -7,7 +7,9 @@ @pytest.mark.asyncio @patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") -@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) @patch( "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" ) @@ -25,9 +27,9 @@ async def test_batch_inference( mock_get_s3_client, mock_generate_with_vllm, mock_create_batch_completions_request_content, - mock_create_batch_completions_request, + mock_create_batch_completions_engine_request, mock_vllm, - create_batch_completions_request, + create_batch_completions_engine_request, create_batch_completions_request_content, mock_s3_client, mock_process, @@ -36,7 +38,9 @@ async def test_batch_inference( # Mock the necessary objects and data mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client - mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request + mock_create_batch_completions_engine_request.parse_file.return_value = ( + create_batch_completions_engine_request + ) mock_create_batch_completions_request_content.parse_raw.return_value = ( create_batch_completions_request_content ) @@ -48,7 +52,7 @@ async def test_batch_inference( await batch_inference() # Assertions - mock_create_batch_completions_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.parse_file.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -61,7 +65,9 @@ async def test_batch_inference( @pytest.mark.asyncio @patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") -@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) @patch( "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" ) @@ -79,9 +85,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed( mock_get_s3_client, mock_generate_with_vllm, mock_create_batch_completions_request_content, - mock_create_batch_completions_request, + mock_create_batch_completions_engine_request, mock_vllm, - create_batch_completions_request, + create_batch_completions_engine_request, create_batch_completions_request_content, mock_s3_client, mock_process, @@ -91,7 +97,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed( mock_process.returncode = 1 # Failed to download model mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client - mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request + mock_create_batch_completions_engine_request.parse_file.return_value = ( + create_batch_completions_engine_request + ) mock_create_batch_completions_request_content.parse_raw.return_value = ( create_batch_completions_request_content ) @@ -103,7 +111,7 @@ async def test_batch_inference_failed_to_download_model_but_proceed( await batch_inference() # Assertions - mock_create_batch_completions_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.parse_file.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -116,7 +124,9 @@ async def test_batch_inference_failed_to_download_model_but_proceed( @pytest.mark.asyncio @patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") -@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) @patch( "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" ) @@ -136,9 +146,9 @@ async def test_batch_inference_two_workers( mock_get_s3_client, mock_generate_with_vllm, mock_create_batch_completions_request_content, - mock_create_batch_completions_request, + mock_create_batch_completions_engine_request, mock_vllm, - create_batch_completions_request, + create_batch_completions_engine_request, create_batch_completions_request_content, mock_s3_client, mock_process, @@ -147,8 +157,10 @@ async def test_batch_inference_two_workers( # Mock the necessary objects and data mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client - create_batch_completions_request.data_parallelism = 2 - mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request + create_batch_completions_engine_request.data_parallelism = 2 + mock_create_batch_completions_engine_request.parse_file.return_value = ( + create_batch_completions_engine_request + ) mock_create_batch_completions_request_content.parse_raw.return_value = ( create_batch_completions_request_content ) @@ -168,7 +180,7 @@ def side_effect(key, default): await batch_inference() # Assertions - mock_create_batch_completions_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.parse_file.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -198,7 +210,9 @@ def side_effect(key, default): @pytest.mark.asyncio @patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") -@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) @patch( "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" ) @@ -218,9 +232,9 @@ async def test_batch_inference_delete_chunks( mock_get_s3_client, mock_generate_with_vllm, mock_create_batch_completions_request_content, - mock_create_batch_completions_request, + mock_create_batch_completions_engine_request, mock_vllm, - create_batch_completions_request, + create_batch_completions_engine_request, create_batch_completions_request_content, mock_s3_client, mock_process, @@ -229,9 +243,11 @@ async def test_batch_inference_delete_chunks( # Mock the necessary objects and data mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client - create_batch_completions_request.data_parallelism = 2 - create_batch_completions_request.output_data_path = "s3://bucket/key" - mock_create_batch_completions_request.parse_file.return_value = create_batch_completions_request + create_batch_completions_engine_request.data_parallelism = 2 + create_batch_completions_engine_request.output_data_path = "s3://bucket/key" + mock_create_batch_completions_engine_request.parse_file.return_value = ( + create_batch_completions_engine_request + ) mock_create_batch_completions_request_content.parse_raw.return_value = ( create_batch_completions_request_content ) @@ -251,7 +267,7 @@ def side_effect(key, default): await batch_inference() # Assertions - mock_create_batch_completions_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.parse_file.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -310,7 +326,9 @@ def test_file_exists_no_such_key(): @pytest.mark.asyncio @patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") -@patch("model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequest") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) @patch( "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" ) @@ -330,7 +348,7 @@ async def test_batch_inference_tool_completion( mock_get_s3_client, mock_generate_with_vllm, mock_create_batch_completions_request_content, - mock_create_batch_completions_request, + mock_create_batch_completions_engine_request, mock_vllm, create_batch_completions_tool_completion_request, create_batch_completions_tool_completion_request_content, @@ -344,7 +362,7 @@ async def test_batch_inference_tool_completion( mock_run.return_value = mock_run_output mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client - mock_create_batch_completions_request.parse_file.return_value = ( + mock_create_batch_completions_engine_request.parse_file.return_value = ( create_batch_completions_tool_completion_request ) mock_create_batch_completions_request_content.parse_raw.return_value = ( @@ -361,7 +379,7 @@ async def test_batch_inference_tool_completion( await batch_inference() # Assertions - mock_create_batch_completions_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.parse_file.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), From ba68b8da2f6ef1f6df3f703b35456b72fbe20176 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 14 May 2024 20:45:16 -0700 Subject: [PATCH 299/425] Infer hardware from model name (#515) * Infer hardware from model name * fix * fix lint * fix * Use formula instead of hardcode * tests * remove print and cache * fixes --- .../model_engine_server/api/llms_v1.py | 15 +- .../model_engine_server/common/dtos/llms.py | 8 +- .../domain/gateways/llm_artifact_gateway.py | 12 +- .../use_cases/llm_model_endpoint_use_cases.py | 239 ++++++++------- .../gateways/abs_llm_artifact_gateway.py | 11 +- .../infra/gateways/s3_llm_artifact_gateway.py | 14 +- model-engine/tests/unit/api/conftest.py | 4 +- model-engine/tests/unit/conftest.py | 26 ++ model-engine/tests/unit/domain/conftest.py | 24 +- .../tests/unit/domain/test_llm_use_cases.py | 273 +++++++++++++++++- 10 files changed, 487 insertions(+), 139 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 614cc6bb..07a93d78 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -1,5 +1,6 @@ """LLM Model Endpoint routes for the hosted model inference service. """ + import traceback from datetime import datetime from typing import Optional @@ -169,6 +170,7 @@ async def create_model_endpoint( create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, model_endpoint_service=external_interfaces.model_endpoint_service, docker_repository=external_interfaces.docker_repository, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, ) return await use_case.execute(user=auth, request=request) except ObjectAlreadyExistsException as exc: @@ -331,9 +333,9 @@ async def create_completion_sync_task( external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, TokenUsage( num_prompt_tokens=response.output.num_prompt_tokens if response.output else None, - num_completion_tokens=response.output.num_completion_tokens - if response.output - else None, + num_completion_tokens=( + response.output.num_completion_tokens if response.output else None + ), total_duration=use_case_timer.duration, ), metric_metadata, @@ -401,9 +403,9 @@ async def event_generator(): external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, TokenUsage( num_prompt_tokens=message.output.num_prompt_tokens if message.output else None, - num_completion_tokens=message.output.num_completion_tokens - if message.output - else None, + num_completion_tokens=( + message.output.num_completion_tokens if message.output else None + ), total_duration=use_case_timer.duration, time_to_first_token=time_to_first_token, ), @@ -593,6 +595,7 @@ async def create_batch_completions( docker_image_batch_job_gateway=external_interfaces.docker_image_batch_job_gateway, docker_repository=external_interfaces.docker_repository, docker_image_batch_job_bundle_repo=external_interfaces.docker_image_batch_job_bundle_repository, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, ) return await use_case.execute(user=auth, request=request) except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index b9744e5f..90498c23 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -51,10 +51,10 @@ class CreateLLMModelEndpointV1Request(BaseModel): metadata: Dict[str, Any] # TODO: JSON type post_inference_hooks: Optional[List[str]] endpoint_type: ModelEndpointType = ModelEndpointType.SYNC - cpus: CpuSpecificationType - gpus: int - memory: StorageSpecificationType - gpu_type: GpuType + cpus: Optional[CpuSpecificationType] + gpus: Optional[int] + memory: Optional[StorageSpecificationType] + gpu_type: Optional[GpuType] storage: Optional[StorageSpecificationType] optimize_costs: Optional[bool] min_workers: int diff --git a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py index 017bedea..8f8ece69 100644 --- a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import Any, Dict, List class LLMArtifactGateway(ABC): @@ -39,3 +39,13 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[ model_name (str): name of the model """ pass + + @abstractmethod + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + """ + Gets the model config from the model files live at given folder. + + Args: + path (str): path to model files + """ + pass diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 9ce99364..8d257bf1 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -10,6 +10,7 @@ import os import re from dataclasses import asdict, dataclass +from functools import lru_cache from typing import Any, AsyncIterable, Dict, List, Optional, Union from model_engine_server.common.config import hmi_config @@ -226,30 +227,6 @@ LLMInferenceFramework.TENSORRT_LLM: [], } -# We need a dict where if we need to override we can -# NOTE: These are in *descending* order of priority. e.g. if you see 'mammoth-coder' -# you'll use that override and not listen to the 'llama-2' override -_VLLM_MODEL_LENGTH_OVERRIDES: Dict[str, Dict[str, Optional[int]]] = { - "mammoth-coder": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, - # Based on config here: https://huggingface.co/TIGER-Lab/MAmmoTH-Coder-7B/blob/main/config.json#L12 - # Can also see 13B, 34B there too - "code-llama": {"max_model_len": 16384, "max_num_batched_tokens": 16384}, - "codellama": { - "max_model_len": 16384, - "max_num_batched_tokens": 16384, - }, # setting both for backwards compatibility, will phase code-llama out in a future pr - # Based on config here: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json#L12 - # Can also see 13B, 34B there too - "gemma": {"max_model_len": 8192, "max_num_batched_tokens": 8192}, - "llama-2": {"max_model_len": None, "max_num_batched_tokens": 4096}, - "llama-3-8b-instruct-262k": {"max_model_len": None, "max_num_batched_tokens": 262144}, - "llama-3": {"max_model_len": None, "max_num_batched_tokens": 8192}, - "mistral": {"max_model_len": 8000, "max_num_batched_tokens": 8000}, - "mixtral-8x7b": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, - "mixtral-8x22b": {"max_model_len": 65536, "max_num_batched_tokens": 65536}, - "zephyr": {"max_model_len": 32768, "max_num_batched_tokens": 32768}, -} - NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes @@ -319,9 +296,9 @@ def validate_num_shards( raise ObjectHasInvalidValueException( f"Num shard {num_shards} must be the same as number of GPUs {gpus} for DeepSpeed." ) - if num_shards > gpus: + if num_shards != gpus: raise ObjectHasInvalidValueException( - f"Num shard {num_shards} must be less than or equal to the number of GPUs {gpus}." + f"Num shard {num_shards} must be equal to the number of GPUs {gpus}." ) @@ -670,16 +647,6 @@ async def create_vllm_bundle( checkpoint_path: Optional[str], ): command = [] - - max_num_batched_tokens: Optional[int] = 2560 # vLLM's default - max_model_len: Optional[int] = None - - for key, value in _VLLM_MODEL_LENGTH_OVERRIDES.items(): - if key in model_name: - max_model_len = value["max_model_len"] - max_num_batched_tokens = value["max_num_batched_tokens"] - break - subcommands = [] checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) @@ -695,14 +662,9 @@ async def create_vllm_bundle( final_weights_folder, ) - if max_model_len: - subcommands.append( - f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005 --max-num-batched-tokens {max_num_batched_tokens} --max-model-len {max_model_len}" - ) - else: - subcommands.append( - f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005 --max-num-batched-tokens {max_num_batched_tokens}" - ) + subcommands.append( + f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005" + ) if quantize: if quantize == Quantization.AWQ: @@ -890,15 +852,26 @@ def __init__( create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, model_endpoint_service: ModelEndpointService, docker_repository: DockerRepository, + llm_artifact_gateway: LLMArtifactGateway, ): self.authz_module = LiveAuthorizationModule() self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case self.model_endpoint_service = model_endpoint_service self.docker_repository = docker_repository + self.llm_artifact_gateway = llm_artifact_gateway async def execute( self, user: User, request: CreateLLMModelEndpointV1Request ) -> CreateLLMModelEndpointV1Response: + _fill_hardware_info(self.llm_artifact_gateway, request) + if not ( + request.gpus + and request.gpu_type + and request.cpus + and request.memory + and request.storage + ): + raise RuntimeError("Some hardware info is missing unexpectedly.") validate_deployment_resources( min_workers=request.min_workers, max_workers=request.max_workers, @@ -2201,35 +2174,107 @@ async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownl return ModelDownloadResponse(urls=urls) -@dataclass -class VLLMEngineArgs: - gpu_memory_utilization: Optional[float] = None - - -def infer_addition_engine_args_from_model_name(model_name: str) -> VLLMEngineArgs: - numbers = re.findall(r"\d+", model_name) - if len(numbers) == 0: - raise ObjectHasInvalidValueException( - f"Model {model_name} is not supported for batch completions." - ) +def _fill_hardware_info( + llm_artifact_gateway: LLMArtifactGateway, request: CreateLLMModelEndpointV1Request +): + if ( + request.gpus is None + or request.gpu_type is None + or request.cpus is None + or request.memory is None + or request.storage is None + ): + if not ( + request.gpus is None + and request.gpu_type is None + and request.cpus is None + and request.memory is None + and request.storage is None + ): + raise ObjectHasInvalidValueException( + "All hardware spec fields (gpus, gpu_type, cpus, memory, storage) must be provided if any hardware spec field is missing." + ) + checkpoint_path = get_checkpoint_path(request.model_name, request.checkpoint_path) + hardware_info = _infer_hardware(llm_artifact_gateway, request.model_name, checkpoint_path) + request.gpus = hardware_info.gpus + request.gpu_type = hardware_info.gpu_type + request.cpus = hardware_info.cpus + request.memory = hardware_info.memory + request.storage = hardware_info.storage + if hardware_info.gpus: # make lint happy + request.num_shards = hardware_info.gpus + + +@lru_cache() +def _infer_hardware( + llm_artifact_gateway: LLMArtifactGateway, + model_name: str, + checkpoint_path: str, +) -> CreateDockerImageBatchJobResourceRequests: + config = llm_artifact_gateway.get_model_config(checkpoint_path) + + dtype_size = 2 + + min_kv_cache_size = ( + 2 + * dtype_size + * config["num_hidden_layers"] + * config["hidden_size"] + * config["max_position_embeddings"] + // (config["num_attention_heads"] // config["num_key_value_heads"]) + ) - b_params = int(numbers[-1]) - if b_params >= 70: - gpu_memory_utilization = 0.95 + if "mixtral-8x7b" in model_name: + model_param_count_b = 47 + elif "mixtral-8x22b" in model_name: + model_param_count_b = 140 else: - gpu_memory_utilization = 0.9 + numbers = re.findall(r"(\d+)b", model_name) + if len(numbers) == 0: + raise ObjectHasInvalidValueException( + f"Unable to infer number of parameters for {model_name}." + ) + model_param_count_b = int(numbers[-1]) - return VLLMEngineArgs(gpu_memory_utilization=gpu_memory_utilization) + model_weights_size = dtype_size * model_param_count_b * 1_000_000_000 + min_memory_gb = math.ceil((min_kv_cache_size + model_weights_size) / 1_000_000_000 / 0.9) -def infer_hardware_from_model_name(model_name: str) -> CreateDockerImageBatchJobResourceRequests: - if "mixtral-8x7b" in model_name: + logger.info( + f"Memory calculation result: {min_memory_gb=} for {model_name}, min_kv_cache_size: {min_kv_cache_size}, model_weights_size: {model_weights_size}" + ) + + if min_memory_gb <= 24: + cpus = "10" + gpus = 1 + memory = "24Gi" + storage = "80Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A10 + elif min_memory_gb <= 48: + cpus = "20" + gpus = 2 + memory = "48Gi" + storage = "80Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A10 + elif min_memory_gb <= 96: + cpus = "40" + gpus = 4 + memory = "96Gi" + storage = "96Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A10 + elif min_memory_gb <= 180: cpus = "20" gpus = 2 memory = "160Gi" storage = "160Gi" gpu_type = GpuType.NVIDIA_AMPERE_A100E - elif "mixtral-8x22b" in model_name: + elif min_memory_gb <= 320: + cpus = "40" + gpus = 4 + memory = "320Gi" + storage = "320Gi" + gpu_type = GpuType.NVIDIA_AMPERE_A100E + elif min_memory_gb <= 640: cpus = "80" gpus = 8 memory = "800Gi" @@ -2242,57 +2287,46 @@ def infer_hardware_from_model_name(model_name: str) -> CreateDockerImageBatchJob storage = "40Gi" gpu_type = GpuType.NVIDIA_AMPERE_A100E else: - numbers = re.findall(r"\d+", model_name) - if len(numbers) == 0: - raise ObjectHasInvalidValueException( - f"Model {model_name} is not supported for batch completions." - ) - - b_params = int(numbers[-1]) - if b_params <= 7: - cpus = "10" - gpus = 1 - memory = "24Gi" - storage = "80Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A10 - elif b_params <= 13: - cpus = "20" - gpus = 2 - memory = "48Gi" - storage = "80Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A10 - elif b_params <= 34: - cpus = "40" - gpus = 4 - memory = "96Gi" - storage = "96Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A10 - elif b_params <= 70: - cpus = "20" - gpus = 2 - memory = "160Gi" - storage = "160Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A100E - else: - raise ObjectHasInvalidValueException( - f"Model {model_name} is not supported for batch completions." - ) + raise ObjectHasInvalidValueException(f"Unable to infer hardware for {model_name}.") return CreateDockerImageBatchJobResourceRequests( cpus=cpus, gpus=gpus, memory=memory, storage=storage, gpu_type=gpu_type ) +@dataclass +class VLLMEngineArgs: + gpu_memory_utilization: Optional[float] = None + + +def infer_addition_engine_args_from_model_name(model_name: str) -> VLLMEngineArgs: + numbers = re.findall(r"\d+", model_name) + if len(numbers) == 0: + raise ObjectHasInvalidValueException( + f"Model {model_name} is not supported for batch completions." + ) + + b_params = int(numbers[-1]) + if b_params >= 70: + gpu_memory_utilization = 0.95 + else: + gpu_memory_utilization = 0.9 + + return VLLMEngineArgs(gpu_memory_utilization=gpu_memory_utilization) + + class CreateBatchCompletionsUseCase: def __init__( self, docker_image_batch_job_gateway: DockerImageBatchJobGateway, docker_repository: DockerRepository, docker_image_batch_job_bundle_repo: DockerImageBatchJobBundleRepository, + llm_artifact_gateway: LLMArtifactGateway, ): self.docker_image_batch_job_gateway = docker_image_batch_job_gateway self.docker_repository = docker_repository self.docker_image_batch_job_bundle_repo = docker_image_batch_job_bundle_repo + self.llm_artifact_gateway = llm_artifact_gateway async def create_batch_job_bundle( self, @@ -2341,7 +2375,14 @@ async def create_batch_job_bundle( async def execute( self, user: User, request: CreateBatchCompletionsRequest ) -> CreateBatchCompletionsResponse: - hardware = infer_hardware_from_model_name(request.model_config.model) + request.model_config.checkpoint_path = get_checkpoint_path( + request.model_config.model, request.model_config.checkpoint_path + ) + hardware = _infer_hardware( + self.llm_artifact_gateway, + request.model_config.model, + request.model_config.checkpoint_path, + ) # Reconcile gpus count with num_shards from request assert hardware.gpus is not None if request.model_config.num_shards: diff --git a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py index 8ebbeda3..d68d539c 100644 --- a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py @@ -1,5 +1,6 @@ +import json import os -from typing import List +from typing import Any, Dict, List from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient, ContainerClient @@ -73,3 +74,11 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[ for blob_name in container_client.list_blob_names(name_starts_with=prefix): model_files.append(f"https://{account}.blob.core.windows.net/{bucket}/{blob_name}") return model_files + + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = os.path.join(parsed_remote.key, "config.json") + + container_client = _get_abs_container_client(bucket) + return json.loads(container_client.download_blob(blob=key).readall()) diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py index ebc6b2fd..b48d1eef 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -1,5 +1,6 @@ +import json import os -from typing import List +from typing import Any, Dict, List import boto3 from model_engine_server.common.config import get_model_cache_directory_name, hmi_config @@ -71,3 +72,14 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[ for obj in s3_bucket.objects.filter(Prefix=prefix): model_files.append(f"s3://{bucket}/{obj.key}") return model_files + + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + s3 = self._get_s3_resource(kwargs) + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = os.path.join(parsed_remote.key, "config.json") + s3_bucket = s3.Bucket(bucket) + filepath = os.path.join("/tmp", key).replace("/", "_") + s3_bucket.download_file(key, filepath) + with open(filepath, "r") as f: + return json.load(f) diff --git a/model-engine/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py index 703c7c12..725b7795 100644 --- a/model-engine/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -1177,7 +1177,7 @@ def create_llm_model_endpoint_request_sync() -> Dict[str, Any]: "gpus": 2, "memory": "1G", "gpu_type": "nvidia-tesla-t4", - "storage": None, + "storage": "1Gi", "min_workers": 1, "max_workers": 5, "per_worker": 3, @@ -1282,7 +1282,7 @@ def create_batch_completions_request() -> Dict[str, Any]: }, "model_config": { "model": "mpt-7b", - "checkpoint_path": "test_checkpoint_path", + "checkpoint_path": "s3://test_checkpoint_path", "labels": [], "num_shards": 2, }, diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 50aa4018..7de4ec47 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -765,6 +765,29 @@ def __init__(self): "mpt-7b": ["model-fake.safetensors"], } self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} + self.model_config = { + "_name_or_path": "meta-llama/Llama-2-7b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + } def _add_model(self, owner: str, model_name: str): self.existing_models.append((owner, model_name)) @@ -784,6 +807,9 @@ def get_model_weights_urls(self, owner: str, model_name: str): return self.urls raise ObjectNotFoundException + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + return self.model_config + class FakeTriggerRepository(TriggerRepository): def __init__(self, contents: Optional[Dict[str, Trigger]] = None): diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index a3419890..c721960e 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -85,7 +85,7 @@ def create_model_endpoint_request_sync( gpus=1, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -189,7 +189,7 @@ def create_llm_model_endpoint_request_sync() -> CreateLLMModelEndpointV1Request: gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -216,7 +216,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=0, max_workers=3, per_worker=2, @@ -243,7 +243,7 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -281,7 +281,7 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -310,7 +310,7 @@ def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -340,7 +340,7 @@ def create_llm_model_endpoint_text_generation_inference_request_async() -> ( gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -366,7 +366,7 @@ def create_llm_model_endpoint_trt_llm_request_streaming() -> CreateLLMModelEndpo gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -394,7 +394,7 @@ def create_llm_model_endpoint_trt_llm_request_async() -> CreateLLMModelEndpointV gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -421,7 +421,7 @@ def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndp gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -448,7 +448,7 @@ def create_llm_model_endpoint_request_invalid_quantization() -> CreateLLMModelEn gpus=2, memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, + storage="10G", min_workers=1, max_workers=3, per_worker=2, @@ -489,7 +489,7 @@ def create_batch_completions_request() -> CreateBatchCompletionsRequest: ), model_config=CreateBatchCompletionsModelConfig( model="mpt-7b", - checkpoint_path="test_checkpoint_path", + checkpoint_path="s3://test_checkpoint_path", labels=[], num_shards=2, ), diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 4cbde450..950f915f 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -49,7 +49,8 @@ GpuType, ModelDownloadV1UseCase, UpdateLLMModelEndpointV1UseCase, - infer_hardware_from_model_name, + _fill_hardware_info, + _infer_hardware, validate_and_update_completion_params, validate_checkpoint_files, ) @@ -96,6 +97,7 @@ async def test_create_model_endpoint_use_case_success( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) @@ -290,6 +292,7 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, ) request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() @@ -331,6 +334,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -443,6 +447,7 @@ async def test_create_model_endpoint_trt_llm_use_case_success( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response_1 = await use_case.execute( @@ -504,6 +509,7 @@ async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): @@ -538,6 +544,7 @@ async def test_create_llm_model_endpoint_use_case_quantization_exception( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectHasInvalidValueException): @@ -611,6 +618,7 @@ async def test_update_model_endpoint_use_case_success( create_llm_model_bundle_use_case=llm_bundle_use_case, model_endpoint_service=fake_model_endpoint_service, docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, ) update_use_case = UpdateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=llm_bundle_use_case, @@ -1766,61 +1774,298 @@ async def test_validate_checkpoint_files_safetensors_with_other_files(): validate_checkpoint_files(fake_model_files) # No exception should be raised -def test_infer_hardware_from_model_name(): - hardware = infer_hardware_from_model_name("mixtral-8x7b") +def test_infer_hardware(fake_llm_artifact_gateway): + fake_llm_artifact_gateway.model_config = { + "architectures": ["MixtralForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mixtral", + "num_attention_heads": 32, + "num_experts_per_tok": 2, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "num_local_experts": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.02, + "torch_dtype": "bfloat16", + "transformers_version": "4.36.0.dev0", + "vocab_size": 32000, + } + hardware = _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x7b", "") assert hardware.cpus == "20" assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E - hardware = infer_hardware_from_model_name("mixtral-8x22b") + fake_llm_artifact_gateway.model_config = { + "architectures": ["MixtralForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 65536, + "model_type": "mixtral", + "num_attention_heads": 48, + "num_experts_per_tok": 2, + "num_hidden_layers": 56, + "num_key_value_heads": 8, + "num_local_experts": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000, + "router_aux_loss_coef": 0.001, + "router_jitter_noise": 0.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "vocab_size": 32000, + } + hardware = _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x22b", "") assert hardware.cpus == "80" assert hardware.gpus == 8 assert hardware.memory == "800Gi" assert hardware.storage == "460Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E - hardware = infer_hardware_from_model_name("llama-2-7b") + fake_llm_artifact_gateway.model_config = { + "_name_or_path": "meta-llama/Llama-2-7b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "vocab_size": 32000, + } + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "") assert hardware.cpus == "10" assert hardware.gpus == 1 assert hardware.memory == "24Gi" assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 - hardware = infer_hardware_from_model_name("llama-2-13b") + fake_llm_artifact_gateway.model_config = { + "architectures": ["LlamaForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 500000.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "vocab_size": 128256, + } + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "") + assert hardware.cpus == "10" + assert hardware.gpus == 1 + assert hardware.memory == "24Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + + fake_llm_artifact_gateway.model_config = { + "_name_or_path": "meta-llama/Llama-2-13b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "torch_dtype": "float16", + "transformers_version": "4.32.0.dev0", + "vocab_size": 32000, + } + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-13b", "") assert hardware.cpus == "20" assert hardware.gpus == 2 assert hardware.memory == "48Gi" assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 - hardware = infer_hardware_from_model_name("codellama-34b") + fake_llm_artifact_gateway.model_config = { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 22016, + "max_position_embeddings": 16384, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 48, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000, + "torch_dtype": "bfloat16", + "transformers_version": "4.32.0.dev0", + "vocab_size": 32000, + } + hardware = _infer_hardware(fake_llm_artifact_gateway, "codellama-34b", "") assert hardware.cpus == "40" assert hardware.gpus == 4 assert hardware.memory == "96Gi" assert hardware.storage == "96Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 - hardware = infer_hardware_from_model_name("llama-2-70b") + fake_llm_artifact_gateway.model_config = { + "_name_or_path": "meta-llama/Llama-2-70b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "torch_dtype": "float16", + "transformers_version": "4.32.0.dev0", + "vocab_size": 32000, + } + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-70b", "") assert hardware.cpus == "20" assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E - hardware = infer_hardware_from_model_name("llama-3-8b-instruct-262k") + fake_llm_artifact_gateway.model_config = { + "architectures": ["LlamaForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 500000.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "vocab_size": 128256, + } + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-70b", "") assert hardware.cpus == "20" assert hardware.gpus == 2 - assert hardware.memory == "40Gi" - assert hardware.storage == "40Gi" + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + # (TODO) figure out how to calculate memory for llama-3-8b-instruct-262k + # fake_llm_artifact_gateway.model_config = { + # "_name_or_path": "gradientai/llama3-8b-stage65k-chat", + # "architectures": ["LlamaForCausalLM"], + # "attention_dropout": 0.0, + # "bos_token_id": 128000, + # "eos_token_id": 128001, + # "hidden_act": "silu", + # "hidden_size": 4096, + # "initializer_range": 0.02, + # "intermediate_size": 14336, + # "max_position_embeddings": 262144, + # "model_type": "llama", + # "num_attention_heads": 32, + # "num_hidden_layers": 32, + # "num_key_value_heads": 8, + # "pretraining_tp": 1, + # "rms_norm_eps": 1e-05, + # "rope_theta": 283461213.0, + # "torch_dtype": "bfloat16", + # "transformers_version": "4.41.0.dev0", + # "vocab_size": 128256, + # } + # hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b-instruct-262k", "") + # assert hardware.cpus == "20" + # assert hardware.gpus == 2 + # assert hardware.memory == "160Gi" + # assert hardware.storage == "160Gi" + # assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + with pytest.raises(ObjectHasInvalidValueException): - infer_hardware_from_model_name("unsupported_model") + _infer_hardware(fake_llm_artifact_gateway, "unsupported_model", "") + + +def test_fill_hardware_info(fake_llm_artifact_gateway): + request = CreateLLMModelEndpointV1Request( + name="mixtral-8x7b", + model_name="mixtral-8x7b", + checkpoint_path="s3://checkpoint", + metadata={}, + min_workers=1, + max_workers=1, + per_worker=1, + labels={}, + ) + _fill_hardware_info(fake_llm_artifact_gateway, request) + assert request.cpus == "20" + assert request.gpus == 2 + assert request.memory == "160Gi" + assert request.storage == "160Gi" + assert request.gpu_type == GpuType.NVIDIA_AMPERE_A100E + + request = CreateLLMModelEndpointV1Request( + name="mixtral-8x7b", + model_name="mixtral-8x7b", + checkpoint_path="s3://checkpoint", + metadata={}, + min_workers=1, + max_workers=1, + per_worker=1, + labels={}, + gpus=1, + ) with pytest.raises(ObjectHasInvalidValueException): - infer_hardware_from_model_name("falcon-180b") + _fill_hardware_info(fake_llm_artifact_gateway, request) @pytest.mark.asyncio @@ -1828,6 +2073,7 @@ async def test_create_batch_completions( fake_docker_image_batch_job_gateway, fake_docker_repository_image_always_exists, fake_docker_image_batch_job_bundle_repository, + fake_llm_artifact_gateway, test_api_key: str, create_batch_completions_request: CreateBatchCompletionsRequest, ): @@ -1835,6 +2081,7 @@ async def test_create_batch_completions( docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway, docker_repository=fake_docker_repository_image_always_exists, docker_image_batch_job_bundle_repo=fake_docker_image_batch_job_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) From 1470aacb0a14781a786475b042cb7cacbc12ccd0 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 15 May 2024 14:03:44 -0700 Subject: [PATCH 300/425] Improve TensorRT-LLM Functionality (#487) Changes to get tensorrtllm to work with Mixtral Update tensorrt llm included code/build processes to a newer version Add some bits to mitigate some tokenization issues Note: the logprobs returned aren't correct still, haven't investigated. Stop sequences don't completely work, to my knowledge this is somewhat of a limitation of how tensorrt/triton works, but there may be another way around this. --- .../use_cases/llm_model_endpoint_use_cases.py | 27 +- .../inference/tensorrt-llm/Dockerfile | 2 +- .../inference/tensorrt-llm/README.md | 14 + .../tensorrt-llm/launch_triton_server.py | 14 +- .../inference/tensorrt-llm/requirements.txt | 3 +- .../triton_model_repo/ensemble/config.pbtxt | 217 +++++++++- .../postprocessing/1/model.py | 83 +++- .../postprocessing/config.pbtxt | 59 ++- .../preprocessing/1/model.py | 212 ++++++++-- .../preprocessing/config.pbtxt | 58 ++- .../tensorrt_llm/config.pbtxt | 188 ++++++++- .../tensorrt_llm_bls/1/model.py | 389 ++++++++++++++++++ .../tensorrt_llm_bls/config.pbtxt | 221 ++++++++++ .../tests/unit/domain/test_llm_use_cases.py | 47 ++- scripts/throughput_benchmarks.py | 22 +- 15 files changed, 1443 insertions(+), 113 deletions(-) create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/README.md mode change 100755 => 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt mode change 100755 => 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/1/model.py create mode 100644 model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/config.pbtxt diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 8d257bf1..761d8ab9 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -216,7 +216,9 @@ "llama-2-70b-chat", ] ), - LLMInferenceFramework.TENSORRT_LLM: set(["llama-2-7b"]), + LLMInferenceFramework.TENSORRT_LLM: set( + ["llama-2-7b", "mixtral-8x7b", "mixtral-8x7b-instruct"] + ), } _SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = { @@ -1467,11 +1469,28 @@ def model_output_to_completion_output( num_prompt_tokens = count_tokens( prompt, model_content.model_name, self.tokenizer_repository ) - return CompletionOutput( + if "token_ids" in model_output: + # TensorRT 23.10 has this field, TensorRT 24.03 does not + # For backwards compatibility with pre-2024/05/02 + num_completion_tokens = len(model_output["token_ids"]) - num_prompt_tokens # Output is " prompt output" - text=model_output["text_output"][(len(prompt) + 4) :], + text = model_output["text_output"][(len(prompt) + 4) :] + elif "output_log_probs" in model_output: + # TensorRT 24.01 + surrounding code. + # For some reason TRT returns output_log_probs as either a list or a float + # Also the log probs don't look right, so returning log-probs is still broken + num_completion_tokens = ( + len(model_output["output_log_probs"]) + if type(model_output["output_log_probs"]) == list + else 1 + ) + # Output is just "output". See `exclude_input_in_output` inside of + # inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt + text = model_output["text_output"] + return CompletionOutput( + text=text, num_prompt_tokens=num_prompt_tokens, - num_completion_tokens=len(model_output["token_ids"]) - num_prompt_tokens, + num_completion_tokens=num_completion_tokens, ) else: raise EndpointUnsupportedInferenceTypeException( diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile b/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile index 7bae22fd..be1fa7e9 100644 --- a/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile +++ b/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 +FROM nvcr.io/nvidia/tritonserver:24.03-trtllm-python-py3 COPY requirements.txt /workspace/requirements.txt WORKDIR /workspace diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/README.md b/model-engine/model_engine_server/inference/tensorrt-llm/README.md new file mode 100644 index 00000000..0468de7d --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/README.md @@ -0,0 +1,14 @@ +# Preparing the model weights/tokenizers + +Our TensorRT-LLM docker image expects weights to live in s3/other blob store with the following directory structure: + +root/ + model_tokenizer/ + + model_weights/ + config.json + rank.engine + +You can obtain `model_weights` by building a TRT-LLM engine via the directions found on Nvidia's site (e.g. https://github.com/NVIDIA/TensorRT-LLM/blob/main/README.md#installation, https://github.com/NVIDIA/TensorRT-LLM/blob/v0.8.0/examples/llama/convert_checkpoint.py) + +The inference image is built via the Dockerfile in the same directory as this readme. \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py b/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py index 0ce46d2b..1a3434ee 100644 --- a/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py +++ b/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py @@ -9,6 +9,12 @@ def parse_arguments(): "--world_size", type=int, default=1, help="world size, only support tensor parallelism now" ) parser.add_argument("--tritonserver", type=str, default="/opt/tritonserver/bin/tritonserver") + parser.add_argument( + "--http-address", + type=str, + default="ipv6:[::1]", + help="Default HTTP address to ipv6:[::1].", + ) parser.add_argument( "--http-port", type=int, @@ -20,14 +26,16 @@ def parse_arguments(): return parser.parse_args() -def get_cmd(world_size, tritonserver, model_repo, http_port): +def get_cmd(world_size, tritonserver, model_repo, http_address, http_port): cmd = "mpirun --allow-run-as-root " for i in range(world_size): - cmd += f" -n 1 {tritonserver} --model-repository={model_repo} --http-address ipv6:[::1] --http-port {http_port} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{i}_ : " + cmd += f" -n 1 {tritonserver} --model-repository={model_repo} --http-address {http_address} --http-port {http_port} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{i}_ : " return cmd if __name__ == "__main__": args = parse_arguments() - cmd = get_cmd(int(args.world_size), args.tritonserver, args.model_repo, args.http_port) + cmd = get_cmd( + int(args.world_size), args.tritonserver, args.model_repo, args.http_address, args.http_port + ) subprocess.call(cmd, shell=True) diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt b/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt index e2e60684..7d75f3fc 100644 --- a/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt +++ b/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt @@ -1,2 +1,3 @@ sentencepiece==0.1.99 -protobuf==4.24.4 \ No newline at end of file +protobuf==4.24.4 +torch==2.2.2 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt old mode 100755 new mode 100644 index 7a7662d3..55a52eaf --- a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -26,7 +26,7 @@ name: "ensemble" platform: "ensemble" -max_batch_size: 128 +max_batch_size: 128 input [ { name: "text_input" @@ -35,34 +35,36 @@ input [ }, { name: "max_tokens" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ -1 ] }, { name: "bad_words" data_type: TYPE_STRING dims: [ -1 ] + optional: true }, { name: "stop_words" data_type: TYPE_STRING dims: [ -1 ] + optional: true }, { name: "end_id" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] optional: true }, { name: "pad_id" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] optional: true }, { name: "top_k" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] optional: true }, @@ -92,7 +94,7 @@ input [ }, { name: "min_length" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] optional: true }, @@ -102,15 +104,39 @@ input [ dims: [ 1 ] optional: true }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, { name: "random_seed" data_type: TYPE_UINT64 dims: [ 1 ] optional: true }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, { name: "beam_width" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] optional: true }, @@ -119,18 +145,57 @@ input [ data_type: TYPE_BOOL dims: [ 1 ] optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "embedding_bias_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "embedding_bias_weights" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true } ] output [ { name: "text_output" data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 dims: [ -1, -1 ] }, { - name: "token_ids" - data_type: TYPE_INT32 + name: "context_logits" + data_type: TYPE_FP32 dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] } ] ensemble_scheduling { @@ -154,6 +219,22 @@ ensemble_scheduling { key: "STOP_WORDS_DICT" value: "stop_words" } + input_map { + key: "EMBEDDING_BIAS_WORDS" + value: "embedding_bias_words" + } + input_map { + key: "EMBEDDING_BIAS_WEIGHTS" + value: "embedding_bias_weights" + } + input_map { + key: "END_ID" + value: "end_id" + } + input_map { + key: "PAD_ID" + value: "pad_id" + } output_map { key: "REQUEST_INPUT_LEN" value: "_REQUEST_INPUT_LEN" @@ -166,6 +247,26 @@ ensemble_scheduling { key: "REQUEST_OUTPUT_LEN" value: "_REQUEST_OUTPUT_LEN" } + output_map { + key: "STOP_WORDS_IDS" + value: "_STOP_WORDS_IDS" + } + output_map { + key: "BAD_WORDS_IDS" + value: "_BAD_WORDS_IDS" + } + output_map { + key: "EMBEDDING_BIAS" + value: "_EMBEDDING_BIAS" + } + output_map { + key: "OUT_END_ID" + value: "_PREPROCESSOR_END_ID" + } + output_map { + key: "OUT_PAD_ID" + value: "_PREPROCESSOR_PAD_ID" + } }, { model_name: "tensorrt_llm" @@ -184,11 +285,15 @@ ensemble_scheduling { } input_map { key: "end_id" - value: "end_id" + value: "_PREPROCESSOR_END_ID" } input_map { key: "pad_id" - value: "pad_id" + value: "_PREPROCESSOR_PAD_ID" + } + input_map { + key: "embedding_bias" + value: "_EMBEDDING_BIAS" } input_map { key: "runtime_top_k" @@ -218,10 +323,26 @@ ensemble_scheduling { key: "presence_penalty" value: "presence_penalty" } + input_map { + key: "frequency_penalty" + value: "frequency_penalty" + } input_map { key: "random_seed" value: "random_seed" } + input_map { + key: "return_log_probs" + value: "return_log_probs" + } + input_map { + key: "return_context_logits" + value: "return_context_logits" + } + input_map { + key: "return_generation_logits" + value: "return_generation_logits" + } input_map { key: "beam_width" value: "beam_width" @@ -230,10 +351,46 @@ ensemble_scheduling { key: "streaming" value: "stream" } + input_map { + key: "prompt_embedding_table" + value: "prompt_embedding_table" + } + input_map { + key: "prompt_vocab_size" + value: "prompt_vocab_size" + } + input_map { + key: "stop_words_list" + value: "_STOP_WORDS_IDS" + } + input_map { + key: "bad_words_list" + value: "_BAD_WORDS_IDS" + } output_map { key: "output_ids" value: "_TOKENS_BATCH" } + output_map { + key: "sequence_length" + value: "_SEQUENCE_LENGTH" + }, + output_map { + key: "cum_log_probs" + value: "_CUM_LOG_PROBS" + } + output_map { + key: "output_log_probs" + value: "_OUTPUT_LOG_PROBS" + }, + output_map { + key: "context_logits" + value: "_CONTEXT_LOGITS" + }, + output_map { + key: "generation_logits" + value: "_GENERATION_LOGITS" + } }, { model_name: "postprocessing" @@ -242,13 +399,45 @@ ensemble_scheduling { key: "TOKENS_BATCH" value: "_TOKENS_BATCH" } + input_map { + key: "CUM_LOG_PROBS" + value: "_CUM_LOG_PROBS" + } + input_map { + key: "OUTPUT_LOG_PROBS" + value: "_OUTPUT_LOG_PROBS" + } + input_map { + key: "CONTEXT_LOGITS" + value: "_CONTEXT_LOGITS" + } + input_map { + key: "GENERATION_LOGITS" + value: "_GENERATION_LOGITS" + } + input_map { + key: "SEQUENCE_LENGTH" + value: "_SEQUENCE_LENGTH" + } output_map { key: "OUTPUT" value: "text_output" } output_map { - key: "OUTPUT_TOKEN_IDS" - value: "token_ids" + key: "OUT_OUTPUT_LOG_PROBS" + value: "output_log_probs" + } + output_map { + key: "OUT_CUM_LOG_PROBS" + value: "cum_log_probs" + } + output_map { + key: "OUT_CONTEXT_LOGITS" + value: "context_logits" + } + output_map { + key: "OUT_GENERATION_LOGITS" + value: "generation_logits" } } ] diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py index 1cd809d9..c1c6353b 100644 --- a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -28,7 +28,7 @@ import numpy as np import triton_python_backend_utils as pb_utils -from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer +from transformers import SPIECE_UNDERLINE, AutoTokenizer, LlamaTokenizer, T5Tokenizer class TritonPythonModel: @@ -55,11 +55,16 @@ def initialize(self, args): model_config = json.loads(args["model_config"]) tokenizer_dir = model_config["parameters"]["tokenizer_dir"]["string_value"] tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + self.skip_special_tokens = model_config["parameters"].get( + "skip_special_tokens", {"string_value": "true"} + )["string_value"].lower() in ["true", "1", "t", "y", "yes"] if tokenizer_type == "t5": self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") elif tokenizer_type == "auto": - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, padding_side="left") + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, padding_side="left", trust_remote_code=True + ) elif tokenizer_type == "llama": self.tokenizer = LlamaTokenizer.from_pretrained( tokenizer_dir, legacy=False, padding_side="left" @@ -70,17 +75,10 @@ def initialize(self, args): # Parse model output configs output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") - output_token_ids_config = pb_utils.get_output_config_by_name( - model_config, "OUTPUT_TOKEN_IDS" - ) # Convert Triton types to numpy types self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) - self.output_token_ids_dtype = pb_utils.triton_string_to_numpy( - output_token_ids_config["data_type"] - ) - def execute(self, requests): """`execute` must be implemented in every Python model. `execute` function receives a list of pb_utils.InferenceRequest as the only @@ -109,20 +107,45 @@ def execute(self, requests): # Get input tensors tokens_batch = pb_utils.get_input_tensor_by_name(request, "TOKENS_BATCH").as_numpy() + # Get sequence length + sequence_lengths = pb_utils.get_input_tensor_by_name( + request, "SEQUENCE_LENGTH" + ).as_numpy() + + # Get cum log probs + cum_log_probs = pb_utils.get_input_tensor_by_name(request, "CUM_LOG_PROBS").as_numpy() + + # Get sequence length + output_log_probs = pb_utils.get_input_tensor_by_name( + request, "OUTPUT_LOG_PROBS" + ).as_numpy() + + # Get context logits + context_logits = pb_utils.get_input_tensor_by_name(request, "CONTEXT_LOGITS").as_numpy() + + # Get generation logits + generation_logits = pb_utils.get_input_tensor_by_name( + request, "GENERATION_LOGITS" + ).as_numpy() + # Reshape Input # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]]) # tokens_batch = tokens_batch.T # Postprocessing output data. - outputs = self._postprocessing(tokens_batch) + outputs = self._postprocessing(tokens_batch, sequence_lengths) # Create output tensors. You need pb_utils.Tensor # objects to create pb_utils.InferenceResponse. output_tensor = pb_utils.Tensor("OUTPUT", np.array(outputs).astype(self.output_dtype)) - output_token_ids = pb_utils.Tensor( - "OUTPUT_TOKEN_IDS", np.array(tokens_batch).astype(self.output_token_ids_dtype) - ) + out_cum_log_probs = pb_utils.Tensor("OUT_CUM_LOG_PROBS", cum_log_probs) + + out_output_log_probs = pb_utils.Tensor("OUT_OUTPUT_LOG_PROBS", output_log_probs) + + out_context_logits = pb_utils.Tensor("OUT_CONTEXT_LOGITS", context_logits) + + out_generation_logits = pb_utils.Tensor("OUT_GENERATION_LOGITS", generation_logits) # Create InferenceResponse. You can set an error here in case # there was a problem with handling this inference request. @@ -132,7 +155,13 @@ def execute(self, requests): # pb_utils.InferenceResponse( # output_tensors=..., TritonError("An error occurred")) inference_response = pb_utils.InferenceResponse( - output_tensors=[output_tensor, output_token_ids] + output_tensors=[ + output_tensor, + out_cum_log_probs, + out_output_log_probs, + out_context_logits, + out_generation_logits, + ] ) responses.append(inference_response) @@ -147,10 +176,26 @@ def finalize(self): """ print("Cleaning up...") - def _postprocessing(self, tokens_batch): + def _postprocessing(self, tokens_batch, sequence_lengths): outputs = [] - for beam_tokens in tokens_batch: - for tokens in beam_tokens: - output = self.tokenizer.decode(tokens) + for batch_idx, beam_tokens in enumerate(tokens_batch): + for beam_idx, tokens in enumerate(beam_tokens): + seq_len = sequence_lengths[batch_idx][beam_idx] + output = self.tokenizer.decode( + tokens[:seq_len], skip_special_tokens=self.skip_special_tokens + ) + # Adapted from https://github.com/triton-inference-server/tensorrtllm_backend/pull/423 + # This is somewhat of a hack: add a space before the output if the first token starts with a space + # This may add a space in front of the first token though when we don't want it. + if seq_len > 0: + token_id_string = self.tokenizer.convert_ids_to_tokens( + tokens[:1], skip_special_tokens=self.skip_special_tokens + ) + if ( + len(token_id_string) > 0 + and len(token_id_string[0]) > 0 + and token_id_string[0][0] == SPIECE_UNDERLINE + ): + output = " " + output outputs.append(output.encode("utf8")) return outputs diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt old mode 100755 new mode 100644 index cc61a24e..93af4eec --- a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -26,31 +26,73 @@ name: "postprocessing" backend: "python" -max_batch_size: 128 +max_batch_size: 128 input [ { name: "TOKENS_BATCH" data_type: TYPE_INT32 dims: [ -1, -1 ] + }, + { + name: "SEQUENCE_LENGTH" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "CUM_LOG_PROBS" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "OUTPUT_LOG_PROBS" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "CONTEXT_LOGITS" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + optional: true + }, + { + name: "GENERATION_LOGITS" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + optional: true } ] output [ { name: "OUTPUT" data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "OUT_CUM_LOG_PROBS" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "OUT_OUTPUT_LOG_PROBS" + data_type: TYPE_FP32 dims: [ -1, -1 ] }, { - name: "OUTPUT_TOKEN_IDS" - data_type: TYPE_INT32 + name: "OUT_CONTEXT_LOGITS" + data_type: TYPE_FP32 dims: [ -1, -1 ] + }, + { + name: "OUT_GENERATION_LOGITS" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] } ] parameters { key: "tokenizer_dir" value: { - string_value: "model_tokenizer" + string_value: "model_tokenizer" } } @@ -61,6 +103,13 @@ parameters { } } +parameters { + key: "skip_special_tokens" + value: { + string_value: "True" + } +} + instance_group [ { count: 1 diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py index b5996f87..ea2d4789 100644 --- a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -24,14 +24,11 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import csv import json from typing import List import numpy as np -import torch import triton_python_backend_utils as pb_utils -from torch.nn.utils.rnn import pad_sequence from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer @@ -59,11 +56,16 @@ def initialize(self, args): model_config = json.loads(args["model_config"]) tokenizer_dir = model_config["parameters"]["tokenizer_dir"]["string_value"] tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + self.add_special_tokens = model_config["parameters"].get( + "add_special_tokens", {"string_value": "false"} + )["string_value"].lower() in ["true", "1", "t", "y", "yes"] if tokenizer_type == "t5": self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") elif tokenizer_type == "auto": - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, padding_side="left") + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, padding_side="left", trust_remote_code=True + ) elif tokenizer_type == "llama": self.tokenizer = LlamaTokenizer.from_pretrained( tokenizer_dir, legacy=False, padding_side="left" @@ -72,16 +74,38 @@ def initialize(self, args): raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") self.tokenizer.pad_token = self.tokenizer.eos_token - self.pad_id = self.tokenizer.encode(self.tokenizer.pad_token, add_special_tokens=False)[0] + self.tokenizer_end_id = self.tokenizer.encode( + self.tokenizer.eos_token, add_special_tokens=False + )[0] + self.tokenizer_pad_id = self.tokenizer.encode( + self.tokenizer.pad_token, add_special_tokens=False + )[0] # Parse model output configs and convert Triton types to numpy types - input_names = ["INPUT_ID", "REQUEST_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS"] + output_names = [ + "INPUT_ID", + "REQUEST_INPUT_LEN", + "BAD_WORDS_IDS", + "STOP_WORDS_IDS", + "OUT_END_ID", + "OUT_PAD_ID", + ] + input_names = ["EMBEDDING_BIAS_WORDS", "EMBEDDING_BIAS_WEIGHTS"] for input_name in input_names: setattr( self, input_name.lower() + "_dtype", pb_utils.triton_string_to_numpy( - pb_utils.get_output_config_by_name(model_config, input_name)["data_type"] + pb_utils.get_input_config_by_name(model_config, input_name)["data_type"] + ), + ) + + for output_name in output_names: + setattr( + self, + output_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name(model_config, output_name)["data_type"] ), ) @@ -109,43 +133,84 @@ def execute(self, requests): # Every Python backend must iterate over everyone of the requests # and create a pb_utils.InferenceResponse for each of them. + logger = pb_utils.Logger for idx, request in enumerate(requests): # Get input tensors query = pb_utils.get_input_tensor_by_name(request, "QUERY").as_numpy() + batch_dim = query.shape[0] + if batch_dim != 1: + + err_str = "Inflight batching backend expects requests with batch size of 1." + logger.log_error(err_str) + responses.append( + pb_utils.InferenceResponse( + output_tensors=[], error=pb_utils.TritonError(err_str) + ) + ) + continue + request_output_len = pb_utils.get_input_tensor_by_name( request, "REQUEST_OUTPUT_LEN" ).as_numpy() - bad_words_dict = pb_utils.get_input_tensor_by_name(request, "BAD_WORDS_DICT").as_numpy() - stop_words_dict = pb_utils.get_input_tensor_by_name( - request, "STOP_WORDS_DICT" - ).as_numpy() + bad_words_dict = pb_utils.get_input_tensor_by_name(request, "BAD_WORDS_DICT") + if bad_words_dict is not None: + bad_words_dict = bad_words_dict.as_numpy() + + stop_words_dict = pb_utils.get_input_tensor_by_name(request, "STOP_WORDS_DICT") + if stop_words_dict is not None: + stop_words_dict = stop_words_dict.as_numpy() + + embedding_bias_words = pb_utils.get_input_tensor_by_name( + request, "EMBEDDING_BIAS_WORDS" + ) + if embedding_bias_words is not None: + embedding_bias_words = embedding_bias_words.as_numpy() + + embedding_bias_weights = pb_utils.get_input_tensor_by_name( + request, "EMBEDDING_BIAS_WEIGHTS" + ) + if embedding_bias_weights is not None: + embedding_bias_weights = embedding_bias_weights.as_numpy() + + # Take the end_id from the input tensors + # If not specified, use tokenizer to get end_id + end_id = pb_utils.get_input_tensor_by_name(request, "END_ID") + if end_id is not None: + end_id = end_id.as_numpy() + else: + end_id = [[self.tokenizer_end_id]] + + # Take the pad_id from the input tensors + # If not specified, use tokenizer to get pad_id + pad_id = pb_utils.get_input_tensor_by_name(request, "PAD_ID") + if pad_id is not None: + pad_id = pad_id.as_numpy() + else: + pad_id = [[self.tokenizer_pad_id]] # Preprocessing input data. input_id, request_input_len = self._create_request(query) bad_words = self._to_word_list_format(bad_words_dict) stop_words = self._to_word_list_format(stop_words_dict) + embedding_bias = self._get_embedding_bias( + embedding_bias_words, embedding_bias_weights, self.embedding_bias_weights_dtype + ) + # Create output tensors. You need pb_utils.Tensor # objects to create pb_utils.InferenceResponse. - input_id_tensor = pb_utils.Tensor( - "INPUT_ID", np.array(input_id).astype(self.input_id_dtype) - ) + input_id_tensor = pb_utils.Tensor("INPUT_ID", input_id.astype(self.input_id_dtype)) request_input_len_tensor = pb_utils.Tensor( - "REQUEST_INPUT_LEN", - np.array(request_input_len).astype(self.request_input_len_dtype), + "REQUEST_INPUT_LEN", request_input_len.astype(self.request_input_len_dtype) ) request_output_len_tensor = pb_utils.Tensor("REQUEST_OUTPUT_LEN", request_output_len) bad_words_ids_tensor = pb_utils.Tensor("BAD_WORDS_IDS", bad_words) stop_words_ids_tensor = pb_utils.Tensor("STOP_WORDS_IDS", stop_words) + embedding_bias_tensor = pb_utils.Tensor("EMBEDDING_BIAS", embedding_bias) + end_id_tensor = pb_utils.Tensor("OUT_END_ID", np.array(end_id, dtype=np.int32)) + pad_id_tensor = pb_utils.Tensor("OUT_PAD_ID", np.array(pad_id, dtype=np.int32)) - # Create InferenceResponse. You can set an error here in case - # there was a problem with handling this inference request. - # Below is an example of how you can set errors in inference - # response: - # - # pb_utils.InferenceResponse( - # output_tensors=..., TritonError("An error occurred")) inference_response = pb_utils.InferenceResponse( output_tensors=[ input_id_tensor, @@ -153,6 +218,9 @@ def execute(self, requests): stop_words_ids_tensor, request_input_len_tensor, request_output_len_tensor, + embedding_bias_tensor, + end_id_tensor, + pad_id_tensor, ] ) responses.append(inference_response) @@ -172,46 +240,70 @@ def _create_request(self, query): """ query : batch string (2D numpy array) """ - start_ids = [torch.IntTensor(self.tokenizer.encode(s[0].decode())) for s in query] - start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) + start_ids = [ + np.array( + self.tokenizer.encode(s[0].decode(), add_special_tokens=self.add_special_tokens) + ).astype(int) + for s in query + ] + start_lengths = np.array([[len(ids)] for ids in start_ids]).astype(int) - start_ids = pad_sequence(start_ids, batch_first=True, padding_value=self.pad_id) - # input_len = min(start_lengths) - # attn_mask = torch.ones((batch_size, input_len, input_len)).tril() + max_len = 0 + for seq in start_ids: + max_len = max(max_len, seq.shape[0]) + start_ids = np.stack( + [ + np.pad( + seq, + (0, max_len - seq.shape[0]), + "constant", + constant_values=(0, self.tokenizer_pad_id), + ) + for seq in start_ids + ] + ) return start_ids, start_lengths - def _to_word_list_format(self, word_dict: List[List[str]]): + def _to_word_list_format(self, word_lists: List[List[str | bytes]]): """ - format of word_dict - len(word_dict) should be same to batch_size - word_dict[i] means the words for batch i - len(word_dict[i]) must be 1, which means it only contains 1 string - This string can contains several sentences and split by ",". - For example, if word_dict[2] = " I am happy, I am sad", then this function will return - the ids for two short sentences " I am happy" and " I am sad". + word_lists format: + len(word_lists) == batch_size + word_lists[i] means the words associated to batch item i. A "word" may actually be any string. Like "lorem" or "lorem ipsum". """ assert self.tokenizer is not None, "need to set tokenizer" + if word_lists is None: + # Return an empty array of shape (1,2,0) + return np.empty([1, 2, 0], dtype="int32") + flat_ids = [] offsets = [] - for word_dict_item in word_dict: + for word_list in word_lists: item_flat_ids = [] item_offsets = [] - if isinstance(word_dict_item[0], bytes): - word_dict_item = [word_dict_item[0].decode()] - - words = list(csv.reader(word_dict_item))[0] - for word in words: - ids = self.tokenizer.encode(word) + for word in word_list: + if isinstance(word, bytes): + word = word.decode() + ids = self.tokenizer.encode(word, add_special_tokens=False) if len(ids) == 0: continue item_flat_ids += ids item_offsets.append(len(ids)) + # Add a case where ids[0] decodes to empty string, then add another set of ids here + # Unfortunately, we don't have access to the entire sequence of returned response tokens when decoding, + # so we have to do what we can to get a reasonable list of token ids corresponding to a stop sequence. + # True correctness would look like figuring out all the ways of decoding a stop sequence, and then + # adding all of them to this item_flat_ids map. + if len(ids) > 1 and self.tokenizer.decode(ids[0]) == "": + new_ids = ids[1:] + item_flat_ids += new_ids + item_offsets.append(len(new_ids)) + flat_ids.append(np.array(item_flat_ids)) offsets.append(np.cumsum(np.array(item_offsets))) @@ -222,3 +314,35 @@ def _to_word_list_format(self, word_dict: List[List[str]]): offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) + + def _get_embedding_bias(self, embedding_bias_words, embedding_bias_weights, bias_dtype): + + assert self.tokenizer is not None, "need to set tokenizer" + + if embedding_bias_words is None or embedding_bias_weights is None: + return np.empty([1, 0], dtype=self.embedding_bias_weights_dtype) + + batch_embedding_bias = [] + for words, weights in zip(embedding_bias_words, embedding_bias_weights): + + vocab_size = self.tokenizer.vocab_size + embedding_bias = [0.0] * vocab_size + + assert len(words) == len( + weights + ), "Embedding bias words must have same dimension as embedding bias weights" + + for word, weight in zip(words, weights): + if isinstance(word, bytes): + word = word.decode() + ids = self.tokenizer.encode(word) + + if len(ids) == 0: + continue + + for id in ids: + embedding_bias[id] += weight + + batch_embedding_bias.append(np.array(embedding_bias)) + + return np.array(batch_embedding_bias, dtype=bias_dtype) diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt index 89d9c91e..3a77e264 100644 --- a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -26,27 +26,53 @@ name: "preprocessing" backend: "python" -max_batch_size: 128 +max_batch_size: 128 input [ { name: "QUERY" data_type: TYPE_STRING dims: [ -1 ] }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_INT32 + dims: [ -1 ] + }, { name: "BAD_WORDS_DICT" data_type: TYPE_STRING dims: [ -1 ] + optional: true }, { name: "STOP_WORDS_DICT" data_type: TYPE_STRING dims: [ -1 ] + optional: true }, { - name: "REQUEST_OUTPUT_LEN" - data_type: TYPE_UINT32 + name: "EMBEDDING_BIAS_WORDS" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "EMBEDDING_BIAS_WEIGHTS" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + }, + { + name: "END_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + }, + { + name: "PAD_ID" + data_type: TYPE_INT32 dims: [ -1 ] + optional: true } ] output [ @@ -70,9 +96,24 @@ output [ data_type: TYPE_INT32 dims: [ 2, -1 ] }, + { + name: "EMBEDDING_BIAS" + data_type: TYPE_FP32 + dims: [ -1 ] + }, { name: "REQUEST_OUTPUT_LEN" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "OUT_END_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "OUT_PAD_ID" + data_type: TYPE_INT32 dims: [ -1 ] } ] @@ -91,6 +132,13 @@ parameters { } } +parameters { + key: "add_special_tokens" + value: { + string_value: "False" + } +} + instance_group [ { count: 1 diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt index e24a95b4..f1b466eb 100644 --- a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt @@ -1,4 +1,4 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -26,17 +26,23 @@ name: "tensorrt_llm" backend: "tensorrtllm" -max_batch_size: 128 +max_batch_size: 128 model_transaction_policy { decoupled: true } +dynamic_batching { + preferred_batch_size: [ 128 ] + max_queue_delay_microseconds: 100000 +} + input [ { name: "input_ids" data_type: TYPE_INT32 dims: [ -1 ] + allow_ragged_batch: true }, { name: "input_lengths" @@ -46,26 +52,54 @@ input [ }, { name: "request_output_len" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] }, + { + name: "draft_input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, { name: "end_id" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] reshape: { shape: [ ] } optional: true }, { name: "pad_id" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] reshape: { shape: [ ] } optional: true }, + { + name: "stop_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "bad_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "embedding_bias" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, { name: "beam_width" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] reshape: { shape: [ ] } optional: true @@ -79,7 +113,7 @@ input [ }, { name: "runtime_top_k" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] reshape: { shape: [ ] } optional: true @@ -107,7 +141,7 @@ input [ }, { name: "min_length" - data_type: TYPE_UINT32 + data_type: TYPE_INT32 dims: [ 1 ] reshape: { shape: [ ] } optional: true @@ -119,6 +153,13 @@ input [ reshape: { shape: [ ] } optional: true }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, { name: "random_seed" data_type: TYPE_UINT64 @@ -126,6 +167,27 @@ input [ reshape: { shape: [ ] } optional: true }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, { name: "stop" data_type: TYPE_BOOL @@ -137,6 +199,51 @@ input [ data_type: TYPE_BOOL dims: [ 1 ] optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ] + # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer + # each of the in / out tensors are first flattened and then concatenated together in the format above. + # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out. + { + name: "lora_weights" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + # module identifier (same size a first dimension of lora_weights) + # See LoraModule::ModuleType for model id mapping + # + # "attn_qkv": 0 # compbined qkv adapter + # "attn_q": 1 # q adapter + # "attn_k": 2 # k adapter + # "attn_v": 3 # v adapter + # "attn_dense": 4 # adapter for the dense layer in attention + # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection + # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection + # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate + # + # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ] + { + name: "lora_config" + data_type: TYPE_INT32 + dims: [ -1, 3 ] + optional: true + allow_ragged_batch: true } ] output [ @@ -144,6 +251,31 @@ output [ name: "output_ids" data_type: TYPE_INT32 dims: [ -1, -1 ] + }, + { + name: "sequence_length" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "context_logits" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] } ] instance_group [ @@ -182,10 +314,16 @@ parameters: { string_value: "${max_tokens_in_paged_kv_cache}" } } +parameters: { + key: "max_attention_window_size" + value: { + string_value: "${max_attention_window_size}" + } +} parameters: { key: "batch_scheduler_policy" value: { - string_value: "${batch_scheduler_policy}" + string_value: "max_utilization" } } parameters: { @@ -195,14 +333,38 @@ parameters: { } } parameters: { - key: "max_num_sequences" + key: "enable_trt_overlap" value: { - string_value: "${max_num_sequences}" + string_value: "${enable_trt_overlap}" } } parameters: { - key: "enable_trt_overlap" + key: "exclude_input_in_output" value: { - string_value: "${enable_trt_overlap}" + string_value: "true" + } +} +parameters: { + key: "enable_kv_cache_reuse" + value: { + string_value: "${enable_kv_cache_reuse}" + } +} +parameters: { + key: "normalize_log_probs" + value: { + string_value: "${normalize_log_probs}" + } +} +parameters: { + key: "enable_chunked_context" + value: { + string_value: "${enable_chunked_context}" + } +} +parameters: { + key: "gpu_device_ids" + value: { + string_value: "${gpu_device_ids}" } } diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/1/model.py new file mode 100644 index 00000000..545e3a7d --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/1/model.py @@ -0,0 +1,389 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import traceback + +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + + # Parse model configs + model_config = json.loads(args["model_config"]) + + params = model_config["parameters"] + + accumulate_tokens_str = "" + if "accumulate_tokens" in params: + accumulate_tokens_str = params["accumulate_tokens"]["string_value"] + + self.accumulate_tokens = accumulate_tokens_str.lower() in ["true", "yes", "1", "t"] + + self.decoupled = pb_utils.using_decoupled_model_transaction_policy(model_config) + + self.logger = pb_utils.Logger + + self.bls_input_tensor_names = [ + "text_input", + "max_tokens", + "bad_words", + "stop_words", + "end_id", + "pad_id", + "top_k", + "top_p", + "temperature", + "length_penalty", + "repetition_penalty", + "min_length", + "presence_penalty", + "frequency_penalty", + "random_seed", + "return_log_probs", + "return_context_logits", + "return_generation_logits", + "beam_width", + "stream", + "prompt_embedding_table", + "prompt_vocab_size", + "embedding_bias_words", + "embedding_bias_weights", + ] + + self.preproc_input_to_bls_input_map = { + "QUERY": "text_input", + "REQUEST_OUTPUT_LEN": "max_tokens", + "BAD_WORDS_DICT": "bad_words", + "STOP_WORDS_DICT": "stop_words", + "EMBEDDING_BIAS_WORDS": "embedding_bias_words", + "EMBEDDING_BIAS_WEIGHTS": "embedding_bias_weights", + "END_ID": "end_id", + "PAD_ID": "pad_id", + } + + self.preproc_output_to_trtllm_input_map = { + "INPUT_ID": "input_ids", + "REQUEST_INPUT_LEN": "input_lengths", + "REQUEST_OUTPUT_LEN": "request_output_len", + "BAD_WORDS_IDS": "bad_words_list", + "STOP_WORDS_IDS": "stop_words_list", + "EMBEDDING_BIAS": "embedding_bias", + "OUT_END_ID": "end_id", + "OUT_PAD_ID": "pad_id", + } + + self.trtllm_input_to_bls_input_map = { + "beam_width": "beam_width", + "runtime_top_k": "top_k", + "runtime_top_p": "top_p", + "len_penalty": "length_penalty", + "repetition_penalty": "repetition_penalty", + "min_length": "min_length", + "presence_penalty": "presence_penalty", + "frequency_penalty": "frequency_penalty", + "random_seed": "random_seed", + "return_log_probs": "return_log_probs", + "return_context_logits": "return_context_logits", + "return_generation_logits": "return_generation_logits", + "streaming": "stream", + "prompt_embedding_table": "prompt_embedding_table", + "prompt_vocab_size": "prompt_vocab_size", + } + + self.trtllm_output_to_postproc_input_map = { + "output_ids": "TOKENS_BATCH", + "sequence_length": "SEQUENCE_LENGTH", + "cum_log_probs": "CUM_LOG_PROBS", + "output_log_probs": "OUTPUT_LOG_PROBS", + "context_logits": "CONTEXT_LOGITS", + "generation_logits": "GENERATION_LOGITS", + } + + self.postproc_output_to_bls_output_map = { + "OUTPUT": "text_output", + "OUT_CUM_LOG_PROBS": "cum_log_probs", + "OUT_OUTPUT_LOG_PROBS": "output_log_probs", + "OUT_CONTEXT_LOGITS": "context_logits", + "OUT_GENERATION_LOGITS": "generation_logits", + } + + def _get_bls_input_tensors_map(self, request): + + bls_input_tensors_map = {} + for input_tensor_name in self.bls_input_tensor_names: + tensor = pb_utils.get_input_tensor_by_name(request, input_tensor_name) + if tensor is not None: + bls_input_tensors_map[input_tensor_name] = tensor + + return bls_input_tensors_map + + def _get_preproc_input_tensors(self, bls_input_tensors_map): + + preproc_input_tensors = [] + + for preproc_name, bls_name in self.preproc_input_to_bls_input_map.items(): + + if bls_name in bls_input_tensors_map: + tensor = bls_input_tensors_map[bls_name] + # Change the name to what the preprocessor expects + preproc_input_tensors.append(pb_utils.Tensor(preproc_name, tensor.as_numpy())) + + return preproc_input_tensors + + def _get_trtllm_input_tensors(self, bls_input_tensors_map, preproc_output_tensors): + + trtllm_input_tensors = [] + + # Set input tensors from preprocessor outputs + for preproc_output_tensor in preproc_output_tensors: + + trtllm_tensor_name = self.preproc_output_to_trtllm_input_map[ + preproc_output_tensor.name() + ] + trtllm_input_tensors.append( + pb_utils.Tensor(trtllm_tensor_name, preproc_output_tensor.as_numpy()) + ) + + # Set input tensors from bls inputs + for trtllm_name, bls_name in self.trtllm_input_to_bls_input_map.items(): + + if bls_name in bls_input_tensors_map: + tensor = bls_input_tensors_map[bls_name] + # Change the name to what the preprocessor expects + trtllm_input_tensors.append(pb_utils.Tensor(trtllm_name, tensor.as_numpy())) + + return trtllm_input_tensors + + def _get_postproc_input_tensors(self, tokens, trtllm_output_tensors): + + postproc_input_tensors = [] + + for trtllm_output_tensor in trtllm_output_tensors: + + # If in decoupled mode, option to append new tokens to existing tokens before calling postprocessor + # This might be needed for some tokenizers + # Note that in that case, the client must overwrite previously received output text + if ( + self.accumulate_tokens + and self.decoupled + and trtllm_output_tensor.name() == "output_ids" + ): + + new_tokens = trtllm_output_tensor.as_numpy() + if new_tokens.ndim != 3: + raise pb_utils.TritonModelException( + "Expected output_ids tensor to have 3 dims." + ) + if new_tokens.shape[0] != 1: + raise pb_utils.TritonModelException( + "Expected output_ids tensor to have batch size of 1" + ) + if new_tokens.shape[1] != 1: + raise pb_utils.TritonModelException( + "Accumulation of tokens is only implemented for beam width = 1" + ) + + tokens = ( + new_tokens if (tokens is None) else np.concatenate((tokens, new_tokens), axis=2) + ) + + # output ids + postproc_output_ids_name = self.trtllm_output_to_postproc_input_map["output_ids"] + postproc_input_tensors.append(pb_utils.Tensor(postproc_output_ids_name, tokens)) + + # sequence length + np_seq_len_tensor = np.array([[tokens.shape[2]]], dtype=np.int32) + postproc_seq_len_name = self.trtllm_output_to_postproc_input_map["sequence_length"] + postproc_input_tensors.append( + pb_utils.Tensor(postproc_seq_len_name, np_seq_len_tensor) + ) + + # Set input tensors from trtllm outputs + for trtllm_output_tensor in trtllm_output_tensors: + + # output_ids and sequence_length were handled earlier + if ( + self.accumulate_tokens + and self.decoupled + and ( + trtllm_output_tensor.name() == "output_ids" + or trtllm_output_tensor.name() == "sequence_length" + ) + ): + continue + + postproc_tensor_name = self.trtllm_output_to_postproc_input_map[ + trtllm_output_tensor.name() + ] + + postproc_input_tensors.append( + pb_utils.Tensor(postproc_tensor_name, trtllm_output_tensor.as_numpy()) + ) + + return tokens, postproc_input_tensors + + def _get_bls_output_tensors(self, postproc_output_tensors): + + bls_output_tensors = [] + + # Set input tensors from trtllm outputs + for postproc_output_tensor in postproc_output_tensors: + + bls_tensor_name = self.postproc_output_to_bls_output_map[postproc_output_tensor.name()] + bls_output_tensors.append( + pb_utils.Tensor(bls_tensor_name, postproc_output_tensor.as_numpy()) + ) + + return bls_output_tensors + + def execute(self, requests): + + responses = [] + bls_response_sender = None + + for request in requests: + + # Get the response sender for the BLS + if self.decoupled: + bls_response_sender = request.get_response_sender() + + try: + # Get the bls input tensors + bls_input_tensors_map = self._get_bls_input_tensors_map(request) + + # Check the batch dimension + for name, tensor in bls_input_tensors_map.items(): + batch_dim = tensor.as_numpy().shape[0] + + if batch_dim != 1: + + err_str = "Inflight batching backend expects requests with batch size of 1." + self.logger.log_error(err_str) + raise pb_utils.TritonModelException(err_str) + + # Create the preprocessor input tensors + preproc_input_tensors = self._get_preproc_input_tensors(bls_input_tensors_map) + + preproc_request = pb_utils.InferenceRequest( + model_name="preprocessing", + inputs=preproc_input_tensors, + requested_output_names=list(self.preproc_output_to_trtllm_input_map.keys()), + ) + + # Execute preprocessor + preproc_response = preproc_request.exec() + + if preproc_response.has_error(): + raise pb_utils.TritonModelException(preproc_response.error().message()) + + # Create the trtllm input tensors + trtllm_input_tensors = self._get_trtllm_input_tensors( + bls_input_tensors_map, preproc_response.output_tensors() + ) + + trtllm_request = pb_utils.InferenceRequest( + model_name="tensorrt_llm", + inputs=trtllm_input_tensors, + requested_output_names=list(self.trtllm_output_to_postproc_input_map.keys()), + ) + + # Execute trtllm + trtllm_responses = trtllm_request.exec(decoupled=self.decoupled) + + if not self.decoupled: + trtllm_responses = [trtllm_responses] + + tokens = None + + # Loop over the trtllm responses + for trtllm_response in trtllm_responses: + + if trtllm_response.has_error(): + raise pb_utils.TritonModelException(trtllm_response.error().message()) + + trtllm_output_tensors = trtllm_response.output_tensors() + + tokens, postproc_input_tensors = self._get_postproc_input_tensors( + tokens, trtllm_output_tensors + ) + + postproc_request = pb_utils.InferenceRequest( + model_name="postprocessing", + inputs=postproc_input_tensors, + requested_output_names=list(self.postproc_output_to_bls_output_map.keys()), + ) + + # Execute postprocessor + postproc_response = postproc_request.exec() + + if postproc_response.has_error(): + raise pb_utils.TritonModelException(postproc_response.error().message()) + + # Create the BLS response + bls_output_tensors = self._get_bls_output_tensors( + postproc_response.output_tensors() + ) + + bls_response = pb_utils.InferenceResponse(output_tensors=bls_output_tensors) + + if self.decoupled: + bls_response_sender.send(bls_response) + else: + responses.append(bls_response) + + # All responses have been sent, set final flag + if self.decoupled: + bls_response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + + except Exception: + + self.logger.log_error(traceback.format_exc()) + # If encountering an error, send a response with err msg + error_response = pb_utils.InferenceResponse( + output_tensors=[], error=pb_utils.TritonError(traceback.format_exc()) + ) + + if self.decoupled: + bls_response_sender.send(error_response) + bls_response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + else: + responses.append(error_response) + + if self.decoupled: + return None + else: + assert len(responses) == len(requests) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/config.pbtxt new file mode 100644 index 00000000..168c819c --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/config.pbtxt @@ -0,0 +1,221 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm_bls" +backend: "python" +max_batch_size: 128 + +model_transaction_policy { + decoupled: true +} + +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "bad_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "stop_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "end_id" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "length_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_width" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "embedding_bias_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "embedding_bias_weights" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + } +] +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "context_logits" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + } +] + +parameters: { + key: "accumulate_tokens" + value: { + string_value: "true" + } +} + +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 950f915f..2beb12a1 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -894,7 +894,7 @@ async def test_completion_sync_text_generation_inference_use_case_success( "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", return_value=6, ) -async def test_completion_sync_trt_llm_use_case_success( +async def test_completion_sync_trt_llm_use_case_success_23_10( test_api_key: str, fake_model_endpoint_service, fake_llm_model_endpoint_service, @@ -929,6 +929,51 @@ async def test_completion_sync_trt_llm_use_case_success( ) +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=6, +) +@pytest.mark.parametrize( + "output_log_probs,output_tokens", [("[0.0,0.0,0.0,0.0,0.0]", 5), ("0.0", 1)] +) +async def test_completion_sync_trt_llm_use_case_success_24_01( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_trt_llm: ModelEndpoint, + completion_sync_request: CompletionSyncV1Request, + output_log_probs: str, + output_tokens: int, +): + completion_sync_request.return_token_log_probs = False # not yet supported + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_trt_llm) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": f'{{"context_logits":0.0,"cum_log_probs":0.0,"generation_logits":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":{output_log_probs},"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":" Machine learning is a branch"}}' + }, + traceback=None, + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_trt_llm.record.name, + request=completion_sync_request, + ) + assert response_1.output == CompletionOutput( + text=" Machine learning is a branch", + num_prompt_tokens=6, + num_completion_tokens=output_tokens, + ) + + @pytest.mark.asyncio async def test_completion_sync_use_case_predict_failed( test_api_key: str, diff --git a/scripts/throughput_benchmarks.py b/scripts/throughput_benchmarks.py index 7d2cb32f..e970ead7 100644 --- a/scripts/throughput_benchmarks.py +++ b/scripts/throughput_benchmarks.py @@ -67,6 +67,8 @@ def send_request(url, request, user=None): first_line = True inter_token_latencies = [] last_token_time = None + payload_json: dict = {} + num_completion_tokens = 0 # We calculate this value manually since tensorrt llm doesn't give it for byte_payload in response.iter_lines(): # Skip line if byte_payload == b"\n" or byte_payload == b"": @@ -87,12 +89,14 @@ def send_request(url, request, user=None): if payload.startswith("data:"): payload_data = payload.lstrip("data:").rstrip("/n") payload_json = json.loads(payload_data) + num_completion_tokens += 1 return { "payload": payload_json, "time_to_first_token": time_to_first_token, "total_time": time.time() - start, "inter_token_latencies": inter_token_latencies, + "num_completion_tokens": num_completion_tokens, } @@ -109,7 +113,13 @@ def pull_and_send_request_from_queue( if use_localhost: if framework == InferenceFramework.VLLM: response = send_request(f"http://localhost:{local_port}/stream", request) - response["num_completion_tokens"] = response["payload"]["count_output_tokens"] + response["num_completion_tokens"] = response["payload"][ + "count_output_tokens" + ] # vLLM gives us completion token count, use that. + elif framework == InferenceFramework.TENSORRT_LLM: + response = send_request( + f"http://localhost:{local_port}/v2/models/ensemble/generate_stream", request + ) else: raise NotImplementedError() else: @@ -128,8 +138,10 @@ def pull_and_send_request_from_queue( def generate_request( framework: InferenceFramework, prompt: str, output_token_count: int, localhost: bool ): + temperature = 0.0 + if not localhost: - return {"prompt": prompt, "max_new_tokens": output_token_count, "temperature": 0.0} + return {"prompt": prompt, "max_new_tokens": output_token_count, "temperature": temperature} if framework == InferenceFramework.TEXT_GENERATION_INFERENCE: return { @@ -144,7 +156,7 @@ def generate_request( return { "prompt": prompt, "max_tokens": output_token_count, - "temperature": 0, + "temperature": temperature, "stream": True, } elif framework == InferenceFramework.LIGHTLLM: @@ -161,6 +173,10 @@ def generate_request( "text_input": prompt, "bad_words": "", "stop_words": "", + "parameters": { + "temperature": temperature, + "stream": True, + }, } else: raise NotImplementedError() From 80e5276f7d29c1f5672dd7b0a7e49dc112de7e99 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 15 May 2024 14:29:59 -0700 Subject: [PATCH 301/425] Upgrade vLLM version for batch completion (#518) * bump pydantic==2.7.1 * Add fallback to v1 models if pydantic >2 * version bump vllm * Update docker image to manually install flash attention * skip coverage --- .../model_engine_server/common/dtos/llms.py | 7 +++- .../inference/batch_inference/Dockerfile_vllm | 40 +++++++++++++++++++ .../batch_inference/requirements-build.txt | 8 ++++ .../batch_inference/requirements-dev.txt | 1 + .../batch_inference/requirements.txt | 6 +-- 5 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 model-engine/model_engine_server/inference/batch_inference/requirements-build.txt create mode 100644 model-engine/model_engine_server/inference/batch_inference/requirements-dev.txt diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 90498c23..1232e1af 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional +import pydantic from model_engine_server.common.dtos.model_endpoints import ( CpuSpecificationType, GetModelEndpointV1Response, @@ -21,7 +22,11 @@ ModelEndpointStatus, Quantization, ) -from pydantic import BaseModel, Field, HttpUrl + +if int(pydantic.__version__.split(".")[0]) > 1: + from pydantic.v1 import BaseModel, Field, HttpUrl # pragma: no cover +else: + from pydantic import BaseModel, Field, HttpUrl class CreateLLMModelEndpointV1Request(BaseModel): diff --git a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm index 3b08756c..90611589 100644 --- a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm +++ b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm @@ -1,3 +1,37 @@ +#################### BASE BUILD IMAGE #################### +FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev +RUN apt-get update -y \ + && apt-get install -y python3-pip git +# Workaround for https://github.com/openai/triton/issues/2507 and +# https://github.com/pytorch/pytorch/issues/107960 -- hopefully +# this won't be needed for future versions of this docker image +# or future versions of triton. +RUN ldconfig /usr/local/cuda-12.1/compat/ +WORKDIR /workspace + +COPY model-engine/model_engine_server/inference/batch_inference/requirements-build.txt requirements-build.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-build.txt +#################### BASE BUILD IMAGE #################### + +#################### FLASH_ATTENTION Build IMAGE #################### +FROM dev as flash-attn-builder +# max jobs used for build +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} +# flash attention version +ARG flash_attn_version=v2.5.6 +ENV FLASH_ATTN_VERSION=${flash_attn_version} + +WORKDIR /usr/src/flash-attention-v2 + +# Download the wheel or build it if a pre-compiled release doesn't exist +RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ + --no-build-isolation --no-deps --no-cache-dir + +#################### FLASH_ATTENTION Build IMAGE #################### + +#################### Runtime IMAGE #################### FROM nvcr.io/nvidia/pytorch:23.09-py3 RUN apt-get update && \ @@ -6,6 +40,10 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean +# Install flash attention (from pre-built wheel) +RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ + pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir + RUN pip uninstall torch -y RUN pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cu121 @@ -21,3 +59,5 @@ RUN pip install -r requirements.txt COPY model-engine /workspace/model-engine RUN pip install -e /workspace/model-engine COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py + +#################### Runtime IMAGE #################### \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements-build.txt b/model-engine/model_engine_server/inference/batch_inference/requirements-build.txt new file mode 100644 index 00000000..020e4532 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/requirements-build.txt @@ -0,0 +1,8 @@ +# Copied from https://github.com/vllm-project/vllm/blob/main/requirements-build.txt +# Needed to build flash-attn into docker image +cmake>=3.21 +ninja +packaging +setuptools>=49.4.0 +torch==2.3.0 +wheel \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements-dev.txt b/model-engine/model_engine_server/inference/batch_inference/requirements-dev.txt new file mode 100644 index 00000000..5d66e061 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/requirements-dev.txt @@ -0,0 +1 @@ +-e ../../.. # Need to install model_engine_server as a package \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt index e83b4ccd..e27ece73 100644 --- a/model-engine/model_engine_server/inference/batch_inference/requirements.txt +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -1,6 +1,6 @@ -vllm==0.2.5 -pydantic==1.10.13 -boto3==1.34.15 +vllm==0.4.2 +pydantic==2.7.1 +boto3>=1.34.105 smart-open==6.4.0 ddtrace==2.4.0 docker==7.0.0 From a36f7a298e98b1e4120b3e70f3150a8bbde7966c Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 15 May 2024 19:46:05 -0700 Subject: [PATCH 302/425] Revert "Upgrade vLLM version for batch completion (#518)" (#520) This reverts commit 80e5276f7d29c1f5672dd7b0a7e49dc112de7e99. --- .../model_engine_server/common/dtos/llms.py | 7 +--- .../inference/batch_inference/Dockerfile_vllm | 40 ------------------- .../batch_inference/requirements-build.txt | 8 ---- .../batch_inference/requirements-dev.txt | 1 - .../batch_inference/requirements.txt | 6 +-- 5 files changed, 4 insertions(+), 58 deletions(-) delete mode 100644 model-engine/model_engine_server/inference/batch_inference/requirements-build.txt delete mode 100644 model-engine/model_engine_server/inference/batch_inference/requirements-dev.txt diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 1232e1af..90498c23 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional -import pydantic from model_engine_server.common.dtos.model_endpoints import ( CpuSpecificationType, GetModelEndpointV1Response, @@ -22,11 +21,7 @@ ModelEndpointStatus, Quantization, ) - -if int(pydantic.__version__.split(".")[0]) > 1: - from pydantic.v1 import BaseModel, Field, HttpUrl # pragma: no cover -else: - from pydantic import BaseModel, Field, HttpUrl +from pydantic import BaseModel, Field, HttpUrl class CreateLLMModelEndpointV1Request(BaseModel): diff --git a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm index 90611589..3b08756c 100644 --- a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm +++ b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm @@ -1,37 +1,3 @@ -#################### BASE BUILD IMAGE #################### -FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev -RUN apt-get update -y \ - && apt-get install -y python3-pip git -# Workaround for https://github.com/openai/triton/issues/2507 and -# https://github.com/pytorch/pytorch/issues/107960 -- hopefully -# this won't be needed for future versions of this docker image -# or future versions of triton. -RUN ldconfig /usr/local/cuda-12.1/compat/ -WORKDIR /workspace - -COPY model-engine/model_engine_server/inference/batch_inference/requirements-build.txt requirements-build.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-build.txt -#################### BASE BUILD IMAGE #################### - -#################### FLASH_ATTENTION Build IMAGE #################### -FROM dev as flash-attn-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# flash attention version -ARG flash_attn_version=v2.5.6 -ENV FLASH_ATTN_VERSION=${flash_attn_version} - -WORKDIR /usr/src/flash-attention-v2 - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ - --no-build-isolation --no-deps --no-cache-dir - -#################### FLASH_ATTENTION Build IMAGE #################### - -#################### Runtime IMAGE #################### FROM nvcr.io/nvidia/pytorch:23.09-py3 RUN apt-get update && \ @@ -40,10 +6,6 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean -# Install flash attention (from pre-built wheel) -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir - RUN pip uninstall torch -y RUN pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cu121 @@ -59,5 +21,3 @@ RUN pip install -r requirements.txt COPY model-engine /workspace/model-engine RUN pip install -e /workspace/model-engine COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py - -#################### Runtime IMAGE #################### \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements-build.txt b/model-engine/model_engine_server/inference/batch_inference/requirements-build.txt deleted file mode 100644 index 020e4532..00000000 --- a/model-engine/model_engine_server/inference/batch_inference/requirements-build.txt +++ /dev/null @@ -1,8 +0,0 @@ -# Copied from https://github.com/vllm-project/vllm/blob/main/requirements-build.txt -# Needed to build flash-attn into docker image -cmake>=3.21 -ninja -packaging -setuptools>=49.4.0 -torch==2.3.0 -wheel \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements-dev.txt b/model-engine/model_engine_server/inference/batch_inference/requirements-dev.txt deleted file mode 100644 index 5d66e061..00000000 --- a/model-engine/model_engine_server/inference/batch_inference/requirements-dev.txt +++ /dev/null @@ -1 +0,0 @@ --e ../../.. # Need to install model_engine_server as a package \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt index e27ece73..e83b4ccd 100644 --- a/model-engine/model_engine_server/inference/batch_inference/requirements.txt +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -1,6 +1,6 @@ -vllm==0.4.2 -pydantic==2.7.1 -boto3>=1.34.105 +vllm==0.2.5 +pydantic==1.10.13 +boto3==1.34.15 smart-open==6.4.0 ddtrace==2.4.0 docker==7.0.0 From 110833b621217be96ebdaaee665df12169a2895b Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 17 May 2024 15:10:18 -0700 Subject: [PATCH 303/425] Allow H100 to be used (#522) * Allow H100 to be used * Add MIG groups --- clients/python/llmengine/data_types.py | 3 +++ clients/python/llmengine/model.py | 2 ++ docs/guides/self_hosting.md | 5 +++-- .../model_engine_server/common/resource_limits.py | 6 ++++++ .../model_engine_server/domain/entities/gpu_type.py | 3 +++ .../domain/use_cases/llm_model_endpoint_use_cases.py | 8 ++++---- model-engine/tests/unit/domain/test_llm_use_cases.py | 12 ++++++------ 7 files changed, 27 insertions(+), 12 deletions(-) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index a9d65d1a..bcaf4112 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -41,6 +41,9 @@ class GpuType(str, Enum): NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" NVIDIA_AMPERE_A100E = "nvidia-ampere-a100e" + NVIDIA_HOPPER_H100 = "nvidia-hopper-h100" + NVIDIA_HOPPER_H100_1G_20GB = "nvidia-hopper-h100-1g20gb" + NVIDIA_HOPPER_H100_3G_40GB = "nvidia-hopper-h100-3g40gb" class ModelEndpointType(str, Enum): diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index b242abea..77a0c1d8 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -148,6 +148,7 @@ def create( - ``nvidia-ampere-a10`` - ``nvidia-ampere-a100`` - ``nvidia-ampere-a100e`` + - ``nvidia-hopper-h100`` high_priority (`Optional[bool]`): Either ``True`` or ``False``. Enabling this will allow the created @@ -531,6 +532,7 @@ def update( - ``nvidia-ampere-a10`` - ``nvidia-ampere-a100`` - ``nvidia-ampere-a100e`` + - ``nvidia-hopper-h100`` high_priority (`Optional[bool]`): Either ``True`` or ``False``. Enabling this will allow the created diff --git a/docs/guides/self_hosting.md b/docs/guides/self_hosting.md index 348e94be..442a16dc 100644 --- a/docs/guides/self_hosting.md +++ b/docs/guides/self_hosting.md @@ -21,8 +21,9 @@ Additionally, they must have the `k8s.amazonaws.com/accelerator` label set appro | --- | --- | | g4dn | nvidia-tesla-t4 | | g5 | nvidia-tesla-a10 | -| p4d | nvidia-tesla-a100 | -| p4de | nvidia-tesla-a100e | +| p4d | nvidia-ampere-a100 | +| p4de | nvidia-ampere-a100e | +| p5 | nvidia-hopper-h100 | We also recommend setting the following taint on your GPU nodes to prevent pods requiring GPU resources from being scheduled on them: - { key = "nvidia.com/gpu", value = "true", effect = "NO_SCHEDULE" } diff --git a/model-engine/model_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py index 10bf0f0d..c65e6cd6 100644 --- a/model-engine/model_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -34,6 +34,9 @@ ) # Should we allow multi-gpu instances? This allows the largest single-gpu g5dn instance. # p4d.24xlarge, p4de.24xlarge A100_INSTANCE_LIMITS = dict(cpus=95, memory="1000Gi") +H100_INSTANCE_LIMITS = dict(cpus=191, memory="2000Gi") +H100_1G_20GB_INSTANCE_LIMITS = dict(cpus=47, memory="500Gi") +H100_3G_40GB_INSTANCE_LIMITS = dict(cpus=95, memory="1000Gi") STORAGE_LIMIT = "500G" # TODO: figure out an actual limit. REQUESTS_BY_GPU_TYPE = { None: CPU_INSTANCE_LIMITS, @@ -41,6 +44,9 @@ GpuType.NVIDIA_AMPERE_A10: A10_INSTANCE_LIMITS, GpuType.NVIDIA_AMPERE_A100: A100_INSTANCE_LIMITS, GpuType.NVIDIA_AMPERE_A100E: A100_INSTANCE_LIMITS, + GpuType.NVIDIA_HOPPER_H100: H100_INSTANCE_LIMITS, + GpuType.NVIDIA_HOPPER_H100_1G_20GB: H100_1G_20GB_INSTANCE_LIMITS, + GpuType.NVIDIA_HOPPER_H100_3G_40GB: H100_3G_40GB_INSTANCE_LIMITS, } FORWARDER_CPU_USAGE = 1 diff --git a/model-engine/model_engine_server/domain/entities/gpu_type.py b/model-engine/model_engine_server/domain/entities/gpu_type.py index 5dc2c459..6c686c01 100644 --- a/model-engine/model_engine_server/domain/entities/gpu_type.py +++ b/model-engine/model_engine_server/domain/entities/gpu_type.py @@ -8,3 +8,6 @@ class GpuType(str, Enum): NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" NVIDIA_AMPERE_A100E = "nvidia-ampere-a100e" + NVIDIA_HOPPER_H100 = "nvidia-hopper-h100" + NVIDIA_HOPPER_H100_1G_20GB = "nvidia-hopper-h100-1g20gb" + NVIDIA_HOPPER_H100_3G_40GB = "nvidia-hopper-h100-3g40gb" diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 761d8ab9..e8549ad9 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -2286,25 +2286,25 @@ def _infer_hardware( gpus = 2 memory = "160Gi" storage = "160Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A100E + gpu_type = GpuType.NVIDIA_HOPPER_H100 elif min_memory_gb <= 320: cpus = "40" gpus = 4 memory = "320Gi" storage = "320Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A100E + gpu_type = GpuType.NVIDIA_HOPPER_H100 elif min_memory_gb <= 640: cpus = "80" gpus = 8 memory = "800Gi" storage = "460Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A100E + gpu_type = GpuType.NVIDIA_HOPPER_H100 elif "llama-3-8b-instruct-262k" in model_name: cpus = "20" gpus = 2 memory = "40Gi" storage = "40Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A100E + gpu_type = GpuType.NVIDIA_HOPPER_H100 else: raise ObjectHasInvalidValueException(f"Unable to infer hardware for {model_name}.") diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 2beb12a1..9b2cbc38 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1848,7 +1848,7 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 fake_llm_artifact_gateway.model_config = { "architectures": ["MixtralForCausalLM"], @@ -1879,7 +1879,7 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.gpus == 8 assert hardware.memory == "800Gi" assert hardware.storage == "460Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 fake_llm_artifact_gateway.model_config = { "_name_or_path": "meta-llama/Llama-2-7b-hf", @@ -2015,7 +2015,7 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 fake_llm_artifact_gateway.model_config = { "architectures": ["LlamaForCausalLM"], @@ -2043,7 +2043,7 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 # (TODO) figure out how to calculate memory for llama-3-8b-instruct-262k # fake_llm_artifact_gateway.model_config = { @@ -2073,7 +2073,7 @@ def test_infer_hardware(fake_llm_artifact_gateway): # assert hardware.gpus == 2 # assert hardware.memory == "160Gi" # assert hardware.storage == "160Gi" - # assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A100E + # assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 with pytest.raises(ObjectHasInvalidValueException): _infer_hardware(fake_llm_artifact_gateway, "unsupported_model", "") @@ -2095,7 +2095,7 @@ def test_fill_hardware_info(fake_llm_artifact_gateway): assert request.gpus == 2 assert request.memory == "160Gi" assert request.storage == "160Gi" - assert request.gpu_type == GpuType.NVIDIA_AMPERE_A100E + assert request.gpu_type == GpuType.NVIDIA_HOPPER_H100 request = CreateLLMModelEndpointV1Request( name="mixtral-8x7b", From e207936904cb534f12e046fa355f02fd1bb553d0 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 20 May 2024 09:43:42 -0700 Subject: [PATCH 304/425] vLLM version 0.4.2 Docker image (#521) --- .../inference/vllm/Dockerfile | 40 +++++++++++++++++++ .../inference/vllm/requirements-build.txt | 8 ++++ .../inference/vllm/requirements.txt | 2 +- 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 model-engine/model_engine_server/inference/vllm/requirements-build.txt diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile b/model-engine/model_engine_server/inference/vllm/Dockerfile index 75b9e1f5..227b3e16 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile @@ -1,3 +1,37 @@ +#################### BASE BUILD IMAGE #################### +FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev +RUN apt-get update -y \ + && apt-get install -y python3-pip git +# Workaround for https://github.com/openai/triton/issues/2507 and +# https://github.com/pytorch/pytorch/issues/107960 -- hopefully +# this won't be needed for future versions of this docker image +# or future versions of triton. +RUN ldconfig /usr/local/cuda-12.1/compat/ +WORKDIR /workspace + +COPY requirements-build.txt requirements-build.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-build.txt +#################### BASE BUILD IMAGE #################### + +#################### FLASH_ATTENTION Build IMAGE #################### +FROM dev as flash-attn-builder +# max jobs used for build +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} +# flash attention version +ARG flash_attn_version=v2.4.2 +ENV FLASH_ATTN_VERSION=${flash_attn_version} + +WORKDIR /usr/src/flash-attention-v2 + +# Download the wheel or build it if a pre-compiled release doesn't exist +RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ + --no-build-isolation --no-deps --no-cache-dir + +#################### FLASH_ATTENTION Build IMAGE #################### + +#################### Runtime IMAGE #################### FROM nvcr.io/nvidia/pytorch:23.09-py3 RUN apt-get update \ @@ -7,6 +41,10 @@ RUN apt-get update \ && apt-get autoremove -y \ && rm -rf /var/lib/apt/lists/* +# Install flash attention (from pre-built wheel) +RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ + pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir + RUN pip uninstall torch -y COPY requirements.txt /workspace/requirements.txt RUN pip install -r requirements.txt @@ -15,3 +53,5 @@ RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linu RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz COPY vllm_server.py /workspace/vllm_server.py + +#################### Runtime IMAGE #################### diff --git a/model-engine/model_engine_server/inference/vllm/requirements-build.txt b/model-engine/model_engine_server/inference/vllm/requirements-build.txt new file mode 100644 index 00000000..020e4532 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/requirements-build.txt @@ -0,0 +1,8 @@ +# Copied from https://github.com/vllm-project/vllm/blob/main/requirements-build.txt +# Needed to build flash-attn into docker image +cmake>=3.21 +ninja +packaging +setuptools>=49.4.0 +torch==2.3.0 +wheel \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index c4d967d7..9b106e07 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,2 +1,2 @@ -vllm==0.4.1 +vllm==0.4.2 pydantic>=2.0 From 2f71b89c23dc6f99344cd4ea163f625065cccb4b Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 20 May 2024 16:21:50 -0700 Subject: [PATCH 305/425] Image cache and balloon on H100s, also temporarily stop people from using A100 (#523) * Cache H100 * Stop people from using A100 * no cover * no cover * update client version --- .../templates/balloon_h100_deployment.yaml | 50 +++++++++++++++++++ charts/model-engine/values_circleci.yaml | 1 + charts/model-engine/values_sample.yaml | 23 +++++++++ clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/model.py | 4 ++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- .../use_cases/llm_model_endpoint_use_cases.py | 4 ++ 8 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 charts/model-engine/templates/balloon_h100_deployment.yaml diff --git a/charts/model-engine/templates/balloon_h100_deployment.yaml b/charts/model-engine/templates/balloon_h100_deployment.yaml new file mode 100644 index 00000000..03bce9aa --- /dev/null +++ b/charts/model-engine/templates/balloon_h100_deployment.yaml @@ -0,0 +1,50 @@ +{{- if not .Values.serviceIdentifier }} +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ .Chart.Name }}-balloon-h100 + labels: + team: infra + product: common-warm-nodes +spec: + replicas: {{ .Values.replicaCount.balloonH100 }} + selector: + matchLabels: + app: {{ .Chart.Name }}-balloon-h100 + version: v1 + template: + metadata: + labels: + app: {{ .Chart.Name }}-balloon-h100 + product: common-warm-nodes + team: infra + env: {{ .Values.context }} + version: v1 + annotations: + sidecar.istio.io/inject: "false" + spec: + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-ampere-h100 + {{- with .Values.balloonNodeSelector }} + {{- toYaml . | nindent 8 }} + {{- end }} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + containers: + - image: public.ecr.aws/ubuntu/ubuntu:latest + imagePullPolicy: IfNotPresent + name: main + resources: + limits: + memory: 28Gi + nvidia.com/gpu: 1 + cpu: 4 + command: + - /bin/bash + - -c + - "while true; do sleep 30; done" + terminationGracePeriodSeconds: 0 + priorityClassName: {{ .Chart.Name }}-low-priority +{{- end }} diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 0f9d9337..d4e7718b 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -8,6 +8,7 @@ replicaCount: balloonA100: 0 balloonCpu: 0 balloonT4: 0 + balloonH100: 0 # tag needs to be set dynamically every time. Usually it is set to the SHA1 hash of the git # commit from which the image was built. diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 2d002c00..97f68532 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -81,6 +81,8 @@ replicaCount: balloonCpu: 0 # balloonT4 is a low priority pod deployment for T4 GPU nodes balloonT4: 0 + # balloonH100 is a low priority pod deployment for H100 GPU nodes + balloonH100: 0 # autoscaling is the autoscaling configuration for LLM Engine server deployments (e.g gateway, cache, and builder deployments) autoscaling: @@ -254,6 +256,27 @@ imageCache: - key: "nvidia.com/gpu" operator: "Exists" effect: "NoSchedule" + - name: h100 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: h100-mig-1g-20gb + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100-mig-1g-20gb + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: h100-mig-3g-40gb + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100-mig-3g-40gb + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" # celeryBrokerType specifies the celery broker type for async endpoints, either "sqs" or "elasticache" celeryBrokerType: sqs diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 110ed4cc..15b836da 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b32" +__version__ = "0.0.0b33" import os from typing import Sequence diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 77a0c1d8..1e18d3bf 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -149,6 +149,8 @@ def create( - ``nvidia-ampere-a100`` - ``nvidia-ampere-a100e`` - ``nvidia-hopper-h100`` + - ``nvidia-hopper-h100-1g20gb`` + - ``nvidia-hopper-h100-3g40gb`` high_priority (`Optional[bool]`): Either ``True`` or ``False``. Enabling this will allow the created @@ -533,6 +535,8 @@ def update( - ``nvidia-ampere-a100`` - ``nvidia-ampere-a100e`` - ``nvidia-hopper-h100`` + - ``nvidia-hopper-h100-1g20gb`` + - ``nvidia-hopper-h100-3g40gb`` high_priority (`Optional[bool]`): Either ``True`` or ``False``. Enabling this will allow the created diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index b7459272..7d645b53 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta32" +version = "0.0.0.beta33" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 489e428a..c11111cf 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta32", + version="0.0.0.beta33", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index e8549ad9..894096e8 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -879,6 +879,10 @@ async def execute( max_workers=request.max_workers, endpoint_type=request.endpoint_type, ) + if request.gpu_type == GpuType.NVIDIA_AMPERE_A100E: # pragma: no cover + raise ObjectHasInvalidValueException( + "We have migrated A100 usage to H100. Please request for H100 instead!" + ) if request.labels is None: raise EndpointLabelsException("Endpoint labels cannot be None!") validate_labels(request.labels) From 8993b18a3bbd1a08608f39ba2ed8ca470207fd38 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 21 May 2024 10:47:40 -0700 Subject: [PATCH 306/425] Hardcode llama 3 70b endpoint param (#524) * Hardcode some tuning for endpoints * remove mixtral 8x22b hardcode * test --- .../use_cases/llm_model_endpoint_use_cases.py | 3 +++ model-engine/tests/unit/conftest.py | 1 + model-engine/tests/unit/domain/conftest.py | 27 +++++++++++++++++++ .../tests/unit/domain/test_llm_use_cases.py | 11 ++++++++ 4 files changed, 42 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 894096e8..b27fe107 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -677,6 +677,9 @@ async def create_vllm_bundle( if hmi_config.sensitive_log_mode: # pragma: no cover subcommands[-1] = subcommands[-1] + " --disable-log-requests" + if "llama-3-70b" in model_name: + subcommands[-1] = subcommands[-1] + " --gpu-memory-utilization 0.95 --enforce-eager" + command = [ "/bin/bash", "-c", diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 7de4ec47..4300bbea 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -763,6 +763,7 @@ def __init__(self): "llama-7b/special_tokens_map.json": ["llama-7b/special_tokens_map.json"], "llama-2-7b": ["model-fake.safetensors"], "mpt-7b": ["model-fake.safetensors"], + "llama-3-70b": ["model-fake.safetensors"], } self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} self.model_config = { diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index c721960e..e9e37cf2 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -292,6 +292,33 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque ) +@pytest.fixture +def create_llm_model_endpoint_request_llama_3_70b() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_3_70b", + model_name="llama-3-70b", + source="hugging_face", + inference_framework="vllm", + inference_framework_image_tag="1.0.0", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-3-70b", + ) + + @pytest.fixture def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( CreateLLMModelEndpointV1Request diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 9b2cbc38..770f3bda 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -80,6 +80,7 @@ async def test_create_model_endpoint_use_case_success( create_llm_model_endpoint_request_sync: CreateLLMModelEndpointV1Request, create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, create_llm_model_endpoint_request_llama_2: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_request_llama_3_70b: CreateLLMModelEndpointV1Request, ): fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository bundle_use_case = CreateModelBundleV2UseCase( @@ -182,6 +183,16 @@ async def test_create_model_endpoint_use_case_success( ) assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1] + response_5 = await use_case.execute( + user=user, request=create_llm_model_endpoint_request_llama_3_70b + ) + assert response_5.endpoint_creation_task_id + assert isinstance(response_5, CreateLLMModelEndpointV1Response) + bundle = await fake_model_bundle_repository.get_latest_model_bundle_by_name( + owner=user.team_id, name=create_llm_model_endpoint_request_llama_3_70b.name + ) + assert " --gpu-memory-utilization 0.95" in bundle.flavor.command[-1] + @pytest.mark.asyncio @pytest.mark.parametrize( From 028d4156b5273df30f0c718d40e1f4018ff8fb34 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 21 May 2024 20:34:24 -0700 Subject: [PATCH 307/425] Don't fail checking GPU memory (#525) --- .../inference/batch_inference/vllm_batch.py | 16 ++++++++++------ .../inference/vllm/vllm_server.py | 16 ++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index e8b9fabe..e4887b20 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -479,7 +479,8 @@ def get_gpu_free_memory(): # pragma: no cover ).stdout gpu_memory = [int(x) for x in output.strip().split("\n")] return gpu_memory - except subprocess.CalledProcessError: + except Exception as e: + print(f"Error getting GPU memory: {e}") return None @@ -494,11 +495,14 @@ def check_unknown_startup_memory_usage(): # pragma: no cover print( f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." ) - # nosemgrep - output = subprocess.run( - ["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True - ).stdout - print(f"Processes using GPU: {output}") + try: + # nosemgrep + output = subprocess.run( + ["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True + ).stdout + print(f"Processes using GPU: {output}") + except Exception as e: + print(f"Error getting processes using GPU: {e}") if __name__ == "__main__": diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 3a966f15..e061b924 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -140,7 +140,8 @@ def get_gpu_free_memory(): ).stdout gpu_memory = [int(x) for x in output.strip().split("\n")] return gpu_memory - except subprocess.CalledProcessError: + except Exception as e: + print(f"Error getting GPU memory: {e}") return None @@ -154,11 +155,14 @@ def check_unknown_startup_memory_usage(): print( f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." ) - # nosemgrep - output = subprocess.run( - ["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True - ).stdout - print(f"Processes using GPU: {output}") + try: + # nosemgrep + output = subprocess.run( + ["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True + ).stdout + print(f"Processes using GPU: {output}") + except Exception as e: + print(f"Error getting processes using GPU: {e}") def debug(sig, frame): From 275f495199d51ab38ddabe65bcb9157800f3d76b Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 28 May 2024 16:35:40 -0700 Subject: [PATCH 308/425] Option to read Redis URL from AWS Secret (#526) Add an option to have the pods read Redis auth info from an AWS secret. Note: there are two places the redis auth info needs to be added, since Redis is used for both the model endpoint creation request message queue and a cache for endpoint info The secret is formatted as follows: It must contain a few keys, namely host, port, scheme (optional, defaults to redis://), auth_token (optional), query_params (optional). These control which Redis gets used as the message queue for the endpoint builder. Also must contain a key cache-url, the full Redis url of the redis to be used as a cache. --- charts/model-engine/values_sample.yaml | 23 +++++++++++++++++-- .../model_engine_server/common/config.py | 16 +++++++++++++ .../model_engine_server/core/celery/app.py | 12 ++++++++++ .../model_engine_server/core/config.py | 3 ++- 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 97f68532..8f78f5a4 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -156,8 +156,22 @@ config: ml_account_id: "000000000000" # docker_repo_prefix [required] is the prefix for AWS ECR repositories docker_repo_prefix: "000000000000.dkr.ecr.us-east-1.amazonaws.com" - # redis_host [required] is the hostname of the redis cluster you wish to connect + # redis_host [required if redis_aws_secret_name not present] is the hostname of the redis cluster you wish to connect redis_host: llm-engine-prod-cache.use1.cache.amazonaws.com + # redis_aws_secret_name [optional] is the AWS secret that contains the connection info of the Redis cluster. + # The information provided should be as follows: + # scheme: either redis:// or rediss://, will default to redis:// + # auth_token (optional): an auth token for the Redis cluster + # host: the hostname of the Redis cluster + # port: the port of the Redis cluster + # query_params (optional): additional query parameters for the Redis cluster, will default to "" + # The url will be built as follows: + # {scheme}{host}:{port}/{db_index}{query_params} if auth_token is not provided, + # {scheme}:{auth_token}@{host}:{port}/{db_index}{query_params} if auth_token is provided + # db_index will be filled in by LLM Engine. + # This secret must be accessible by the default LLM Engine AWS role + # e.g. what is set by profile_ml_worker if provided + # redis_aws_secret_name: sample-prod/redis-credentials # s3_bucket [required] is the S3 bucket you wish to connect s3_bucket: "llm-engine" launch: @@ -165,9 +179,14 @@ config: endpoint_namespace: llm-engine # cache_redis_aws_url is the full url for the redis cluster you wish to connect, # cache_redis_azure_host is the redis cluster host when using cloud_provider azure - # one of cache_redis_aws_url and cache_redis_azure_host must be provided + # cache_redis_aws_secret_name is an AWS secret that contains the Redis credentials. + # It has a field "cache-url" with the full URL of the Redis cluster (including db number). + # Other fields are ignored; e.g. you can use the secret for multiple purposes. + # This secret must be accessible by the default LLM Engine AWS role + # exactly one of cache_redis_aws_url, cache_redis_azure_host, or cache_redis_aws_secret_name must be provided cache_redis_aws_url: redis://llm-engine-prod-cache.use1.cache.amazonaws.com:6379/15 cache_redis_azure_host: llm-engine-cache.redis.cache.windows.net:6380 + cache_redis_aws_secret_name: sample-prod/redis-credentials # s3_file_llm_fine_tuning_job_repository [required] is the S3 URI for the S3 bucket/key that you wish to save fine-tuned assests s3_file_llm_fine_tuning_job_repository: "s3://llm-engine/llm-ft-job-repository" # dd_trace_enabled specifies whether to enable datadog tracing, datadog must be installed in the cluster diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index dd18a1c5..c66b8df8 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -8,6 +8,7 @@ import yaml from azure.identity import DefaultAzureCredential +from model_engine_server.core.aws.secrets import get_key_file from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger @@ -68,8 +69,12 @@ class HostedModelInferenceServiceConfig: user_inference_tensorflow_repository: str docker_image_layer_cache_repository: str sensitive_log_mode: bool + # Exactly one of the following three must be specified cache_redis_aws_url: Optional[str] = None # also using this to store sync autoscaling metrics cache_redis_azure_host: Optional[str] = None + cache_redis_aws_secret_name: Optional[ + str + ] = None # Not an env var because the redis cache info is already here @classmethod def from_yaml(cls, yaml_path): @@ -80,7 +85,18 @@ def from_yaml(cls, yaml_path): @property def cache_redis_url(self) -> str: if self.cache_redis_aws_url: + assert infra_config().cloud_provider == "aws", "cache_redis_aws_url is only for AWS" + if self.cache_redis_aws_secret_name: + logger.warning( + "Both cache_redis_aws_url and cache_redis_aws_secret_name are set. Using cache_redis_aws_url" + ) return self.cache_redis_aws_url + elif self.cache_redis_aws_secret_name: + assert ( + infra_config().cloud_provider == "aws" + ), "cache_redis_aws_secret_name is only for AWS" + creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role + return creds["cache-url"] assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure" username = os.getenv("AZURE_OBJECT_ID") diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index 167c01ba..80fda86b 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -10,6 +10,7 @@ from celery.app.control import Inspect from celery.result import AsyncResult from model_engine_server.core.aws.roles import session +from model_engine_server.core.aws.secrets import get_key_file from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import ( CustomJSONFormatter, @@ -195,6 +196,17 @@ def get_redis_host_port(): def get_redis_endpoint(db_index: int = 0) -> str: + if infra_config().redis_aws_secret_name is not None: + logger.info("Using infra_config().redis_aws_secret_name for Redis endpoint") + creds = get_key_file(infra_config().redis_aws_secret_name) # Use default role + scheme = creds.get("scheme", "redis://") + host = creds["host"] + port = creds["port"] + query_params = creds.get("query_params", "") + auth_token = creds.get("auth_token", None) + if auth_token is not None: + return f"{scheme}:{auth_token}@{host}:{port}/{db_index}{query_params}" + return f"{scheme}{host}:{port}/{db_index}{query_params}" host, port = get_redis_host_port() auth_token = os.getenv("REDIS_AUTH_TOKEN") if auth_token: diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index 5bbe58bb..301b4f80 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -38,8 +38,9 @@ class InfraConfig: default_region: str ml_account_id: str docker_repo_prefix: str - redis_host: str s3_bucket: str + redis_host: Optional[str] = None + redis_aws_secret_name: Optional[str] = None profile_ml_worker: str = "default" profile_ml_inference_worker: str = "default" identity_service_url: Optional[str] = None From 8a4c745f027b1bd54c6be308ea25f6572211f340 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri <87143260+saiatmakuri@users.noreply.github.com> Date: Tue, 28 May 2024 16:45:50 -0700 Subject: [PATCH 309/425] Fix formatting on completions documentation guide (#527) --- docs/guides/completions.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/guides/completions.md b/docs/guides/completions.md index 8bfad184..250bbbd1 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -261,6 +261,7 @@ response = Completion.create( print(response.json()) # {"request_id": "34621b44-c655-402c-a459-f108b3e49b12", "output": {"text": "John", "num_prompt_tokens": 6, "num_completion_tokens": 4, "tokens": None}} +``` ## Which model should I use? From 5bb8797b3f8e90ef949fb22dacd1300f12650ac4 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:15:34 -0700 Subject: [PATCH 310/425] Higher priority for gateway (#529) --- charts/model-engine/templates/gateway_deployment.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/charts/model-engine/templates/gateway_deployment.yaml b/charts/model-engine/templates/gateway_deployment.yaml index ed1d6cae..e6466544 100644 --- a/charts/model-engine/templates/gateway_deployment.yaml +++ b/charts/model-engine/templates/gateway_deployment.yaml @@ -35,6 +35,7 @@ spec: imagePullSecrets: {{- toYaml . | nindent 8 }} {{- end }} + priorityClassName: model-engine-high-priority containers: - name: {{ include "modelEngine.fullname" . }} image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag}}" From bd192cbb387c203f75d1bf1c4f4098aebeecc0f3 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:15:33 -0700 Subject: [PATCH 311/425] Non-interactive installation during docker build (#533) --- .../model_engine_server/inference/pytorch_or_tf.base.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile index 8d8d3378..459fdaee 100644 --- a/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile +++ b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile @@ -7,7 +7,7 @@ WORKDIR /app # TODO: ffmpeg, libsm6, and lixext6 are essentially hardcoded from lidar. # It's probably more correct to add support for arbitrary user-specified base images, # otherwise this base image gets bloated over time. -RUN apt-get update && apt-get install -y \ +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ apt-utils \ dumb-init \ git \ From ad24f6515994c666c0021fb02b98fd699c22d87a Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:47:56 -0700 Subject: [PATCH 312/425] [Client] Add guided_grammar and other missing fields (#532) Add guided_grammar to the client, + add some missing fields to some codepaths --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/completion.py | 20 ++++++++++++++++++++ clients/python/llmengine/data_types.py | 2 ++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 5 files changed, 25 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 15b836da..b2ea471a 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b33" +__version__ = "0.0.0b34" import os from typing import Sequence diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 01aa86a9..4cbbaf75 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -47,6 +47,7 @@ async def acreate( guided_json: Optional[Dict[str, Any]] = None, guided_regex: Optional[str] = None, guided_choice: Optional[List[str]] = None, + guided_grammar: Optional[str] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]: @@ -118,6 +119,9 @@ async def acreate( guided_choice (Optional[List[str]]): If specified, the output will be exactly one of the choices. + guided_grammar (Optional[str]): + If specified, the output will follow the context-free grammar provided. + timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -218,6 +222,7 @@ async def _acreate_stream( guided_json=guided_json, guided_regex=guided_regex, guided_choice=guided_choice, + guided_grammar=guided_grammar, timeout=timeout, ) @@ -242,6 +247,11 @@ async def _acreate_sync(**kwargs) -> CompletionSyncResponse: frequency_penalty=frequency_penalty, top_k=top_k, top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, + guided_grammar=guided_grammar, ) @classmethod @@ -261,6 +271,7 @@ def create( guided_json: Optional[Dict[str, Any]] = None, guided_regex: Optional[str] = None, guided_choice: Optional[List[str]] = None, + guided_grammar: Optional[str] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]: @@ -333,6 +344,9 @@ def create( guided_choice (Optional[List[str]]): If specified, the output will be exactly one of the choices. + guided_grammar (Optional[str]): + If specified, the output will follow the context-free grammar provided. + timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -419,6 +433,11 @@ def _create_stream(**kwargs): frequency_penalty=frequency_penalty, top_k=top_k, top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, + guided_grammar=guided_grammar, ) else: @@ -436,6 +455,7 @@ def _create_stream(**kwargs): guided_json=guided_json, guided_regex=guided_regex, guided_choice=guided_choice, + guided_grammar=guided_grammar, ).dict() response = cls.post_sync( resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}", diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index bcaf4112..f1c9b56c 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -331,6 +331,7 @@ class CompletionSyncV1Request(BaseModel): guided_json: Optional[Dict[str, Any]] = Field(default=None) guided_regex: Optional[str] = Field(default=None) guided_choice: Optional[List[str]] = Field(default=None) + guided_grammar: Optional[str] = Field(default=None) class TokenOutput(BaseModel): @@ -405,6 +406,7 @@ class CompletionStreamV1Request(BaseModel): guided_json: Optional[Dict[str, Any]] = Field(default=None) guided_regex: Optional[str] = Field(default=None) guided_choice: Optional[List[str]] = Field(default=None) + guided_grammar: Optional[str] = Field(default=None) class CompletionStreamOutput(BaseModel): diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 7d645b53..910fa162 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta33" +version = "0.0.0.beta34" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index c11111cf..c8d30e11 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta33", + version="0.0.0.beta34", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) From f84adbb15aea29ba77e720fc44cae0c805177314 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 6 Jun 2024 09:41:41 -0700 Subject: [PATCH 313/425] Make balloon creation flexible (#531) * Make balloon creation flexible * fix ref --- .../templates/balloon_a10_deployment.yaml | 50 ------------------- .../templates/balloon_cpu_deployment.yaml | 18 ++++--- ...ployment.yaml => balloon_deployments.yaml} | 21 +++++--- .../templates/balloon_h100_deployment.yaml | 50 ------------------- .../templates/balloon_t4_deployment.yaml | 50 ------------------- charts/model-engine/values_circleci.yaml | 17 +++++-- charts/model-engine/values_sample.yaml | 27 ++++++---- 7 files changed, 53 insertions(+), 180 deletions(-) delete mode 100644 charts/model-engine/templates/balloon_a10_deployment.yaml rename charts/model-engine/templates/{balloon_a100_deployment.yaml => balloon_deployments.yaml} (65%) delete mode 100644 charts/model-engine/templates/balloon_h100_deployment.yaml delete mode 100644 charts/model-engine/templates/balloon_t4_deployment.yaml diff --git a/charts/model-engine/templates/balloon_a10_deployment.yaml b/charts/model-engine/templates/balloon_a10_deployment.yaml deleted file mode 100644 index 5e71af2b..00000000 --- a/charts/model-engine/templates/balloon_a10_deployment.yaml +++ /dev/null @@ -1,50 +0,0 @@ -{{- if not .Values.serviceIdentifier }} -apiVersion: apps/v1 -kind: Deployment -metadata: - name: {{ .Chart.Name }}-balloon-a10 - labels: - team: infra - product: common-warm-nodes -spec: - replicas: {{ .Values.replicaCount.balloonA10 }} - selector: - matchLabels: - app: {{ .Chart.Name }}-balloon-a10 - version: v1 - template: - metadata: - labels: - app: {{ .Chart.Name }}-balloon-a10 - product: common-warm-nodes - team: infra - env: {{ .Values.context }} - version: v1 - annotations: - sidecar.istio.io/inject: "false" - spec: - nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-ampere-a10 - {{- with .Values.balloonNodeSelector }} - {{- toYaml . | nindent 8 }} - {{- end }} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - containers: - - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: IfNotPresent - name: main - resources: - limits: - memory: 28Gi - nvidia.com/gpu: 1 - cpu: 4 - command: - - /bin/bash - - -c - - "while true; do sleep 30; done" - terminationGracePeriodSeconds: 0 - priorityClassName: {{ .Chart.Name }}-low-priority -{{- end }} diff --git a/charts/model-engine/templates/balloon_cpu_deployment.yaml b/charts/model-engine/templates/balloon_cpu_deployment.yaml index a7be9011..583e3c1e 100644 --- a/charts/model-engine/templates/balloon_cpu_deployment.yaml +++ b/charts/model-engine/templates/balloon_cpu_deployment.yaml @@ -1,29 +1,31 @@ {{- if not .Values.serviceIdentifier }} +{{- range .Values.balloons }} +{{- if eq .acceleratorName "cpu" }} apiVersion: apps/v1 kind: Deployment metadata: - name: {{ .Chart.Name }}-balloon-cpu + name: {{ $.Chart.Name }}-balloon-cpu labels: team: infra product: common-warm-nodes spec: - replicas: {{ .Values.replicaCount.balloonCpu }} + replicas: {{ .replicaCount }} selector: matchLabels: - app: {{ .Chart.Name }}-balloon-cpu + app: {{ $.Chart.Name }}-balloon-cpu version: v1 template: metadata: labels: - app: {{ .Chart.Name }}-balloon-cpu + app: {{ $.Chart.Name }}-balloon-cpu product: common-warm-nodes team: infra - env: {{ .Values.context }} + env: {{ $.Values.context }} version: v1 annotations: sidecar.istio.io/inject: "false" spec: - {{- with .Values.balloonNodeSelector }} + {{- with $.Values.balloonNodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} {{- end }} @@ -40,5 +42,7 @@ spec: - -c - "while true; do sleep 30; done" terminationGracePeriodSeconds: 0 - priorityClassName: {{ .Chart.Name }}-low-priority + priorityClassName: {{ $.Chart.Name }}-low-priority +{{- end }} +{{- end }} {{- end }} diff --git a/charts/model-engine/templates/balloon_a100_deployment.yaml b/charts/model-engine/templates/balloon_deployments.yaml similarity index 65% rename from charts/model-engine/templates/balloon_a100_deployment.yaml rename to charts/model-engine/templates/balloon_deployments.yaml index 50dbfea4..3a4e1f20 100644 --- a/charts/model-engine/templates/balloon_a100_deployment.yaml +++ b/charts/model-engine/templates/balloon_deployments.yaml @@ -1,31 +1,33 @@ {{- if not .Values.serviceIdentifier }} +{{- range .Values.balloons }} +{{- if not (eq .acceleratorName "cpu") }} apiVersion: apps/v1 kind: Deployment metadata: - name: {{ .Chart.Name }}-balloon-a100 + name: {{ $.Chart.Name }}-balloon-{{ .acceleratorName }} labels: team: infra product: common-warm-nodes spec: - replicas: {{ .Values.replicaCount.balloonA100 }} + replicas: {{ .replicaCount }} selector: matchLabels: - app: {{ .Chart.Name }}-balloon-a100 + app: {{ $.Chart.Name }}-balloon-{{ .acceleratorName }} version: v1 template: metadata: labels: - app: {{ .Chart.Name }}-balloon-a100 + app: {{ $.Chart.Name }}-balloon-{{ .acceleratorName }} product: common-warm-nodes team: infra - env: {{ .Values.context }} + env: {{ $.Values.context }} version: v1 annotations: sidecar.istio.io/inject: "false" spec: nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-ampere-a100 - {{- with .Values.balloonNodeSelector }} + k8s.amazonaws.com/accelerator: {{ .acceleratorName }} + {{- with $.Values.balloonNodeSelector }} {{- toYaml . | nindent 8 }} {{- end }} tolerations: @@ -46,5 +48,8 @@ spec: - -c - "while true; do sleep 30; done" terminationGracePeriodSeconds: 0 - priorityClassName: {{ .Chart.Name }}-low-priority + priorityClassName: {{ $.Chart.Name }}-low-priority +--- {{- end }} +{{- end }} +{{- end }} \ No newline at end of file diff --git a/charts/model-engine/templates/balloon_h100_deployment.yaml b/charts/model-engine/templates/balloon_h100_deployment.yaml deleted file mode 100644 index 03bce9aa..00000000 --- a/charts/model-engine/templates/balloon_h100_deployment.yaml +++ /dev/null @@ -1,50 +0,0 @@ -{{- if not .Values.serviceIdentifier }} -apiVersion: apps/v1 -kind: Deployment -metadata: - name: {{ .Chart.Name }}-balloon-h100 - labels: - team: infra - product: common-warm-nodes -spec: - replicas: {{ .Values.replicaCount.balloonH100 }} - selector: - matchLabels: - app: {{ .Chart.Name }}-balloon-h100 - version: v1 - template: - metadata: - labels: - app: {{ .Chart.Name }}-balloon-h100 - product: common-warm-nodes - team: infra - env: {{ .Values.context }} - version: v1 - annotations: - sidecar.istio.io/inject: "false" - spec: - nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-ampere-h100 - {{- with .Values.balloonNodeSelector }} - {{- toYaml . | nindent 8 }} - {{- end }} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - containers: - - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: IfNotPresent - name: main - resources: - limits: - memory: 28Gi - nvidia.com/gpu: 1 - cpu: 4 - command: - - /bin/bash - - -c - - "while true; do sleep 30; done" - terminationGracePeriodSeconds: 0 - priorityClassName: {{ .Chart.Name }}-low-priority -{{- end }} diff --git a/charts/model-engine/templates/balloon_t4_deployment.yaml b/charts/model-engine/templates/balloon_t4_deployment.yaml deleted file mode 100644 index 6a5e8292..00000000 --- a/charts/model-engine/templates/balloon_t4_deployment.yaml +++ /dev/null @@ -1,50 +0,0 @@ -{{- if not .Values.serviceIdentifier }} -apiVersion: apps/v1 -kind: Deployment -metadata: - name: {{ .Chart.Name }}-balloon-t4 - labels: - team: infra - product: common-warm-nodes -spec: - replicas: {{ .Values.replicaCount.balloonT4 }} - selector: - matchLabels: - app: {{ .Chart.Name }}-balloon-t4 - version: v1 - template: - metadata: - labels: - app: {{ .Chart.Name }}-balloon-t4 - product: common-warm-nodes - team: infra - env: {{ .Values.context }} - version: v1 - annotations: - sidecar.istio.io/inject: "false" - spec: - nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-tesla-t4 - {{- with .Values.balloonNodeSelector }} - {{- toYaml . | nindent 8 }} - {{- end }} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - containers: - - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: IfNotPresent - name: main - resources: - limits: - memory: 28Gi - nvidia.com/gpu: 1 - cpu: 4 - command: - - /bin/bash - - -c - - "while true; do sleep 30; done" - terminationGracePeriodSeconds: 0 - priorityClassName: {{ .Chart.Name }}-low-priority -{{- end }} diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index d4e7718b..f874ac5f 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -4,11 +4,18 @@ replicaCount: gateway: 1 cacher: 1 builder: 1 - balloonA10: 0 - balloonA100: 0 - balloonCpu: 0 - balloonT4: 0 - balloonH100: 0 + +balloons: + - acceleratorName: nvidia-ampere-a10 + replicaCount: 0 + - acceleratorName: nvidia-ampere-a100 + replicaCount: 0 + - acceleratorName: cpu + replicaCount: 0 + - acceleratorName: nvidia-tesla-t4 + replicaCount: 0 + - acceleratorName: nvidia-hopper-h100 + replicaCount: 0 # tag needs to be set dynamically every time. Usually it is set to the SHA1 hash of the git # commit from which the image was built. diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 8f78f5a4..c0285039 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -73,16 +73,23 @@ replicaCount: cacher: 1 # builder is the endpoint builder deployment builder: 1 - # balloonA10 is a low priority pod deployment for A10 GPU nodes - balloonA10: 0 - # balloonA100 is a low priority pod deployment for A100 GPU nodes - balloonA100: 0 - # balloonCpu is a low priority pod deployment for CPU nodes - balloonCpu: 0 - # balloonT4 is a low priority pod deployment for T4 GPU nodes - balloonT4: 0 - # balloonH100 is a low priority pod deployment for H100 GPU nodes - balloonH100: 0 + +balloons: + # A low priority pod deployment for A10 GPU nodes + - acceleratorName: nvidia-ampere-a10 + replicaCount: 0 + # A low priority pod deployment for A100 GPU nodes + - acceleratorName: nvidia-ampere-a100 + replicaCount: 0 + # A low priority pod deployment for CPU nodes + - acceleratorName: cpu + replicaCount: 0 + # A low priority pod deployment for T4 GPU nodes + - acceleratorName: nvidia-tesla-t4 + replicaCount: 0 + # A low priority pod deployment for H100 GPU nodes + - acceleratorName: nvidia-hopper-h100 + replicaCount: 0 # autoscaling is the autoscaling configuration for LLM Engine server deployments (e.g gateway, cache, and builder deployments) autoscaling: From 6447c5f456475151e4b8af031001cec0dc0653d9 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 10 Jun 2024 10:29:06 -0700 Subject: [PATCH 314/425] Bump kv cache min memory for batch jobs (#536) * bump kv cache min for batch jobs * Add test for batch job * Bump multiplier to 18 to get batch job to use 4 GPU --- .../use_cases/llm_model_endpoint_use_cases.py | 7 ++- .../tests/unit/domain/test_llm_use_cases.py | 56 +++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index b27fe107..15bfaa69 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -2236,13 +2236,15 @@ def _infer_hardware( llm_artifact_gateway: LLMArtifactGateway, model_name: str, checkpoint_path: str, + is_batch_job: bool = False, ) -> CreateDockerImageBatchJobResourceRequests: config = llm_artifact_gateway.get_model_config(checkpoint_path) dtype_size = 2 + kv_multiplier = 20 if is_batch_job else 2 min_kv_cache_size = ( - 2 + kv_multiplier * dtype_size * config["num_hidden_layers"] * config["hidden_size"] @@ -2267,7 +2269,7 @@ def _infer_hardware( min_memory_gb = math.ceil((min_kv_cache_size + model_weights_size) / 1_000_000_000 / 0.9) logger.info( - f"Memory calculation result: {min_memory_gb=} for {model_name}, min_kv_cache_size: {min_kv_cache_size}, model_weights_size: {model_weights_size}" + f"Memory calculation result: {min_memory_gb=} for {model_name}, min_kv_cache_size: {min_kv_cache_size}, model_weights_size: {model_weights_size}, is_batch_job: {is_batch_job}" ) if min_memory_gb <= 24: @@ -2408,6 +2410,7 @@ async def execute( self.llm_artifact_gateway, request.model_config.model, request.model_config.checkpoint_path, + is_batch_job=True, ) # Reconcile gpus count with num_shards from request assert hardware.gpus is not None diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 770f3bda..4aa9e982 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1861,6 +1861,13 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + hardware = _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x7b", "", is_batch_job=True) + assert hardware.cpus == "20" + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + fake_llm_artifact_gateway.model_config = { "architectures": ["MixtralForCausalLM"], "attention_dropout": 0.0, @@ -1892,6 +1899,13 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "460Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + hardware = _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x22b", "", is_batch_job=True) + assert hardware.cpus == "80" + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "460Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + fake_llm_artifact_gateway.model_config = { "_name_or_path": "meta-llama/Llama-2-7b-hf", "architectures": ["LlamaForCausalLM"], @@ -1919,6 +1933,13 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True) + assert hardware.cpus == "20" + assert hardware.gpus == 2 + assert hardware.memory == "48Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + fake_llm_artifact_gateway.model_config = { "architectures": ["LlamaForCausalLM"], "attention_dropout": 0.0, @@ -1947,6 +1968,13 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True) + assert hardware.cpus == "20" + assert hardware.gpus == 2 + assert hardware.memory == "48Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + fake_llm_artifact_gateway.model_config = { "_name_or_path": "meta-llama/Llama-2-13b-hf", "architectures": ["LlamaForCausalLM"], @@ -1974,6 +2002,13 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-13b", "", is_batch_job=True) + assert hardware.cpus == "40" + assert hardware.gpus == 4 + assert hardware.memory == "96Gi" + assert hardware.storage == "96Gi" + assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + fake_llm_artifact_gateway.model_config = { "architectures": ["LlamaForCausalLM"], "bos_token_id": 1, @@ -2001,6 +2036,13 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "96Gi" assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + hardware = _infer_hardware(fake_llm_artifact_gateway, "codellama-34b", "", is_batch_job=True) + assert hardware.cpus == "20" + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + fake_llm_artifact_gateway.model_config = { "_name_or_path": "meta-llama/Llama-2-70b-hf", "architectures": ["LlamaForCausalLM"], @@ -2028,6 +2070,13 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-70b", "", is_batch_job=True) + assert hardware.cpus == "20" + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + fake_llm_artifact_gateway.model_config = { "architectures": ["LlamaForCausalLM"], "attention_dropout": 0.0, @@ -2056,6 +2105,13 @@ def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-70b", "", is_batch_job=True) + assert hardware.cpus == "40" + assert hardware.gpus == 4 + assert hardware.memory == "320Gi" + assert hardware.storage == "320Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + # (TODO) figure out how to calculate memory for llama-3-8b-instruct-262k # fake_llm_artifact_gateway.model_config = { # "_name_or_path": "gradientai/llama3-8b-stage65k-chat", From 4c6b176068bf89d0fa20fb9c680e2f231f60f11a Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 14 Jun 2024 11:37:06 -0700 Subject: [PATCH 315/425] DEBUG: Add additional logging for authz errors (#539) * Add debug log for authz errors * no cover --- model-engine/model_engine_server/api/llms_v1.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 07a93d78..e9fb8a43 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -237,6 +237,11 @@ async def get_model_endpoint( ) return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + if isinstance(exc, ObjectNotAuthorizedException): # pragma: no cover + logger.info( + f"GET /llm/model-endpoints/{model_endpoint_name} for {auth} failed with authz error {exc.args}" + ) + raise HTTPException( status_code=404, detail=f"Model Endpoint {model_endpoint_name} was not found.", From 69163b23a3d63e0ae34a7fd55cd1f9de439c8cca Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 14 Jun 2024 19:51:51 -0700 Subject: [PATCH 316/425] Add debug log for authz errors (#540) --- model-engine/model_engine_server/api/llms_v1.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index e9fb8a43..ea5f49cf 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -319,10 +319,10 @@ async def create_completion_sync_task( Runs a sync prompt completion on an LLM. """ if hmi_config.sensitive_log_mode: # pragma: no cover - logger.info(f"POST /completion_sync to endpoint {model_endpoint_name} for {auth}") + logger.info(f"POST /completions-sync to endpoint {model_endpoint_name} for {auth}") else: logger.info( - f"POST /completion_sync with {request} to endpoint {model_endpoint_name} for {auth}" + f"POST /completions-sync with {request} to endpoint {model_endpoint_name} for {auth}" ) try: use_case = CompletionSyncV1UseCase( @@ -356,6 +356,11 @@ async def create_completion_sync_task( detail=f"Upstream service error for request_id {request_id}", ) except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + if isinstance(exc, ObjectNotAuthorizedException): # pragma: no cover + logger.info( + f"POST /completions-sync to endpoint {model_endpoint_name} for {auth} failed with authz error {exc.args}" + ) + raise HTTPException( status_code=404, detail="The specified endpoint could not be found.", @@ -384,10 +389,10 @@ async def create_completion_stream_task( Runs a stream prompt completion on an LLM. """ if hmi_config.sensitive_log_mode: # pragma: no cover - logger.info(f"POST /completion_stream to endpoint {model_endpoint_name} for {auth}") + logger.info(f"POST /completions-stream to endpoint {model_endpoint_name} for {auth}") else: logger.info( - f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" + f"POST /completions-stream with {request} to endpoint {model_endpoint_name} for {auth}" ) use_case = CompletionStreamV1UseCase( model_endpoint_service=external_interfaces.model_endpoint_service, From dfb7b15dc4e1cb8a6695eaa2bfb6fd1d214df7b0 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 20 Jun 2024 10:42:49 -0700 Subject: [PATCH 317/425] Mitigation for AsyncEngineDeadError (#545) * Mitigation for AsyncEngineDeadError * Kill process directly --- .../inference/vllm/vllm_server.py | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index e061b924..365d271d 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -1,6 +1,7 @@ import argparse import code import json +import os import signal import subprocess import traceback @@ -10,7 +11,7 @@ from fastapi import BackgroundTasks, FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncEngineDeadError, AsyncLLMEngine +from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.outputs import CompletionOutput @@ -39,6 +40,13 @@ async def generate(request: Request) -> Response: - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). """ + # check health before accepting request and fail fast if engine isn't healthy + try: + await engine.check_health() + except Exception as e: + print(f"The vllm engine is dead, exiting the pod: {e}") + os.kill(os.getpid(), signal.SIGINT) + request_dict = await request.json() prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) @@ -75,34 +83,32 @@ async def generate(request: Request) -> Response: sampling_params.logits_processors.append(guided_decode_logit_processor) request_id = random_uuid() - try: - results_generator = engine.generate(prompt, sampling_params, request_id) - except AsyncEngineDeadError as e: - print(f"The vllm engine is dead, exiting the pod: {e}") - exit(1) - - # Streaming case - async def stream_results() -> AsyncGenerator[str, None]: - last_output_text = "" - async for request_output in results_generator: - log_probs = format_logprobs(request_output) - ret = { - "text": request_output.outputs[-1].text[len(last_output_text) :], - "count_prompt_tokens": len(request_output.prompt_token_ids), - "count_output_tokens": len(request_output.outputs[0].token_ids), - "log_probs": log_probs[-1] if log_probs and sampling_params.logprobs else None, - "finished": request_output.finished, - } - last_output_text = request_output.outputs[-1].text - yield f"data:{json.dumps(ret)}\n\n" + + results_generator = engine.generate(prompt, sampling_params, request_id) async def abort_request() -> None: await engine.abort(request_id) if stream: + # Streaming case + async def stream_results() -> AsyncGenerator[str, None]: + last_output_text = "" + async for request_output in results_generator: + log_probs = format_logprobs(request_output) + ret = { + "text": request_output.outputs[-1].text[len(last_output_text) :], + "count_prompt_tokens": len(request_output.prompt_token_ids), + "count_output_tokens": len(request_output.outputs[0].token_ids), + "log_probs": log_probs[-1] if log_probs and sampling_params.logprobs else None, + "finished": request_output.finished, + } + last_output_text = request_output.outputs[-1].text + yield f"data:{json.dumps(ret)}\n\n" + background_tasks = BackgroundTasks() # Abort the request if the client disconnects. background_tasks.add_task(abort_request) + return StreamingResponse(stream_results(), background=background_tasks) # Non-streaming case From f0fee2a7316d4148528384b400f20410aac502c9 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Thu, 20 Jun 2024 13:42:56 -0700 Subject: [PATCH 318/425] Infer hardware specs from config (#543) * Config infer WIP * unit tests * updates * remove * fix tests * comment * __init__ files * fix unit test --- .../recommended_hardware_config_map.yaml | 28 +++ charts/model-engine/values_circleci.yaml | 46 ++++ charts/model-engine/values_sample.yaml | 46 ++++ .../model_engine_server/api/llms_v1.py | 6 + .../model_engine_server/domain/exceptions.py | 9 + .../use_cases/llm_model_endpoint_use_cases.py | 91 ++++---- model-engine/tests/unit/__init__.py | 0 model-engine/tests/unit/api/__init__.py | 0 model-engine/tests/unit/api/test_llms.py | 6 + model-engine/tests/unit/conftest.py | 55 +++++ model-engine/tests/unit/domain/__init__.py | 0 .../tests/unit/domain/test_llm_use_cases.py | 216 ++++++++++-------- 12 files changed, 362 insertions(+), 141 deletions(-) create mode 100644 charts/model-engine/templates/recommended_hardware_config_map.yaml create mode 100644 model-engine/tests/unit/__init__.py create mode 100644 model-engine/tests/unit/api/__init__.py create mode 100644 model-engine/tests/unit/domain/__init__.py diff --git a/charts/model-engine/templates/recommended_hardware_config_map.yaml b/charts/model-engine/templates/recommended_hardware_config_map.yaml new file mode 100644 index 00000000..47474ceb --- /dev/null +++ b/charts/model-engine/templates/recommended_hardware_config_map.yaml @@ -0,0 +1,28 @@ +{{ if .Values.recommendedHardware }} +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "modelEngine.fullname" . }}-recommended-hardware-config + labels: + product: common + team: infra +data: + byGpuMemoryGb: |- +{{- range $.Values.recommendedHardware.byGpuMemoryGb }} + - gpu_memory_le: {{ .gpu_memory_le }} + cpus: {{ .cpus }} + gpus: {{ .gpus }} + memory: {{ .memory }} + storage: {{ .storage }} + gpu_type: {{ .gpu_type }} +{{- end }} + byModelName: |- +{{- range $.Values.recommendedHardware.byModelName }} + - name: {{ .name }} + cpus: {{ .cpus }} + gpus: {{ .gpus }} + memory: {{ .memory }} + storage: {{ .storage }} + gpu_type: {{ .gpu_type }} +{{- end }} +{{- end }} \ No newline at end of file diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index f874ac5f..b29b18e9 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -233,3 +233,49 @@ celeryBrokerType: redis datadog: enabled: false + +recommendedHardware: + byGpuMemoryGb: + - gpu_memory_le: 24 + cpus: 10 + gpus: 1 + memory: 24Gi + storage: 80Gi + gpu_type: nvidia-ampere-a10 + - gpu_memory_le: 48 + cpus: 20 + gpus: 2 + memory: 48Gi + storage: 80Gi + gpu_type: nvidia-ampere-a10 + - gpu_memory_le: 96 + cpus: 40 + gpus: 4 + memory: 96Gi + storage: 96Gi + gpu_type: nvidia-ampere-a10 + - gpu_memory_le: 180 + cpus: 20 + gpus: 2 + memory: 160Gi + storage: 160Gi + gpu_type: nvidia-hopper-h100 + - gpu_memory_le: 320 + cpus: 40 + gpus: 4 + memory: 320Gi + storage: 320Gi + gpu_type: nvidia-hopper-h100 + - gpu_memory_le: 640 + cpus: 80 + gpus: 8 + memory: 800Gi + storage: 460Gi + gpu_type: nvidia-hopper-h100 + byModelName: + - name: llama-3-8b-instruct-262k + cpus: 20 + gpus: 2 + memory: 40Gi + storage: 40Gi + gpu_type: nvidia-hopper-h100 \ No newline at end of file diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index c0285039..38f631e0 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -309,3 +309,49 @@ celeryBrokerType: sqs datadog: enabled: false + +recommendedHardware: + byGpuMemoryGb: + - gpu_memory_le: 24 + cpus: 10 + gpus: 1 + memory: 24Gi + storage: 80Gi + gpu_type: nvidia-ampere-a10 + - gpu_memory_le: 48 + cpus: 20 + gpus: 2 + memory: 48Gi + storage: 80Gi + gpu_type: nvidia-ampere-a10 + - gpu_memory_le: 96 + cpus: 40 + gpus: 4 + memory: 96Gi + storage: 96Gi + gpu_type: nvidia-ampere-a10 + - gpu_memory_le: 180 + cpus: 20 + gpus: 2 + memory: 160Gi + storage: 160Gi + gpu_type: nvidia-hopper-h100 + - gpu_memory_le: 320 + cpus: 40 + gpus: 4 + memory: 320Gi + storage: 320Gi + gpu_type: nvidia-hopper-h100 + - gpu_memory_le: 640 + cpus: 80 + gpus: 8 + memory: 800Gi + storage: 460Gi + gpu_type: nvidia-hopper-h100 + byModelName: + - name: llama-3-8b-instruct-262k + cpus: 20 + gpus: 2 + memory: 40Gi + storage: 40Gi + gpu_type: nvidia-hopper-h100 diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index ea5f49cf..6f26ae1e 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -56,6 +56,7 @@ EndpointResourceInvalidRequestException, EndpointUnsupportedInferenceTypeException, ExistingEndpointOperationInProgressException, + FailToInferHardwareException, InvalidRequestException, LLMFineTuningMethodNotImplementedException, LLMFineTuningQuotaReached, @@ -200,6 +201,11 @@ async def create_model_endpoint( status_code=404, detail="The specified docker image could not be found.", ) from exc + except FailToInferHardwareException as exc: + raise HTTPException( + status_code=500, + detail="Failed to infer hardware exception.", + ) from exc @llm_router_v1.get("/model-endpoints", response_model=ListLLMModelEndpointsV1Response) diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index e9ded985..5b81a68e 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -188,3 +188,12 @@ class LatestImageTagNotFoundException(DomainException): """ Thrown if the latest image tag cannot be found. """ + + +@dataclass +class FailToInferHardwareException(DomainException): + """ + Thrown if failed to infer hardware. + """ + + message: str diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 15bfaa69..615cef8a 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -13,6 +13,7 @@ from functools import lru_cache from typing import Any, AsyncIterable, Dict, List, Optional, Union +import yaml from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests from model_engine_server.common.dtos.llms import ( @@ -69,6 +70,7 @@ EndpointInfraStateNotFound, EndpointLabelsException, EndpointUnsupportedInferenceTypeException, + FailToInferHardwareException, InvalidRequestException, LatestImageTagNotFoundException, ObjectHasInvalidValueException, @@ -238,6 +240,7 @@ if SERVICE_IDENTIFIER: SERVICE_NAME += f"-{SERVICE_IDENTIFIER}" LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = f"{SERVICE_NAME}-inference-framework-latest-config" +RECOMMENDED_HARDWARE_CONFIG_MAP_NAME = f"{SERVICE_NAME}-recommended-hardware-config" def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int: @@ -257,6 +260,19 @@ async def _get_latest_tag(inference_framework: LLMInferenceFramework) -> str: return config_map[inference_framework] +async def _get_recommended_hardware_config_map() -> Dict[str, Any]: + try: + config_map = await read_config_map(RECOMMENDED_HARDWARE_CONFIG_MAP_NAME) + except Exception as e: + logger.error( + f"Failed to read config map {RECOMMENDED_HARDWARE_CONFIG_MAP_NAME}, can't infer hardware config." + ) + raise FailToInferHardwareException( + f"Failed to read config map {RECOMMENDED_HARDWARE_CONFIG_MAP_NAME}, can't infer hardware config." + ) from e + return config_map + + def _model_endpoint_entity_to_get_llm_model_endpoint_response( model_endpoint: ModelEndpoint, ) -> GetLLMModelEndpointV1Response: @@ -868,7 +884,7 @@ def __init__( async def execute( self, user: User, request: CreateLLMModelEndpointV1Request ) -> CreateLLMModelEndpointV1Response: - _fill_hardware_info(self.llm_artifact_gateway, request) + await _fill_hardware_info(self.llm_artifact_gateway, request) if not ( request.gpus and request.gpu_type @@ -2200,7 +2216,7 @@ async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownl return ModelDownloadResponse(urls=urls) -def _fill_hardware_info( +async def _fill_hardware_info( llm_artifact_gateway: LLMArtifactGateway, request: CreateLLMModelEndpointV1Request ): if ( @@ -2221,7 +2237,9 @@ def _fill_hardware_info( "All hardware spec fields (gpus, gpu_type, cpus, memory, storage) must be provided if any hardware spec field is missing." ) checkpoint_path = get_checkpoint_path(request.model_name, request.checkpoint_path) - hardware_info = _infer_hardware(llm_artifact_gateway, request.model_name, checkpoint_path) + hardware_info = await _infer_hardware( + llm_artifact_gateway, request.model_name, checkpoint_path + ) request.gpus = hardware_info.gpus request.gpu_type = hardware_info.gpu_type request.cpus = hardware_info.cpus @@ -2232,7 +2250,7 @@ def _fill_hardware_info( @lru_cache() -def _infer_hardware( +async def _infer_hardware( llm_artifact_gateway: LLMArtifactGateway, model_name: str, checkpoint_path: str, @@ -2272,50 +2290,27 @@ def _infer_hardware( f"Memory calculation result: {min_memory_gb=} for {model_name}, min_kv_cache_size: {min_kv_cache_size}, model_weights_size: {model_weights_size}, is_batch_job: {is_batch_job}" ) - if min_memory_gb <= 24: - cpus = "10" - gpus = 1 - memory = "24Gi" - storage = "80Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A10 - elif min_memory_gb <= 48: - cpus = "20" - gpus = 2 - memory = "48Gi" - storage = "80Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A10 - elif min_memory_gb <= 96: - cpus = "40" - gpus = 4 - memory = "96Gi" - storage = "96Gi" - gpu_type = GpuType.NVIDIA_AMPERE_A10 - elif min_memory_gb <= 180: - cpus = "20" - gpus = 2 - memory = "160Gi" - storage = "160Gi" - gpu_type = GpuType.NVIDIA_HOPPER_H100 - elif min_memory_gb <= 320: - cpus = "40" - gpus = 4 - memory = "320Gi" - storage = "320Gi" - gpu_type = GpuType.NVIDIA_HOPPER_H100 - elif min_memory_gb <= 640: - cpus = "80" - gpus = 8 - memory = "800Gi" - storage = "460Gi" - gpu_type = GpuType.NVIDIA_HOPPER_H100 - elif "llama-3-8b-instruct-262k" in model_name: - cpus = "20" - gpus = 2 - memory = "40Gi" - storage = "40Gi" - gpu_type = GpuType.NVIDIA_HOPPER_H100 + config_map = await _get_recommended_hardware_config_map() + by_model_name = {item["name"]: item for item in yaml.safe_load(config_map["byModelName"])} + by_gpu_memory_gb = yaml.safe_load(config_map["byGpuMemoryGb"]) + if model_name in by_model_name: + cpus = by_model_name[model_name]["cpus"] + gpus = by_model_name[model_name]["gpus"] + memory = by_model_name[model_name]["memory"] + storage = by_model_name[model_name]["storage"] + gpu_type = by_model_name[model_name]["gpu_type"] else: - raise ObjectHasInvalidValueException(f"Unable to infer hardware for {model_name}.") + by_gpu_memory_gb = sorted(by_gpu_memory_gb, key=lambda x: x["gpu_memory_le"]) + for recs in by_gpu_memory_gb: + if min_memory_gb <= recs["gpu_memory_le"]: + cpus = recs["cpus"] + gpus = recs["gpus"] + memory = recs["memory"] + storage = recs["storage"] + gpu_type = recs["gpu_type"] + break + else: + raise ObjectHasInvalidValueException(f"Unable to infer hardware for {model_name}.") return CreateDockerImageBatchJobResourceRequests( cpus=cpus, gpus=gpus, memory=memory, storage=storage, gpu_type=gpu_type @@ -2406,7 +2401,7 @@ async def execute( request.model_config.checkpoint_path = get_checkpoint_path( request.model_config.model, request.model_config.checkpoint_path ) - hardware = _infer_hardware( + hardware = await _infer_hardware( self.llm_artifact_gateway, request.model_config.model, request.model_config.checkpoint_path, diff --git a/model-engine/tests/unit/__init__.py b/model-engine/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/tests/unit/api/__init__.py b/model-engine/tests/unit/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 7ab55908..4ef3eafe 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -7,6 +7,8 @@ from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus from model_engine_server.domain.entities import ModelEndpoint +from ..conftest import mocked__get_recommended_hardware_config_map + def test_create_llm_model_endpoint_success( create_llm_model_endpoint_request_sync: Dict[str, Any], @@ -233,6 +235,10 @@ def test_completion_stream_endpoint_not_found_returns_404( assert "404" in message.decode("utf-8") +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) def test_create_batch_completions_success( create_batch_completions_request: Dict[str, Any], test_api_key: str, diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 4300bbea..2366019a 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -14,6 +14,7 @@ Set, Tuple, ) +from unittest import mock from unittest.mock import mock_open from uuid import uuid4 @@ -4535,3 +4536,57 @@ def llm_model_endpoint_trt_llm( image="test_image", ), ) + + +def mocked__get_recommended_hardware_config_map(): + async def async_mock(*args, **kwargs): # noqa + return { + "byGpuMemoryGb": """ + - gpu_memory_le: 20 + cpus: 5 + gpus: 1 + memory: 20Gi + storage: 40Gi + gpu_type: nvidia-hopper-h100-1g20gb + - gpu_memory_le: 40 + cpus: 10 + gpus: 1 + memory: 40Gi + storage: 80Gi + gpu_type: nvidia-hopper-h100-3g40gb + - gpu_memory_le: 80 + cpus: 20 + gpus: 1 + memory: 80Gi + storage: 96Gi + gpu_type: nvidia-hopper-h100 + - gpu_memory_le: 160 + cpus: 40 + gpus: 2 + memory: 160Gi + storage: 160Gi + gpu_type: nvidia-hopper-h100 + - gpu_memory_le: 320 + cpus: 80 + gpus: 4 + memory: 320Gi + storage: 320Gi + gpu_type: nvidia-hopper-h100 + - gpu_memory_le: 640 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + """, + "byModelName": """ + - name: llama-3-8b-instruct-262k + cpus: 40 + gpus: 2 + memory: 160Gi + storage: 160Gi + gpu_type: nvidia-hopper-h100 + """, + } + + return mock.AsyncMock(side_effect=async_mock) diff --git a/model-engine/tests/unit/domain/__init__.py b/model-engine/tests/unit/domain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 4aa9e982..0d4a9899 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -56,6 +56,8 @@ ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase +from ..conftest import mocked__get_recommended_hardware_config_map + def mocked__get_latest_tag(): async def async_mock(*args, **kwargs): # noqa @@ -1830,7 +1832,12 @@ async def test_validate_checkpoint_files_safetensors_with_other_files(): validate_checkpoint_files(fake_model_files) # No exception should be raised -def test_infer_hardware(fake_llm_artifact_gateway): +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) +async def test_infer_hardware(fake_llm_artifact_gateway): fake_llm_artifact_gateway.model_config = { "architectures": ["MixtralForCausalLM"], "attention_dropout": 0.0, @@ -1854,15 +1861,17 @@ def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.36.0.dev0", "vocab_size": 32000, } - hardware = _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x7b", "") - assert hardware.cpus == "20" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x7b", "") + assert hardware.cpus == "40" assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 - hardware = _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x7b", "", is_batch_job=True) - assert hardware.cpus == "20" + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "mixtral-8x7b", "", is_batch_job=True + ) + assert hardware.cpus == "40" assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" @@ -1892,18 +1901,20 @@ def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.40.0.dev0", "vocab_size": 32000, } - hardware = _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x22b", "") - assert hardware.cpus == "80" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x22b", "") + assert hardware.cpus == "160" assert hardware.gpus == 8 assert hardware.memory == "800Gi" - assert hardware.storage == "460Gi" + assert hardware.storage == "640Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 - hardware = _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x22b", "", is_batch_job=True) - assert hardware.cpus == "80" + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "mixtral-8x22b", "", is_batch_job=True + ) + assert hardware.cpus == "160" assert hardware.gpus == 8 assert hardware.memory == "800Gi" - assert hardware.storage == "460Gi" + assert hardware.storage == "640Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 fake_llm_artifact_gateway.model_config = { @@ -1926,19 +1937,19 @@ def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.31.0.dev0", "vocab_size": 32000, } - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "") - assert hardware.cpus == "10" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "") + assert hardware.cpus == "5" assert hardware.gpus == 1 - assert hardware.memory == "24Gi" - assert hardware.storage == "80Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + assert hardware.memory == "20Gi" + assert hardware.storage == "40Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True) - assert hardware.cpus == "20" - assert hardware.gpus == 2 - assert hardware.memory == "48Gi" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True) + assert hardware.cpus == "10" + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB fake_llm_artifact_gateway.model_config = { "architectures": ["LlamaForCausalLM"], @@ -1961,19 +1972,19 @@ def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.40.0.dev0", "vocab_size": 128256, } - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "") - assert hardware.cpus == "10" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "") + assert hardware.cpus == "5" assert hardware.gpus == 1 - assert hardware.memory == "24Gi" - assert hardware.storage == "80Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + assert hardware.memory == "20Gi" + assert hardware.storage == "40Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True) - assert hardware.cpus == "20" - assert hardware.gpus == 2 - assert hardware.memory == "48Gi" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True) + assert hardware.cpus == "10" + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB fake_llm_artifact_gateway.model_config = { "_name_or_path": "meta-llama/Llama-2-13b-hf", @@ -1995,19 +2006,21 @@ def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.32.0.dev0", "vocab_size": 32000, } - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-13b", "") - assert hardware.cpus == "20" - assert hardware.gpus == 2 - assert hardware.memory == "48Gi" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-13b", "") + assert hardware.cpus == "10" + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-13b", "", is_batch_job=True) - assert hardware.cpus == "40" - assert hardware.gpus == 4 - assert hardware.memory == "96Gi" + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "llama-2-13b", "", is_batch_job=True + ) + assert hardware.cpus == "20" + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" assert hardware.storage == "96Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 fake_llm_artifact_gateway.model_config = { "architectures": ["LlamaForCausalLM"], @@ -2029,15 +2042,17 @@ def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.32.0.dev0", "vocab_size": 32000, } - hardware = _infer_hardware(fake_llm_artifact_gateway, "codellama-34b", "") - assert hardware.cpus == "40" - assert hardware.gpus == 4 - assert hardware.memory == "96Gi" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "codellama-34b", "") + assert hardware.cpus == "20" + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" assert hardware.storage == "96Gi" - assert hardware.gpu_type == GpuType.NVIDIA_AMPERE_A10 + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 - hardware = _infer_hardware(fake_llm_artifact_gateway, "codellama-34b", "", is_batch_job=True) - assert hardware.cpus == "20" + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "codellama-34b", "", is_batch_job=True + ) + assert hardware.cpus == "40" assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" @@ -2063,18 +2078,20 @@ def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.32.0.dev0", "vocab_size": 32000, } - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-70b", "") - assert hardware.cpus == "20" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-70b", "") + assert hardware.cpus == "40" assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-2-70b", "", is_batch_job=True) - assert hardware.cpus == "20" - assert hardware.gpus == 2 - assert hardware.memory == "160Gi" - assert hardware.storage == "160Gi" + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "llama-2-70b", "", is_batch_job=True + ) + assert hardware.cpus == "80" + assert hardware.gpus == 4 + assert hardware.memory == "320Gi" + assert hardware.storage == "320Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 fake_llm_artifact_gateway.model_config = { @@ -2098,55 +2115,64 @@ def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.40.0.dev0", "vocab_size": 128256, } - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-70b", "") - assert hardware.cpus == "20" + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-70b", "") + assert hardware.cpus == "40" assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 - hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-70b", "", is_batch_job=True) - assert hardware.cpus == "40" + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "llama-3-70b", "", is_batch_job=True + ) + assert hardware.cpus == "80" assert hardware.gpus == 4 assert hardware.memory == "320Gi" assert hardware.storage == "320Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 - # (TODO) figure out how to calculate memory for llama-3-8b-instruct-262k - # fake_llm_artifact_gateway.model_config = { - # "_name_or_path": "gradientai/llama3-8b-stage65k-chat", - # "architectures": ["LlamaForCausalLM"], - # "attention_dropout": 0.0, - # "bos_token_id": 128000, - # "eos_token_id": 128001, - # "hidden_act": "silu", - # "hidden_size": 4096, - # "initializer_range": 0.02, - # "intermediate_size": 14336, - # "max_position_embeddings": 262144, - # "model_type": "llama", - # "num_attention_heads": 32, - # "num_hidden_layers": 32, - # "num_key_value_heads": 8, - # "pretraining_tp": 1, - # "rms_norm_eps": 1e-05, - # "rope_theta": 283461213.0, - # "torch_dtype": "bfloat16", - # "transformers_version": "4.41.0.dev0", - # "vocab_size": 128256, - # } - # hardware = _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b-instruct-262k", "") - # assert hardware.cpus == "20" - # assert hardware.gpus == 2 - # assert hardware.memory == "160Gi" - # assert hardware.storage == "160Gi" - # assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + fake_llm_artifact_gateway.model_config = { + "_name_or_path": "gradientai/llama3-8b-stage65k-chat", + "architectures": ["LlamaForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 262144, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 283461213.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.41.0.dev0", + "vocab_size": 128256, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b-instruct-262k", "") + assert hardware.cpus == "40" + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + + with pytest.raises(ObjectHasInvalidValueException): + await _infer_hardware(fake_llm_artifact_gateway, "unsupported_model", "") with pytest.raises(ObjectHasInvalidValueException): - _infer_hardware(fake_llm_artifact_gateway, "unsupported_model", "") + await _infer_hardware(fake_llm_artifact_gateway, "llama-3-999b", "") -def test_fill_hardware_info(fake_llm_artifact_gateway): +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) +async def test_fill_hardware_info(fake_llm_artifact_gateway): request = CreateLLMModelEndpointV1Request( name="mixtral-8x7b", model_name="mixtral-8x7b", @@ -2157,8 +2183,8 @@ def test_fill_hardware_info(fake_llm_artifact_gateway): per_worker=1, labels={}, ) - _fill_hardware_info(fake_llm_artifact_gateway, request) - assert request.cpus == "20" + await _fill_hardware_info(fake_llm_artifact_gateway, request) + assert request.cpus == "40" assert request.gpus == 2 assert request.memory == "160Gi" assert request.storage == "160Gi" @@ -2177,10 +2203,14 @@ def test_fill_hardware_info(fake_llm_artifact_gateway): ) with pytest.raises(ObjectHasInvalidValueException): - _fill_hardware_info(fake_llm_artifact_gateway, request) + await _fill_hardware_info(fake_llm_artifact_gateway, request) @pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) async def test_create_batch_completions( fake_docker_image_batch_job_gateway, fake_docker_repository_image_always_exists, From 2756aeda0dceaaed283502e27c05cc3f31dd5011 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Fri, 21 Jun 2024 12:09:36 -0700 Subject: [PATCH 319/425] Add special token param to completions + batch completions apis (#544) * add field to dtos * add to client, add to batch inference * smoke test * temp change * oops * revert changes --- clients/python/llmengine/data_types.py | 6 ++++++ model-engine/model_engine_server/common/dtos/llms.py | 12 ++++++++++++ .../domain/use_cases/llm_model_endpoint_use_cases.py | 4 ++++ .../inference/batch_inference/vllm_batch.py | 4 ++++ model-engine/tests/unit/domain/test_llm_use_cases.py | 1 + 5 files changed, 27 insertions(+) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index f1c9b56c..965006ec 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -332,6 +332,7 @@ class CompletionSyncV1Request(BaseModel): guided_regex: Optional[str] = Field(default=None) guided_choice: Optional[List[str]] = Field(default=None) guided_grammar: Optional[str] = Field(default=None) + skip_special_tokens: Optional[bool] = Field(default=True) class TokenOutput(BaseModel): @@ -407,6 +408,7 @@ class CompletionStreamV1Request(BaseModel): guided_regex: Optional[str] = Field(default=None) guided_choice: Optional[List[str]] = Field(default=None) guided_grammar: Optional[str] = Field(default=None) + skip_special_tokens: Optional[bool] = Field(default=True) class CompletionStreamOutput(BaseModel): @@ -699,6 +701,10 @@ class CreateBatchCompletionsRequestContent(BaseModel): """ Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. + """ class CreateBatchCompletionsModelConfig(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 90498c23..9ee8ef24 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -200,6 +200,10 @@ class CompletionSyncV1Request(BaseModel): """ Context-free grammar for guided decoding. Only supported in vllm. """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ class TokenOutput(BaseModel): @@ -280,6 +284,10 @@ class CompletionStreamV1Request(BaseModel): """ Context-free grammar for guided decoding. Only supported in vllm. """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ class CompletionStreamOutput(BaseModel): @@ -450,6 +458,10 @@ class CreateBatchCompletionsRequestContent(BaseModel): """ Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. + """ class CreateBatchCompletionsModelConfig(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 615cef8a..e615cdec 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1703,6 +1703,8 @@ async def execute( vllm_args["guided_json"] = request.guided_json if request.guided_grammar is not None: vllm_args["guided_grammar"] = request.guided_grammar + if request.skip_special_tokens is not None: + vllm_args["skip_special_tokens"] = request.skip_special_tokens inference_request = SyncEndpointPredictV1Request( args=vllm_args, @@ -1973,6 +1975,8 @@ async def execute( args["guided_json"] = request.guided_json if request.guided_grammar is not None: args["guided_grammar"] = request.guided_grammar + if request.skip_special_tokens is not None: + args["skip_special_tokens"] = request.skip_special_tokens args["stream"] = True elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: args = { diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index e4887b20..7add5911 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -217,6 +217,7 @@ def __repr__(self) -> str: content.frequency_penalty, content.top_k, content.top_p, + content.skip_special_tokens, [iter[0] for iter in iter_prompts], bar, use_tool=True, @@ -366,6 +367,7 @@ async def batch_inference(): content.frequency_penalty, content.top_k, content.top_p, + content.skip_special_tokens, prompts, bar, use_tool=False, @@ -401,6 +403,7 @@ async def generate_with_vllm( frequency_penalty, top_k, top_p, + skip_special_tokens, prompts, bar, use_tool, @@ -424,6 +427,7 @@ async def generate_with_vllm( frequency_penalty=frequency_penalty or 0.0, top_k=top_k or -1, top_p=top_p or 1.0, + skip_special_tokens=skip_special_tokens if skip_special_tokens is not None else True, ) results_generator = await engine.add_request( request_id, prompt, sampling_params, None, time.monotonic() diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 0d4a9899..bb3a9604 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -711,6 +711,7 @@ async def test_completion_sync_use_case_success( ): completion_sync_request.include_stop_str_in_output = True completion_sync_request.guided_json = {} + completion_sync_request.skip_special_tokens = False fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( SyncEndpointPredictV1Response( From 51b38a925a3a28553fc46a7f2f56aa54d544b5e9 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 21 Jun 2024 18:51:28 -0700 Subject: [PATCH 320/425] Fix integration test (#546) --- integration_tests/rest_api_utils.py | 2 +- integration_tests/test_completions.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index 285e2c1d..8a1c3525 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -176,7 +176,7 @@ def my_model(**keyword_args): "cpus": 20, "gpus": 1, "memory": "20Gi", - "gpu_type": "nvidia-ampere-a10", + "gpu_type": "nvidia-hopper-h100-1g20gb", "storage": "40Gi", "optimize_costs": False, "min_workers": 1, diff --git a/integration_tests/test_completions.py b/integration_tests/test_completions.py index 01dcdc2d..e2530963 100644 --- a/integration_tests/test_completions.py +++ b/integration_tests/test_completions.py @@ -11,6 +11,7 @@ create_llm_streaming_tasks, create_llm_sync_tasks, delete_llm_model_endpoint, + ensure_launch_gateway_healthy, ensure_llm_task_response_is_correct, ensure_n_ready_private_llm_endpoints_short, ensure_nonzero_available_llm_workers, @@ -26,6 +27,7 @@ reason="Skip unless running inference framework tests", ) def test_completions(capsys): + ensure_launch_gateway_healthy() with capsys.disabled(): try: user = USER_ID_0 From d8b5efeace3d1e5b2efef4a29f63021a22c9e502 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 24 Jun 2024 09:45:36 -0700 Subject: [PATCH 321/425] Bump vllm to v0.5.0.post1 (#547) --- .../inference/vllm/Dockerfile | 40 ---- .../inference/vllm/requirements-build.txt | 8 - .../inference/vllm/requirements.txt | 2 +- .../inference/vllm/vllm_server.py | 177 +++++++++--------- 4 files changed, 91 insertions(+), 136 deletions(-) delete mode 100644 model-engine/model_engine_server/inference/vllm/requirements-build.txt diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile b/model-engine/model_engine_server/inference/vllm/Dockerfile index 227b3e16..75b9e1f5 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile @@ -1,37 +1,3 @@ -#################### BASE BUILD IMAGE #################### -FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev -RUN apt-get update -y \ - && apt-get install -y python3-pip git -# Workaround for https://github.com/openai/triton/issues/2507 and -# https://github.com/pytorch/pytorch/issues/107960 -- hopefully -# this won't be needed for future versions of this docker image -# or future versions of triton. -RUN ldconfig /usr/local/cuda-12.1/compat/ -WORKDIR /workspace - -COPY requirements-build.txt requirements-build.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-build.txt -#################### BASE BUILD IMAGE #################### - -#################### FLASH_ATTENTION Build IMAGE #################### -FROM dev as flash-attn-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# flash attention version -ARG flash_attn_version=v2.4.2 -ENV FLASH_ATTN_VERSION=${flash_attn_version} - -WORKDIR /usr/src/flash-attention-v2 - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ - --no-build-isolation --no-deps --no-cache-dir - -#################### FLASH_ATTENTION Build IMAGE #################### - -#################### Runtime IMAGE #################### FROM nvcr.io/nvidia/pytorch:23.09-py3 RUN apt-get update \ @@ -41,10 +7,6 @@ RUN apt-get update \ && apt-get autoremove -y \ && rm -rf /var/lib/apt/lists/* -# Install flash attention (from pre-built wheel) -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir - RUN pip uninstall torch -y COPY requirements.txt /workspace/requirements.txt RUN pip install -r requirements.txt @@ -53,5 +15,3 @@ RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linu RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz COPY vllm_server.py /workspace/vllm_server.py - -#################### Runtime IMAGE #################### diff --git a/model-engine/model_engine_server/inference/vllm/requirements-build.txt b/model-engine/model_engine_server/inference/vllm/requirements-build.txt deleted file mode 100644 index 020e4532..00000000 --- a/model-engine/model_engine_server/inference/vllm/requirements-build.txt +++ /dev/null @@ -1,8 +0,0 @@ -# Copied from https://github.com/vllm-project/vllm/blob/main/requirements-build.txt -# Needed to build flash-attn into docker image -cmake>=3.21 -ninja -packaging -setuptools>=49.4.0 -torch==2.3.0 -wheel \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 9b106e07..e56693e7 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,2 +1,2 @@ -vllm==0.4.2 +vllm==0.5.0.post1 pydantic>=2.0 diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 365d271d..e32b0834 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -11,7 +11,7 @@ from fastapi import BackgroundTasks, FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.async_llm_engine import AsyncEngineDeadError, AsyncLLMEngine from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.outputs import CompletionOutput @@ -43,97 +43,101 @@ async def generate(request: Request) -> Response: # check health before accepting request and fail fast if engine isn't healthy try: await engine.check_health() - except Exception as e: - print(f"The vllm engine is dead, exiting the pod: {e}") - os.kill(os.getpid(), signal.SIGINT) - request_dict = await request.json() - prompt = request_dict.pop("prompt") - stream = request_dict.pop("stream", False) - guided_json = request_dict.pop("guided_json", None) - guided_regex = request_dict.pop("guided_regex", None) - guided_choice = request_dict.pop("guided_choice", None) - guided_grammar = request_dict.pop("guided_grammar", None) - sampling_params = SamplingParams(**request_dict) + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + guided_json = request_dict.pop("guided_json", None) + guided_regex = request_dict.pop("guided_regex", None) + guided_choice = request_dict.pop("guided_choice", None) + guided_grammar = request_dict.pop("guided_grammar", None) + sampling_params = SamplingParams(**request_dict) + + # Dummy request to get guided decode logit processor + try: + partial_openai_request = OpenAICompletionRequest.model_validate( + { + "model": "", + "prompt": "", + "guided_json": guided_json, + "guided_regex": guided_regex, + "guided_choice": guided_choice, + "guided_grammar": guided_grammar, + } + ) + except Exception: + raise HTTPException( + status_code=400, detail="Bad request: failed to parse guided decoding parameters." + ) - # Dummy request to get guided decode logit processor - try: - partial_openai_request = OpenAICompletionRequest.model_validate( - { - "model": "", - "prompt": "", - "guided_json": guided_json, - "guided_regex": guided_regex, - "guided_choice": guided_choice, - "guided_grammar": guided_grammar, - } - ) - except Exception: - raise HTTPException( - status_code=400, detail="Bad request: failed to parse guided decoding parameters." + guided_decoding_backend = engine.engine.decoding_config.guided_decoding_backend + guided_decode_logit_processor = await get_guided_decoding_logits_processor( + guided_decoding_backend, partial_openai_request, await engine.get_tokenizer() ) + if guided_decode_logit_processor is not None: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(guided_decode_logit_processor) - guided_decoding_backend = engine.engine.decoding_config.guided_decoding_backend - guided_decode_logit_processor = await get_guided_decoding_logits_processor( - guided_decoding_backend, partial_openai_request, await engine.get_tokenizer() - ) - if guided_decode_logit_processor is not None: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append(guided_decode_logit_processor) - - request_id = random_uuid() - - results_generator = engine.generate(prompt, sampling_params, request_id) - - async def abort_request() -> None: - await engine.abort(request_id) - - if stream: - # Streaming case - async def stream_results() -> AsyncGenerator[str, None]: - last_output_text = "" - async for request_output in results_generator: - log_probs = format_logprobs(request_output) - ret = { - "text": request_output.outputs[-1].text[len(last_output_text) :], - "count_prompt_tokens": len(request_output.prompt_token_ids), - "count_output_tokens": len(request_output.outputs[0].token_ids), - "log_probs": log_probs[-1] if log_probs and sampling_params.logprobs else None, - "finished": request_output.finished, - } - last_output_text = request_output.outputs[-1].text - yield f"data:{json.dumps(ret)}\n\n" - - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) - - return StreamingResponse(stream_results(), background=background_tasks) - - # Non-streaming case - final_output = None - tokens = [] - last_output_text = "" - async for request_output in results_generator: - tokens.append(request_output.outputs[-1].text[len(last_output_text) :]) - last_output_text = request_output.outputs[-1].text - if await request.is_disconnected(): - # Abort the request if the client disconnects. + request_id = random_uuid() + + results_generator = engine.generate(prompt, sampling_params, request_id) + + async def abort_request() -> None: await engine.abort(request_id) - return Response(status_code=499) - final_output = request_output - assert final_output is not None - prompt = final_output.prompt - ret = { - "text": final_output.outputs[0].text, - "count_prompt_tokens": len(final_output.prompt_token_ids), - "count_output_tokens": len(final_output.outputs[0].token_ids), - "log_probs": format_logprobs(final_output), - "tokens": tokens, - } - return Response(content=json.dumps(ret)) + if stream: + # Streaming case + async def stream_results() -> AsyncGenerator[str, None]: + last_output_text = "" + async for request_output in results_generator: + log_probs = format_logprobs(request_output) + ret = { + "text": request_output.outputs[-1].text[len(last_output_text) :], + "count_prompt_tokens": len(request_output.prompt_token_ids), + "count_output_tokens": len(request_output.outputs[0].token_ids), + "log_probs": log_probs[-1] + if log_probs and sampling_params.logprobs + else None, + "finished": request_output.finished, + } + last_output_text = request_output.outputs[-1].text + yield f"data:{json.dumps(ret)}\n\n" + + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) + + return StreamingResponse(stream_results(), background=background_tasks) + + # Non-streaming case + final_output = None + tokens = [] + last_output_text = "" + async for request_output in results_generator: + tokens.append(request_output.outputs[-1].text[len(last_output_text) :]) + last_output_text = request_output.outputs[-1].text + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + ret = { + "text": final_output.outputs[0].text, + "count_prompt_tokens": len(final_output.prompt_token_ids), + "count_output_tokens": len(final_output.outputs[0].token_ids), + "log_probs": format_logprobs(final_output), + "tokens": tokens, + } + return Response(content=json.dumps(ret)) + + except AsyncEngineDeadError as e: + print(f"The vllm engine is dead, exiting the pod: {e}") + os.kill(os.getpid(), signal.SIGINT) + raise e def get_gpu_free_memory(): @@ -206,7 +210,6 @@ def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) - engine.check_health() signal.signal(signal.SIGUSR1, debug) From 4471d196865579d5adf93fce0795a002a5d1d28a Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 24 Jun 2024 11:05:52 -0700 Subject: [PATCH 322/425] Fix integration tests for streaming case (#548) --- integration_tests/rest_api_utils.py | 25 +++++++++++++++++++++++-- integration_tests/test_completions.py | 3 ++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index 8a1c3525..f77d96ea 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -168,7 +168,7 @@ def my_model(**keyword_args): CREATE_LLM_MODEL_ENDPOINT_REQUEST: Dict[str, Any] = { "name": format_name("llama-2-7b-test"), - "model_name": "llama-2-7b", + "model_name": "llama-2-7b-chat", "source": "hugging_face", "inference_framework": "vllm", "inference_framework_image_tag": "latest", @@ -802,7 +802,7 @@ async def create_llm_streaming_task( timeout=LONG_NETWORK_TIMEOUT_SEC, ) as response: assert response.status == 200, (await response.read()).decode() - return await response.json() + return (await response.read()).decode() async def create_sync_tasks( @@ -987,6 +987,27 @@ def ensure_llm_task_response_is_correct( assert re.search(response_text_regex, response["output"]["text"]) +def ensure_llm_task_stream_response_is_correct( + response: str, + required_output_fields: Optional[List[str]], + response_text_regex: Optional[str], +): + # parse response + # data has format "data: \n\ndata: \n\n" + # We want to get a list of dictionaries parsing out the 'data:' field + parsed_response = [ + json.loads(r.split("data: ")[1]) for r in response.split("\n") if "data:" in r.strip() + ] + + # Join the text field of the response + response_text = "".join([r["output"]["text"] for r in parsed_response]) + print("response text: ", response_text) + assert response_text is not None + + if response_text_regex is not None: + assert re.search(response_text_regex, response_text) + + # Wait up to 30 seconds for the tasks to be returned. @retry( stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError) diff --git a/integration_tests/test_completions.py b/integration_tests/test_completions.py index e2530963..aac6b213 100644 --- a/integration_tests/test_completions.py +++ b/integration_tests/test_completions.py @@ -13,6 +13,7 @@ delete_llm_model_endpoint, ensure_launch_gateway_healthy, ensure_llm_task_response_is_correct, + ensure_llm_task_stream_response_is_correct, ensure_n_ready_private_llm_endpoints_short, ensure_nonzero_available_llm_workers, ) @@ -86,7 +87,7 @@ def test_completions(capsys): ) ) for response in task_responses: - ensure_llm_task_response_is_correct( + ensure_llm_task_stream_response_is_correct( response, required_output_fields, response_text_regex ) except Exception as e: From f92830b6bfd0830affe753defa1ec5a1343b666e Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 25 Jun 2024 19:30:08 -0700 Subject: [PATCH 323/425] Update vllm batch job to work with vllm > 0.5.0 (#550) * Update vllm batch job to work with vllm > 0.5.0 * Fix test * Add comments --- .../model_engine_server/common/dtos/llms.py | 11 ++ .../inference/batch_inference/dto.py | 165 ++++++++++++++++++ .../batch_inference/requirements.txt | 4 +- .../inference/batch_inference/vllm_batch.py | 36 ++-- model-engine/tests/unit/inference/conftest.py | 30 ++-- 5 files changed, 211 insertions(+), 35 deletions(-) create mode 100644 model-engine/model_engine_server/inference/batch_inference/dto.py diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 9ee8ef24..40d6f2ca 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -1,5 +1,7 @@ """ DTOs for LLM APIs. + +Make sure to keep this in sync with inference/batch_inference/dto.py. """ from typing import Any, Dict, List, Optional @@ -553,6 +555,14 @@ class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest): hidden from the DTO exposed to the client. """ + model_cfg: CreateBatchCompletionsModelConfig + """ + Model configuration for the batch inference. Hardware configurations are inferred. + + We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + reserves model_config as a keyword. + """ + max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0) """ Maximum GPU memory utilization for the batch inference. Default to 90%. @@ -565,6 +575,7 @@ def from_api(request: CreateBatchCompletionsRequest) -> "CreateBatchCompletionsE output_data_path=request.output_data_path, content=request.content, model_config=request.model_config, + model_cfg=request.model_config, data_parallelism=request.data_parallelism, max_runtime_sec=request.max_runtime_sec, tool_config=request.tool_config, diff --git a/model-engine/model_engine_server/inference/batch_inference/dto.py b/model-engine/model_engine_server/inference/batch_inference/dto.py new file mode 100644 index 00000000..da63c545 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/dto.py @@ -0,0 +1,165 @@ +# This is a copy of model_engine_server.common.dtos.llm +# This is done to decouple the pydantic requirements since vllm requires pydantic >2 +# while model engine is on 1.x +from enum import Enum +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field + + +class TokenOutput(BaseModel): + token: str + log_prob: float + + +class CompletionOutput(BaseModel): + text: str + num_prompt_tokens: int + num_completion_tokens: int + tokens: Optional[List[TokenOutput]] = None + + +class CreateBatchCompletionsRequestContent(BaseModel): + prompts: List[str] + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. + """ + + +class Quantization(str, Enum): + BITSANDBYTES = "bitsandbytes" + AWQ = "awq" + + +class CreateBatchCompletionsModelConfig(BaseModel): + model: str + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + labels: Dict[str, str] + """ + Labels to attach to the batch inference job. + """ + num_shards: Optional[int] = 1 + """ + Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. + System may decide to use a different number than the given value. + """ + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + seed: Optional[int] = None + """ + Random seed for the model. + """ + + +class ToolConfig(BaseModel): + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + name: str + """ + Name of the tool to use for the batch inference. + """ + max_iterations: Optional[int] = 10 + """ + Maximum number of iterations to run the tool. + """ + execution_timeout_seconds: Optional[int] = 60 + """ + Maximum runtime of the tool in seconds. + """ + should_retry_on_error: Optional[bool] = True + """ + Whether to retry the tool on error. + """ + + +class CreateBatchCompletionsRequest(BaseModel): + """ + Request object for batch completions. + """ + + input_data_path: Optional[str] + output_data_path: str + """ + Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. + """ + content: Optional[CreateBatchCompletionsRequestContent] = None + """ + Either `input_data_path` or `content` needs to be provided. + When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. + """ + + data_parallelism: Optional[int] = Field(default=1, ge=1, le=64) + """ + Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference. + """ + max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600) + """ + Maximum runtime of the batch inference in seconds. Default to one day. + """ + tool_config: Optional[ToolConfig] = None + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + +class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest): + """ + Internal model for representing request to the llm engine. This contains additional fields that we want + hidden from the DTO exposed to the client. + """ + + model_cfg: CreateBatchCompletionsModelConfig = Field(alias="model_config") + """ + Model configuration for the batch inference. Hardware configurations are inferred. + + We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + reserves model_config as a keyword. + + We alias `model_config` for deserialization for backwards compatibility. + """ + + max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0) + """ + Maximum GPU memory utilization for the batch inference. Default to 90%. + """ diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt index e83b4ccd..ca6b220f 100644 --- a/model-engine/model_engine_server/inference/batch_inference/requirements.txt +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -1,5 +1,5 @@ -vllm==0.2.5 -pydantic==1.10.13 +vllm==0.5.0.post1 +pydantic>=2 boto3==1.34.15 smart-open==6.4.0 ddtrace==2.4.0 diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 7add5911..7881e182 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -13,7 +13,7 @@ import boto3 import smart_open from func_timeout import FunctionTimedOut, func_set_timeout -from model_engine_server.common.dtos.llms import ( +from model_engine_server.inference.batch_inference.dto import ( CompletionOutput, CreateBatchCompletionsEngineRequest, CreateBatchCompletionsRequestContent, @@ -150,9 +150,9 @@ def get_vllm_engine(model: str, request: CreateBatchCompletionsEngineRequest): engine_args = AsyncEngineArgs( model=model, - quantization=request.model_config.quantize, - tensor_parallel_size=request.model_config.num_shards, - seed=request.model_config.seed or 0, + quantization=request.model_cfg.quantize, + tensor_parallel_size=request.model_cfg.num_shards, + seed=request.model_cfg.seed or 0, disable_log_requests=True, gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9, ) @@ -316,18 +316,16 @@ async def batch_inference(): request = CreateBatchCompletionsEngineRequest.parse_file(CONFIG_FILE) - if request.model_config.checkpoint_path is not None: - download_model(request.model_config.checkpoint_path, MODEL_WEIGHTS_FOLDER) + if request.model_cfg.checkpoint_path is not None: + download_model(request.model_cfg.checkpoint_path, MODEL_WEIGHTS_FOLDER) content = request.content if content is None: with smart_open.open(request.input_data_path, "r") as f: content = CreateBatchCompletionsRequestContent.parse_raw(f.read()) - model = ( - MODEL_WEIGHTS_FOLDER if request.model_config.checkpoint_path else request.model_config.model - ) - is_finetuned = request.model_config.checkpoint_path is not None + model = MODEL_WEIGHTS_FOLDER if request.model_cfg.checkpoint_path else request.model_cfg.model + is_finetuned = request.model_cfg.checkpoint_path is not None llm = get_vllm_engine(model, request) @@ -352,7 +350,7 @@ async def batch_inference(): prompts, tool, is_finetuned, - request.model_config.model, + request.model_cfg.model, ) else: bar = tqdm(total=len(prompts), desc="Processed prompts") @@ -372,7 +370,7 @@ async def batch_inference(): bar, use_tool=False, is_finetuned=is_finetuned, - model=request.model_config.model, + model=request.model_cfg.model, ) bar.close() @@ -430,27 +428,25 @@ async def generate_with_vllm( skip_special_tokens=skip_special_tokens if skip_special_tokens is not None else True, ) results_generator = await engine.add_request( - request_id, prompt, sampling_params, None, time.monotonic() + request_id, prompt, sampling_params, time.monotonic(), None ) results_generators.append(results_generator) outputs = [] for generator in results_generators: - last_output_text = "" tokens = [] async for request_output in generator: if request_output.finished: bar.update(1) - token_text = request_output.outputs[-1].text[len(last_output_text) :] - log_probs = request_output.outputs[0].logprobs[-1] if return_token_log_probs else None - last_output_text = request_output.outputs[-1].text - if return_token_log_probs: + output = request_output.outputs[0] + log_probs = output.logprobs[-1] if return_token_log_probs else None + token_id = output.token_ids[-1] tokens.append( TokenOutput( - token=token_text, - log_prob=log_probs[request_output.outputs[0].token_ids[-1]], + token=log_probs[token_id].decoded_token, + log_prob=log_probs[token_id].logprob, ) ) diff --git a/model-engine/tests/unit/inference/conftest.py b/model-engine/tests/unit/inference/conftest.py index e8bdca29..870f4075 100644 --- a/model-engine/tests/unit/inference/conftest.py +++ b/model-engine/tests/unit/inference/conftest.py @@ -1,11 +1,10 @@ from unittest.mock import MagicMock import pytest -from model_engine_server.common.dtos.llms import ( +from model_engine_server.inference.batch_inference.dto import ( CompletionOutput, CreateBatchCompletionsEngineRequest, CreateBatchCompletionsModelConfig, - CreateBatchCompletionsRequest, CreateBatchCompletionsRequestContent, TokenOutput, ToolConfig, @@ -14,16 +13,18 @@ @pytest.fixture def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineRequest: + model_config = CreateBatchCompletionsModelConfig( + model="model", + checkpoint_path="checkpoint_path", + labels={}, + seed=123, + num_shards=4, + ) return CreateBatchCompletionsEngineRequest( input_data_path="input_data_path", output_data_path="output_data_path", - model_config=CreateBatchCompletionsModelConfig( - model="model", - checkpoint_path="checkpoint_path", - labels={}, - seed=123, - num_shards=4, - ), + model_cfg=model_config, + model_config=model_config, data_parallelism=1, max_runtime_sec=86400, max_gpu_memory_utilization=0.95, @@ -32,10 +33,13 @@ def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineReq @pytest.fixture def create_batch_completions_tool_completion_request(): - return CreateBatchCompletionsRequest( - model_config=CreateBatchCompletionsModelConfig( - checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={} - ), + model_config = CreateBatchCompletionsModelConfig( + checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={} + ) + + return CreateBatchCompletionsEngineRequest( + model_cfg=model_config, + model_config=model_config, data_parallelism=1, input_data_path="input_data_path", output_data_path="output_data_path", From c1b521d90c8060662c2a4383eccd23110292acd3 Mon Sep 17 00:00:00 2001 From: Anant Marur <165984904+anant-marur@users.noreply.github.com> Date: Tue, 25 Jun 2024 23:37:26 -0700 Subject: [PATCH 324/425] Modify v1 completions_stream logic to raise most exceptions before async streaming inference response (#534) * consolidate streaming response logic into separate inline function. call execute() synchronously and call inline function async * iterate * refactor: pull inference result status/empty check outside of framework conditionals to dedupe code. put logic for unsuccessful/empty results before other handling logic for readability. add some commenting and other small edits. * formatting fixes * improve commenting * fix and reenable 404 unit test * fix stream success unit test, add async test client fixture * move _response_chunk_generator() from an inline def in execute() to a separate private method for the usecase * fix issue with streaming tests interacting by defining a per-session event loop fixture and reconfiguring test_create_streaming_task_success as an async test * one more unit test * update llm-engine Completions docs with details on streaming error handling --- docs/guides/completions.md | 9 +- .../model_engine_server/api/llms_v1.py | 43 +++++-- .../use_cases/llm_model_endpoint_use_cases.py | 117 +++++++++++------- model-engine/tests/unit/api/conftest.py | 80 ++++++++++++ model-engine/tests/unit/api/test_llms.py | 100 +++++++++------ model-engine/tests/unit/api/test_tasks.py | 36 +++--- .../tests/unit/domain/test_llm_use_cases.py | 8 +- 7 files changed, 280 insertions(+), 113 deletions(-) diff --git a/docs/guides/completions.md b/docs/guides/completions.md index 250bbbd1..56b4538d 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -67,7 +67,11 @@ applications. When streaming, tokens will be sent as data-only To enable token streaming, pass `stream=True` to either [Completion.create](../../api/python_client/#llmengine.completion.Completion.create) or [Completion.acreate](../../api/python_client/#llmengine.completion.Completion.acreate). -Note that errors from streaming calls are returned back to the user as plain-text messages and currently need to be handled by the client. +### Streaming Error Handling + +Note: Error handling semantics are mixed for streaming calls: +- Errors that arise *before* streaming begins are returned back to the user as `HTTP` errors with the appropriate status code. +- Errors that arise *after* streaming begins within a `HTTP 200` response are returned back to the user as plain-text messages and currently need to be handled by the client. An example of token streaming using the synchronous Completions API looks as follows: @@ -78,6 +82,7 @@ import sys from llmengine import Completion +# errors occurring before streaming begins will be thrown here stream = Completion.create( model="llama-2-7b", prompt="Give me a 200 word summary on the current economic events in the US.", @@ -90,7 +95,7 @@ for response in stream: if response.output: print(response.output.text, end="") sys.stdout.flush() - else: # an error occurred + else: # an error occurred after streaming began print(response.error) # print the error message out break ``` diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 6f26ae1e..db4da712 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -405,7 +405,29 @@ async def create_completion_stream_task( llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, tokenizer_repository=external_interfaces.tokenizer_repository, ) - response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request) + + try: + # Call execute() with await, since it needs to handle exceptions before we begin streaming the response below. + # execute() will create a response chunk generator and return a reference to it. + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except EndpointUnsupportedInferenceTypeException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException( + status_code=500, detail="Internal error occurred. Our team has been notified." + ) from exc async def event_generator(): try: @@ -427,14 +449,19 @@ async def event_generator(): ), metric_metadata, ) - except (InvalidRequestException, ObjectHasInvalidValueException) as exc: + # The following two exceptions are only raised after streaming begins, so we wrap the exception within a Response object + except InvalidRequestException as exc: yield handle_streaming_exception(exc, 400, str(exc)) - except ( - ObjectNotFoundException, - ObjectNotAuthorizedException, - EndpointUnsupportedInferenceTypeException, - ) as exc: - yield handle_streaming_exception(exc, 404, str(exc)) + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + yield handle_streaming_exception( + exc, + 500, + f"Upstream service error for request_id {request_id}", + ) except Exception as exc: yield handle_streaming_exception( exc, 500, "Internal error occurred. Our team has been notified." diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index e615cdec..cf8b1f55 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -78,7 +78,10 @@ ObjectNotFoundException, UpstreamServiceError, ) -from model_engine_server.domain.gateways import DockerImageBatchJobGateway +from model_engine_server.domain.gateways import ( + DockerImageBatchJobGateway, + StreamingModelEndpointInferenceGateway, +) from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway from model_engine_server.domain.repositories import ( DockerImageBatchJobBundleRepository, @@ -1845,6 +1848,9 @@ async def execute( ) -> AsyncIterable[CompletionStreamV1Response]: """ Runs the use case to create a stream inference task. + NOTE: Must be called with await(), since the function is not a generator itself, but rather creates one and + returns a reference to it. This structure allows exceptions that occur before response streaming begins + to propagate to the client as HTTP exceptions with the appropriate code. Args: user: The user who is creating the stream inference task. @@ -1852,11 +1858,17 @@ async def execute( request: The body of the request to forward to the endpoint. Returns: - A response object that contains the status and result of the task. + An asynchronous response chunk generator, containing response objects to be iterated through with 'async for'. + Each response object contains the status and result of the task. Raises: ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectHasInvalidValueException: If there are multiple model endpoints with the given name. ObjectNotAuthorizedException: If the owner does not own the model endpoint. + EndpointUnsupportedInferenceTypeException: If the model endpoint does not support streaming or uses + an unsupported inference framework. + UpstreamServiceError: If an error occurs upstream in the streaming inference API call. + InvalidRequestException: If request validation fails during inference. """ request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) @@ -2020,7 +2032,6 @@ async def execute( model_content.model_name, self.tokenizer_repository, ) - else: raise EndpointUnsupportedInferenceTypeException( f"Unsupported inference framework {model_content.inference_framework}" @@ -2031,15 +2042,55 @@ async def execute( num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) + + return self._response_chunk_generator( + request=request, + request_id=request_id, + model_endpoint=model_endpoint, + model_content=model_content, + inference_gateway=inference_gateway, + inference_request=inference_request, + num_prompt_tokens=num_prompt_tokens, + ) + + async def _response_chunk_generator( + self, + request: CompletionStreamV1Request, + request_id: Optional[str], + model_endpoint: ModelEndpoint, + model_content: GetLLMModelEndpointV1Response, + inference_gateway: StreamingModelEndpointInferenceGateway, + inference_request: SyncEndpointPredictV1Request, + num_prompt_tokens: Optional[int], + ) -> AsyncIterable[CompletionStreamV1Response]: + """ + Async generator yielding tokens to stream for the completions response. Should only be called when + returned directly by execute(). + """ predict_result = inference_gateway.streaming_predict( topic=model_endpoint.record.destination, predict_request=inference_request ) num_completion_tokens = 0 async for res in predict_result: - result = res.result - if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: - if res.status == TaskStatus.SUCCESS and result is not None: + if not res.status == TaskStatus.SUCCESS or res.result is None: + # Raise an UpstreamServiceError if the task has failed + if res.status == TaskStatus.FAILURE: + raise UpstreamServiceError( + status_code=500, + content=( + res.traceback.encode("utf-8") if res.traceback is not None else b"" + ), + ) + # Otherwise, yield empty response chunk for unsuccessful or empty results + yield CompletionStreamV1Response( + request_id=request_id, + output=None, + ) + else: + result = res.result + # DEEPSPEED + if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: if "token" in result["result"]: yield CompletionStreamV1Response( request_id=request_id, @@ -2063,15 +2114,11 @@ async def execute( num_completion_tokens=completion_token_count, ), ) - else: - yield CompletionStreamV1Response( - request_id=request_id, - output=None, - ) - elif ( - model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE - ): - if res.status == TaskStatus.SUCCESS and result is not None: + # TEXT_GENERATION_INTERFACE + elif ( + model_content.inference_framework + == LLMInferenceFramework.TEXT_GENERATION_INFERENCE + ): if result["result"].get("generated_text") is not None: finished = True else: @@ -2108,14 +2155,8 @@ async def execute( raise UpstreamServiceError( status_code=500, content=result.get("error") ) # also change llms_v1.py that will return a 500 HTTPException so user can retry - - else: - yield CompletionStreamV1Response( - request_id=request_id, - output=None, - ) - elif model_content.inference_framework == LLMInferenceFramework.VLLM: - if res.status == TaskStatus.SUCCESS and result is not None: + # VLLM + elif model_content.inference_framework == LLMInferenceFramework.VLLM: token = None if request.return_token_log_probs: token = TokenOutput( @@ -2134,13 +2175,8 @@ async def execute( token=token, ), ) - else: - yield CompletionStreamV1Response( - request_id=request_id, - output=None, - ) - elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: - if res.status == TaskStatus.SUCCESS and result is not None: + # LIGHTLLM + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: token = None num_completion_tokens += 1 if request.return_token_log_probs: @@ -2159,13 +2195,8 @@ async def execute( token=token, ), ) - else: - yield CompletionStreamV1Response( - request_id=request_id, - output=None, - ) - elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: - if res.status == TaskStatus.SUCCESS and result is not None: + # TENSORRT_LLM + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: num_completion_tokens += 1 yield CompletionStreamV1Response( request_id=request_id, @@ -2176,15 +2207,9 @@ async def execute( num_completion_tokens=num_completion_tokens, ), ) - else: - yield CompletionStreamV1Response( - request_id=request_id, - output=None, - ) - else: - raise EndpointUnsupportedInferenceTypeException( - f"Unsupported inference framework {model_content.inference_framework}" - ) + # No else clause needed for an unsupported inference framework, since we check + # model_content.inference_framework in execute() prior to calling _response_chunk_generator, + # raising an exception if it is not one of the frameworks handled above. class ModelDownloadV1UseCase: diff --git a/model-engine/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py index 725b7795..b312f7eb 100644 --- a/model-engine/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -1,10 +1,13 @@ +import asyncio import datetime from typing import Any, Dict, Iterator, Tuple import pytest +import pytest_asyncio from fastapi import Depends, HTTPException from fastapi.security import HTTPBasicCredentials from fastapi.testclient import TestClient +from httpx import AsyncClient from model_engine_server.api.app import app from model_engine_server.api.dependencies import ( AUTH, @@ -90,6 +93,14 @@ def fake_auth(): app.dependency_overrides[verify_authentication] = {} +@pytest_asyncio.fixture(scope="session", autouse=True) +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + @pytest.fixture def get_test_client_wrapper(get_repositories_generator_wrapper): def get_test_client( @@ -159,6 +170,75 @@ def get_test_client( return get_test_client +@pytest.fixture +def get_async_test_client_wrapper(get_repositories_generator_wrapper): + def get_async_test_client( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents=None, + fake_model_endpoint_record_repository_contents=None, + fake_model_endpoint_infra_gateway_contents=None, + fake_batch_job_record_repository_contents=None, + fake_batch_job_progress_gateway_contents=None, + fake_docker_image_batch_job_bundle_repository_contents=None, + fake_docker_image_batch_job_gateway_contents=None, + fake_llm_fine_tuning_service_contents=None, + fake_file_storage_gateway_contents=None, + fake_file_system_gateway_contents=None, + fake_trigger_repository_contents=None, + fake_cron_job_gateway_contents=None, + fake_sync_inference_content=None, + ) -> AsyncClient: + if fake_docker_image_batch_job_gateway_contents is None: + fake_docker_image_batch_job_gateway_contents = {} + if fake_docker_image_batch_job_bundle_repository_contents is None: + fake_docker_image_batch_job_bundle_repository_contents = {} + if fake_batch_job_progress_gateway_contents is None: + fake_batch_job_progress_gateway_contents = {} + if fake_batch_job_record_repository_contents is None: + fake_batch_job_record_repository_contents = {} + if fake_model_endpoint_infra_gateway_contents is None: + fake_model_endpoint_infra_gateway_contents = {} + if fake_model_endpoint_record_repository_contents is None: + fake_model_endpoint_record_repository_contents = {} + if fake_model_bundle_repository_contents is None: + fake_model_bundle_repository_contents = {} + if fake_llm_fine_tuning_service_contents is None: + fake_llm_fine_tuning_service_contents = {} + if fake_file_storage_gateway_contents is None: + fake_file_storage_gateway_contents = {} + if fake_file_system_gateway_contents is None: + fake_file_system_gateway_contents = {} + if fake_trigger_repository_contents is None: + fake_trigger_repository_contents = {} + if fake_cron_job_gateway_contents is None: + fake_cron_job_gateway_contents = {} + if fake_sync_inference_content is None: + fake_sync_inference_content = {} + app.dependency_overrides[get_external_interfaces] = get_repositories_generator_wrapper( + fake_docker_repository_image_always_exists=fake_docker_repository_image_always_exists, + fake_model_bundle_repository_contents=fake_model_bundle_repository_contents, + fake_model_endpoint_record_repository_contents=fake_model_endpoint_record_repository_contents, + fake_model_endpoint_infra_gateway_contents=fake_model_endpoint_infra_gateway_contents, + fake_batch_job_record_repository_contents=fake_batch_job_record_repository_contents, + fake_batch_job_progress_gateway_contents=fake_batch_job_progress_gateway_contents, + fake_docker_image_batch_job_bundle_repository_contents=fake_docker_image_batch_job_bundle_repository_contents, + fake_docker_image_batch_job_gateway_contents=fake_docker_image_batch_job_gateway_contents, + fake_llm_fine_tuning_service_contents=fake_llm_fine_tuning_service_contents, + fake_file_storage_gateway_contents=fake_file_storage_gateway_contents, + fake_file_system_gateway_contents=fake_file_system_gateway_contents, + fake_trigger_repository_contents=fake_trigger_repository_contents, + fake_cron_job_gateway_contents=fake_cron_job_gateway_contents, + fake_sync_inference_content=fake_sync_inference_content, + ) + app.dependency_overrides[get_external_interfaces_read_only] = app.dependency_overrides[ + get_external_interfaces + ] + client = AsyncClient(app=app, base_url="http://test") + return client + + return get_async_test_client + + @pytest.fixture def simple_client(get_test_client_wrapper) -> TestClient: """Returns a Client with no initial contents and a Docker repository that always returns True""" diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 4ef3eafe..1c65cea6 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -158,15 +158,13 @@ def test_completion_sync_endpoint_not_found_returns_404( assert response_1.status_code == 404 -# When enabling this test, other tests fail with "RunTumeError got Future attached to a different loop" -# https://github.com/encode/starlette/issues/1315#issuecomment-980784457 -@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") -def test_completion_stream_success( +@pytest.mark.asyncio +async def test_completion_stream_success( llm_model_endpoint_streaming: ModelEndpoint, completion_stream_request: Dict[str, Any], - get_test_client_wrapper, + get_async_test_client_wrapper, ): # pragma: no cover - client = get_test_client_wrapper( + async with get_async_test_client_wrapper( fake_docker_repository_image_always_exists=True, fake_model_bundle_repository_contents={}, fake_model_endpoint_record_repository_contents={ @@ -178,34 +176,35 @@ def test_completion_stream_success( fake_batch_job_record_repository_contents={}, fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, - ) - with mock.patch( - "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", - return_value=5, - ): - response_1 = client.post( - f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", - auth=("no_user", ""), - json=completion_stream_request, - stream=True, - ) - assert response_1.status_code == 200 - count = 0 - for message in response_1: - decoded_message = message.decode("utf-8") - assert decoded_message.startswith("data: "), "SSE does not start with 'data: '" + ) as client: + with mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=5, + ): + async with client.stream( + method="POST", + url=f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + auth=("no_user", ""), + json=completion_stream_request, + ) as r: + assert r.status_code == 200 + count = 0 + async for message in r.aiter_bytes(): + decoded_message = message.decode("utf-8") + assert decoded_message.startswith( + "data: " + ), f"SSE does not start with 'data: ', message is '{decoded_message}'" - # strip 'data: ' prefix from Server-sent events format - json_str = decoded_message[len("data: ") :] - parsed_data = json.loads(json_str.strip()) - assert parsed_data["request_id"] is not None - assert parsed_data["output"] is None - assert parsed_data["error"] is None - count += 1 - assert count == 1 + # strip 'data: ' prefix from Server-sent events format + json_str = decoded_message[len("data: ") :] + parsed_data = json.loads(json_str.strip()) + assert parsed_data["request_id"] is not None + assert parsed_data["output"] is None + assert parsed_data["error"] is None + count += 1 + assert count == 1 -@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") def test_completion_stream_endpoint_not_found_returns_404( llm_model_endpoint_streaming: ModelEndpoint, completion_stream_request: Dict[str, Any], @@ -222,17 +221,42 @@ def test_completion_stream_endpoint_not_found_returns_404( fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, ) - response_1 = client.post( - f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + with client.stream( + method="POST", + url=f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", auth=("no_user", ""), json=completion_stream_request, - stream=True, - ) + ) as r: + assert r.status_code == 404 - assert response_1.status_code == 200 - for message in response_1: - assert "404" in message.decode("utf-8") +def test_completion_stream_misc_server_error_returns_500( + llm_model_endpoint_streaming: ModelEndpoint, + completion_stream_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + with mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.CompletionStreamV1UseCase.execute", + ) as mock_stream_usecase: + mock_stream_usecase.side_effect = RuntimeError("Some server side runtime error.") + with client.stream( + method="POST", + url=f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + auth=("no_user", ""), + json=completion_stream_request, + ) as r: + assert r.status_code == 500 @mock.patch( diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index 3d019016..80f21734 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Tuple from unittest.mock import AsyncMock, MagicMock, patch +import pytest from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.domain.entities import ModelBundle, ModelEndpoint from model_engine_server.domain.exceptions import ( @@ -375,15 +376,16 @@ def test_create_sync_task_returns_failure( assert response.json()["status"] == "FAILURE" -def test_create_streaming_task_success( +@pytest.mark.asyncio +async def test_create_streaming_task_success( model_bundle_5: ModelBundle, model_endpoint_streaming: ModelEndpoint, endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]], test_api_key: str, - get_test_client_wrapper, + get_async_test_client_wrapper, ): assert model_endpoint_streaming.infra_state is not None - client = get_test_client_wrapper( + async with get_async_test_client_wrapper( fake_docker_repository_image_always_exists=True, fake_model_bundle_repository_contents={ model_bundle_5.id: model_bundle_5, @@ -397,15 +399,19 @@ def test_create_streaming_task_success( fake_batch_job_record_repository_contents={}, fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, - ) - with client.stream( - method="POST", - url=f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", - auth=(test_api_key, ""), - json=endpoint_predict_request_1[1], - ) as response: - assert response.status_code == 200 - assert ( - response.read() - == b'data: {"status": "SUCCESS", "result": null, "traceback": null}\r\n\r\n' - ) + ) as client: + async with client.stream( + method="POST", + url=f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", + auth=(test_api_key, ""), + json=endpoint_predict_request_1[1], + ) as response: + assert response.status_code == 200 + count = 0 + async for message in response.aiter_bytes(): + assert ( + message + == b'data: {"status": "SUCCESS", "result": null, "traceback": null}\r\n\r\n' + ) + count += 1 + assert count == 1 diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index bb3a9604..4b4d39a3 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1257,7 +1257,7 @@ async def test_completion_stream_use_case_success( tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = use_case.execute( + response_1 = await use_case.execute( user=user, model_endpoint_name=llm_model_endpoint_streaming.record.name, request=completion_stream_request, @@ -1367,7 +1367,7 @@ async def test_completion_stream_vllm_use_case_success( tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = use_case.execute( + response_1 = await use_case.execute( user=user, model_endpoint_name=llm_model_endpoint_stream[0].record.name, request=completion_stream_request, @@ -1434,7 +1434,7 @@ async def test_completion_stream_text_generation_inference_use_case_success( tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = use_case.execute( + response_1 = await use_case.execute( user=user, model_endpoint_name=llm_model_endpoint_text_generation_inference.record.name, request=completion_stream_request, @@ -1496,7 +1496,7 @@ async def test_completion_stream_trt_llm_use_case_success( tokenizer_repository=fake_tokenizer_repository, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = use_case.execute( + response_1 = await use_case.execute( user=user, model_endpoint_name=llm_model_endpoint_trt_llm.record.name, request=completion_stream_request, From 20c15afc9d15bea80c3f951efe610526bfbdb212 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 2 Jul 2024 17:32:57 -0700 Subject: [PATCH 325/425] Increase default concurrency to 100 for http forwarder (#552) * increase default concurrency to 50 for http forwarder * even more hehe * codecoverage --- .../inference/forwarding/http_forwarder.py | 2 +- .../unit/inference/test_http_forwarder.py | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index 1fdb030b..2f6ad755 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -51,7 +51,7 @@ def get_streaming_forwarder_loader(): @lru_cache() def get_concurrency_limiter(): config = get_config() - concurrency = int(config.get("max_concurrency", 5)) + concurrency = int(config.get("max_concurrency", 100)) return MultiprocessingConcurrencyLimiter( concurrency=concurrency, fail_on_concurrency_limit=True ) diff --git a/model-engine/tests/unit/inference/test_http_forwarder.py b/model-engine/tests/unit/inference/test_http_forwarder.py index bad6e6b4..837812e3 100644 --- a/model-engine/tests/unit/inference/test_http_forwarder.py +++ b/model-engine/tests/unit/inference/test_http_forwarder.py @@ -10,6 +10,7 @@ from model_engine_server.inference.forwarding.forwarding import Forwarder from model_engine_server.inference.forwarding.http_forwarder import ( MultiprocessingConcurrencyLimiter, + get_concurrency_limiter, predict, ) from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( @@ -57,6 +58,32 @@ def json(self) -> dict: return mocked_static_json() +def mocked_get_config(): + return { + "sync": { + "user_port": 5005, + "user_hostname": "localhost", + "use_grpc": False, + "predict_route": "/predict", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": True, + "forward_http_status": True, + }, + "stream": { + "user_port": 5005, + "user_hostname": "localhost", + "predict_route": "/stream", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": False, + }, + "max_concurrency": 42, + } + + @pytest.fixture def post_inference_hooks_handler(): handler = PostInferenceHooksHandler( @@ -108,6 +135,13 @@ def mock_request(): ) +@mock.patch("model_engine_server.inference.forwarding.http_forwarder.get_config", mocked_get_config) +def test_get_concurrency_limiter(): + limiter = get_concurrency_limiter() + assert isinstance(limiter, MultiprocessingConcurrencyLimiter) + assert limiter.concurrency == 42 + + @mock.patch("requests.post", mocked_post) @mock.patch("requests.get", mocked_get) def test_http_service_429(mock_request, post_inference_hooks_handler): From 8860ee3204c4a18125d215cee6464d567bf911d6 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 3 Jul 2024 16:34:15 -0700 Subject: [PATCH 326/425] Use circleci AWS IAM role (#553) * Use circleci role * remove * remove --- .circleci/config.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index abeb67f0..d8aa3f75 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,6 +1,7 @@ version: 2.1 orbs: python: circleci/python@2.1.1 + aws-cli: circleci/aws-cli@3.1.5 workflows: ci: @@ -175,13 +176,14 @@ jobs: pip install -r model-engine/requirements.txt - install_client - install_server + - aws-cli/setup: + role-arn: ${CIRCLECI_ROLE_ARN} + aws-region: AWS_REGION - run: name: Run integration tests command: | pushd $HOME/project kubectl port-forward svc/model-engine 5001:80 & - export AWS_ACCESS_KEY_ID=$CIRCLECI_AWS_ACCESS_KEY - export AWS_SECRET_ACCESS_KEY=$CIRCLECI_AWS_SECRET_KEY export GIT_TAG=$CIRCLE_SHA1 pytest integration_tests From 1f474ba5f4c81ae31d46a3ade40323a77ac46f4d Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 5 Jul 2024 11:10:48 -0700 Subject: [PATCH 327/425] Allow hardware infer from client (#555) * Allow hardware infer from client * Fix --- clients/python/llmengine/data_types.py | 6 +-- clients/python/llmengine/model.py | 52 ++++++++++++++++++-------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 965006ec..263d16cb 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -164,9 +164,9 @@ class CreateLLMEndpointRequest(BaseModel): metadata: Dict[str, Any] # TODO: JSON type post_inference_hooks: Optional[List[str]] endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING - cpus: CpuSpecificationType - gpus: int - memory: StorageSpecificationType + cpus: Optional[CpuSpecificationType] + gpus: Optional[int] + memory: Optional[StorageSpecificationType] gpu_type: Optional[GpuType] storage: Optional[StorageSpecificationType] optimize_costs: Optional[bool] = None diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 1e18d3bf..4d5a6bb1 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -43,15 +43,15 @@ def create( quantize: Optional[Quantization] = None, checkpoint_path: Optional[str] = None, # General endpoint fields - cpus: int = 8, - memory: str = "24Gi", - storage: str = "40Gi", - gpus: int = 1, + cpus: Optional[int] = None, + memory: Optional[str] = None, + storage: Optional[str] = None, + gpus: Optional[int] = None, min_workers: int = 0, max_workers: int = 1, per_worker: int = 2, endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING, - gpu_type: Optional[str] = "nvidia-ampere-a10", + gpu_type: Optional[str] = None, high_priority: Optional[bool] = False, post_inference_hooks: Optional[List[PostInferenceHooks]] = None, default_callback_url: Optional[str] = None, @@ -91,21 +91,23 @@ def create( Can be either a folder or a tar file. Folder is preferred since we don't need to untar and model loads faster. For model weights, safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be longer). - cpus (`int`): + cpus (`Optional[int]`): Number of cpus each worker should get, e.g. 1, 2, etc. This must be greater - than or equal to 1. Recommendation is set it to 8 * GPU count. + than or equal to 1. Recommendation is set it to 8 * GPU count. Can be inferred from the model size. - memory (`str`): + memory (`Optional[str]`): Amount of memory each worker should get, e.g. "4Gi", "512Mi", etc. This must be a positive amount of memory. Recommendation is set it to 24Gi * GPU count. + Can be inferred from the model size. - storage (`str`): + storage (`Optional[str]`): Amount of local ephemeral storage each worker should get, e.g. "4Gi", "512Mi", etc. This must be a positive amount of storage. Recommendataion is 40Gi for 7B models, 80Gi for 13B models and 200Gi for 70B models. + Can be inferred from the model size. - gpus (`int`): - Number of gpus each worker should get, e.g. 0, 1, etc. + gpus (`Optional[int]`): + Number of gpus each worker should get, e.g. 0, 1, etc. Can be inferred from the model size. min_workers (`int`): The minimum number of workers. Must be greater than or equal to 0. This @@ -142,15 +144,15 @@ def create( gpu_type (`Optional[str]`): If specifying a non-zero number of gpus, this controls the type of gpu - requested. Here are the supported values: + requested. Can be inferred from the model size. Here are the supported values: - ``nvidia-tesla-t4`` - ``nvidia-ampere-a10`` - ``nvidia-ampere-a100`` - ``nvidia-ampere-a100e`` - ``nvidia-hopper-h100`` - - ``nvidia-hopper-h100-1g20gb`` - - ``nvidia-hopper-h100-3g40gb`` + - ``nvidia-hopper-h100-1g20gb`` # 1 slice of MIG with 1g compute and 20GB memory + - ``nvidia-hopper-h100-3g40gb`` # 1 slice of MIG with 3g compute and 40GB memory high_priority (`Optional[bool]`): Either ``True`` or ``False``. Enabling this will allow the created @@ -173,7 +175,27 @@ def create( Returns: CreateLLMEndpointResponse: creation task ID of the created Model. Currently not used. - === "Create Llama 2 7B model in Python" + === "Create Llama 2 70B model with hardware specs inferred in Python" + ```python + from llmengine import Model + + response = Model.create( + name="llama-2-70b-test" + model="llama-2-70b", + inference_framework_image_tag="0.9.4", + inference_framework=LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + num_shards=4, + checkpoint_path="s3://path/to/checkpoint", + min_workers=0, + max_workers=1, + per_worker=10, + endpoint_type=ModelEndpointType.STREAMING, + public_inference=False, + ) + + print(response.json()) + ``` + === "Create Llama 2 7B model with hardware specs specified in Python" ```python from llmengine import Model From 137f88dfea01bbdce69af8417bbcb74ec39506c8 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Fri, 5 Jul 2024 14:01:36 -0700 Subject: [PATCH 328/425] Fix AWS IAM role access (#556) * Fix AWS IAM role access * one more * move --- .circleci/config.yml | 6 +++--- .circleci/resources/.minikube-config-map | 5 +++-- .circleci/resources/.minikube-registry-creds | 6 +++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index d8aa3f75..1525385d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -99,6 +99,9 @@ jobs: executor: ubuntu-large steps: - checkout + - aws-cli/setup: + role-arn: ${CIRCLECI_ROLE_ARN} + aws-region: AWS_REGION - run: name: Build Docker Image command: | @@ -176,9 +179,6 @@ jobs: pip install -r model-engine/requirements.txt - install_client - install_server - - aws-cli/setup: - role-arn: ${CIRCLECI_ROLE_ARN} - aws-region: AWS_REGION - run: name: Run integration tests command: | diff --git a/.circleci/resources/.minikube-config-map b/.circleci/resources/.minikube-config-map index 37ef6f32..620e3ab1 100644 --- a/.circleci/resources/.minikube-config-map +++ b/.circleci/resources/.minikube-config-map @@ -1,4 +1,5 @@ # Configmap for AWS credentials inside minikube. [default] -aws_access_key_id = $CIRCLECI_AWS_ACCESS_KEY -aws_secret_access_key = $CIRCLECI_AWS_SECRET_KEY +aws_access_key_id = $AWS_ACCESS_KEY_ID +aws_secret_access_key = $AWS_SECRET_ACCESS_KEY +aws_session_token = $AWS_SESSION_TOKEN \ No newline at end of file diff --git a/.circleci/resources/.minikube-registry-creds b/.circleci/resources/.minikube-registry-creds index a1ef51f2..37f4b1fa 100644 --- a/.circleci/resources/.minikube-registry-creds +++ b/.circleci/resources/.minikube-registry-creds @@ -3,9 +3,9 @@ # See expect syntax here: https://manpages.ubuntu.com/manpages/trusty/man1/expect.1.html spawn minikube addons configure registry-creds expect "Do you want to enable AWS Elastic Container Registry?" { send "y\r" } -expect "Enter AWS Access Key ID:" { send "$CIRCLECI_AWS_ACCESS_KEY\r" } -expect "Enter AWS Secret Access Key:" { send "$CIRCLECI_AWS_SECRET_KEY\r" } -expect "Enter AWS Session Token:" { send "\r" } +expect "Enter AWS Access Key ID:" { send "$AWS_ACCESS_KEY_ID\r" } +expect "Enter AWS Secret Access Key:" { send "$AWS_SECRET_ACCESS_KEY\r" } +expect "Enter AWS Session Token:" { send "$AWS_SESSION_TOKEN\r" } expect "Enter AWS Region:" { send "us-west-2\r" } expect "Enter 12 digit AWS Account ID (Comma separated list):" { send "$CIRCLECI_AWS_ACCOUNT_ID\r" } expect "Enter ARN of AWS role to assume:" { send "\r" } From d5d91937936f825ef08bb359d26cad73e81a82ee Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 8 Jul 2024 11:38:24 -0700 Subject: [PATCH 329/425] More rigorous endpoint update handling (#558) * Fix metadata update * Update tests --- .../use_cases/llm_model_endpoint_use_cases.py | 43 +++++-- model-engine/tests/unit/domain/conftest.py | 16 +++ .../tests/unit/domain/test_llm_use_cases.py | 118 ++++++++++++++++++ 3 files changed, 166 insertions(+), 11 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index cf8b1f55..f19e5ab5 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -111,6 +111,8 @@ logger = make_logger(logger_name()) +LLM_METADATA_KEY = "_llm" +RESERVED_METADATA_KEYS = [LLM_METADATA_KEY, CONVERTED_FROM_ARTIFACT_LIKE_KEY] INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = { LLMInferenceFramework.DEEPSPEED: "instant-llm", @@ -279,11 +281,14 @@ async def _get_recommended_hardware_config_map() -> Dict[str, Any]: def _model_endpoint_entity_to_get_llm_model_endpoint_response( model_endpoint: ModelEndpoint, ) -> GetLLMModelEndpointV1Response: - if model_endpoint.record.metadata is None or "_llm" not in model_endpoint.record.metadata: + if ( + model_endpoint.record.metadata is None + or LLM_METADATA_KEY not in model_endpoint.record.metadata + ): raise ObjectHasInvalidValueException( f"Can't translate model entity to response, endpoint {model_endpoint.record.id} does not have LLM metadata." ) - llm_metadata = model_endpoint.record.metadata.get("_llm", {}) + llm_metadata = model_endpoint.record.metadata.get(LLM_METADATA_KEY, {}) response = GetLLMModelEndpointV1Response( id=model_endpoint.record.id, name=model_endpoint.record.name, @@ -962,7 +967,7 @@ async def execute( aws_role = self.authz_module.get_aws_role_for_user(user) results_s3_bucket = self.authz_module.get_s3_bucket_for_user(user) - request.metadata["_llm"] = asdict( + request.metadata[LLM_METADATA_KEY] = asdict( LLMMetadata( model_name=request.model_name, source=request.source, @@ -1088,6 +1093,16 @@ async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndp return _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) +def merge_metadata( + request: Optional[Dict[str, Any]], record: Optional[Dict[str, Any]] +) -> Optional[Dict[str, Any]]: + if request is None: + return record + if record is None: + return request + return {**record, **request} + + class UpdateLLMModelEndpointV1UseCase: def __init__( self, @@ -1131,6 +1146,7 @@ async def execute( raise EndpointInfraStateNotFound(error_msg) infra_state = model_endpoint.infra_state + metadata: Optional[Dict[str, Any]] if ( request.model_name @@ -1140,7 +1156,7 @@ async def execute( or request.quantize or request.checkpoint_path ): - llm_metadata = (model_endpoint.record.metadata or {}).get("_llm", {}) + llm_metadata = (model_endpoint.record.metadata or {}).get(LLM_METADATA_KEY, {}) inference_framework = llm_metadata["inference_framework"] if request.inference_framework_image_tag == "latest": @@ -1177,7 +1193,7 @@ async def execute( ) metadata = endpoint_record.metadata or {} - metadata["_llm"] = asdict( + metadata[LLM_METADATA_KEY] = asdict( LLMMetadata( model_name=model_name, source=source, @@ -1188,7 +1204,7 @@ async def execute( checkpoint_path=checkpoint_path, ) ) - request.metadata = metadata + endpoint_record.metadata = metadata # For resources that are not specified in the update endpoint request, pass in resource from # infra_state to make sure that after the update, all resources are valid and in sync. @@ -1209,15 +1225,20 @@ async def execute( endpoint_type=endpoint_record.endpoint_type, ) - if request.metadata is not None and CONVERTED_FROM_ARTIFACT_LIKE_KEY in request.metadata: - raise ObjectHasInvalidValueException( - f"{CONVERTED_FROM_ARTIFACT_LIKE_KEY} is a reserved metadata key and cannot be used by user." - ) + if request.metadata is not None: + # If reserved metadata key is provided, throw ObjectHasInvalidValueException + for key in RESERVED_METADATA_KEYS: + if key in request.metadata: + raise ObjectHasInvalidValueException( + f"{key} is a reserved metadata key and cannot be used by user." + ) + + metadata = merge_metadata(request.metadata, endpoint_record.metadata) updated_endpoint_record = await self.model_endpoint_service.update_model_endpoint( model_endpoint_id=model_endpoint_id, model_bundle_id=bundle.id, - metadata=request.metadata, + metadata=metadata, post_inference_hooks=request.post_inference_hooks, cpus=request.cpus, gpus=request.gpus, diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index e9e37cf2..937f3cfc 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -31,6 +31,9 @@ Quantization, StreamingEnhancedRunnableImageFlavor, ) +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( + CONVERTED_FROM_ARTIFACT_LIKE_KEY, +) @pytest.fixture @@ -265,6 +268,19 @@ def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request: ) +@pytest.fixture +def update_llm_model_endpoint_request_only_workers() -> UpdateLLMModelEndpointV1Request: + return UpdateLLMModelEndpointV1Request( + min_workers=5, + max_workers=10, + ) + + +@pytest.fixture +def update_llm_model_endpoint_request_bad_metadata() -> UpdateLLMModelEndpointV1Request: + return UpdateLLMModelEndpointV1Request(metadata={CONVERTED_FROM_ARTIFACT_LIKE_KEY: {}}) + + @pytest.fixture def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 4b4d39a3..8e310211 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -51,6 +51,7 @@ UpdateLLMModelEndpointV1UseCase, _fill_hardware_info, _infer_hardware, + merge_metadata, validate_and_update_completion_params, validate_checkpoint_files, ) @@ -614,6 +615,7 @@ async def test_update_model_endpoint_use_case_success( fake_llm_model_endpoint_service, create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, update_llm_model_endpoint_request: UpdateLLMModelEndpointV1Request, + update_llm_model_endpoint_request_only_workers: UpdateLLMModelEndpointV1Request, ): fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository bundle_use_case = CreateModelBundleV2UseCase( @@ -687,6 +689,102 @@ async def test_update_model_endpoint_use_case_success( == update_llm_model_endpoint_request.max_workers ) + update_response2 = await update_use_case.execute( + user=user, + model_endpoint_name=create_llm_model_endpoint_request_streaming.name, + request=update_llm_model_endpoint_request_only_workers, + ) + assert update_response2.endpoint_creation_task_id + + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_streaming.model_name, + "source": create_llm_model_endpoint_request_streaming.source, + "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, + "inference_framework_image_tag": "fake_docker_repository_latest_image_tag", + "num_shards": create_llm_model_endpoint_request_streaming.num_shards, + "quantize": None, + "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, + } + } + assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory + assert ( + endpoint.infra_state.deployment_state.min_workers + == update_llm_model_endpoint_request_only_workers.min_workers + ) + assert ( + endpoint.infra_state.deployment_state.max_workers + == update_llm_model_endpoint_request_only_workers.max_workers + ) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag", + mocked__get_latest_tag(), +) +async def test_update_model_endpoint_use_case_failure( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + fake_llm_model_endpoint_service, + create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, + update_llm_model_endpoint_request_bad_metadata: UpdateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + create_use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + update_use_case = UpdateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + + await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + fake_llm_model_endpoint_service.add_model_endpoint(endpoint) + + with pytest.raises(ObjectHasInvalidValueException): + await update_use_case.execute( + user=user, + model_endpoint_name=create_llm_model_endpoint_request_streaming.name, + request=update_llm_model_endpoint_request_bad_metadata, + ) + def mocked_auto_tokenizer_from_pretrained(*args, **kwargs): # noqa class mocked_encode: @@ -2241,3 +2339,23 @@ async def test_create_batch_completions( "-c", "ddtrace-run python vllm_batch.py", ] + + +def test_merge_metadata(): + request_metadata = { + "key1": "value1", + "key2": "value2", + } + + endpoint_metadata = { + "key1": "value0", + "key3": "value3", + } + + assert merge_metadata(request_metadata, None) == request_metadata + assert merge_metadata(None, endpoint_metadata) == endpoint_metadata + assert merge_metadata(request_metadata, endpoint_metadata) == { + "key1": "value1", + "key2": "value2", + "key3": "value3", + } From 0bacaa562c2ca6ed4cbc64a1aaf456f1da6c89d4 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 9 Jul 2024 14:55:17 -0700 Subject: [PATCH 330/425] Update vllm server to be openai compatible (#560) * Update vllm engine to be openai compatible * Bump vllm to 0.5.1 * Revert 0.5.1 -- need some CUDA version upgrade * Small cleanup --- .../inference/vllm/vllm_server.py | 103 ++++++++++++++++-- 1 file changed, 92 insertions(+), 11 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index e32b0834..68a9a263 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -1,33 +1,86 @@ import argparse +import asyncio import code import json +import logging import os import signal import subprocess import traceback +from logging import Logger from typing import AsyncGenerator, Dict, List, Optional import uvicorn from fastapi import BackgroundTasks, FastAPI, HTTPException, Request -from fastapi.responses import Response, StreamingResponse +from fastapi.responses import JSONResponse, Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncEngineDeadError, AsyncLLMEngine +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.protocol import ChatCompletionRequest as OpenAIChatCompletionRequest +from vllm.entrypoints.openai.protocol import ChatCompletionResponse as OpenAIChatCompletionResponse from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest +from vllm.entrypoints.openai.protocol import ErrorResponse +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.outputs import CompletionOutput from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob from vllm.utils import random_uuid +from vllm.version import __version__ as VLLM_VERSION + +logging.basicConfig( + format="%(asctime)s | %(levelname)s: %(message)s", + datefmt="%b/%d %H:%M:%S", + level=logging.INFO, +) + +logger = Logger("vllm_server") TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds app = FastAPI() +openai_serving_chat: OpenAIServingChat +openai_serving_completion: OpenAIServingCompletion +openai_serving_embedding: OpenAIServingEmbedding + @app.get("/healthz") @app.get("/health") -def healthcheck(): - return "OK" +async def healthcheck(): + await openai_serving_chat.engine.check_health() + return Response(status_code=200) + + +@app.get("/v1/models") +async def show_available_models(): + models = await openai_serving_chat.show_available_models() + return JSONResponse(content=models.model_dump()) + + +@app.post("/v1/chat/completions") +async def create_chat_completion(request: OpenAIChatCompletionRequest, raw_request: Request): + generator = await openai_serving_chat.create_chat_completion(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, media_type="text/event-stream") + else: + assert isinstance(generator, OpenAIChatCompletionResponse) + return JSONResponse(content=generator.model_dump()) + + +@app.post("/v1/completions") +async def create_completion(request: OpenAICompletionRequest, raw_request: Request): + generator = await openai_serving_completion.create_completion(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, media_type="text/event-stream") + else: + return JSONResponse(content=generator.model_dump()) @app.post("/predict") @@ -135,7 +188,7 @@ async def stream_results() -> AsyncGenerator[str, None]: return Response(content=json.dumps(ret)) except AsyncEngineDeadError as e: - print(f"The vllm engine is dead, exiting the pod: {e}") + logger.error(f"The vllm engine is dead, exiting the pod: {e}") os.kill(os.getpid(), signal.SIGINT) raise e @@ -151,7 +204,7 @@ def get_gpu_free_memory(): gpu_memory = [int(x) for x in output.strip().split("\n")] return gpu_memory except Exception as e: - print(f"Error getting GPU memory: {e}") + logger.warn(f"Error getting GPU memory: {e}") return None @@ -162,7 +215,7 @@ def check_unknown_startup_memory_usage(): min_mem = min(gpu_free_memory) max_mem = max(gpu_free_memory) if max_mem - min_mem > 10: - print( + logger.warn( f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." ) try: @@ -170,9 +223,9 @@ def check_unknown_startup_memory_usage(): output = subprocess.run( ["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True ).stdout - print(f"Processes using GPU: {output}") + logger.info(f"Processes using GPU: {output}") except Exception as e: - print(f"Error getting processes using GPU: {e}") + logger.error(f"Error getting processes using GPU: {e}") def debug(sig, frame): @@ -200,23 +253,51 @@ def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: return [extract_logprobs(logprobs) for logprobs in output_logprobs] +def parse_args(): + parser = make_arg_parser() + return parser.parse_args() + + if __name__ == "__main__": check_unknown_startup_memory_usage() + parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) # None == IPv4 / IPv6 dualstack parser.add_argument("--port", type=int, default=5005) parser = AsyncEngineArgs.add_cli_args(parser) - args = parser.parse_args() + args = parse_args() + + logger.info("vLLM version %s", VLLM_VERSION) + logger.info("args: %s", args) + + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + signal.signal(signal.SIGUSR1, debug) engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) - signal.signal(signal.SIGUSR1, debug) + model_config = asyncio.run(engine.get_model_config()) + + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + served_model_names, + args.response_role, + args.lora_modules, + args.chat_template, + ) + openai_serving_completion = OpenAIServingCompletion( + engine, model_config, served_model_names, args.lora_modules + ) uvicorn.run( app, host=args.host, port=args.port, - log_level="debug", + log_level=args.uvicorn_log_level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE, ) From 72a2b5ae9a0461d2abf9e96166f613cae6434bd6 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Wed, 10 Jul 2024 10:18:39 -0700 Subject: [PATCH 331/425] Fix healthcheck_route and predict_route for async endpoints (#554) --- .../templates/service_template_config_map.yaml | 4 ++-- .../service_template_config_map_circleci.yaml | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 9300637a..7c2477e8 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -220,9 +220,9 @@ data: - --task-visibility - "VISIBILITY_24H" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.async.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" {{- if eq $celery_broker_type "sqs" }} - --sqs-url - "${SQS_QUEUE_URL}" diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index bfe5c492..1c0cb8e2 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -122,9 +122,9 @@ data: - --task-visibility - "VISIBILITY_24H" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.async.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" - --num-workers - "${PER_WORKER}" env: @@ -393,9 +393,9 @@ data: - --task-visibility - "VISIBILITY_24H" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.async.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" - --num-workers - "${PER_WORKER}" env: @@ -1343,9 +1343,9 @@ data: - --task-visibility - "VISIBILITY_24H" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.async.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" - --num-workers - "${PER_WORKER}" env: @@ -1621,9 +1621,9 @@ data: - --task-visibility - "VISIBILITY_24H" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.async.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" - --num-workers - "${PER_WORKER}" env: From 3ff1196429c84925b3651ac9b82b499fb93c96ff Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 10 Jul 2024 12:30:21 -0700 Subject: [PATCH 332/425] Fix some oddities in the client (#562) --- clients/python/llmengine/errors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clients/python/llmengine/errors.py b/clients/python/llmengine/errors.py index 3383878a..27008024 100644 --- a/clients/python/llmengine/errors.py +++ b/clients/python/llmengine/errors.py @@ -81,7 +81,7 @@ def parse_error(status_code: int, content: bytes) -> Exception: try: payload = json.loads(content) message = payload["detail"] - except json.JSONDecodeError: + except (json.JSONDecodeError, KeyError): message = content.decode("utf-8") # Try to parse a APIInference error @@ -93,7 +93,7 @@ def parse_error(status_code: int, content: bytes) -> Exception: return NotFoundError(message) if status_code == 429: return RateLimitExceededError(message) - if 600 < status_code <= 500: + if 500 <= status_code < 600: return ServerError(status_code, message) # Fallback to an unknown error From c0cea6070868abde6f3c4a7ad090ec7a97f10c5b Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 11 Jul 2024 01:32:56 -0700 Subject: [PATCH 333/425] Bump pydantic to 2.8.2 (#561) * Bump pydantic to 2.8.2 * Run pydantic-bump * More updates * Fix openapi generation * Fix tests * black formatting * Fix integration tests --- .circleci/config.yml | 40 +++---- clients/python/llmengine/data_types.py | 5 +- integration_tests/rest_api_utils.py | 7 +- integration_tests/test_endpoints.py | 2 +- .../common/dtos/batch_jobs.py | 49 ++++----- .../model_engine_server/common/dtos/core.py | 17 +++ .../common/dtos/docker_repository.py | 4 +- .../common/dtos/endpoint_builder.py | 16 +-- .../model_engine_server/common/dtos/llms.py | 101 +++++++++--------- .../common/dtos/model_bundles.py | 46 ++++---- .../common/dtos/model_endpoints.py | 54 +++++----- .../model_engine_server/common/dtos/tasks.py | 10 +- .../common/dtos/triggers.py | 14 ++- .../domain/entities/batch_job_entity.py | 16 +-- .../docker_image_batch_job_bundle_entity.py | 19 ++-- .../domain/entities/llm_fine_tune_entity.py | 6 +- .../domain/entities/model_bundle_entity.py | 49 ++++----- .../domain/entities/model_endpoint_entity.py | 36 +++---- .../domain/entities/trigger_entity.py | 9 +- .../gateways/monitoring_metrics_gateway.py | 2 +- .../use_cases/llm_model_endpoint_use_cases.py | 20 ++-- .../inference/batch_inference/dto.py | 2 +- .../model_engine_server/inference/common.py | 2 +- .../inference/post_inference_hooks.py | 4 +- .../inference/requirements_base.txt | 4 +- ..._async_model_endpoint_inference_gateway.py | 4 +- .../live_model_endpoints_schema_gateway.py | 56 +++++++--- .../k8s_endpoint_resource_delegate.py | 2 +- ...ocker_image_batch_job_bundle_repository.py | 2 +- .../repositories/db_trigger_repository.py | 2 +- .../services/live_endpoint_builder_service.py | 1 + model-engine/requirements.in | 2 +- model-engine/requirements.txt | 8 +- .../tests/integration/inference/conftest.py | 4 +- .../inference/test_async_inference.py | 2 +- model-engine/tests/unit/api/conftest.py | 6 +- model-engine/tests/unit/api/test_tasks.py | 3 +- model-engine/tests/unit/conftest.py | 42 ++++---- model-engine/tests/unit/domain/conftest.py | 2 +- .../tests/unit/domain/test_entities.py | 4 +- .../tests/unit/domain/test_llm_use_cases.py | 36 +++---- ...test_live_async_model_inference_gateway.py | 2 +- .../test_live_batch_job_progress_gateway.py | 2 +- requirements-docs.txt | 2 +- 44 files changed, 373 insertions(+), 343 deletions(-) create mode 100644 model-engine/model_engine_server/common/dtos/core.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 1525385d..5b636bee 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -74,7 +74,7 @@ jobs: resource_class: small parallelism: 1 steps: - - add_ssh_keys: # gives write access to CircleCI worker + - add_ssh_keys: # gives write access to CircleCI worker fingerprints: - "76:0c:1b:9e:e3:6a:c3:5c:6f:24:91:ef:7c:54:d2:7a" - checkout # checkout source code to working directory @@ -157,10 +157,10 @@ jobs: DOCKER_BUILDKIT=1 docker build -f model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile \ --build-arg BASE_IMAGE=temp:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1 \ --build-arg REQUIREMENTS_FILE="$CIRCLE_SHA1-requirements.txt" \ - -t $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1-021694 . + -t $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1-b8c25b . rm $CIRCLE_SHA1-requirements.txt - minikube --logtostderr -v 1 image load $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1-021694 + minikube --logtostderr -v 1 image load $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1-b8c25b - run: name: Install helm chart command: | @@ -207,23 +207,23 @@ commands: install_server: description: Installs LLM Engine server steps: - - python/install-packages: - pkg-manager: pip - app-dir: model-engine - - python/install-packages: - pkg-manager: pip - app-dir: model-engine - pip-dependency-file: requirements-test.txt - - python/install-packages: - pkg-manager: pip - app-dir: model-engine - pip-dependency-file: requirements_override.txt - - run: - name: Install Server - command: | - pushd model-engine - pip install -e . - popd + - python/install-packages: + pkg-manager: pip + app-dir: model-engine + - python/install-packages: + pkg-manager: pip + app-dir: model-engine + pip-dependency-file: requirements-test.txt + - python/install-packages: + pkg-manager: pip + app-dir: model-engine + pip-dependency-file: requirements_override.txt + - run: + name: Install Server + command: | + pushd model-engine + pip install -e . + popd install_client: description: Install LLM Engine client steps: diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 263d16cb..7f775e5a 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -6,9 +6,10 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union -import pydantic +from pydantic.version import VERSION as PYDANTIC_VERSION -if int(pydantic.__version__.split(".")[0]) > 1: +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") +if PYDANTIC_V2: from pydantic.v1 import BaseModel, Field, HttpUrl else: from pydantic import BaseModel, Field, HttpUrl # type: ignore diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index f77d96ea..db087992 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -59,7 +59,12 @@ def my_model(**keyword_args): "framework_type": "pytorch", "pytorch_image_tag": "1.7.1-cuda11.0-cudnn8-runtime", }, - "requirements": ["cloudpickle==2.1.0", "pyyaml==6.0"], + "requirements": [ + "cloudpickle==2.1.0", + "pyyaml==6.0", + "pydantic==2.8.2", + "fastapi==0.110.0", + ], "location": "s3://model-engine-integration-tests/model_bundles/echo_bundle", }, } diff --git a/integration_tests/test_endpoints.py b/integration_tests/test_endpoints.py index 5b0a6404..5d7eae2a 100644 --- a/integration_tests/test_endpoints.py +++ b/integration_tests/test_endpoints.py @@ -232,7 +232,7 @@ def test_sync_streaming_model_endpoint(capsys): for response in task_responses: assert ( response.strip() - == 'data: {"status": "SUCCESS", "result": {"result": {"y": 1}}, "traceback": null}' + == 'data: {"status":"SUCCESS","result":{"result":{"y":1}},"traceback":null}' ) finally: delete_model_endpoint(create_endpoint_request["name"], user) diff --git a/model-engine/model_engine_server/common/dtos/batch_jobs.py b/model-engine/model_engine_server/common/dtos/batch_jobs.py index ce1af0c8..0600df22 100644 --- a/model-engine/model_engine_server/common/dtos/batch_jobs.py +++ b/model-engine/model_engine_server/common/dtos/batch_jobs.py @@ -13,20 +13,21 @@ GpuType, StorageSpecificationType, ) -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, ConfigDict, model_validator class CreateBatchJobResourceRequests(BaseModel): - cpus: Optional[CpuSpecificationType] - memory: Optional[StorageSpecificationType] - gpus: Optional[int] - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - max_workers: Optional[int] - per_worker: Optional[int] + cpus: Optional[CpuSpecificationType] = None + memory: Optional[StorageSpecificationType] = None + gpus: Optional[int] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + max_workers: Optional[int] = None + per_worker: Optional[int] = None class CreateBatchJobV1Request(BaseModel): + model_config = ConfigDict(protected_namespaces=()) model_bundle_id: str input_path: str serialization_format: BatchJobSerializationFormat @@ -41,10 +42,10 @@ class CreateBatchJobV1Response(BaseModel): class GetBatchJobV1Response(BaseModel): status: BatchJobStatus - result: Optional[str] + result: Optional[str] = None duration: timedelta - num_tasks_pending: Optional[int] - num_tasks_completed: Optional[int] + num_tasks_pending: Optional[int] = None + num_tasks_completed: Optional[int] = None class UpdateBatchJobV1Request(BaseModel): @@ -64,9 +65,7 @@ class CreateDockerImageBatchJobResourceRequests(BaseModel): gpus: Optional[int] = None gpu_type: Optional[GpuType] = None storage: Optional[StorageSpecificationType] = None - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) @classmethod def merge_requests( @@ -93,7 +92,7 @@ def common_requests( class CreateDockerImageBatchJobV1Request(BaseModel): docker_image_batch_job_bundle_name: Optional[str] = None docker_image_batch_job_bundle_id: Optional[str] = None - job_config: Optional[Dict[str, Any]] + job_config: Optional[Dict[str, Any]] = None # TODO also expose a separate argument to pass an s3file to the job, as opposed to job_config labels: Dict[str, str] # TODO this probably should go in the bundle @@ -103,7 +102,7 @@ class CreateDockerImageBatchJobV1Request(BaseModel): override_job_max_runtime_s: Optional[int] = None - @root_validator + @model_validator(mode="before") def exactly_one_name_or_id(cls, values): bundle_name = values.get("docker_image_batch_job_bundle_name") bundle_id = values.get("docker_image_batch_job_bundle_id") @@ -166,16 +165,14 @@ class DockerImageBatchJobBundleV1Response(BaseModel): image_tag: str command: List[str] env: Dict[str, str] - mount_location: Optional[str] - cpus: Optional[str] - memory: Optional[str] - storage: Optional[str] - gpus: Optional[int] - gpu_type: Optional[str] - public: Optional[bool] - - class Config: - orm_mode = True + mount_location: Optional[str] = None + cpus: Optional[str] = None + memory: Optional[str] = None + storage: Optional[str] = None + gpus: Optional[int] = None + gpu_type: Optional[str] = None + public: Optional[bool] = None + model_config = ConfigDict(from_attributes=True) class ListDockerImageBatchJobBundleV1Response(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/core.py b/model-engine/model_engine_server/common/dtos/core.py new file mode 100644 index 00000000..ad709658 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/core.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel, BeforeValidator, ConfigDict, HttpUrl, TypeAdapter +from typing_extensions import Annotated + +# See: https://github.com/pydantic/pydantic/issues/7186 +# pydantic v2 doesn't treat HttpUrl the same way as in v1 which causes various issue +# This is an attempt to make it behave as similar as possible +HttpUrlTypeAdapter = TypeAdapter(HttpUrl) +HttpUrlStr = Annotated[ + str, + BeforeValidator(lambda value: HttpUrlTypeAdapter.validate_python(value) and value), +] + + +class LLMEngineModel(BaseModel): + """Common pydantic configurations for model engine""" + + model_config = ConfigDict(protected_namespaces=()) diff --git a/model-engine/model_engine_server/common/dtos/docker_repository.py b/model-engine/model_engine_server/common/dtos/docker_repository.py index 6e4651d9..694c4098 100644 --- a/model-engine/model_engine_server/common/dtos/docker_repository.py +++ b/model-engine/model_engine_server/common/dtos/docker_repository.py @@ -10,8 +10,8 @@ class BuildImageRequest(BaseModel): base_path: str dockerfile: str base_image: str - requirements_folder: Optional[str] - substitution_args: Optional[Dict[str, str]] + requirements_folder: Optional[str] = None + substitution_args: Optional[Dict[str, str]] = None class BuildImageResponse(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/endpoint_builder.py b/model-engine/model_engine_server/common/dtos/endpoint_builder.py index 0edbeaaf..8ec2d2f9 100644 --- a/model-engine/model_engine_server/common/dtos/endpoint_builder.py +++ b/model-engine/model_engine_server/common/dtos/endpoint_builder.py @@ -20,19 +20,19 @@ class BuildEndpointRequest(BaseModel): cpus: CpuSpecificationType gpus: int memory: StorageSpecificationType - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None optimize_costs: bool aws_role: str results_s3_bucket: str - child_fn_info: Optional[Dict[str, Any]] # TODO: remove this if we don't need it. - post_inference_hooks: Optional[List[str]] + child_fn_info: Optional[Dict[str, Any]] = None # TODO: remove this if we don't need it. + post_inference_hooks: Optional[List[str]] = None labels: Dict[str, str] - billing_tags: Optional[Dict[str, Any]] + billing_tags: Optional[Dict[str, Any]] = None prewarm: bool = True - high_priority: Optional[bool] - default_callback_url: Optional[str] - default_callback_auth: Optional[CallbackAuth] + high_priority: Optional[bool] = None + default_callback_url: Optional[str] = None + default_callback_auth: Optional[CallbackAuth] = None class BuildEndpointStatus(str, Enum): diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 40d6f2ca..b35bff36 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional +from model_engine_server.common.dtos.core import HttpUrlStr from model_engine_server.common.dtos.model_endpoints import ( CpuSpecificationType, GetModelEndpointV1Response, @@ -23,7 +24,7 @@ ModelEndpointStatus, Quantization, ) -from pydantic import BaseModel, Field, HttpUrl +from pydantic import BaseModel, ConfigDict, Field class CreateLLMModelEndpointV1Request(BaseModel): @@ -51,23 +52,23 @@ class CreateLLMModelEndpointV1Request(BaseModel): # General endpoint fields metadata: Dict[str, Any] # TODO: JSON type - post_inference_hooks: Optional[List[str]] + post_inference_hooks: Optional[List[str]] = None endpoint_type: ModelEndpointType = ModelEndpointType.SYNC - cpus: Optional[CpuSpecificationType] - gpus: Optional[int] - memory: Optional[StorageSpecificationType] - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None min_workers: int max_workers: int per_worker: int labels: Dict[str, str] - prewarm: Optional[bool] - high_priority: Optional[bool] - billing_tags: Optional[Dict[str, Any]] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = True # LLM endpoints are public by default. @@ -99,43 +100,43 @@ class ListLLMModelEndpointsV1Response(BaseModel): class UpdateLLMModelEndpointV1Request(BaseModel): # LLM specific fields - model_name: Optional[str] - source: Optional[LLMSource] - inference_framework_image_tag: Optional[str] - num_shards: Optional[int] + model_name: Optional[str] = None + source: Optional[LLMSource] = None + inference_framework_image_tag: Optional[str] = None + num_shards: Optional[int] = None """ Number of shards to distribute the model onto GPUs. """ - quantize: Optional[Quantization] + quantize: Optional[Quantization] = None """ Whether to quantize the model. """ - checkpoint_path: Optional[str] + checkpoint_path: Optional[str] = None """ Path to the checkpoint to load the model from. """ # General endpoint fields - metadata: Optional[Dict[str, Any]] - post_inference_hooks: Optional[List[str]] - cpus: Optional[CpuSpecificationType] - gpus: Optional[int] - memory: Optional[StorageSpecificationType] - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] - min_workers: Optional[int] - max_workers: Optional[int] - per_worker: Optional[int] - labels: Optional[Dict[str, str]] - prewarm: Optional[bool] - high_priority: Optional[bool] - billing_tags: Optional[Dict[str, Any]] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] - public_inference: Optional[bool] + metadata: Optional[Dict[str, Any]] = None + post_inference_hooks: Optional[List[str]] = None + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None + min_workers: Optional[int] = None + max_workers: Optional[int] = None + per_worker: Optional[int] = None + labels: Optional[Dict[str, str]] = None + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = None class UpdateLLMModelEndpointV1Response(BaseModel): @@ -225,7 +226,7 @@ class CompletionSyncV1Response(BaseModel): Response object for a synchronous prompt completion task. """ - request_id: Optional[str] + request_id: Optional[str] = None output: Optional[CompletionOutput] = None @@ -323,7 +324,7 @@ class CompletionStreamV1Response(BaseModel): Response object for a stream prompt completion task. """ - request_id: Optional[str] + request_id: Optional[str] = None output: Optional[CompletionStreamOutput] = None error: Optional[StreamError] = None """Error of the response (if any).""" @@ -520,7 +521,9 @@ class CreateBatchCompletionsRequest(BaseModel): Request object for batch completions. """ - input_data_path: Optional[str] + model_config = ConfigDict(protected_namespaces=()) + + input_data_path: Optional[str] = None output_data_path: str """ Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. @@ -530,10 +533,14 @@ class CreateBatchCompletionsRequest(BaseModel): Either `input_data_path` or `content` needs to be provided. When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. """ - model_config: CreateBatchCompletionsModelConfig + model_cfg: CreateBatchCompletionsModelConfig = Field(alias="model_config") """ Model configuration for the batch inference. Hardware configurations are inferred. + + We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + reserves model_config as a keyword. """ + data_parallelism: Optional[int] = Field(default=1, ge=1, le=64) """ Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference. @@ -555,14 +562,6 @@ class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest): hidden from the DTO exposed to the client. """ - model_cfg: CreateBatchCompletionsModelConfig - """ - Model configuration for the batch inference. Hardware configurations are inferred. - - We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which - reserves model_config as a keyword. - """ - max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0) """ Maximum GPU memory utilization for the batch inference. Default to 90%. @@ -574,8 +573,8 @@ def from_api(request: CreateBatchCompletionsRequest) -> "CreateBatchCompletionsE input_data_path=request.input_data_path, output_data_path=request.output_data_path, content=request.content, - model_config=request.model_config, - model_cfg=request.model_config, + model_config=request.model_cfg, + model_cfg=request.model_cfg, data_parallelism=request.data_parallelism, max_runtime_sec=request.max_runtime_sec, tool_config=request.tool_config, diff --git a/model-engine/model_engine_server/common/dtos/model_bundles.py b/model-engine/model_engine_server/common/dtos/model_bundles.py index 778b2942..d49537c4 100644 --- a/model-engine/model_engine_server/common/dtos/model_bundles.py +++ b/model-engine/model_engine_server/common/dtos/model_bundles.py @@ -10,7 +10,7 @@ ModelBundleFlavors, ModelBundlePackagingType, ) -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class CreateModelBundleV1Request(BaseModel): @@ -23,9 +23,9 @@ class CreateModelBundleV1Request(BaseModel): requirements: List[str] env_params: ModelBundleEnvironmentParams packaging_type: ModelBundlePackagingType - metadata: Optional[Dict[str, Any]] - app_config: Optional[Dict[str, Any]] - schema_location: Optional[str] + metadata: Optional[Dict[str, Any]] = None + app_config: Optional[Dict[str, Any]] = None + schema_location: Optional[str] = None class CloneModelBundleV1Request(BaseModel): @@ -38,7 +38,7 @@ class CloneModelBundleV1Request(BaseModel): The ID of the ModelBundle to copy from. """ - new_app_config: Optional[Dict[str, Any]] + new_app_config: Optional[Dict[str, Any]] = None """ The app_config of the new ModelBundle. If not specified, then the new ModelBundle will use the same app_config as the original. @@ -50,6 +50,8 @@ class CreateModelBundleV1Response(BaseModel): Response object for creating a Model Bundle. """ + model_config = ConfigDict(protected_namespaces=()) + model_bundle_id: str @@ -58,6 +60,8 @@ class ModelBundleV1Response(BaseModel): Response object for a single Model Bundle. """ + model_config = ConfigDict(from_attributes=True, protected_namespaces=()) + id: str name: str location: str @@ -65,17 +69,10 @@ class ModelBundleV1Response(BaseModel): env_params: ModelBundleEnvironmentParams packaging_type: ModelBundlePackagingType metadata: Dict[str, Any] - app_config: Optional[Dict[str, Any]] + app_config: Optional[Dict[str, Any]] = None created_at: datetime.datetime model_artifact_ids: List[str] - schema_location: Optional[str] - - class Config: - """ - ModelBundleResponse Config class. - """ - - orm_mode = True + schema_location: Optional[str] = None class ListModelBundlesV1Response(BaseModel): @@ -83,6 +80,8 @@ class ListModelBundlesV1Response(BaseModel): Response object for listing Model Bundles. """ + model_config = ConfigDict(protected_namespaces=()) + model_bundles: List[ModelBundleV1Response] @@ -92,7 +91,7 @@ class CreateModelBundleV2Request(BaseModel): """ name: str - metadata: Optional[Dict[str, Any]] + metadata: Optional[Dict[str, Any]] = None schema_location: str flavor: ModelBundleFlavors = Field(..., discriminator="flavor") @@ -107,7 +106,7 @@ class CloneModelBundleV2Request(BaseModel): The ID of the ModelBundle to copy from. """ - new_app_config: Optional[Dict[str, Any]] + new_app_config: Optional[Dict[str, Any]] = None """ The app_config of the new ModelBundle. If not specified, then the new ModelBundle will use the same app_config as the original. @@ -119,6 +118,8 @@ class CreateModelBundleV2Response(BaseModel): Response object for creating a Model Bundle. """ + model_config = ConfigDict(protected_namespaces=()) + model_bundle_id: str @@ -127,27 +128,24 @@ class ModelBundleV2Response(BaseModel): Response object for a single Model Bundle. """ + model_config = ConfigDict(from_attributes=True, protected_namespaces=()) + id: str name: str metadata: Dict[str, Any] created_at: datetime.datetime model_artifact_ids: List[str] - schema_location: Optional[str] + schema_location: Optional[str] = None flavor: ModelBundleFlavors = Field(..., discriminator="flavor") - class Config: - """ - ModelBundleResponse Config class. - """ - - orm_mode = True - class ListModelBundlesV2Response(BaseModel): """ Response object for listing Model Bundles. """ + model_config = ConfigDict(protected_namespaces=()) + model_bundles: List[ModelBundleV2Response] diff --git a/model-engine/model_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py index 06073ada..cfeb44bf 100644 --- a/model-engine/model_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -10,6 +10,7 @@ from enum import Enum from typing import Any, Dict, List, Optional +from model_engine_server.common.dtos.core import HttpUrlStr from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, @@ -21,7 +22,7 @@ ModelEndpointType, StorageSpecificationType, ) -from pydantic import BaseModel, Field, HttpUrl +from pydantic import BaseModel, ConfigDict, Field class BrokerType(str, Enum): @@ -51,22 +52,22 @@ class CreateModelEndpointV1Request(BaseModel): model_bundle_id: str endpoint_type: ModelEndpointType metadata: Dict[str, Any] # TODO: JSON type - post_inference_hooks: Optional[List[str]] + post_inference_hooks: Optional[List[str]] = None cpus: CpuSpecificationType gpus: int = Field(..., ge=0) memory: StorageSpecificationType - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None min_workers: int = Field(..., ge=0) max_workers: int = Field(..., ge=0) per_worker: int = Field(..., gt=0) labels: Dict[str, str] - prewarm: Optional[bool] - high_priority: Optional[bool] - billing_tags: Optional[Dict[str, Any]] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = Field(default=False) @@ -75,25 +76,25 @@ class CreateModelEndpointV1Response(BaseModel): class UpdateModelEndpointV1Request(BaseModel): - model_bundle_id: Optional[str] - metadata: Optional[Dict[str, Any]] # TODO: JSON type - post_inference_hooks: Optional[List[str]] - cpus: Optional[CpuSpecificationType] + model_bundle_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None # TODO: JSON type + post_inference_hooks: Optional[List[str]] = None + cpus: Optional[CpuSpecificationType] = None gpus: Optional[int] = Field(default=None, ge=0) - memory: Optional[StorageSpecificationType] - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None min_workers: Optional[int] = Field(default=None, ge=0) max_workers: Optional[int] = Field(default=None, ge=0) per_worker: Optional[int] = Field(default=None, gt=0) - labels: Optional[Dict[str, str]] - prewarm: Optional[bool] - high_priority: Optional[bool] - billing_tags: Optional[Dict[str, Any]] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] - public_inference: Optional[bool] + labels: Optional[Dict[str, str]] = None + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = None class UpdateModelEndpointV1Response(BaseModel): @@ -110,7 +111,7 @@ class GetModelEndpointV1Response(BaseModel): bundle_name: str status: ModelEndpointStatus post_inference_hooks: Optional[List[str]] = Field(default=None) - default_callback_url: Optional[HttpUrl] = Field(default=None) + default_callback_url: Optional[HttpUrlStr] = Field(default=None) default_callback_auth: Optional[CallbackAuth] = Field(default=None) labels: Optional[Dict[str, str]] = Field(default=None) aws_role: Optional[str] = Field(default=None) @@ -143,6 +144,7 @@ class ModelEndpointOrderBy(str, Enum): class GetModelEndpointsSchemaV1Response(BaseModel): + model_config = ConfigDict(protected_namespaces=()) model_endpoints_schema: ModelEndpointsSchema diff --git a/model-engine/model_engine_server/common/dtos/tasks.py b/model-engine/model_engine_server/common/dtos/tasks.py index 36c20903..b9919f68 100644 --- a/model-engine/model_engine_server/common/dtos/tasks.py +++ b/model-engine/model_engine_server/common/dtos/tasks.py @@ -6,15 +6,15 @@ from typing import Any, Optional from model_engine_server.domain.entities import CallbackAuth -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel -class ResponseSchema(BaseModel): - __root__: Any +class ResponseSchema(RootModel): + root: Any -class RequestSchema(BaseModel): - __root__: Any +class RequestSchema(RootModel): + root: Any class TaskStatus(str, Enum): diff --git a/model-engine/model_engine_server/common/dtos/triggers.py b/model-engine/model_engine_server/common/dtos/triggers.py index ee4d2121..3d75376e 100644 --- a/model-engine/model_engine_server/common/dtos/triggers.py +++ b/model-engine/model_engine_server/common/dtos/triggers.py @@ -4,15 +4,15 @@ import datetime from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class CreateTriggerV1Request(BaseModel): name: str cron_schedule: str bundle_id: str - default_job_config: Optional[Dict[str, Any]] - default_job_metadata: Optional[Dict[str, str]] + default_job_config: Optional[Dict[str, Any]] = None + default_job_metadata: Optional[Dict[str, str]] = None class CreateTriggerV1Response(BaseModel): @@ -29,9 +29,7 @@ class GetTriggerV1Response(BaseModel): docker_image_batch_job_bundle_id: str default_job_config: Optional[Dict[str, Any]] = Field(default=None) default_job_metadata: Optional[Dict[str, str]] = Field(default=None) - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class ListTriggersV1Response(BaseModel): @@ -39,8 +37,8 @@ class ListTriggersV1Response(BaseModel): class UpdateTriggerV1Request(BaseModel): - cron_schedule: Optional[str] - suspend: Optional[bool] + cron_schedule: Optional[str] = None + suspend: Optional[bool] = None class UpdateTriggerV1Response(BaseModel): diff --git a/model-engine/model_engine_server/domain/entities/batch_job_entity.py b/model-engine/model_engine_server/domain/entities/batch_job_entity.py index 6bf51b0d..62238d66 100644 --- a/model-engine/model_engine_server/domain/entities/batch_job_entity.py +++ b/model-engine/model_engine_server/domain/entities/batch_job_entity.py @@ -26,24 +26,24 @@ class BatchJobSerializationFormat(str, Enum): class BatchJobRecord(OwnedEntity): id: str created_at: datetime - completed_at: Optional[datetime] + completed_at: Optional[datetime] = None status: BatchJobStatus created_by: str owner: str model_bundle: ModelBundle - model_endpoint_id: Optional[str] - task_ids_location: Optional[str] - result_location: Optional[str] + model_endpoint_id: Optional[str] = None + task_ids_location: Optional[str] = None + result_location: Optional[str] = None class BatchJobProgress(BaseModel): - num_tasks_pending: Optional[int] - num_tasks_completed: Optional[int] + num_tasks_pending: Optional[int] = None + num_tasks_completed: Optional[int] = None class BatchJob(BaseModel): record: BatchJobRecord - model_endpoint: Optional[ModelEndpoint] + model_endpoint: Optional[ModelEndpoint] = None progress: BatchJobProgress @@ -57,7 +57,7 @@ class DockerImageBatchJob(BaseModel): created_by: str owner: str created_at: datetime - completed_at: Optional[datetime] + completed_at: Optional[datetime] = None status: BatchJobStatus # the status map relatively nicely onto BatchJobStatus annotations: Optional[Dict[str, str]] = None override_job_max_runtime_s: Optional[int] = None diff --git a/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py b/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py index 1ed2838d..9213af13 100644 --- a/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py @@ -3,6 +3,7 @@ from model_engine_server.domain.entities import GpuType from model_engine_server.domain.entities.owned_entity import OwnedEntity +from pydantic import ConfigDict class DockerImageBatchJobBundle(OwnedEntity): @@ -15,13 +16,11 @@ class DockerImageBatchJobBundle(OwnedEntity): image_tag: str command: List[str] env: Dict[str, str] - mount_location: Optional[str] - cpus: Optional[str] - memory: Optional[str] - storage: Optional[str] - gpus: Optional[int] - gpu_type: Optional[GpuType] - public: Optional[bool] - - class Config: - orm_mode = True + mount_location: Optional[str] = None + cpus: Optional[str] = None + memory: Optional[str] = None + storage: Optional[str] = None + gpus: Optional[int] = None + gpu_type: Optional[GpuType] = None + public: Optional[bool] = None + model_config = ConfigDict(from_attributes=True) diff --git a/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py b/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py index 13188c06..b18bbdd2 100644 --- a/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class LLMFineTuneTemplate(BaseModel): @@ -8,9 +8,7 @@ class LLMFineTuneTemplate(BaseModel): launch_endpoint_config: Dict[str, Any] default_hparams: Dict[str, Any] required_params: List[str] - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) class LLMFineTuneEvent(BaseModel): diff --git a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py index 247539d0..e3ceb836 100644 --- a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py @@ -5,7 +5,7 @@ from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME from model_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Literal @@ -38,12 +38,12 @@ class ModelBundleEnvironmentParams(BaseModel): """ framework_type: ModelBundleFrameworkType - pytorch_image_tag: Optional[str] # for pytorch - tensorflow_version: Optional[str] # for tensorflow - ecr_repo: Optional[str] # for custom base image - image_tag: Optional[str] # for custom base image + pytorch_image_tag: Optional[str] = None # for pytorch + tensorflow_version: Optional[str] = None # for tensorflow + ecr_repo: Optional[str] = None # for custom base image + image_tag: Optional[str] = None # for custom base image - @root_validator + @model_validator(mode="before") @classmethod def validate_fields_present_for_framework_type(cls, field_values): """ @@ -72,12 +72,7 @@ def validate_fields_present_for_framework_type(cls, field_values): ) return field_values - class Config: - """ - Model Bundle Environment Params Config class. - """ - - orm_mode = True + model_config = ConfigDict(from_attributes=True) class PytorchFramework(BaseModel): @@ -127,7 +122,7 @@ class ArtifactLike(BaseModel, ABC): framework: Union[PytorchFramework, TensorflowFramework, CustomFramework] = Field( ..., discriminator="framework_type" ) - app_config: Optional[Dict[str, Any]] + app_config: Optional[Dict[str, Any]] = None location: str @@ -159,7 +154,7 @@ class RunnableImageLike(BaseModel, ABC): command: List[str] predict_route: str = "/predict" healthcheck_route: str = "/readyz" - env: Optional[Dict[str, str]] + env: Optional[Dict[str, str]] = None protocol: Literal["http"] # TODO: add support for other protocols (e.g. grpc) readiness_initial_delay_seconds: int = 120 @@ -177,11 +172,11 @@ class TritonEnhancedRunnableImageFlavor(RunnableImageLike): flavor: Literal[ModelBundleFlavorType.TRITON_ENHANCED_RUNNABLE_IMAGE] triton_model_repository: str - triton_model_replicas: Optional[Dict[str, str]] + triton_model_replicas: Optional[Dict[str, str]] = None triton_num_cpu: float triton_commit_tag: str - triton_storage: Optional[str] - triton_memory: Optional[str] + triton_storage: Optional[str] = None + triton_memory: Optional[str] = None triton_readiness_initial_delay_seconds: int = 300 # will default to 300 seconds @@ -217,23 +212,17 @@ class ModelBundle(OwnedEntity): created_at: datetime.datetime metadata: Dict[str, Any] model_artifact_ids: List[str] - schema_location: Optional[str] + schema_location: Optional[str] = None owner: str flavor: ModelBundleFlavors = Field(..., discriminator="flavor") # LEGACY FIELDS - requirements: Optional[List[str]] # FIXME: Delete - location: Optional[str] # FIXME: Delete - env_params: Optional[ModelBundleEnvironmentParams] # FIXME: Delete - packaging_type: Optional[ModelBundlePackagingType] # FIXME: Delete - app_config: Optional[Dict[str, Any]] # FIXME: Delete - - class Config: - """ - Model Bundle Config class. - """ - - orm_mode = True + requirements: Optional[List[str]] = None # FIXME: Delete + location: Optional[str] = None # FIXME: Delete + env_params: Optional[ModelBundleEnvironmentParams] = None # FIXME: Delete + packaging_type: Optional[ModelBundlePackagingType] = None # FIXME: Delete + app_config: Optional[Dict[str, Any]] = None # FIXME: Delete + model_config = ConfigDict(from_attributes=True) def is_runnable(self) -> bool: """True iff the model bundle calls for it. diff --git a/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py index cb6277f6..a0f84c4e 100644 --- a/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py @@ -12,7 +12,7 @@ from model_engine_server.domain.entities.gpu_type import GpuType from model_engine_server.domain.entities.model_bundle_entity import ModelBundle from model_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel from typing_extensions import Literal ModelEndpointsSchema = OpenAPI @@ -42,9 +42,9 @@ class ModelEndpointResourceState(BaseModel): cpus: CpuSpecificationType # TODO(phil): try to use decimal.Decimal gpus: int = Field(..., ge=0) memory: StorageSpecificationType - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None class ModelEndpointDeploymentState(BaseModel): @@ -71,8 +71,8 @@ class CallbackmTLSAuth(BaseModel): key: str -class CallbackAuth(BaseModel): - __root__: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") +class CallbackAuth(RootModel): + root: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") class ModelEndpointConfig(BaseModel): @@ -82,14 +82,14 @@ class ModelEndpointConfig(BaseModel): endpoint_name: str bundle_name: str - post_inference_hooks: Optional[List[str]] + post_inference_hooks: Optional[List[str]] = None user_id: Optional[str] = None billing_queue: Optional[str] = None billing_tags: Optional[Dict[str, Any]] = None default_callback_url: Optional[str] = None - default_callback_auth: Optional[CallbackAuth] + default_callback_auth: Optional[CallbackAuth] = None endpoint_id: Optional[str] = None - endpoint_type: Optional[ModelEndpointType] + endpoint_type: Optional[ModelEndpointType] = None bundle_id: Optional[str] = None labels: Optional[Dict[str, str]] = None @@ -102,8 +102,8 @@ def deserialize(serialized_config: str) -> "ModelEndpointConfig": class ModelEndpointUserConfigState(BaseModel): - app_config: Optional[Dict[str, Any]] - endpoint_config: Optional[ModelEndpointConfig] + app_config: Optional[Dict[str, Any]] = None + endpoint_config: Optional[ModelEndpointConfig] = None class ModelEndpointRecord(OwnedEntity): @@ -117,15 +117,15 @@ class ModelEndpointRecord(OwnedEntity): name: str created_by: str created_at: datetime.datetime - last_updated_at: Optional[datetime.datetime] - metadata: Optional[Dict[str, Any]] + last_updated_at: Optional[datetime.datetime] = None + metadata: Optional[Dict[str, Any]] = None creation_task_id: Optional[str] = Field(default=None) endpoint_type: ModelEndpointType destination: str status: ModelEndpointStatus current_model_bundle: ModelBundle owner: str - public_inference: Optional[bool] + public_inference: Optional[bool] = None class ModelEndpointInfraState(BaseModel): @@ -136,14 +136,14 @@ class ModelEndpointInfraState(BaseModel): deployment_name: str aws_role: str results_s3_bucket: str - child_fn_info: Optional[Dict[str, Any]] + child_fn_info: Optional[Dict[str, Any]] = None labels: Dict[str, str] deployment_state: ModelEndpointDeploymentState resource_state: ModelEndpointResourceState user_config_state: ModelEndpointUserConfigState prewarm: Optional[bool] = None - high_priority: Optional[bool] - num_queued_items: Optional[int] + high_priority: Optional[bool] = None + num_queued_items: Optional[int] = None image: str @@ -153,4 +153,4 @@ class ModelEndpoint(BaseModel): """ record: ModelEndpointRecord - infra_state: Optional[ModelEndpointInfraState] + infra_state: Optional[ModelEndpointInfraState] = None diff --git a/model-engine/model_engine_server/domain/entities/trigger_entity.py b/model-engine/model_engine_server/domain/entities/trigger_entity.py index ac515865..0d68ec92 100644 --- a/model-engine/model_engine_server/domain/entities/trigger_entity.py +++ b/model-engine/model_engine_server/domain/entities/trigger_entity.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Optional from model_engine_server.domain.entities.owned_entity import OwnedEntity +from pydantic import ConfigDict class Trigger(OwnedEntity): @@ -13,8 +14,6 @@ class Trigger(OwnedEntity): cron_schedule: str docker_image_batch_job_bundle_id: str - default_job_config: Optional[Dict[str, Any]] - default_job_metadata: Optional[Dict[str, str]] - - class Config: - orm_mode = True + default_job_config: Optional[Dict[str, Any]] = None + default_job_metadata: Optional[Dict[str, str]] = None + model_config = ConfigDict(from_attributes=True) diff --git a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py index 38861ade..23759911 100644 --- a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -16,7 +16,7 @@ class MetricMetadata(BaseModel): user: User - model_name: Optional[str] + model_name: Optional[str] = None class MonitoringMetricsGateway(ABC): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index f19e5ab5..a46b04bb 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -2408,7 +2408,7 @@ async def create_batch_job_bundle( hardware: CreateDockerImageBatchJobResourceRequests, ) -> DockerImageBatchJobBundle: bundle_name = ( - f"{request.model_config.model}_{datetime.datetime.utcnow().strftime('%y%m%d-%H%M%S')}" + f"{request.model_cfg.model}_{datetime.datetime.utcnow().strftime('%y%m%d-%H%M%S')}" ) image_tag = self.docker_repository.get_latest_image_tag( @@ -2448,22 +2448,22 @@ async def create_batch_job_bundle( async def execute( self, user: User, request: CreateBatchCompletionsRequest ) -> CreateBatchCompletionsResponse: - request.model_config.checkpoint_path = get_checkpoint_path( - request.model_config.model, request.model_config.checkpoint_path + request.model_cfg.checkpoint_path = get_checkpoint_path( + request.model_cfg.model, request.model_cfg.checkpoint_path ) hardware = await _infer_hardware( self.llm_artifact_gateway, - request.model_config.model, - request.model_config.checkpoint_path, + request.model_cfg.model, + request.model_cfg.checkpoint_path, is_batch_job=True, ) # Reconcile gpus count with num_shards from request assert hardware.gpus is not None - if request.model_config.num_shards: - hardware.gpus = max(hardware.gpus, request.model_config.num_shards) + if request.model_cfg.num_shards: + hardware.gpus = max(hardware.gpus, request.model_cfg.num_shards) engine_request = CreateBatchCompletionsEngineRequest.from_api(request) - engine_request.model_config.num_shards = hardware.gpus + engine_request.model_cfg.num_shards = hardware.gpus if engine_request.tool_config and engine_request.tool_config.name != "code_evaluator": raise ObjectHasInvalidValueException( @@ -2471,7 +2471,7 @@ async def execute( ) additional_engine_args = infer_addition_engine_args_from_model_name( - engine_request.model_config.model + engine_request.model_cfg.model ) if additional_engine_args.gpu_memory_utilization is not None: @@ -2502,7 +2502,7 @@ async def execute( repo=batch_bundle.image_repository, tag=batch_bundle.image_tag, resource_requests=hardware, - labels=engine_request.model_config.labels, + labels=engine_request.model_cfg.labels, mount_location=batch_bundle.mount_location, override_job_max_runtime_s=engine_request.max_runtime_sec, num_workers=engine_request.data_parallelism, diff --git a/model-engine/model_engine_server/inference/batch_inference/dto.py b/model-engine/model_engine_server/inference/batch_inference/dto.py index da63c545..109050c2 100644 --- a/model-engine/model_engine_server/inference/batch_inference/dto.py +++ b/model-engine/model_engine_server/inference/batch_inference/dto.py @@ -117,7 +117,7 @@ class CreateBatchCompletionsRequest(BaseModel): Request object for batch completions. """ - input_data_path: Optional[str] + input_data_path: Optional[str] = None output_data_path: str """ Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. diff --git a/model-engine/model_engine_server/inference/common.py b/model-engine/model_engine_server/inference/common.py index 2655eb12..b8ddfea0 100644 --- a/model-engine/model_engine_server/inference/common.py +++ b/model-engine/model_engine_server/inference/common.py @@ -198,7 +198,7 @@ def predict_on_url(predict_fn: Callable, request_url: str, return_pickled: bool) def predict_on_args( predict_fn: Callable, inputs: RequestSchema, return_pickled: bool ) -> Dict[str, str]: - inputs_kwargs = inputs.__root__ + inputs_kwargs = inputs.root output = predict_fn(**inputs_kwargs) if return_pickled: diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py index 3295c3b4..5d45b5cb 100644 --- a/model-engine/model_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -76,8 +76,8 @@ def handle( response["task_id"] = task_id auth = request_payload.callback_auth or self._default_callback_auth - if auth and isinstance(auth.__root__, CallbackBasicAuth): - auth_tuple = (auth.__root__.username, auth.__root__.password) + if auth and isinstance(auth.root, CallbackBasicAuth): + auth_tuple = (auth.root.username, auth.root.password) else: auth_tuple = (self._user_id, "") diff --git a/model-engine/model_engine_server/inference/requirements_base.txt b/model-engine/model_engine_server/inference/requirements_base.txt index 4561bd06..aeeb5efd 100644 --- a/model-engine/model_engine_server/inference/requirements_base.txt +++ b/model-engine/model_engine_server/inference/requirements_base.txt @@ -4,7 +4,7 @@ boto3~=1.34.33 celery[redis,sqs,tblib]==5.3.1 datadog-api-client==2.11.0 datadog~=0.47.0 -fastapi==0.78.0 +fastapi~=0.110.0 # Incompatibility between celery 5 and python 3.7 because of importlib-metadata 5, so we pin it importlib-metadata<5.0;python_version<"3.8" scale-launch>=0.1.0 @@ -21,3 +21,5 @@ json-log-formatter~=0.3 # model_engine_server/core/loggers.py tenacity>=6.0.0,<=6.2.0 # model_engine_server/core/loggers.py tqdm~=4.64 # model_engine_server/common/service_requests.py gunicorn~=20.0 +pydantic==2.8.2 + diff --git a/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py index 3c0408c8..f1c8b4f9 100644 --- a/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py @@ -31,8 +31,8 @@ def create_task( *, task_name: str = DEFAULT_CELERY_TASK_NAME, ) -> CreateAsyncTaskV1Response: - # Use json.loads instead of predict_request.dict() because we have overridden the '__root__' - # key in some fields, and __root__ overriding only reflects in the json() output. + # Use json.loads instead of predict_request.dict() because we have overridden the 'root' + # key in some fields, and root overriding only reflects in the json() output. predict_args = json.loads(predict_request.json()) send_task_response = self.task_queue_gateway.send_task( diff --git a/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py index 5fac2841..f6f51d9c 100644 --- a/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py @@ -1,11 +1,11 @@ import json from enum import Enum -from typing import Any, Callable, Dict, Sequence, Set, Type, Union +from typing import Any, Callable, Dict, List, Sequence, Set, Type, Union from fastapi import routing -from fastapi._compat import GenerateJsonSchema, get_model_definitions +from fastapi._compat import GenerateJsonSchema, get_definitions from fastapi.openapi.constants import REF_TEMPLATE -from fastapi.openapi.utils import get_openapi_path +from fastapi.openapi.utils import get_fields_from_routes, get_openapi_path from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, GetAsyncTaskV1Response, @@ -72,6 +72,7 @@ def get_model_endpoints_schema( methods=["POST"], ) routes.append(route) + definitions = self.get_schemas_from_model_endpoint_record(record) definitions = LiveModelEndpointsSchemaGateway.update_model_definitions_with_prefix( prefix=record.name, model_definitions=definitions @@ -121,12 +122,19 @@ def get_openapi( prefix = model_endpoint_name model_name_map = LiveModelEndpointsSchemaGateway.get_model_name_map(prefix) schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) + all_fields = get_fields_from_routes([route]) + field_mapping, _ = get_definitions( + fields=all_fields, + schema_generator=schema_generator, + model_name_map=model_name_map, + ) + result = get_openapi_path( route=route, - model_name_map=model_name_map, operation_ids=operation_ids, schema_generator=schema_generator, - field_mapping={}, + model_name_map=model_name_map, + field_mapping=field_mapping, ) if result: path, security_schemes, path_definitions = result @@ -156,19 +164,17 @@ def update_model_definitions_with_prefix( Returns: Dict[str, Any]: The updated model definitions. """ - models: Set[Union[Type[BaseModel], Type[Enum]]] = { - CallbackAuth, - CallbackBasicAuth, - CallbackmTLSAuth, - TaskStatus, + models: List[Type[BaseModel]] = [ EndpointPredictV1Request, GetAsyncTaskV1Response, SyncEndpointPredictV1Response, - } - definitions = get_model_definitions( - flat_models=models, - model_name_map=LiveModelEndpointsSchemaGateway.get_model_name_map(prefix), + ] + + model_name_map = LiveModelEndpointsSchemaGateway.get_model_name_map(prefix) + definitions: Dict[str, Any] = LiveModelEndpointsSchemaGateway.get_model_definitions( + models=models, model_name_map=model_name_map ) + user_definitions = {} for k, v in model_definitions.items(): LiveModelEndpointsSchemaGateway.update_schema_refs_with_prefix(v, prefix) @@ -236,8 +242,8 @@ def get_default_model_definitions() -> Dict[str, Any]: global _default_model_definitions if _default_model_definitions is None: - _default_model_definitions = get_model_definitions( - flat_models={RequestSchema, ResponseSchema}, + _default_model_definitions = LiveModelEndpointsSchemaGateway.get_model_definitions( + models=[RequestSchema, ResponseSchema], model_name_map={ RequestSchema: "RequestSchema", ResponseSchema: "ResponseSchema", @@ -245,3 +251,21 @@ def get_default_model_definitions() -> Dict[str, Any]: ) return _default_model_definitions + + @staticmethod + def get_model_definitions( + models: Sequence[Type[BaseModel]], + model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], + ) -> Dict[str, Any]: + """Get OpenAPI definitions for provided models using the name provided in model_name_map""" + + definitions = {} + for model in models: + schema = model.model_json_schema( + schema_generator=GenerateJsonSchema, ref_template=REF_TEMPLATE + ) + m_defs = schema.pop("$defs", {}) + definitions.update(m_defs) + model_name = model_name_map.get(model, model.__name__) + definitions[model_name] = schema + return definitions diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 024ca99e..af054dba 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -50,7 +50,7 @@ get_endpoint_resource_arguments_from_request, ) from packaging import version -from pydantic.utils import deep_update +from pydantic.v1.utils import deep_update logger = make_logger(logger_name()) diff --git a/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py b/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py index 4fa1948c..9e3cd17d 100644 --- a/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py @@ -15,7 +15,7 @@ DbRepositoryMixin, raise_if_read_only, ) -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError class DbDockerImageBatchJobBundleRepository(DockerImageBatchJobBundleRepository, DbRepositoryMixin): diff --git a/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py b/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py index bb9cb5a3..b4114358 100644 --- a/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py @@ -12,7 +12,7 @@ DbRepositoryMixin, raise_if_read_only, ) -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError from sqlalchemy.exc import IntegrityError diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index aecef2c7..9f9f257d 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -307,6 +307,7 @@ async def build_endpoint( user_config_state=ModelEndpointUserConfigState( app_config=build_endpoint_request.model_endpoint_record.current_model_bundle.app_config, endpoint_config=ModelEndpointConfig( + endpoint_type=build_endpoint_request.model_endpoint_record.endpoint_type, endpoint_name=build_endpoint_request.model_endpoint_record.name, bundle_name=build_endpoint_request.model_endpoint_record.current_model_bundle.name, post_inference_hooks=build_endpoint_request.post_inference_hooks, diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 2ef63150..f70d4503 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -37,7 +37,7 @@ protobuf~=3.20 psycopg2-binary==2.9.3 py-xid==0.3.0 pycurl~=7.44 # For celery[sqs] -pydantic==1.10.14 +pydantic==2.8.2 python-multipart~=0.0.7 quart==0.18.3 requests-auth-aws-sigv4~=0.7 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 71e7440d..fb0d4d24 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -18,6 +18,8 @@ alembic==1.8.1 # via -r model-engine/requirements.in amqp==5.1.1 # via kombu +annotated-types==0.7.0 + # via pydantic anyio==3.7.1 # via # azure-core @@ -362,10 +364,12 @@ pycurl==7.45.2 # -r model-engine/requirements.in # celery # kombu -pydantic==1.10.14 +pydantic==2.8.2 # via # -r model-engine/requirements.in # fastapi +pydantic-core==2.20.1 + # via pydantic pygments==2.15.1 # via # readme-renderer @@ -530,6 +534,7 @@ types-s3transfer==0.6.1 typing-extensions==4.10.0 # via # aioredis + # annotated-types # asgiref # azure-core # azure-keyvault-secrets @@ -552,6 +557,7 @@ typing-extensions==4.10.0 # mypy-boto3-s3 # mypy-boto3-sqs # pydantic + # pydantic-core # rich # sqlalchemy # starlette diff --git a/model-engine/tests/integration/inference/conftest.py b/model-engine/tests/integration/inference/conftest.py index fcd63dfc..07e900b0 100644 --- a/model-engine/tests/integration/inference/conftest.py +++ b/model-engine/tests/integration/inference/conftest.py @@ -47,7 +47,7 @@ def test_user_id() -> str: @pytest.fixture(scope="session") def test_default_callback_auth() -> CallbackAuth: return CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="test_user", password="test_password") + root=CallbackBasicAuth(kind="basic", username="test_user", password="test_password") ) @@ -100,7 +100,7 @@ def launch_celery_app( f"--loglevel=INFO --concurrency=1 --queues={queue}" ) # Wait up to 10 seconds for process to start and be ready. - with subprocess.Popen( + with subprocess.Popen( # nosemgrep command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) as process: for attempt in Retrying( diff --git a/model-engine/tests/integration/inference/test_async_inference.py b/model-engine/tests/integration/inference/test_async_inference.py index db9bc9a7..d1d7f7c5 100644 --- a/model-engine/tests/integration/inference/test_async_inference.py +++ b/model-engine/tests/integration/inference/test_async_inference.py @@ -42,7 +42,7 @@ def redis_available() -> bool: @pytest.mark.parametrize( "task_args,cloudpickle,expected_status,expected_result", [ - ({"y": 1}, False, TaskStatus.SUCCESS, ResponseSchema(__root__={"result": "1"})), + ({"y": 1}, False, TaskStatus.SUCCESS, ResponseSchema(root={"result": "1"})), ({"x": False, "y": 1}, False, TaskStatus.FAILURE, None), ], ) diff --git a/model-engine/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py index b312f7eb..2ca38500 100644 --- a/model-engine/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -789,7 +789,7 @@ def model_endpoint_1( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -831,7 +831,7 @@ def model_endpoint_1( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", @@ -1363,7 +1363,7 @@ def create_batch_completions_request() -> Dict[str, Any]: "model_config": { "model": "mpt-7b", "checkpoint_path": "s3://test_checkpoint_path", - "labels": [], + "labels": {}, "num_shards": 2, }, "data_parallelism": 1, diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index 80f21734..f2cee1bb 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -410,8 +410,7 @@ async def test_create_streaming_task_success( count = 0 async for message in response.aiter_bytes(): assert ( - message - == b'data: {"status": "SUCCESS", "result": null, "traceback": null}\r\n\r\n' + message == b'data: {"status":"SUCCESS","result":null,"traceback":null}\r\n\r\n' ) count += 1 assert count == 1 diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 2366019a..ec3850af 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -3197,7 +3197,7 @@ def build_endpoint_request_async_runnable_image( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3240,7 +3240,7 @@ def build_endpoint_request_streaming_runnable_image( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3283,7 +3283,7 @@ def build_endpoint_request_sync_runnable_image( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3326,7 +3326,7 @@ def build_endpoint_request_sync_pytorch( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3368,7 +3368,7 @@ def build_endpoint_request_async_tensorflow( optimize_costs=False, default_callback_url="https://example.com/path", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3513,9 +3513,7 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An args=["test_arg_1", "test_arg_2"], callback_url="http://test_callback_url.xyz", callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( - kind="basic", username="test_username", password="test_password" - ) + root=CallbackBasicAuth(kind="basic", username="test_username", password="test_password") ), return_pickled=True, ) @@ -3594,7 +3592,7 @@ def llm_model_endpoint_async( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -3653,7 +3651,7 @@ def llm_model_endpoint_async( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", @@ -3726,7 +3724,7 @@ def llm_model_endpoint_sync( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -3785,7 +3783,7 @@ def llm_model_endpoint_sync( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", @@ -3858,7 +3856,7 @@ def llm_model_endpoint_stream( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -3917,7 +3915,7 @@ def llm_model_endpoint_stream( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", @@ -3990,7 +3988,7 @@ def llm_model_endpoint_sync_tgi( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -4049,7 +4047,7 @@ def llm_model_endpoint_sync_tgi( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", @@ -4122,7 +4120,7 @@ def llm_model_endpoint_sync_lightllm( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -4181,7 +4179,7 @@ def llm_model_endpoint_sync_lightllm( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", @@ -4254,7 +4252,7 @@ def llm_model_endpoint_sync_trt_llm( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -4313,7 +4311,7 @@ def llm_model_endpoint_sync_trt_llm( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", @@ -4451,7 +4449,7 @@ def llm_model_endpoint_text_generation_inference( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -4524,7 +4522,7 @@ def llm_model_endpoint_trt_llm( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 937f3cfc..0a10cb77 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -533,7 +533,7 @@ def create_batch_completions_request() -> CreateBatchCompletionsRequest: model_config=CreateBatchCompletionsModelConfig( model="mpt-7b", checkpoint_path="s3://test_checkpoint_path", - labels=[], + labels={}, num_shards=2, ), data_parallelism=2, diff --git a/model-engine/tests/unit/domain/test_entities.py b/model-engine/tests/unit/domain/test_entities.py index 41533afc..cd0ab507 100644 --- a/model-engine/tests/unit/domain/test_entities.py +++ b/model-engine/tests/unit/domain/test_entities.py @@ -25,9 +25,7 @@ user_id="test_user", billing_queue="test_queue", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( - kind="basic", username="test_user", password="test_password" - ) + root=CallbackBasicAuth(kind="basic", username="test_user", password="test_password") ), ), ], diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 8e310211..125aeab9 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1961,7 +1961,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32000, } hardware = await _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x7b", "") - assert hardware.cpus == "40" + assert hardware.cpus == 40 assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" @@ -1970,7 +1970,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): hardware = await _infer_hardware( fake_llm_artifact_gateway, "mixtral-8x7b", "", is_batch_job=True ) - assert hardware.cpus == "40" + assert hardware.cpus == 40 assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" @@ -2001,7 +2001,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32000, } hardware = await _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x22b", "") - assert hardware.cpus == "160" + assert hardware.cpus == 160 assert hardware.gpus == 8 assert hardware.memory == "800Gi" assert hardware.storage == "640Gi" @@ -2010,7 +2010,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): hardware = await _infer_hardware( fake_llm_artifact_gateway, "mixtral-8x22b", "", is_batch_job=True ) - assert hardware.cpus == "160" + assert hardware.cpus == 160 assert hardware.gpus == 8 assert hardware.memory == "800Gi" assert hardware.storage == "640Gi" @@ -2037,14 +2037,14 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32000, } hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "") - assert hardware.cpus == "5" + assert hardware.cpus == 5 assert hardware.gpus == 1 assert hardware.memory == "20Gi" assert hardware.storage == "40Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True) - assert hardware.cpus == "10" + assert hardware.cpus == 10 assert hardware.gpus == 1 assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" @@ -2072,14 +2072,14 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 128256, } hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "") - assert hardware.cpus == "5" + assert hardware.cpus == 5 assert hardware.gpus == 1 assert hardware.memory == "20Gi" assert hardware.storage == "40Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True) - assert hardware.cpus == "10" + assert hardware.cpus == 10 assert hardware.gpus == 1 assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" @@ -2106,7 +2106,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32000, } hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-13b", "") - assert hardware.cpus == "10" + assert hardware.cpus == 10 assert hardware.gpus == 1 assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" @@ -2115,7 +2115,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): hardware = await _infer_hardware( fake_llm_artifact_gateway, "llama-2-13b", "", is_batch_job=True ) - assert hardware.cpus == "20" + assert hardware.cpus == 20 assert hardware.gpus == 1 assert hardware.memory == "80Gi" assert hardware.storage == "96Gi" @@ -2142,7 +2142,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32000, } hardware = await _infer_hardware(fake_llm_artifact_gateway, "codellama-34b", "") - assert hardware.cpus == "20" + assert hardware.cpus == 20 assert hardware.gpus == 1 assert hardware.memory == "80Gi" assert hardware.storage == "96Gi" @@ -2151,7 +2151,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): hardware = await _infer_hardware( fake_llm_artifact_gateway, "codellama-34b", "", is_batch_job=True ) - assert hardware.cpus == "40" + assert hardware.cpus == 40 assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" @@ -2178,7 +2178,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32000, } hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-70b", "") - assert hardware.cpus == "40" + assert hardware.cpus == 40 assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" @@ -2187,7 +2187,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): hardware = await _infer_hardware( fake_llm_artifact_gateway, "llama-2-70b", "", is_batch_job=True ) - assert hardware.cpus == "80" + assert hardware.cpus == 80 assert hardware.gpus == 4 assert hardware.memory == "320Gi" assert hardware.storage == "320Gi" @@ -2215,7 +2215,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 128256, } hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-70b", "") - assert hardware.cpus == "40" + assert hardware.cpus == 40 assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" @@ -2224,7 +2224,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): hardware = await _infer_hardware( fake_llm_artifact_gateway, "llama-3-70b", "", is_batch_job=True ) - assert hardware.cpus == "80" + assert hardware.cpus == 80 assert hardware.gpus == 4 assert hardware.memory == "320Gi" assert hardware.storage == "320Gi" @@ -2253,7 +2253,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 128256, } hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b-instruct-262k", "") - assert hardware.cpus == "40" + assert hardware.cpus == 40 assert hardware.gpus == 2 assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" @@ -2283,7 +2283,7 @@ async def test_fill_hardware_info(fake_llm_artifact_gateway): labels={}, ) await _fill_hardware_info(fake_llm_artifact_gateway, request) - assert request.cpus == "40" + assert request.cpus == 40 assert request.gpus == 2 assert request.memory == "160Gi" assert request.storage == "160Gi" diff --git a/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py index e86f0f1f..1d38c223 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py @@ -52,7 +52,7 @@ def test_task_create_get_args_callback( task_queue_gateway: Any = fake_live_async_model_inference_gateway.task_queue_gateway assert len(task_queue_gateway.queue) == 1 assert task_queue_gateway.queue[task_id]["args"][0] == { - "args": endpoint_predict_request_2[0].args.__root__, + "args": endpoint_predict_request_2[0].args.root, "url": None, "cloudpickle": None, "callback_auth": json.loads(endpoint_predict_request_2[0].callback_auth.json()), diff --git a/model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py index 2a3fe197..4112ac8b 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py @@ -34,4 +34,4 @@ def test_update_progress(test_api_key: str, fake_filesystem_gateway): progress=BatchJobProgress(num_tasks_pending=4, num_tasks_completed=5), ) handle = fake_filesystem_gateway.mock_open() - handle.write.assert_called_once_with('{"num_tasks_pending": 4, "num_tasks_completed": 5}') + handle.write.assert_called_once_with('{"num_tasks_pending":4,"num_tasks_completed":5}') diff --git a/requirements-docs.txt b/requirements-docs.txt index 51d81c23..fdc1a843 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -6,7 +6,7 @@ mkdocs-render-swagger-plugin~=0.0.4 mkdocs-simple-hooks~=0.1.5 mkdocs-video~=1.5.0 mkdocstrings[python]~=0.20.0 -pydantic~=1.10.0 +pydantic==2.8.2 neoteroi-mkdocs~=1.0.0 tabulate~=0.9.0 scale-llm-engine \ No newline at end of file From b5e4daf646747e43c2b869fef0acb2e007204551 Mon Sep 17 00:00:00 2001 From: Nicolas Tomeo Date: Thu, 11 Jul 2024 12:26:04 +0200 Subject: [PATCH 334/425] fix: Use env AWS_REGION in sqs_client or default to us-west-2 (#563) --- .../model_engine_server/core/celery/celery_autoscaler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py index d8782e35..3abd6d4d 100644 --- a/model-engine/model_engine_server/core/celery/celery_autoscaler.py +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -54,6 +54,7 @@ def excluded_namespaces(): autoscaler_broker = os.environ.get("BROKER_NAME", SQS_BROKER) aws_profile = os.environ.get("AWS_PROFILE") +aws_region = os.environ.get("AWS_REGION", "us-west-2") @dataclasses.dataclass @@ -403,7 +404,7 @@ async def get_broker_metrics( class SQSBroker(AutoscalerBroker): @staticmethod def _get_sqs_queue_size(queue_name: str): - sqs_client = session(aws_profile).client("sqs", region_name="us-west-2") + sqs_client = session(aws_profile).client("sqs", region_name=aws_region) try: total_start_time = time.time() queue_size_hist = [] From 8acb52f6e44b3ae9b9d4fb7e491fa3810dbee241 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 15 Jul 2024 17:46:02 -0700 Subject: [PATCH 335/425] Add support for phi 3 models (#564) * Add support for phi 3 models * Add parameter count for phi 3 * Update tests --- docs/model_zoo.md | 2 + .../use_cases/llm_model_endpoint_use_cases.py | 12 ++ .../repositories/live_tokenizer_repository.py | 6 + .../tests/unit/domain/test_llm_use_cases.py | 152 ++++++++++++++++++ 4 files changed, 172 insertions(+) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index a8f4ae63..7705b633 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -42,6 +42,8 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `gemma-2b-instruct` | ✅ | | vllm | 8192 | | `gemma-7b` | ✅ | | vllm | 8192 | | `gemma-7b-instruct` | ✅ | | vllm | 8192 | +| `phi-3-mini-4k-instruct` | ✅ | | vllm | 4096 | + ## Usage diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index a46b04bb..e15e1002 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -210,6 +210,12 @@ "gemma-2b-instruct", "gemma-7b", "gemma-7b-instruct", + "phi-3-mini-4k-instruct", + "phi-3-mini-128k-instruct", + "phi-3-small-8k-instruct", + "phi-3-small-128k-instruct", + "phi-3-medium-4-instruct", + "phi-3-medium-128k-instruct", ] ), LLMInferenceFramework.LIGHTLLM: set( @@ -2324,6 +2330,12 @@ async def _infer_hardware( model_param_count_b = 47 elif "mixtral-8x22b" in model_name: model_param_count_b = 140 + elif "phi-3-mini" in model_name: + model_param_count_b = 4 + elif "phi-3-small" in model_name: + model_param_count_b = 8 + elif "phi-3-medium" in model_name: + model_param_count_b = 15 else: numbers = re.findall(r"(\d+)b", model_name) if len(numbers) == 0: diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 8dff922b..a5783393 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -85,6 +85,12 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "gemma-2b-instruct": ModelInfo("google/gemma-2b-it", None), "gemma-7b": ModelInfo("google/gemma-7b", None), "gemma-7b-instruct": ModelInfo("google/gemma-7b-it", None), + "phi-3-mini-4k-instruct": ModelInfo("microsoft/phi-3-mini-4k-instruct", None), + "phi-3-mini-128k-instruct": ModelInfo("microsoft/phi-3-mini-128k-instruct", None), + "phi-3-small-8k-instruct": ModelInfo("microsoft/phi-3-small-8k-instruct", None), + "phi-3-small-128k-instruct": ModelInfo("microsoft/phi-3-small-128k-instruct", None), + "phi-3-medium-4-instruct": ModelInfo("microsoft/phi-3-medium-4k-instruct", None), + "phi-3-medium-128k-instruct": ModelInfo("microsoft/phi-3-medium-128k-instruct", None), } diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 125aeab9..e4d107f4 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1937,6 +1937,158 @@ async def test_validate_checkpoint_files_safetensors_with_other_files(): mocked__get_recommended_hardware_config_map(), ) async def test_infer_hardware(fake_llm_artifact_gateway): + # Phi 3 mini from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json + fake_llm_artifact_gateway.model_config = { + "architectures": ["Phi3ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.2", + "use_cache": True, + "attention_bias": False, + "vocab_size": 32064, + } + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-mini-4k-instruct", "") + assert hardware.cpus == 5 + assert hardware.gpus == 1 + assert hardware.memory == "20Gi" + assert hardware.storage == "40Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "phi-3-mini-4k-instruct", "", is_batch_job=True + ) + assert hardware.cpus == 10 + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + + # Phi 3 small from https://huggingface.co/microsoft/Phi-3-small-8k-instruct/blob/main/config.json + fake_llm_artifact_gateway.model_config = { + "architectures": ["Phi3SmallForCausalLM"], + "attention_dropout_prob": 0.0, + "blocksparse_block_size": 64, + "blocksparse_homo_head_pattern": False, + "blocksparse_num_local_blocks": 16, + "blocksparse_triton_kernel_block_size": 64, + "blocksparse_vert_stride": 8, + "bos_token_id": 100257, + "dense_attention_every_n_layers": 2, + "embedding_dropout_prob": 0.1, + "eos_token_id": 100257, + "ff_dim_multiplier": None, + "ff_intermediate_size": 14336, + "ffn_dropout_prob": 0.1, + "gegelu_limit": 20.0, + "gegelu_pad_to_256": True, + "hidden_act": "gegelu", + "hidden_size": 4096, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 8192, + "model_type": "phi3small", + "mup_attn_multiplier": 1.0, + "mup_embedding_multiplier": 10.0, + "mup_use_scaling": True, + "mup_width_multiplier": 8.0, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pad_sequence_to_multiple_of_64": True, + "reorder_and_upcast_attn": False, + "rope_embedding_base": 1000000, + "rope_position_scale": 1.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.1", + "use_cache": True, + "attention_bias": False, + "vocab_size": 100352, + } + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-small-8k-instruct", "") + print(hardware) + assert hardware.cpus == 5 + assert hardware.gpus == 1 + assert hardware.memory == "20Gi" + assert hardware.storage == "40Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "phi-3-small-8k-instruct", "", is_batch_job=True + ) + print(hardware) + assert hardware.cpus == 10 + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + + fake_llm_artifact_gateway.model_config = { + "architectures": ["Phi3ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 17920, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 10, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": True, + "attention_bias": False, + "vocab_size": 32064, + } + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-medium-8k-instruct", "") + assert hardware.cpus == 10 + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "phi-3-medium-8k-instruct", "", is_batch_job=True + ) + assert hardware.cpus == 20 + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" + assert hardware.storage == "96Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + fake_llm_artifact_gateway.model_config = { "architectures": ["MixtralForCausalLM"], "attention_dropout": 0.0, From 6132a3e570a8ea62ef0bf6b1cc4dbe634141c7b3 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 15 Jul 2024 23:00:10 -0700 Subject: [PATCH 336/425] Parse wrapped sync endpoint error (#566) * Parse wrapped sync endpoint error * Fix * More cases * simplify --- ...eaming_model_endpoint_inference_gateway.py | 2 +- ...e_sync_model_endpoint_inference_gateway.py | 23 ++++++--- ...e_sync_model_endpoint_inference_gateway.py | 50 +++++++++++++++++++ 3 files changed, 67 insertions(+), 8 deletions(-) diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index 9c2ff9b7..ae790eef 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -208,7 +208,7 @@ async def streaming_predict( async for item in response: yield SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=item) except UpstreamServiceError as exc: - logger.error(f"Service error on sync task: {exc.content!r}") + logger.error(f"Service error on streaming task: {exc.content!r}") try: error_json = orjson.loads(exc.content.decode("utf-8")) result_traceback = ( diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index add25b7b..48ae3410 100644 --- a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -21,7 +21,6 @@ SyncModelEndpointInferenceGateway, ) from model_engine_server.infra.gateways.k8s_resource_parser import get_node_port -from orjson import JSONDecodeError from tenacity import ( AsyncRetrying, RetryError, @@ -186,17 +185,27 @@ async def predict( except UpstreamServiceError as exc: logger.error(f"Service error on sync task: {exc.content!r}") try: + # Try to parse traceback from the response, fallback to just return all the content if failed. + # Three cases considered: + # detail.traceback + # result."detail.traceback" + # result."detail[]" error_json = orjson.loads(exc.content.decode("utf-8")) - result_traceback = ( - error_json.get("detail", {}).get("traceback") - if isinstance(error_json, dict) - else None - ) + if "result" in error_json: + error_json = orjson.loads(error_json["result"]) + detail = error_json.get("detail", {}) + if not isinstance(detail, dict): + result_traceback = orjson.dumps(error_json) + else: + result_traceback = error_json.get("detail", {}).get( + "traceback", "Failed to parse traceback" + ) return SyncEndpointPredictV1Response( status=TaskStatus.FAILURE, traceback=result_traceback, ) - except JSONDecodeError: + except Exception as e: + logger.error(f"Failed to parse error: {e}") return SyncEndpointPredictV1Response( status=TaskStatus.FAILURE, traceback=exc.content.decode() ) diff --git a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py index 806ee93f..afa1aee5 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py @@ -179,3 +179,53 @@ async def test_predict_raises_traceback_not_json( "result": None, "traceback": "Test traceback content", } + + +@pytest.mark.asyncio +async def test_predict_raises_traceback_wrapped( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] +): + gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + + content = json.dumps( + {"result": json.dumps({"detail": {"traceback": "test_traceback"}})} + ).encode("utf-8") + fake_response = FakeResponse(status=500, content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = await gateway.predict( + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + ) + assert isinstance(response, SyncEndpointPredictV1Response) + assert response.dict() == { + "status": "FAILURE", + "result": None, + "traceback": "test_traceback", + } + + +@pytest.mark.asyncio +async def test_predict_raises_traceback_wrapped_detail_array( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] +): + gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + + content = json.dumps({"result": json.dumps({"detail": [{"error": "error"}]})}).encode("utf-8") + fake_response = FakeResponse(status=500, content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = await gateway.predict( + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + ) + assert isinstance(response, SyncEndpointPredictV1Response) + assert response.dict() == { + "status": "FAILURE", + "result": None, + "traceback": """{"detail":[{"error":"error"}]}""", + } From 8baaefb951cbb1f564a1694b8390e6d630b4fb46 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 15 Jul 2024 23:26:39 -0700 Subject: [PATCH 337/425] Allow request deserialization using alias (#567) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 2 +- .../model_engine_server/inference/batch_inference/dto.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index e15e1002..8ca999c7 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -2508,7 +2508,7 @@ async def execute( job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( created_by=user.user_id, owner=user.team_id, - job_config=engine_request.dict(), + job_config=engine_request.model_dump(by_alias=True), env=batch_bundle.env, command=batch_bundle.command, repo=batch_bundle.image_repository, diff --git a/model-engine/model_engine_server/inference/batch_inference/dto.py b/model-engine/model_engine_server/inference/batch_inference/dto.py index 109050c2..c682e14f 100644 --- a/model-engine/model_engine_server/inference/batch_inference/dto.py +++ b/model-engine/model_engine_server/inference/batch_inference/dto.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class TokenOutput(BaseModel): @@ -149,6 +149,8 @@ class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest): hidden from the DTO exposed to the client. """ + model_config = ConfigDict(populate_by_name=True) + model_cfg: CreateBatchCompletionsModelConfig = Field(alias="model_config") """ Model configuration for the batch inference. Hardware configurations are inferred. From b3d9200c0bfe3fc725e848c94900191c30733389 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 16 Jul 2024 10:09:11 -0700 Subject: [PATCH 338/425] Add earliest log (#568) --- model-engine/model_engine_server/api/app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index 851f0183..362a5686 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -46,6 +46,7 @@ class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): try: + logger.debug(f"Received request at {request.url.path}") LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) LoggerTagManager.set(LoggerTagKey.REQUEST_SIZE, request.headers.get("content-length")) # we intentionally exclude healthcheck routes from the concurrency limiter From 7670d7b4f3852ac3a3fbb4fce52740f2e12f8154 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 16 Jul 2024 17:41:25 -0700 Subject: [PATCH 339/425] Log info instead of debug (#569) --- model-engine/model_engine_server/api/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index 362a5686..c45eedaf 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -46,7 +46,7 @@ class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): try: - logger.debug(f"Received request at {request.url.path}") + logger.info(f"Received request at {request.url.path}") LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) LoggerTagManager.set(LoggerTagKey.REQUEST_SIZE, request.headers.get("content-length")) # we intentionally exclude healthcheck routes from the concurrency limiter From adc6c379900392349772cdf99a4132228f9c32e9 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 22 Jul 2024 11:08:27 -0700 Subject: [PATCH 340/425] Disable data parallelism for batch completions (#570) * Disable data parallelism for batch completions * fix test * ignore coverage for this case * ignore coverage for this case --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 7 +++++++ model-engine/tests/unit/domain/conftest.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 8ca999c7..b0bcbfcd 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -2460,6 +2460,13 @@ async def create_batch_job_bundle( async def execute( self, user: User, request: CreateBatchCompletionsRequest ) -> CreateBatchCompletionsResponse: + if ( + request.data_parallelism is not None and request.data_parallelism > 1 + ): # pragma: no cover + raise ObjectHasInvalidValueException( + "Data parallelism is disabled for batch completions." + ) + request.model_cfg.checkpoint_path = get_checkpoint_path( request.model_cfg.model, request.model_cfg.checkpoint_path ) diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 0a10cb77..aaad807c 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -534,7 +534,7 @@ def create_batch_completions_request() -> CreateBatchCompletionsRequest: model="mpt-7b", checkpoint_path="s3://test_checkpoint_path", labels={}, - num_shards=2, + num_shards=1, ), - data_parallelism=2, + data_parallelism=1, ) From 2558f7d6ee37a8f7a13925ba457a0e1aa172e75d Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 22 Jul 2024 13:57:59 -0700 Subject: [PATCH 341/425] bump vllm batch version (#571) --- .../inference/batch_inference/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt index ca6b220f..e4ad0c3b 100644 --- a/model-engine/model_engine_server/inference/batch_inference/requirements.txt +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -1,4 +1,4 @@ -vllm==0.5.0.post1 +vllm==0.5.1 pydantic>=2 boto3==1.34.15 smart-open==6.4.0 From 758f7bb3d8cd145bcfb4808e4310063298537ed1 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 22 Jul 2024 15:46:17 -0700 Subject: [PATCH 342/425] Add deepseek models (#572) * Add tests * Remove log --- charts/model-engine/values_circleci.yaml | 26 +++- charts/model-engine/values_sample.yaml | 26 +++- docs/model_zoo.md | 88 +++++++------ .../use_cases/llm_model_endpoint_use_cases.py | 8 ++ .../repositories/live_tokenizer_repository.py | 6 + model-engine/tests/unit/conftest.py | 24 ++++ .../tests/unit/domain/test_llm_use_cases.py | 124 ++++++++++++++++++ 7 files changed, 258 insertions(+), 44 deletions(-) diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index b29b18e9..fa3a21b3 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -270,7 +270,7 @@ recommendedHardware: cpus: 80 gpus: 8 memory: 800Gi - storage: 460Gi + storage: 640Gi gpu_type: nvidia-hopper-h100 byModelName: - name: llama-3-8b-instruct-262k @@ -278,4 +278,28 @@ recommendedHardware: gpus: 2 memory: 40Gi storage: 40Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2-lite + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2-lite-instruct + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2-instruct + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi gpu_type: nvidia-hopper-h100 \ No newline at end of file diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 38f631e0..296f3075 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -346,7 +346,7 @@ recommendedHardware: cpus: 80 gpus: 8 memory: 800Gi - storage: 460Gi + storage: 640Gi gpu_type: nvidia-hopper-h100 byModelName: - name: llama-3-8b-instruct-262k @@ -355,3 +355,27 @@ recommendedHardware: memory: 40Gi storage: 40Gi gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2-lite + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2-lite-instruct + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2-instruct + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 \ No newline at end of file diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 7705b633..d8892627 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -2,48 +2,52 @@ Scale hosts the following models in the LLM Engine Model Zoo: -| Model Name | Inference APIs Available | Fine-tuning APIs Available | Inference Frameworks Available | Inference max total tokens (prompt + response) | -| ------------------------ | ------------------------ | -------------------------- | ------------------------------------------ | ---------------------------------------------- | -| `llama-7b` | ✅ | ✅ | deepspeed, text-generation-inference | 2048 | -| `llama-2-7b` | ✅ | ✅ | text-generation-inference, vllm | 4096 | -| `llama-2-7b-chat` | ✅ | | text-generation-inference, vllm | 4096 | -| `llama-2-13b` | ✅ | | text-generation-inference, vllm | 4096 | -| `llama-2-13b-chat` | ✅ | | text-generation-inference, vllm | 4096 | -| `llama-2-70b` | ✅ | ✅ | text-generation-inference, vllm | 4096 | -| `llama-2-70b-chat` | ✅ | | text-generation-inference, vllm | 4096 | -| `llama-3-8b` | ✅ | | vllm | 8192 | -| `llama-3-8b-instruct` | ✅ | | vllm | 8192 | -| `llama-3-70b` | ✅ | | vllm | 8192 | -| `llama-3-70b-instruct` | ✅ | | vllm | 8192 | -| `falcon-7b` | ✅ | | text-generation-inference, vllm | 2048 | -| `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | -| `falcon-40b` | ✅ | | text-generation-inference, vllm | 2048 | -| `falcon-40b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | -| `mpt-7b` | ✅ | | deepspeed, text-generation-inference, vllm | 2048 | -| `mpt-7b-instruct` | ✅ | ✅ | deepspeed, text-generation-inference, vllm | 2048 | -| `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | 2048 | -| `mistral-7b` | ✅ | ✅ | vllm | 8000 | -| `mistral-7b-instruct` | ✅ | ✅ | vllm | 8000 | -| `mixtral-8x7b` | ✅ | | vllm | 32768 | -| `mixtral-8x7b-instruct` | ✅ | | vllm | 32768 | -| `mixtral-8x22b` | ✅ | | vllm | 65536 | -| `mixtral-8x22b-instruct` | ✅ | | vllm | 65536 | -| `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | -| `codellama-70b` | ✅ | | vllm | 16384 | -| `codellama-70b-instruct` | ✅ | | vllm | 4096 | -| `zephyr-7b-alpha` | ✅ | | text-generation-inference, vllm | 32768 | -| `zephyr-7b-beta` | ✅ | | text-generation-inference, vllm | 32768 | -| `gemma-2b` | ✅ | | vllm | 8192 | -| `gemma-2b-instruct` | ✅ | | vllm | 8192 | -| `gemma-7b` | ✅ | | vllm | 8192 | -| `gemma-7b-instruct` | ✅ | | vllm | 8192 | -| `phi-3-mini-4k-instruct` | ✅ | | vllm | 4096 | - +| Model Name | Inference APIs Available | Fine-tuning APIs Available | Inference Frameworks Available | Inference max total tokens (prompt + response) | +| --------------------------------- | ------------------------ | -------------------------- | ------------------------------------------ | ---------------------------------------------- | +| `llama-7b` | ✅ | ✅ | deepspeed, text-generation-inference | 2048 | +| `llama-2-7b` | ✅ | ✅ | text-generation-inference, vllm | 4096 | +| `llama-2-7b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-2-13b` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-2-13b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-2-70b` | ✅ | ✅ | text-generation-inference, vllm | 4096 | +| `llama-2-70b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-3-8b` | ✅ | | vllm | 8192 | +| `llama-3-8b-instruct` | ✅ | | vllm | 8192 | +| `llama-3-70b` | ✅ | | vllm | 8192 | +| `llama-3-70b-instruct` | ✅ | | vllm | 8192 | +| `falcon-7b` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-40b` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-40b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | +| `mpt-7b` | ✅ | | deepspeed, text-generation-inference, vllm | 2048 | +| `mpt-7b-instruct` | ✅ | ✅ | deepspeed, text-generation-inference, vllm | 2048 | +| `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | 2048 | +| `mistral-7b` | ✅ | ✅ | vllm | 8000 | +| `mistral-7b-instruct` | ✅ | ✅ | vllm | 8000 | +| `mixtral-8x7b` | ✅ | | vllm | 32768 | +| `mixtral-8x7b-instruct` | ✅ | | vllm | 32768 | +| `mixtral-8x22b` | ✅ | | vllm | 65536 | +| `mixtral-8x22b-instruct` | ✅ | | vllm | 65536 | +| `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-70b` | ✅ | | vllm | 16384 | +| `codellama-70b-instruct` | ✅ | | vllm | 4096 | +| `zephyr-7b-alpha` | ✅ | | text-generation-inference, vllm | 32768 | +| `zephyr-7b-beta` | ✅ | | text-generation-inference, vllm | 32768 | +| `gemma-2b` | ✅ | | vllm | 8192 | +| `gemma-2b-instruct` | ✅ | | vllm | 8192 | +| `gemma-7b` | ✅ | | vllm | 8192 | +| `gemma-7b-instruct` | ✅ | | vllm | 8192 | +| `phi-3-mini-4k-instruct` | ✅ | | vllm | 4096 | +| `deepseek-coder-v2` | ✅ | | vllm | 131072 | +| `deepseek-coder-v2-instruct` | ✅ | | vllm | 131072 | +| `deepseek-coder-v2-lite` | ✅ | | vllm | 131072 | +| `deepseek-coder-v2-lite-instruct` | ✅ | | vllm | 131072 | + ## Usage diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index b0bcbfcd..0bca59e3 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -216,6 +216,10 @@ "phi-3-small-128k-instruct", "phi-3-medium-4-instruct", "phi-3-medium-128k-instruct", + "deepseek-coder-v2", + "deepseek-coder-v2-instruct", + "deepseek-coder-v2-lite", + "deepseek-coder-v2-lite-instruct", ] ), LLMInferenceFramework.LIGHTLLM: set( @@ -2336,6 +2340,10 @@ async def _infer_hardware( model_param_count_b = 8 elif "phi-3-medium" in model_name: model_param_count_b = 15 + elif "deepseek-coder-v2-lite" in model_name: + model_param_count_b = 16 + elif "deepseek-coder-v2" in model_name: + model_param_count_b = 237 else: numbers = re.findall(r"(\d+)b", model_name) if len(numbers) == 0: diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index a5783393..1132795f 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -91,6 +91,12 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "phi-3-small-128k-instruct": ModelInfo("microsoft/phi-3-small-128k-instruct", None), "phi-3-medium-4-instruct": ModelInfo("microsoft/phi-3-medium-4k-instruct", None), "phi-3-medium-128k-instruct": ModelInfo("microsoft/phi-3-medium-128k-instruct", None), + "deepseek-coder-v2": ModelInfo("deepseek-ai/DeepSeek-Coder-V2-Base", None), + "deepseek-coder-v2-instruct": ModelInfo("deepseek-ai/DeepSeek-Coder-V2-Instruct", None), + "deepseek-coder-v2-lite": ModelInfo("deepseek-ai/DeepSeek-Coder-V2-Lite-Base", None), + "deepseek-coder-v2-lite-instruct": ModelInfo( + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", None + ), } diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index ec3850af..ee58bdc8 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -4583,6 +4583,30 @@ async def async_mock(*args, **kwargs): # noqa gpus: 2 memory: 160Gi storage: 160Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2-lite + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2-lite-instruct + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + - name: deepseek-coder-v2-instruct + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi gpu_type: nvidia-hopper-h100 """, } diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index e4d107f4..4c3e6ce8 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1937,6 +1937,130 @@ async def test_validate_checkpoint_files_safetensors_with_other_files(): mocked__get_recommended_hardware_config_map(), ) async def test_infer_hardware(fake_llm_artifact_gateway): + # deepseek from https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Instruct/raw/main/config.json + fake_llm_artifact_gateway.model_config = { + "architectures": ["DeepseekV2ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "aux_loss_alpha": 0.001, + "bos_token_id": 100000, + "eos_token_id": 100001, + "ep_size": 1, + "first_k_dense_replace": 1, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 12288, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v2", + "moe_intermediate_size": 1536, + "moe_layer_freq": 1, + "n_group": 8, + "n_routed_experts": 160, + "n_shared_experts": 2, + "norm_topk_prob": False, + "num_attention_heads": 128, + "num_experts_per_tok": 6, + "num_hidden_layers": 60, + "num_key_value_heads": 128, + "pretraining_tp": 1, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_theta": 10000, + "routed_scaling_factor": 16.0, + "scoring_func": "softmax", + "seq_aux": True, + "tie_word_embeddings": False, + "topk_group": 3, + "topk_method": "group_limited_greedy", + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": True, + "v_head_dim": 128, + "vocab_size": 102400, + } + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "deepseek-coder-v2-instruct", "") + assert hardware.cpus == 160 + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "640Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "deepseek-coder-v2-instruct", "", is_batch_job=True + ) + assert hardware.cpus == 160 + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "640Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + + # deepseek lite https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct/raw/main/config.json + fake_llm_artifact_gateway.model_config = { + "architectures": ["DeepseekV2ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "aux_loss_alpha": 0.001, + "bos_token_id": 100000, + "eos_token_id": 100001, + "first_k_dense_replace": 1, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 10944, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v2", + "moe_intermediate_size": 1408, + "moe_layer_freq": 1, + "n_group": 1, + "n_routed_experts": 64, + "n_shared_experts": 2, + "norm_topk_prob": False, + "num_attention_heads": 16, + "num_experts_per_tok": 6, + "num_hidden_layers": 27, + "num_key_value_heads": 16, + "pretraining_tp": 1, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_theta": 10000, + "routed_scaling_factor": 1.0, + "scoring_func": "softmax", + "seq_aux": True, + "tie_word_embeddings": False, + "topk_group": 1, + "topk_method": "greedy", + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": True, + "v_head_dim": 128, + "vocab_size": 102400, + } + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "deepseek-coder-v2-lite-instruct", "" + ) + assert hardware.cpus == 160 + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "640Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "deepseek-coder-v2-lite-instruct", "", is_batch_job=True + ) + assert hardware.cpus == 160 + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "640Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + # Phi 3 mini from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json fake_llm_artifact_gateway.model_config = { "architectures": ["Phi3ForCausalLM"], From 04e5818b883a180cfc5709e4e05fb87f8343d085 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 22 Jul 2024 16:52:19 -0700 Subject: [PATCH 343/425] Reduce hardware requirement for deepseek coder lite (#573) --- charts/model-engine/values_circleci.yaml | 12 ------------ charts/model-engine/values_sample.yaml | 12 ------------ model-engine/tests/unit/conftest.py | 12 ------------ model-engine/tests/unit/domain/test_llm_use_cases.py | 8 ++++---- 4 files changed, 4 insertions(+), 40 deletions(-) diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index fa3a21b3..e4f03888 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -279,18 +279,6 @@ recommendedHardware: memory: 40Gi storage: 40Gi gpu_type: nvidia-hopper-h100 - - name: deepseek-coder-v2-lite - cpus: 160 - gpus: 8 - memory: 800Gi - storage: 640Gi - gpu_type: nvidia-hopper-h100 - - name: deepseek-coder-v2-lite-instruct - cpus: 160 - gpus: 8 - memory: 800Gi - storage: 640Gi - gpu_type: nvidia-hopper-h100 - name: deepseek-coder-v2 cpus: 160 gpus: 8 diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 296f3075..c9365466 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -355,18 +355,6 @@ recommendedHardware: memory: 40Gi storage: 40Gi gpu_type: nvidia-hopper-h100 - - name: deepseek-coder-v2-lite - cpus: 160 - gpus: 8 - memory: 800Gi - storage: 640Gi - gpu_type: nvidia-hopper-h100 - - name: deepseek-coder-v2-lite-instruct - cpus: 160 - gpus: 8 - memory: 800Gi - storage: 640Gi - gpu_type: nvidia-hopper-h100 - name: deepseek-coder-v2 cpus: 160 gpus: 8 diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index ee58bdc8..e1763aba 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -4584,18 +4584,6 @@ async def async_mock(*args, **kwargs): # noqa memory: 160Gi storage: 160Gi gpu_type: nvidia-hopper-h100 - - name: deepseek-coder-v2-lite - cpus: 160 - gpus: 8 - memory: 800Gi - storage: 640Gi - gpu_type: nvidia-hopper-h100 - - name: deepseek-coder-v2-lite-instruct - cpus: 160 - gpus: 8 - memory: 800Gi - storage: 640Gi - gpu_type: nvidia-hopper-h100 - name: deepseek-coder-v2 cpus: 160 gpus: 8 diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 4c3e6ce8..319e64bd 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -2046,10 +2046,10 @@ async def test_infer_hardware(fake_llm_artifact_gateway): hardware = await _infer_hardware( fake_llm_artifact_gateway, "deepseek-coder-v2-lite-instruct", "" ) - assert hardware.cpus == 160 - assert hardware.gpus == 8 - assert hardware.memory == "800Gi" - assert hardware.storage == "640Gi" + assert hardware.cpus == 20 + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" + assert hardware.storage == "96Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 hardware = await _infer_hardware( From 7d4ac8664513cb5ae863873d84ce4622ac3c99f8 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 23 Jul 2024 15:04:56 -0700 Subject: [PATCH 344/425] Bump vllm version to 0.5.3post1 (#576) * bump vllm version to 0.5.3post1 * bump batch image also --- .../inference/batch_inference/requirements.txt | 2 +- .../model_engine_server/inference/vllm/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt index e4ad0c3b..89413c72 100644 --- a/model-engine/model_engine_server/inference/batch_inference/requirements.txt +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -1,4 +1,4 @@ -vllm==0.5.1 +vllm==0.5.3.post1 pydantic>=2 boto3==1.34.15 smart-open==6.4.0 diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index e56693e7..8fa7cb6e 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,2 +1,2 @@ -vllm==0.5.0.post1 +vllm==0.5.3.post1 pydantic>=2.0 From 9818676023a1ca9a99ebbb8a39e3b1c0ec668a07 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Tue, 23 Jul 2024 16:02:47 -0700 Subject: [PATCH 345/425] Azure compatibility work for LLM engine (#551) --- charts/model-engine/templates/_helpers.tpl | 14 +- .../model-engine/templates/istio-metrics.yaml | 2 + .../populate_fine_tuning_repository_job.yaml | 58 +++ .../templates/restart_keda_operator.yaml | 57 +++ .../templates/service_account.yaml | 7 + .../templates/service_account_inference.yaml | 7 + .../service_template_config_map.yaml | 16 +- .../templates/trigger_authentication.yaml | 11 + charts/model-engine/values_circleci.yaml | 2 +- .../model_engine_server/api/dependencies.py | 35 +- .../model_engine_server/common/config.py | 11 +- model-engine/model_engine_server/db/base.py | 261 ++++++---- .../inference_autoscaling_metrics_gateway.py | 14 + .../domain/services/model_endpoint_service.py | 2 +- .../use_cases/llm_model_endpoint_use_cases.py | 77 ++- .../streaming_inference_use_cases.py | 2 +- .../use_cases/sync_inference_use_cases.py | 2 +- .../entrypoints/k8s_cache.py | 7 +- ...populate_llm_fine_tuning_job_repository.py | 448 ++++++++++++++++++ .../start_batch_job_orchestration.py | 18 +- .../infra/gateways/__init__.py | 2 + ...b_inference_autoscaling_metrics_gateway.py | 72 +++ ...s_inference_autoscaling_metrics_gateway.py | 6 + .../gateways/resources/k8s_resource_types.py | 8 + .../live_endpoint_resource_gateway.py | 16 +- .../service_template_config_map_circleci.yaml | 4 +- ...bs_file_llm_fine_tune_events_repository.py | 70 ++- .../abs_file_llm_fine_tune_repository.py | 40 +- ...s3_file_llm_fine_tune_events_repository.py | 4 +- ...image_batch_job_llm_fine_tuning_service.py | 7 + .../services/live_model_endpoint_service.py | 2 +- .../service_builder/tasks_v1.py | 19 +- .../service_config_circleci.yaml | 2 +- model-engine/setup.cfg | 2 + model-engine/tests/unit/conftest.py | 8 +- .../tests/unit/domain/test_llm_use_cases.py | 24 +- 36 files changed, 1183 insertions(+), 154 deletions(-) create mode 100644 charts/model-engine/templates/populate_fine_tuning_repository_job.yaml create mode 100644 charts/model-engine/templates/restart_keda_operator.yaml create mode 100644 charts/model-engine/templates/trigger_authentication.yaml create mode 100644 model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py create mode 100644 model-engine/model_engine_server/infra/gateways/asb_inference_autoscaling_metrics_gateway.py diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index d13af358..e0766b26 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -50,6 +50,9 @@ team: infra app.kubernetes.io/version: {{ .Values.tag }} tags.datadoghq.com/version: {{ .Values.tag }} tags.datadoghq.com/env: {{ .Values.context }} +{{- if .Values.azure }} +azure.workload.identity/use: "true" +{{- end }} {{- end }} {{/* @@ -91,6 +94,9 @@ managed-by: {{- include "modelEngine.fullname" . | printf " %s\n" -}} use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: {{- .Values.context | printf " %s" }} tags.datadoghq.com/version: ${GIT_TAG} +{{- if .Values.azure }} +azure.workload.identity/use: "true" +{{- end }} {{- end }} {{- define "modelEngine.serviceTemplateLabels" -}} @@ -246,6 +252,8 @@ env: value: {{ .Values.azure.object_id }} - name: ABS_ACCOUNT_NAME value: {{ .Values.azure.abs_account_name }} + - name: ABS_CONTAINER_NAME + value: {{ .Values.azure.abs_container_name }} {{- end }} {{- end }} @@ -268,6 +276,8 @@ env: {{- if .Values.azure}} - name: ABS_ACCOUNT_NAME value: {{ .Values.azure.abs_account_name }} + - name: ABS_CONTAINER_NAME + value: {{ .Values.azure.abs_container_name }} - name: SERVICEBUS_NAMESPACE value: {{ .Values.azure.servicebus_namespace }} {{- end }} @@ -341,12 +351,12 @@ env: value: {{ .Values.azure.client_id }} - name: AZURE_OBJECT_ID value: {{ .Values.azure.object_id }} - - name: AZURE_KEYVAULT_IDENTITY_CLIENT_ID - value: {{ .Values.azure.keyvault_identity_client_id }} - name: KEYVAULT_NAME value: {{ .Values.azure.keyvault_name }} - name: ABS_ACCOUNT_NAME value: {{ .Values.azure.abs_account_name }} + - name: ABS_CONTAINER_NAME + value: {{ .Values.azure.abs_container_name }} - name: SERVICEBUS_NAMESPACE value: {{ .Values.azure.servicebus_namespace }} {{- end }} diff --git a/charts/model-engine/templates/istio-metrics.yaml b/charts/model-engine/templates/istio-metrics.yaml index 7020e793..4f19cf73 100644 --- a/charts/model-engine/templates/istio-metrics.yaml +++ b/charts/model-engine/templates/istio-metrics.yaml @@ -1,3 +1,4 @@ +{{- if empty .Values.azure }} apiVersion: telemetry.istio.io/v1alpha1 kind: Telemetry metadata: @@ -32,3 +33,4 @@ spec: matchLabels: {{- include "modelEngine.selectorLabels.gateway" . | nindent 6 }} url: https://storage.googleapis.com/istio-build/proxy/attributegen-359dcd3a19f109c50e97517fe6b1e2676e870c4d.wasm +{{- end }} diff --git a/charts/model-engine/templates/populate_fine_tuning_repository_job.yaml b/charts/model-engine/templates/populate_fine_tuning_repository_job.yaml new file mode 100644 index 00000000..080f21e6 --- /dev/null +++ b/charts/model-engine/templates/populate_fine_tuning_repository_job.yaml @@ -0,0 +1,58 @@ +{{- if .Values.populateFineTuningRepository }} +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ include "modelEngine.fullname" . }}-populate-fine-tuning-repository + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": post-install + "helm.sh/hook-weight": "1" + "helm.sh/hook-delete-policy": hook-succeeded +spec: + backoffLimit: 0 + activeDeadlineSeconds: 600 + template: + metadata: + labels: + sidecar.istio.io/inject: "false" + {{- include "modelEngine.labels" . | nindent 8 }} + spec: + restartPolicy: Never + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + containers: + - name: {{ include "modelEngine.fullname" . }} + image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + command: + - dumb-init + - -- + args: + - python + - -m + - model_engine_server.entrypoints.populate_llm_fine_tuning_job_repository + {{- if .Values.azure }} + - --cloud-provider + - azure + {{- end }} + - --initialize-repository + {{- include "modelEngine.serviceEnvGitTagFromHelmVar" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} + {{- with .Values.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} +{{- end }} diff --git a/charts/model-engine/templates/restart_keda_operator.yaml b/charts/model-engine/templates/restart_keda_operator.yaml new file mode 100644 index 00000000..8937ea82 --- /dev/null +++ b/charts/model-engine/templates/restart_keda_operator.yaml @@ -0,0 +1,57 @@ +# needed for the Azure bicep deployment due to using the default keda installation and a workload identity for auth +# see note in https://learn.microsoft.com/en-us/azure/aks/keda-deploy-add-on-arm +# keda-operator pods need AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE, and AZURE_AUTHORITY_HOST env vars injected +{{- if .Values.restartKedaOperator }} +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ include "modelEngine.fullname" . }}-restart-keda-operator + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": post-install + "helm.sh/hook-weight": "1" + "helm.sh/hook-delete-policy": hook-succeeded +spec: + backoffLimit: 0 + activeDeadlineSeconds: 600 + template: + metadata: + labels: + sidecar.istio.io/inject: "false" + {{- include "modelEngine.labels" . | nindent 8 }} + spec: + restartPolicy: Never + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + containers: + - name: {{ include "modelEngine.fullname" . }} + image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + command: + - kubectl + - rollout + - restart + - deployment + - keda-operator + - -n + - kube-system + {{- include "modelEngine.serviceEnvGitTagFromHelmVar" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} + {{- with .Values.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} +{{- end }} diff --git a/charts/model-engine/templates/service_account.yaml b/charts/model-engine/templates/service_account.yaml index 1d0d7d3b..dc41c998 100644 --- a/charts/model-engine/templates/service_account.yaml +++ b/charts/model-engine/templates/service_account.yaml @@ -13,6 +13,13 @@ metadata: {{- with $annotations }} annotations: {{- toYaml . | nindent 4 }} + {{- if $.Values.azure }} + azure.workload.identity/client-id: {{ $.Values.azure.client_id }} + {{- end }} {{- end }} +{{- if $.Values.azure }} +imagePullSecrets: + - name: egp-ecr-regcred +{{- end }} --- {{- end }} diff --git a/charts/model-engine/templates/service_account_inference.yaml b/charts/model-engine/templates/service_account_inference.yaml index 9be37377..c9fa94fb 100644 --- a/charts/model-engine/templates/service_account_inference.yaml +++ b/charts/model-engine/templates/service_account_inference.yaml @@ -13,6 +13,13 @@ metadata: {{- with $annotations }} annotations: {{- toYaml . | nindent 4 }} + {{- if $.Values.azure }} + azure.workload.identity/client-id: {{ $.Values.azure.client_id }} + {{- end }} {{- end }} +{{- if $.Values.azure }} +imagePullSecrets: + - name: egp-ecr-regcred +{{- end }} --- {{- end }} \ No newline at end of file diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 7c2477e8..f721eb46 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -23,6 +23,7 @@ {{- $node_selector := .Values.nodeSelector }} {{- $require_aws_config := not (empty .Values.aws) }} {{- $enable_datadog := .Values.datadog.enabled }} +{{- $azure_cloud_provider := not (empty .Values.azure) }} {{- if .Values.message }} {{- .Values.message }} @@ -431,8 +432,8 @@ data: apiVersion: keda.sh/v1alpha1 kind: ScaledObject metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} labels: {{- $service_template_labels | nindent 8 }} spec: @@ -446,6 +447,16 @@ data: failureThreshold: 3 replicas: ${MIN_WORKERS} triggers: + {{- if $azure_cloud_provider }} + - type: azure-servicebus + metadata: + queueName: "launch-endpoint-autoscaling.${ENDPOINT_ID}" + namespace: ${SERVICEBUS_NAMESPACE} + messageCount: "100" + activationMessageCount: "0" + authenticationRef: + name: "${AUTHENTICATION_REF}" + {{- else }} - type: redis metadata: address: ${REDIS_HOST_PORT} # Format must be host:port @@ -456,6 +467,7 @@ data: enableTLS: "false" unsafeSsl: "false" databaseIndex: "${REDIS_DB_INDEX}" + {{- end }} service.yaml: |- apiVersion: v1 kind: Service diff --git a/charts/model-engine/templates/trigger_authentication.yaml b/charts/model-engine/templates/trigger_authentication.yaml new file mode 100644 index 00000000..63209f68 --- /dev/null +++ b/charts/model-engine/templates/trigger_authentication.yaml @@ -0,0 +1,11 @@ +{{- if .Values.azure }} +apiVersion: keda.sh/v1alpha1 +kind: TriggerAuthentication +metadata: + name: azure-workload-identity + namespace: {{ .Values.config.values.launch.endpoint_namespace }} +spec: + podIdentity: + provider: azure-workload + identityId: {{ .Values.azure.client_id }} +{{- end }} \ No newline at end of file diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index e4f03888..f897f4df 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -152,7 +152,7 @@ config: billing_queue_arn: none cache_redis_aws_url: redis://redis-message-broker-master.default/15 - s3_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET/fine_tune_repository" + cloud_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET/fine_tune_repository" dd_trace_enabled: false istio_enabled: true sensitive_log_mode: false diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index eb2ee227..cb664583 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -21,7 +21,7 @@ logger_name, make_logger, ) -from model_engine_server.db.base import SessionAsync, SessionReadOnlyAsync +from model_engine_server.db.base import get_session_async, get_session_read_only_async from model_engine_server.domain.gateways import ( CronJobGateway, DockerImageBatchJobGateway, @@ -55,6 +55,7 @@ ABSFileStorageGateway, ABSFilesystemGateway, ABSLLMArtifactGateway, + ASBInferenceAutoscalingMetricsGateway, CeleryTaskQueueGateway, DatadogMonitoringMetricsGateway, FakeMonitoringMetricsGateway, @@ -220,8 +221,16 @@ def _get_external_interfaces( else: inference_task_queue_gateway = sqs_task_queue_gateway infra_task_queue_gateway = sqs_task_queue_gateway - resource_gateway = LiveEndpointResourceGateway(queue_delegate=queue_delegate) redis_client = aioredis.Redis(connection_pool=get_or_create_aioredis_pool()) + inference_autoscaling_metrics_gateway = ( + ASBInferenceAutoscalingMetricsGateway() + if infra_config().cloud_provider == "azure" + else RedisInferenceAutoscalingMetricsGateway(redis_client=redis_client) + ) # we can just reuse the existing redis client, we shouldn't get key collisions because of the prefix + resource_gateway = LiveEndpointResourceGateway( + queue_delegate=queue_delegate, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, + ) model_endpoint_cache_repo = RedisModelEndpointCacheRepository( redis_client=redis_client, ) @@ -252,9 +261,6 @@ def _get_external_interfaces( model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) - inference_autoscaling_metrics_gateway = RedisInferenceAutoscalingMetricsGateway( - redis_client=redis_client - ) # we can just reuse the existing redis client, we shouldn't get key collisions because of the prefix model_endpoint_service = LiveModelEndpointService( model_endpoint_record_repository=model_endpoint_record_repo, model_endpoint_infra_gateway=model_endpoint_infra_gateway, @@ -290,14 +296,17 @@ def _get_external_interfaces( cron_job_gateway = LiveCronJobGateway() llm_fine_tune_repository: LLMFineTuneRepository + file_path = os.getenv( + "CLOUD_FILE_LLM_FINE_TUNE_REPOSITORY", + hmi_config.cloud_file_llm_fine_tune_repository, + ) if infra_config().cloud_provider == "azure": - llm_fine_tune_repository = ABSFileLLMFineTuneRepository("not supported yet") + llm_fine_tune_repository = ABSFileLLMFineTuneRepository( + file_path=file_path, + ) else: llm_fine_tune_repository = S3FileLLMFineTuneRepository( - file_path=os.getenv( - "S3_FILE_LLM_FINE_TUNE_REPOSITORY", - hmi_config.s3_file_llm_fine_tune_repository, - ), + file_path=file_path, ) llm_fine_tune_events_repository = ( ABSFileLLMFineTuneEventsRepository() @@ -319,7 +328,7 @@ def _get_external_interfaces( docker_repository: DockerRepository if CIRCLECI: docker_repository = FakeDockerRepository() - elif infra_config().cloud_provider == "azure": + elif infra_config().docker_repo_prefix.endswith("azurecr.io"): docker_repository = ACRDockerRepository() else: docker_repository = ECRDockerRepository() @@ -356,13 +365,13 @@ def _get_external_interfaces( def get_default_external_interfaces() -> ExternalInterfaces: - session = async_scoped_session(SessionAsync, scopefunc=asyncio.current_task) # type: ignore + session = async_scoped_session(get_session_async(), scopefunc=asyncio.current_task) # type: ignore return _get_external_interfaces(read_only=False, session=session) def get_default_external_interfaces_read_only() -> ExternalInterfaces: session = async_scoped_session( # type: ignore - SessionReadOnlyAsync, scopefunc=asyncio.current_task # type: ignore + get_session_read_only_async(), scopefunc=asyncio.current_task # type: ignore ) return _get_external_interfaces(read_only=True, session=session) diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index c66b8df8..16a4dd00 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -34,7 +34,7 @@ redis_cache_expiration_timestamp = None -# duplicated from llm/ia3_finetune +# duplicated from llm/finetune_pipeline def get_model_cache_directory_name(model_name: str): """How huggingface maps model names to directory names in their cache for model files. We adopt this when storing model cache files in s3. @@ -55,7 +55,7 @@ class HostedModelInferenceServiceConfig: sqs_queue_policy_template: str sqs_queue_tag_template: str model_primitive_host: str - s3_file_llm_fine_tune_repository: str + cloud_file_llm_fine_tune_repository: str hf_user_fine_tuned_weights_prefix: str istio_enabled: bool dd_trace_enabled: bool @@ -116,14 +116,17 @@ def cache_redis_host_port(self) -> str: # redis://redis.url:6379/ # -> redis.url:6379 if "rediss://" in self.cache_redis_url: - return self.cache_redis_url.split("rediss://")[1].split("/")[0] + return self.cache_redis_url.split("rediss://")[1].split("@")[-1].split("/")[0] return self.cache_redis_url.split("redis://")[1].split("/")[0] @property def cache_redis_db_index(self) -> int: # redis://redis.url:6379/ # -> - return int(self.cache_redis_url.split("/")[-1]) + try: + return int(self.cache_redis_url.split("/")[-1]) + except ValueError: + return 0 # 0 is the default index used by redis if it's not specified def read_default_config(): diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index 9acf95c0..d6beefe9 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -1,10 +1,11 @@ import asyncio import os import sys +import time from typing import Iterator, Optional import sqlalchemy -from azure.identity import DefaultAzureCredential, ManagedIdentityCredential +from azure.identity import DefaultAzureCredential from azure.keyvault.secrets import SecretClient from model_engine_server.core.aws.secrets import get_key_file from model_engine_server.core.config import infra_config @@ -17,6 +18,9 @@ logger = make_logger(logger_name()) +database_credential_expiration_timestamp = time.time() +EXPIRATION_BUFFER = 300 # 5 minutes + def get_key_file_name(environment: str) -> str: if infra_config().cloud_provider == "azure": @@ -24,7 +28,12 @@ def get_key_file_name(environment: str) -> str: return f"{environment}/ml_infra_pg".replace("training", "prod").replace("-new", "") -def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool = True) -> str: +def get_engine_url( + env: Optional[str] = None, + read_only: bool = True, + sync: bool = True, + reset_expiration_timestamp: bool = False, +) -> str: """Gets the URL of the Postgresql engine depending on the environment.""" if os.getenv("ML_INFRA_DATABASE_URL"): # In CircleCI environment, we set up a test in another container and specify the URL. @@ -50,17 +59,17 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool if infra_config().cloud_provider == "azure": client = SecretClient( vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net", - credential=ManagedIdentityCredential( - client_id=os.getenv("AZURE_KEYVAULT_IDENTITY_CLIENT_ID") - ), # uses a different managed identity than the default + credential=DefaultAzureCredential(), ) db = client.get_secret(key_file).value user = os.environ.get("AZURE_IDENTITY_NAME") - password = ( - DefaultAzureCredential() - .get_token("https://ossrdbms-aad.database.windows.net/.default") - .token + token = DefaultAzureCredential().get_token( + "https://ossrdbms-aad.database.windows.net/.default" ) + password = token.token + if reset_expiration_timestamp: + global database_credential_expiration_timestamp + database_credential_expiration_timestamp = token.expires_on logger.info(f"Connecting to db {db} as user {user}") engine_url = f"postgresql://{user}:{password}@{db}?sslmode=require" @@ -87,97 +96,168 @@ def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool return engine_url -# Try pool_pre_ping=True, see -# https://docs.sqlalchemy.org/en/14/core/engines.html -# ?highlight=create_engine#sqlalchemy.create_engine.params.pool_pre_ping -# tl;dr is hopefully it stops the psycopg errors from happening -# Another probably less impactful (ie it shouldn't increase latency by as much, -# but also shouldn't get rid of as many errors e.g. 95% compared to 99.9%) -# option is to try pool_recycle = something kinda short e.g. a minute -# pool_pre_ping=True seems to not increase latency by very much -# (I profiled 2.7 ms -> 3.3 ms on GET model_bundles/) -# but hopefully should completely eliminate -# any of the postgres connection errors we've been seeing. - -pg_engine = create_engine( - get_engine_url(read_only=False, sync=True), - echo=False, - future=True, - pool_pre_ping=True, - pool_size=20, - max_overflow=30, -) -pg_engine_read_only = create_engine( - get_engine_url(read_only=True, sync=True), - echo=False, - future=True, - pool_pre_ping=True, - pool_size=20, - max_overflow=30, -) -pg_engine_async = create_async_engine( - get_engine_url(read_only=False, sync=False), - echo=False, - future=True, - pool_pre_ping=True, - pool_size=20, - max_overflow=30, -) -pg_engine_read_only_async = create_async_engine( - get_engine_url(read_only=True, sync=False), - echo=False, - future=True, - pool_pre_ping=True, - pool_size=20, - max_overflow=30, -) -pg_engine_async_null_pool = create_async_engine( - get_engine_url(read_only=False, sync=False), - echo=False, - future=True, - poolclass=NullPool, - pool_pre_ping=True, -) - # Synchronous sessions (Session and SessionReadOnly) are fairly straightforward, and both # can be used at any time. To use asynchronous sqlalchemy, use the SessionAsyncNullPool # if you're running a synchronous program where concurrency of database connections is not # super important (e.g. Celery workers that use long-standing connections, and Celery is currently # synchronous). Use SessionAsync and SessionReadOnlyAsync in ASGI applications. -Session = sessionmaker(autocommit=False, autoflush=False, bind=pg_engine) -SessionReadOnly = sessionmaker(autocommit=False, autoflush=False, bind=pg_engine_read_only) -SessionAsync = async_scoped_session( - session_factory=async_sessionmaker( - autocommit=False, - autoflush=False, - bind=pg_engine_async, - expire_on_commit=False, - ), - scopefunc=asyncio.current_task, -) -SessionAsyncNullPool = async_scoped_session( - session_factory=async_sessionmaker( - autocommit=False, - autoflush=False, - bind=pg_engine_async_null_pool, - expire_on_commit=False, - ), - scopefunc=asyncio.current_task, -) -SessionReadOnlyAsync = async_scoped_session( - async_sessionmaker( - autocommit=False, - autoflush=False, - bind=pg_engine_read_only_async, - expire_on_commit=False, - ), - scopefunc=asyncio.current_task, -) + +_Session: Optional[sessionmaker] = None +_SessionReadOnly: Optional[sessionmaker] = None +_SessionAsync: Optional[async_scoped_session] = None +_SessionAsyncNullPool: Optional[async_scoped_session] = None +_SessionReadOnlyAsync: Optional[async_scoped_session] = None + + +def refresh_sessions(): + # Try pool_pre_ping=True, see + # https://docs.sqlalchemy.org/en/14/core/engines.html + # ?highlight=create_engine#sqlalchemy.create_engine.params.pool_pre_ping + # tl;dr is hopefully it stops the psycopg errors from happening + # Another probably less impactful (ie it shouldn't increase latency by as much, + # but also shouldn't get rid of as many errors e.g. 95% compared to 99.9%) + # option is to try pool_recycle = something kinda short e.g. a minute + # pool_pre_ping=True seems to not increase latency by very much + # (I profiled 2.7 ms -> 3.3 ms on GET model_bundles/) + # but hopefully should completely eliminate + # any of the postgres connection errors we've been seeing. + + pg_engine = create_engine( + get_engine_url(read_only=False, sync=True, reset_expiration_timestamp=True), + echo=False, + future=True, + pool_pre_ping=True, + pool_size=20, + max_overflow=30, + ) + pg_engine_read_only = create_engine( + get_engine_url(read_only=True, sync=True), + echo=False, + future=True, + pool_pre_ping=True, + pool_size=20, + max_overflow=30, + ) + pg_engine_async = create_async_engine( + get_engine_url(read_only=False, sync=False), + echo=False, + future=True, + pool_pre_ping=True, + pool_size=20, + max_overflow=30, + ) + pg_engine_read_only_async = create_async_engine( + get_engine_url(read_only=True, sync=False), + echo=False, + future=True, + pool_pre_ping=True, + pool_size=20, + max_overflow=30, + ) + pg_engine_async_null_pool = create_async_engine( + get_engine_url(read_only=False, sync=False), + echo=False, + future=True, + poolclass=NullPool, + pool_pre_ping=True, + ) + + global _Session + global _SessionReadOnly + global _SessionAsync + global _SessionAsyncNullPool + global _SessionReadOnlyAsync + + _Session = sessionmaker(autocommit=False, autoflush=False, bind=pg_engine) + _SessionReadOnly = sessionmaker(autocommit=False, autoflush=False, bind=pg_engine_read_only) + _SessionAsync = async_scoped_session( + session_factory=async_sessionmaker( + autocommit=False, + autoflush=False, + bind=pg_engine_async, + expire_on_commit=False, + ), + scopefunc=asyncio.current_task, + ) + _SessionAsyncNullPool = async_scoped_session( + session_factory=async_sessionmaker( + autocommit=False, + autoflush=False, + bind=pg_engine_async_null_pool, + expire_on_commit=False, + ), + scopefunc=asyncio.current_task, + ) + _SessionReadOnlyAsync = async_scoped_session( + async_sessionmaker( + autocommit=False, + autoflush=False, + bind=pg_engine_read_only_async, + expire_on_commit=False, + ), + scopefunc=asyncio.current_task, + ) + + +refresh_sessions() + + +def get_session(): + global _Session + global database_credential_expiration_timestamp + + if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: + refresh_sessions() + + return _Session + + +def get_session_read_only(): + global _SessionReadOnly + global database_credential_expiration_timestamp + + if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: + refresh_sessions() + + return _SessionReadOnly + + +def get_session_async(): + global _SessionAsync + global database_credential_expiration_timestamp + + if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: + refresh_sessions() + + return _SessionAsync + + +def get_session_async_null_pool(): + global _SessionAsyncNullPool + global database_credential_expiration_timestamp + + if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: + refresh_sessions() + + return _SessionAsyncNullPool + + +def get_session_read_only_async(): + global _SessionReadOnlyAsync + global database_credential_expiration_timestamp + + if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: + refresh_sessions() + + return _SessionReadOnlyAsync + + Base = declarative_base() def get_session_iterator() -> Iterator[sqlalchemy.orm.Session]: """Utility to return an iterator with an instantiated session in the ML Infra database.""" + Session = get_session() session = Session() try: yield session @@ -187,6 +267,7 @@ def get_session_iterator() -> Iterator[sqlalchemy.orm.Session]: def get_read_only_session_iterator() -> Iterator[sqlalchemy.orm.Session]: """Utility to return an iterator with an instantiated session in the ML Infra database.""" + SessionReadOnly = get_session_read_only() session = SessionReadOnly() try: yield session diff --git a/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py index da603b4d..862b6e05 100644 --- a/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py @@ -20,3 +20,17 @@ async def emit_prewarm_metric(self, endpoint_id: str): If you want to prewarm an endpoint, emit a metric here """ pass + + @abstractmethod + async def create_or_update_resources(self, endpoint_id: str): + """ + Create necessary resources for autoscaling metrics + """ + pass + + @abstractmethod + async def delete_resources(self, endpoint_id: str): + """ + Delete necessary resources for autoscaling metrics + """ + pass diff --git a/model-engine/model_engine_server/domain/services/model_endpoint_service.py b/model-engine/model_engine_server/domain/services/model_endpoint_service.py index 8492ae45..4c3471b4 100644 --- a/model-engine/model_engine_server/domain/services/model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/model_endpoint_service.py @@ -53,7 +53,7 @@ def get_streaming_model_endpoint_inference_gateway( """ @abstractmethod - def get_inference_auto_scaling_metrics_gateway( + def get_inference_autoscaling_metrics_gateway( self, ) -> InferenceAutoscalingMetricsGateway: """ diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 0bca59e3..215dd7de 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -348,9 +348,13 @@ def validate_quantization( def validate_checkpoint_path_uri(checkpoint_path: str) -> None: - if not checkpoint_path.startswith("s3://"): + if ( + not checkpoint_path.startswith("s3://") + and not checkpoint_path.startswith("azure://") + and "blob.core.windows.net" not in checkpoint_path + ): raise ObjectHasInvalidValueException( - f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}." + f"Only S3 and Azure Blob Storage paths are supported. Given checkpoint path: {checkpoint_path}." ) if checkpoint_path.endswith(".tar"): raise ObjectHasInvalidValueException( @@ -553,6 +557,22 @@ async def create_text_generation_inference_bundle( def load_model_weights_sub_commands( self, framework, framework_image_tag, checkpoint_path, final_weights_folder + ): + if checkpoint_path.startswith("s3://"): + return self.load_model_weights_sub_commands_s3( + framework, framework_image_tag, checkpoint_path, final_weights_folder + ) + elif checkpoint_path.startswith("azure://") or "blob.core.windows.net" in checkpoint_path: + return self.load_model_weights_sub_commands_abs( + framework, framework_image_tag, checkpoint_path, final_weights_folder + ) + else: + raise ObjectHasInvalidValueException( + f"Only S3 and Azure Blob Storage paths are supported. Given checkpoint path: {checkpoint_path}." + ) + + def load_model_weights_sub_commands_s3( + self, framework, framework_image_tag, checkpoint_path, final_weights_folder ): subcommands = [] s5cmd = "s5cmd" @@ -577,6 +597,38 @@ def load_model_weights_sub_commands( ) return subcommands + def load_model_weights_sub_commands_abs( + self, framework, framework_image_tag, checkpoint_path, final_weights_folder + ): + subcommands = [] + + subcommands.extend( + [ + "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", + "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", + ] + ) + + base_path = checkpoint_path.split("/")[-1] + if base_path.endswith(".tar"): + # If the checkpoint file is a tar file, extract it into final_weights_folder + subcommands.extend( + [ + f"azcopy copy {checkpoint_path} .", + f"mkdir -p {final_weights_folder}", + f"tar --no-same-owner -xf {base_path} -C {final_weights_folder}", + ] + ) + else: + file_selection_str = ( + '--include-pattern "*.model;*.json;*.safetensors" --exclude-pattern "optimizer*"' + ) + subcommands.append( + f"azcopy copy --recursive {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) + + return subcommands + def load_model_files_sub_commands_trt_llm( self, checkpoint_path, @@ -588,9 +640,18 @@ def load_model_files_sub_commands_trt_llm( See llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt """ - subcommands = [ - f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./" - ] + if checkpoint_path.startswith("s3://"): + subcommands = [ + f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./" + ] + else: + subcommands.extend( + [ + "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", + "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", + f"azcopy copy --recursive {os.path.join(checkpoint_path, '*')} ./", + ] + ) return subcommands async def create_deepspeed_bundle( @@ -1022,7 +1083,7 @@ async def execute( post_inference_hooks=request.post_inference_hooks, ) - await self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway().emit_prewarm_metric( + await self.model_endpoint_service.get_inference_autoscaling_metrics_gateway().emit_prewarm_metric( model_endpoint_record.id ) @@ -1607,7 +1668,7 @@ async def execute( inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() autoscaling_metrics_gateway = ( - self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway() + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint.record.id @@ -1935,7 +1996,7 @@ async def execute( self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() ) autoscaling_metrics_gateway = ( - self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway() + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint.record.id diff --git a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py index e17b512e..b358ea04 100644 --- a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py @@ -62,7 +62,7 @@ async def execute( self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() ) autoscaling_metrics_gateway = ( - self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway() + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint_id diff --git a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py index 16196ab6..6835f74a 100644 --- a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py @@ -66,7 +66,7 @@ async def execute( inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() autoscaling_metrics_gateway = ( - self.model_endpoint_service.get_inference_auto_scaling_metrics_gateway() + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint_id diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index df1e9df2..98dcd9b3 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -16,7 +16,7 @@ from model_engine_server.common.env_vars import CIRCLECI from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger -from model_engine_server.db.base import SessionAsyncNullPool +from model_engine_server.db.base import get_session_async_null_pool from model_engine_server.domain.repositories import DockerRepository from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( ASBQueueEndpointResourceDelegate, @@ -102,7 +102,7 @@ async def main(args: Any): monitoring_metrics_gateway = get_monitoring_metrics_gateway() endpoint_record_repo = DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, - session=SessionAsyncNullPool, + session=get_session_async_null_pool(), read_only=True, ) @@ -118,12 +118,13 @@ async def main(args: Any): k8s_resource_manager = LiveEndpointResourceGateway( queue_delegate=queue_delegate, + inference_autoscaling_metrics_gateway=None, ) image_cache_gateway = ImageCacheGateway() docker_repo: DockerRepository if CIRCLECI: docker_repo = FakeDockerRepository() - elif infra_config().cloud_provider == "azure": + elif infra_config().docker_repo_prefix.endswith("azurecr.io"): docker_repo = ACRDockerRepository() else: docker_repo = ECRDockerRepository() diff --git a/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py b/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py new file mode 100644 index 00000000..0b971b06 --- /dev/null +++ b/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py @@ -0,0 +1,448 @@ +""" +This script initializes the file backing the LLMFineTuneRepository and adds a test template to it + +FOR TESTING: +To get the bundle id, print the result of calling +`get_or_create_docker_image_batch_job_bundle(CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, users[0])` +from e2e_test_v1.py + +FOR ACTUAL CREATION: +You will need a docker image from the fine-tuning repo. Refer to llm/finetune_pipeline/README.md for instructions. + +""" +import argparse +import asyncio + +import requests +from model_engine_server.common.config import hmi_config +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.repositories import ( + ABSFileLLMFineTuneRepository, + S3FileLLMFineTuneRepository, +) + +FT_IMAGE_TAG = "00f0edae308d9cd5d9fc24fbd4ee0180e8edc738" + +BUNDLE_NAME_BY_MODEL = { + "7b_or_13b": "fine-tune-upload-safetensors", + "llama_2_34b": "fine-tune-upload-safetensors-34b", + "llama_2_70b": "fine-tune-upload-safetensors-70b", +} + +DEFAULT_7B_MODEL_CONFIG = { + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "num_shards": 1, + "quantize": None, + "cpus": 8, + "memory": "24Gi", + "storage": "40Gi", + "gpus": 1, + "gpu_type": "nvidia-ampere-a10", + "min_workers": 0, + "max_workers": 1, + "per_worker": 10, + "endpoint_type": "streaming", +} + +DEFAULT_13B_MODEL_CONFIG = { + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "num_shards": 2, + "quantize": None, + "cpus": 16, + "memory": "48Gi", + "storage": "80Gi", + "gpus": 2, + "gpu_type": "nvidia-ampere-a10", + "min_workers": 0, + "max_workers": 1, + "per_worker": 10, + "endpoint_type": "streaming", +} + +# DEFAULT_34B_MODEL_CONFIG defined below because it depends on cloud_provider + +DEFAULT_70B_MODEL_CONFIG = { + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "num_shards": 2, + "quantize": None, + "cpus": 20, + "memory": "160Gi", + "storage": "200Gi", + "gpus": 2, + "gpu_type": "nvidia-ampere-a100e", + "min_workers": 0, + "max_workers": 1, + "per_worker": 30, + "endpoint_type": "streaming", +} + + +def create_model_bundle(cloud_provider, url, user, model_type, image_tag): + RESOURCE_REQUESTS_BY_MODEL = { + "7b_or_13b": { + "cpus": 40, + "memory": "160Gi", + "storage": "94Gi", + "gpus": 2 if cloud_provider == "azure" else 4, + "gpu_type": "nvidia-ampere-a10", + }, + "llama_2_34b": { + "cpus": 60, + "memory": "400Gi", + "storage": "300Gi", + "gpus": 4, + "gpu_type": "nvidia-ampere-a100e", + }, + "llama_2_70b": { + "cpus": 80, + "memory": "1000Gi", + "storage": "500Gi", + "gpus": 8, + "gpu_type": "nvidia-ampere-a100e", + }, + } + + name = BUNDLE_NAME_BY_MODEL[model_type] + resource_requests = RESOURCE_REQUESTS_BY_MODEL[model_type] + + response = requests.post( + f"{url}/v1/docker-image-batch-job-bundles", + json={ + "name": name, + "image_repository": "spellbook-finetune", + "image_tag": image_tag, + "command": [ + "dumb-init", + "--", + "ddtrace-run", + "python", + "llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py", + "--config-file", + "/launch_reserved/config_file.json", + ], + "mount_location": "/launch_reserved/config_file.json", + "resource_requests": resource_requests, + "public": True, + }, + headers={"Content-Type": "application/json"}, + auth=requests.auth.HTTPBasicAuth(user, ""), + ).json() + return response["docker_image_batch_job_bundle_id"] + + +async def main(args): + cloud_provider = args.cloud_provider + url = args.url or f"http://model-engine.{hmi_config.gateway_namespace}.svc.cluster.local" + repository = args.repository or hmi_config.cloud_file_llm_fine_tune_repository + user = args.user or "test-user" + initialize_repository = args.initialize_repository + + if repository.startswith("s3://"): + repo = S3FileLLMFineTuneRepository(file_path=repository) + elif repository.startswith("azure://") or "blob.core.windows.net" in repository: + repo = ABSFileLLMFineTuneRepository(file_path=repository) + else: + raise ValueError(f"LLM fine-tune repository must be S3 or ABS file; got {repository}") + + # Clears the file. Needed the first time we're populating data + if initialize_repository: + await repo.initialize_data() + + lora_7b_or_13b_bun = create_model_bundle(cloud_provider, url, user, "7b_or_13b", FT_IMAGE_TAG) + print(f"lora_7b_or_13b bundle id: {lora_7b_or_13b_bun}") + + lora_llama_2_34b_bun = create_model_bundle( + cloud_provider, url, user, "llama_2_34b", FT_IMAGE_TAG + ) + print(f"lora_34b_bun bundle id: {lora_llama_2_34b_bun}") + + lora_llama_2_70b_bun = create_model_bundle( + cloud_provider, url, user, "llama_2_70b", FT_IMAGE_TAG + ) + print(f"llama_2_70b bundle id: {lora_llama_2_70b_bun}") + + await repo.write_job_template_for_model( + "mpt-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "mosaicml/mpt-7b", + "_BASE_MODEL_SHORT": "mpt-7b", + }, + required_params=[], + ), + ) + print("Wrote mpt-7b with lora") + + await repo.write_job_template_for_model( + "mpt-7b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "mosaicml/mpt-7b-instruct", + "_BASE_MODEL_SHORT": "mpt-7b-instruct", + }, + required_params=[], + ), + ) + print("Wrote mpt-7b-instruct with lora") + + await repo.write_job_template_for_model( + "llama-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-7b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-7b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-7b with lora") + + await repo.write_job_template_for_model( + "llama-2-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-7b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-7b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-2-7b with lora") + + await repo.write_job_template_for_model( + "llama-2-7b-chat", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-7b-chat", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-7b-chat", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-2-7b-chat with lora") + + await repo.write_job_template_for_model( + "llama-2-13b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_13B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-13b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-13b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-2-13b with lora") + + await repo.write_job_template_for_model( + "llama-2-13b-chat", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_13B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-13b-chat", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-13b-chat", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-2-13b-chat with lora") + + await repo.write_job_template_for_model( + "llama-2-70b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_llama_2_70b_bun, + launch_endpoint_config=DEFAULT_70B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-70b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-70b", # == create llm endpoint request's model_name + "max_length": 1024, # To prevent OOM on 8xA100e + }, + required_params=[], + ), + ) + print("Wrote llama-2-70b with lora") + + await repo.write_job_template_for_model( + "mistral-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "mistralai/mistral-7b-v0.1", # == model_name inside of training script + "_BASE_MODEL_SHORT": "mistral-7b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote mistral-7b with lora") + + await repo.write_job_template_for_model( + "mistral-7b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "mistralai/mistral-7b-instruct-v0.1", # == model_name inside of training script + "_BASE_MODEL_SHORT": "mistral-7b-instruct", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote mistral-7b-instruct with lora") + await repo.write_job_template_for_model( + "codellama-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-7b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-7b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-7b with lora") + + await repo.write_job_template_for_model( + "codellama-7b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-7b-instruct", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-7b-instruct", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-7b-instruct with lora") + + await repo.write_job_template_for_model( + "codellama-13b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_13B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-13b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-13b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-13b with lora") + + await repo.write_job_template_for_model( + "codellama-13b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_13B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-13b-instruct", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-13b-instruct", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-13b-instruct with lora") + + DEFAULT_34B_MODEL_CONFIG = { + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "num_shards": 2 if cloud_provider == "azure" else 4, + "quantize": None, + "cpus": 32, + "memory": "80Gi", + "storage": "100Gi", + "gpus": 2 if cloud_provider == "azure" else 4, + "gpu_type": "nvidia-ampere-a10", + "min_workers": 0, + "max_workers": 1, + "per_worker": 10, + "endpoint_type": "streaming", + } + + await repo.write_job_template_for_model( + "codellama-34b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_llama_2_34b_bun, + launch_endpoint_config=DEFAULT_34B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-34b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-34b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-34b with lora") + + await repo.write_job_template_for_model( + "codellama-34b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_llama_2_34b_bun, + launch_endpoint_config=DEFAULT_34B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-34b-instruct", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-34b-instruct", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-34b-instruct with lora") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process command line arguments.") + parser.add_argument( + "--cloud-provider", + choices=["aws", "azure"], + help="Cloud provider", + required=False, + default="aws", + ) + parser.add_argument("--url", help="Url to the model-engine gateway", required=False) + parser.add_argument( + "--repository", help="Url to the LLM fine-tuning job repository", required=False + ) + parser.add_argument( + "--user", help="User ID to create Docker image batch job bundles with", required=False + ) + parser.add_argument( + "--initialize-repository", action="store_true", required=False, default=False + ) + args = parser.parse_args() + asyncio.run(main(args)) diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index c9abea51..c059a9eb 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -9,11 +9,12 @@ from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.env_vars import CIRCLECI from model_engine_server.core.config import infra_config -from model_engine_server.db.base import SessionAsyncNullPool +from model_engine_server.db.base import get_session_async_null_pool from model_engine_server.domain.entities import BatchJobSerializationFormat from model_engine_server.domain.gateways import TaskQueueGateway from model_engine_server.infra.gateways import ( ABSFilesystemGateway, + ASBInferenceAutoscalingMetricsGateway, CeleryTaskQueueGateway, LiveAsyncModelEndpointInferenceGateway, LiveBatchJobProgressGateway, @@ -57,7 +58,7 @@ async def run_batch_job( serialization_format: BatchJobSerializationFormat, timeout_seconds: float, ): - session = SessionAsyncNullPool + session = get_session_async_null_pool() pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) redis = aioredis.Redis(connection_pool=pool) sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) @@ -78,7 +79,15 @@ async def run_batch_job( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) - resource_gateway = LiveEndpointResourceGateway(queue_delegate=queue_delegate) + inference_autoscaling_metrics_gateway = ( + ASBInferenceAutoscalingMetricsGateway() + if infra_config().cloud_provider == "azure" + else RedisInferenceAutoscalingMetricsGateway(redis_client=redis) + ) + resource_gateway = LiveEndpointResourceGateway( + queue_delegate=queue_delegate, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, + ) inference_task_queue_gateway: TaskQueueGateway infra_task_queue_gateway: TaskQueueGateway @@ -113,9 +122,6 @@ async def run_batch_job( model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) - inference_autoscaling_metrics_gateway = RedisInferenceAutoscalingMetricsGateway( - redis_client=redis, - ) model_endpoint_service = LiveModelEndpointService( model_endpoint_record_repository=model_endpoint_record_repo, model_endpoint_infra_gateway=model_endpoint_infra_gateway, diff --git a/model-engine/model_engine_server/infra/gateways/__init__.py b/model-engine/model_engine_server/infra/gateways/__init__.py index 5a0d7a90..f8a3ee6e 100644 --- a/model-engine/model_engine_server/infra/gateways/__init__.py +++ b/model-engine/model_engine_server/infra/gateways/__init__.py @@ -3,6 +3,7 @@ from .abs_file_storage_gateway import ABSFileStorageGateway from .abs_filesystem_gateway import ABSFilesystemGateway from .abs_llm_artifact_gateway import ABSLLMArtifactGateway +from .asb_inference_autoscaling_metrics_gateway import ASBInferenceAutoscalingMetricsGateway from .batch_job_orchestration_gateway import BatchJobOrchestrationGateway from .batch_job_progress_gateway import BatchJobProgressGateway from .celery_task_queue_gateway import CeleryTaskQueueGateway @@ -29,6 +30,7 @@ "ABSFileStorageGateway", "ABSFilesystemGateway", "ABSLLMArtifactGateway", + "ASBInferenceAutoscalingMetricsGateway", "BatchJobOrchestrationGateway", "BatchJobProgressGateway", "CeleryTaskQueueGateway", diff --git a/model-engine/model_engine_server/infra/gateways/asb_inference_autoscaling_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/asb_inference_autoscaling_metrics_gateway.py new file mode 100644 index 00000000..6ab06a27 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/asb_inference_autoscaling_metrics_gateway.py @@ -0,0 +1,72 @@ +import os +from datetime import timedelta + +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.servicebus import ServiceBusClient, ServiceBusMessage +from azure.servicebus.management import ServiceBusAdministrationClient +from model_engine_server.domain.gateways.inference_autoscaling_metrics_gateway import ( + InferenceAutoscalingMetricsGateway, +) + +EXPIRY_SECONDS = 60 # 1 minute; this gets added to the cooldown time present in the keda ScaledObject to get total +# scaledown time. This also needs to be larger than the keda ScaledObject's refresh rate. +PREWARM_EXPIRY_SECONDS = 60 * 60 # 1 hour + + +def _get_servicebus_administration_client() -> ServiceBusAdministrationClient: + return ServiceBusAdministrationClient( + f"{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net", + credential=DefaultAzureCredential(), + ) + + +class ASBInferenceAutoscalingMetricsGateway(InferenceAutoscalingMetricsGateway): + @staticmethod + def _find_queue_name(endpoint_id: str): + # Keep in line with keda scaled object yaml + return f"launch-endpoint-autoscaling.{endpoint_id}" + + async def _emit_metric(self, endpoint_id: str, expiry_time: int): + queue_name = self._find_queue_name(endpoint_id) + + servicebus_namespace = os.getenv("SERVICEBUS_NAMESPACE") + if servicebus_namespace is None: + raise ValueError("SERVICEBUS_NAMESPACE env var must be set in Azure") + + with ServiceBusClient( + fully_qualified_namespace=f"{servicebus_namespace}.servicebus.windows.net", + credential=DefaultAzureCredential(), + ) as servicebus_client: + sender = servicebus_client.get_queue_sender(queue_name=queue_name) + with sender: + message = ServiceBusMessage( + "message", time_to_live=timedelta(seconds=expiry_time) + ) # we only care about the length of the queue, not the message values + sender.send_messages(message=message) + + receiver = servicebus_client.get_queue_receiver(queue_name=queue_name) + with receiver: + receiver.peek_messages(max_message_count=1, timeout=1) + + async def emit_inference_autoscaling_metric(self, endpoint_id: str): + await self._emit_metric(endpoint_id, EXPIRY_SECONDS) + + async def emit_prewarm_metric(self, endpoint_id: str): + await self._emit_metric(endpoint_id, PREWARM_EXPIRY_SECONDS) + + async def create_or_update_resources(self, endpoint_id: str): + queue_name = self._find_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + client.create_queue(queue_name=queue_name) + except ResourceExistsError: + pass + + async def delete_resources(self, endpoint_id: str): + queue_name = self._find_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + client.delete_queue(queue_name=queue_name) + except ResourceNotFoundError: + pass diff --git a/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py index a5bcc31e..027442e8 100644 --- a/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py @@ -47,3 +47,9 @@ async def emit_inference_autoscaling_metric(self, endpoint_id: str): async def emit_prewarm_metric(self, endpoint_id: str): await self._emit_metric(endpoint_id, PREWARM_EXPIRY_SECONDS) + + async def create_or_update_resources(self, endpoint_id: str): + pass # no extra resources needed + + async def delete_resources(self, endpoint_id: str): + pass diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 6b958920..b83f404e 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -1,4 +1,5 @@ import hashlib +import os from datetime import datetime from typing import Any, Dict, List, Optional, Sequence, TypedDict, Union @@ -295,6 +296,8 @@ class KedaScaledObjectArguments(_BaseEndpointArguments): # CONCURRENCY: float # TODO add in when we scale from 1 -> N pods REDIS_HOST_PORT: str REDIS_DB_INDEX: str + SERVICEBUS_NAMESPACE: Optional[str] + AUTHENTICATION_REF: str class UserConfigArguments(_BaseEndpointArguments): @@ -528,6 +531,9 @@ def get_endpoint_resource_arguments_from_request( main_env.append({"name": "AWS_PROFILE", "value": build_endpoint_request.aws_role}) # NOTE: /opt/.aws/config is where service_template_config_map.yaml mounts the AWS config file, point to the mount for boto clients main_env.append({"name": "AWS_CONFIG_FILE", "value": "/opt/.aws/config"}) + abs_account_name = os.getenv("ABS_ACCOUNT_NAME") + if abs_account_name is not None: + main_env.append({"name": "ABS_ACCOUNT_NAME", "value": abs_account_name}) infra_service_config_volume_mount_path = "/infra-config" forwarder_config_file_name = "service--forwarder.yaml" @@ -1146,6 +1152,8 @@ def get_endpoint_resource_arguments_from_request( # CONCURRENCY=build_endpoint_request.concurrency, REDIS_HOST_PORT=hmi_config.cache_redis_host_port, REDIS_DB_INDEX=hmi_config.cache_redis_db_index, + SERVICEBUS_NAMESPACE=os.getenv("SERVICEBUS_NAMESPACE"), + AUTHENTICATION_REF="azure-workload-identity", ) elif endpoint_resource_name == "service": # Use ClusterIP by default for sync endpoint. diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index 516470ba..fb637c10 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -8,6 +8,7 @@ ModelEndpointType, ) from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways import InferenceAutoscalingMetricsGateway from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, EndpointResourceGatewayCreateOrUpdateResourcesResponse, @@ -24,9 +25,14 @@ class LiveEndpointResourceGateway(EndpointResourceGateway[QueueInfo]): - def __init__(self, queue_delegate: QueueEndpointResourceDelegate): + def __init__( + self, + queue_delegate: QueueEndpointResourceDelegate, + inference_autoscaling_metrics_gateway: Optional[InferenceAutoscalingMetricsGateway], + ): self.k8s_delegate = K8SEndpointResourceDelegate() self.queue_delegate = queue_delegate + self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway async def create_queue( self, @@ -59,6 +65,11 @@ async def create_or_update_resources( queue_name = None queue_url = None + if self.inference_autoscaling_metrics_gateway is not None: + await self.inference_autoscaling_metrics_gateway.create_or_update_resources( + endpoint_record.id + ) + await self.k8s_delegate.create_or_update_resources( request=request, sqs_queue_name=queue_name, @@ -109,4 +120,7 @@ async def delete_resources( logger.warning("Could not delete SQS resources", exc_info=e) sqs_result = False + if self.inference_autoscaling_metrics_gateway is not None: + await self.inference_autoscaling_metrics_gateway.delete_resources(endpoint_id) + return k8s_result and sqs_result diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 1c0cb8e2..3311a509 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -2567,8 +2567,8 @@ data: apiVersion: keda.sh/v1alpha1 kind: ScaledObject metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} labels: user_id: ${OWNER} team: ${TEAM} diff --git a/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py index 9d33585b..8a221c9f 100644 --- a/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py +++ b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py @@ -1,19 +1,83 @@ -from typing import List +import json +import os +from json.decoder import JSONDecodeError +from typing import IO, List +import smart_open +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent +from model_engine_server.domain.exceptions import ObjectNotFoundException from model_engine_server.domain.repositories.llm_fine_tune_events_repository import ( LLMFineTuneEventsRepository, ) +# Echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py +ABS_HF_USER_FINE_TUNED_WEIGHTS_PREFIX = ( + f"azure://{os.getenv('ABS_CONTAINER_NAME')}/hosted-model-inference/fine_tuned_weights" +) + class ABSFileLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): def __init__(self): pass + def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), + ) + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + # echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py + def _get_model_cache_directory_name(self, model_name: str): + """How huggingface maps model names to directory names in their cache for model files. + We adopt this when storing model cache files in ABS. + + Args: + model_name (str): Name of the huggingface model + """ + name = "models--" + model_name.replace("/", "--") + return name + + def _get_file_location(self, user_id: str, model_endpoint_name: str): + model_cache_name = self._get_model_cache_directory_name(model_endpoint_name) + abs_file_location = ( + f"{ABS_HF_USER_FINE_TUNED_WEIGHTS_PREFIX}/{user_id}/{model_cache_name}.jsonl" + ) + return abs_file_location + async def get_fine_tune_events( self, user_id: str, model_endpoint_name: str ) -> List[LLMFineTuneEvent]: - raise NotImplementedError("ABS not supported yet") + abs_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + try: + with self._open(abs_file_location, "r") as f: + lines = f.readlines() + final_events = [] + for line in lines: + try: + event_dict = json.loads(line) + event = LLMFineTuneEvent( + timestamp=event_dict["timestamp"], + message=str(event_dict["message"]), + level=event_dict.get("level", "info"), + ) + except JSONDecodeError: + event = LLMFineTuneEvent( + message=line, + level="info", + ) + final_events.append(event) + return final_events + except Exception as exc: # TODO better exception + raise ObjectNotFoundException from exc async def initialize_events(self, user_id: str, model_endpoint_name: str) -> None: - raise NotImplementedError("ABS not supported yet") + abs_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + self._open(abs_file_location, "w") diff --git a/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py index a205fd83..fc8860f2 100644 --- a/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py +++ b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py @@ -1,5 +1,10 @@ -from typing import Optional +import json +import os +from typing import IO, Dict, Optional +import smart_open +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository @@ -8,12 +13,41 @@ class ABSFileLLMFineTuneRepository(LLMFineTuneRepository): def __init__(self, file_path: str): self.file_path = file_path + def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), + ) + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + @staticmethod + def _get_key(model_name, fine_tuning_method): + return f"{model_name}-{fine_tuning_method}" # possible for collisions but we control these names + async def get_job_template_for_model( self, model_name: str, fine_tuning_method: str ) -> Optional[LLMFineTuneTemplate]: - raise NotImplementedError("ABS not supported yet") + with self._open(self.file_path, "r") as f: + data = json.load(f) + key = self._get_key(model_name, fine_tuning_method) + job_template_dict = data.get(key, None) + if job_template_dict is None: + return None + return LLMFineTuneTemplate.parse_obj(job_template_dict) async def write_job_template_for_model( self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate ): - raise NotImplementedError("ABS not supported yet") + # Use locally in script + with self._open(self.file_path, "r") as f: + data: Dict = json.load(f) + key = self._get_key(model_name, fine_tuning_method) + data[key] = dict(job_template) + with self._open(self.file_path, "w") as f: + json.dump(data, f) + + async def initialize_data(self): + # Use locally in script + with self._open(self.file_path, "w") as f: + json.dump({}, f) diff --git a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py index 6993d1d0..2dfcbc76 100644 --- a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py +++ b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py @@ -12,7 +12,7 @@ LLMFineTuneEventsRepository, ) -# Echoes llm/ia3_finetune/docker_image_fine_tuning_entrypoint.py +# Echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py S3_HF_USER_FINE_TUNED_WEIGHTS_PREFIX = ( f"s3://{infra_config().s3_bucket}/hosted-model-inference/fine_tuned_weights" ) @@ -36,7 +36,7 @@ def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: transport_params = {"client": client} return smart_open.open(uri, mode, transport_params=transport_params) - # echoes llm/ia3_finetune/docker_image_fine_tuning_entrypoint.py + # echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py def _get_model_cache_directory_name(self, model_name: str): """How huggingface maps model names to directory names in their cache for model files. We adopt this when storing model cache files in s3. diff --git a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py index fd60966f..d35ef21a 100644 --- a/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py +++ b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import FineTuneHparamValueType from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob @@ -88,6 +89,12 @@ async def create_fine_tune( job_config=dict( **labels, gateway_url=os.getenv("GATEWAY_URL"), + cloud_provider=infra_config().cloud_provider, + aws_profile=infra_config().profile_ml_worker, + s3_bucket=infra_config().s3_bucket, + azure_client_id=os.getenv("AZURE_CLIENT_ID"), + abs_account_name=os.getenv("ABS_ACCOUNT_NAME"), + abs_container_name=os.getenv("ABS_CONTAINER_NAME"), user_id=owner, training_file=training_file, validation_file=validation_file, diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index b9ffc260..475fbca8 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -80,7 +80,7 @@ def get_streaming_model_endpoint_inference_gateway( ) -> StreamingModelEndpointInferenceGateway: return self.streaming_model_endpoint_inference_gateway - def get_inference_auto_scaling_metrics_gateway( + def get_inference_autoscaling_metrics_gateway( self, ) -> InferenceAutoscalingMetricsGateway: return self.inference_autoscaling_metrics_gateway diff --git a/model-engine/model_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py index 7615e123..e9eca9a6 100644 --- a/model-engine/model_engine_server/service_builder/tasks_v1.py +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -14,9 +14,14 @@ from model_engine_server.common.env_vars import CIRCLECI from model_engine_server.core.config import infra_config from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway -from model_engine_server.db.base import SessionAsyncNullPool +from model_engine_server.db.base import get_session_async_null_pool from model_engine_server.domain.repositories import DockerRepository -from model_engine_server.infra.gateways import ABSFilesystemGateway, S3FilesystemGateway +from model_engine_server.infra.gateways import ( + ABSFilesystemGateway, + ASBInferenceAutoscalingMetricsGateway, + RedisInferenceAutoscalingMetricsGateway, + S3FilesystemGateway, +) from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( ASBQueueEndpointResourceDelegate, ) @@ -69,14 +74,20 @@ def get_live_endpoint_builder_service( docker_repository: DockerRepository if CIRCLECI: docker_repository = FakeDockerRepository() - elif infra_config().cloud_provider == "azure": + elif infra_config().docker_repo_prefix.endswith("azurecr.io"): docker_repository = ACRDockerRepository() else: docker_repository = ECRDockerRepository() + inference_autoscaling_metrics_gateway = ( + ASBInferenceAutoscalingMetricsGateway() + if infra_config().cloud_provider == "azure" + else RedisInferenceAutoscalingMetricsGateway(redis_client=redis) + ) service = LiveEndpointBuilderService( docker_repository=docker_repository, resource_gateway=LiveEndpointResourceGateway( queue_delegate=queue_delegate, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, ), monitoring_metrics_gateway=monitoring_metrics_gateway, model_endpoint_record_repository=DbModelEndpointRecordRepository( @@ -96,7 +107,7 @@ def get_live_endpoint_builder_service( async def _build_endpoint( build_endpoint_request: BuildEndpointRequest, ) -> BuildEndpointResponse: - session = SessionAsyncNullPool + session = get_session_async_null_pool() pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) redis = aioredis.Redis(connection_pool=pool) service: LiveEndpointBuilderService = get_live_endpoint_builder_service(session, redis) diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml index de998c27..31644ad2 100644 --- a/model-engine/service_configs/service_config_circleci.yaml +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -53,7 +53,7 @@ billing_queue_arn: none # There's a separate piece of infra that caches k8s state onto redis, so we need a url to it cache_redis_aws_url: redis://127.0.0.1:6379/15 -s3_file_llm_fine_tune_repository: "s3://model-engine-integration-tests/fine_tune_repository/circleci" +cloud_file_llm_fine_tune_repository: "s3://model-engine-integration-tests/fine_tune_repository/circleci" dd_trace_enabled: false istio_enabled: true diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg index 76fa54d1..6b2273a8 100644 --- a/model-engine/setup.cfg +++ b/model-engine/setup.cfg @@ -14,6 +14,8 @@ omit = model_engine_server/infra/gateways/abs_file_storage_gateway.py model_engine_server/infra/gateways/abs_filesystem_gateway.py model_engine_server/infra/gateways/abs_llm_artifact_gateway.py + model_engine_server/infra/gateways/asb_inference_autoscaling_metrics_gateway.py + model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py model_engine_server/infra/gateways/resources/k8s_resource_types.py diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index e1763aba..0e8b59d9 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -1589,6 +1589,12 @@ async def emit_inference_autoscaling_metric(self, endpoint_id: str): async def emit_prewarm_metric(self, endpoint_id: str): pass + async def create_or_update_resources(self, endpoint_id: str): + pass + + async def delete_resources(self, endpoint_id: str): + pass + class FakeStreamingStorageGateway(StreamingStorageGateway): def put_record(self, stream_name: str, record: Dict[str, Any]): @@ -1661,7 +1667,7 @@ def get_sync_model_endpoint_inference_gateway( ) -> SyncModelEndpointInferenceGateway: return self.sync_model_endpoint_inference_gateway - def get_inference_auto_scaling_metrics_gateway( + def get_inference_autoscaling_metrics_gateway( self, ) -> InferenceAutoscalingMetricsGateway: return self.inference_autoscaling_metrics_gateway diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 319e64bd..d36165d0 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -406,7 +406,7 @@ def test_load_model_weights_sub_commands( framework = LLMInferenceFramework.VLLM framework_image_tag = "0.2.7" - checkpoint_path = "fake-checkpoint" + checkpoint_path = "s3://fake-checkpoint" final_weights_folder = "test_folder" subcommands = llm_bundle_use_case.load_model_weights_sub_commands( @@ -414,13 +414,13 @@ def test_load_model_weights_sub_commands( ) expected_result = [ - "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder", + "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' s3://fake-checkpoint/* test_folder", ] assert expected_result == subcommands framework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE framework_image_tag = "1.0.0" - checkpoint_path = "fake-checkpoint" + checkpoint_path = "s3://fake-checkpoint" final_weights_folder = "test_folder" subcommands = llm_bundle_use_case.load_model_weights_sub_commands( @@ -429,7 +429,23 @@ def test_load_model_weights_sub_commands( expected_result = [ "s5cmd > /dev/null || conda install -c conda-forge -y s5cmd", - "s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' fake-checkpoint/* test_folder", + "s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' s3://fake-checkpoint/* test_folder", + ] + assert expected_result == subcommands + + framework = LLMInferenceFramework.VLLM + framework_image_tag = "0.2.7" + checkpoint_path = "azure://fake-checkpoint" + final_weights_folder = "test_folder" + + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder + ) + + expected_result = [ + "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", + "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", + 'azcopy copy --recursive --include-pattern "*.model;*.json;*.safetensors" --exclude-pattern "optimizer*" azure://fake-checkpoint/* test_folder', ] assert expected_result == subcommands From afbb98a892d12088344d17205a25cb136a8d2d5d Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 23 Jul 2024 19:21:10 -0700 Subject: [PATCH 346/425] Add Llama 3.1 models (#577) * add llama 3.1 * oops * fix vllm server * try again * try again * try again * try again * try again * docs.md * run precommit * delete comments --- docs/model_zoo.md | 4 +++ .../use_cases/llm_model_endpoint_use_cases.py | 6 ++++ .../inference/vllm/vllm_server.py | 29 +++++++++++-------- .../repositories/live_tokenizer_repository.py | 6 ++++ 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index d8892627..63c5bd1f 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -15,6 +15,10 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `llama-3-8b-instruct` | ✅ | | vllm | 8192 | | `llama-3-70b` | ✅ | | vllm | 8192 | | `llama-3-70b-instruct` | ✅ | | vllm | 8192 | +| `llama-3-1-8b` | ✅ | | vllm | 131072 | +| `llama-3-1-8b-instruct` | ✅ | | vllm | 131072 | +| `llama-3-1-70b` | ✅ | | vllm | 131072 | +| `llama-3-1-70b-instruct` | ✅ | | vllm | 131072 | | `falcon-7b` | ✅ | | text-generation-inference, vllm | 2048 | | `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | | `falcon-40b` | ✅ | | text-generation-inference, vllm | 2048 | diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 215dd7de..7bd5a878 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -181,6 +181,12 @@ "llama-3-8b-instruct-262k", "llama-3-70b", "llama-3-70b-instruct", + "llama-3-1-8b", + "llama-3-1-8b-instruct", + "llama-3-1-70b", + "llama-3-1-70b-instruct", + "llama-3-1-405b", + "llama-3-1-405b-instruct", "falcon-7b", "falcon-7b-instruct", "falcon-40b", diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 68a9a263..a30d4a25 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -1,4 +1,3 @@ -import argparse import asyncio import code import json @@ -27,7 +26,7 @@ from vllm.outputs import CompletionOutput from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob -from vllm.utils import random_uuid +from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION logging.basicConfig( @@ -253,19 +252,18 @@ def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: return [extract_logprobs(logprobs) for logprobs in output_logprobs] -def parse_args(): - parser = make_arg_parser() +def parse_args(parser: FlexibleArgumentParser): + parser = make_arg_parser(parser) return parser.parse_args() if __name__ == "__main__": check_unknown_startup_memory_usage() - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default=None) # None == IPv4 / IPv6 dualstack - parser.add_argument("--port", type=int, default=5005) - parser = AsyncEngineArgs.add_cli_args(parser) - args = parse_args() + parser = FlexibleArgumentParser() + # host, port, and AsyncEngineArgs are already given by make_arg_parser() in parse_args() + # host == None -> IPv4 / IPv6 dualstack + args = parse_args(parser) logger.info("vLLM version %s", VLLM_VERSION) logger.info("args: %s", args) @@ -287,11 +285,18 @@ def parse_args(): model_config, served_model_names, args.response_role, - args.lora_modules, - args.chat_template, + lora_modules=args.lora_modules, + chat_template=args.chat_template, + prompt_adapters=args.prompt_adapters, + request_logger=None, ) openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules + engine, + model_config, + served_model_names, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=None, ) uvicorn.run( diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 1132795f..9f55217c 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -46,6 +46,12 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "llama-3-8b-instruct-262k": ModelInfo("gradientai/Llama-3-8B-Instruct-262k", None), "llama-3-70b": ModelInfo("meta-llama/Meta-Llama-3-70B", None), "llama-3-70b-instruct": ModelInfo("meta-llama/Meta-Llama-3-70B-Instruct", None), + "llama-3-1-8b": ModelInfo("meta-llama/Meta-Llama-3.1-8B", None), + "llama-3-1-8b-instruct": ModelInfo("meta-llama/Meta-Llama-3.1-8B-Instruct", None), + "llama-3-1-70b": ModelInfo("meta-llama/Meta-Llama-3.1-70B", None), + "llama-3-1-70b-instruct": ModelInfo("meta-llama/Meta-Llama-3.1-70B-Instruct", None), + "llama-3-1-405b": ModelInfo("meta-llama/Meta-Llama-3.1-405B", None), + "llama-3-1-405b-instruct": ModelInfo("meta-llama/Meta-Llama-3.1-405B-Instruct", None), "falcon-7b": ModelInfo("tiiuae/falcon-7b", None), "falcon-7b-instruct": ModelInfo("tiiuae/falcon-7b-instruct", None), "falcon-40b": ModelInfo("tiiuae/falcon-40b", None), From 42f1de11b16ff18419348984ebec1ef480e4f555 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 26 Jul 2024 13:00:55 -0700 Subject: [PATCH 347/425] Shared pydantic configs (#578) --- .../model_engine_server/common/dtos/batch_jobs.py | 3 +-- model-engine/model_engine_server/common/dtos/core.py | 8 +------- .../common/dtos/docker_repository.py | 2 +- .../common/dtos/endpoint_builder.py | 2 +- .../model_engine_server/common/dtos/files.py | 2 +- model-engine/model_engine_server/common/dtos/llms.py | 2 +- .../model_engine_server/common/dtos/model_bundles.py | 2 +- .../common/dtos/model_endpoints.py | 2 +- .../common/dtos/resource_manager.py | 2 +- .../model_engine_server/common/dtos/tasks.py | 2 +- .../model_engine_server/common/dtos/triggers.py | 2 +- .../model_engine_server/common/pydantic_types.py | 8 ++++++++ .../pydantic_types/endpoint_predict_payload.py | 10 ---------- .../domain/entities/batch_job_entity.py | 2 +- .../entities/docker_image_batch_job_bundle_entity.py | 2 +- .../domain/entities/file_entity.py | 2 +- .../domain/entities/llm_fine_tune_entity.py | 2 +- .../domain/entities/model_bundle_entity.py | 2 +- .../domain/entities/model_endpoint_entity.py | 2 +- .../domain/entities/owned_entity.py | 2 +- .../domain/entities/trigger_entity.py | 2 +- .../domain/gateways/monitoring_metrics_gateway.py | 2 +- .../gateways/live_model_endpoints_schema_gateway.py | 12 ++++++------ .../gateways/resources/endpoint_resource_gateway.py | 2 +- .../db_docker_image_batch_job_bundle_repository.py | 2 +- .../infra/repositories/db_trigger_repository.py | 2 +- 26 files changed, 37 insertions(+), 46 deletions(-) create mode 100644 model-engine/model_engine_server/common/pydantic_types.py delete mode 100644 model-engine/model_engine_server/common/pydantic_types/endpoint_predict_payload.py diff --git a/model-engine/model_engine_server/common/dtos/batch_jobs.py b/model-engine/model_engine_server/common/dtos/batch_jobs.py index 0600df22..12225537 100644 --- a/model-engine/model_engine_server/common/dtos/batch_jobs.py +++ b/model-engine/model_engine_server/common/dtos/batch_jobs.py @@ -5,6 +5,7 @@ from typing import Any, Collection, Dict, List, Optional from model_engine_server.common import dict_not_none +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, model_validator from model_engine_server.domain.entities import ( BatchJobSerializationFormat, BatchJobStatus, @@ -13,7 +14,6 @@ GpuType, StorageSpecificationType, ) -from pydantic import BaseModel, ConfigDict, model_validator class CreateBatchJobResourceRequests(BaseModel): @@ -27,7 +27,6 @@ class CreateBatchJobResourceRequests(BaseModel): class CreateBatchJobV1Request(BaseModel): - model_config = ConfigDict(protected_namespaces=()) model_bundle_id: str input_path: str serialization_format: BatchJobSerializationFormat diff --git a/model-engine/model_engine_server/common/dtos/core.py b/model-engine/model_engine_server/common/dtos/core.py index ad709658..c8d2ee22 100644 --- a/model-engine/model_engine_server/common/dtos/core.py +++ b/model-engine/model_engine_server/common/dtos/core.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, BeforeValidator, ConfigDict, HttpUrl, TypeAdapter +from pydantic import BeforeValidator, HttpUrl, TypeAdapter from typing_extensions import Annotated # See: https://github.com/pydantic/pydantic/issues/7186 @@ -9,9 +9,3 @@ str, BeforeValidator(lambda value: HttpUrlTypeAdapter.validate_python(value) and value), ] - - -class LLMEngineModel(BaseModel): - """Common pydantic configurations for model engine""" - - model_config = ConfigDict(protected_namespaces=()) diff --git a/model-engine/model_engine_server/common/dtos/docker_repository.py b/model-engine/model_engine_server/common/dtos/docker_repository.py index 694c4098..a5ddc1cf 100644 --- a/model-engine/model_engine_server/common/dtos/docker_repository.py +++ b/model-engine/model_engine_server/common/dtos/docker_repository.py @@ -1,6 +1,6 @@ from typing import Dict, Optional -from pydantic import BaseModel +from model_engine_server.common.pydantic_types import BaseModel class BuildImageRequest(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/endpoint_builder.py b/model-engine/model_engine_server/common/dtos/endpoint_builder.py index 8ec2d2f9..64ea43d0 100644 --- a/model-engine/model_engine_server/common/dtos/endpoint_builder.py +++ b/model-engine/model_engine_server/common/dtos/endpoint_builder.py @@ -1,6 +1,7 @@ from enum import Enum from typing import Any, Dict, List, Optional +from model_engine_server.common.pydantic_types import BaseModel from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, @@ -8,7 +9,6 @@ ModelEndpointRecord, StorageSpecificationType, ) -from pydantic import BaseModel class BuildEndpointRequest(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/files.py b/model-engine/model_engine_server/common/dtos/files.py index 94b54474..8f09d9a3 100644 --- a/model-engine/model_engine_server/common/dtos/files.py +++ b/model-engine/model_engine_server/common/dtos/files.py @@ -3,7 +3,7 @@ """ from typing import List -from pydantic import BaseModel, Field +from model_engine_server.common.pydantic_types import BaseModel, Field class UploadFileResponse(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index b35bff36..232b8ba6 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -14,6 +14,7 @@ ModelEndpointType, StorageSpecificationType, ) +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field from model_engine_server.domain.entities import ( BatchJobStatus, CallbackAuth, @@ -24,7 +25,6 @@ ModelEndpointStatus, Quantization, ) -from pydantic import BaseModel, ConfigDict, Field class CreateLLMModelEndpointV1Request(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/model_bundles.py b/model-engine/model_engine_server/common/dtos/model_bundles.py index d49537c4..cd6f7f30 100644 --- a/model-engine/model_engine_server/common/dtos/model_bundles.py +++ b/model-engine/model_engine_server/common/dtos/model_bundles.py @@ -5,12 +5,12 @@ from enum import Enum from typing import Any, Dict, List, Optional +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field from model_engine_server.domain.entities import ( ModelBundleEnvironmentParams, ModelBundleFlavors, ModelBundlePackagingType, ) -from pydantic import BaseModel, ConfigDict, Field class CreateModelBundleV1Request(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py index cfeb44bf..e8620890 100644 --- a/model-engine/model_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional from model_engine_server.common.dtos.core import HttpUrlStr +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, @@ -22,7 +23,6 @@ ModelEndpointType, StorageSpecificationType, ) -from pydantic import BaseModel, ConfigDict, Field class BrokerType(str, Enum): diff --git a/model-engine/model_engine_server/common/dtos/resource_manager.py b/model-engine/model_engine_server/common/dtos/resource_manager.py index e156f77e..f173a1a8 100644 --- a/model-engine/model_engine_server/common/dtos/resource_manager.py +++ b/model-engine/model_engine_server/common/dtos/resource_manager.py @@ -1,5 +1,5 @@ from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from pydantic import BaseModel +from model_engine_server.common.pydantic_types import BaseModel class CreateOrUpdateResourcesRequest(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/tasks.py b/model-engine/model_engine_server/common/dtos/tasks.py index b9919f68..874c50a8 100644 --- a/model-engine/model_engine_server/common/dtos/tasks.py +++ b/model-engine/model_engine_server/common/dtos/tasks.py @@ -5,8 +5,8 @@ from enum import Enum from typing import Any, Optional +from model_engine_server.common.pydantic_types import BaseModel, Field, RootModel from model_engine_server.domain.entities import CallbackAuth -from pydantic import BaseModel, Field, RootModel class ResponseSchema(RootModel): diff --git a/model-engine/model_engine_server/common/dtos/triggers.py b/model-engine/model_engine_server/common/dtos/triggers.py index 3d75376e..ed8d45cf 100644 --- a/model-engine/model_engine_server/common/dtos/triggers.py +++ b/model-engine/model_engine_server/common/dtos/triggers.py @@ -4,7 +4,7 @@ import datetime from typing import Any, Dict, List, Optional -from pydantic import BaseModel, ConfigDict, Field +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field class CreateTriggerV1Request(BaseModel): diff --git a/model-engine/model_engine_server/common/pydantic_types.py b/model-engine/model_engine_server/common/pydantic_types.py new file mode 100644 index 00000000..19fc99c0 --- /dev/null +++ b/model-engine/model_engine_server/common/pydantic_types.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel as PydanticBaseModel +from pydantic import ConfigDict, Field, RootModel, ValidationError, model_validator # noqa: F401 + + +class BaseModel(PydanticBaseModel): + """Common pydantic configurations for model engine""" + + model_config = ConfigDict(protected_namespaces=()) diff --git a/model-engine/model_engine_server/common/pydantic_types/endpoint_predict_payload.py b/model-engine/model_engine_server/common/pydantic_types/endpoint_predict_payload.py deleted file mode 100644 index 218099a1..00000000 --- a/model-engine/model_engine_server/common/pydantic_types/endpoint_predict_payload.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Any, Optional - -from pydantic import BaseModel - - -class EndpointPredictPayload(BaseModel): - url: Optional[str] = None - args: Optional[Any] = None - cloudpickle: Optional[str] = None - return_pickled: bool diff --git a/model-engine/model_engine_server/domain/entities/batch_job_entity.py b/model-engine/model_engine_server/domain/entities/batch_job_entity.py index 62238d66..a1b2ea1b 100644 --- a/model-engine/model_engine_server/domain/entities/batch_job_entity.py +++ b/model-engine/model_engine_server/domain/entities/batch_job_entity.py @@ -2,10 +2,10 @@ from enum import Enum from typing import Dict, Optional +from model_engine_server.common.pydantic_types import BaseModel from model_engine_server.domain.entities.model_bundle_entity import ModelBundle from model_engine_server.domain.entities.model_endpoint_entity import ModelEndpoint from model_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import BaseModel class BatchJobStatus(str, Enum): diff --git a/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py b/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py index 9213af13..a3914e3f 100644 --- a/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py @@ -1,9 +1,9 @@ import datetime from typing import Dict, List, Optional +from model_engine_server.common.pydantic_types import ConfigDict from model_engine_server.domain.entities import GpuType from model_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import ConfigDict class DockerImageBatchJobBundle(OwnedEntity): diff --git a/model-engine/model_engine_server/domain/entities/file_entity.py b/model-engine/model_engine_server/domain/entities/file_entity.py index f21314eb..f4d5a1f4 100644 --- a/model-engine/model_engine_server/domain/entities/file_entity.py +++ b/model-engine/model_engine_server/domain/entities/file_entity.py @@ -1,6 +1,6 @@ from datetime import datetime -from pydantic import BaseModel +from model_engine_server.common.pydantic_types import BaseModel class FileMetadata(BaseModel): diff --git a/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py b/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py index b18bbdd2..14a0c97a 100644 --- a/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, ConfigDict +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict class LLMFineTuneTemplate(BaseModel): diff --git a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py index e3ceb836..d5d0a5f3 100644 --- a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py @@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Union from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field, model_validator from model_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Literal diff --git a/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py index a0f84c4e..814c8683 100644 --- a/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py @@ -4,6 +4,7 @@ from fastapi.openapi.models import OpenAPI from model_engine_server.common import dict_not_none +from model_engine_server.common.pydantic_types import BaseModel, Field, RootModel from model_engine_server.common.serialization_utils import b64_to_python_json, python_json_to_b64 from model_engine_server.domain.entities.common_types import ( CpuSpecificationType, @@ -12,7 +13,6 @@ from model_engine_server.domain.entities.gpu_type import GpuType from model_engine_server.domain.entities.model_bundle_entity import ModelBundle from model_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import BaseModel, Field, RootModel from typing_extensions import Literal ModelEndpointsSchema = OpenAPI diff --git a/model-engine/model_engine_server/domain/entities/owned_entity.py b/model-engine/model_engine_server/domain/entities/owned_entity.py index 7ea79a0d..6eaf0737 100644 --- a/model-engine/model_engine_server/domain/entities/owned_entity.py +++ b/model-engine/model_engine_server/domain/entities/owned_entity.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from model_engine_server.common.pydantic_types import BaseModel class OwnedEntity(BaseModel): diff --git a/model-engine/model_engine_server/domain/entities/trigger_entity.py b/model-engine/model_engine_server/domain/entities/trigger_entity.py index 0d68ec92..989b44cf 100644 --- a/model-engine/model_engine_server/domain/entities/trigger_entity.py +++ b/model-engine/model_engine_server/domain/entities/trigger_entity.py @@ -1,8 +1,8 @@ import datetime from typing import Any, Dict, Optional +from model_engine_server.common.pydantic_types import ConfigDict from model_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import ConfigDict class Trigger(OwnedEntity): diff --git a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py index 23759911..c17e5b09 100644 --- a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -10,8 +10,8 @@ from typing import Optional from model_engine_server.common.dtos.llms import TokenUsage +from model_engine_server.common.pydantic_types import BaseModel from model_engine_server.core.auth.authentication_repository import User -from pydantic import BaseModel class MetricMetadata(BaseModel): diff --git a/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py index f6f51d9c..883335bf 100644 --- a/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, Sequence, Set, Type, Union +import pydantic from fastapi import routing from fastapi._compat import GenerateJsonSchema, get_definitions from fastapi.openapi.constants import REF_TEMPLATE @@ -25,7 +26,6 @@ ) from model_engine_server.domain.gateways import ModelEndpointsSchemaGateway from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway -from pydantic import BaseModel from starlette.routing import BaseRoute # Caches the default model definition so we don't need to recompute every time @@ -57,7 +57,7 @@ def get_model_endpoints_schema( model_endpoint_names = [] model_definitions = {} for record in model_endpoint_records: - response_model: Type[BaseModel] = GetAsyncTaskV1Response + response_model: Type[pydantic.BaseModel] = GetAsyncTaskV1Response predict_stub: Callable[[EndpointPredictV1Request], Any] = predict_stub_async base_route = "/v1/async-tasks" if record.endpoint_type == ModelEndpointType.SYNC: @@ -164,7 +164,7 @@ def update_model_definitions_with_prefix( Returns: Dict[str, Any]: The updated model definitions. """ - models: List[Type[BaseModel]] = [ + models: List[Type[pydantic.BaseModel]] = [ EndpointPredictV1Request, GetAsyncTaskV1Response, SyncEndpointPredictV1Response, @@ -198,7 +198,7 @@ def update_schema_refs_with_prefix(schema: Dict[str, Any], prefix: str) -> None: LiveModelEndpointsSchemaGateway.update_schema_refs_with_prefix(item, prefix) @staticmethod - def get_model_name_map(prefix: str) -> Dict[Union[Type[BaseModel], Type[Enum]], str]: + def get_model_name_map(prefix: str) -> Dict[Union[Type[pydantic.BaseModel], Type[Enum]], str]: return { CallbackAuth: "CallbackAuth", CallbackBasicAuth: "CallbackBasicAuth", @@ -254,8 +254,8 @@ def get_default_model_definitions() -> Dict[str, Any]: @staticmethod def get_model_definitions( - models: Sequence[Type[BaseModel]], - model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], + models: Sequence[Type[pydantic.BaseModel]], + model_name_map: Dict[Union[Type[pydantic.BaseModel], Type[Enum]], str], ) -> Dict[str, Any]: """Get OpenAPI definitions for provided models using the name provided in model_name_map""" diff --git a/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py index 1e6a3f6d..8c2779b3 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py @@ -2,13 +2,13 @@ from typing import Dict, Generic, Sequence, Tuple, TypeVar from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.pydantic_types import BaseModel from model_engine_server.domain.entities import ( ModelEndpointInfraState, ModelEndpointRecord, ModelEndpointType, ) from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import QueueInfo -from pydantic import BaseModel __all__: Sequence[str] = ( "EndpointResourceGateway", diff --git a/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py b/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py index 9e3cd17d..b97f57f4 100644 --- a/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py @@ -2,6 +2,7 @@ from model_engine_server.common import dict_not_none from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.common.pydantic_types import ValidationError from model_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle from model_engine_server.domain.entities import GpuType from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( @@ -15,7 +16,6 @@ DbRepositoryMixin, raise_if_read_only, ) -from pydantic import ValidationError class DbDockerImageBatchJobBundleRepository(DockerImageBatchJobBundleRepository, DbRepositoryMixin): diff --git a/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py b/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py index b4114358..367942f9 100644 --- a/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional, Sequence from model_engine_server.common import dict_not_none +from model_engine_server.common.pydantic_types import ValidationError from model_engine_server.db.models import Trigger as OrmTrigger from model_engine_server.domain.entities.trigger_entity import Trigger from model_engine_server.domain.exceptions import ( @@ -12,7 +13,6 @@ DbRepositoryMixin, raise_if_read_only, ) -from pydantic import ValidationError from sqlalchemy.exc import IntegrityError From 87d816eba799e2ce56bb9c47395fdef4bbed790d Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 26 Jul 2024 15:00:02 -0700 Subject: [PATCH 348/425] Add autogenerated openai spec (#579) --- model-engine/README.md | 6 + .../common/types/gen/openai.py | 5995 +++++++ model-engine/mypy.ini | 3 + requirements-dev.txt | 3 +- scripts/generate-openai-types.sh | 16 + scripts/openai-spec.yaml | 14267 ++++++++++++++++ 6 files changed, 20289 insertions(+), 1 deletion(-) create mode 100644 model-engine/model_engine_server/common/types/gen/openai.py create mode 100755 scripts/generate-openai-types.sh create mode 100644 scripts/openai-spec.yaml diff --git a/model-engine/README.md b/model-engine/README.md index 7d87e120..3f6b9579 100644 --- a/model-engine/README.md +++ b/model-engine/README.md @@ -41,3 +41,9 @@ Run `mypy . --install-types` to set up mypy. Most of the business logic in Model Engine should contain unit tests, located in [`tests/unit`](./tests/unit). To run the tests, run `pytest`. + +## Generating OpenAI types +We've decided to make our V2 APIs OpenAI compatible. We generate the +corresponding Pydantic models: +1. Fetch the OpenAPI spec from https://github.com/openai/openai-openapi/blob/master/openapi.yaml +2. Run scripts/generate-openai-types.sh diff --git a/model-engine/model_engine_server/common/types/gen/openai.py b/model-engine/model_engine_server/common/types/gen/openai.py new file mode 100644 index 00000000..2c58f13b --- /dev/null +++ b/model-engine/model_engine_server/common/types/gen/openai.py @@ -0,0 +1,5995 @@ +# generated by datamodel-codegen: +# filename: openai-spec.yaml +# timestamp: 2024-07-26T21:34:42+00:00 + +from __future__ import annotations + +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel +from typing_extensions import Annotated + + +class Error(BaseModel): + code: str + message: str + param: str + type: str + + +class ErrorResponse(BaseModel): + error: Error + + +class Object(Enum): + list = "list" + + +class DeleteModelResponse(BaseModel): + id: str + deleted: bool + object: str + + +class Model1(Enum): + gpt_3_5_turbo_instruct = "gpt-3.5-turbo-instruct" + davinci_002 = "davinci-002" + babbage_002 = "babbage-002" + + +class Prompt(RootModel[Optional[List[int]]]): + root: Annotated[ + Optional[List[int]], + Field( + "<|endoftext|>", + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + examples=["[1212, 318, 257, 1332, 13]"], + min_length=1, + ), + ] = "<|endoftext|>" + + +class Prompt1Item(RootModel[List[int]]): + root: Annotated[List[int], Field(min_length=1)] + + +class Prompt1(RootModel[Optional[List[Prompt1Item]]]): + root: Annotated[ + Optional[List[Prompt1Item]], + Field( + "<|endoftext|>", + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + examples=["[[1212, 318, 257, 1332, 13]]"], + min_length=1, + ), + ] = "<|endoftext|>" + + +class Stop(RootModel[Optional[List[str]]]): + root: Annotated[ + Optional[List[str]], + Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + max_length=4, + min_length=1, + ), + ] = None + + +class FinishReason(Enum): + stop = "stop" + length = "length" + content_filter = "content_filter" + + +class Logprobs(BaseModel): + text_offset: Optional[List[int]] = None + token_logprobs: Optional[List[float]] = None + tokens: Optional[List[str]] = None + top_logprobs: Optional[List[Dict[str, float]]] = None + + +class Choice(BaseModel): + finish_reason: Annotated[ + FinishReason, + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\nor `content_filter` if content was omitted due to a flag from our content filters.\n" + ), + ] + index: int + logprobs: Logprobs + text: str + + +class Object1(Enum): + text_completion = "text_completion" + + +class Type(Enum): + image_url = "image_url" + + +class Detail(Enum): + auto = "auto" + low = "low" + high = "high" + + +class ImageUrl(BaseModel): + url: Annotated[ + AnyUrl, + Field(description="Either a URL of the image or the base64 encoded image data."), + ] + detail: Annotated[ + Optional[Detail], + Field( + "auto", + description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding).", + ), + ] + + +class ChatCompletionRequestMessageContentPartImage(BaseModel): + type: Annotated[Type, Field(description="The type of the content part.")] + image_url: ImageUrl + + +class Type1(Enum): + text = "text" + + +class ChatCompletionRequestMessageContentPartText(BaseModel): + type: Annotated[Type1, Field(description="The type of the content part.")] + text: Annotated[str, Field(description="The text content.")] + + +class Role(Enum): + system = "system" + + +class ChatCompletionRequestSystemMessage(BaseModel): + content: Annotated[str, Field(description="The contents of the system message.")] + role: Annotated[ + Role, + Field(description="The role of the messages author, in this case `system`."), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ), + ] + + +class Role1(Enum): + user = "user" + + +class Role2(Enum): + assistant = "assistant" + + +class FunctionCall(BaseModel): + arguments: Annotated[ + str, + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] + name: Annotated[str, Field(description="The name of the function to call.")] + + +class Weight(Enum): + integer_0 = 0 + integer_1 = 1 + + +class Role3(Enum): + tool = "tool" + + +class ChatCompletionRequestToolMessage(BaseModel): + role: Annotated[ + Role3, + Field(description="The role of the messages author, in this case `tool`."), + ] + content: Annotated[str, Field(description="The contents of the tool message.")] + tool_call_id: Annotated[str, Field(description="Tool call that this message is responding to.")] + + +class Role4(Enum): + function = "function" + + +class ChatCompletionRequestFunctionMessage(BaseModel): + role: Annotated[ + Role4, + Field(description="The role of the messages author, in this case `function`."), + ] + content: Annotated[str, Field(description="The contents of the function message.")] + name: Annotated[str, Field(description="The name of the function to call.")] + + +class FunctionParameters(BaseModel): + pass + model_config = ConfigDict( + extra="allow", + ) + + +class ChatCompletionFunctions(BaseModel): + description: Annotated[ + Optional[str], + Field( + None, + description="A description of what the function does, used by the model to choose when and how to call the function.", + ), + ] + name: Annotated[ + str, + Field( + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + parameters: Optional[FunctionParameters] = None + + +class ChatCompletionFunctionCallOption(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class Type2(Enum): + function = "function" + + +class FunctionObject(BaseModel): + description: Annotated[ + Optional[str], + Field( + None, + description="A description of what the function does, used by the model to choose when and how to call the function.", + ), + ] + name: Annotated[ + str, + Field( + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + parameters: Optional[FunctionParameters] = None + + +class ChatCompletionToolChoiceOption1(Enum): + none = "none" + auto = "auto" + required = "required" + + +class Function(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class ChatCompletionNamedToolChoice(BaseModel): + type: Annotated[ + Type2, + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: Function + + +class ParallelToolCalls(RootModel[bool]): + root: Annotated[ + bool, + Field( + description="Whether to enable [parallel function calling](/docs/guides/function-calling/parallel-function-calling) during tool use." + ), + ] + + +class Function1(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + arguments: Annotated[ + str, + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] + + +class ChatCompletionMessageToolCall(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call.")] + type: Annotated[ + Type2, + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: Annotated[Function1, Field(description="The function that the model called.")] + + +class Function2(BaseModel): + name: Annotated[Optional[str], Field(None, description="The name of the function to call.")] + arguments: Annotated[ + Optional[str], + Field( + None, + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ), + ] + + +class ChatCompletionMessageToolCallChunk(BaseModel): + index: int + id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] + type: Annotated[ + Optional[Type2], + Field( + None, + description="The type of the tool. Currently, only `function` is supported.", + ), + ] + function: Optional[Function2] = None + + +class ChatCompletionRole(Enum): + system = "system" + user = "user" + assistant = "assistant" + tool = "tool" + function = "function" + + +class ChatCompletionStreamOptions(BaseModel): + include_usage: Annotated[ + Optional[bool], + Field( + None, + description="If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.\n", + ), + ] + + +class Role5(Enum): + assistant = "assistant" + + +class FunctionCall2(BaseModel): + arguments: Annotated[ + Optional[str], + Field( + None, + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ), + ] + name: Annotated[Optional[str], Field(None, description="The name of the function to call.")] + + +class Role6(Enum): + system = "system" + user = "user" + assistant = "assistant" + tool = "tool" + + +class ChatCompletionStreamResponseDelta(BaseModel): + content: Annotated[Optional[str], Field(None, description="The contents of the chunk message.")] + function_call: Annotated[ + Optional[FunctionCall2], + Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ), + ] + tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = None + role: Annotated[ + Optional[Role6], + Field(None, description="The role of the author of this message."), + ] + + +class Model2(Enum): + gpt_4o = "gpt-4o" + gpt_4o_2024_05_13 = "gpt-4o-2024-05-13" + gpt_4o_mini = "gpt-4o-mini" + gpt_4o_mini_2024_07_18 = "gpt-4o-mini-2024-07-18" + gpt_4_turbo = "gpt-4-turbo" + gpt_4_turbo_2024_04_09 = "gpt-4-turbo-2024-04-09" + gpt_4_0125_preview = "gpt-4-0125-preview" + gpt_4_turbo_preview = "gpt-4-turbo-preview" + gpt_4_1106_preview = "gpt-4-1106-preview" + gpt_4_vision_preview = "gpt-4-vision-preview" + gpt_4 = "gpt-4" + gpt_4_0314 = "gpt-4-0314" + gpt_4_0613 = "gpt-4-0613" + gpt_4_32k = "gpt-4-32k" + gpt_4_32k_0314 = "gpt-4-32k-0314" + gpt_4_32k_0613 = "gpt-4-32k-0613" + gpt_3_5_turbo = "gpt-3.5-turbo" + gpt_3_5_turbo_16k = "gpt-3.5-turbo-16k" + gpt_3_5_turbo_0301 = "gpt-3.5-turbo-0301" + gpt_3_5_turbo_0613 = "gpt-3.5-turbo-0613" + gpt_3_5_turbo_1106 = "gpt-3.5-turbo-1106" + gpt_3_5_turbo_0125 = "gpt-3.5-turbo-0125" + gpt_3_5_turbo_16k_0613 = "gpt-3.5-turbo-16k-0613" + + +class Type6(Enum): + text = "text" + json_object = "json_object" + + +class ResponseFormat(BaseModel): + type: Annotated[ + Optional[Type6], + Field( + "text", + description="Must be one of `text` or `json_object`.", + examples=["json_object"], + ), + ] + + +class ServiceTier(Enum): + auto = "auto" + default = "default" + + +class Stop1(RootModel[List[str]]): + root: Annotated[ + List[str], + Field( + description="Up to 4 sequences where the API will stop generating further tokens.\n", + max_length=4, + min_length=1, + ), + ] + + +class FunctionCall3(Enum): + none = "none" + auto = "auto" + + +class FinishReason1(Enum): + stop = "stop" + length = "length" + tool_calls = "tool_calls" + content_filter = "content_filter" + function_call = "function_call" + + +class ServiceTier1(Enum): + scale = "scale" + default = "default" + + +class Object2(Enum): + chat_completion = "chat.completion" + + +class FinishReason2(Enum): + stop = "stop" + length = "length" + function_call = "function_call" + content_filter = "content_filter" + + +class TopLogprob(BaseModel): + token: Annotated[str, Field(description="The token.")] + logprob: Annotated[ + float, + Field( + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely." + ), + ] + bytes: Annotated[ + List[int], + Field( + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." + ), + ] + + +class ChatCompletionTokenLogprob(BaseModel): + token: Annotated[str, Field(description="The token.")] + logprob: Annotated[ + float, + Field( + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely." + ), + ] + bytes: Annotated[ + List[int], + Field( + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." + ), + ] + top_logprobs: Annotated[ + List[TopLogprob], + Field( + description="List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned." + ), + ] + + +class Object4(Enum): + list = "list" + + +class Logprobs2(BaseModel): + content: Annotated[ + List[ChatCompletionTokenLogprob], + Field(description="A list of message content tokens with log probability information."), + ] + + +class FinishReason3(Enum): + stop = "stop" + length = "length" + tool_calls = "tool_calls" + content_filter = "content_filter" + function_call = "function_call" + + +class Choice3(BaseModel): + delta: ChatCompletionStreamResponseDelta + logprobs: Annotated[ + Optional[Logprobs2], + Field(None, description="Log probability information for the choice."), + ] + finish_reason: Annotated[ + FinishReason3, + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + + +class Object5(Enum): + chat_completion_chunk = "chat.completion.chunk" + + +class Usage(BaseModel): + completion_tokens: Annotated[ + int, Field(description="Number of tokens in the generated completion.") + ] + prompt_tokens: Annotated[int, Field(description="Number of tokens in the prompt.")] + total_tokens: Annotated[ + int, + Field(description="Total number of tokens used in the request (prompt + completion)."), + ] + + +class CreateChatCompletionStreamResponse(BaseModel): + id: Annotated[ + str, + Field( + description="A unique identifier for the chat completion. Each chunk has the same ID." + ), + ] + choices: Annotated[ + List[Choice3], + Field( + description='A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the\nlast chunk if you set `stream_options: {"include_usage": true}`.\n' + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp." + ), + ] + model: Annotated[str, Field(description="The model to generate the completion.")] + service_tier: Annotated[ + Optional[ServiceTier1], + Field( + None, + description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", + examples=["scale"], + ), + ] + system_fingerprint: Annotated[ + Optional[str], + Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ), + ] + object: Annotated[ + Object5, + Field(description="The object type, which is always `chat.completion.chunk`."), + ] + usage: Annotated[ + Optional[Usage], + Field( + None, + description='An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.\nWhen present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.\n', + ), + ] + + +class CreateChatCompletionImageResponse(BaseModel): + pass + + +class Model3(Enum): + dall_e_2 = "dall-e-2" + dall_e_3 = "dall-e-3" + + +class Quality(Enum): + standard = "standard" + hd = "hd" + + +class ResponseFormat1(Enum): + url = "url" + b64_json = "b64_json" + + +class Size(Enum): + field_256x256 = "256x256" + field_512x512 = "512x512" + field_1024x1024 = "1024x1024" + field_1792x1024 = "1792x1024" + field_1024x1792 = "1024x1792" + + +class Style(Enum): + vivid = "vivid" + natural = "natural" + + +class CreateImageRequest(BaseModel): + prompt: Annotated[ + str, + Field( + description="A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`.", + examples=["A cute baby sea otter"], + ), + ] + model: Annotated[ + Optional[Union[str, Model3]], + Field( + "dall-e-2", + description="The model to use for image generation.", + examples=["dall-e-3"], + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + examples=[1], + ge=1, + le=10, + ), + ] + quality: Annotated[ + Optional[Quality], + Field( + "standard", + description="The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`.", + examples=["standard"], + ), + ] + response_format: Annotated[ + Optional[ResponseFormat1], + Field( + "url", + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + examples=["url"], + ), + ] + size: Annotated[ + Optional[Size], + Field( + "1024x1024", + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models.", + examples=["1024x1024"], + ), + ] + style: Annotated[ + Optional[Style], + Field( + "vivid", + description="The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`.", + examples=["vivid"], + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class Image(BaseModel): + b64_json: Annotated[ + Optional[str], + Field( + None, + description="The base64-encoded JSON of the generated image, if `response_format` is `b64_json`.", + ), + ] + url: Annotated[ + Optional[str], + Field( + None, + description="The URL of the generated image, if `response_format` is `url` (default).", + ), + ] + revised_prompt: Annotated[ + Optional[str], + Field( + None, + description="The prompt that was used to generate the image, if there was any revision to the prompt.", + ), + ] + + +class Model4(Enum): + dall_e_2 = "dall-e-2" + + +class Size1(Enum): + field_256x256 = "256x256" + field_512x512 = "512x512" + field_1024x1024 = "1024x1024" + + +class CreateImageEditRequest(BaseModel): + image: Annotated[ + bytes, + Field( + description="The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask." + ), + ] + prompt: Annotated[ + str, + Field( + description="A text description of the desired image(s). The maximum length is 1000 characters.", + examples=["A cute baby sea otter wearing a beret"], + ), + ] + mask: Annotated[ + Optional[bytes], + Field( + None, + description="An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`.", + ), + ] + model: Annotated[ + Optional[Union[str, Model4]], + Field( + "dall-e-2", + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + examples=["dall-e-2"], + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="The number of images to generate. Must be between 1 and 10.", + examples=[1], + ge=1, + le=10, + ), + ] + size: Annotated[ + Optional[Size1], + Field( + "1024x1024", + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + examples=["1024x1024"], + ), + ] + response_format: Annotated[ + Optional[ResponseFormat1], + Field( + "url", + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + examples=["url"], + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class CreateImageVariationRequest(BaseModel): + image: Annotated[ + bytes, + Field( + description="The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square." + ), + ] + model: Annotated[ + Optional[Union[str, Model4]], + Field( + "dall-e-2", + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + examples=["dall-e-2"], + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + examples=[1], + ge=1, + le=10, + ), + ] + response_format: Annotated[ + Optional[ResponseFormat1], + Field( + "url", + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + examples=["url"], + ), + ] + size: Annotated[ + Optional[Size1], + Field( + "1024x1024", + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + examples=["1024x1024"], + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class Model6(Enum): + text_moderation_latest = "text-moderation-latest" + text_moderation_stable = "text-moderation-stable" + + +class CreateModerationRequest(BaseModel): + input: Annotated[Union[str, List[str]], Field(description="The input text to classify")] + model: Annotated[ + Optional[Union[str, Model6]], + Field( + "text-moderation-latest", + description="Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`.\n\nThe default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`.\n", + examples=["text-moderation-stable"], + ), + ] + + +class Categories(BaseModel): + hate: Annotated[ + bool, + Field( + description="Content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. Hateful content aimed at non-protected groups (e.g., chess players) is harassment." + ), + ] + hate_threatening: Annotated[ + bool, + Field( + alias="hate/threatening", + description="Hateful content that also includes violence or serious harm towards the targeted group based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste.", + ), + ] + harassment: Annotated[ + bool, + Field( + description="Content that expresses, incites, or promotes harassing language towards any target." + ), + ] + harassment_threatening: Annotated[ + bool, + Field( + alias="harassment/threatening", + description="Harassment content that also includes violence or serious harm towards any target.", + ), + ] + self_harm: Annotated[ + bool, + Field( + alias="self-harm", + description="Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders.", + ), + ] + self_harm_intent: Annotated[ + bool, + Field( + alias="self-harm/intent", + description="Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders.", + ), + ] + self_harm_instructions: Annotated[ + bool, + Field( + alias="self-harm/instructions", + description="Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts.", + ), + ] + sexual: Annotated[ + bool, + Field( + description="Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness)." + ), + ] + sexual_minors: Annotated[ + bool, + Field( + alias="sexual/minors", + description="Sexual content that includes an individual who is under 18 years old.", + ), + ] + violence: Annotated[ + bool, + Field(description="Content that depicts death, violence, or physical injury."), + ] + violence_graphic: Annotated[ + bool, + Field( + alias="violence/graphic", + description="Content that depicts death, violence, or physical injury in graphic detail.", + ), + ] + + +class CategoryScores(BaseModel): + hate: Annotated[float, Field(description="The score for the category 'hate'.")] + hate_threatening: Annotated[ + float, + Field( + alias="hate/threatening", + description="The score for the category 'hate/threatening'.", + ), + ] + harassment: Annotated[float, Field(description="The score for the category 'harassment'.")] + harassment_threatening: Annotated[ + float, + Field( + alias="harassment/threatening", + description="The score for the category 'harassment/threatening'.", + ), + ] + self_harm: Annotated[ + float, + Field(alias="self-harm", description="The score for the category 'self-harm'."), + ] + self_harm_intent: Annotated[ + float, + Field( + alias="self-harm/intent", + description="The score for the category 'self-harm/intent'.", + ), + ] + self_harm_instructions: Annotated[ + float, + Field( + alias="self-harm/instructions", + description="The score for the category 'self-harm/instructions'.", + ), + ] + sexual: Annotated[float, Field(description="The score for the category 'sexual'.")] + sexual_minors: Annotated[ + float, + Field( + alias="sexual/minors", + description="The score for the category 'sexual/minors'.", + ), + ] + violence: Annotated[float, Field(description="The score for the category 'violence'.")] + violence_graphic: Annotated[ + float, + Field( + alias="violence/graphic", + description="The score for the category 'violence/graphic'.", + ), + ] + + +class Result(BaseModel): + flagged: Annotated[bool, Field(description="Whether any of the below categories are flagged.")] + categories: Annotated[ + Categories, + Field(description="A list of the categories, and whether they are flagged or not."), + ] + category_scores: Annotated[ + CategoryScores, + Field( + description="A list of the categories along with their scores as predicted by model." + ), + ] + + +class CreateModerationResponse(BaseModel): + id: Annotated[str, Field(description="The unique identifier for the moderation request.")] + model: Annotated[str, Field(description="The model used to generate the moderation results.")] + results: Annotated[List[Result], Field(description="A list of moderation objects.")] + + +class Object6(Enum): + list = "list" + + +class Purpose(Enum): + assistants = "assistants" + batch = "batch" + fine_tune = "fine-tune" + vision = "vision" + + +class CreateFileRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file: Annotated[bytes, Field(description="The File object (not file name) to be uploaded.\n")] + purpose: Annotated[ + Purpose, + Field( + description='The intended purpose of the uploaded file.\n\nUse "assistants" for [Assistants](/docs/api-reference/assistants) and [Message](/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](/docs/guides/batch), and "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning).\n' + ), + ] + + +class Object7(Enum): + file = "file" + + +class DeleteFileResponse(BaseModel): + id: str + object: Object7 + deleted: bool + + +class CreateUploadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + filename: Annotated[str, Field(description="The name of the file to upload.\n")] + purpose: Annotated[ + Purpose, + Field( + description="The intended purpose of the uploaded file.\n\nSee the [documentation on File purposes](/docs/api-reference/files/create#files-create-purpose).\n" + ), + ] + bytes: Annotated[int, Field(description="The number of bytes in the file you are uploading.\n")] + mime_type: Annotated[ + str, + Field( + description="The MIME type of the file.\n\nThis must fall within the supported MIME types for your file purpose. See the supported MIME types for assistants and vision.\n" + ), + ] + + +class AddUploadPartRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + data: Annotated[bytes, Field(description="The chunk of bytes for this Part.\n")] + + +class CompleteUploadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + part_ids: Annotated[List[str], Field(description="The ordered list of Part IDs.\n")] + md5: Annotated[ + Optional[str], + Field( + None, + description="The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect.\n", + ), + ] + + +class CancelUploadRequest(BaseModel): + pass + model_config = ConfigDict( + extra="forbid", + ) + + +class Model7(Enum): + babbage_002 = "babbage-002" + davinci_002 = "davinci-002" + gpt_3_5_turbo = "gpt-3.5-turbo" + + +class BatchSize(Enum): + auto = "auto" + + +class BatchSize1(RootModel[int]): + root: Annotated[ + int, + Field( + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", + ge=1, + le=256, + ), + ] + + +class LearningRateMultiplier(Enum): + auto = "auto" + + +class LearningRateMultiplier1(RootModel[float]): + root: Annotated[ + float, + Field( + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", + gt=0.0, + ), + ] + + +class NEpochs(Enum): + auto = "auto" + + +class NEpochs1(RootModel[int]): + root: Annotated[ + int, + Field( + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", + ge=1, + le=50, + ), + ] + + +class Hyperparameters(BaseModel): + batch_size: Annotated[ + Optional[Union[BatchSize, BatchSize1]], + Field( + "auto", + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", + ), + ] + learning_rate_multiplier: Annotated[ + Optional[Union[LearningRateMultiplier, LearningRateMultiplier1]], + Field( + "auto", + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", + ), + ] + n_epochs: Annotated[ + Optional[Union[NEpochs, NEpochs1]], + Field( + "auto", + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", + ), + ] + + +class Type7(Enum): + wandb = "wandb" + + +class Wandb(BaseModel): + project: Annotated[ + str, + Field( + description="The name of the project that the new run will be created under.\n", + examples=["my-wandb-project"], + ), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="A display name to set for the run. If not set, we will use the Job ID as the name.\n", + ), + ] + entity: Annotated[ + Optional[str], + Field( + None, + description="The entity to use for the run. This allows you to set the team or username of the WandB user that you would\nlike associated with the run. If not set, the default entity for the registered WandB API key is used.\n", + ), + ] + tags: Annotated[ + Optional[List[str]], + Field( + None, + description='A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some\ndefault tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".\n', + ), + ] + + +class Integration(BaseModel): + type: Annotated[ + Type7, + Field( + description='The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported.\n' + ), + ] + wandb: Annotated[ + Wandb, + Field( + description="The settings for your integration with Weights and Biases. This payload specifies the project that\nmetrics will be sent to. Optionally, you can set an explicit display name for your run, add tags\nto your run, and set a default entity (team, username, etc) to be associated with your run.\n" + ), + ] + + +class CreateFineTuningJobRequest(BaseModel): + model: Annotated[ + Union[str, Model7], + Field( + description="The name of the model to fine-tune. You can select one of the\n[supported models](/docs/guides/fine-tuning/what-models-can-be-fine-tuned).\n", + examples=["gpt-3.5-turbo"], + ), + ] + training_file: Annotated[ + str, + Field( + description="The ID of an uploaded file that contains training data.\n\nSee [upload file](/docs/api-reference/files/create) for how to upload a file.\n\nYour dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`.\n\nThe contents of the file should differ depending on if the model uses the [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) format.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + examples=["file-abc123"], + ), + ] + hyperparameters: Annotated[ + Optional[Hyperparameters], + Field(None, description="The hyperparameters used for the fine-tuning job."), + ] + suffix: Annotated[ + Optional[str], + Field( + None, + description='A string of up to 18 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-3.5-turbo:openai:custom-model-name:7p4lURel`.\n', + max_length=40, + min_length=1, + ), + ] + validation_file: Annotated[ + Optional[str], + Field( + None, + description="The ID of an uploaded file that contains validation data.\n\nIf you provide this file, the data is used to generate validation\nmetrics periodically during fine-tuning. These metrics can be viewed in\nthe fine-tuning results file.\nThe same data should not be present in both train and validation files.\n\nYour dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + examples=["file-abc123"], + ), + ] + integrations: Annotated[ + Optional[List[Integration]], + Field( + None, + description="A list of integrations to enable for your fine-tuning job.", + ), + ] + seed: Annotated[ + Optional[int], + Field( + None, + description="The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases.\nIf a seed is not specified, one will be generated for you.\n", + examples=[42], + ge=0, + le=2147483647, + ), + ] + + +class Object8(Enum): + list = "list" + + +class Input(RootModel[List[str]]): + root: Annotated[ + List[str], + Field( + description="The array of strings that will be turned into an embedding.", + examples=["The quick brown fox jumped over the lazy dog"], + max_length=2048, + min_length=1, + title="array", + ), + ] + + +class Input1(RootModel[List[int]]): + root: Annotated[ + List[int], + Field( + description="The array of integers that will be turned into an embedding.", + examples=["[1212, 318, 257, 1332, 13]"], + max_length=2048, + min_length=1, + title="array", + ), + ] + + +class Input2Item(RootModel[List[int]]): + root: Annotated[List[int], Field(min_length=1)] + + +class Input2(RootModel[List[Input2Item]]): + root: Annotated[ + List[Input2Item], + Field( + description="The array of arrays containing integers that will be turned into an embedding.", + examples=["[[1212, 318, 257, 1332, 13]]"], + max_length=2048, + min_length=1, + title="array", + ), + ] + + +class Model8(Enum): + text_embedding_ada_002 = "text-embedding-ada-002" + text_embedding_3_small = "text-embedding-3-small" + text_embedding_3_large = "text-embedding-3-large" + + +class EncodingFormat(Enum): + float = "float" + base64 = "base64" + + +class CreateEmbeddingRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + input: Annotated[ + Union[str, Input, Input1, Input2], + Field( + description="Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + examples=["The quick brown fox jumped over the lazy dog"], + ), + ] + model: Annotated[ + Union[str, Model8], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + examples=["text-embedding-3-small"], + ), + ] + encoding_format: Annotated[ + Optional[EncodingFormat], + Field( + "float", + description="The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).", + examples=["float"], + ), + ] + dimensions: Annotated[ + Optional[int], + Field( + None, + description="The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.\n", + ge=1, + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class Usage1(BaseModel): + prompt_tokens: Annotated[int, Field(description="The number of tokens used by the prompt.")] + total_tokens: Annotated[ + int, Field(description="The total number of tokens used by the request.") + ] + + +class Model9(Enum): + whisper_1 = "whisper-1" + + +class ResponseFormat4(Enum): + json = "json" + text = "text" + srt = "srt" + verbose_json = "verbose_json" + vtt = "vtt" + + +class TimestampGranularity(Enum): + word = "word" + segment = "segment" + + +class CreateTranscriptionRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file: Annotated[ + bytes, + Field( + description="The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n" + ), + ] + model: Annotated[ + Union[str, Model9], + Field( + description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", + examples=["whisper-1"], + ), + ] + language: Annotated[ + Optional[str], + Field( + None, + description="The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.\n", + ), + ] + prompt: Annotated[ + Optional[str], + Field( + None, + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language.\n", + ), + ] + response_format: Annotated[ + Optional[ResponseFormat4], + Field( + "json", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 0, + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + ), + ] + timestamp_granularities__: Annotated[ + Optional[List[TimestampGranularity]], + Field( + ["segment"], + alias="timestamp_granularities[]", + description="The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency.\n", + ), + ] + + +class CreateTranscriptionResponseJson(BaseModel): + text: Annotated[str, Field(description="The transcribed text.")] + + +class TranscriptionSegment(BaseModel): + id: Annotated[int, Field(description="Unique identifier of the segment.")] + seek: Annotated[int, Field(description="Seek offset of the segment.")] + start: Annotated[float, Field(description="Start time of the segment in seconds.")] + end: Annotated[float, Field(description="End time of the segment in seconds.")] + text: Annotated[str, Field(description="Text content of the segment.")] + tokens: Annotated[List[int], Field(description="Array of token IDs for the text content.")] + temperature: Annotated[ + float, + Field(description="Temperature parameter used for generating the segment."), + ] + avg_logprob: Annotated[ + float, + Field( + description="Average logprob of the segment. If the value is lower than -1, consider the logprobs failed." + ), + ] + compression_ratio: Annotated[ + float, + Field( + description="Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed." + ), + ] + no_speech_prob: Annotated[ + float, + Field( + description="Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this segment silent." + ), + ] + + +class TranscriptionWord(BaseModel): + word: Annotated[str, Field(description="The text content of the word.")] + start: Annotated[float, Field(description="Start time of the word in seconds.")] + end: Annotated[float, Field(description="End time of the word in seconds.")] + + +class CreateTranscriptionResponseVerboseJson(BaseModel): + language: Annotated[str, Field(description="The language of the input audio.")] + duration: Annotated[str, Field(description="The duration of the input audio.")] + text: Annotated[str, Field(description="The transcribed text.")] + words: Annotated[ + Optional[List[TranscriptionWord]], + Field(None, description="Extracted words and their corresponding timestamps."), + ] + segments: Annotated[ + Optional[List[TranscriptionSegment]], + Field( + None, + description="Segments of the transcribed text and their corresponding details.", + ), + ] + + +class CreateTranslationRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file: Annotated[ + bytes, + Field( + description="The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n" + ), + ] + model: Annotated[ + Union[str, Model9], + Field( + description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", + examples=["whisper-1"], + ), + ] + prompt: Annotated[ + Optional[str], + Field( + None, + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English.\n", + ), + ] + response_format: Annotated[ + Optional[str], + Field( + "json", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 0, + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + ), + ] + + +class CreateTranslationResponseJson(BaseModel): + text: str + + +class CreateTranslationResponseVerboseJson(BaseModel): + language: Annotated[ + str, + Field(description="The language of the output translation (always `english`)."), + ] + duration: Annotated[str, Field(description="The duration of the input audio.")] + text: Annotated[str, Field(description="The translated text.")] + segments: Annotated[ + Optional[List[TranscriptionSegment]], + Field( + None, + description="Segments of the translated text and their corresponding details.", + ), + ] + + +class Model11(Enum): + tts_1 = "tts-1" + tts_1_hd = "tts-1-hd" + + +class Voice(Enum): + alloy = "alloy" + echo = "echo" + fable = "fable" + onyx = "onyx" + nova = "nova" + shimmer = "shimmer" + + +class ResponseFormat5(Enum): + mp3 = "mp3" + opus = "opus" + aac = "aac" + flac = "flac" + wav = "wav" + pcm = "pcm" + + +class CreateSpeechRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + model: Annotated[ + Union[str, Model11], + Field( + description="One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd`\n" + ), + ] + input: Annotated[ + str, + Field( + description="The text to generate audio for. The maximum length is 4096 characters.", + max_length=4096, + ), + ] + voice: Annotated[ + Voice, + Field( + description="The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are available in the [Text to speech guide](/docs/guides/text-to-speech/voice-options)." + ), + ] + response_format: Annotated[ + Optional[ResponseFormat5], + Field( + "mp3", + description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`.", + ), + ] + speed: Annotated[ + Optional[float], + Field( + 1.0, + description="The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.", + ge=0.25, + le=4.0, + ), + ] + + +class Object11(Enum): + model = "model" + + +class Model(BaseModel): + id: Annotated[ + str, + Field(description="The model identifier, which can be referenced in the API endpoints."), + ] + created: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) when the model was created."), + ] + object: Annotated[Object11, Field(description='The object type, which is always "model".')] + owned_by: Annotated[str, Field(description="The organization that owns the model.")] + + +class Object12(Enum): + file = "file" + + +class Purpose2(Enum): + assistants = "assistants" + assistants_output = "assistants_output" + batch = "batch" + batch_output = "batch_output" + fine_tune = "fine-tune" + fine_tune_results = "fine-tune-results" + vision = "vision" + + +class Status(Enum): + uploaded = "uploaded" + processed = "processed" + error = "error" + + +class OpenAIFile(BaseModel): + id: Annotated[ + str, + Field(description="The file identifier, which can be referenced in the API endpoints."), + ] + bytes: Annotated[int, Field(description="The size of the file, in bytes.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the file was created."), + ] + filename: Annotated[str, Field(description="The name of the file.")] + object: Annotated[Object12, Field(description="The object type, which is always `file`.")] + purpose: Annotated[ + Purpose2, + Field( + description="The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`." + ), + ] + status: Annotated[ + Status, + Field( + description="Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`." + ), + ] + status_details: Annotated[ + Optional[str], + Field( + None, + description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`.", + ), + ] + + +class Status1(Enum): + pending = "pending" + completed = "completed" + cancelled = "cancelled" + expired = "expired" + + +class Object13(Enum): + upload = "upload" + + +class Upload(BaseModel): + id: Annotated[ + str, + Field( + description="The Upload unique identifier, which can be referenced in API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Upload was created."), + ] + filename: Annotated[str, Field(description="The name of the file to be uploaded.")] + bytes: Annotated[int, Field(description="The intended number of bytes to be uploaded.")] + purpose: Annotated[ + str, + Field( + description="The intended purpose of the file. [Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values." + ), + ] + status: Annotated[Status1, Field(description="The status of the Upload.")] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Upload was created."), + ] + object: Annotated[ + Optional[Object13], + Field(None, description='The object type, which is always "upload".'), + ] + file: Annotated[ + Optional[OpenAIFile], + Field(None, description="The ready File object after the Upload is completed."), + ] + + +class Object14(Enum): + upload_part = "upload.part" + + +class UploadPart(BaseModel): + id: Annotated[ + str, + Field( + description="The upload Part unique identifier, which can be referenced in API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Part was created."), + ] + upload_id: Annotated[ + str, + Field(description="The ID of the Upload object that this Part was added to."), + ] + object: Annotated[ + Object14, Field(description="The object type, which is always `upload.part`.") + ] + + +class Object15(Enum): + embedding = "embedding" + + +class Embedding(BaseModel): + index: Annotated[ + int, Field(description="The index of the embedding in the list of embeddings.") + ] + embedding: Annotated[ + List[float], + Field( + description="The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).\n" + ), + ] + object: Annotated[Object15, Field(description='The object type, which is always "embedding".')] + + +class Error1(BaseModel): + code: Annotated[str, Field(description="A machine-readable error code.")] + message: Annotated[str, Field(description="A human-readable error message.")] + param: Annotated[ + str, + Field( + description="The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific." + ), + ] + + +class NEpochs2(Enum): + auto = "auto" + + +class NEpochs3(RootModel[int]): + root: Annotated[ + int, + Field( + description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.', + ge=1, + le=50, + ), + ] + + +class Hyperparameters1(BaseModel): + n_epochs: Annotated[ + Union[NEpochs2, NEpochs3], + Field( + description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.' + ), + ] + + +class Object16(Enum): + fine_tuning_job = "fine_tuning.job" + + +class Status2(Enum): + validating_files = "validating_files" + queued = "queued" + running = "running" + succeeded = "succeeded" + failed = "failed" + cancelled = "cancelled" + + +class FineTuningIntegration(BaseModel): + type: Annotated[ + Type7, + Field(description="The type of the integration being enabled for the fine-tuning job"), + ] + wandb: Annotated[ + Wandb, + Field( + description="The settings for your integration with Weights and Biases. This payload specifies the project that\nmetrics will be sent to. Optionally, you can set an explicit display name for your run, add tags\nto your run, and set a default entity (team, username, etc) to be associated with your run.\n" + ), + ] + + +class Level(Enum): + info = "info" + warn = "warn" + error = "error" + + +class Object17(Enum): + fine_tuning_job_event = "fine_tuning.job.event" + + +class FineTuningJobEvent(BaseModel): + id: str + created_at: int + level: Level + message: str + object: Object17 + + +class Metrics(BaseModel): + step: Optional[float] = None + train_loss: Optional[float] = None + train_mean_token_accuracy: Optional[float] = None + valid_loss: Optional[float] = None + valid_mean_token_accuracy: Optional[float] = None + full_valid_loss: Optional[float] = None + full_valid_mean_token_accuracy: Optional[float] = None + + +class Object18(Enum): + fine_tuning_job_checkpoint = "fine_tuning.job.checkpoint" + + +class FineTuningJobCheckpoint(BaseModel): + id: Annotated[ + str, + Field( + description="The checkpoint identifier, which can be referenced in the API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the checkpoint was created."), + ] + fine_tuned_model_checkpoint: Annotated[ + str, + Field(description="The name of the fine-tuned checkpoint model that is created."), + ] + step_number: Annotated[ + int, Field(description="The step number that the checkpoint was created at.") + ] + metrics: Annotated[ + Metrics, + Field(description="Metrics at the step number during the fine-tuning job."), + ] + fine_tuning_job_id: Annotated[ + str, + Field(description="The name of the fine-tuning job that this checkpoint was created from."), + ] + object: Annotated[ + Object18, + Field(description='The object type, which is always "fine_tuning.job.checkpoint".'), + ] + + +class FinetuneCompletionRequestInput(BaseModel): + prompt: Annotated[ + Optional[str], + Field(None, description="The input prompt for this training example."), + ] + completion: Annotated[ + Optional[str], + Field(None, description="The desired completion for this training example."), + ] + + +class CompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, Field(description="Number of tokens in the generated completion.") + ] + prompt_tokens: Annotated[int, Field(description="Number of tokens in the prompt.")] + total_tokens: Annotated[ + int, + Field(description="Total number of tokens used in the request (prompt + completion)."), + ] + + +class RunCompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, + Field(description="Number of completion tokens used over the course of the run."), + ] + prompt_tokens: Annotated[ + int, + Field(description="Number of prompt tokens used over the course of the run."), + ] + total_tokens: Annotated[ + int, Field(description="Total number of tokens used (prompt + completion).") + ] + + +class RunStepCompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, + Field(description="Number of completion tokens used over the course of the run step."), + ] + prompt_tokens: Annotated[ + int, + Field(description="Number of prompt tokens used over the course of the run step."), + ] + total_tokens: Annotated[ + int, Field(description="Total number of tokens used (prompt + completion).") + ] + + +class AssistantsApiResponseFormatOption1(Enum): + none = "none" + auto = "auto" + + +class Type9(Enum): + text = "text" + json_object = "json_object" + + +class AssistantsApiResponseFormat(BaseModel): + type: Annotated[ + Optional[Type9], + Field( + "text", + description="Must be one of `text` or `json_object`.", + examples=["json_object"], + ), + ] + + +class Object19(Enum): + assistant = "assistant" + + +class CodeInterpreter(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + [], + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] + + +class FileSearch(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ToolResources(BaseModel): + code_interpreter: Optional[CodeInterpreter] = None + file_search: Optional[FileSearch] = None + + +class Model12(Enum): + gpt_4o = "gpt-4o" + gpt_4o_2024_05_13 = "gpt-4o-2024-05-13" + gpt_4o_mini = "gpt-4o-mini" + gpt_4o_mini_2024_07_18 = "gpt-4o-mini-2024-07-18" + gpt_4_turbo = "gpt-4-turbo" + gpt_4_turbo_2024_04_09 = "gpt-4-turbo-2024-04-09" + gpt_4_0125_preview = "gpt-4-0125-preview" + gpt_4_turbo_preview = "gpt-4-turbo-preview" + gpt_4_1106_preview = "gpt-4-1106-preview" + gpt_4_vision_preview = "gpt-4-vision-preview" + gpt_4 = "gpt-4" + gpt_4_0314 = "gpt-4-0314" + gpt_4_0613 = "gpt-4-0613" + gpt_4_32k = "gpt-4-32k" + gpt_4_32k_0314 = "gpt-4-32k-0314" + gpt_4_32k_0613 = "gpt-4-32k-0613" + gpt_3_5_turbo = "gpt-3.5-turbo" + gpt_3_5_turbo_16k = "gpt-3.5-turbo-16k" + gpt_3_5_turbo_0613 = "gpt-3.5-turbo-0613" + gpt_3_5_turbo_1106 = "gpt-3.5-turbo-1106" + gpt_3_5_turbo_0125 = "gpt-3.5-turbo-0125" + gpt_3_5_turbo_16k_0613 = "gpt-3.5-turbo-16k-0613" + + +class CodeInterpreter1(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + [], + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] + + +class Type10(Enum): + auto = "auto" + + +class ChunkingStrategy(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type10, Field(description="Always `auto`.")] + + +class Type11(Enum): + static = "static" + + +class Static(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + max_chunk_size_tokens: Annotated[ + int, + Field( + description="The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`.", + ge=100, + le=4096, + ), + ] + chunk_overlap_tokens: Annotated[ + int, + Field( + description="The number of tokens that overlap between chunks. The default value is `400`.\n\nNote that the overlap must not exceed half of `max_chunk_size_tokens`.\n" + ), + ] + + +class ChunkingStrategy1(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type11, Field(description="Always `static`.")] + static: Static + + +class VectorStore(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy, ChunkingStrategy1]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class FileSearch1(BaseModel): + vector_store_ids: Annotated[ + List[str], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + Optional[List[VectorStore]], + Field( + None, + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class Type12(Enum): + auto = "auto" + + +class ChunkingStrategy2(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type12, Field(description="Always `auto`.")] + + +class Type13(Enum): + static = "static" + + +class ChunkingStrategy3(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type13, Field(description="Always `static`.")] + static: Static + + +class VectorStore1(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy2, ChunkingStrategy3]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class FileSearch2(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + List[VectorStore1], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ToolResources1(BaseModel): + code_interpreter: Optional[CodeInterpreter1] = None + file_search: Optional[Union[FileSearch1, FileSearch2]] = None + + +class CodeInterpreter2(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + [], + description="Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] + + +class FileSearch3(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ToolResources2(BaseModel): + code_interpreter: Optional[CodeInterpreter2] = None + file_search: Optional[FileSearch3] = None + + +class Object20(Enum): + assistant_deleted = "assistant.deleted" + + +class DeleteAssistantResponse(BaseModel): + id: str + deleted: bool + object: Object20 + + +class Type14(Enum): + code_interpreter = "code_interpreter" + + +class AssistantToolsCode(BaseModel): + type: Annotated[Type14, Field(description="The type of tool being defined: `code_interpreter`")] + + +class Type15(Enum): + file_search = "file_search" + + +class FileSearch4(BaseModel): + max_num_results: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of results the file search tool should output. The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. This number should be between 1 and 50 inclusive.\n\nNote that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information.\n", + ge=1, + le=50, + ), + ] + + +class AssistantToolsFileSearch(BaseModel): + type: Annotated[Type15, Field(description="The type of tool being defined: `file_search`")] + file_search: Annotated[ + Optional[FileSearch4], + Field(None, description="Overrides for the file search tool."), + ] + + +class AssistantToolsFileSearchTypeOnly(BaseModel): + type: Annotated[Type15, Field(description="The type of tool being defined: `file_search`")] + + +class Type17(Enum): + function = "function" + + +class AssistantToolsFunction(BaseModel): + type: Annotated[Type17, Field(description="The type of tool being defined: `function`")] + function: FunctionObject + + +class Type18(Enum): + auto = "auto" + last_messages = "last_messages" + + +class TruncationObject(BaseModel): + type: Annotated[ + Type18, + Field( + description="The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`." + ), + ] + last_messages: Annotated[ + Optional[int], + Field( + None, + description="The number of most recent messages from the thread when constructing the context for the run.", + ge=1, + ), + ] + + +class AssistantsApiToolChoiceOption1(Enum): + none = "none" + auto = "auto" + required = "required" + + +class Type19(Enum): + function = "function" + code_interpreter = "code_interpreter" + file_search = "file_search" + + +class Function3(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class AssistantsNamedToolChoice(BaseModel): + type: Annotated[ + Type19, + Field( + description="The type of the tool. If type is `function`, the function name must be set" + ), + ] + function: Optional[Function3] = None + + +class Object21(Enum): + thread_run = "thread.run" + + +class Status3(Enum): + queued = "queued" + in_progress = "in_progress" + requires_action = "requires_action" + cancelling = "cancelling" + cancelled = "cancelled" + failed = "failed" + completed = "completed" + incomplete = "incomplete" + expired = "expired" + + +class Type20(Enum): + submit_tool_outputs = "submit_tool_outputs" + + +class Code(Enum): + server_error = "server_error" + rate_limit_exceeded = "rate_limit_exceeded" + invalid_prompt = "invalid_prompt" + + +class LastError(BaseModel): + code: Annotated[ + Code, + Field(description="One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class Reason(Enum): + max_completion_tokens = "max_completion_tokens" + max_prompt_tokens = "max_prompt_tokens" + + +class IncompleteDetails(BaseModel): + reason: Annotated[ + Optional[Reason], + Field( + None, + description="The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run.", + ), + ] + + +class ModifyRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class ToolOutput(BaseModel): + tool_call_id: Annotated[ + Optional[str], + Field( + None, + description="The ID of the tool call in the `required_action` object within the run object the output is being submitted for.", + ), + ] + output: Annotated[ + Optional[str], + Field( + None, + description="The output of the tool call to be submitted to continue the run.", + ), + ] + + +class SubmitToolOutputsRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + tool_outputs: Annotated[ + List[ToolOutput], + Field(description="A list of tools for which the outputs are being submitted."), + ] + stream: Annotated[ + Optional[bool], + Field( + None, + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + ), + ] + + +class Type21(Enum): + function = "function" + + +class Function4(BaseModel): + name: Annotated[str, Field(description="The name of the function.")] + arguments: Annotated[ + str, + Field(description="The arguments that the model expects you to pass to the function."), + ] + + +class RunToolCallObject(BaseModel): + id: Annotated[ + str, + Field( + description="The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](/docs/api-reference/runs/submitToolOutputs) endpoint." + ), + ] + type: Annotated[ + Type21, + Field( + description="The type of tool call the output is required for. For now, this is always `function`." + ), + ] + function: Annotated[Function4, Field(description="The function definition.")] + + +class CodeInterpreter3(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + [], + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] + + +class FileSearch5(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ToolResources3(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch5] = None + + +class Object22(Enum): + thread = "thread" + + +class FileSearch6(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + + +class ToolResources4(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch6] = None + + +class ThreadObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[Object22, Field(description="The object type, which is always `thread`.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the thread was created."), + ] + tool_resources: Annotated[ + ToolResources4, + Field( + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class Type22(Enum): + auto = "auto" + + +class ChunkingStrategy4(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type22, Field(description="Always `auto`.")] + + +class Type23(Enum): + static = "static" + + +class ChunkingStrategy5(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type23, Field(description="Always `static`.")] + static: Static + + +class VectorStore2(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy4, ChunkingStrategy5]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class FileSearch7(BaseModel): + vector_store_ids: Annotated[ + List[str], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + Optional[List[VectorStore2]], + Field( + None, + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + + +class Type24(Enum): + auto = "auto" + + +class ChunkingStrategy6(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type24, Field(description="Always `auto`.")] + + +class Type25(Enum): + static = "static" + + +class ChunkingStrategy7(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type25, Field(description="Always `static`.")] + static: Static + + +class VectorStore3(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy6, ChunkingStrategy7]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class FileSearch8(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + List[VectorStore3], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + + +class ToolResources5(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[Union[FileSearch7, FileSearch8]] = None + + +class FileSearch9(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + + +class ToolResources6(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch9] = None + + +class ModifyThreadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + tool_resources: Annotated[ + Optional[ToolResources6], + Field( + None, + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class Object23(Enum): + thread_deleted = "thread.deleted" + + +class DeleteThreadResponse(BaseModel): + id: str + deleted: bool + object: Object23 + + +class ListThreadsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[ThreadObject] + first_id: Annotated[str, Field(examples=["asst_abc123"])] + last_id: Annotated[str, Field(examples=["asst_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class Object24(Enum): + thread_message = "thread.message" + + +class Status4(Enum): + in_progress = "in_progress" + incomplete = "incomplete" + completed = "completed" + + +class Reason1(Enum): + content_filter = "content_filter" + max_tokens = "max_tokens" + run_cancelled = "run_cancelled" + run_expired = "run_expired" + run_failed = "run_failed" + + +class IncompleteDetails1(BaseModel): + reason: Annotated[Reason1, Field(description="The reason the message is incomplete.")] + + +class Role7(Enum): + user = "user" + assistant = "assistant" + + +class Attachment(BaseModel): + file_id: Annotated[ + Optional[str], + Field(None, description="The ID of the file to attach to the message."), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearchTypeOnly]]], + Field(None, description="The tools to add this file to."), + ] + + +class Object25(Enum): + thread_message_delta = "thread.message.delta" + + +class ModifyMessageRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class Object26(Enum): + thread_message_deleted = "thread.message.deleted" + + +class DeleteMessageResponse(BaseModel): + id: str + deleted: bool + object: Object26 + + +class Type26(Enum): + image_file = "image_file" + + +class ImageFile(BaseModel): + file_id: Annotated[ + str, + Field( + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.' + ), + ] + detail: Annotated[ + Optional[Detail], + Field( + "auto", + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + ), + ] + + +class MessageContentImageFileObject(BaseModel): + type: Annotated[Type26, Field(description="Always `image_file`.")] + image_file: ImageFile + + +class ImageFile1(BaseModel): + file_id: Annotated[ + Optional[str], + Field( + None, + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.', + ), + ] + detail: Annotated[ + Optional[Detail], + Field( + "auto", + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + ), + ] + + +class MessageDeltaContentImageFileObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Type26, Field(description="Always `image_file`.")] + image_file: Optional[ImageFile1] = None + + +class Type28(Enum): + image_url = "image_url" + + +class ImageUrl1(BaseModel): + url: Annotated[ + AnyUrl, + Field( + description="The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + ), + ] + detail: Annotated[ + Optional[Detail], + Field( + "auto", + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`", + ), + ] + + +class MessageContentImageUrlObject(BaseModel): + type: Annotated[Type28, Field(description="The type of the content part.")] + image_url: ImageUrl1 + + +class ImageUrl2(BaseModel): + url: Annotated[ + Optional[str], + Field( + None, + description="The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp.", + ), + ] + detail: Annotated[ + Optional[Detail], + Field( + "auto", + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + ), + ] + + +class MessageDeltaContentImageUrlObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Type28, Field(description="Always `image_url`.")] + image_url: Optional[ImageUrl2] = None + + +class Type30(Enum): + text = "text" + + +class MessageRequestContentTextObject(BaseModel): + type: Annotated[Type30, Field(description="Always `text`.")] + text: Annotated[str, Field(description="Text content to be sent to the model")] + + +class Type32(Enum): + file_citation = "file_citation" + + +class FileCitation(BaseModel): + file_id: Annotated[str, Field(description="The ID of the specific File the citation is from.")] + + +class MessageContentTextAnnotationsFileCitationObject(BaseModel): + type: Annotated[Type32, Field(description="Always `file_citation`.")] + text: Annotated[ + str, + Field(description="The text in the message content that needs to be replaced."), + ] + file_citation: FileCitation + start_index: Annotated[int, Field(ge=0)] + end_index: Annotated[int, Field(ge=0)] + + +class Type33(Enum): + file_path = "file_path" + + +class FilePath(BaseModel): + file_id: Annotated[str, Field(description="The ID of the file that was generated.")] + + +class MessageContentTextAnnotationsFilePathObject(BaseModel): + type: Annotated[Type33, Field(description="Always `file_path`.")] + text: Annotated[ + str, + Field(description="The text in the message content that needs to be replaced."), + ] + file_path: FilePath + start_index: Annotated[int, Field(ge=0)] + end_index: Annotated[int, Field(ge=0)] + + +class Type34(Enum): + text = "text" + + +class Type35(Enum): + file_citation = "file_citation" + + +class FileCitation1(BaseModel): + file_id: Annotated[ + Optional[str], + Field(None, description="The ID of the specific File the citation is from."), + ] + quote: Annotated[Optional[str], Field(None, description="The specific quote in the file.")] + + +class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): + index: Annotated[ + int, Field(description="The index of the annotation in the text content part.") + ] + type: Annotated[Type35, Field(description="Always `file_citation`.")] + text: Annotated[ + Optional[str], + Field( + None, + description="The text in the message content that needs to be replaced.", + ), + ] + file_citation: Optional[FileCitation1] = None + start_index: Annotated[Optional[int], Field(None, ge=0)] + end_index: Annotated[Optional[int], Field(None, ge=0)] + + +class Type36(Enum): + file_path = "file_path" + + +class FilePath1(BaseModel): + file_id: Annotated[ + Optional[str], Field(None, description="The ID of the file that was generated.") + ] + + +class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): + index: Annotated[ + int, Field(description="The index of the annotation in the text content part.") + ] + type: Annotated[Type36, Field(description="Always `file_path`.")] + text: Annotated[ + Optional[str], + Field( + None, + description="The text in the message content that needs to be replaced.", + ), + ] + file_path: Optional[FilePath1] = None + start_index: Annotated[Optional[int], Field(None, ge=0)] + end_index: Annotated[Optional[int], Field(None, ge=0)] + + +class Object27(Enum): + thread_run_step = "thread.run.step" + + +class Type37(Enum): + message_creation = "message_creation" + tool_calls = "tool_calls" + + +class Status5(Enum): + in_progress = "in_progress" + cancelled = "cancelled" + failed = "failed" + completed = "completed" + expired = "expired" + + +class Code1(Enum): + server_error = "server_error" + rate_limit_exceeded = "rate_limit_exceeded" + + +class LastError1(BaseModel): + code: Annotated[Code1, Field(description="One of `server_error` or `rate_limit_exceeded`.")] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class Object28(Enum): + thread_run_step_delta = "thread.run.step.delta" + + +class Type38(Enum): + message_creation = "message_creation" + + +class MessageCreation(BaseModel): + message_id: Annotated[ + str, + Field(description="The ID of the message that was created by this run step."), + ] + + +class RunStepDetailsMessageCreationObject(BaseModel): + type: Annotated[Type38, Field(description="Always `message_creation`.")] + message_creation: MessageCreation + + +class MessageCreation1(BaseModel): + message_id: Annotated[ + Optional[str], + Field(None, description="The ID of the message that was created by this run step."), + ] + + +class RunStepDeltaStepDetailsMessageCreationObject(BaseModel): + type: Annotated[Type38, Field(description="Always `message_creation`.")] + message_creation: Optional[MessageCreation1] = None + + +class Type40(Enum): + tool_calls = "tool_calls" + + +class Type42(Enum): + code_interpreter = "code_interpreter" + + +class Type44(Enum): + logs = "logs" + + +class RunStepDetailsToolCallsCodeOutputLogsObject(BaseModel): + type: Annotated[Type44, Field(description="Always `logs`.")] + logs: Annotated[str, Field(description="The text output from the Code Interpreter tool call.")] + + +class RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject(BaseModel): + index: Annotated[int, Field(description="The index of the output in the outputs array.")] + type: Annotated[Type44, Field(description="Always `logs`.")] + logs: Annotated[ + Optional[str], + Field(None, description="The text output from the Code Interpreter tool call."), + ] + + +class Type46(Enum): + image = "image" + + +class Image1(BaseModel): + file_id: Annotated[ + str, Field(description="The [file](/docs/api-reference/files) ID of the image.") + ] + + +class RunStepDetailsToolCallsCodeOutputImageObject(BaseModel): + type: Annotated[Type46, Field(description="Always `image`.")] + image: Image1 + + +class Image2(BaseModel): + file_id: Annotated[ + Optional[str], + Field(None, description="The [file](/docs/api-reference/files) ID of the image."), + ] + + +class RunStepDeltaStepDetailsToolCallsCodeOutputImageObject(BaseModel): + index: Annotated[int, Field(description="The index of the output in the outputs array.")] + type: Annotated[Type46, Field(description="Always `image`.")] + image: Optional[Image2] = None + + +class Type48(Enum): + file_search = "file_search" + + +class RunStepDetailsToolCallsFileSearchObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call object.")] + type: Annotated[ + Type48, + Field( + description="The type of tool call. This is always going to be `file_search` for this type of tool call." + ), + ] + file_search: Annotated[ + Dict[str, Any], + Field(description="For now, this is always going to be an empty object."), + ] + + +class RunStepDeltaStepDetailsToolCallsFileSearchObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] + type: Annotated[ + Type48, + Field( + description="The type of tool call. This is always going to be `file_search` for this type of tool call." + ), + ] + file_search: Annotated[ + Dict[str, Any], + Field(description="For now, this is always going to be an empty object."), + ] + + +class Type50(Enum): + function = "function" + + +class Function5(BaseModel): + name: Annotated[str, Field(description="The name of the function.")] + arguments: Annotated[str, Field(description="The arguments passed to the function.")] + output: Annotated[ + str, + Field( + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." + ), + ] + + +class RunStepDetailsToolCallsFunctionObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call object.")] + type: Annotated[ + Type50, + Field( + description="The type of tool call. This is always going to be `function` for this type of tool call." + ), + ] + function: Annotated[ + Function5, Field(description="The definition of the function that was called.") + ] + + +class Function6(BaseModel): + name: Annotated[Optional[str], Field(None, description="The name of the function.")] + arguments: Annotated[ + Optional[str], Field(None, description="The arguments passed to the function.") + ] + output: Annotated[ + Optional[str], + Field( + None, + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet.", + ), + ] + + +class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] + type: Annotated[ + Type50, + Field( + description="The type of tool call. This is always going to be `function` for this type of tool call." + ), + ] + function: Annotated[ + Optional[Function6], + Field(None, description="The definition of the function that was called."), + ] + + +class Anchor(Enum): + last_active_at = "last_active_at" + + +class VectorStoreExpirationAfter(BaseModel): + anchor: Annotated[ + Anchor, + Field( + description="Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`." + ), + ] + days: Annotated[ + int, + Field( + description="The number of days after the anchor time that the vector store will expire.", + ge=1, + le=365, + ), + ] + + +class Object29(Enum): + vector_store = "vector_store" + + +class FileCounts(BaseModel): + in_progress: Annotated[ + int, + Field(description="The number of files that are currently being processed."), + ] + completed: Annotated[ + int, + Field(description="The number of files that have been successfully processed."), + ] + failed: Annotated[int, Field(description="The number of files that have failed to process.")] + cancelled: Annotated[int, Field(description="The number of files that were cancelled.")] + total: Annotated[int, Field(description="The total number of files.")] + + +class Status6(Enum): + expired = "expired" + in_progress = "in_progress" + completed = "completed" + + +class VectorStoreObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Object29, Field(description="The object type, which is always `vector_store`.") + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the vector store was created."), + ] + name: Annotated[str, Field(description="The name of the vector store.")] + usage_bytes: Annotated[ + int, + Field(description="The total number of bytes used by the files in the vector store."), + ] + file_counts: FileCounts + status: Annotated[ + Status6, + Field( + description="The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use." + ), + ] + expires_after: Optional[VectorStoreExpirationAfter] = None + expires_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the vector store will expire.", + ), + ] + last_active_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store was last active." + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class UpdateVectorStoreRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + name: Annotated[Optional[str], Field(None, description="The name of the vector store.")] + expires_after: Optional[VectorStoreExpirationAfter] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class ListVectorStoresResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[VectorStoreObject] + first_id: Annotated[str, Field(examples=["vs_abc123"])] + last_id: Annotated[str, Field(examples=["vs_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class Object30(Enum): + vector_store_deleted = "vector_store.deleted" + + +class DeleteVectorStoreResponse(BaseModel): + id: str + deleted: bool + object: Object30 + + +class Object31(Enum): + vector_store_file = "vector_store.file" + + +class Status7(Enum): + in_progress = "in_progress" + completed = "completed" + cancelled = "cancelled" + failed = "failed" + + +class Code2(Enum): + internal_error = "internal_error" + file_not_found = "file_not_found" + parsing_error = "parsing_error" + unhandled_mime_type = "unhandled_mime_type" + + +class LastError2(BaseModel): + code: Annotated[Code2, Field(description="One of `server_error` or `rate_limit_exceeded`.")] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class Type52(Enum): + other = "other" + + +class OtherChunkingStrategyResponseParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type52, Field(description="Always `other`.")] + + +class Type53(Enum): + static = "static" + + +class StaticChunkingStrategy(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + max_chunk_size_tokens: Annotated[ + int, + Field( + description="The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`.", + ge=100, + le=4096, + ), + ] + chunk_overlap_tokens: Annotated[ + int, + Field( + description="The number of tokens that overlap between chunks. The default value is `400`.\n\nNote that the overlap must not exceed half of `max_chunk_size_tokens`.\n" + ), + ] + + +class Type54(Enum): + auto = "auto" + + +class AutoChunkingStrategyRequestParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type54, Field(description="Always `auto`.")] + + +class Type55(Enum): + static = "static" + + +class StaticChunkingStrategyRequestParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type55, Field(description="Always `static`.")] + static: StaticChunkingStrategy + + +class ChunkingStrategyRequestParam( + RootModel[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]] +): + root: Annotated[ + Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] + + +class CreateVectorStoreFileRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file_id: Annotated[ + str, + Field( + description="A [File](/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files." + ), + ] + chunking_strategy: Optional[ChunkingStrategyRequestParam] = None + + +class Object32(Enum): + vector_store_file_deleted = "vector_store.file.deleted" + + +class DeleteVectorStoreFileResponse(BaseModel): + id: str + deleted: bool + object: Object32 + + +class Object33(Enum): + vector_store_files_batch = "vector_store.files_batch" + + +class FileCounts1(BaseModel): + in_progress: Annotated[ + int, + Field(description="The number of files that are currently being processed."), + ] + completed: Annotated[int, Field(description="The number of files that have been processed.")] + failed: Annotated[int, Field(description="The number of files that have failed to process.")] + cancelled: Annotated[int, Field(description="The number of files that where cancelled.")] + total: Annotated[int, Field(description="The total number of files.")] + + +class VectorStoreFileBatchObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Object33, + Field(description="The object type, which is always `vector_store.file_batch`."), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store files batch was created." + ), + ] + vector_store_id: Annotated[ + str, + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to." + ), + ] + status: Annotated[ + Status7, + Field( + description="The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`." + ), + ] + file_counts: FileCounts1 + + +class CreateVectorStoreFileBatchRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file_ids: Annotated[ + List[str], + Field( + description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", + max_length=500, + min_length=1, + ), + ] + chunking_strategy: Optional[ChunkingStrategyRequestParam] = None + + +class Event(Enum): + thread_created = "thread.created" + + +class ThreadStreamEvent1(BaseModel): + event: Event + data: ThreadObject + + +class ThreadStreamEvent(RootModel[ThreadStreamEvent1]): + root: ThreadStreamEvent1 + + +class Event1(Enum): + thread_run_created = "thread.run.created" + + +class Event2(Enum): + thread_run_queued = "thread.run.queued" + + +class Event3(Enum): + thread_run_in_progress = "thread.run.in_progress" + + +class Event4(Enum): + thread_run_requires_action = "thread.run.requires_action" + + +class Event5(Enum): + thread_run_completed = "thread.run.completed" + + +class Event6(Enum): + thread_run_incomplete = "thread.run.incomplete" + + +class Event7(Enum): + thread_run_failed = "thread.run.failed" + + +class Event8(Enum): + thread_run_cancelling = "thread.run.cancelling" + + +class Event9(Enum): + thread_run_cancelled = "thread.run.cancelled" + + +class Event10(Enum): + thread_run_expired = "thread.run.expired" + + +class Event11(Enum): + thread_run_step_created = "thread.run.step.created" + + +class Event12(Enum): + thread_run_step_in_progress = "thread.run.step.in_progress" + + +class Event13(Enum): + thread_run_step_delta = "thread.run.step.delta" + + +class Event14(Enum): + thread_run_step_completed = "thread.run.step.completed" + + +class Event15(Enum): + thread_run_step_failed = "thread.run.step.failed" + + +class Event16(Enum): + thread_run_step_cancelled = "thread.run.step.cancelled" + + +class Event17(Enum): + thread_run_step_expired = "thread.run.step.expired" + + +class Event18(Enum): + thread_message_created = "thread.message.created" + + +class Event19(Enum): + thread_message_in_progress = "thread.message.in_progress" + + +class Event20(Enum): + thread_message_delta = "thread.message.delta" + + +class Event21(Enum): + thread_message_completed = "thread.message.completed" + + +class Event22(Enum): + thread_message_incomplete = "thread.message.incomplete" + + +class Event23(Enum): + error = "error" + + +class ErrorEvent(BaseModel): + event: Event23 + data: Error + + +class Event24(Enum): + done = "done" + + +class Data(Enum): + field_DONE_ = "[DONE]" + + +class DoneEvent(BaseModel): + event: Event24 + data: Data + + +class Object34(Enum): + batch = "batch" + + +class Datum(BaseModel): + code: Annotated[ + Optional[str], + Field(None, description="An error code identifying the error type."), + ] + message: Annotated[ + Optional[str], + Field( + None, + description="A human-readable message providing more details about the error.", + ), + ] + param: Annotated[ + Optional[str], + Field( + None, + description="The name of the parameter that caused the error, if applicable.", + ), + ] + line: Annotated[ + Optional[int], + Field( + None, + description="The line number of the input file where the error occurred, if applicable.", + ), + ] + + +class Errors(BaseModel): + object: Annotated[ + Optional[str], + Field(None, description="The object type, which is always `list`."), + ] + data: Optional[List[Datum]] = None + + +class Status9(Enum): + validating = "validating" + failed = "failed" + in_progress = "in_progress" + finalizing = "finalizing" + completed = "completed" + expired = "expired" + cancelling = "cancelling" + cancelled = "cancelled" + + +class RequestCounts(BaseModel): + total: Annotated[int, Field(description="Total number of requests in the batch.")] + completed: Annotated[ + int, + Field(description="Number of requests that have been completed successfully."), + ] + failed: Annotated[int, Field(description="Number of requests that have failed.")] + + +class Batch(BaseModel): + id: str + object: Annotated[Object34, Field(description="The object type, which is always `batch`.")] + endpoint: Annotated[str, Field(description="The OpenAI API endpoint used by the batch.")] + errors: Optional[Errors] = None + input_file_id: Annotated[str, Field(description="The ID of the input file for the batch.")] + completion_window: Annotated[ + str, + Field(description="The time frame within which the batch should be processed."), + ] + status: Annotated[Status9, Field(description="The current status of the batch.")] + output_file_id: Annotated[ + Optional[str], + Field( + None, + description="The ID of the file containing the outputs of successfully executed requests.", + ), + ] + error_file_id: Annotated[ + Optional[str], + Field( + None, + description="The ID of the file containing the outputs of requests with errors.", + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the batch was created."), + ] + in_progress_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch started processing.", + ), + ] + expires_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch will expire.", + ), + ] + finalizing_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch started finalizing.", + ), + ] + completed_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch was completed.", + ), + ] + failed_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch failed.", + ), + ] + expired_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch expired.", + ), + ] + cancelling_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch started cancelling.", + ), + ] + cancelled_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch was cancelled.", + ), + ] + request_counts: Annotated[ + Optional[RequestCounts], + Field( + None, + description="The request counts for different statuses within the batch.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class Method(Enum): + POST = "POST" + + +class BatchRequestInput(BaseModel): + custom_id: Annotated[ + Optional[str], + Field( + None, + description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch.", + ), + ] + method: Annotated[ + Optional[Method], + Field( + None, + description="The HTTP method to be used for the request. Currently only `POST` is supported.", + ), + ] + url: Annotated[ + Optional[str], + Field( + None, + description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported.", + ), + ] + + +class Response(BaseModel): + status_code: Annotated[ + Optional[int], Field(None, description="The HTTP status code of the response") + ] + request_id: Annotated[ + Optional[str], + Field( + None, + description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support.", + ), + ] + body: Annotated[ + Optional[Dict[str, Any]], + Field(None, description="The JSON body of the response"), + ] + + +class Error2(BaseModel): + code: Annotated[Optional[str], Field(None, description="A machine-readable error code.")] + message: Annotated[Optional[str], Field(None, description="A human-readable error message.")] + + +class BatchRequestOutput(BaseModel): + id: Optional[str] = None + custom_id: Annotated[ + Optional[str], + Field( + None, + description="A developer-provided per-request id that will be used to match outputs to inputs.", + ), + ] + response: Optional[Response] = None + error: Annotated[ + Optional[Error2], + Field( + None, + description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure.", + ), + ] + + +class Object35(Enum): + list = "list" + + +class ListBatchesResponse(BaseModel): + data: List[Batch] + first_id: Annotated[Optional[str], Field(None, examples=["batch_abc123"])] + last_id: Annotated[Optional[str], Field(None, examples=["batch_abc456"])] + has_more: bool + object: Object35 + + +class ListModelsResponse(BaseModel): + object: Object + data: List[Model] + + +class CreateCompletionRequest(BaseModel): + model: Annotated[ + Union[str, Model1], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] + prompt: Annotated[ + Union[str, List[str], Prompt, Prompt1], + Field( + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n" + ), + ] + best_of: Annotated[ + Optional[int], + Field( + 1, + description='Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.\n\nWhen used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n', + ge=0, + le=20, + ), + ] + echo: Annotated[ + Optional[bool], + Field(False, description="Echo back the prompt in addition to the completion\n"), + ] + frequency_penalty: Annotated[ + Optional[float], + Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] + logit_bias: Annotated[ + Optional[Dict[str, int]], + Field( + None, + description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n', + ), + ] + logprobs: Annotated[ + Optional[int], + Field( + None, + description="Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.\n\nThe maximum value for `logprobs` is 5.\n", + ge=0, + le=5, + ), + ] + max_tokens: Annotated[ + Optional[int], + Field( + 16, + description="The maximum number of [tokens](/tokenizer) that can be generated in the completion.\n\nThe token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + examples=[16], + ge=0, + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="How many completions to generate for each prompt.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n", + examples=[1], + ge=1, + le=128, + ), + ] + presence_penalty: Annotated[ + Optional[float], + Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] + seed: Annotated[ + Optional[int], + Field( + None, + description="If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\n\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ge=-9223372036854775808, + le=9223372036854775807, + ), + ] + stop: Annotated[ + Optional[Union[str, Stop]], + Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + ), + ] + stream: Annotated[ + Optional[bool], + Field( + False, + description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + stream_options: Optional[ChatCompletionStreamOptions] = None + suffix: Annotated[ + Optional[str], + Field( + None, + description="The suffix that comes after a completion of inserted text.\n\nThis parameter is only supported for `gpt-3.5-turbo-instruct`.\n", + examples=["test."], + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class CreateCompletionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the completion.")] + choices: Annotated[ + List[Choice], + Field( + description="The list of completion choices the model generated for the input prompt." + ), + ] + created: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the completion was created."), + ] + model: Annotated[str, Field(description="The model used for completion.")] + system_fingerprint: Annotated[ + Optional[str], + Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ), + ] + object: Annotated[ + Object1, Field(description='The object type, which is always "text_completion"') + ] + usage: Optional[CompletionUsage] = None + + +class ChatCompletionRequestMessageContentPart( + RootModel[ + Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + ] +): + root: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + + +class Content(RootModel[List[ChatCompletionRequestMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestMessageContentPart], + Field( + description="An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model.", + min_length=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestUserMessage(BaseModel): + content: Annotated[ + Union[str, Content], Field(description="The contents of the user message.\n") + ] + role: Annotated[ + Role1, + Field(description="The role of the messages author, in this case `user`."), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ), + ] + + +class ChatCompletionTool(BaseModel): + type: Annotated[ + Type2, + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: FunctionObject + + +class ChatCompletionToolChoiceOption( + RootModel[Union[ChatCompletionToolChoiceOption1, ChatCompletionNamedToolChoice]] +): + root: Annotated[ + Union[ChatCompletionToolChoiceOption1, ChatCompletionNamedToolChoice], + Field( + description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tool and instead generates a message.\n`auto` means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools.\nSpecifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n\n`none` is the default when no tools are present. `auto` is the default if tools are present.\n' + ), + ] + + +class ChatCompletionMessageToolCalls(RootModel[List[ChatCompletionMessageToolCall]]): + root: Annotated[ + List[ChatCompletionMessageToolCall], + Field(description="The tool calls generated by the model, such as function calls."), + ] + + +class ChatCompletionResponseMessage(BaseModel): + content: Annotated[str, Field(description="The contents of the message.")] + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + role: Annotated[Role5, Field(description="The role of the author of this message.")] + function_call: Annotated[ + Optional[FunctionCall], + Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ), + ] + + +class Choice1(BaseModel): + finish_reason: Annotated[ + FinishReason1, + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + message: ChatCompletionResponseMessage + logprobs: Annotated[Logprobs2, Field(description="Log probability information for the choice.")] + + +class CreateChatCompletionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the chat completion.")] + choices: Annotated[ + List[Choice1], + Field( + description="A list of chat completion choices. Can be more than one if `n` is greater than 1." + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created." + ), + ] + model: Annotated[str, Field(description="The model used for the chat completion.")] + service_tier: Annotated[ + Optional[ServiceTier1], + Field( + None, + description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", + examples=["scale"], + ), + ] + system_fingerprint: Annotated[ + Optional[str], + Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ), + ] + object: Annotated[ + Object2, + Field(description="The object type, which is always `chat.completion`."), + ] + usage: Optional[CompletionUsage] = None + + +class Choice2(BaseModel): + finish_reason: Annotated[ + FinishReason2, + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + message: ChatCompletionResponseMessage + + +class CreateChatCompletionFunctionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the chat completion.")] + choices: Annotated[ + List[Choice2], + Field( + description="A list of chat completion choices. Can be more than one if `n` is greater than 1." + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created." + ), + ] + model: Annotated[str, Field(description="The model used for the chat completion.")] + system_fingerprint: Annotated[ + Optional[str], + Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ), + ] + object: Annotated[ + Object2, + Field(description="The object type, which is always `chat.completion`."), + ] + usage: Optional[CompletionUsage] = None + + +class ImagesResponse(BaseModel): + created: int + data: List[Image] + + +class ListFilesResponse(BaseModel): + data: List[OpenAIFile] + object: Object6 + + +class ListFineTuningJobEventsResponse(BaseModel): + data: List[FineTuningJobEvent] + object: Object8 + + +class ListFineTuningJobCheckpointsResponse(BaseModel): + data: List[FineTuningJobCheckpoint] + object: Object8 + first_id: Optional[str] = None + last_id: Optional[str] = None + has_more: bool + + +class CreateEmbeddingResponse(BaseModel): + data: Annotated[ + List[Embedding], + Field(description="The list of embeddings generated by the model."), + ] + model: Annotated[ + str, Field(description="The name of the model used to generate the embedding.") + ] + object: Annotated[Object8, Field(description='The object type, which is always "list".')] + usage: Annotated[Usage1, Field(description="The usage information for the request.")] + + +class FineTuningJob(BaseModel): + id: Annotated[ + str, + Field(description="The object identifier, which can be referenced in the API endpoints."), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job was created." + ), + ] + error: Annotated[ + Error1, + Field( + description="For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure." + ), + ] + fine_tuned_model: Annotated[ + str, + Field( + description="The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running." + ), + ] + finished_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running." + ), + ] + hyperparameters: Annotated[ + Hyperparameters1, + Field( + description="The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/fine-tuning) for more details." + ), + ] + model: Annotated[str, Field(description="The base model that is being fine-tuned.")] + object: Annotated[ + Object16, + Field(description='The object type, which is always "fine_tuning.job".'), + ] + organization_id: Annotated[ + str, Field(description="The organization that owns the fine-tuning job.") + ] + result_files: Annotated[ + List[str], + Field( + description="The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + status: Annotated[ + Status2, + Field( + description="The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`." + ), + ] + trained_tokens: Annotated[ + int, + Field( + description="The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running." + ), + ] + training_file: Annotated[ + str, + Field( + description="The file ID used for training. You can retrieve the training data with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + validation_file: Annotated[ + str, + Field( + description="The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + integrations: Annotated[ + Optional[List[FineTuningIntegration]], + Field( + None, + description="A list of integrations to enable for this fine-tuning job.", + max_length=5, + ), + ] + seed: Annotated[int, Field(description="The seed used for the fine-tuning job.")] + estimated_finish: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.", + ), + ] + + +class AssistantsApiResponseFormatOption( + RootModel[Union[AssistantsApiResponseFormatOption1, AssistantsApiResponseFormat]] +): + root: Annotated[ + Union[AssistantsApiResponseFormatOption1, AssistantsApiResponseFormat], + Field( + description='Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' + ), + ] + + +class AssistantObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[Object19, Field(description="The object type, which is always `assistant`.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the assistant was created."), + ] + name: Annotated[ + str, + Field( + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] + description: Annotated[ + str, + Field( + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] + model: Annotated[ + str, + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] + instructions: Annotated[ + str, + Field( + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_length=128, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources], + Field( + None, + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class CreateAssistantRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + model: Annotated[ + Union[str, Model12], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + examples=["gpt-4-turbo"], + ), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] + description: Annotated[ + Optional[str], + Field( + None, + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] + instructions: Annotated[ + Optional[str], + Field( + None, + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + [], + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_length=128, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources1], + Field( + None, + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class ModifyAssistantRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + model: Annotated[ + Optional[str], + Field( + None, + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + ), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] + description: Annotated[ + Optional[str], + Field( + None, + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] + instructions: Annotated[ + Optional[str], + Field( + None, + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + [], + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_length=128, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources2], + Field( + None, + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class ListAssistantsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[AssistantObject] + first_id: Annotated[str, Field(examples=["asst_abc123"])] + last_id: Annotated[str, Field(examples=["asst_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class AssistantsApiToolChoiceOption( + RootModel[Union[AssistantsApiToolChoiceOption1, AssistantsNamedToolChoice]] +): + root: Annotated[ + Union[AssistantsApiToolChoiceOption1, AssistantsNamedToolChoice], + Field( + description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tools and instead generates a message.\n`auto` is the default value and means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools before responding to the user.\nSpecifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n' + ), + ] + + +class SubmitToolOutputs(BaseModel): + tool_calls: Annotated[ + List[RunToolCallObject], Field(description="A list of the relevant tool calls.") + ] + + +class RequiredAction(BaseModel): + type: Annotated[Type20, Field(description="For now, this is always `submit_tool_outputs`.")] + submit_tool_outputs: Annotated[ + SubmitToolOutputs, + Field(description="Details on the tool outputs needed for this run to continue."), + ] + + +class RunObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[Object21, Field(description="The object type, which is always `thread.run`.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was created."), + ] + thread_id: Annotated[ + str, + Field( + description="The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run." + ), + ] + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run." + ), + ] + status: Annotated[ + Status3, + Field( + description="The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`." + ), + ] + required_action: Annotated[ + RequiredAction, + Field( + description="Details on the action required to continue the run. Will be `null` if no action is required." + ), + ] + last_error: Annotated[ + LastError, + Field( + description="The last error associated with this run. Will be `null` if there are no errors." + ), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run will expire."), + ] + started_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was started."), + ] + cancelled_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was cancelled."), + ] + failed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run failed."), + ] + completed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was completed."), + ] + incomplete_details: Annotated[ + IncompleteDetails, + Field( + description="Details on why the run is incomplete. Will be `null` if the run is not incomplete." + ), + ] + model: Annotated[ + str, + Field( + description="The model that the [assistant](/docs/api-reference/assistants) used for this run." + ), + ] + instructions: Annotated[ + str, + Field( + description="The instructions that the [assistant](/docs/api-reference/assistants) used for this run." + ), + ] + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.", + max_length=20, + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + usage: RunCompletionUsage + temperature: Annotated[ + Optional[float], + Field( + None, + description="The sampling temperature used for this run. If not set, defaults to 1.", + ), + ] + top_p: Annotated[ + Optional[float], + Field( + None, + description="The nucleus sampling value used for this run. If not set, defaults to 1.", + ), + ] + max_prompt_tokens: Annotated[ + int, + Field( + description="The maximum number of prompt tokens specified to have been used over the course of the run.\n", + ge=256, + ), + ] + max_completion_tokens: Annotated[ + int, + Field( + description="The maximum number of completion tokens specified to have been used over the course of the run.\n", + ge=256, + ), + ] + truncation_strategy: TruncationObject + tool_choice: AssistantsApiToolChoiceOption + parallel_tool_calls: ParallelToolCalls + response_format: AssistantsApiResponseFormatOption + + +class ListRunsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[RunObject] + first_id: Annotated[str, Field(examples=["run_abc123"])] + last_id: Annotated[str, Field(examples=["run_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class Content1( + RootModel[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageRequestContentTextObject, + ] + ] + ] +): + root: Annotated[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageRequestContentTextObject, + ] + ], + Field( + description="An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](/docs/models/overview).", + min_length=1, + title="Array of content parts", + ), + ] + + +class CreateMessageRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + role: Annotated[ + Role7, + Field( + description="The role of the entity that is creating the message. Allowed values include:\n- `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages.\n- `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation.\n" + ), + ] + content: Union[str, Content1] + attachments: Annotated[ + Optional[List[Attachment]], + Field( + None, + description="A list of files attached to the message, and the tools they should be added to.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class Text(BaseModel): + value: Annotated[str, Field(description="The data that makes up the text.")] + annotations: List[ + Union[ + MessageContentTextAnnotationsFileCitationObject, + MessageContentTextAnnotationsFilePathObject, + ] + ] + + +class MessageContentTextObject(BaseModel): + type: Annotated[Type30, Field(description="Always `text`.")] + text: Text + + +class Text1(BaseModel): + value: Annotated[Optional[str], Field(None, description="The data that makes up the text.")] + annotations: Optional[ + List[ + Union[ + MessageDeltaContentTextAnnotationsFileCitationObject, + MessageDeltaContentTextAnnotationsFilePathObject, + ] + ] + ] = None + + +class MessageDeltaContentTextObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Type34, Field(description="Always `text`.")] + text: Optional[Text1] = None + + +class CodeInterpreter7(BaseModel): + input: Annotated[str, Field(description="The input to the Code Interpreter tool call.")] + outputs: Annotated[ + List[ + Union[ + RunStepDetailsToolCallsCodeOutputLogsObject, + RunStepDetailsToolCallsCodeOutputImageObject, + ] + ], + Field( + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type." + ), + ] + + +class RunStepDetailsToolCallsCodeObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call.")] + type: Annotated[ + Type42, + Field( + description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." + ), + ] + code_interpreter: Annotated[ + CodeInterpreter7, + Field(description="The Code Interpreter tool call definition."), + ] + + +class CodeInterpreter8(BaseModel): + input: Annotated[ + Optional[str], + Field(None, description="The input to the Code Interpreter tool call."), + ] + outputs: Annotated[ + Optional[ + List[ + Union[ + RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject, + RunStepDeltaStepDetailsToolCallsCodeOutputImageObject, + ] + ] + ], + Field( + None, + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type.", + ), + ] + + +class RunStepDeltaStepDetailsToolCallsCodeObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] + type: Annotated[ + Type42, + Field( + description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." + ), + ] + code_interpreter: Annotated[ + Optional[CodeInterpreter8], + Field(None, description="The Code Interpreter tool call definition."), + ] + + +class CreateVectorStoreRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", + max_length=500, + ), + ] + name: Annotated[Optional[str], Field(None, description="The name of the vector store.")] + expires_after: Optional[VectorStoreExpirationAfter] = None + chunking_strategy: Annotated[ + Optional[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class StaticChunkingStrategyResponseParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Type53, Field(description="Always `static`.")] + static: StaticChunkingStrategy + + +class RunStreamEvent1(BaseModel): + event: Event1 + data: RunObject + + +class RunStreamEvent2(BaseModel): + event: Event2 + data: RunObject + + +class RunStreamEvent3(BaseModel): + event: Event3 + data: RunObject + + +class RunStreamEvent4(BaseModel): + event: Event4 + data: RunObject + + +class RunStreamEvent5(BaseModel): + event: Event5 + data: RunObject + + +class RunStreamEvent6(BaseModel): + event: Event6 + data: RunObject + + +class RunStreamEvent7(BaseModel): + event: Event7 + data: RunObject + + +class RunStreamEvent8(BaseModel): + event: Event8 + data: RunObject + + +class RunStreamEvent9(BaseModel): + event: Event9 + data: RunObject + + +class RunStreamEvent10(BaseModel): + event: Event10 + data: RunObject + + +class RunStreamEvent( + RootModel[ + Union[ + RunStreamEvent1, + RunStreamEvent2, + RunStreamEvent3, + RunStreamEvent4, + RunStreamEvent5, + RunStreamEvent6, + RunStreamEvent7, + RunStreamEvent8, + RunStreamEvent9, + RunStreamEvent10, + ] + ] +): + root: Union[ + RunStreamEvent1, + RunStreamEvent2, + RunStreamEvent3, + RunStreamEvent4, + RunStreamEvent5, + RunStreamEvent6, + RunStreamEvent7, + RunStreamEvent8, + RunStreamEvent9, + RunStreamEvent10, + ] + + +class ChatCompletionRequestAssistantMessage(BaseModel): + content: Annotated[ + Optional[str], + Field( + None, + description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n", + ), + ] + role: Annotated[ + Role2, + Field(description="The role of the messages author, in this case `assistant`."), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ), + ] + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + function_call: Annotated[ + Optional[FunctionCall], + Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ), + ] + + +class FineTuneChatCompletionRequestAssistantMessage(ChatCompletionRequestAssistantMessage): + weight: Annotated[ + Optional[Weight], + Field( + None, + description="Controls whether the assistant message is trained against (0 or 1)", + ), + ] + role: Annotated[ + Role2, + Field(description="The role of the messages author, in this case `assistant`."), + ] + + +class ListPaginatedFineTuningJobsResponse(BaseModel): + data: List[FineTuningJob] + has_more: bool + object: Object4 + + +class FinetuneChatRequestInput(BaseModel): + messages: Annotated[ + Optional[ + List[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + FineTuneChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + ] + ], + Field(None, min_length=1), + ] + tools: Annotated[ + Optional[List[ChatCompletionTool]], + Field(None, description="A list of tools the model may generate JSON inputs for."), + ] + parallel_tool_calls: Optional[ParallelToolCalls] = None + functions: Annotated[ + Optional[List[ChatCompletionFunctions]], + Field( + None, + description="A list of functions the model may generate JSON inputs for.", + max_length=128, + min_length=1, + ), + ] + + +class CreateRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run." + ), + ] + model: Annotated[ + Optional[Union[str, Model12]], + Field( + None, + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + examples=["gpt-4-turbo"], + ), + ] + instructions: Annotated[ + Optional[str], + Field( + None, + description="Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis.", + ), + ] + additional_instructions: Annotated[ + Optional[str], + Field( + None, + description="Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions.", + ), + ] + additional_messages: Annotated[ + Optional[List[CreateMessageRequest]], + Field( + None, + description="Adds additional messages to the thread before creating the run.", + ), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + None, + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_length=20, + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + stream: Annotated[ + Optional[bool], + Field( + None, + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + ), + ] + max_prompt_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] + max_completion_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] + truncation_strategy: Optional[TruncationObject] = None + tool_choice: Optional[AssistantsApiToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class CreateThreadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + messages: Annotated[ + Optional[List[CreateMessageRequest]], + Field( + None, + description="A list of [messages](/docs/api-reference/messages) to start the thread with.", + ), + ] + tool_resources: Annotated[ + Optional[ToolResources5], + Field( + None, + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class MessageObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Object24, + Field(description="The object type, which is always `thread.message`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the message was created."), + ] + thread_id: Annotated[ + str, + Field( + description="The [thread](/docs/api-reference/threads) ID that this message belongs to." + ), + ] + status: Annotated[ + Status4, + Field( + description="The status of the message, which can be either `in_progress`, `incomplete`, or `completed`." + ), + ] + incomplete_details: Annotated[ + IncompleteDetails1, + Field(description="On an incomplete message, details about why the message is incomplete."), + ] + completed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the message was completed."), + ] + incomplete_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the message was marked as incomplete." + ), + ] + role: Annotated[ + Role7, + Field(description="The entity that produced the message. One of `user` or `assistant`."), + ] + content: Annotated[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageContentTextObject, + ] + ], + Field(description="The content of the message in array of text and/or images."), + ] + assistant_id: Annotated[ + str, + Field( + description="If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message." + ), + ] + run_id: Annotated[ + str, + Field( + description="The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints." + ), + ] + attachments: Annotated[ + List[Attachment], + Field( + description="A list of files attached to the message, and the tools they were added to." + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class Delta(BaseModel): + role: Annotated[ + Optional[Role7], + Field( + None, + description="The entity that produced the message. One of `user` or `assistant`.", + ), + ] + content: Annotated[ + Optional[ + List[ + Union[ + MessageDeltaContentImageFileObject, + MessageDeltaContentTextObject, + MessageDeltaContentImageUrlObject, + ] + ] + ], + Field( + None, + description="The content of the message in array of text and/or images.", + ), + ] + + +class MessageDeltaObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the message, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Object25, + Field(description="The object type, which is always `thread.message.delta`."), + ] + delta: Annotated[ + Delta, + Field(description="The delta containing the fields that have changed on the Message."), + ] + + +class ListMessagesResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[MessageObject] + first_id: Annotated[str, Field(examples=["msg_abc123"])] + last_id: Annotated[str, Field(examples=["msg_abc123"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class RunStepDetailsToolCallsObject(BaseModel): + type: Annotated[Type40, Field(description="Always `tool_calls`.")] + tool_calls: Annotated[ + List[ + Union[ + RunStepDetailsToolCallsCodeObject, + RunStepDetailsToolCallsFileSearchObject, + RunStepDetailsToolCallsFunctionObject, + ] + ], + Field( + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n" + ), + ] + + +class RunStepDeltaStepDetailsToolCallsObject(BaseModel): + type: Annotated[Type40, Field(description="Always `tool_calls`.")] + tool_calls: Annotated[ + Optional[ + List[ + Union[ + RunStepDeltaStepDetailsToolCallsCodeObject, + RunStepDeltaStepDetailsToolCallsFileSearchObject, + RunStepDeltaStepDetailsToolCallsFunctionObject, + ] + ] + ], + Field( + None, + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n", + ), + ] + + +class VectorStoreFileObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Object31, + Field(description="The object type, which is always `vector_store.file`."), + ] + usage_bytes: Annotated[ + int, + Field( + description="The total vector store usage in bytes. Note that this may be different from the original file size." + ), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store file was created." + ), + ] + vector_store_id: Annotated[ + str, + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to." + ), + ] + status: Annotated[ + Status7, + Field( + description="The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use." + ), + ] + last_error: Annotated[ + LastError2, + Field( + description="The last error associated with this vector store file. Will be `null` if there are no errors." + ), + ] + chunking_strategy: Annotated[ + Optional[Union[StaticChunkingStrategyResponseParam, OtherChunkingStrategyResponseParam]], + Field(None, description="The strategy used to chunk the file."), + ] + + +class ListVectorStoreFilesResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[VectorStoreFileObject] + first_id: Annotated[str, Field(examples=["file-abc123"])] + last_id: Annotated[str, Field(examples=["file-abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class MessageStreamEvent1(BaseModel): + event: Event18 + data: MessageObject + + +class MessageStreamEvent2(BaseModel): + event: Event19 + data: MessageObject + + +class MessageStreamEvent3(BaseModel): + event: Event20 + data: MessageDeltaObject + + +class MessageStreamEvent4(BaseModel): + event: Event21 + data: MessageObject + + +class MessageStreamEvent5(BaseModel): + event: Event22 + data: MessageObject + + +class MessageStreamEvent( + RootModel[ + Union[ + MessageStreamEvent1, + MessageStreamEvent2, + MessageStreamEvent3, + MessageStreamEvent4, + MessageStreamEvent5, + ] + ] +): + root: Union[ + MessageStreamEvent1, + MessageStreamEvent2, + MessageStreamEvent3, + MessageStreamEvent4, + MessageStreamEvent5, + ] + + +class ChatCompletionRequestMessage( + RootModel[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + ] +): + root: Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + + +class CreateChatCompletionRequest(BaseModel): + messages: Annotated[ + List[ChatCompletionRequestMessage], + Field( + description="A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).", + min_length=1, + ), + ] + model: Annotated[ + Union[str, Model2], + Field( + description="ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.", + examples=["gpt-4-turbo"], + ), + ] + frequency_penalty: Annotated[ + Optional[float], + Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] + logit_bias: Annotated[ + Optional[Dict[str, int]], + Field( + None, + description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n", + ), + ] + logprobs: Annotated[ + Optional[bool], + Field( + False, + description="Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.", + ), + ] + top_logprobs: Annotated[ + Optional[int], + Field( + None, + description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.", + ge=0, + le=20, + ), + ] + max_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs.", + examples=[1], + ge=1, + le=128, + ), + ] + presence_penalty: Annotated[ + Optional[float], + Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] + response_format: Annotated[ + Optional[ResponseFormat], + Field( + None, + description='An object specifying the format that the model must output. Compatible with [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n', + ), + ] + seed: Annotated[ + Optional[int], + Field( + None, + description="This feature is in Beta.\nIf specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ge=-9223372036854775808, + le=9223372036854775807, + ), + ] + service_tier: Annotated[ + Optional[ServiceTier], + Field( + None, + description="Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service:\n - If set to 'auto', the system will utilize scale tier credits until they are exhausted.\n - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee.\n - When not set, the default behavior is 'auto'.\n\n When this parameter is set, the response body will include the `service_tier` utilized.\n", + ), + ] + stop: Annotated[ + Optional[Union[str, Stop1]], + Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens.\n", + ), + ] + stream: Annotated[ + Optional[bool], + Field( + False, + description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + stream_options: Optional[ChatCompletionStreamOptions] = None + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + tools: Annotated[ + Optional[List[ChatCompletionTool]], + Field( + None, + description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.\n", + ), + ] + tool_choice: Optional[ChatCompletionToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + function_call: Annotated[ + Optional[Union[FunctionCall3, ChatCompletionFunctionCallOption]], + Field( + None, + description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n', + ), + ] + functions: Annotated[ + Optional[List[ChatCompletionFunctions]], + Field( + None, + description="Deprecated in favor of `tools`.\n\nA list of functions the model may generate JSON inputs for.\n", + max_length=128, + min_length=1, + ), + ] + + +class CreateThreadAndRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run." + ), + ] + thread: Annotated[ + Optional[CreateThreadRequest], + Field( + None, + description="If no thread is provided, an empty thread will be created.", + ), + ] + model: Annotated[ + Optional[Union[str, Model12]], + Field( + None, + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + examples=["gpt-4-turbo"], + ), + ] + instructions: Annotated[ + Optional[str], + Field( + None, + description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis.", + ), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + None, + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_length=20, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources3], + Field( + None, + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + stream: Annotated[ + Optional[bool], + Field( + None, + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + ), + ] + max_prompt_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] + max_completion_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] + truncation_strategy: Optional[TruncationObject] = None + tool_choice: Optional[AssistantsApiToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class RunStepObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the run step, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Object27, + Field(description="The object type, which is always `thread.run.step`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step was created."), + ] + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) associated with the run step." + ), + ] + thread_id: Annotated[ + str, + Field(description="The ID of the [thread](/docs/api-reference/threads) that was run."), + ] + run_id: Annotated[ + str, + Field( + description="The ID of the [run](/docs/api-reference/runs) that this run step is a part of." + ), + ] + type: Annotated[ + Type37, + Field( + description="The type of run step, which can be either `message_creation` or `tool_calls`." + ), + ] + status: Annotated[ + Status5, + Field( + description="The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`." + ), + ] + step_details: Annotated[ + Union[RunStepDetailsMessageCreationObject, RunStepDetailsToolCallsObject], + Field(description="The details of the run step."), + ] + last_error: Annotated[ + LastError1, + Field( + description="The last error associated with this run step. Will be `null` if there are no errors." + ), + ] + expired_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired." + ), + ] + cancelled_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step was cancelled."), + ] + failed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step failed."), + ] + completed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step completed."), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + usage: RunStepCompletionUsage + + +class Delta1(BaseModel): + step_details: Annotated[ + Optional[ + Union[ + RunStepDeltaStepDetailsMessageCreationObject, + RunStepDeltaStepDetailsToolCallsObject, + ] + ], + Field(None, description="The details of the run step."), + ] + + +class RunStepDeltaObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the run step, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Object28, + Field(description="The object type, which is always `thread.run.step.delta`."), + ] + delta: Annotated[ + Delta1, + Field(description="The delta containing the fields that have changed on the run step."), + ] + + +class ListRunStepsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[RunStepObject] + first_id: Annotated[str, Field(examples=["step_abc123"])] + last_id: Annotated[str, Field(examples=["step_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class RunStepStreamEvent1(BaseModel): + event: Event11 + data: RunStepObject + + +class RunStepStreamEvent2(BaseModel): + event: Event12 + data: RunStepObject + + +class RunStepStreamEvent3(BaseModel): + event: Event13 + data: RunStepDeltaObject + + +class RunStepStreamEvent4(BaseModel): + event: Event14 + data: RunStepObject + + +class RunStepStreamEvent5(BaseModel): + event: Event15 + data: RunStepObject + + +class RunStepStreamEvent6(BaseModel): + event: Event16 + data: RunStepObject + + +class RunStepStreamEvent7(BaseModel): + event: Event17 + data: RunStepObject + + +class RunStepStreamEvent( + RootModel[ + Union[ + RunStepStreamEvent1, + RunStepStreamEvent2, + RunStepStreamEvent3, + RunStepStreamEvent4, + RunStepStreamEvent5, + RunStepStreamEvent6, + RunStepStreamEvent7, + ] + ] +): + root: Union[ + RunStepStreamEvent1, + RunStepStreamEvent2, + RunStepStreamEvent3, + RunStepStreamEvent4, + RunStepStreamEvent5, + RunStepStreamEvent6, + RunStepStreamEvent7, + ] + + +class AssistantStreamEvent( + RootModel[ + Union[ + ThreadStreamEvent, + RunStreamEvent, + RunStepStreamEvent, + MessageStreamEvent, + ErrorEvent, + DoneEvent, + ] + ] +): + root: Annotated[ + Union[ + ThreadStreamEvent, + RunStreamEvent, + RunStepStreamEvent, + MessageStreamEvent, + ErrorEvent, + DoneEvent, + ], + Field( + description='Represents an event emitted when streaming a Run.\n\nEach event in a server-sent events stream has an `event` and `data` property:\n\n```\nevent: thread.created\ndata: {"id": "thread_123", "object": "thread", ...}\n```\n\nWe emit events whenever a new object is created, transitions to a new state, or is being\nstreamed in parts (deltas). For example, we emit `thread.run.created` when a new run\nis created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses\nto create a message during a run, we emit a `thread.message.created event`, a\n`thread.message.in_progress` event, many `thread.message.delta` events, and finally a\n`thread.message.completed` event.\n\nWe may add additional events over time, so we recommend handling unknown events gracefully\nin your code. See the [Assistants API quickstart](/docs/assistants/overview) to learn how to\nintegrate the Assistants API with streaming.\n' + ), + ] diff --git a/model-engine/mypy.ini b/model-engine/mypy.ini index 82c6107a..cfd7e38d 100644 --- a/model-engine/mypy.ini +++ b/model-engine/mypy.ini @@ -25,3 +25,6 @@ ignore_errors = True [mypy-tests.*] ignore_errors = True + +[mypy-model_engine_server.common.types.gen.openai] +ignore_errors = True \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index f785f8de..87959381 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,10 @@ # Make sure to update .pre-commit-config.yaml to match versions! black[jupyter]==22.12.0 +datamodel-code-generator>=0.25.8 ruff==0.0.278 ipython==8.12.0 # 8.12.0 is the last version to support Python 3.8 isort==5.12.0 mypy==1.3.0 pip-tools==7.0.0 poetry==1.8.2 -pre-commit==3.3.3 +pre-commit==3.3.3 \ No newline at end of file diff --git a/scripts/generate-openai-types.sh b/scripts/generate-openai-types.sh new file mode 100755 index 00000000..cbe3b323 --- /dev/null +++ b/scripts/generate-openai-types.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +BASE_DIR=${SCRIPT_DIR}/.. + +DEST_DIR=${BASE_DIR}/model-engine/model_engine_server/common/types/gen +OPENAI_SPEC=${SCRIPT_DIR}/openai-spec.yaml + +# Generate OpenAPI types +datamodel-codegen \ + --input ${OPENAI_SPEC} \ + --input-file-type openapi \ + --output ${DEST_DIR}/openai.py \ + --output-model-type pydantic_v2.BaseModel \ + --field-constraints \ + --use-annotated \ No newline at end of file diff --git a/scripts/openai-spec.yaml b/scripts/openai-spec.yaml new file mode 100644 index 00000000..aebd38a5 --- /dev/null +++ b/scripts/openai-spec.yaml @@ -0,0 +1,14267 @@ +openapi: 3.0.0 +info: + title: OpenAI API + description: The OpenAI REST API. Please see https://platform.openai.com/docs/api-reference for more details. + version: "2.1.0" + termsOfService: https://openai.com/policies/terms-of-use + contact: + name: OpenAI Support + url: https://help.openai.com/ + license: + name: MIT + url: https://github.com/openai/openai-openapi/blob/master/LICENSE +servers: + - url: https://api.openai.com/v1 +tags: + - name: Assistants + description: Build Assistants that can call models and use tools. + - name: Audio + description: Turn audio into text or text into audio. + - name: Chat + description: Given a list of messages comprising a conversation, the model will return a response. + - name: Completions + description: Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position. + - name: Embeddings + description: Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms. + - name: Fine-tuning + description: Manage fine-tuning jobs to tailor a model to your specific training data. + - name: Batch + description: Create large batches of API requests to run asynchronously. + - name: Files + description: Files are used to upload documents that can be used with features like Assistants and Fine-tuning. + - name: Uploads + description: Use Uploads to upload large files in multiple parts. + - name: Images + description: Given a prompt and/or an input image, the model will generate a new image. + - name: Models + description: List and describe the various models available in the API. + - name: Moderations + description: Given a input text, outputs if the model classifies it as potentially harmful. +paths: + # Note: When adding an endpoint, make sure you also add it in the `groups` section, in the end of this file, + # under the appropriate group + /chat/completions: + post: + operationId: createChatCompletion + tags: + - Chat + summary: Creates a model response for the given chat conversation. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateChatCompletionRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/CreateChatCompletionResponse" + + x-oaiMeta: + name: Create chat completion + group: chat + returns: | + Returns a [chat completion](/docs/api-reference/chat/object) object, or a streamed sequence of [chat completion chunk](/docs/api-reference/chat/streaming) objects if the request is streamed. + path: create + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello!" + } + ] + }' + python: | + from openai import OpenAI + client = OpenAI() + + completion = client.chat.completions.create( + model="VAR_model_id", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ] + ) + + print(completion.choices[0].message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const completion = await openai.chat.completions.create({ + messages: [{ role: "system", content: "You are a helpful assistant." }], + model: "VAR_model_id", + }); + + console.log(completion.choices[0]); + } + + main(); + response: &chat_completion_example | + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "system_fingerprint": "fp_44709d6fcb", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "\n\nHello there, how may I assist you today?", + }, + "logprobs": null, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + - title: Image input + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "gpt-4-turbo", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What'\''s in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + } + ] + } + ], + "max_tokens": 300 + }' + python: | + from openai import OpenAI + + client = OpenAI() + + response = client.chat.completions.create( + model="gpt-4-turbo", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + }, + ], + } + ], + max_tokens=300, + ) + + print(response.choices[0]) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const response = await openai.chat.completions.create({ + model: "gpt-4-turbo", + messages: [ + { + role: "user", + content: [ + { type: "text", text: "What's in this image?" }, + { + type: "image_url", + image_url: + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + }, + ], + }, + ], + }); + console.log(response.choices[0]); + } + main(); + response: &chat_completion_image_example | + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "system_fingerprint": "fp_44709d6fcb", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "\n\nThis image shows a wooden boardwalk extending through a lush green marshland.", + }, + "logprobs": null, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello!" + } + ], + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + completion = client.chat.completions.create( + model="VAR_model_id", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ], + stream=True + ) + + for chunk in completion: + print(chunk.choices[0].delta) + + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const completion = await openai.chat.completions.create({ + model: "VAR_model_id", + messages: [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ], + stream: true, + }); + + for await (const chunk of completion) { + console.log(chunk.choices[0].delta.content); + } + } + + main(); + response: &chat_completion_chunk_example | + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]} + + .... + + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + - title: Functions + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "gpt-4-turbo", + "messages": [ + { + "role": "user", + "content": "What'\''s the weather like in Boston today?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "tool_choice": "auto" + }' + python: | + from openai import OpenAI + client = OpenAI() + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ] + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + completion = client.chat.completions.create( + model="VAR_model_id", + messages=messages, + tools=tools, + tool_choice="auto" + ) + + print(completion) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]; + const tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ]; + + const response = await openai.chat.completions.create({ + model: "gpt-4-turbo", + messages: messages, + tools: tools, + tool_choice: "auto", + }); + + console.log(response); + } + + main(); + response: &chat_completion_function_example | + { + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n\"location\": \"Boston, MA\"\n}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 82, + "completion_tokens": 17, + "total_tokens": 99 + } + } + - title: Logprobs + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "messages": [ + { + "role": "user", + "content": "Hello!" + } + ], + "logprobs": true, + "top_logprobs": 2 + }' + python: | + from openai import OpenAI + client = OpenAI() + + completion = client.chat.completions.create( + model="VAR_model_id", + messages=[ + {"role": "user", "content": "Hello!"} + ], + logprobs=True, + top_logprobs=2 + ) + + print(completion.choices[0].message) + print(completion.choices[0].logprobs) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const completion = await openai.chat.completions.create({ + messages: [{ role: "user", content: "Hello!" }], + model: "VAR_model_id", + logprobs: true, + top_logprobs: 2, + }); + + console.log(completion.choices[0]); + } + + main(); + response: | + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1702685778, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + "logprobs": { + "content": [ + { + "token": "Hello", + "logprob": -0.31725305, + "bytes": [72, 101, 108, 108, 111], + "top_logprobs": [ + { + "token": "Hello", + "logprob": -0.31725305, + "bytes": [72, 101, 108, 108, 111] + }, + { + "token": "Hi", + "logprob": -1.3190403, + "bytes": [72, 105] + } + ] + }, + { + "token": "!", + "logprob": -0.02380986, + "bytes": [ + 33 + ], + "top_logprobs": [ + { + "token": "!", + "logprob": -0.02380986, + "bytes": [33] + }, + { + "token": " there", + "logprob": -3.787621, + "bytes": [32, 116, 104, 101, 114, 101] + } + ] + }, + { + "token": " How", + "logprob": -0.000054669687, + "bytes": [32, 72, 111, 119], + "top_logprobs": [ + { + "token": " How", + "logprob": -0.000054669687, + "bytes": [32, 72, 111, 119] + }, + { + "token": "<|end|>", + "logprob": -10.953937, + "bytes": null + } + ] + }, + { + "token": " can", + "logprob": -0.015801601, + "bytes": [32, 99, 97, 110], + "top_logprobs": [ + { + "token": " can", + "logprob": -0.015801601, + "bytes": [32, 99, 97, 110] + }, + { + "token": " may", + "logprob": -4.161023, + "bytes": [32, 109, 97, 121] + } + ] + }, + { + "token": " I", + "logprob": -3.7697225e-6, + "bytes": [ + 32, + 73 + ], + "top_logprobs": [ + { + "token": " I", + "logprob": -3.7697225e-6, + "bytes": [32, 73] + }, + { + "token": " assist", + "logprob": -13.596657, + "bytes": [32, 97, 115, 115, 105, 115, 116] + } + ] + }, + { + "token": " assist", + "logprob": -0.04571125, + "bytes": [32, 97, 115, 115, 105, 115, 116], + "top_logprobs": [ + { + "token": " assist", + "logprob": -0.04571125, + "bytes": [32, 97, 115, 115, 105, 115, 116] + }, + { + "token": " help", + "logprob": -3.1089056, + "bytes": [32, 104, 101, 108, 112] + } + ] + }, + { + "token": " you", + "logprob": -5.4385737e-6, + "bytes": [32, 121, 111, 117], + "top_logprobs": [ + { + "token": " you", + "logprob": -5.4385737e-6, + "bytes": [32, 121, 111, 117] + }, + { + "token": " today", + "logprob": -12.807695, + "bytes": [32, 116, 111, 100, 97, 121] + } + ] + }, + { + "token": " today", + "logprob": -0.0040071653, + "bytes": [32, 116, 111, 100, 97, 121], + "top_logprobs": [ + { + "token": " today", + "logprob": -0.0040071653, + "bytes": [32, 116, 111, 100, 97, 121] + }, + { + "token": "?", + "logprob": -5.5247097, + "bytes": [63] + } + ] + }, + { + "token": "?", + "logprob": -0.0008108172, + "bytes": [63], + "top_logprobs": [ + { + "token": "?", + "logprob": -0.0008108172, + "bytes": [63] + }, + { + "token": "?\n", + "logprob": -7.184561, + "bytes": [63, 10] + } + ] + } + ] + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 9, + "total_tokens": 18 + }, + "system_fingerprint": null + } + + /completions: + post: + operationId: createCompletion + tags: + - Completions + summary: Creates a completion for the provided prompt and parameters. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateCompletionRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/CreateCompletionResponse" + x-oaiMeta: + name: Create completion + group: completions + returns: | + Returns a [completion](/docs/api-reference/completions/object) object, or a sequence of completion objects if the request is streamed. + legacy: true + examples: + - title: No streaming + request: + curl: | + curl https://api.openai.com/v1/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "prompt": "Say this is a test", + "max_tokens": 7, + "temperature": 0 + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.completions.create( + model="VAR_model_id", + prompt="Say this is a test", + max_tokens=7, + temperature=0 + ) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const completion = await openai.completions.create({ + model: "VAR_model_id", + prompt: "Say this is a test.", + max_tokens: 7, + temperature: 0, + }); + + console.log(completion); + } + main(); + response: | + { + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "VAR_model_id", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "text": "\n\nThis is indeed a test", + "index": 0, + "logprobs": null, + "finish_reason": "length" + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12 + } + } + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "prompt": "Say this is a test", + "max_tokens": 7, + "temperature": 0, + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + for chunk in client.completions.create( + model="VAR_model_id", + prompt="Say this is a test", + max_tokens=7, + temperature=0, + stream=True + ): + print(chunk.choices[0].text) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const stream = await openai.completions.create({ + model: "VAR_model_id", + prompt: "Say this is a test.", + stream: true, + }); + + for await (const chunk of stream) { + console.log(chunk.choices[0].text) + } + } + main(); + response: | + { + "id": "cmpl-7iA7iJjj8V2zOkCGvWF2hAkDWBQZe", + "object": "text_completion", + "created": 1690759702, + "choices": [ + { + "text": "This", + "index": 0, + "logprobs": null, + "finish_reason": null + } + ], + "model": "gpt-3.5-turbo-instruct" + "system_fingerprint": "fp_44709d6fcb", + } + + /images/generations: + post: + operationId: createImage + tags: + - Images + summary: Creates an image given a prompt. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateImageRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ImagesResponse" + x-oaiMeta: + name: Create image + group: images + returns: Returns a list of [image](/docs/api-reference/images/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/images/generations \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "dall-e-3", + "prompt": "A cute baby sea otter", + "n": 1, + "size": "1024x1024" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.images.generate( + model="dall-e-3", + prompt="A cute baby sea otter", + n=1, + size="1024x1024" + ) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const image = await openai.images.generate({ model: "dall-e-3", prompt: "A cute baby sea otter" }); + + console.log(image.data); + } + main(); + response: | + { + "created": 1589478378, + "data": [ + { + "url": "https://..." + }, + { + "url": "https://..." + } + ] + } + /images/edits: + post: + operationId: createImageEdit + tags: + - Images + summary: Creates an edited or extended image given an original image and a prompt. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateImageEditRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ImagesResponse" + x-oaiMeta: + name: Create image edit + group: images + returns: Returns a list of [image](/docs/api-reference/images/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/images/edits \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -F image="@otter.png" \ + -F mask="@mask.png" \ + -F prompt="A cute baby sea otter wearing a beret" \ + -F n=2 \ + -F size="1024x1024" + python: | + from openai import OpenAI + client = OpenAI() + + client.images.edit( + image=open("otter.png", "rb"), + mask=open("mask.png", "rb"), + prompt="A cute baby sea otter wearing a beret", + n=2, + size="1024x1024" + ) + node.js: |- + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const image = await openai.images.edit({ + image: fs.createReadStream("otter.png"), + mask: fs.createReadStream("mask.png"), + prompt: "A cute baby sea otter wearing a beret", + }); + + console.log(image.data); + } + main(); + response: | + { + "created": 1589478378, + "data": [ + { + "url": "https://..." + }, + { + "url": "https://..." + } + ] + } + /images/variations: + post: + operationId: createImageVariation + tags: + - Images + summary: Creates a variation of a given image. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateImageVariationRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ImagesResponse" + x-oaiMeta: + name: Create image variation + group: images + returns: Returns a list of [image](/docs/api-reference/images/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/images/variations \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -F image="@otter.png" \ + -F n=2 \ + -F size="1024x1024" + python: | + from openai import OpenAI + client = OpenAI() + + response = client.images.create_variation( + image=open("image_edit_original.png", "rb"), + n=2, + size="1024x1024" + ) + node.js: |- + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const image = await openai.images.createVariation({ + image: fs.createReadStream("otter.png"), + }); + + console.log(image.data); + } + main(); + response: | + { + "created": 1589478378, + "data": [ + { + "url": "https://..." + }, + { + "url": "https://..." + } + ] + } + + /embeddings: + post: + operationId: createEmbedding + tags: + - Embeddings + summary: Creates an embedding vector representing the input text. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateEmbeddingRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/CreateEmbeddingResponse" + x-oaiMeta: + name: Create embeddings + group: embeddings + returns: A list of [embedding](/docs/api-reference/embeddings/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/embeddings \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The food was delicious and the waiter...", + "model": "text-embedding-ada-002", + "encoding_format": "float" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.embeddings.create( + model="text-embedding-ada-002", + input="The food was delicious and the waiter...", + encoding_format="float" + ) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const embedding = await openai.embeddings.create({ + model: "text-embedding-ada-002", + input: "The quick brown fox jumped over the lazy dog", + encoding_format: "float", + }); + + console.log(embedding); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ + 0.0023064255, + -0.009327292, + .... (1536 floats total for ada-002) + -0.0028842222, + ], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + + /audio/speech: + post: + operationId: createSpeech + tags: + - Audio + summary: Generates audio from the input text. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateSpeechRequest" + responses: + "200": + description: OK + headers: + Transfer-Encoding: + schema: + type: string + description: chunked + content: + application/octet-stream: + schema: + type: string + format: binary + x-oaiMeta: + name: Create speech + group: audio + returns: The audio file content. + examples: + request: + curl: | + curl https://api.openai.com/v1/audio/speech \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "tts-1", + "input": "The quick brown fox jumped over the lazy dog.", + "voice": "alloy" + }' \ + --output speech.mp3 + python: | + from pathlib import Path + import openai + + speech_file_path = Path(__file__).parent / "speech.mp3" + response = openai.audio.speech.create( + model="tts-1", + voice="alloy", + input="The quick brown fox jumped over the lazy dog." + ) + response.stream_to_file(speech_file_path) + node: | + import fs from "fs"; + import path from "path"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + const speechFile = path.resolve("./speech.mp3"); + + async function main() { + const mp3 = await openai.audio.speech.create({ + model: "tts-1", + voice: "alloy", + input: "Today is a wonderful day to build something people love!", + }); + console.log(speechFile); + const buffer = Buffer.from(await mp3.arrayBuffer()); + await fs.promises.writeFile(speechFile, buffer); + } + main(); + /audio/transcriptions: + post: + operationId: createTranscription + tags: + - Audio + summary: Transcribes audio into the input language. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateTranscriptionRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + oneOf: + - $ref: "#/components/schemas/CreateTranscriptionResponseJson" + - $ref: "#/components/schemas/CreateTranscriptionResponseVerboseJson" + x-oaiMeta: + name: Create transcription + group: audio + returns: The [transcription object](/docs/api-reference/audio/json-object) or a [verbose transcription object](/docs/api-reference/audio/verbose-json-object). + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/audio/transcriptions \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/audio.mp3" \ + -F model="whisper-1" + python: | + from openai import OpenAI + client = OpenAI() + + audio_file = open("speech.mp3", "rb") + transcript = client.audio.transcriptions.create( + model="whisper-1", + file=audio_file + ) + node: | + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const transcription = await openai.audio.transcriptions.create({ + file: fs.createReadStream("audio.mp3"), + model: "whisper-1", + }); + + console.log(transcription.text); + } + main(); + response: &basic_transcription_response_example | + { + "text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that." + } + - title: Word timestamps + request: + curl: | + curl https://api.openai.com/v1/audio/transcriptions \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/audio.mp3" \ + -F "timestamp_granularities[]=word" \ + -F model="whisper-1" \ + -F response_format="verbose_json" + python: | + from openai import OpenAI + client = OpenAI() + + audio_file = open("speech.mp3", "rb") + transcript = client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="verbose_json", + timestamp_granularities=["word"] + ) + + print(transcript.words) + node: | + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const transcription = await openai.audio.transcriptions.create({ + file: fs.createReadStream("audio.mp3"), + model: "whisper-1", + response_format: "verbose_json", + timestamp_granularities: ["word"] + }); + + console.log(transcription.text); + } + main(); + response: | + { + "task": "transcribe", + "language": "english", + "duration": 8.470000267028809, + "text": "The beach was a popular spot on a hot summer day. People were swimming in the ocean, building sandcastles, and playing beach volleyball.", + "words": [ + { + "word": "The", + "start": 0.0, + "end": 0.23999999463558197 + }, + ... + { + "word": "volleyball", + "start": 7.400000095367432, + "end": 7.900000095367432 + } + ] + } + - title: Segment timestamps + request: + curl: | + curl https://api.openai.com/v1/audio/transcriptions \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/audio.mp3" \ + -F "timestamp_granularities[]=segment" \ + -F model="whisper-1" \ + -F response_format="verbose_json" + python: | + from openai import OpenAI + client = OpenAI() + + audio_file = open("speech.mp3", "rb") + transcript = client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="verbose_json", + timestamp_granularities=["segment"] + ) + + print(transcript.words) + node: | + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const transcription = await openai.audio.transcriptions.create({ + file: fs.createReadStream("audio.mp3"), + model: "whisper-1", + response_format: "verbose_json", + timestamp_granularities: ["segment"] + }); + + console.log(transcription.text); + } + main(); + response: &verbose_transcription_response_example | + { + "task": "transcribe", + "language": "english", + "duration": 8.470000267028809, + "text": "The beach was a popular spot on a hot summer day. People were swimming in the ocean, building sandcastles, and playing beach volleyball.", + "segments": [ + { + "id": 0, + "seek": 0, + "start": 0.0, + "end": 3.319999933242798, + "text": " The beach was a popular spot on a hot summer day.", + "tokens": [ + 50364, 440, 7534, 390, 257, 3743, 4008, 322, 257, 2368, 4266, 786, 13, 50530 + ], + "temperature": 0.0, + "avg_logprob": -0.2860786020755768, + "compression_ratio": 1.2363636493682861, + "no_speech_prob": 0.00985979475080967 + }, + ... + ] + } + /audio/translations: + post: + operationId: createTranslation + tags: + - Audio + summary: Translates audio into English. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateTranslationRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + oneOf: + - $ref: "#/components/schemas/CreateTranslationResponseJson" + - $ref: "#/components/schemas/CreateTranslationResponseVerboseJson" + x-oaiMeta: + name: Create translation + group: audio + returns: The translated text. + examples: + request: + curl: | + curl https://api.openai.com/v1/audio/translations \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/german.m4a" \ + -F model="whisper-1" + python: | + from openai import OpenAI + client = OpenAI() + + audio_file = open("speech.mp3", "rb") + transcript = client.audio.translations.create( + model="whisper-1", + file=audio_file + ) + node: | + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const translation = await openai.audio.translations.create({ + file: fs.createReadStream("speech.mp3"), + model: "whisper-1", + }); + + console.log(translation.text); + } + main(); + response: | + { + "text": "Hello, my name is Wolfgang and I come from Germany. Where are you heading today?" + } + + /files: + get: + operationId: listFiles + tags: + - Files + summary: Returns a list of files that belong to the user's organization. + parameters: + - in: query + name: purpose + required: false + schema: + type: string + description: Only return files with the given purpose. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListFilesResponse" + x-oaiMeta: + name: List files + group: files + returns: A list of [File](/docs/api-reference/files/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.files.list() + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.files.list(); + + for await (const file of list) { + console.log(file); + } + } + + main(); + response: | + { + "data": [ + { + "id": "file-abc123", + "object": "file", + "bytes": 175, + "created_at": 1613677385, + "filename": "salesOverview.pdf", + "purpose": "assistants", + }, + { + "id": "file-abc123", + "object": "file", + "bytes": 140, + "created_at": 1613779121, + "filename": "puppy.jsonl", + "purpose": "fine-tune", + } + ], + "object": "list" + } + post: + operationId: createFile + tags: + - Files + summary: | + Upload a file that can be used across various endpoints. Individual files can be up to 512 MB, and the size of all files uploaded by one organization can be up to 100 GB. + + The Assistants API supports files up to 2 million tokens and of specific file types. See the [Assistants Tools guide](/docs/assistants/tools) for details. + + The Fine-tuning API only supports `.jsonl` files. The input also has certain required formats for fine-tuning [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) models. + + The Batch API only supports `.jsonl` files up to 100 MB in size. The input also has a specific required [format](/docs/api-reference/batch/request-input). + + Please [contact us](https://help.openai.com/) if you need to increase these storage limits. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateFileRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/OpenAIFile" + x-oaiMeta: + name: Upload file + group: files + returns: The uploaded [File](/docs/api-reference/files/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -F purpose="fine-tune" \ + -F file="@mydata.jsonl" + python: | + from openai import OpenAI + client = OpenAI() + + client.files.create( + file=open("mydata.jsonl", "rb"), + purpose="fine-tune" + ) + node.js: |- + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const file = await openai.files.create({ + file: fs.createReadStream("mydata.jsonl"), + purpose: "fine-tune", + }); + + console.log(file); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "file", + "bytes": 120000, + "created_at": 1677610602, + "filename": "mydata.jsonl", + "purpose": "fine-tune", + } + /files/{file_id}: + delete: + operationId: deleteFile + tags: + - Files + summary: Delete a file. + parameters: + - in: path + name: file_id + required: true + schema: + type: string + description: The ID of the file to use for this request. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteFileResponse" + x-oaiMeta: + name: Delete file + group: files + returns: Deletion status. + examples: + request: + curl: | + curl https://api.openai.com/v1/files/file-abc123 \ + -X DELETE \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.files.delete("file-abc123") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const file = await openai.files.del("file-abc123"); + + console.log(file); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "file", + "deleted": true + } + get: + operationId: retrieveFile + tags: + - Files + summary: Returns information about a specific file. + parameters: + - in: path + name: file_id + required: true + schema: + type: string + description: The ID of the file to use for this request. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/OpenAIFile" + x-oaiMeta: + name: Retrieve file + group: files + returns: The [File](/docs/api-reference/files/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/files/file-abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.files.retrieve("file-abc123") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const file = await openai.files.retrieve("file-abc123"); + + console.log(file); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "file", + "bytes": 120000, + "created_at": 1677610602, + "filename": "mydata.jsonl", + "purpose": "fine-tune", + } + /files/{file_id}/content: + get: + operationId: downloadFile + tags: + - Files + summary: Returns the contents of the specified file. + parameters: + - in: path + name: file_id + required: true + schema: + type: string + description: The ID of the file to use for this request. + responses: + "200": + description: OK + content: + application/json: + schema: + type: string + x-oaiMeta: + name: Retrieve file content + group: files + returns: The file content. + examples: + request: + curl: | + curl https://api.openai.com/v1/files/file-abc123/content \ + -H "Authorization: Bearer $OPENAI_API_KEY" > file.jsonl + python: | + from openai import OpenAI + client = OpenAI() + + content = client.files.content("file-abc123") + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const file = await openai.files.content("file-abc123"); + + console.log(file); + } + + main(); + /uploads: + post: + operationId: createUpload + tags: + - Uploads + summary: | + Creates an intermediate [Upload](/docs/api-reference/uploads/object) object that you can add [Parts](/docs/api-reference/uploads/part-object) to. Currently, an Upload can accept at most 8 GB in total and expires after an hour after you create it. + + Once you complete the Upload, we will create a [File](/docs/api-reference/files/object) object that contains all the parts you uploaded. This File is usable in the rest of our platform as a regular File object. + + For certain `purpose`s, the correct `mime_type` must be specified. Please refer to documentation for the supported MIME types for your use case: + - [Assistants](/docs/assistants/tools/file-search/supported-files) + + For guidance on the proper filename extensions for each purpose, please follow the documentation on [creating a File](/docs/api-reference/files/create). + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateUploadRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/Upload" + x-oaiMeta: + name: Create upload + group: uploads + returns: The [Upload](/docs/api-reference/uploads/object) object with status `pending`. + examples: + request: + curl: | + curl https://api.openai.com/v1/uploads \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "purpose": "fine-tune", + "filename": "training_examples.jsonl", + "bytes": 2147483648, + "mime_type": "text/jsonl" + }' + response: | + { + "id": "upload_abc123", + "object": "upload", + "bytes": 2147483648, + "created_at": 1719184911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + "status": "pending", + "expires_at": 1719127296 + } + + /uploads/{upload_id}/parts: + post: + operationId: addUploadPart + tags: + - Uploads + summary: | + Adds a [Part](/docs/api-reference/uploads/part-object) to an [Upload](/docs/api-reference/uploads/object) object. A Part represents a chunk of bytes from the file you are trying to upload. + + Each Part can be at most 64 MB, and you can add Parts until you hit the Upload maximum of 8 GB. + + It is possible to add multiple Parts in parallel. You can decide the intended order of the Parts when you [complete the Upload](/docs/api-reference/uploads/complete). + parameters: + - in: path + name: upload_id + required: true + schema: + type: string + example: upload_abc123 + description: | + The ID of the Upload. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/AddUploadPartRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/UploadPart" + x-oaiMeta: + name: Add upload part + group: uploads + returns: The upload [Part](/docs/api-reference/uploads/part-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/uploads/upload_abc123/parts + -F data="aHR0cHM6Ly9hcGkub3BlbmFpLmNvbS92MS91cGxvYWRz..." + response: | + { + "id": "part_def456", + "object": "upload.part", + "created_at": 1719185911, + "upload_id": "upload_abc123" + } + + /uploads/{upload_id}/complete: + post: + operationId: completeUpload + tags: + - Uploads + summary: | + Completes the [Upload](/docs/api-reference/uploads/object). + + Within the returned Upload object, there is a nested [File](/docs/api-reference/files/object) object that is ready to use in the rest of the platform. + + You can specify the order of the Parts by passing in an ordered list of the Part IDs. + + The number of bytes uploaded upon completion must match the number of bytes initially specified when creating the Upload object. No Parts may be added after an Upload is completed. + parameters: + - in: path + name: upload_id + required: true + schema: + type: string + example: upload_abc123 + description: | + The ID of the Upload. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CompleteUploadRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/Upload" + x-oaiMeta: + name: Complete upload + group: uploads + returns: The [Upload](/docs/api-reference/uploads/object) object with status `completed` with an additional `file` property containing the created usable File object. + examples: + request: + curl: | + curl https://api.openai.com/v1/uploads/upload_abc123/complete + -d '{ + "part_ids": ["part_def456", "part_ghi789"] + }' + response: | + { + "id": "upload_abc123", + "object": "upload", + "bytes": 2147483648, + "created_at": 1719184911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + "status": "completed", + "expires_at": 1719127296, + "file": { + "id": "file-xyz321", + "object": "file", + "bytes": 2147483648, + "created_at": 1719186911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + } + } + + /uploads/{upload_id}/cancel: + post: + operationId: cancelUpload + tags: + - Uploads + summary: | + Cancels the Upload. No Parts may be added after an Upload is cancelled. + parameters: + - in: path + name: upload_id + required: true + schema: + type: string + example: upload_abc123 + description: | + The ID of the Upload. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/Upload" + x-oaiMeta: + name: Cancel upload + group: uploads + returns: The [Upload](/docs/api-reference/uploads/object) object with status `cancelled`. + examples: + request: + curl: | + curl https://api.openai.com/v1/uploads/upload_abc123/cancel + response: | + { + "id": "upload_abc123", + "object": "upload", + "bytes": 2147483648, + "created_at": 1719184911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + "status": "cancelled", + "expires_at": 1719127296 + } + + /fine_tuning/jobs: + post: + operationId: createFineTuningJob + tags: + - Fine-tuning + summary: | + Creates a fine-tuning job which begins the process of creating a new model from a given dataset. + + Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete. + + [Learn more about fine-tuning](/docs/guides/fine-tuning) + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateFineTuningJobRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/FineTuningJob" + x-oaiMeta: + name: Create fine-tuning job + group: fine-tuning + returns: A [fine-tuning.job](/docs/api-reference/fine-tuning/object) object. + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "training_file": "file-BK7bzQj3FfZFXr7DbL6xJwfo", + "model": "gpt-3.5-turbo" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.create( + training_file="file-abc123", + model="gpt-3.5-turbo" + ) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.create({ + training_file: "file-abc123" + }); + + console.log(fineTune); + } + + main(); + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-3.5-turbo-0125", + "created_at": 1614807352, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "status": "queued", + "validation_file": null, + "training_file": "file-abc123", + } + - title: Epochs + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "training_file": "file-abc123", + "model": "gpt-3.5-turbo", + "hyperparameters": { + "n_epochs": 2 + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.create( + training_file="file-abc123", + model="gpt-3.5-turbo", + hyperparameters={ + "n_epochs":2 + } + ) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.create({ + training_file: "file-abc123", + model: "gpt-3.5-turbo", + hyperparameters: { n_epochs: 2 } + }); + + console.log(fineTune); + } + + main(); + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-3.5-turbo-0125", + "created_at": 1614807352, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "status": "queued", + "validation_file": null, + "training_file": "file-abc123", + "hyperparameters": {"n_epochs": 2}, + } + - title: Validation file + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "training_file": "file-abc123", + "validation_file": "file-abc123", + "model": "gpt-3.5-turbo" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.create( + training_file="file-abc123", + validation_file="file-def456", + model="gpt-3.5-turbo" + ) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.create({ + training_file: "file-abc123", + validation_file: "file-abc123" + }); + + console.log(fineTune); + } + + main(); + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-3.5-turbo-0125", + "created_at": 1614807352, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "status": "queued", + "validation_file": "file-abc123", + "training_file": "file-abc123", + } + - title: W&B Integration + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "training_file": "file-abc123", + "validation_file": "file-abc123", + "model": "gpt-3.5-turbo", + "integrations": [ + { + "type": "wandb", + "wandb": { + "project": "my-wandb-project", + "name": "ft-run-display-name" + "tags": [ + "first-experiment", "v2" + ] + } + } + ] + }' + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-3.5-turbo-0125", + "created_at": 1614807352, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "status": "queued", + "validation_file": "file-abc123", + "training_file": "file-abc123", + "integrations": [ + { + "type": "wandb", + "wandb": { + "project": "my-wandb-project", + "entity": None, + "run_id": "ftjob-abc123" + } + } + ] + } + get: + operationId: listPaginatedFineTuningJobs + tags: + - Fine-tuning + summary: | + List your organization's fine-tuning jobs + parameters: + - name: after + in: query + description: Identifier for the last job from the previous pagination request. + required: false + schema: + type: string + - name: limit + in: query + description: Number of fine-tuning jobs to retrieve. + required: false + schema: + type: integer + default: 20 + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListPaginatedFineTuningJobsResponse" + x-oaiMeta: + name: List fine-tuning jobs + group: fine-tuning + returns: A list of paginated [fine-tuning job](/docs/api-reference/fine-tuning/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs?limit=2 \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.list() + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.fineTuning.jobs.list(); + + for await (const fineTune of list) { + console.log(fineTune); + } + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "object": "fine_tuning.job.event", + "id": "ft-event-TjX0lMfOniCZX64t9PUQT5hn", + "created_at": 1689813489, + "level": "warn", + "message": "Fine tuning process stopping due to job cancellation", + "data": null, + "type": "message" + }, + { ... }, + { ... } + ], "has_more": true + } + /fine_tuning/jobs/{fine_tuning_job_id}: + get: + operationId: retrieveFineTuningJob + tags: + - Fine-tuning + summary: | + Get info about a fine-tuning job. + + [Learn more about fine-tuning](/docs/guides/fine-tuning) + parameters: + - in: path + name: fine_tuning_job_id + required: true + schema: + type: string + example: ft-AF1WoRqd3aJAHsqc9NY7iL8F + description: | + The ID of the fine-tuning job. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/FineTuningJob" + x-oaiMeta: + name: Retrieve fine-tuning job + group: fine-tuning + returns: The [fine-tuning](/docs/api-reference/fine-tuning/object) object with the given ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs/ft-AF1WoRqd3aJAHsqc9NY7iL8F \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.retrieve("ftjob-abc123") + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.retrieve("ftjob-abc123"); + + console.log(fineTune); + } + + main(); + response: &fine_tuning_example | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "davinci-002", + "created_at": 1692661014, + "finished_at": 1692661190, + "fine_tuned_model": "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + "organization_id": "org-123", + "result_files": [ + "file-abc123" + ], + "status": "succeeded", + "validation_file": null, + "training_file": "file-abc123", + "hyperparameters": { + "n_epochs": 4, + "batch_size": 1, + "learning_rate_multiplier": 1.0 + }, + "trained_tokens": 5768, + "integrations": [], + "seed": 0, + "estimated_finish": 0 + } + /fine_tuning/jobs/{fine_tuning_job_id}/events: + get: + operationId: listFineTuningEvents + tags: + - Fine-tuning + summary: | + Get status updates for a fine-tuning job. + parameters: + - in: path + name: fine_tuning_job_id + required: true + schema: + type: string + example: ft-AF1WoRqd3aJAHsqc9NY7iL8F + description: | + The ID of the fine-tuning job to get events for. + - name: after + in: query + description: Identifier for the last event from the previous pagination request. + required: false + schema: + type: string + - name: limit + in: query + description: Number of events to retrieve. + required: false + schema: + type: integer + default: 20 + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListFineTuningJobEventsResponse" + x-oaiMeta: + name: List fine-tuning events + group: fine-tuning + returns: A list of fine-tuning event objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs/ftjob-abc123/events \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.list_events( + fine_tuning_job_id="ftjob-abc123", + limit=2 + ) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.fineTuning.list_events(id="ftjob-abc123", limit=2); + + for await (const fineTune of list) { + console.log(fineTune); + } + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "object": "fine_tuning.job.event", + "id": "ft-event-ddTJfwuMVpfLXseO0Am0Gqjm", + "created_at": 1692407401, + "level": "info", + "message": "Fine tuning job successfully completed", + "data": null, + "type": "message" + }, + { + "object": "fine_tuning.job.event", + "id": "ft-event-tyiGuB72evQncpH87xe505Sv", + "created_at": 1692407400, + "level": "info", + "message": "New fine-tuned model created: ft:gpt-3.5-turbo:openai::7p4lURel", + "data": null, + "type": "message" + } + ], + "has_more": true + } + /fine_tuning/jobs/{fine_tuning_job_id}/cancel: + post: + operationId: cancelFineTuningJob + tags: + - Fine-tuning + summary: | + Immediately cancel a fine-tune job. + parameters: + - in: path + name: fine_tuning_job_id + required: true + schema: + type: string + example: ft-AF1WoRqd3aJAHsqc9NY7iL8F + description: | + The ID of the fine-tuning job to cancel. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/FineTuningJob" + x-oaiMeta: + name: Cancel fine-tuning + group: fine-tuning + returns: The cancelled [fine-tuning](/docs/api-reference/fine-tuning/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/fine_tuning/jobs/ftjob-abc123/cancel \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.cancel("ftjob-abc123") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.cancel("ftjob-abc123"); + + console.log(fineTune); + } + main(); + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-3.5-turbo-0125", + "created_at": 1689376978, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "hyperparameters": { + "n_epochs": "auto" + }, + "status": "cancelled", + "validation_file": "file-abc123", + "training_file": "file-abc123" + } + /fine_tuning/jobs/{fine_tuning_job_id}/checkpoints: + get: + operationId: listFineTuningJobCheckpoints + tags: + - Fine-tuning + summary: | + List checkpoints for a fine-tuning job. + parameters: + - in: path + name: fine_tuning_job_id + required: true + schema: + type: string + example: ft-AF1WoRqd3aJAHsqc9NY7iL8F + description: | + The ID of the fine-tuning job to get checkpoints for. + - name: after + in: query + description: Identifier for the last checkpoint ID from the previous pagination request. + required: false + schema: + type: string + - name: limit + in: query + description: Number of checkpoints to retrieve. + required: false + schema: + type: integer + default: 10 + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListFineTuningJobCheckpointsResponse" + x-oaiMeta: + name: List fine-tuning checkpoints + group: fine-tuning + returns: A list of fine-tuning [checkpoint objects](/docs/api-reference/fine-tuning/checkpoint-object) for a fine-tuning job. + examples: + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs/ftjob-abc123/checkpoints \ + -H "Authorization: Bearer $OPENAI_API_KEY" + response: | + { + "object": "list" + "data": [ + { + "object": "fine_tuning.job.checkpoint", + "id": "ftckpt_zc4Q7MP6XxulcVzj4MZdwsAB", + "created_at": 1519129973, + "fine_tuned_model_checkpoint": "ft:gpt-3.5-turbo-0125:my-org:custom-suffix:96olL566:ckpt-step-2000", + "metrics": { + "full_valid_loss": 0.134, + "full_valid_mean_token_accuracy": 0.874 + }, + "fine_tuning_job_id": "ftjob-abc123", + "step_number": 2000, + }, + { + "object": "fine_tuning.job.checkpoint", + "id": "ftckpt_enQCFmOTGj3syEpYVhBRLTSy", + "created_at": 1519129833, + "fine_tuned_model_checkpoint": "ft:gpt-3.5-turbo-0125:my-org:custom-suffix:7q8mpxmy:ckpt-step-1000", + "metrics": { + "full_valid_loss": 0.167, + "full_valid_mean_token_accuracy": 0.781 + }, + "fine_tuning_job_id": "ftjob-abc123", + "step_number": 1000, + }, + ], + "first_id": "ftckpt_zc4Q7MP6XxulcVzj4MZdwsAB", + "last_id": "ftckpt_enQCFmOTGj3syEpYVhBRLTSy", + "has_more": true + } + + /models: + get: + operationId: listModels + tags: + - Models + summary: Lists the currently available models, and provides basic information about each one such as the owner and availability. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListModelsResponse" + x-oaiMeta: + name: List models + group: models + returns: A list of [model](/docs/api-reference/models/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/models \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.models.list() + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.models.list(); + + for await (const model of list) { + console.log(model); + } + } + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "model-id-0", + "object": "model", + "created": 1686935002, + "owned_by": "organization-owner" + }, + { + "id": "model-id-1", + "object": "model", + "created": 1686935002, + "owned_by": "organization-owner", + }, + { + "id": "model-id-2", + "object": "model", + "created": 1686935002, + "owned_by": "openai" + }, + ], + "object": "list" + } + /models/{model}: + get: + operationId: retrieveModel + tags: + - Models + summary: Retrieves a model instance, providing basic information about the model such as the owner and permissioning. + parameters: + - in: path + name: model + required: true + schema: + type: string + # ideally this will be an actual ID, so this will always work from browser + example: gpt-3.5-turbo + description: The ID of the model to use for this request + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/Model" + x-oaiMeta: + name: Retrieve model + group: models + returns: The [model](/docs/api-reference/models/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/models/VAR_model_id \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.models.retrieve("VAR_model_id") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const model = await openai.models.retrieve("VAR_model_id"); + + console.log(model); + } + + main(); + response: &retrieve_model_response | + { + "id": "VAR_model_id", + "object": "model", + "created": 1686935002, + "owned_by": "openai" + } + delete: + operationId: deleteModel + tags: + - Models + summary: Delete a fine-tuned model. You must have the Owner role in your organization to delete a model. + parameters: + - in: path + name: model + required: true + schema: + type: string + example: ft:gpt-3.5-turbo:acemeco:suffix:abc123 + description: The model to delete + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteModelResponse" + x-oaiMeta: + name: Delete a fine-tuned model + group: models + returns: Deletion status. + examples: + request: + curl: | + curl https://api.openai.com/v1/models/ft:gpt-3.5-turbo:acemeco:suffix:abc123 \ + -X DELETE \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.models.delete("ft:gpt-3.5-turbo:acemeco:suffix:abc123") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const model = await openai.models.del("ft:gpt-3.5-turbo:acemeco:suffix:abc123"); + + console.log(model); + } + main(); + response: | + { + "id": "ft:gpt-3.5-turbo:acemeco:suffix:abc123", + "object": "model", + "deleted": true + } + + /moderations: + post: + operationId: createModeration + tags: + - Moderations + summary: Classifies if text is potentially harmful. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateModerationRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/CreateModerationResponse" + x-oaiMeta: + name: Create moderation + group: moderations + returns: A [moderation](/docs/api-reference/moderations/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/moderations \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "input": "I want to kill them." + }' + python: | + from openai import OpenAI + client = OpenAI() + + moderation = client.moderations.create(input="I want to kill them.") + print(moderation) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const moderation = await openai.moderations.create({ input: "I want to kill them." }); + + console.log(moderation); + } + main(); + response: &moderation_example | + { + "id": "modr-XXXXX", + "model": "text-moderation-005", + "results": [ + { + "flagged": true, + "categories": { + "sexual": false, + "hate": false, + "harassment": false, + "self-harm": false, + "sexual/minors": false, + "hate/threatening": false, + "violence/graphic": false, + "self-harm/intent": false, + "self-harm/instructions": false, + "harassment/threatening": true, + "violence": true, + }, + "category_scores": { + "sexual": 1.2282071e-06, + "hate": 0.010696256, + "harassment": 0.29842457, + "self-harm": 1.5236925e-08, + "sexual/minors": 5.7246268e-08, + "hate/threatening": 0.0060676364, + "violence/graphic": 4.435014e-06, + "self-harm/intent": 8.098441e-10, + "self-harm/instructions": 2.8498655e-11, + "harassment/threatening": 0.63055265, + "violence": 0.99011886, + } + } + ] + } + + /assistants: + get: + operationId: listAssistants + tags: + - Assistants + summary: Returns a list of assistants. + parameters: + - name: limit + in: query + description: &pagination_limit_param_description | + A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20. + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: &pagination_order_param_description | + Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order. + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: &pagination_after_param_description | + A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list. + schema: + type: string + - name: before + in: query + description: &pagination_before_param_description | + A cursor for use in pagination. `before` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list. + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListAssistantsResponse" + x-oaiMeta: + name: List assistants + group: assistants + beta: true + returns: A list of [assistant](/docs/api-reference/assistants/object) objects. + examples: + request: + curl: | + curl "https://api.openai.com/v1/assistants?order=desc&limit=20" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + my_assistants = client.beta.assistants.list( + order="desc", + limit="20", + ) + print(my_assistants.data) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myAssistants = await openai.beta.assistants.list({ + order: "desc", + limit: "20", + }); + + console.log(myAssistants.data); + } + + main(); + response: &list_assistants_example | + { + "object": "list", + "data": [ + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1698982736, + "name": "Coding Tutor", + "description": null, + "model": "gpt-4-turbo", + "instructions": "You are a helpful assistant designed to make me better at coding!", + "tools": [], + "tool_resources": {}, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + }, + { + "id": "asst_abc456", + "object": "assistant", + "created_at": 1698982718, + "name": "My Assistant", + "description": null, + "model": "gpt-4-turbo", + "instructions": "You are a helpful assistant designed to make me better at coding!", + "tools": [], + "tool_resources": {}, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + }, + { + "id": "asst_abc789", + "object": "assistant", + "created_at": 1698982643, + "name": null, + "description": null, + "model": "gpt-4-turbo", + "instructions": null, + "tools": [], + "tool_resources": {}, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + ], + "first_id": "asst_abc123", + "last_id": "asst_abc789", + "has_more": false + } + post: + operationId: createAssistant + tags: + - Assistants + summary: Create an assistant with a model and instructions. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateAssistantRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/AssistantObject" + x-oaiMeta: + name: Create assistant + group: assistants + beta: true + returns: An [assistant](/docs/api-reference/assistants/object) object. + examples: + - title: Code Interpreter + request: + curl: | + curl "https://api.openai.com/v1/assistants" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + "name": "Math Tutor", + "tools": [{"type": "code_interpreter"}], + "model": "gpt-4-turbo" + }' + + python: | + from openai import OpenAI + client = OpenAI() + + my_assistant = client.beta.assistants.create( + instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + name="Math Tutor", + tools=[{"type": "code_interpreter"}], + model="gpt-4-turbo", + ) + print(my_assistant) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myAssistant = await openai.beta.assistants.create({ + instructions: + "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + name: "Math Tutor", + tools: [{ type: "code_interpreter" }], + model: "gpt-4-turbo", + }); + + console.log(myAssistant); + } + + main(); + response: &create_assistants_example | + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1698984975, + "name": "Math Tutor", + "description": null, + "model": "gpt-4-turbo", + "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + "tools": [ + { + "type": "code_interpreter" + } + ], + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + - title: Files + request: + curl: | + curl https://api.openai.com/v1/assistants \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies.", + "tools": [{"type": "file_search"}], + "tool_resources": {"file_search": {"vector_store_ids": ["vs_123"]}}, + "model": "gpt-4-turbo" + }' + python: | + from openai import OpenAI + client = OpenAI() + + my_assistant = client.beta.assistants.create( + instructions="You are an HR bot, and you have access to files to answer employee questions about company policies.", + name="HR Helper", + tools=[{"type": "file_search"}], + tool_resources={"file_search": {"vector_store_ids": ["vs_123"]}}, + model="gpt-4-turbo" + ) + print(my_assistant) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myAssistant = await openai.beta.assistants.create({ + instructions: + "You are an HR bot, and you have access to files to answer employee questions about company policies.", + name: "HR Helper", + tools: [{ type: "file_search" }], + tool_resources: { + file_search: { + vector_store_ids: ["vs_123"] + } + }, + model: "gpt-4-turbo" + }); + + console.log(myAssistant); + } + + main(); + response: | + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1699009403, + "name": "HR Helper", + "description": null, + "model": "gpt-4-turbo", + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies.", + "tools": [ + { + "type": "file_search" + } + ], + "tool_resources": { + "file_search": { + "vector_store_ids": ["vs_123"] + } + }, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + + /assistants/{assistant_id}: + get: + operationId: getAssistant + tags: + - Assistants + summary: Retrieves an assistant. + parameters: + - in: path + name: assistant_id + required: true + schema: + type: string + description: The ID of the assistant to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/AssistantObject" + x-oaiMeta: + name: Retrieve assistant + group: assistants + beta: true + returns: The [assistant](/docs/api-reference/assistants/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/assistants/asst_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + my_assistant = client.beta.assistants.retrieve("asst_abc123") + print(my_assistant) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myAssistant = await openai.beta.assistants.retrieve( + "asst_abc123" + ); + + console.log(myAssistant); + } + + main(); + response: | + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1699009709, + "name": "HR Helper", + "description": null, + "model": "gpt-4-turbo", + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies.", + "tools": [ + { + "type": "file_search" + } + ], + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + post: + operationId: modifyAssistant + tags: + - Assistants + summary: Modifies an assistant. + parameters: + - in: path + name: assistant_id + required: true + schema: + type: string + description: The ID of the assistant to modify. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ModifyAssistantRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/AssistantObject" + x-oaiMeta: + name: Modify assistant + group: assistants + beta: true + returns: The modified [assistant](/docs/api-reference/assistants/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/assistants/asst_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", + "tools": [{"type": "file_search"}], + "model": "gpt-4-turbo" + }' + python: | + from openai import OpenAI + client = OpenAI() + + my_updated_assistant = client.beta.assistants.update( + "asst_abc123", + instructions="You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", + name="HR Helper", + tools=[{"type": "file_search"}], + model="gpt-4-turbo" + ) + + print(my_updated_assistant) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myUpdatedAssistant = await openai.beta.assistants.update( + "asst_abc123", + { + instructions: + "You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", + name: "HR Helper", + tools: [{ type: "file_search" }], + model: "gpt-4-turbo" + } + ); + + console.log(myUpdatedAssistant); + } + + main(); + response: | + { + "id": "asst_123", + "object": "assistant", + "created_at": 1699009709, + "name": "HR Helper", + "description": null, + "model": "gpt-4-turbo", + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", + "tools": [ + { + "type": "file_search" + } + ], + "tool_resources": { + "file_search": { + "vector_store_ids": [] + } + }, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + delete: + operationId: deleteAssistant + tags: + - Assistants + summary: Delete an assistant. + parameters: + - in: path + name: assistant_id + required: true + schema: + type: string + description: The ID of the assistant to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteAssistantResponse" + x-oaiMeta: + name: Delete assistant + group: assistants + beta: true + returns: Deletion status + examples: + request: + curl: | + curl https://api.openai.com/v1/assistants/asst_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -X DELETE + python: | + from openai import OpenAI + client = OpenAI() + + response = client.beta.assistants.delete("asst_abc123") + print(response) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const response = await openai.beta.assistants.del("asst_abc123"); + + console.log(response); + } + main(); + response: | + { + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + } + + /threads: + post: + operationId: createThread + tags: + - Assistants + summary: Create a thread. + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/CreateThreadRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ThreadObject" + x-oaiMeta: + name: Create thread + group: threads + beta: true + returns: A [thread](/docs/api-reference/threads) object. + examples: + - title: Empty + request: + curl: | + curl https://api.openai.com/v1/threads \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '' + python: | + from openai import OpenAI + client = OpenAI() + + empty_thread = client.beta.threads.create() + print(empty_thread) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const emptyThread = await openai.beta.threads.create(); + + console.log(emptyThread); + } + + main(); + response: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1699012949, + "metadata": {}, + "tool_resources": {} + } + - title: Messages + request: + curl: | + curl https://api.openai.com/v1/threads \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "messages": [{ + "role": "user", + "content": "Hello, what is AI?" + }, { + "role": "user", + "content": "How does AI work? Explain it in simple terms." + }] + }' + python: | + from openai import OpenAI + client = OpenAI() + + message_thread = client.beta.threads.create( + messages=[ + { + "role": "user", + "content": "Hello, what is AI?" + }, + { + "role": "user", + "content": "How does AI work? Explain it in simple terms." + }, + ] + ) + + print(message_thread) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const messageThread = await openai.beta.threads.create({ + messages: [ + { + role: "user", + content: "Hello, what is AI?" + }, + { + role: "user", + content: "How does AI work? Explain it in simple terms.", + }, + ], + }); + + console.log(messageThread); + } + + main(); + response: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1699014083, + "metadata": {}, + "tool_resources": {} + } + + /threads/{thread_id}: + get: + operationId: getThread + tags: + - Assistants + summary: Retrieves a thread. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ThreadObject" + x-oaiMeta: + name: Retrieve thread + group: threads + beta: true + returns: The [thread](/docs/api-reference/threads/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + my_thread = client.beta.threads.retrieve("thread_abc123") + print(my_thread) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myThread = await openai.beta.threads.retrieve( + "thread_abc123" + ); + + console.log(myThread); + } + + main(); + response: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1699014083, + "metadata": {}, + "tool_resources": { + "code_interpreter": { + "file_ids": [] + } + } + } + post: + operationId: modifyThread + tags: + - Assistants + summary: Modifies a thread. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to modify. Only the `metadata` can be modified. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ModifyThreadRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ThreadObject" + x-oaiMeta: + name: Modify thread + group: threads + beta: true + returns: The modified [thread](/docs/api-reference/threads/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "metadata": { + "modified": "true", + "user": "abc123" + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + my_updated_thread = client.beta.threads.update( + "thread_abc123", + metadata={ + "modified": "true", + "user": "abc123" + } + ) + print(my_updated_thread) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const updatedThread = await openai.beta.threads.update( + "thread_abc123", + { + metadata: { modified: "true", user: "abc123" }, + } + ); + + console.log(updatedThread); + } + + main(); + response: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1699014083, + "metadata": { + "modified": "true", + "user": "abc123" + }, + "tool_resources": {} + } + delete: + operationId: deleteThread + tags: + - Assistants + summary: Delete a thread. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteThreadResponse" + x-oaiMeta: + name: Delete thread + group: threads + beta: true + returns: Deletion status + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -X DELETE + python: | + from openai import OpenAI + client = OpenAI() + + response = client.beta.threads.delete("thread_abc123") + print(response) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const response = await openai.beta.threads.del("thread_abc123"); + + console.log(response); + } + main(); + response: | + { + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + } + + /threads/{thread_id}/messages: + get: + operationId: listMessages + tags: + - Assistants + summary: Returns a list of messages for a given thread. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) the messages belong to. + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + - name: run_id + in: query + description: | + Filter messages by the run ID that generated them. + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListMessagesResponse" + x-oaiMeta: + name: List messages + group: threads + beta: true + returns: A list of [message](/docs/api-reference/messages) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/messages \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + thread_messages = client.beta.threads.messages.list("thread_abc123") + print(thread_messages.data) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const threadMessages = await openai.beta.threads.messages.list( + "thread_abc123" + ); + + console.log(threadMessages.data); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1699016383, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "How does AI work? Explain it in simple terms.", + "annotations": [] + } + } + ], + "attachments": [], + "metadata": {} + }, + { + "id": "msg_abc456", + "object": "thread.message", + "created_at": 1699016383, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "Hello, what is AI?", + "annotations": [] + } + } + ], + "attachments": [], + "metadata": {} + } + ], + "first_id": "msg_abc123", + "last_id": "msg_abc456", + "has_more": false + } + post: + operationId: createMessage + tags: + - Assistants + summary: Create a message. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) to create a message for. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateMessageRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/MessageObject" + x-oaiMeta: + name: Create message + group: threads + beta: true + returns: A [message](/docs/api-reference/messages/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/messages \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "role": "user", + "content": "How does AI work? Explain it in simple terms." + }' + python: | + from openai import OpenAI + client = OpenAI() + + thread_message = client.beta.threads.messages.create( + "thread_abc123", + role="user", + content="How does AI work? Explain it in simple terms.", + ) + print(thread_message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const threadMessages = await openai.beta.threads.messages.create( + "thread_abc123", + { role: "user", content: "How does AI work? Explain it in simple terms." } + ); + + console.log(threadMessages); + } + + main(); + response: | + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1713226573, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "How does AI work? Explain it in simple terms.", + "annotations": [] + } + } + ], + "attachments": [], + "metadata": {} + } + + /threads/{thread_id}/messages/{message_id}: + get: + operationId: getMessage + tags: + - Assistants + summary: Retrieve a message. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) to which this message belongs. + - in: path + name: message_id + required: true + schema: + type: string + description: The ID of the message to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/MessageObject" + x-oaiMeta: + name: Retrieve message + group: threads + beta: true + returns: The [message](/docs/api-reference/messages/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/messages/msg_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + message = client.beta.threads.messages.retrieve( + message_id="msg_abc123", + thread_id="thread_abc123", + ) + print(message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const message = await openai.beta.threads.messages.retrieve( + "thread_abc123", + "msg_abc123" + ); + + console.log(message); + } + + main(); + response: | + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1699017614, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "How does AI work? Explain it in simple terms.", + "annotations": [] + } + } + ], + "attachments": [], + "metadata": {} + } + post: + operationId: modifyMessage + tags: + - Assistants + summary: Modifies a message. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to which this message belongs. + - in: path + name: message_id + required: true + schema: + type: string + description: The ID of the message to modify. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ModifyMessageRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/MessageObject" + x-oaiMeta: + name: Modify message + group: threads + beta: true + returns: The modified [message](/docs/api-reference/messages/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/messages/msg_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "metadata": { + "modified": "true", + "user": "abc123" + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + message = client.beta.threads.messages.update( + message_id="msg_abc12", + thread_id="thread_abc123", + metadata={ + "modified": "true", + "user": "abc123", + }, + ) + print(message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const message = await openai.beta.threads.messages.update( + "thread_abc123", + "msg_abc123", + { + metadata: { + modified: "true", + user: "abc123", + }, + } + }' + response: | + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1699017614, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "How does AI work? Explain it in simple terms.", + "annotations": [] + } + } + ], + "file_ids": [], + "metadata": { + "modified": "true", + "user": "abc123" + } + } + delete: + operationId: deleteMessage + tags: + - Assistants + summary: Deletes a message. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to which this message belongs. + - in: path + name: message_id + required: true + schema: + type: string + description: The ID of the message to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteMessageResponse" + x-oaiMeta: + name: Delete message + group: threads + beta: true + returns: Deletion status + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/threads/thread_abc123/messages/msg_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + deleted_message = client.beta.threads.messages.delete( + message_id="msg_abc12", + thread_id="thread_abc123", + ) + print(deleted_message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const deletedMessage = await openai.beta.threads.messages.del( + "thread_abc123", + "msg_abc123" + ); + + console.log(deletedMessage); + } + response: | + { + "id": "msg_abc123", + "object": "thread.message.deleted", + "deleted": true + } + + /threads/runs: + post: + operationId: createThreadAndRun + tags: + - Assistants + summary: Create a thread and run it in one request. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateThreadAndRunRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Create thread and run + group: threads + beta: true + returns: A [run](/docs/api-reference/runs/object) object. + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/threads/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_abc123", + "thread": { + "messages": [ + {"role": "user", "content": "Explain deep learning to a 5 year old."} + ] + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.create_and_run( + assistant_id="asst_abc123", + thread={ + "messages": [ + {"role": "user", "content": "Explain deep learning to a 5 year old."} + ] + } + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.createAndRun({ + assistant_id: "asst_abc123", + thread: { + messages: [ + { role: "user", content: "Explain deep learning to a 5 year old." }, + ], + }, + }); + + console.log(run); + } + + main(); + response: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699076792, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "queued", + "started_at": null, + "expires_at": 1699077392, + "cancelled_at": null, + "failed_at": null, + "completed_at": null, + "required_action": null, + "last_error": null, + "model": "gpt-4-turbo", + "instructions": "You are a helpful assistant.", + "tools": [], + "tool_resources": {}, + "metadata": {}, + "temperature": 1.0, + "top_p": 1.0, + "max_completion_tokens": null, + "max_prompt_tokens": null, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "incomplete_details": null, + "usage": null, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/threads/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_123", + "thread": { + "messages": [ + {"role": "user", "content": "Hello"} + ] + }, + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + stream = client.beta.threads.create_and_run( + assistant_id="asst_123", + thread={ + "messages": [ + {"role": "user", "content": "Hello"} + ] + }, + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const stream = await openai.beta.threads.createAndRun({ + assistant_id: "asst_123", + thread: { + messages: [ + { role: "user", content: "Hello" }, + ], + }, + stream: true + }); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.created + data: {"id":"thread_123","object":"thread","created_at":1710348075,"metadata":{}} + + event: thread.run.created + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + + event: thread.run.step.created + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.message.created + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[], "metadata":{}} + + event: thread.message.in_progress + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[], "metadata":{}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"Hello","annotations":[]}}]}} + + ... + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" today"}}]}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"?"}}]}} + + event: thread.message.completed + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1710348077,"role":"assistant","content":[{"type":"text","text":{"value":"Hello! How can I assist you today?","annotations":[]}}], "metadata":{}} + + event: thread.run.step.completed + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710348077,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31}} + + event: thread.run.completed + {"id":"run_123","object":"thread.run","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1713226836,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1713226837,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":345,"completion_tokens":11,"total_tokens":356},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + + event: done + data: [DONE] + + - title: Streaming with Functions + request: + curl: | + curl https://api.openai.com/v1/threads/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_abc123", + "thread": { + "messages": [ + {"role": "user", "content": "What is the weather like in San Francisco?"} + ] + }, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ] + + stream = client.beta.threads.create_and_run( + thread={ + "messages": [ + {"role": "user", "content": "What is the weather like in San Francisco?"} + ] + }, + assistant_id="asst_abc123", + tools=tools, + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + const tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ]; + + async function main() { + const stream = await openai.beta.threads.createAndRun({ + assistant_id: "asst_123", + thread: { + messages: [ + { role: "user", content: "What is the weather like in San Francisco?" }, + ], + }, + tools: tools, + stream: true + }); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.created + data: {"id":"thread_123","object":"thread","created_at":1710351818,"metadata":{}} + + event: thread.run.created + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710351818,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.step.created + data: {"id":"step_001","object":"thread.run.step","created_at":1710351819,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"tool_calls","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710352418,"failed_at":null,"last_error":null,"step_details":{"type":"tool_calls","tool_calls":[]},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_001","object":"thread.run.step","created_at":1710351819,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"tool_calls","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710352418,"failed_at":null,"last_error":null,"step_details":{"type":"tool_calls","tool_calls":[]},"usage":null} + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"id":"call_XXNp8YGaFrjrSjgqxtC8JJ1B","type":"function","function":{"name":"get_current_weather","arguments":"","output":null}}]}}} + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"type":"function","function":{"arguments":"{\""}}]}}} + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"type":"function","function":{"arguments":"location"}}]}}} + + ... + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"type":"function","function":{"arguments":"ahrenheit"}}]}}} + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"type":"function","function":{"arguments":"\"}"}}]}}} + + event: thread.run.requires_action + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"requires_action","started_at":1710351818,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":{"type":"submit_tool_outputs","submit_tool_outputs":{"tool_calls":[{"id":"call_XXNp8YGaFrjrSjgqxtC8JJ1B","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\":\"San Francisco, CA\",\"unit\":\"fahrenheit\"}"}}]}},"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":345,"completion_tokens":11,"total_tokens":356},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: done + data: [DONE] + + /threads/{thread_id}/runs: + get: + operationId: listRuns + tags: + - Assistants + summary: Returns a list of runs belonging to a thread. + parameters: + - name: thread_id + in: path + required: true + schema: + type: string + description: The ID of the thread the run belongs to. + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListRunsResponse" + x-oaiMeta: + name: List runs + group: threads + beta: true + returns: A list of [run](/docs/api-reference/runs/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + runs = client.beta.threads.runs.list( + "thread_abc123" + ) + + print(runs) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const runs = await openai.beta.threads.runs.list( + "thread_abc123" + ); + + console.log(runs); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699075072, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699075072, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699075073, + "last_error": null, + "model": "gpt-4-turbo", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "tool_resources": { + "code_interpreter": { + "file_ids": [ + "file-abc123", + "file-abc456" + ] + } + }, + "metadata": {}, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + }, + { + "id": "run_abc456", + "object": "thread.run", + "created_at": 1699063290, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699063290, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699063291, + "last_error": null, + "model": "gpt-4-turbo", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "tool_resources": { + "code_interpreter": { + "file_ids": [ + "file-abc123", + "file-abc456" + ] + } + }, + "metadata": {}, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + ], + "first_id": "run_abc123", + "last_id": "run_abc456", + "has_more": false + } + post: + operationId: createRun + tags: + - Assistants + summary: Create a run. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to run. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateRunRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Create run + group: threads + beta: true + returns: A [run](/docs/api-reference/runs/object) object. + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_abc123" + }' + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.create( + thread_id="thread_abc123", + assistant_id="asst_abc123" + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.create( + "thread_abc123", + { assistant_id: "asst_abc123" } + ); + + console.log(run); + } + + main(); + response: &run_object_example | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699063290, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "queued", + "started_at": 1699063290, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699063291, + "last_error": null, + "model": "gpt-4-turbo", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "metadata": {}, + "usage": null, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/threads/thread_123/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_123", + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + stream = client.beta.threads.runs.create( + thread_id="thread_123", + assistant_id="asst_123", + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const stream = await openai.beta.threads.runs.create( + "thread_123", + { assistant_id: "asst_123", stream: true } + ); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.run.created + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710330641,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.step.created + data: {"id":"step_001","object":"thread.run.step","created_at":1710330641,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710331240,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_001","object":"thread.run.step","created_at":1710330641,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710331240,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.message.created + data: {"id":"msg_001","object":"thread.message","created_at":1710330641,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.in_progress + data: {"id":"msg_001","object":"thread.message","created_at":1710330641,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"Hello","annotations":[]}}]}} + + ... + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" today"}}]}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"?"}}]}} + + event: thread.message.completed + data: {"id":"msg_001","object":"thread.message","created_at":1710330641,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1710330642,"role":"assistant","content":[{"type":"text","text":{"value":"Hello! How can I assist you today?","annotations":[]}}],"metadata":{}} + + event: thread.run.step.completed + data: {"id":"step_001","object":"thread.run.step","created_at":1710330641,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710330642,"expires_at":1710331240,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31}} + + event: thread.run.completed + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710330641,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710330642,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: done + data: [DONE] + + - title: Streaming with Functions + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_abc123", + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ] + + stream = client.beta.threads.runs.create( + thread_id="thread_abc123", + assistant_id="asst_abc123", + tools=tools, + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + const tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ]; + + async function main() { + const stream = await openai.beta.threads.runs.create( + "thread_abc123", + { + assistant_id: "asst_abc123", + tools: tools, + stream: true + } + ); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.run.created + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710348075,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.step.created + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.message.created + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.in_progress + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"Hello","annotations":[]}}]}} + + ... + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" today"}}]}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"?"}}]}} + + event: thread.message.completed + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1710348077,"role":"assistant","content":[{"type":"text","text":{"value":"Hello! How can I assist you today?","annotations":[]}}],"metadata":{}} + + event: thread.run.step.completed + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710348077,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31}} + + event: thread.run.completed + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710348075,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710348077,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: done + data: [DONE] + + /threads/{thread_id}/runs/{run_id}: + get: + operationId: getRun + tags: + - Assistants + summary: Retrieves a run. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) that was run. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Retrieve run + group: threads + beta: true + returns: The [run](/docs/api-reference/runs/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.retrieve( + thread_id="thread_abc123", + run_id="run_abc123" + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.retrieve( + "thread_abc123", + "run_abc123" + ); + + console.log(run); + } + + main(); + response: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699075072, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699075072, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699075073, + "last_error": null, + "model": "gpt-4-turbo", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "metadata": {}, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + post: + operationId: modifyRun + tags: + - Assistants + summary: Modifies a run. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) that was run. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run to modify. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ModifyRunRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Modify run + group: threads + beta: true + returns: The modified [run](/docs/api-reference/runs/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "metadata": { + "user_id": "user_abc123" + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.update( + thread_id="thread_abc123", + run_id="run_abc123", + metadata={"user_id": "user_abc123"}, + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.update( + "thread_abc123", + "run_abc123", + { + metadata: { + user_id: "user_abc123", + }, + } + ); + + console.log(run); + } + + main(); + response: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699075072, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699075072, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699075073, + "last_error": null, + "model": "gpt-4-turbo", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "tool_resources": { + "code_interpreter": { + "file_ids": [ + "file-abc123", + "file-abc456" + ] + } + }, + "metadata": { + "user_id": "user_abc123" + }, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + + /threads/{thread_id}/runs/{run_id}/submit_tool_outputs: + post: + operationId: submitToolOuputsToRun + tags: + - Assistants + summary: | + When a run has the `status: "requires_action"` and `required_action.type` is `submit_tool_outputs`, this endpoint can be used to submit the outputs from the tool calls once they're all completed. All outputs must be submitted in a single request. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) to which this run belongs. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run that requires the tool output submission. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/SubmitToolOutputsRunRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Submit tool outputs to run + group: threads + beta: true + returns: The modified [run](/docs/api-reference/runs/object) object matching the specified ID. + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/threads/thread_123/runs/run_123/submit_tool_outputs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "tool_outputs": [ + { + "tool_call_id": "call_001", + "output": "70 degrees and sunny." + } + ] + }' + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.submit_tool_outputs( + thread_id="thread_123", + run_id="run_123", + tool_outputs=[ + { + "tool_call_id": "call_001", + "output": "70 degrees and sunny." + } + ] + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.submitToolOutputs( + "thread_123", + "run_123", + { + tool_outputs: [ + { + tool_call_id: "call_001", + output: "70 degrees and sunny.", + }, + ], + } + ); + + console.log(run); + } + + main(); + response: | + { + "id": "run_123", + "object": "thread.run", + "created_at": 1699075592, + "assistant_id": "asst_123", + "thread_id": "thread_123", + "status": "queued", + "started_at": 1699075592, + "expires_at": 1699076192, + "cancelled_at": null, + "failed_at": null, + "completed_at": null, + "last_error": null, + "model": "gpt-4-turbo", + "instructions": null, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "metadata": {}, + "usage": null, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/threads/thread_123/runs/run_123/submit_tool_outputs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "tool_outputs": [ + { + "tool_call_id": "call_001", + "output": "70 degrees and sunny." + } + ], + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + stream = client.beta.threads.runs.submit_tool_outputs( + thread_id="thread_123", + run_id="run_123", + tool_outputs=[ + { + "tool_call_id": "call_001", + "output": "70 degrees and sunny." + } + ], + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const stream = await openai.beta.threads.runs.submitToolOutputs( + "thread_123", + "run_123", + { + tool_outputs: [ + { + tool_call_id: "call_001", + output: "70 degrees and sunny.", + }, + ], + } + ); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.run.step.completed + data: {"id":"step_001","object":"thread.run.step","created_at":1710352449,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"tool_calls","status":"completed","cancelled_at":null,"completed_at":1710352475,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"tool_calls","tool_calls":[{"id":"call_iWr0kQ2EaYMaxNdl0v3KYkx7","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\":\"San Francisco, CA\",\"unit\":\"fahrenheit\"}","output":"70 degrees and sunny."}}]},"usage":{"prompt_tokens":291,"completion_tokens":24,"total_tokens":315}} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":1710352448,"expires_at":1710353047,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710352475,"expires_at":1710353047,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.step.created + data: {"id":"step_002","object":"thread.run.step","created_at":1710352476,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_002"}},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_002","object":"thread.run.step","created_at":1710352476,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_002"}},"usage":null} + + event: thread.message.created + data: {"id":"msg_002","object":"thread.message","created_at":1710352476,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.in_progress + data: {"id":"msg_002","object":"thread.message","created_at":1710352476,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"The","annotations":[]}}]}} + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" current"}}]}} + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" weather"}}]}} + + ... + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" sunny"}}]}} + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"."}}]}} + + event: thread.message.completed + data: {"id":"msg_002","object":"thread.message","created_at":1710352476,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1710352477,"role":"assistant","content":[{"type":"text","text":{"value":"The current weather in San Francisco, CA is 70 degrees Fahrenheit and sunny.","annotations":[]}}],"metadata":{}} + + event: thread.run.step.completed + data: {"id":"step_002","object":"thread.run.step","created_at":1710352476,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710352477,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_002"}},"usage":{"prompt_tokens":329,"completion_tokens":18,"total_tokens":347}} + + event: thread.run.completed + data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710352475,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710352477,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: done + data: [DONE] + + /threads/{thread_id}/runs/{run_id}/cancel: + post: + operationId: cancelRun + tags: + - Assistants + summary: Cancels a run that is `in_progress`. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to which this run belongs. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run to cancel. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Cancel a run + group: threads + beta: true + returns: The modified [run](/docs/api-reference/runs/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123/cancel \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -X POST + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.cancel( + thread_id="thread_abc123", + run_id="run_abc123" + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.cancel( + "thread_abc123", + "run_abc123" + ); + + console.log(run); + } + + main(); + response: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699076126, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "cancelling", + "started_at": 1699076126, + "expires_at": 1699076726, + "cancelled_at": null, + "failed_at": null, + "completed_at": null, + "last_error": null, + "model": "gpt-4-turbo", + "instructions": "You summarize books.", + "tools": [ + { + "type": "file_search" + } + ], + "tool_resources": { + "file_search": { + "vector_store_ids": ["vs_123"] + } + }, + "metadata": {}, + "usage": null, + "temperature": 1.0, + "top_p": 1.0, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + + /threads/{thread_id}/runs/{run_id}/steps: + get: + operationId: listRunSteps + tags: + - Assistants + summary: Returns a list of run steps belonging to a run. + parameters: + - name: thread_id + in: path + required: true + schema: + type: string + description: The ID of the thread the run and run steps belong to. + - name: run_id + in: path + required: true + schema: + type: string + description: The ID of the run the run steps belong to. + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListRunStepsResponse" + x-oaiMeta: + name: List run steps + group: threads + beta: true + returns: A list of [run step](/docs/api-reference/runs/step-object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123/steps \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + run_steps = client.beta.threads.runs.steps.list( + thread_id="thread_abc123", + run_id="run_abc123" + ) + + print(run_steps) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const runStep = await openai.beta.threads.runs.steps.list( + "thread_abc123", + "run_abc123" + ); + console.log(runStep); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "step_abc123", + "object": "thread.run.step", + "created_at": 1699063291, + "run_id": "run_abc123", + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "type": "message_creation", + "status": "completed", + "cancelled_at": null, + "completed_at": 1699063291, + "expired_at": null, + "failed_at": null, + "last_error": null, + "step_details": { + "type": "message_creation", + "message_creation": { + "message_id": "msg_abc123" + } + }, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + } + } + ], + "first_id": "step_abc123", + "last_id": "step_abc456", + "has_more": false + } + + /threads/{thread_id}/runs/{run_id}/steps/{step_id}: + get: + operationId: getRunStep + tags: + - Assistants + summary: Retrieves a run step. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to which the run and run step belongs. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run to which the run step belongs. + - in: path + name: step_id + required: true + schema: + type: string + description: The ID of the run step to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunStepObject" + x-oaiMeta: + name: Retrieve run step + group: threads + beta: true + returns: The [run step](/docs/api-reference/runs/step-object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123/steps/step_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + run_step = client.beta.threads.runs.steps.retrieve( + thread_id="thread_abc123", + run_id="run_abc123", + step_id="step_abc123" + ) + + print(run_step) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const runStep = await openai.beta.threads.runs.steps.retrieve( + "thread_abc123", + "run_abc123", + "step_abc123" + ); + console.log(runStep); + } + + main(); + response: &run_step_object_example | + { + "id": "step_abc123", + "object": "thread.run.step", + "created_at": 1699063291, + "run_id": "run_abc123", + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "type": "message_creation", + "status": "completed", + "cancelled_at": null, + "completed_at": 1699063291, + "expired_at": null, + "failed_at": null, + "last_error": null, + "step_details": { + "type": "message_creation", + "message_creation": { + "message_id": "msg_abc123" + } + }, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + } + } + + /vector_stores: + get: + operationId: listVectorStores + tags: + - Vector Stores + summary: Returns a list of vector stores. + parameters: + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListVectorStoresResponse" + x-oaiMeta: + name: List vector stores + group: vector_stores + beta: true + returns: A list of [vector store](/docs/api-reference/vector-stores/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_stores = client.beta.vector_stores.list() + print(vector_stores) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStores = await openai.beta.vectorStores.list(); + console.log(vectorStores); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "vs_abc123", + "object": "vector_store", + "created_at": 1699061776, + "name": "Support FAQ", + "bytes": 139920, + "file_counts": { + "in_progress": 0, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 3 + } + }, + { + "id": "vs_abc456", + "object": "vector_store", + "created_at": 1699061776, + "name": "Support FAQ v2", + "bytes": 139920, + "file_counts": { + "in_progress": 0, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 3 + } + } + ], + "first_id": "vs_abc123", + "last_id": "vs_abc456", + "has_more": false + } + post: + operationId: createVectorStore + tags: + - Vector Stores + summary: Create a vector store. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateVectorStoreRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreObject" + x-oaiMeta: + name: Create vector store + group: vector_stores + beta: true + returns: A [vector store](/docs/api-reference/vector-stores/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + -d '{ + "name": "Support FAQ" + }' + python: | + from openai import OpenAI + client = OpenAI() + + vector_store = client.beta.vector_stores.create( + name="Support FAQ" + ) + print(vector_store) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStore = await openai.beta.vectorStores.create({ + name: "Support FAQ" + }); + console.log(vectorStore); + } + + main(); + response: | + { + "id": "vs_abc123", + "object": "vector_store", + "created_at": 1699061776, + "name": "Support FAQ", + "bytes": 139920, + "file_counts": { + "in_progress": 0, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 3 + } + } + + /vector_stores/{vector_store_id}: + get: + operationId: getVectorStore + tags: + - Vector Stores + summary: Retrieves a vector store. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreObject" + x-oaiMeta: + name: Retrieve vector store + group: vector_stores + beta: true + returns: The [vector store](/docs/api-reference/vector-stores/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store = client.beta.vector_stores.retrieve( + vector_store_id="vs_abc123" + ) + print(vector_store) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStore = await openai.beta.vectorStores.retrieve( + "vs_abc123" + ); + console.log(vectorStore); + } + + main(); + response: | + { + "id": "vs_abc123", + "object": "vector_store", + "created_at": 1699061776 + } + post: + operationId: modifyVectorStore + tags: + - Vector Stores + summary: Modifies a vector store. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store to modify. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UpdateVectorStoreRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreObject" + x-oaiMeta: + name: Modify vector store + group: vector_stores + beta: true + returns: The modified [vector store](/docs/api-reference/vector-stores/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + -d '{ + "name": "Support FAQ" + }' + python: | + from openai import OpenAI + client = OpenAI() + + vector_store = client.beta.vector_stores.update( + vector_store_id="vs_abc123", + name="Support FAQ" + ) + print(vector_store) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStore = await openai.beta.vectorStores.update( + "vs_abc123", + { + name: "Support FAQ" + } + ); + console.log(vectorStore); + } + + main(); + response: | + { + "id": "vs_abc123", + "object": "vector_store", + "created_at": 1699061776, + "name": "Support FAQ", + "bytes": 139920, + "file_counts": { + "in_progress": 0, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 3 + } + } + + delete: + operationId: deleteVectorStore + tags: + - Vector Stores + summary: Delete a vector store. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteVectorStoreResponse" + x-oaiMeta: + name: Delete vector store + group: vector_stores + beta: true + returns: Deletion status + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -X DELETE + python: | + from openai import OpenAI + client = OpenAI() + + deleted_vector_store = client.beta.vector_stores.delete( + vector_store_id="vs_abc123" + ) + print(deleted_vector_store) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const deletedVectorStore = await openai.beta.vectorStores.del( + "vs_abc123" + ); + console.log(deletedVectorStore); + } + + main(); + response: | + { + id: "vs_abc123", + object: "vector_store.deleted", + deleted: true + } + + /vector_stores/{vector_store_id}/files: + get: + operationId: listVectorStoreFiles + tags: + - Vector Stores + summary: Returns a list of vector store files. + parameters: + - name: vector_store_id + in: path + description: The ID of the vector store that the files belong to. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + - name: filter + in: query + description: "Filter by file status. One of `in_progress`, `completed`, `failed`, `cancelled`." + schema: + type: string + enum: ["in_progress", "completed", "failed", "cancelled"] + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListVectorStoreFilesResponse" + x-oaiMeta: + name: List vector store files + group: vector_stores + beta: true + returns: A list of [vector store file](/docs/api-reference/vector-stores-files/file-object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_files = client.beta.vector_stores.files.list( + vector_store_id="vs_abc123" + ) + print(vector_store_files) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStoreFiles = await openai.beta.vectorStores.files.list( + "vs_abc123" + ); + console.log(vectorStoreFiles); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "file-abc123", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abc123" + }, + { + "id": "file-abc456", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abc123" + } + ], + "first_id": "file-abc123", + "last_id": "file-abc456", + "has_more": false + } + post: + operationId: createVectorStoreFile + tags: + - Vector Stores + summary: Create a vector store file by attaching a [File](/docs/api-reference/files) to a [vector store](/docs/api-reference/vector-stores/object). + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + example: vs_abc123 + description: | + The ID of the vector store for which to create a File. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateVectorStoreFileRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileObject" + x-oaiMeta: + name: Create vector store file + group: vector_stores + beta: true + returns: A [vector store file](/docs/api-reference/vector-stores-files/file-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "file_id": "file-abc123" + }' + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_file = client.beta.vector_stores.files.create( + vector_store_id="vs_abc123", + file_id="file-abc123" + ) + print(vector_store_file) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const myVectorStoreFile = await openai.beta.vectorStores.files.create( + "vs_abc123", + { + file_id: "file-abc123" + } + ); + console.log(myVectorStoreFile); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "vector_store.file", + "created_at": 1699061776, + "usage_bytes": 1234, + "vector_store_id": "vs_abcd", + "status": "completed", + "last_error": null + } + + /vector_stores/{vector_store_id}/files/{file_id}: + get: + operationId: getVectorStoreFile + tags: + - Vector Stores + summary: Retrieves a vector store file. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + example: vs_abc123 + description: The ID of the vector store that the file belongs to. + - in: path + name: file_id + required: true + schema: + type: string + example: file-abc123 + description: The ID of the file being retrieved. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileObject" + x-oaiMeta: + name: Retrieve vector store file + group: vector_stores + beta: true + returns: The [vector store file](/docs/api-reference/vector-stores-files/file-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files/file-abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_file = client.beta.vector_stores.files.retrieve( + vector_store_id="vs_abc123", + file_id="file-abc123" + ) + print(vector_store_file) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStoreFile = await openai.beta.vectorStores.files.retrieve( + "vs_abc123", + "file-abc123" + ); + console.log(vectorStoreFile); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abcd", + "status": "completed", + "last_error": null + } + delete: + operationId: deleteVectorStoreFile + tags: + - Vector Stores + summary: Delete a vector store file. This will remove the file from the vector store but the file itself will not be deleted. To delete the file, use the [delete file](/docs/api-reference/files/delete) endpoint. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store that the file belongs to. + - in: path + name: file_id + required: true + schema: + type: string + description: The ID of the file to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteVectorStoreFileResponse" + x-oaiMeta: + name: Delete vector store file + group: vector_stores + beta: true + returns: Deletion status + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files/file-abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -X DELETE + python: | + from openai import OpenAI + client = OpenAI() + + deleted_vector_store_file = client.beta.vector_stores.files.delete( + vector_store_id="vs_abc123", + file_id="file-abc123" + ) + print(deleted_vector_store_file) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const deletedVectorStoreFile = await openai.beta.vectorStores.files.del( + "vs_abc123", + "file-abc123" + ); + console.log(deletedVectorStoreFile); + } + + main(); + response: | + { + id: "file-abc123", + object: "vector_store.file.deleted", + deleted: true + } + + /vector_stores/{vector_store_id}/file_batches: + post: + operationId: createVectorStoreFileBatch + tags: + - Vector Stores + summary: Create a vector store file batch. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + example: vs_abc123 + description: | + The ID of the vector store for which to create a File Batch. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateVectorStoreFileBatchRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileBatchObject" + x-oaiMeta: + name: Create vector store file batch + group: vector_stores + beta: true + returns: A [vector store file batch](/docs/api-reference/vector-stores-file-batches/batch-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/file_batches \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "file_ids": ["file-abc123", "file-abc456"] + }' + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_file_batch = client.beta.vector_stores.file_batches.create( + vector_store_id="vs_abc123", + file_ids=["file-abc123", "file-abc456"] + ) + print(vector_store_file_batch) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const myVectorStoreFileBatch = await openai.beta.vectorStores.fileBatches.create( + "vs_abc123", + { + file_ids: ["file-abc123", "file-abc456"] + } + ); + console.log(myVectorStoreFileBatch); + } + + main(); + response: | + { + "id": "vsfb_abc123", + "object": "vector_store.file_batch", + "created_at": 1699061776, + "vector_store_id": "vs_abc123", + "status": "in_progress", + "file_counts": { + "in_progress": 1, + "completed": 1, + "failed": 0, + "cancelled": 0, + "total": 0, + } + } + + /vector_stores/{vector_store_id}/file_batches/{batch_id}: + get: + operationId: getVectorStoreFileBatch + tags: + - Vector Stores + summary: Retrieves a vector store file batch. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + example: vs_abc123 + description: The ID of the vector store that the file batch belongs to. + - in: path + name: batch_id + required: true + schema: + type: string + example: vsfb_abc123 + description: The ID of the file batch being retrieved. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileBatchObject" + x-oaiMeta: + name: Retrieve vector store file batch + group: vector_stores + beta: true + returns: The [vector store file batch](/docs/api-reference/vector-stores-file-batches/batch-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files_batches/vsfb_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_file_batch = client.beta.vector_stores.file_batches.retrieve( + vector_store_id="vs_abc123", + batch_id="vsfb_abc123" + ) + print(vector_store_file_batch) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStoreFileBatch = await openai.beta.vectorStores.fileBatches.retrieve( + "vs_abc123", + "vsfb_abc123" + ); + console.log(vectorStoreFileBatch); + } + + main(); + response: | + { + "id": "vsfb_abc123", + "object": "vector_store.file_batch", + "created_at": 1699061776, + "vector_store_id": "vs_abc123", + "status": "in_progress", + "file_counts": { + "in_progress": 1, + "completed": 1, + "failed": 0, + "cancelled": 0, + "total": 0, + } + } + + /vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel: + post: + operationId: cancelVectorStoreFileBatch + tags: + - Vector Stores + summary: Cancel a vector store file batch. This attempts to cancel the processing of files in this batch as soon as possible. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store that the file batch belongs to. + - in: path + name: batch_id + required: true + schema: + type: string + description: The ID of the file batch to cancel. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileBatchObject" + x-oaiMeta: + name: Cancel vector store file batch + group: vector_stores + beta: true + returns: The modified vector store file batch object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files_batches/vsfb_abc123/cancel \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -X POST + python: | + from openai import OpenAI + client = OpenAI() + + deleted_vector_store_file_batch = client.beta.vector_stores.file_batches.cancel( + vector_store_id="vs_abc123", + file_batch_id="vsfb_abc123" + ) + print(deleted_vector_store_file_batch) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const deletedVectorStoreFileBatch = await openai.vector_stores.fileBatches.cancel( + "vs_abc123", + "vsfb_abc123" + ); + console.log(deletedVectorStoreFileBatch); + } + + main(); + response: | + { + "id": "vsfb_abc123", + "object": "vector_store.file_batch", + "created_at": 1699061776, + "vector_store_id": "vs_abc123", + "status": "cancelling", + "file_counts": { + "in_progress": 12, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 15, + } + } + + /vector_stores/{vector_store_id}/file_batches/{batch_id}/files: + get: + operationId: listFilesInVectorStoreBatch + tags: + - Vector Stores + summary: Returns a list of vector store files in a batch. + parameters: + - name: vector_store_id + in: path + description: The ID of the vector store that the files belong to. + required: true + schema: + type: string + - name: batch_id + in: path + description: The ID of the file batch that the files belong to. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + - name: filter + in: query + description: "Filter by file status. One of `in_progress`, `completed`, `failed`, `cancelled`." + schema: + type: string + enum: ["in_progress", "completed", "failed", "cancelled"] + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListVectorStoreFilesResponse" + x-oaiMeta: + name: List vector store files in a batch + group: vector_stores + beta: true + returns: A list of [vector store file](/docs/api-reference/vector-stores-files/file-object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files_batches/vsfb_abc123/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_files = client.beta.vector_stores.file_batches.list_files( + vector_store_id="vs_abc123", + batch_id="vsfb_abc123" + ) + print(vector_store_files) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStoreFiles = await openai.beta.vectorStores.fileBatches.listFiles( + "vs_abc123", + "vsfb_abc123" + ); + console.log(vectorStoreFiles); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "file-abc123", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abc123" + }, + { + "id": "file-abc456", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abc123" + } + ], + "first_id": "file-abc123", + "last_id": "file-abc456", + "has_more": false + } + + /batches: + post: + summary: Creates and executes a batch from an uploaded file of requests + operationId: createBatch + tags: + - Batch + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - input_file_id + - endpoint + - completion_window + properties: + input_file_id: + type: string + description: | + The ID of an uploaded file that contains requests for the new batch. + + See [upload file](/docs/api-reference/files/create) for how to upload a file. + + Your input file must be formatted as a [JSONL file](/docs/api-reference/batch/request-input), and must be uploaded with the purpose `batch`. The file can contain up to 50,000 requests, and can be up to 100 MB in size. + endpoint: + type: string + enum: + [ + "/v1/chat/completions", + "/v1/embeddings", + "/v1/completions", + ] + description: The endpoint to be used for all requests in the batch. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported. Note that `/v1/embeddings` batches are also restricted to a maximum of 50,000 embedding inputs across all requests in the batch. + completion_window: + type: string + enum: ["24h"] + description: The time frame within which the batch should be processed. Currently only `24h` is supported. + metadata: + type: object + additionalProperties: + type: string + description: Optional custom metadata for the batch. + nullable: true + responses: + "200": + description: Batch created successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Batch" + x-oaiMeta: + name: Create batch + group: batch + returns: The created [Batch](/docs/api-reference/batch/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/batches \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "input_file_id": "file-abc123", + "endpoint": "/v1/chat/completions", + "completion_window": "24h" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.batches.create( + input_file_id="file-abc123", + endpoint="/v1/chat/completions", + completion_window="24h" + ) + node: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const batch = await openai.batches.create({ + input_file_id: "file-abc123", + endpoint: "/v1/chat/completions", + completion_window: "24h" + }); + + console.log(batch); + } + + main(); + response: | + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "validating", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": null, + "expires_at": null, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 0, + "completed": 0, + "failed": 0 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job", + } + } + get: + operationId: listBatches + tags: + - Batch + summary: List your organization's batches. + parameters: + - in: query + name: after + required: false + schema: + type: string + description: *pagination_after_param_description + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + responses: + "200": + description: Batch listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ListBatchesResponse" + x-oaiMeta: + name: List batch + group: batch + returns: A list of paginated [Batch](/docs/api-reference/batch/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/batches?limit=2 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" + python: | + from openai import OpenAI + client = OpenAI() + + client.batches.list() + node: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.batches.list(); + + for await (const batch of list) { + console.log(batch); + } + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job", + } + }, + { ... }, + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + } + + /batches/{batch_id}: + get: + operationId: retrieveBatch + tags: + - Batch + summary: Retrieves a batch. + parameters: + - in: path + name: batch_id + required: true + schema: + type: string + description: The ID of the batch to retrieve. + responses: + "200": + description: Batch retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Batch" + x-oaiMeta: + name: Retrieve batch + group: batch + returns: The [Batch](/docs/api-reference/batch/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/batches/batch_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + python: | + from openai import OpenAI + client = OpenAI() + + client.batches.retrieve("batch_abc123") + node: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const batch = await openai.batches.retrieve("batch_abc123"); + + console.log(batch); + } + + main(); + response: &batch_object | + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job", + } + } + + /batches/{batch_id}/cancel: + post: + operationId: cancelBatch + tags: + - Batch + summary: Cancels an in-progress batch. The batch will be in status `cancelling` for up to 10 minutes, before changing to `cancelled`, where it will have partial results (if any) available in the output file. + parameters: + - in: path + name: batch_id + required: true + schema: + type: string + description: The ID of the batch to cancel. + responses: + "200": + description: Batch is cancelling. Returns the cancelling batch's details. + content: + application/json: + schema: + $ref: "#/components/schemas/Batch" + x-oaiMeta: + name: Cancel batch + group: batch + returns: The [Batch](/docs/api-reference/batch/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/batches/batch_abc123/cancel \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -X POST + python: | + from openai import OpenAI + client = OpenAI() + + client.batches.cancel("batch_abc123") + node: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const batch = await openai.batches.cancel("batch_abc123"); + + console.log(batch); + } + + main(); + response: | + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job", + } + } + +components: + securitySchemes: + ApiKeyAuth: + type: http + scheme: "bearer" + + schemas: + Error: + type: object + properties: + code: + type: string + nullable: true + message: + type: string + nullable: false + param: + type: string + nullable: true + type: + type: string + nullable: false + required: + - type + - message + - param + - code + ErrorResponse: + type: object + properties: + error: + $ref: "#/components/schemas/Error" + required: + - error + + ListModelsResponse: + type: object + properties: + object: + type: string + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/Model" + required: + - object + - data + DeleteModelResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + required: + - id + - object + - deleted + + CreateCompletionRequest: + type: object + properties: + model: + description: &model_description | + ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them. + anyOf: + - type: string + - type: string + enum: ["gpt-3.5-turbo-instruct", "davinci-002", "babbage-002"] + x-oaiTypeLabel: string + prompt: + description: &completions_prompt_description | + The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays. + + Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document. + default: "<|endoftext|>" + nullable: true + oneOf: + - type: string + default: "" + example: "This is a test." + - type: array + items: + type: string + default: "" + example: "This is a test." + - type: array + minItems: 1 + items: + type: integer + example: "[1212, 318, 257, 1332, 13]" + - type: array + minItems: 1 + items: + type: array + minItems: 1 + items: + type: integer + example: "[[1212, 318, 257, 1332, 13]]" + best_of: + type: integer + default: 1 + minimum: 0 + maximum: 20 + nullable: true + description: &completions_best_of_description | + Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed. + + When used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`. + + **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. + echo: + type: boolean + default: false + nullable: true + description: &completions_echo_description > + Echo back the prompt in addition to the completion + frequency_penalty: + type: number + default: 0 + minimum: -2 + maximum: 2 + nullable: true + description: &completions_frequency_penalty_description | + Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + + [See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details) + logit_bias: &completions_logit_bias + type: object + x-oaiTypeLabel: map + default: null + nullable: true + additionalProperties: + type: integer + description: &completions_logit_bias_description | + Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + + As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated. + logprobs: &completions_logprobs_configuration + type: integer + minimum: 0 + maximum: 5 + default: null + nullable: true + description: &completions_logprobs_description | + Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response. + + The maximum value for `logprobs` is 5. + max_tokens: + type: integer + minimum: 0 + default: 16 + example: 16 + nullable: true + description: &completions_max_tokens_description | + The maximum number of [tokens](/tokenizer) that can be generated in the completion. + + The token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. + n: + type: integer + minimum: 1 + maximum: 128 + default: 1 + example: 1 + nullable: true + description: &completions_completions_description | + How many completions to generate for each prompt. + + **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. + presence_penalty: + type: number + default: 0 + minimum: -2 + maximum: 2 + nullable: true + description: &completions_presence_penalty_description | + Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + + [See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details) + seed: &completions_seed_param + type: integer + minimum: -9223372036854775808 + maximum: 9223372036854775807 + nullable: true + description: | + If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result. + + Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend. + stop: + description: &completions_stop_description > + Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. + default: null + nullable: true + oneOf: + - type: string + default: <|endoftext|> + example: "\n" + nullable: true + - type: array + minItems: 1 + maxItems: 4 + items: + type: string + example: '["\n"]' + stream: + description: > + Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions). + type: boolean + nullable: true + default: false + stream_options: + $ref: "#/components/schemas/ChatCompletionStreamOptions" + suffix: + description: | + The suffix that comes after a completion of inserted text. + + This parameter is only supported for `gpt-3.5-turbo-instruct`. + default: null + nullable: true + type: string + example: "test." + temperature: + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + description: &completions_temperature_description | + What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + + We generally recommend altering this or `top_p` but not both. + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: &completions_top_p_description | + An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + user: &end_user_param_configuration + type: string + example: user-1234 + description: | + A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids). + required: + - model + - prompt + + CreateCompletionResponse: + type: object + description: | + Represents a completion response from the API. Note: both the streamed and non-streamed response objects share the same shape (unlike the chat endpoint). + properties: + id: + type: string + description: A unique identifier for the completion. + choices: + type: array + description: The list of completion choices the model generated for the input prompt. + items: + type: object + required: + - finish_reason + - index + - logprobs + - text + properties: + finish_reason: + type: string + description: &completion_finish_reason_description | + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, + `length` if the maximum number of tokens specified in the request was reached, + or `content_filter` if content was omitted due to a flag from our content filters. + enum: ["stop", "length", "content_filter"] + index: + type: integer + logprobs: + type: object + nullable: true + properties: + text_offset: + type: array + items: + type: integer + token_logprobs: + type: array + items: + type: number + tokens: + type: array + items: + type: string + top_logprobs: + type: array + items: + type: object + additionalProperties: + type: number + text: + type: string + created: + type: integer + description: The Unix timestamp (in seconds) of when the completion was created. + model: + type: string + description: The model used for completion. + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always "text_completion" + enum: [text_completion] + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - id + - object + - created + - model + - choices + x-oaiMeta: + name: The completion object + legacy: true + example: | + { + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "gpt-4-turbo", + "choices": [ + { + "text": "\n\nThis is indeed a test", + "index": 0, + "logprobs": null, + "finish_reason": "length" + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12 + } + } + + ChatCompletionRequestMessageContentPart: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartImage" + x-oaiExpandable: true + + ChatCompletionRequestMessageContentPartImage: + type: object + title: Image content part + properties: + type: + type: string + enum: ["image_url"] + description: The type of the content part. + image_url: + type: object + properties: + url: + type: string + description: Either a URL of the image or the base64 encoded image data. + format: uri + detail: + type: string + description: Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding). + enum: ["auto", "low", "high"] + default: "auto" + required: + - url + required: + - type + - image_url + + ChatCompletionRequestMessageContentPartText: + type: object + title: Text content part + properties: + type: + type: string + enum: ["text"] + description: The type of the content part. + text: + type: string + description: The text content. + required: + - type + - text + + ChatCompletionRequestMessage: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestSystemMessage" + - $ref: "#/components/schemas/ChatCompletionRequestUserMessage" + - $ref: "#/components/schemas/ChatCompletionRequestAssistantMessage" + - $ref: "#/components/schemas/ChatCompletionRequestToolMessage" + - $ref: "#/components/schemas/ChatCompletionRequestFunctionMessage" + x-oaiExpandable: true + + ChatCompletionRequestSystemMessage: + type: object + title: System message + properties: + content: + description: The contents of the system message. + type: string + role: + type: string + enum: ["system"] + description: The role of the messages author, in this case `system`. + name: + type: string + description: An optional name for the participant. Provides the model information to differentiate between participants of the same role. + required: + - content + - role + + ChatCompletionRequestUserMessage: + type: object + title: User message + properties: + content: + description: | + The contents of the user message. + oneOf: + - type: string + description: The text contents of the message. + title: Text content + - type: array + description: An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model. + title: Array of content parts + items: + $ref: "#/components/schemas/ChatCompletionRequestMessageContentPart" + minItems: 1 + x-oaiExpandable: true + role: + type: string + enum: ["user"] + description: The role of the messages author, in this case `user`. + name: + type: string + description: An optional name for the participant. Provides the model information to differentiate between participants of the same role. + required: + - content + - role + + ChatCompletionRequestAssistantMessage: + type: object + title: Assistant message + properties: + content: + nullable: true + type: string + description: | + The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified. + role: + type: string + enum: ["assistant"] + description: The role of the messages author, in this case `assistant`. + name: + type: string + description: An optional name for the participant. Provides the model information to differentiate between participants of the same role. + tool_calls: + $ref: "#/components/schemas/ChatCompletionMessageToolCalls" + function_call: + type: object + deprecated: true + description: "Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + nullable: true + properties: + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + name: + type: string + description: The name of the function to call. + required: + - arguments + - name + required: + - role + + FineTuneChatCompletionRequestAssistantMessage: + allOf: + - type: object + title: Assistant message + deprecated: false + properties: + weight: + type: integer + enum: [0, 1] + description: "Controls whether the assistant message is trained against (0 or 1)" + - $ref: "#/components/schemas/ChatCompletionRequestAssistantMessage" + required: + - role + + ChatCompletionRequestToolMessage: + type: object + title: Tool message + properties: + role: + type: string + enum: ["tool"] + description: The role of the messages author, in this case `tool`. + content: + type: string + description: The contents of the tool message. + tool_call_id: + type: string + description: Tool call that this message is responding to. + required: + - role + - content + - tool_call_id + + ChatCompletionRequestFunctionMessage: + type: object + title: Function message + deprecated: true + properties: + role: + type: string + enum: ["function"] + description: The role of the messages author, in this case `function`. + content: + nullable: true + type: string + description: The contents of the function message. + name: + type: string + description: The name of the function to call. + required: + - role + - content + - name + + FunctionParameters: + type: object + description: "The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. \n\nOmitting `parameters` defines a function with an empty parameter list." + additionalProperties: true + + ChatCompletionFunctions: + type: object + deprecated: true + properties: + description: + type: string + description: A description of what the function does, used by the model to choose when and how to call the function. + name: + type: string + description: The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + parameters: + $ref: "#/components/schemas/FunctionParameters" + required: + - name + + ChatCompletionFunctionCallOption: + type: object + description: > + Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + properties: + name: + type: string + description: The name of the function to call. + required: + - name + + ChatCompletionTool: + type: object + properties: + type: + type: string + enum: ["function"] + description: The type of the tool. Currently, only `function` is supported. + function: + $ref: "#/components/schemas/FunctionObject" + required: + - type + - function + + FunctionObject: + type: object + properties: + description: + type: string + description: A description of what the function does, used by the model to choose when and how to call the function. + name: + type: string + description: The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + parameters: + $ref: "#/components/schemas/FunctionParameters" + required: + - name + + ChatCompletionToolChoiceOption: + description: | + Controls which (if any) tool is called by the model. + `none` means the model will not call any tool and instead generates a message. + `auto` means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools. + Specifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. + + `none` is the default when no tools are present. `auto` is the default if tools are present. + oneOf: + - type: string + description: > + `none` means the model will not call any tool and instead generates a message. + `auto` means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools. + enum: [none, auto, required] + - $ref: "#/components/schemas/ChatCompletionNamedToolChoice" + x-oaiExpandable: true + + ChatCompletionNamedToolChoice: + type: object + description: Specifies a tool the model should use. Use to force the model to call a specific function. + properties: + type: + type: string + enum: ["function"] + description: The type of the tool. Currently, only `function` is supported. + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + required: + - name + required: + - type + - function + + ParallelToolCalls: + description: Whether to enable [parallel function calling](/docs/guides/function-calling/parallel-function-calling) during tool use. + type: boolean + default: true + + ChatCompletionMessageToolCalls: + type: array + description: The tool calls generated by the model, such as function calls. + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCall" + + ChatCompletionMessageToolCall: + type: object + properties: + # TODO: index included when streaming + id: + type: string + description: The ID of the tool call. + type: + type: string + enum: ["function"] + description: The type of the tool. Currently, only `function` is supported. + function: + type: object + description: The function that the model called. + properties: + name: + type: string + description: The name of the function to call. + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + required: + - name + - arguments + required: + - id + - type + - function + + ChatCompletionMessageToolCallChunk: + type: object + properties: + index: + type: integer + id: + type: string + description: The ID of the tool call. + type: + type: string + enum: ["function"] + description: The type of the tool. Currently, only `function` is supported. + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + required: + - index + + # Note, this isn't referenced anywhere, but is kept as a convenience to record all possible roles in one place. + ChatCompletionRole: + type: string + description: The role of the author of a message + enum: + - system + - user + - assistant + - tool + - function + + ChatCompletionStreamOptions: + description: | + Options for streaming response. Only set this when you set `stream: true`. + type: object + nullable: true + default: null + properties: + include_usage: + type: boolean + description: | + If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value. + + ChatCompletionResponseMessage: + type: object + description: A chat completion message generated by the model. + properties: + content: + type: string + description: The contents of the message. + nullable: true + tool_calls: + $ref: "#/components/schemas/ChatCompletionMessageToolCalls" + role: + type: string + enum: ["assistant"] + description: The role of the author of this message. + function_call: + type: object + deprecated: true + description: "Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + properties: + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + name: + type: string + description: The name of the function to call. + required: + - name + - arguments + required: + - role + - content + + ChatCompletionStreamResponseDelta: + type: object + description: A chat completion delta generated by streamed model responses. + properties: + content: + type: string + description: The contents of the chunk message. + nullable: true + function_call: + deprecated: true + type: object + description: "Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + properties: + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + name: + type: string + description: The name of the function to call. + tool_calls: + type: array + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCallChunk" + role: + type: string + enum: ["system", "user", "assistant", "tool"] + description: The role of the author of this message. + + CreateChatCompletionRequest: + type: object + properties: + messages: + description: A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models). + type: array + minItems: 1 + items: + $ref: "#/components/schemas/ChatCompletionRequestMessage" + model: + description: ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API. + example: "gpt-4-turbo" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + frequency_penalty: + type: number + default: 0 + minimum: -2 + maximum: 2 + nullable: true + description: *completions_frequency_penalty_description + logit_bias: + type: object + x-oaiTypeLabel: map + default: null + nullable: true + additionalProperties: + type: integer + description: | + Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + logprobs: + description: Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`. + type: boolean + default: false + nullable: true + top_logprobs: + description: An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used. + type: integer + minimum: 0 + maximum: 20 + nullable: true + max_tokens: + description: | + The maximum number of [tokens](/tokenizer) that can be generated in the chat completion. + + The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. + type: integer + nullable: true + n: + type: integer + minimum: 1 + maximum: 128 + default: 1 + example: 1 + nullable: true + description: How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs. + presence_penalty: + type: number + default: 0 + minimum: -2 + maximum: 2 + nullable: true + description: *completions_presence_penalty_description + response_format: + type: object + description: | + An object specifying the format that the model must output. Compatible with [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`. + + Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON. + + **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length. + properties: + type: + type: string + enum: ["text", "json_object"] + example: "json_object" + default: "text" + description: Must be one of `text` or `json_object`. + seed: + type: integer + minimum: -9223372036854775808 + maximum: 9223372036854775807 + nullable: true + description: | + This feature is in Beta. + If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result. + Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend. + x-oaiMeta: + beta: true + service_tier: + description: | + Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service: + - If set to 'auto', the system will utilize scale tier credits until they are exhausted. + - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee. + - When not set, the default behavior is 'auto'. + + When this parameter is set, the response body will include the `service_tier` utilized. + type: string + enum: ["auto", "default"] + nullable: true + default: null + stop: + description: | + Up to 4 sequences where the API will stop generating further tokens. + default: null + oneOf: + - type: string + nullable: true + - type: array + minItems: 1 + maxItems: 4 + items: + type: string + stream: + description: > + If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions). + type: boolean + nullable: true + default: false + stream_options: + $ref: "#/components/schemas/ChatCompletionStreamOptions" + temperature: + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + description: *completions_temperature_description + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *completions_top_p_description + tools: + type: array + description: > + A list of tools the model may call. Currently, only functions are supported as a tool. + Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. + items: + $ref: "#/components/schemas/ChatCompletionTool" + tool_choice: + $ref: "#/components/schemas/ChatCompletionToolChoiceOption" + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + user: *end_user_param_configuration + function_call: + deprecated: true + description: | + Deprecated in favor of `tool_choice`. + + Controls which (if any) function is called by the model. + `none` means the model will not call a function and instead generates a message. + `auto` means the model can pick between generating a message or calling a function. + Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + + `none` is the default when no functions are present. `auto` is the default if functions are present. + oneOf: + - type: string + description: > + `none` means the model will not call a function and instead generates a message. + `auto` means the model can pick between generating a message or calling a function. + enum: [none, auto] + - $ref: "#/components/schemas/ChatCompletionFunctionCallOption" + x-oaiExpandable: true + functions: + deprecated: true + description: | + Deprecated in favor of `tools`. + + A list of functions the model may generate JSON inputs for. + type: array + minItems: 1 + maxItems: 128 + items: + $ref: "#/components/schemas/ChatCompletionFunctions" + + required: + - model + - messages + + CreateChatCompletionResponse: + type: object + description: Represents a chat completion response returned by model, based on the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. + choices: + type: array + description: A list of chat completion choices. Can be more than one if `n` is greater than 1. + items: + type: object + required: + - finish_reason + - index + - message + - logprobs + properties: + finish_reason: + type: string + description: &chat_completion_finish_reason_description | + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, + `length` if the maximum number of tokens specified in the request was reached, + `content_filter` if content was omitted due to a flag from our content filters, + `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + enum: + [ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + ] + index: + type: integer + description: The index of the choice in the list of choices. + message: + $ref: "#/components/schemas/ChatCompletionResponseMessage" + logprobs: &chat_completion_response_logprobs + description: Log probability information for the choice. + type: object + nullable: true + properties: + content: + description: A list of message content tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + nullable: true + required: + - content + created: + type: integer + description: The Unix timestamp (in seconds) of when the chat completion was created. + model: + type: string + description: The model used for the chat completion. + service_tier: + description: The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. + type: string + enum: ["scale", "default"] + example: "scale" + nullable: true + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always `chat.completion`. + enum: [chat.completion] + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + x-oaiMeta: + name: The chat completion object + group: chat + example: *chat_completion_example + + CreateChatCompletionFunctionResponse: + type: object + description: Represents a chat completion response returned by model, based on the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. + choices: + type: array + description: A list of chat completion choices. Can be more than one if `n` is greater than 1. + items: + type: object + required: + - finish_reason + - index + - message + - logprobs + properties: + finish_reason: + type: string + description: + &chat_completion_function_finish_reason_description | + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function. + enum: ["stop", "length", "function_call", "content_filter"] + index: + type: integer + description: The index of the choice in the list of choices. + message: + $ref: "#/components/schemas/ChatCompletionResponseMessage" + created: + type: integer + description: The Unix timestamp (in seconds) of when the chat completion was created. + model: + type: string + description: The model used for the chat completion. + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always `chat.completion`. + enum: [chat.completion] + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + x-oaiMeta: + name: The chat completion object + group: chat + example: *chat_completion_function_example + + ChatCompletionTokenLogprob: + type: object + properties: + token: &chat_completion_response_logprobs_token + description: The token. + type: string + logprob: &chat_completion_response_logprobs_token_logprob + description: The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. + type: number + bytes: &chat_completion_response_logprobs_bytes + description: A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. + type: array + items: + type: integer + nullable: true + top_logprobs: + description: List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned. + type: array + items: + type: object + properties: + token: *chat_completion_response_logprobs_token + logprob: *chat_completion_response_logprobs_token_logprob + bytes: *chat_completion_response_logprobs_bytes + required: + - token + - logprob + - bytes + required: + - token + - logprob + - bytes + - top_logprobs + + ListPaginatedFineTuningJobsResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/FineTuningJob" + has_more: + type: boolean + object: + type: string + enum: [list] + required: + - object + - data + - has_more + + CreateChatCompletionStreamResponse: + type: object + description: Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. Each chunk has the same ID. + choices: + type: array + description: | + A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the + last chunk if you set `stream_options: {"include_usage": true}`. + items: + type: object + required: + - delta + - finish_reason + - index + properties: + delta: + $ref: "#/components/schemas/ChatCompletionStreamResponseDelta" + logprobs: *chat_completion_response_logprobs + finish_reason: + type: string + description: *chat_completion_finish_reason_description + enum: + [ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + ] + nullable: true + index: + type: integer + description: The index of the choice in the list of choices. + created: + type: integer + description: The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp. + model: + type: string + description: The model to generate the completion. + service_tier: + description: The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. + type: string + enum: ["scale", "default"] + example: "scale" + nullable: true + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always `chat.completion.chunk`. + enum: [chat.completion.chunk] + usage: + type: object + description: | + An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request. + When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request. + properties: + completion_tokens: + type: integer + description: Number of tokens in the generated completion. + prompt_tokens: + type: integer + description: Number of tokens in the prompt. + total_tokens: + type: integer + description: Total number of tokens used in the request (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + required: + - choices + - created + - id + - model + - object + x-oaiMeta: + name: The chat completion chunk object + group: chat + example: *chat_completion_chunk_example + + CreateChatCompletionImageResponse: + type: object + description: Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + x-oaiMeta: + name: The chat completion chunk object + group: chat + example: *chat_completion_image_example + + CreateImageRequest: + type: object + properties: + prompt: + description: A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`. + type: string + example: "A cute baby sea otter" + model: + anyOf: + - type: string + - type: string + enum: ["dall-e-2", "dall-e-3"] + x-oaiTypeLabel: string + default: "dall-e-2" + example: "dall-e-3" + nullable: true + description: The model to use for image generation. + n: &images_n + type: integer + minimum: 1 + maximum: 10 + default: 1 + example: 1 + nullable: true + description: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported. + quality: + type: string + enum: ["standard", "hd"] + default: "standard" + example: "standard" + description: The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`. + response_format: &images_response_format + type: string + enum: ["url", "b64_json"] + default: "url" + example: "url" + nullable: true + description: The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated. + size: &images_size + type: string + enum: ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] + default: "1024x1024" + example: "1024x1024" + nullable: true + description: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models. + style: + type: string + enum: ["vivid", "natural"] + default: "vivid" + example: "vivid" + nullable: true + description: The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`. + user: *end_user_param_configuration + required: + - prompt + + ImagesResponse: + properties: + created: + type: integer + data: + type: array + items: + $ref: "#/components/schemas/Image" + required: + - created + - data + + Image: + type: object + description: Represents the url or the content of an image generated by the OpenAI API. + properties: + b64_json: + type: string + description: The base64-encoded JSON of the generated image, if `response_format` is `b64_json`. + url: + type: string + description: The URL of the generated image, if `response_format` is `url` (default). + revised_prompt: + type: string + description: The prompt that was used to generate the image, if there was any revision to the prompt. + x-oaiMeta: + name: The image object + example: | + { + "url": "...", + "revised_prompt": "..." + } + + CreateImageEditRequest: + type: object + properties: + image: + description: The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask. + type: string + format: binary + prompt: + description: A text description of the desired image(s). The maximum length is 1000 characters. + type: string + example: "A cute baby sea otter wearing a beret" + mask: + description: An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`. + type: string + format: binary + model: + anyOf: + - type: string + - type: string + enum: ["dall-e-2"] + x-oaiTypeLabel: string + default: "dall-e-2" + example: "dall-e-2" + nullable: true + description: The model to use for image generation. Only `dall-e-2` is supported at this time. + n: + type: integer + minimum: 1 + maximum: 10 + default: 1 + example: 1 + nullable: true + description: The number of images to generate. Must be between 1 and 10. + size: &dalle2_images_size + type: string + enum: ["256x256", "512x512", "1024x1024"] + default: "1024x1024" + example: "1024x1024" + nullable: true + description: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. + response_format: *images_response_format + user: *end_user_param_configuration + required: + - prompt + - image + + CreateImageVariationRequest: + type: object + properties: + image: + description: The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. + type: string + format: binary + model: + anyOf: + - type: string + - type: string + enum: ["dall-e-2"] + x-oaiTypeLabel: string + default: "dall-e-2" + example: "dall-e-2" + nullable: true + description: The model to use for image generation. Only `dall-e-2` is supported at this time. + n: *images_n + response_format: *images_response_format + size: *dalle2_images_size + user: *end_user_param_configuration + required: + - image + + CreateModerationRequest: + type: object + properties: + input: + description: The input text to classify + oneOf: + - type: string + default: "" + example: "I want to kill them." + - type: array + items: + type: string + default: "" + example: "I want to kill them." + model: + description: | + Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`. + + The default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`. + nullable: false + default: "text-moderation-latest" + example: "text-moderation-stable" + anyOf: + - type: string + - type: string + enum: ["text-moderation-latest", "text-moderation-stable"] + x-oaiTypeLabel: string + required: + - input + + CreateModerationResponse: + type: object + description: Represents if a given text input is potentially harmful. + properties: + id: + type: string + description: The unique identifier for the moderation request. + model: + type: string + description: The model used to generate the moderation results. + results: + type: array + description: A list of moderation objects. + items: + type: object + properties: + flagged: + type: boolean + description: Whether any of the below categories are flagged. + categories: + type: object + description: A list of the categories, and whether they are flagged or not. + properties: + hate: + type: boolean + description: Content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. Hateful content aimed at non-protected groups (e.g., chess players) is harassment. + hate/threatening: + type: boolean + description: Hateful content that also includes violence or serious harm towards the targeted group based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. + harassment: + type: boolean + description: Content that expresses, incites, or promotes harassing language towards any target. + harassment/threatening: + type: boolean + description: Harassment content that also includes violence or serious harm towards any target. + self-harm: + type: boolean + description: Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders. + self-harm/intent: + type: boolean + description: Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders. + self-harm/instructions: + type: boolean + description: Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts. + sexual: + type: boolean + description: Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness). + sexual/minors: + type: boolean + description: Sexual content that includes an individual who is under 18 years old. + violence: + type: boolean + description: Content that depicts death, violence, or physical injury. + violence/graphic: + type: boolean + description: Content that depicts death, violence, or physical injury in graphic detail. + required: + - hate + - hate/threatening + - harassment + - harassment/threatening + - self-harm + - self-harm/intent + - self-harm/instructions + - sexual + - sexual/minors + - violence + - violence/graphic + category_scores: + type: object + description: A list of the categories along with their scores as predicted by model. + properties: + hate: + type: number + description: The score for the category 'hate'. + hate/threatening: + type: number + description: The score for the category 'hate/threatening'. + harassment: + type: number + description: The score for the category 'harassment'. + harassment/threatening: + type: number + description: The score for the category 'harassment/threatening'. + self-harm: + type: number + description: The score for the category 'self-harm'. + self-harm/intent: + type: number + description: The score for the category 'self-harm/intent'. + self-harm/instructions: + type: number + description: The score for the category 'self-harm/instructions'. + sexual: + type: number + description: The score for the category 'sexual'. + sexual/minors: + type: number + description: The score for the category 'sexual/minors'. + violence: + type: number + description: The score for the category 'violence'. + violence/graphic: + type: number + description: The score for the category 'violence/graphic'. + required: + - hate + - hate/threatening + - harassment + - harassment/threatening + - self-harm + - self-harm/intent + - self-harm/instructions + - sexual + - sexual/minors + - violence + - violence/graphic + required: + - flagged + - categories + - category_scores + required: + - id + - model + - results + x-oaiMeta: + name: The moderation object + example: *moderation_example + + ListFilesResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/OpenAIFile" + object: + type: string + enum: [list] + required: + - object + - data + + CreateFileRequest: + type: object + additionalProperties: false + properties: + file: + description: | + The File object (not file name) to be uploaded. + type: string + format: binary + purpose: + description: | + The intended purpose of the uploaded file. + + Use "assistants" for [Assistants](/docs/api-reference/assistants) and [Message](/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](/docs/guides/batch), and "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning). + type: string + enum: ["assistants", "batch", "fine-tune", "vision"] + required: + - file + - purpose + + DeleteFileResponse: + type: object + properties: + id: + type: string + object: + type: string + enum: [file] + deleted: + type: boolean + required: + - id + - object + - deleted + + CreateUploadRequest: + type: object + additionalProperties: false + properties: + filename: + description: | + The name of the file to upload. + type: string + purpose: + description: | + The intended purpose of the uploaded file. + + See the [documentation on File purposes](/docs/api-reference/files/create#files-create-purpose). + type: string + enum: ["assistants", "batch", "fine-tune", "vision"] + bytes: + description: | + The number of bytes in the file you are uploading. + type: integer + mime_type: + description: | + The MIME type of the file. + + This must fall within the supported MIME types for your file purpose. See the supported MIME types for assistants and vision. + type: string + required: + - filename + - purpose + - bytes + - mime_type + + AddUploadPartRequest: + type: object + additionalProperties: false + properties: + data: + description: | + The chunk of bytes for this Part. + type: string + format: binary + required: + - data + + CompleteUploadRequest: + type: object + additionalProperties: false + properties: + part_ids: + type: array + description: | + The ordered list of Part IDs. + items: + type: string + md5: + description: | + The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect. + type: string + required: + - part_ids + + CancelUploadRequest: + type: object + additionalProperties: false + + CreateFineTuningJobRequest: + type: object + properties: + model: + description: | + The name of the model to fine-tune. You can select one of the + [supported models](/docs/guides/fine-tuning/what-models-can-be-fine-tuned). + example: "gpt-3.5-turbo" + anyOf: + - type: string + - type: string + enum: ["babbage-002", "davinci-002", "gpt-3.5-turbo"] + x-oaiTypeLabel: string + training_file: + description: | + The ID of an uploaded file that contains training data. + + See [upload file](/docs/api-reference/files/create) for how to upload a file. + + Your dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`. + + The contents of the file should differ depending on if the model uses the [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) format. + + See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + type: string + example: "file-abc123" + hyperparameters: + type: object + description: The hyperparameters used for the fine-tuning job. + properties: + batch_size: + description: | + Number of examples in each batch. A larger batch size means that model parameters + are updated less frequently, but with lower variance. + oneOf: + - type: string + enum: [auto] + - type: integer + minimum: 1 + maximum: 256 + default: auto + learning_rate_multiplier: + description: | + Scaling factor for the learning rate. A smaller learning rate may be useful to avoid + overfitting. + oneOf: + - type: string + enum: [auto] + - type: number + minimum: 0 + exclusiveMinimum: true + default: auto + n_epochs: + description: | + The number of epochs to train the model for. An epoch refers to one full cycle + through the training dataset. + oneOf: + - type: string + enum: [auto] + - type: integer + minimum: 1 + maximum: 50 + default: auto + suffix: + description: | + A string of up to 18 characters that will be added to your fine-tuned model name. + + For example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-3.5-turbo:openai:custom-model-name:7p4lURel`. + type: string + minLength: 1 + maxLength: 40 + default: null + nullable: true + validation_file: + description: | + The ID of an uploaded file that contains validation data. + + If you provide this file, the data is used to generate validation + metrics periodically during fine-tuning. These metrics can be viewed in + the fine-tuning results file. + The same data should not be present in both train and validation files. + + Your dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`. + + See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + type: string + nullable: true + example: "file-abc123" + integrations: + type: array + description: A list of integrations to enable for your fine-tuning job. + nullable: true + items: + type: object + required: + - type + - wandb + properties: + type: + description: | + The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported. + oneOf: + - type: string + enum: [wandb] + wandb: + type: object + description: | + The settings for your integration with Weights and Biases. This payload specifies the project that + metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags + to your run, and set a default entity (team, username, etc) to be associated with your run. + required: + - project + properties: + project: + description: | + The name of the project that the new run will be created under. + type: string + example: "my-wandb-project" + name: + description: | + A display name to set for the run. If not set, we will use the Job ID as the name. + nullable: true + type: string + entity: + description: | + The entity to use for the run. This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered WandB API key is used. + nullable: true + type: string + tags: + description: | + A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some + default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + type: array + items: + type: string + example: "custom-tag" + + seed: + description: | + The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases. + If a seed is not specified, one will be generated for you. + type: integer + nullable: true + minimum: 0 + maximum: 2147483647 + example: 42 + required: + - model + - training_file + + ListFineTuningJobEventsResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/FineTuningJobEvent" + object: + type: string + enum: [list] + required: + - object + - data + + ListFineTuningJobCheckpointsResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/FineTuningJobCheckpoint" + object: + type: string + enum: [list] + first_id: + type: string + nullable: true + last_id: + type: string + nullable: true + has_more: + type: boolean + required: + - object + - data + - has_more + + CreateEmbeddingRequest: + type: object + additionalProperties: false + properties: + input: + description: | + Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. + example: "The quick brown fox jumped over the lazy dog" + oneOf: + - type: string + title: string + description: The string that will be turned into an embedding. + default: "" + example: "This is a test." + - type: array + title: array + description: The array of strings that will be turned into an embedding. + minItems: 1 + maxItems: 2048 + items: + type: string + default: "" + example: "['This is a test.']" + - type: array + title: array + description: The array of integers that will be turned into an embedding. + minItems: 1 + maxItems: 2048 + items: + type: integer + example: "[1212, 318, 257, 1332, 13]" + - type: array + title: array + description: The array of arrays containing integers that will be turned into an embedding. + minItems: 1 + maxItems: 2048 + items: + type: array + minItems: 1 + items: + type: integer + example: "[[1212, 318, 257, 1332, 13]]" + x-oaiExpandable: true + model: + description: *model_description + example: "text-embedding-3-small" + anyOf: + - type: string + - type: string + enum: + [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ] + x-oaiTypeLabel: string + encoding_format: + description: "The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/)." + example: "float" + default: "float" + type: string + enum: ["float", "base64"] + dimensions: + description: | + The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models. + type: integer + minimum: 1 + user: *end_user_param_configuration + required: + - model + - input + + CreateEmbeddingResponse: + type: object + properties: + data: + type: array + description: The list of embeddings generated by the model. + items: + $ref: "#/components/schemas/Embedding" + model: + type: string + description: The name of the model used to generate the embedding. + object: + type: string + description: The object type, which is always "list". + enum: [list] + usage: + type: object + description: The usage information for the request. + properties: + prompt_tokens: + type: integer + description: The number of tokens used by the prompt. + total_tokens: + type: integer + description: The total number of tokens used by the request. + required: + - prompt_tokens + - total_tokens + required: + - object + - model + - data + - usage + + CreateTranscriptionRequest: + type: object + additionalProperties: false + properties: + file: + description: | + The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + type: string + x-oaiTypeLabel: file + format: binary + model: + description: | + ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. + example: whisper-1 + anyOf: + - type: string + - type: string + enum: ["whisper-1"] + x-oaiTypeLabel: string + language: + description: | + The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency. + type: string + prompt: + description: | + An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language. + type: string + response_format: + description: | + The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`. + type: string + enum: + - json + - text + - srt + - verbose_json + - vtt + default: json + temperature: + description: | + The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + type: number + default: 0 + timestamp_granularities[]: + description: | + The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency. + type: array + items: + type: string + enum: + - word + - segment + default: [segment] + required: + - file + - model + + # Note: This does not currently support the non-default response format types. + CreateTranscriptionResponseJson: + type: object + description: Represents a transcription response returned by model, based on the provided input. + properties: + text: + type: string + description: The transcribed text. + required: + - text + x-oaiMeta: + name: The transcription object (JSON) + group: audio + example: *basic_transcription_response_example + + TranscriptionSegment: + type: object + properties: + id: + type: integer + description: Unique identifier of the segment. + seek: + type: integer + description: Seek offset of the segment. + start: + type: number + format: float + description: Start time of the segment in seconds. + end: + type: number + format: float + description: End time of the segment in seconds. + text: + type: string + description: Text content of the segment. + tokens: + type: array + items: + type: integer + description: Array of token IDs for the text content. + temperature: + type: number + format: float + description: Temperature parameter used for generating the segment. + avg_logprob: + type: number + format: float + description: Average logprob of the segment. If the value is lower than -1, consider the logprobs failed. + compression_ratio: + type: number + format: float + description: Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed. + no_speech_prob: + type: number + format: float + description: Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this segment silent. + required: + - id + - seek + - start + - end + - text + - tokens + - temperature + - avg_logprob + - compression_ratio + - no_speech_prob + + TranscriptionWord: + type: object + properties: + word: + type: string + description: The text content of the word. + start: + type: number + format: float + description: Start time of the word in seconds. + end: + type: number + format: float + description: End time of the word in seconds. + required: [word, start, end] + + CreateTranscriptionResponseVerboseJson: + type: object + description: Represents a verbose json transcription response returned by model, based on the provided input. + properties: + language: + type: string + description: The language of the input audio. + duration: + type: string + description: The duration of the input audio. + text: + type: string + description: The transcribed text. + words: + type: array + description: Extracted words and their corresponding timestamps. + items: + $ref: "#/components/schemas/TranscriptionWord" + segments: + type: array + description: Segments of the transcribed text and their corresponding details. + items: + $ref: "#/components/schemas/TranscriptionSegment" + required: [language, duration, text] + x-oaiMeta: + name: The transcription object (Verbose JSON) + group: audio + example: *verbose_transcription_response_example + + CreateTranslationRequest: + type: object + additionalProperties: false + properties: + file: + description: | + The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + type: string + x-oaiTypeLabel: file + format: binary + model: + description: | + ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. + example: whisper-1 + anyOf: + - type: string + - type: string + enum: ["whisper-1"] + x-oaiTypeLabel: string + prompt: + description: | + An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English. + type: string + response_format: + description: | + The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`. + type: string + default: json + temperature: + description: | + The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + type: number + default: 0 + required: + - file + - model + + # Note: This does not currently support the non-default response format types. + CreateTranslationResponseJson: + type: object + properties: + text: + type: string + required: + - text + + CreateTranslationResponseVerboseJson: + type: object + properties: + language: + type: string + description: The language of the output translation (always `english`). + duration: + type: string + description: The duration of the input audio. + text: + type: string + description: The translated text. + segments: + type: array + description: Segments of the translated text and their corresponding details. + items: + $ref: "#/components/schemas/TranscriptionSegment" + required: [language, duration, text] + + CreateSpeechRequest: + type: object + additionalProperties: false + properties: + model: + description: | + One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd` + anyOf: + - type: string + - type: string + enum: ["tts-1", "tts-1-hd"] + x-oaiTypeLabel: string + input: + type: string + description: The text to generate audio for. The maximum length is 4096 characters. + maxLength: 4096 + voice: + description: The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are available in the [Text to speech guide](/docs/guides/text-to-speech/voice-options). + type: string + enum: ["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + response_format: + description: "The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`." + default: "mp3" + type: string + enum: ["mp3", "opus", "aac", "flac", "wav", "pcm"] + speed: + description: "The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default." + type: number + default: 1.0 + minimum: 0.25 + maximum: 4.0 + required: + - model + - input + - voice + + Model: + title: Model + description: Describes an OpenAI model offering that can be used with the API. + properties: + id: + type: string + description: The model identifier, which can be referenced in the API endpoints. + created: + type: integer + description: The Unix timestamp (in seconds) when the model was created. + object: + type: string + description: The object type, which is always "model". + enum: [model] + owned_by: + type: string + description: The organization that owns the model. + required: + - id + - object + - created + - owned_by + x-oaiMeta: + name: The model object + example: *retrieve_model_response + + OpenAIFile: + title: OpenAIFile + description: The `File` object represents a document that has been uploaded to OpenAI. + properties: + id: + type: string + description: The file identifier, which can be referenced in the API endpoints. + bytes: + type: integer + description: The size of the file, in bytes. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the file was created. + filename: + type: string + description: The name of the file. + object: + type: string + description: The object type, which is always `file`. + enum: ["file"] + purpose: + type: string + description: The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`. + enum: + [ + "assistants", + "assistants_output", + "batch", + "batch_output", + "fine-tune", + "fine-tune-results", + "vision", + ] + status: + type: string + deprecated: true + description: Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`. + enum: ["uploaded", "processed", "error"] + status_details: + type: string + deprecated: true + description: Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`. + required: + - id + - object + - bytes + - created_at + - filename + - purpose + - status + x-oaiMeta: + name: The file object + example: | + { + "id": "file-abc123", + "object": "file", + "bytes": 120000, + "created_at": 1677610602, + "filename": "salesOverview.pdf", + "purpose": "assistants", + } + Upload: + type: object + title: Upload + description: | + The Upload object can accept byte chunks in the form of Parts. + properties: + id: + type: string + description: The Upload unique identifier, which can be referenced in API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the Upload was created. + filename: + type: string + description: The name of the file to be uploaded. + bytes: + type: integer + description: The intended number of bytes to be uploaded. + purpose: + type: string + description: The intended purpose of the file. [Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values. + status: + type: string + description: The status of the Upload. + enum: ["pending", "completed", "cancelled", "expired"] + expires_at: + type: integer + description: The Unix timestamp (in seconds) for when the Upload was created. + object: + type: string + description: The object type, which is always "upload". + enum: [upload] + file: + $ref: "#/components/schemas/OpenAIFile" + nullable: true + description: The ready File object after the Upload is completed. + required: + - bytes + - created_at + - expires_at + - filename + - id + - purpose + - status + - step_number + x-oaiMeta: + name: The upload object + example: | + { + "id": "upload_abc123", + "object": "upload", + "bytes": 2147483648, + "created_at": 1719184911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + "status": "completed", + "expires_at": 1719127296, + "file": { + "id": "file-xyz321", + "object": "file", + "bytes": 2147483648, + "created_at": 1719186911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + } + } + UploadPart: + type: object + title: UploadPart + description: | + The upload Part represents a chunk of bytes we can add to an Upload object. + properties: + id: + type: string + description: The upload Part unique identifier, which can be referenced in API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the Part was created. + upload_id: + type: string + description: The ID of the Upload object that this Part was added to. + object: + type: string + description: The object type, which is always `upload.part`. + enum: ["upload.part"] + required: + - created_at + - id + - object + - upload_id + x-oaiMeta: + name: The upload part object + example: | + { + "id": "part_def456", + "object": "upload.part", + "created_at": 1719186911, + "upload_id": "upload_abc123" + } + Embedding: + type: object + description: | + Represents an embedding vector returned by embedding endpoint. + properties: + index: + type: integer + description: The index of the embedding in the list of embeddings. + embedding: + type: array + description: | + The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings). + items: + type: number + object: + type: string + description: The object type, which is always "embedding". + enum: [embedding] + required: + - index + - object + - embedding + x-oaiMeta: + name: The embedding object + example: | + { + "object": "embedding", + "embedding": [ + 0.0023064255, + -0.009327292, + .... (1536 floats total for ada-002) + -0.0028842222, + ], + "index": 0 + } + + FineTuningJob: + type: object + title: FineTuningJob + description: | + The `fine_tuning.job` object represents a fine-tuning job that has been created through the API. + properties: + id: + type: string + description: The object identifier, which can be referenced in the API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the fine-tuning job was created. + error: + type: object + nullable: true + description: For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure. + properties: + code: + type: string + description: A machine-readable error code. + message: + type: string + description: A human-readable error message. + param: + type: string + description: The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific. + nullable: true + required: + - code + - message + - param + fine_tuned_model: + type: string + nullable: true + description: The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running. + finished_at: + type: integer + nullable: true + description: The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running. + hyperparameters: + type: object + description: The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + properties: + n_epochs: + oneOf: + - type: string + enum: [auto] + - type: integer + minimum: 1 + maximum: 50 + default: auto + description: + The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. + + "auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs. + required: + - n_epochs + model: + type: string + description: The base model that is being fine-tuned. + object: + type: string + description: The object type, which is always "fine_tuning.job". + enum: [fine_tuning.job] + organization_id: + type: string + description: The organization that owns the fine-tuning job. + result_files: + type: array + description: The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the [Files API](/docs/api-reference/files/retrieve-contents). + items: + type: string + example: file-abc123 + status: + type: string + description: The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`. + enum: + [ + "validating_files", + "queued", + "running", + "succeeded", + "failed", + "cancelled", + ] + trained_tokens: + type: integer + nullable: true + description: The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running. + training_file: + type: string + description: The file ID used for training. You can retrieve the training data with the [Files API](/docs/api-reference/files/retrieve-contents). + validation_file: + type: string + nullable: true + description: The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents). + integrations: + type: array + nullable: true + description: A list of integrations to enable for this fine-tuning job. + maxItems: 5 + items: + oneOf: + - $ref: "#/components/schemas/FineTuningIntegration" + x-oaiExpandable: true + seed: + type: integer + description: The seed used for the fine-tuning job. + estimated_finish: + type: integer + nullable: true + description: The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running. + required: + - created_at + - error + - finished_at + - fine_tuned_model + - hyperparameters + - id + - model + - object + - organization_id + - result_files + - status + - trained_tokens + - training_file + - validation_file + - seed + x-oaiMeta: + name: The fine-tuning job object + example: *fine_tuning_example + + FineTuningIntegration: + type: object + title: Fine-Tuning Job Integration + required: + - type + - wandb + properties: + type: + type: string + description: "The type of the integration being enabled for the fine-tuning job" + enum: ["wandb"] + wandb: + type: object + description: | + The settings for your integration with Weights and Biases. This payload specifies the project that + metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags + to your run, and set a default entity (team, username, etc) to be associated with your run. + required: + - project + properties: + project: + description: | + The name of the project that the new run will be created under. + type: string + example: "my-wandb-project" + name: + description: | + A display name to set for the run. If not set, we will use the Job ID as the name. + nullable: true + type: string + entity: + description: | + The entity to use for the run. This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered WandB API key is used. + nullable: true + type: string + tags: + description: | + A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some + default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + type: array + items: + type: string + example: "custom-tag" + + FineTuningJobEvent: + type: object + description: Fine-tuning job event object + properties: + id: + type: string + created_at: + type: integer + level: + type: string + enum: ["info", "warn", "error"] + message: + type: string + object: + type: string + enum: [fine_tuning.job.event] + required: + - id + - object + - created_at + - level + - message + x-oaiMeta: + name: The fine-tuning job event object + example: | + { + "object": "fine_tuning.job.event", + "id": "ftevent-abc123" + "created_at": 1677610602, + "level": "info", + "message": "Created fine-tuning job" + } + + FineTuningJobCheckpoint: + type: object + title: FineTuningJobCheckpoint + description: | + The `fine_tuning.job.checkpoint` object represents a model checkpoint for a fine-tuning job that is ready to use. + properties: + id: + type: string + description: The checkpoint identifier, which can be referenced in the API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the checkpoint was created. + fine_tuned_model_checkpoint: + type: string + description: The name of the fine-tuned checkpoint model that is created. + step_number: + type: integer + description: The step number that the checkpoint was created at. + metrics: + type: object + description: Metrics at the step number during the fine-tuning job. + properties: + step: + type: number + train_loss: + type: number + train_mean_token_accuracy: + type: number + valid_loss: + type: number + valid_mean_token_accuracy: + type: number + full_valid_loss: + type: number + full_valid_mean_token_accuracy: + type: number + fine_tuning_job_id: + type: string + description: The name of the fine-tuning job that this checkpoint was created from. + object: + type: string + description: The object type, which is always "fine_tuning.job.checkpoint". + enum: [fine_tuning.job.checkpoint] + required: + - created_at + - fine_tuning_job_id + - fine_tuned_model_checkpoint + - id + - metrics + - object + - step_number + x-oaiMeta: + name: The fine-tuning job checkpoint object + example: | + { + "object": "fine_tuning.job.checkpoint", + "id": "ftckpt_qtZ5Gyk4BLq1SfLFWp3RtO3P", + "created_at": 1712211699, + "fine_tuned_model_checkpoint": "ft:gpt-3.5-turbo-0125:my-org:custom_suffix:9ABel2dg:ckpt-step-88", + "fine_tuning_job_id": "ftjob-fpbNQ3H1GrMehXRf8cO97xTN", + "metrics": { + "step": 88, + "train_loss": 0.478, + "train_mean_token_accuracy": 0.924, + "valid_loss": 10.112, + "valid_mean_token_accuracy": 0.145, + "full_valid_loss": 0.567, + "full_valid_mean_token_accuracy": 0.944 + }, + "step_number": 88 + } + + FinetuneChatRequestInput: + type: object + description: The per-line training example of a fine-tuning input file for chat models + properties: + messages: + type: array + minItems: 1 + items: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestSystemMessage" + - $ref: "#/components/schemas/ChatCompletionRequestUserMessage" + - $ref: "#/components/schemas/FineTuneChatCompletionRequestAssistantMessage" + - $ref: "#/components/schemas/ChatCompletionRequestToolMessage" + - $ref: "#/components/schemas/ChatCompletionRequestFunctionMessage" + x-oaiExpandable: true + tools: + type: array + description: A list of tools the model may generate JSON inputs for. + items: + $ref: "#/components/schemas/ChatCompletionTool" + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + functions: + deprecated: true + description: A list of functions the model may generate JSON inputs for. + type: array + minItems: 1 + maxItems: 128 + items: + $ref: "#/components/schemas/ChatCompletionFunctions" + x-oaiMeta: + name: Training format for chat models + example: | + { + "messages": [ + { "role": "user", "content": "What is the weather in San Francisco?" }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_id", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}" + } + } + ] + } + ], + "parallel_tool_calls": false, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and country, eg. San Francisco, USA" + }, + "format": { "type": "string", "enum": ["celsius", "fahrenheit"] } + }, + "required": ["location", "format"] + } + } + } + ] + } + + FinetuneCompletionRequestInput: + type: object + description: The per-line training example of a fine-tuning input file for completions models + properties: + prompt: + type: string + description: The input prompt for this training example. + completion: + type: string + description: The desired completion for this training example. + x-oaiMeta: + name: Training format for completions models + example: | + { + "prompt": "What is the answer to 2+2", + "completion": "4" + } + + CompletionUsage: + type: object + description: Usage statistics for the completion request. + properties: + completion_tokens: + type: integer + description: Number of tokens in the generated completion. + prompt_tokens: + type: integer + description: Number of tokens in the prompt. + total_tokens: + type: integer + description: Total number of tokens used in the request (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + + RunCompletionUsage: + type: object + description: Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.). + properties: + completion_tokens: + type: integer + description: Number of completion tokens used over the course of the run. + prompt_tokens: + type: integer + description: Number of prompt tokens used over the course of the run. + total_tokens: + type: integer + description: Total number of tokens used (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + nullable: true + + RunStepCompletionUsage: + type: object + description: Usage statistics related to the run step. This value will be `null` while the run step's status is `in_progress`. + properties: + completion_tokens: + type: integer + description: Number of completion tokens used over the course of the run step. + prompt_tokens: + type: integer + description: Number of prompt tokens used over the course of the run step. + total_tokens: + type: integer + description: Total number of tokens used (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + nullable: true + + AssistantsApiResponseFormatOption: + description: | + Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. + + Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON. + + **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length. + oneOf: + - type: string + description: > + `auto` is the default value + enum: [none, auto] + - $ref: "#/components/schemas/AssistantsApiResponseFormat" + x-oaiExpandable: true + + AssistantsApiResponseFormat: + type: object + description: | + An object describing the expected output of the model. If `json_object` only `function` type `tools` are allowed to be passed to the Run. If `text` the model can return text or any value needed. + properties: + type: + type: string + enum: ["text", "json_object"] + example: "json_object" + default: "text" + description: Must be one of `text` or `json_object`. + + AssistantObject: + type: object + title: Assistant + description: Represents an `assistant` that can call the model and use tools. + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `assistant`. + type: string + enum: [assistant] + created_at: + description: The Unix timestamp (in seconds) for when the assistant was created. + type: integer + name: + description: &assistant_name_param_description | + The name of the assistant. The maximum length is 256 characters. + type: string + maxLength: 256 + nullable: true + description: + description: &assistant_description_param_description | + The description of the assistant. The maximum length is 512 characters. + type: string + maxLength: 512 + nullable: true + model: + description: *model_description + type: string + instructions: + description: &assistant_instructions_param_description | + The system instructions that the assistant uses. The maximum length is 256,000 characters. + type: string + maxLength: 256000 + nullable: true + tools: + description: &assistant_tools_param_description | + A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`. + default: [] + type: array + maxItems: 128 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: &metadata_description | + Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + description: &run_temperature_description | + What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: &run_top_p_description | + An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or temperature but not both. + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - id + - object + - created_at + - name + - description + - model + - instructions + - tools + - metadata + x-oaiMeta: + name: The assistant object + beta: true + example: *create_assistants_example + + CreateAssistantRequest: + type: object + additionalProperties: false + properties: + model: + description: *model_description + example: "gpt-4-turbo" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + name: + description: *assistant_name_param_description + type: string + nullable: true + maxLength: 256 + description: + description: *assistant_description_param_description + type: string + nullable: true + maxLength: 512 + instructions: + description: *assistant_instructions_param_description + type: string + nullable: true + maxLength: 256000 + tools: + description: *assistant_tools_param_description + default: [] + type: array + maxItems: 128 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + vector_stores: + type: array + description: | + A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store. + maxItems: 10000 + items: + type: string + chunking_strategy: + # Ideally we'd reuse the chunking strategy schema here, but it doesn't expand properly + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. + oneOf: + - type: object + title: Auto Chunking Strategy + description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + additionalProperties: false + properties: + type: + type: string + description: Always `auto`. + enum: ["auto"] + required: + - type + - type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + type: object + additionalProperties: false + properties: + max_chunk_size_tokens: + type: integer + minimum: 100 + maximum: 4096 + description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. + chunk_overlap_tokens: + type: integer + description: | + The number of tokens that overlap between chunks. The default value is `400`. + + Note that the overlap must not exceed half of `max_chunk_size_tokens`. + required: + - max_chunk_size_tokens + - chunk_overlap_tokens + required: + - type + - static + x-oaiExpandable: true + metadata: + type: object + description: | + Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + x-oaiTypeLabel: map + oneOf: + - required: [vector_store_ids] + - required: [vector_stores] + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + description: *run_temperature_description + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - model + + ModifyAssistantRequest: + type: object + additionalProperties: false + properties: + model: + description: *model_description + anyOf: + - type: string + name: + description: *assistant_name_param_description + type: string + nullable: true + maxLength: 256 + description: + description: *assistant_description_param_description + type: string + nullable: true + maxLength: 512 + instructions: + description: *assistant_instructions_param_description + type: string + nullable: true + maxLength: 256000 + tools: + description: *assistant_tools_param_description + default: [] + type: array + maxItems: 128 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + description: *run_temperature_description + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + + DeleteAssistantResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [assistant.deleted] + required: + - id + - object + - deleted + + ListAssistantsResponse: + type: object + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/AssistantObject" + first_id: + type: string + example: "asst_abc123" + last_id: + type: string + example: "asst_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + x-oaiMeta: + name: List assistants response object + group: chat + example: *list_assistants_example + + AssistantToolsCode: + type: object + title: Code interpreter tool + properties: + type: + type: string + description: "The type of tool being defined: `code_interpreter`" + enum: ["code_interpreter"] + required: + - type + + AssistantToolsFileSearch: + type: object + title: FileSearch tool + properties: + type: + type: string + description: "The type of tool being defined: `file_search`" + enum: ["file_search"] + file_search: + type: object + description: Overrides for the file search tool. + properties: + max_num_results: + type: integer + minimum: 1 + maximum: 50 + description: | + The maximum number of results the file search tool should output. The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. This number should be between 1 and 50 inclusive. + + Note that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information. + required: + - type + + AssistantToolsFileSearchTypeOnly: + type: object + title: FileSearch tool + properties: + type: + type: string + description: "The type of tool being defined: `file_search`" + enum: ["file_search"] + required: + - type + + AssistantToolsFunction: + type: object + title: Function tool + properties: + type: + type: string + description: "The type of tool being defined: `function`" + enum: ["function"] + function: + $ref: "#/components/schemas/FunctionObject" + required: + - type + - function + + TruncationObject: + type: object + title: Thread Truncation Controls + description: Controls for how a thread will be truncated prior to the run. Use this to control the intial context window of the run. + properties: + type: + type: string + description: The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`. + enum: ["auto", "last_messages"] + last_messages: + type: integer + description: The number of most recent messages from the thread when constructing the context for the run. + minimum: 1 + nullable: true + required: + - type + + AssistantsApiToolChoiceOption: + description: | + Controls which (if any) tool is called by the model. + `none` means the model will not call any tools and instead generates a message. + `auto` is the default value and means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools before responding to the user. + Specifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. + + oneOf: + - type: string + description: > + `none` means the model will not call any tools and instead generates a message. + `auto` means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools before responding to the user. + enum: [none, auto, required] + - $ref: "#/components/schemas/AssistantsNamedToolChoice" + x-oaiExpandable: true + + AssistantsNamedToolChoice: + type: object + description: Specifies a tool the model should use. Use to force the model to call a specific tool. + properties: + type: + type: string + enum: ["function", "code_interpreter", "file_search"] + description: The type of the tool. If type is `function`, the function name must be set + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + required: + - name + required: + - type + + RunObject: + type: object + title: A run on a thread + description: Represents an execution run on a [thread](/docs/api-reference/threads). + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.run`. + type: string + enum: ["thread.run"] + created_at: + description: The Unix timestamp (in seconds) for when the run was created. + type: integer + thread_id: + description: The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run. + type: string + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run. + type: string + status: + description: The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`. + type: string + enum: + [ + "queued", + "in_progress", + "requires_action", + "cancelling", + "cancelled", + "failed", + "completed", + "incomplete", + "expired", + ] + required_action: + type: object + description: Details on the action required to continue the run. Will be `null` if no action is required. + nullable: true + properties: + type: + description: For now, this is always `submit_tool_outputs`. + type: string + enum: ["submit_tool_outputs"] + submit_tool_outputs: + type: object + description: Details on the tool outputs needed for this run to continue. + properties: + tool_calls: + type: array + description: A list of the relevant tool calls. + items: + $ref: "#/components/schemas/RunToolCallObject" + required: + - tool_calls + required: + - type + - submit_tool_outputs + last_error: + type: object + description: The last error associated with this run. Will be `null` if there are no errors. + nullable: true + properties: + code: + type: string + description: One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`. + enum: ["server_error", "rate_limit_exceeded", "invalid_prompt"] + message: + type: string + description: A human-readable description of the error. + required: + - code + - message + expires_at: + description: The Unix timestamp (in seconds) for when the run will expire. + type: integer + nullable: true + started_at: + description: The Unix timestamp (in seconds) for when the run was started. + type: integer + nullable: true + cancelled_at: + description: The Unix timestamp (in seconds) for when the run was cancelled. + type: integer + nullable: true + failed_at: + description: The Unix timestamp (in seconds) for when the run failed. + type: integer + nullable: true + completed_at: + description: The Unix timestamp (in seconds) for when the run was completed. + type: integer + nullable: true + incomplete_details: + description: Details on why the run is incomplete. Will be `null` if the run is not incomplete. + type: object + nullable: true + properties: + reason: + description: The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run. + type: string + enum: ["max_completion_tokens", "max_prompt_tokens"] + model: + description: The model that the [assistant](/docs/api-reference/assistants) used for this run. + type: string + instructions: + description: The instructions that the [assistant](/docs/api-reference/assistants) used for this run. + type: string + tools: + description: The list of tools that the [assistant](/docs/api-reference/assistants) used for this run. + default: [] + type: array + maxItems: 20 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + usage: + $ref: "#/components/schemas/RunCompletionUsage" + temperature: + description: The sampling temperature used for this run. If not set, defaults to 1. + type: number + nullable: true + top_p: + description: The nucleus sampling value used for this run. If not set, defaults to 1. + type: number + nullable: true + max_prompt_tokens: + type: integer + nullable: true + description: | + The maximum number of prompt tokens specified to have been used over the course of the run. + minimum: 256 + max_completion_tokens: + type: integer + nullable: true + description: | + The maximum number of completion tokens specified to have been used over the course of the run. + minimum: 256 + truncation_strategy: + $ref: "#/components/schemas/TruncationObject" + nullable: true + tool_choice: + $ref: "#/components/schemas/AssistantsApiToolChoiceOption" + nullable: true + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - id + - object + - created_at + - thread_id + - assistant_id + - status + - required_action + - last_error + - expires_at + - started_at + - cancelled_at + - failed_at + - completed_at + - model + - instructions + - tools + - metadata + - usage + - incomplete_details + - max_prompt_tokens + - max_completion_tokens + - truncation_strategy + - tool_choice + - parallel_tool_calls + - response_format + x-oaiMeta: + name: The run object + beta: true + example: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1698107661, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699073476, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699073498, + "last_error": null, + "model": "gpt-4-turbo", + "instructions": null, + "tools": [{"type": "file_search"}, {"type": "code_interpreter"}], + "metadata": {}, + "incomplete_details": null, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + CreateRunRequest: + type: object + additionalProperties: false + properties: + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run. + type: string + model: + description: The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. + example: "gpt-4-turbo" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + nullable: true + instructions: + description: Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis. + type: string + nullable: true + additional_instructions: + description: Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions. + type: string + nullable: true + additional_messages: + description: Adds additional messages to the thread before creating the run. + type: array + items: + $ref: "#/components/schemas/CreateMessageRequest" + nullable: true + tools: + description: Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. + nullable: true + type: array + maxItems: 20 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + description: *run_temperature_description + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + stream: + type: boolean + nullable: true + description: | + If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + max_prompt_tokens: + type: integer + nullable: true + description: | + The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + max_completion_tokens: + type: integer + nullable: true + description: | + The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + truncation_strategy: + $ref: "#/components/schemas/TruncationObject" + nullable: true + tool_choice: + $ref: "#/components/schemas/AssistantsApiToolChoiceOption" + nullable: true + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - thread_id + - assistant_id + ListRunsResponse: + type: object + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/RunObject" + first_id: + type: string + example: "run_abc123" + last_id: + type: string + example: "run_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + ModifyRunRequest: + type: object + additionalProperties: false + properties: + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + SubmitToolOutputsRunRequest: + type: object + additionalProperties: false + properties: + tool_outputs: + description: A list of tools for which the outputs are being submitted. + type: array + items: + type: object + properties: + tool_call_id: + type: string + description: The ID of the tool call in the `required_action` object within the run object the output is being submitted for. + output: + type: string + description: The output of the tool call to be submitted to continue the run. + stream: + type: boolean + nullable: true + description: | + If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + required: + - tool_outputs + + RunToolCallObject: + type: object + description: Tool call objects + properties: + id: + type: string + description: The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](/docs/api-reference/runs/submitToolOutputs) endpoint. + type: + type: string + description: The type of tool call the output is required for. For now, this is always `function`. + enum: ["function"] + function: + type: object + description: The function definition. + properties: + name: + type: string + description: The name of the function. + arguments: + type: string + description: The arguments that the model expects you to pass to the function. + required: + - name + - arguments + required: + - id + - type + - function + + CreateThreadAndRunRequest: + type: object + additionalProperties: false + properties: + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run. + type: string + thread: + $ref: "#/components/schemas/CreateThreadRequest" + description: If no thread is provided, an empty thread will be created. + model: + description: The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. + example: "gpt-4-turbo" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + nullable: true + instructions: + description: Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis. + type: string + nullable: true + tools: + description: Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. + nullable: true + type: array + maxItems: 20 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + description: *run_temperature_description + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + stream: + type: boolean + nullable: true + description: | + If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + max_prompt_tokens: + type: integer + nullable: true + description: | + The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + max_completion_tokens: + type: integer + nullable: true + description: | + The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + truncation_strategy: + $ref: "#/components/schemas/TruncationObject" + nullable: true + tool_choice: + $ref: "#/components/schemas/AssistantsApiToolChoiceOption" + nullable: true + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - thread_id + - assistant_id + + ThreadObject: + type: object + title: Thread + description: Represents a thread that contains [messages](/docs/api-reference/messages). + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread`. + type: string + enum: ["thread"] + created_at: + description: The Unix timestamp (in seconds) for when the thread was created. + type: integer + tool_resources: + type: object + description: | + A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + required: + - id + - object + - created_at + - tool_resources + - metadata + x-oaiMeta: + name: The thread object + beta: true + example: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1698107661, + "metadata": {} + } + + CreateThreadRequest: + type: object + additionalProperties: false + properties: + messages: + description: A list of [messages](/docs/api-reference/messages) to start the thread with. + type: array + items: + $ref: "#/components/schemas/CreateMessageRequest" + tool_resources: + type: object + description: | + A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: string + vector_stores: + type: array + description: | + A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store. + maxItems: 10000 + items: + type: string + chunking_strategy: + # Ideally we'd reuse the chunking strategy schema here, but it doesn't expand properly + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. + oneOf: + - type: object + title: Auto Chunking Strategy + description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + additionalProperties: false + properties: + type: + type: string + description: Always `auto`. + enum: ["auto"] + required: + - type + - type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + type: object + additionalProperties: false + properties: + max_chunk_size_tokens: + type: integer + minimum: 100 + maximum: 4096 + description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. + chunk_overlap_tokens: + type: integer + description: | + The number of tokens that overlap between chunks. The default value is `400`. + + Note that the overlap must not exceed half of `max_chunk_size_tokens`. + required: + - max_chunk_size_tokens + - chunk_overlap_tokens + required: + - type + - static + x-oaiExpandable: true + metadata: + type: object + description: | + Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + x-oaiTypeLabel: map + x-oaiExpandable: true + oneOf: + - required: [vector_store_ids] + - required: [vector_stores] + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + ModifyThreadRequest: + type: object + additionalProperties: false + properties: + tool_resources: + type: object + description: | + A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + DeleteThreadResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [thread.deleted] + required: + - id + - object + - deleted + + ListThreadsResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/ThreadObject" + first_id: + type: string + example: "asst_abc123" + last_id: + type: string + example: "asst_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + MessageObject: + type: object + title: The message object + description: Represents a message within a [thread](/docs/api-reference/threads). + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.message`. + type: string + enum: ["thread.message"] + created_at: + description: The Unix timestamp (in seconds) for when the message was created. + type: integer + thread_id: + description: The [thread](/docs/api-reference/threads) ID that this message belongs to. + type: string + status: + description: The status of the message, which can be either `in_progress`, `incomplete`, or `completed`. + type: string + enum: ["in_progress", "incomplete", "completed"] + incomplete_details: + description: On an incomplete message, details about why the message is incomplete. + type: object + properties: + reason: + type: string + description: The reason the message is incomplete. + enum: + [ + "content_filter", + "max_tokens", + "run_cancelled", + "run_expired", + "run_failed", + ] + nullable: true + required: + - reason + completed_at: + description: The Unix timestamp (in seconds) for when the message was completed. + type: integer + nullable: true + incomplete_at: + description: The Unix timestamp (in seconds) for when the message was marked as incomplete. + type: integer + nullable: true + role: + description: The entity that produced the message. One of `user` or `assistant`. + type: string + enum: ["user", "assistant"] + content: + description: The content of the message in array of text and/or images. + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageContentImageFileObject" + - $ref: "#/components/schemas/MessageContentImageUrlObject" + - $ref: "#/components/schemas/MessageContentTextObject" + x-oaiExpandable: true + assistant_id: + description: If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message. + type: string + nullable: true + run_id: + description: The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints. + type: string + nullable: true + attachments: + type: array + items: + type: object + properties: + file_id: + type: string + description: The ID of the file to attach to the message. + tools: + description: The tools to add this file to. + type: array + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearchTypeOnly" + x-oaiExpandable: true + description: A list of files attached to the message, and the tools they were added to. + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + required: + - id + - object + - created_at + - thread_id + - status + - incomplete_details + - completed_at + - incomplete_at + - role + - content + - assistant_id + - run_id + - attachments + - metadata + x-oaiMeta: + name: The message object + beta: true + example: | + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1698983503, + "thread_id": "thread_abc123", + "role": "assistant", + "content": [ + { + "type": "text", + "text": { + "value": "Hi! How can I help you today?", + "annotations": [] + } + } + ], + "assistant_id": "asst_abc123", + "run_id": "run_abc123", + "attachments": [], + "metadata": {} + } + + MessageDeltaObject: + type: object + title: Message delta object + description: | + Represents a message delta i.e. any changed fields on a message during streaming. + properties: + id: + description: The identifier of the message, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.message.delta`. + type: string + enum: ["thread.message.delta"] + delta: + description: The delta containing the fields that have changed on the Message. + type: object + properties: + role: + description: The entity that produced the message. One of `user` or `assistant`. + type: string + enum: ["user", "assistant"] + content: + description: The content of the message in array of text and/or images. + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageDeltaContentImageFileObject" + - $ref: "#/components/schemas/MessageDeltaContentTextObject" + - $ref: "#/components/schemas/MessageDeltaContentImageUrlObject" + x-oaiExpandable: true + required: + - id + - object + - delta + x-oaiMeta: + name: The message delta object + beta: true + example: | + { + "id": "msg_123", + "object": "thread.message.delta", + "delta": { + "content": [ + { + "index": 0, + "type": "text", + "text": { "value": "Hello", "annotations": [] } + } + ] + } + } + + CreateMessageRequest: + type: object + additionalProperties: false + required: + - role + - content + properties: + role: + type: string + enum: ["user", "assistant"] + description: | + The role of the entity that is creating the message. Allowed values include: + - `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages. + - `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation. + content: + oneOf: + - type: string + description: The text contents of the message. + title: Text content + - type: array + description: An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](/docs/models/overview). + title: Array of content parts + items: + oneOf: + - $ref: "#/components/schemas/MessageContentImageFileObject" + - $ref: "#/components/schemas/MessageContentImageUrlObject" + - $ref: "#/components/schemas/MessageRequestContentTextObject" + x-oaiExpandable: true + minItems: 1 + x-oaiExpandable: true + attachments: + type: array + items: + type: object + properties: + file_id: + type: string + description: The ID of the file to attach to the message. + tools: + description: The tools to add this file to. + type: array + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearchTypeOnly" + x-oaiExpandable: true + description: A list of files attached to the message, and the tools they should be added to. + required: + - file_id + - tools + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + ModifyMessageRequest: + type: object + additionalProperties: false + properties: + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + DeleteMessageResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [thread.message.deleted] + required: + - id + - object + - deleted + + ListMessagesResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/MessageObject" + first_id: + type: string + example: "msg_abc123" + last_id: + type: string + example: "msg_abc123" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + MessageContentImageFileObject: + title: Image file + type: object + description: References an image [File](/docs/api-reference/files) in the content of a message. + properties: + type: + description: Always `image_file`. + type: string + enum: ["image_file"] + image_file: + type: object + properties: + file_id: + description: The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content. + type: string + detail: + type: string + description: Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`. + enum: ["auto", "low", "high"] + default: "auto" + required: + - file_id + required: + - type + - image_file + + MessageDeltaContentImageFileObject: + title: Image file + type: object + description: References an image [File](/docs/api-reference/files) in the content of a message. + properties: + index: + type: integer + description: The index of the content part in the message. + type: + description: Always `image_file`. + type: string + enum: ["image_file"] + image_file: + type: object + properties: + file_id: + description: The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content. + type: string + detail: + type: string + description: Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`. + enum: ["auto", "low", "high"] + default: "auto" + required: + - index + - type + + MessageContentImageUrlObject: + title: Image URL + type: object + description: References an image URL in the content of a message. + properties: + type: + type: string + enum: ["image_url"] + description: The type of the content part. + image_url: + type: object + properties: + url: + type: string + description: "The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + format: uri + detail: + type: string + description: Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto` + enum: ["auto", "low", "high"] + default: "auto" + required: + - url + required: + - type + - image_url + + MessageDeltaContentImageUrlObject: + title: Image URL + type: object + description: References an image URL in the content of a message. + properties: + index: + type: integer + description: The index of the content part in the message. + type: + description: Always `image_url`. + type: string + enum: ["image_url"] + image_url: + type: object + properties: + url: + description: "The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + type: string + detail: + type: string + description: Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. + enum: ["auto", "low", "high"] + default: "auto" + required: + - index + - type + + MessageContentTextObject: + title: Text + type: object + description: The text content that is part of a message. + properties: + type: + description: Always `text`. + type: string + enum: ["text"] + text: + type: object + properties: + value: + description: The data that makes up the text. + type: string + annotations: + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageContentTextAnnotationsFileCitationObject" + - $ref: "#/components/schemas/MessageContentTextAnnotationsFilePathObject" + x-oaiExpandable: true + required: + - value + - annotations + required: + - type + - text + + MessageRequestContentTextObject: + title: Text + type: object + description: The text content that is part of a message. + properties: + type: + description: Always `text`. + type: string + enum: ["text"] + text: + type: string + description: Text content to be sent to the model + required: + - type + - text + + MessageContentTextAnnotationsFileCitationObject: + title: File citation + type: object + description: A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. + properties: + type: + description: Always `file_citation`. + type: string + enum: ["file_citation"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_citation: + type: object + properties: + file_id: + description: The ID of the specific File the citation is from. + type: string + required: + - file_id + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 + required: + - type + - text + - file_citation + - start_index + - end_index + + MessageContentTextAnnotationsFilePathObject: + title: File path + type: object + description: A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. + properties: + type: + description: Always `file_path`. + type: string + enum: ["file_path"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_path: + type: object + properties: + file_id: + description: The ID of the file that was generated. + type: string + required: + - file_id + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 + required: + - type + - text + - file_path + - start_index + - end_index + + MessageDeltaContentTextObject: + title: Text + type: object + description: The text content that is part of a message. + properties: + index: + type: integer + description: The index of the content part in the message. + type: + description: Always `text`. + type: string + enum: ["text"] + text: + type: object + properties: + value: + description: The data that makes up the text. + type: string + annotations: + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageDeltaContentTextAnnotationsFileCitationObject" + - $ref: "#/components/schemas/MessageDeltaContentTextAnnotationsFilePathObject" + x-oaiExpandable: true + required: + - index + - type + + MessageDeltaContentTextAnnotationsFileCitationObject: + title: File citation + type: object + description: A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. + properties: + index: + type: integer + description: The index of the annotation in the text content part. + type: + description: Always `file_citation`. + type: string + enum: ["file_citation"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_citation: + type: object + properties: + file_id: + description: The ID of the specific File the citation is from. + type: string + quote: + description: The specific quote in the file. + type: string + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 + required: + - index + - type + + MessageDeltaContentTextAnnotationsFilePathObject: + title: File path + type: object + description: A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. + properties: + index: + type: integer + description: The index of the annotation in the text content part. + type: + description: Always `file_path`. + type: string + enum: ["file_path"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_path: + type: object + properties: + file_id: + description: The ID of the file that was generated. + type: string + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 + required: + - index + - type + + RunStepObject: + type: object + title: Run steps + description: | + Represents a step in execution of a run. + properties: + id: + description: The identifier of the run step, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.run.step`. + type: string + enum: ["thread.run.step"] + created_at: + description: The Unix timestamp (in seconds) for when the run step was created. + type: integer + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) associated with the run step. + type: string + thread_id: + description: The ID of the [thread](/docs/api-reference/threads) that was run. + type: string + run_id: + description: The ID of the [run](/docs/api-reference/runs) that this run step is a part of. + type: string + type: + description: The type of run step, which can be either `message_creation` or `tool_calls`. + type: string + enum: ["message_creation", "tool_calls"] + status: + description: The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`. + type: string + enum: ["in_progress", "cancelled", "failed", "completed", "expired"] + step_details: + type: object + description: The details of the run step. + oneOf: + - $ref: "#/components/schemas/RunStepDetailsMessageCreationObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsObject" + x-oaiExpandable: true + last_error: + type: object + description: The last error associated with this run step. Will be `null` if there are no errors. + nullable: true + properties: + code: + type: string + description: One of `server_error` or `rate_limit_exceeded`. + enum: ["server_error", "rate_limit_exceeded"] + message: + type: string + description: A human-readable description of the error. + required: + - code + - message + expired_at: + description: The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired. + type: integer + nullable: true + cancelled_at: + description: The Unix timestamp (in seconds) for when the run step was cancelled. + type: integer + nullable: true + failed_at: + description: The Unix timestamp (in seconds) for when the run step failed. + type: integer + nullable: true + completed_at: + description: The Unix timestamp (in seconds) for when the run step completed. + type: integer + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + usage: + $ref: "#/components/schemas/RunStepCompletionUsage" + required: + - id + - object + - created_at + - assistant_id + - thread_id + - run_id + - type + - status + - step_details + - last_error + - expired_at + - cancelled_at + - failed_at + - completed_at + - metadata + - usage + x-oaiMeta: + name: The run step object + beta: true + example: *run_step_object_example + + RunStepDeltaObject: + type: object + title: Run step delta object + description: | + Represents a run step delta i.e. any changed fields on a run step during streaming. + properties: + id: + description: The identifier of the run step, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.run.step.delta`. + type: string + enum: ["thread.run.step.delta"] + delta: + description: The delta containing the fields that have changed on the run step. + type: object + properties: + step_details: + type: object + description: The details of the run step. + oneOf: + - $ref: "#/components/schemas/RunStepDeltaStepDetailsMessageCreationObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsObject" + x-oaiExpandable: true + required: + - id + - object + - delta + x-oaiMeta: + name: The run step delta object + beta: true + example: | + { + "id": "step_123", + "object": "thread.run.step.delta", + "delta": { + "step_details": { + "type": "tool_calls", + "tool_calls": [ + { + "index": 0, + "id": "call_123", + "type": "code_interpreter", + "code_interpreter": { "input": "", "outputs": [] } + } + ] + } + } + } + + ListRunStepsResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/RunStepObject" + first_id: + type: string + example: "step_abc123" + last_id: + type: string + example: "step_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + RunStepDetailsMessageCreationObject: + title: Message creation + type: object + description: Details of the message creation by the run step. + properties: + type: + description: Always `message_creation`. + type: string + enum: ["message_creation"] + message_creation: + type: object + properties: + message_id: + type: string + description: The ID of the message that was created by this run step. + required: + - message_id + required: + - type + - message_creation + + RunStepDeltaStepDetailsMessageCreationObject: + title: Message creation + type: object + description: Details of the message creation by the run step. + properties: + type: + description: Always `message_creation`. + type: string + enum: ["message_creation"] + message_creation: + type: object + properties: + message_id: + type: string + description: The ID of the message that was created by this run step. + required: + - type + + RunStepDetailsToolCallsObject: + title: Tool calls + type: object + description: Details of the tool call. + properties: + type: + description: Always `tool_calls`. + type: string + enum: ["tool_calls"] + tool_calls: + type: array + description: | + An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. + items: + oneOf: + - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsFileSearchObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsFunctionObject" + x-oaiExpandable: true + required: + - type + - tool_calls + + RunStepDeltaStepDetailsToolCallsObject: + title: Tool calls + type: object + description: Details of the tool call. + properties: + type: + description: Always `tool_calls`. + type: string + enum: ["tool_calls"] + tool_calls: + type: array + description: | + An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. + items: + oneOf: + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsFileSearchObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsFunctionObject" + x-oaiExpandable: true + required: + - type + + RunStepDetailsToolCallsCodeObject: + title: Code Interpreter tool call + type: object + description: Details of the Code Interpreter tool call the run step was involved in. + properties: + id: + type: string + description: The ID of the tool call. + type: + type: string + description: The type of tool call. This is always going to be `code_interpreter` for this type of tool call. + enum: ["code_interpreter"] + code_interpreter: + type: object + description: The Code Interpreter tool call definition. + required: + - input + - outputs + properties: + input: + type: string + description: The input to the Code Interpreter tool call. + outputs: + type: array + description: The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. + items: + type: object + oneOf: + - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeOutputLogsObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeOutputImageObject" + x-oaiExpandable: true + required: + - id + - type + - code_interpreter + + RunStepDeltaStepDetailsToolCallsCodeObject: + title: Code interpreter tool call + type: object + description: Details of the Code Interpreter tool call the run step was involved in. + properties: + index: + type: integer + description: The index of the tool call in the tool calls array. + id: + type: string + description: The ID of the tool call. + type: + type: string + description: The type of tool call. This is always going to be `code_interpreter` for this type of tool call. + enum: ["code_interpreter"] + code_interpreter: + type: object + description: The Code Interpreter tool call definition. + properties: + input: + type: string + description: The input to the Code Interpreter tool call. + outputs: + type: array + description: The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. + items: + type: object + oneOf: + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeOutputImageObject" + x-oaiExpandable: true + required: + - index + - type + + RunStepDetailsToolCallsCodeOutputLogsObject: + title: Code Interpreter log output + type: object + description: Text output from the Code Interpreter tool call as part of a run step. + properties: + type: + description: Always `logs`. + type: string + enum: ["logs"] + logs: + type: string + description: The text output from the Code Interpreter tool call. + required: + - type + - logs + + RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject: + title: Code interpreter log output + type: object + description: Text output from the Code Interpreter tool call as part of a run step. + properties: + index: + type: integer + description: The index of the output in the outputs array. + type: + description: Always `logs`. + type: string + enum: ["logs"] + logs: + type: string + description: The text output from the Code Interpreter tool call. + required: + - index + - type + + RunStepDetailsToolCallsCodeOutputImageObject: + title: Code Interpreter image output + type: object + properties: + type: + description: Always `image`. + type: string + enum: ["image"] + image: + type: object + properties: + file_id: + description: The [file](/docs/api-reference/files) ID of the image. + type: string + required: + - file_id + required: + - type + - image + + RunStepDeltaStepDetailsToolCallsCodeOutputImageObject: + title: Code interpreter image output + type: object + properties: + index: + type: integer + description: The index of the output in the outputs array. + type: + description: Always `image`. + type: string + enum: ["image"] + image: + type: object + properties: + file_id: + description: The [file](/docs/api-reference/files) ID of the image. + type: string + required: + - index + - type + + RunStepDetailsToolCallsFileSearchObject: + title: File search tool call + type: object + properties: + id: + type: string + description: The ID of the tool call object. + type: + type: string + description: The type of tool call. This is always going to be `file_search` for this type of tool call. + enum: ["file_search"] + file_search: + type: object + description: For now, this is always going to be an empty object. + x-oaiTypeLabel: map + required: + - id + - type + - file_search + + RunStepDeltaStepDetailsToolCallsFileSearchObject: + title: File search tool call + type: object + properties: + index: + type: integer + description: The index of the tool call in the tool calls array. + id: + type: string + description: The ID of the tool call object. + type: + type: string + description: The type of tool call. This is always going to be `file_search` for this type of tool call. + enum: ["file_search"] + file_search: + type: object + description: For now, this is always going to be an empty object. + x-oaiTypeLabel: map + required: + - index + - type + - file_search + + RunStepDetailsToolCallsFunctionObject: + type: object + title: Function tool call + properties: + id: + type: string + description: The ID of the tool call object. + type: + type: string + description: The type of tool call. This is always going to be `function` for this type of tool call. + enum: ["function"] + function: + type: object + description: The definition of the function that was called. + properties: + name: + type: string + description: The name of the function. + arguments: + type: string + description: The arguments passed to the function. + output: + type: string + description: The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet. + nullable: true + required: + - name + - arguments + - output + required: + - id + - type + - function + + RunStepDeltaStepDetailsToolCallsFunctionObject: + type: object + title: Function tool call + properties: + index: + type: integer + description: The index of the tool call in the tool calls array. + id: + type: string + description: The ID of the tool call object. + type: + type: string + description: The type of tool call. This is always going to be `function` for this type of tool call. + enum: ["function"] + function: + type: object + description: The definition of the function that was called. + properties: + name: + type: string + description: The name of the function. + arguments: + type: string + description: The arguments passed to the function. + output: + type: string + description: The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet. + nullable: true + required: + - index + - type + + VectorStoreExpirationAfter: + type: object + title: Vector store expiration policy + description: The expiration policy for a vector store. + properties: + anchor: + description: "Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`." + type: string + enum: ["last_active_at"] + days: + description: The number of days after the anchor time that the vector store will expire. + type: integer + minimum: 1 + maximum: 365 + required: + - anchor + - days + + VectorStoreObject: + type: object + title: Vector store + description: A vector store is a collection of processed files can be used by the `file_search` tool. + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `vector_store`. + type: string + enum: ["vector_store"] + created_at: + description: The Unix timestamp (in seconds) for when the vector store was created. + type: integer + name: + description: The name of the vector store. + type: string + usage_bytes: + description: The total number of bytes used by the files in the vector store. + type: integer + file_counts: + type: object + properties: + in_progress: + description: The number of files that are currently being processed. + type: integer + completed: + description: The number of files that have been successfully processed. + type: integer + failed: + description: The number of files that have failed to process. + type: integer + cancelled: + description: The number of files that were cancelled. + type: integer + total: + description: The total number of files. + type: integer + required: + - in_progress + - completed + - failed + - cancelled + - total + status: + description: The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use. + type: string + enum: ["expired", "in_progress", "completed"] + expires_after: + $ref: "#/components/schemas/VectorStoreExpirationAfter" + expires_at: + description: The Unix timestamp (in seconds) for when the vector store will expire. + type: integer + nullable: true + last_active_at: + description: The Unix timestamp (in seconds) for when the vector store was last active. + type: integer + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + required: + - id + - object + - usage_bytes + - created_at + - status + - last_active_at + - name + - file_counts + - metadata + x-oaiMeta: + name: The vector store object + beta: true + example: | + { + "id": "vs_123", + "object": "vector_store", + "created_at": 1698107661, + "usage_bytes": 123456, + "last_active_at": 1698107661, + "name": "my_vector_store", + "status": "completed", + "file_counts": { + "in_progress": 0, + "completed": 100, + "cancelled": 0, + "failed": 0, + "total": 100 + }, + "metadata": {}, + "last_used_at": 1698107661 + } + + CreateVectorStoreRequest: + type: object + additionalProperties: false + properties: + file_ids: + description: A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. + type: array + maxItems: 500 + items: + type: string + name: + description: The name of the vector store. + type: string + expires_after: + $ref: "#/components/schemas/VectorStoreExpirationAfter" + chunking_strategy: + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty. + oneOf: + - $ref: "#/components/schemas/AutoChunkingStrategyRequestParam" + - $ref: "#/components/schemas/StaticChunkingStrategyRequestParam" + x-oaiExpandable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + UpdateVectorStoreRequest: + type: object + additionalProperties: false + properties: + name: + description: The name of the vector store. + type: string + nullable: true + expires_after: + $ref: "#/components/schemas/VectorStoreExpirationAfter" + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + ListVectorStoresResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/VectorStoreObject" + first_id: + type: string + example: "vs_abc123" + last_id: + type: string + example: "vs_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + DeleteVectorStoreResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [vector_store.deleted] + required: + - id + - object + - deleted + + VectorStoreFileObject: + type: object + title: Vector store files + description: A list of files attached to a vector store. + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `vector_store.file`. + type: string + enum: ["vector_store.file"] + usage_bytes: + description: The total vector store usage in bytes. Note that this may be different from the original file size. + type: integer + created_at: + description: The Unix timestamp (in seconds) for when the vector store file was created. + type: integer + vector_store_id: + description: The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to. + type: string + status: + description: The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use. + type: string + enum: ["in_progress", "completed", "cancelled", "failed"] + last_error: + type: object + description: The last error associated with this vector store file. Will be `null` if there are no errors. + nullable: true + properties: + code: + type: string + description: One of `server_error` or `rate_limit_exceeded`. + enum: + [ + "internal_error", + "file_not_found", + "parsing_error", + "unhandled_mime_type", + ] + message: + type: string + description: A human-readable description of the error. + required: + - code + - message + chunking_strategy: + type: object + description: The strategy used to chunk the file. + oneOf: + - $ref: "#/components/schemas/StaticChunkingStrategyResponseParam" + - $ref: "#/components/schemas/OtherChunkingStrategyResponseParam" + x-oaiExpandable: true + required: + - id + - object + - usage_bytes + - created_at + - vector_store_id + - status + - last_error + x-oaiMeta: + name: The vector store file object + beta: true + example: | + { + "id": "file-abc123", + "object": "vector_store.file", + "usage_bytes": 1234, + "created_at": 1698107661, + "vector_store_id": "vs_abc123", + "status": "completed", + "last_error": null, + "chunking_strategy": { + "type": "static", + "static": { + "max_chunk_size_tokens": 800, + "chunk_overlap_tokens": 400 + } + } + } + + OtherChunkingStrategyResponseParam: + type: object + title: Other Chunking Strategy + description: This is returned when the chunking strategy is unknown. Typically, this is because the file was indexed before the `chunking_strategy` concept was introduced in the API. + additionalProperties: false + properties: + type: + type: string + description: Always `other`. + enum: ["other"] + required: + - type + + StaticChunkingStrategyResponseParam: + type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + $ref: "#/components/schemas/StaticChunkingStrategy" + required: + - type + - static + + StaticChunkingStrategy: + type: object + additionalProperties: false + properties: + max_chunk_size_tokens: + type: integer + minimum: 100 + maximum: 4096 + description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. + chunk_overlap_tokens: + type: integer + description: | + The number of tokens that overlap between chunks. The default value is `400`. + + Note that the overlap must not exceed half of `max_chunk_size_tokens`. + required: + - max_chunk_size_tokens + - chunk_overlap_tokens + + AutoChunkingStrategyRequestParam: + type: object + title: Auto Chunking Strategy + description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + additionalProperties: false + properties: + type: + type: string + description: Always `auto`. + enum: ["auto"] + required: + - type + + StaticChunkingStrategyRequestParam: + type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + $ref: "#/components/schemas/StaticChunkingStrategy" + required: + - type + - static + + ChunkingStrategyRequestParam: + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. + oneOf: + - $ref: "#/components/schemas/AutoChunkingStrategyRequestParam" + - $ref: "#/components/schemas/StaticChunkingStrategyRequestParam" + x-oaiExpandable: true + + CreateVectorStoreFileRequest: + type: object + additionalProperties: false + properties: + file_id: + description: A [File](/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files. + type: string + chunking_strategy: + $ref: "#/components/schemas/ChunkingStrategyRequestParam" + required: + - file_id + + ListVectorStoreFilesResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/VectorStoreFileObject" + first_id: + type: string + example: "file-abc123" + last_id: + type: string + example: "file-abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + DeleteVectorStoreFileResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [vector_store.file.deleted] + required: + - id + - object + - deleted + + VectorStoreFileBatchObject: + type: object + title: Vector store file batch + description: A batch of files attached to a vector store. + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `vector_store.file_batch`. + type: string + enum: ["vector_store.files_batch"] + created_at: + description: The Unix timestamp (in seconds) for when the vector store files batch was created. + type: integer + vector_store_id: + description: The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to. + type: string + status: + description: The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`. + type: string + enum: ["in_progress", "completed", "cancelled", "failed"] + file_counts: + type: object + properties: + in_progress: + description: The number of files that are currently being processed. + type: integer + completed: + description: The number of files that have been processed. + type: integer + failed: + description: The number of files that have failed to process. + type: integer + cancelled: + description: The number of files that where cancelled. + type: integer + total: + description: The total number of files. + type: integer + required: + - in_progress + - completed + - cancelled + - failed + - total + required: + - id + - object + - created_at + - vector_store_id + - status + - file_counts + x-oaiMeta: + name: The vector store files batch object + beta: true + example: | + { + "id": "vsfb_123", + "object": "vector_store.files_batch", + "created_at": 1698107661, + "vector_store_id": "vs_abc123", + "status": "completed", + "file_counts": { + "in_progress": 0, + "completed": 100, + "failed": 0, + "cancelled": 0, + "total": 100 + } + } + + CreateVectorStoreFileBatchRequest: + type: object + additionalProperties: false + properties: + file_ids: + description: A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. + type: array + minItems: 1 + maxItems: 500 + items: + type: string + chunking_strategy: + $ref: "#/components/schemas/ChunkingStrategyRequestParam" + required: + - file_ids + + AssistantStreamEvent: + description: | + Represents an event emitted when streaming a Run. + + Each event in a server-sent events stream has an `event` and `data` property: + + ``` + event: thread.created + data: {"id": "thread_123", "object": "thread", ...} + ``` + + We emit events whenever a new object is created, transitions to a new state, or is being + streamed in parts (deltas). For example, we emit `thread.run.created` when a new run + is created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses + to create a message during a run, we emit a `thread.message.created event`, a + `thread.message.in_progress` event, many `thread.message.delta` events, and finally a + `thread.message.completed` event. + + We may add additional events over time, so we recommend handling unknown events gracefully + in your code. See the [Assistants API quickstart](/docs/assistants/overview) to learn how to + integrate the Assistants API with streaming. + oneOf: + - $ref: "#/components/schemas/ThreadStreamEvent" + - $ref: "#/components/schemas/RunStreamEvent" + - $ref: "#/components/schemas/RunStepStreamEvent" + - $ref: "#/components/schemas/MessageStreamEvent" + - $ref: "#/components/schemas/ErrorEvent" + - $ref: "#/components/schemas/DoneEvent" + x-oaiMeta: + name: Assistant stream events + beta: true + + ThreadStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.created"] + data: + $ref: "#/components/schemas/ThreadObject" + required: + - event + - data + description: Occurs when a new [thread](/docs/api-reference/threads/object) is created. + x-oaiMeta: + dataDescription: "`data` is a [thread](/docs/api-reference/threads/object)" + + RunStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.run.created"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a new [run](/docs/api-reference/runs/object) is created. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.queued"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `queued` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.in_progress"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to an `in_progress` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.requires_action"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `requires_action` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.completed"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) is completed. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.incomplete"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) ends with status `incomplete`. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.failed"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) fails. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.cancelling"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `cancelling` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.cancelled"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) is cancelled. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.expired"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) expires. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + + RunStepStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.run.step.created"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) is created. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.in_progress"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) moves to an `in_progress` state. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.delta"] + data: + $ref: "#/components/schemas/RunStepDeltaObject" + required: + - event + - data + description: Occurs when parts of a [run step](/docs/api-reference/runs/step-object) are being streamed. + x-oaiMeta: + dataDescription: "`data` is a [run step delta](/docs/api-reference/assistants-streaming/run-step-delta-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.completed"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) is completed. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.failed"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) fails. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.cancelled"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) is cancelled. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.expired"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) expires. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + + MessageStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.message.created"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) is created. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.in_progress"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) moves to an `in_progress` state. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.delta"] + data: + $ref: "#/components/schemas/MessageDeltaObject" + required: + - event + - data + description: Occurs when parts of a [Message](/docs/api-reference/messages/object) are being streamed. + x-oaiMeta: + dataDescription: "`data` is a [message delta](/docs/api-reference/assistants-streaming/message-delta-object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.completed"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) is completed. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.incomplete"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) ends before it is completed. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + + ErrorEvent: + type: object + properties: + event: + type: string + enum: ["error"] + data: + $ref: "#/components/schemas/Error" + required: + - event + - data + description: Occurs when an [error](/docs/guides/error-codes/api-errors) occurs. This can happen due to an internal server error or a timeout. + x-oaiMeta: + dataDescription: "`data` is an [error](/docs/guides/error-codes/api-errors)" + + DoneEvent: + type: object + properties: + event: + type: string + enum: ["done"] + data: + type: string + enum: ["[DONE]"] + required: + - event + - data + description: Occurs when a stream ends. + x-oaiMeta: + dataDescription: "`data` is `[DONE]`" + + Batch: + type: object + properties: + id: + type: string + object: + type: string + enum: [batch] + description: The object type, which is always `batch`. + endpoint: + type: string + description: The OpenAI API endpoint used by the batch. + + errors: + type: object + properties: + object: + type: string + description: The object type, which is always `list`. + data: + type: array + items: + type: object + properties: + code: + type: string + description: An error code identifying the error type. + message: + type: string + description: A human-readable message providing more details about the error. + param: + type: string + description: The name of the parameter that caused the error, if applicable. + nullable: true + line: + type: integer + description: The line number of the input file where the error occurred, if applicable. + nullable: true + input_file_id: + type: string + description: The ID of the input file for the batch. + completion_window: + type: string + description: The time frame within which the batch should be processed. + status: + type: string + description: The current status of the batch. + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + output_file_id: + type: string + description: The ID of the file containing the outputs of successfully executed requests. + error_file_id: + type: string + description: The ID of the file containing the outputs of requests with errors. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch was created. + in_progress_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch started processing. + expires_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch will expire. + finalizing_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch started finalizing. + completed_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch was completed. + failed_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch failed. + expired_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch expired. + cancelling_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch started cancelling. + cancelled_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch was cancelled. + request_counts: + type: object + properties: + total: + type: integer + description: Total number of requests in the batch. + completed: + type: integer + description: Number of requests that have been completed successfully. + failed: + type: integer + description: Number of requests that have failed. + required: + - total + - completed + - failed + description: The request counts for different statuses within the batch. + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + required: + - id + - object + - endpoint + - input_file_id + - completion_window + - status + - created_at + x-oaiMeta: + name: The batch object + example: *batch_object + + BatchRequestInput: + type: object + description: The per-line object of the batch input file + properties: + custom_id: + type: string + description: A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch. + method: + type: string + enum: ["POST"] + description: The HTTP method to be used for the request. Currently only `POST` is supported. + url: + type: string + description: The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported. + x-oaiMeta: + name: The request input object + example: | + {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is 2+2?"}]}} + + BatchRequestOutput: + type: object + description: The per-line object of the batch output and error files + properties: + id: + type: string + custom_id: + type: string + description: A developer-provided per-request id that will be used to match outputs to inputs. + response: + type: object + nullable: true + properties: + status_code: + type: integer + description: The HTTP status code of the response + request_id: + type: string + description: An unique identifier for the OpenAI API request. Please include this request ID when contacting support. + body: + type: object + x-oaiTypeLabel: map + description: The JSON body of the response + error: + type: object + nullable: true + description: For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure. + properties: + code: + type: string + description: A machine-readable error code. + message: + type: string + description: A human-readable error message. + x-oaiMeta: + name: The request output object + example: | + {"id": "batch_req_wnaDys", "custom_id": "request-2", "response": {"status_code": 200, "request_id": "req_c187b3", "body": {"id": "chatcmpl-9758Iw", "object": "chat.completion", "created": 1711475054, "model": "gpt-3.5-turbo", "choices": [{"index": 0, "message": {"role": "assistant", "content": "2 + 2 equals 4."}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 24, "completion_tokens": 15, "total_tokens": 39}, "system_fingerprint": null}}, "error": null} + + ListBatchesResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/Batch" + first_id: + type: string + example: "batch_abc123" + last_id: + type: string + example: "batch_abc456" + has_more: + type: boolean + object: + type: string + enum: [list] + required: + - object + - data + - has_more + +security: + - ApiKeyAuth: [] + +x-oaiMeta: + navigationGroups: + - id: endpoints + title: Endpoints + - id: assistants + title: Assistants + - id: legacy + title: Legacy + groups: + # > General Notes + # The `groups` section is used to generate the API reference pages and navigation, in the same + # order listed below. Additionally, each `group` can have a list of `sections`, each of which + # will become a navigation subroute and subsection under the group. Each section has: + # - `type`: Currently, either an `endpoint` or `object`, depending on how the section needs to + # be rendered + # - `key`: The reference key that can be used to lookup the section definition + # - `path`: The path (url) of the section, which is used to generate the navigation link. + # + # > The `object` sections maps to a schema component and the following fields are read for rendering + # - `x-oaiMeta.name`: The name of the object, which will become the section title + # - `x-oaiMeta.example`: The example object, which will be used to generate the example sample (always JSON) + # - `description`: The description of the object, which will be used to generate the section description + # + # > The `endpoint` section maps to an operation path and the following fields are read for rendering: + # - `x-oaiMeta.name`: The name of the endpoint, which will become the section title + # - `x-oaiMeta.examples`: The endpoint examples, which can be an object (meaning a single variation, most + # endpoints, or an array of objects, meaning multiple variations, e.g. the + # chat completion and completion endpoints, with streamed and non-streamed examples. + # - `x-oaiMeta.returns`: text describing what the endpoint returns. + # - `summary`: The summary of the endpoint, which will be used to generate the section description + - id: audio + title: Audio + description: | + Learn how to turn audio into text or text into audio. + + Related guide: [Speech to text](/docs/guides/speech-to-text) + navigationGroup: endpoints + sections: + - type: endpoint + key: createSpeech + path: createSpeech + - type: endpoint + key: createTranscription + path: createTranscription + - type: endpoint + key: createTranslation + path: createTranslation + - type: object + key: CreateTranscriptionResponseJson + path: json-object + - type: object + key: CreateTranscriptionResponseVerboseJson + path: verbose-json-object + - id: chat + title: Chat + description: | + Given a list of messages comprising a conversation, the model will return a response. + + Related guide: [Chat Completions](/docs/guides/text-generation) + navigationGroup: endpoints + sections: + - type: endpoint + key: createChatCompletion + path: create + - type: object + key: CreateChatCompletionResponse + path: object + - type: object + key: CreateChatCompletionStreamResponse + path: streaming + - id: embeddings + title: Embeddings + description: | + Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms. + + Related guide: [Embeddings](/docs/guides/embeddings) + navigationGroup: endpoints + sections: + - type: endpoint + key: createEmbedding + path: create + - type: object + key: Embedding + path: object + - id: fine-tuning + title: Fine-tuning + description: | + Manage fine-tuning jobs to tailor a model to your specific training data. + + Related guide: [Fine-tune models](/docs/guides/fine-tuning) + navigationGroup: endpoints + sections: + - type: endpoint + key: createFineTuningJob + path: create + - type: endpoint + key: listPaginatedFineTuningJobs + path: list + - type: endpoint + key: listFineTuningEvents + path: list-events + - type: endpoint + key: listFineTuningJobCheckpoints + path: list-checkpoints + - type: endpoint + key: retrieveFineTuningJob + path: retrieve + - type: endpoint + key: cancelFineTuningJob + path: cancel + - type: object + key: FinetuneChatRequestInput + path: chat-input + - type: object + key: FinetuneCompletionRequestInput + path: completions-input + - type: object + key: FineTuningJob + path: object + - type: object + key: FineTuningJobEvent + path: event-object + - type: object + key: FineTuningJobCheckpoint + path: checkpoint-object + - id: batch + title: Batch + description: | + Create large batches of API requests for asynchronous processing. The Batch API returns completions within 24 hours for a 50% discount. + + Related guide: [Batch](/docs/guides/batch) + navigationGroup: endpoints + sections: + - type: endpoint + key: createBatch + path: create + - type: endpoint + key: retrieveBatch + path: retrieve + - type: endpoint + key: cancelBatch + path: cancel + - type: endpoint + key: listBatches + path: list + - type: object + key: Batch + path: object + - type: object + key: BatchRequestInput + path: request-input + - type: object + key: BatchRequestOutput + path: request-output + - id: files + title: Files + description: | + Files are used to upload documents that can be used with features like [Assistants](/docs/api-reference/assistants), [Fine-tuning](/docs/api-reference/fine-tuning), and [Batch API](/docs/guides/batch). + navigationGroup: endpoints + sections: + - type: endpoint + key: createFile + path: create + - type: endpoint + key: listFiles + path: list + - type: endpoint + key: retrieveFile + path: retrieve + - type: endpoint + key: deleteFile + path: delete + - type: endpoint + key: downloadFile + path: retrieve-contents + - type: object + key: OpenAIFile + path: object + - id: uploads + title: Uploads + description: | + Allows you to upload large files in multiple parts. + navigationGroup: endpoints + sections: + - type: endpoint + key: createUpload + path: create + - type: endpoint + key: addUploadPart + path: add-part + - type: endpoint + key: completeUpload + path: complete + - type: endpoint + key: cancelUpload + path: cancel + - type: object + key: Upload + path: object + - type: object + key: UploadPart + path: part-object + - id: images + title: Images + description: | + Given a prompt and/or an input image, the model will generate a new image. + + Related guide: [Image generation](/docs/guides/images) + navigationGroup: endpoints + sections: + - type: endpoint + key: createImage + path: create + - type: endpoint + key: createImageEdit + path: createEdit + - type: endpoint + key: createImageVariation + path: createVariation + - type: object + key: Image + path: object + - id: models + title: Models + description: | + List and describe the various models available in the API. You can refer to the [Models](/docs/models) documentation to understand what models are available and the differences between them. + navigationGroup: endpoints + sections: + - type: endpoint + key: listModels + path: list + - type: endpoint + key: retrieveModel + path: retrieve + - type: endpoint + key: deleteModel + path: delete + - type: object + key: Model + path: object + - id: moderations + title: Moderations + description: | + Given some input text, outputs if the model classifies it as potentially harmful across several categories. + + Related guide: [Moderations](/docs/guides/moderation) + navigationGroup: endpoints + sections: + - type: endpoint + key: createModeration + path: create + - type: object + key: CreateModerationResponse + path: object + - id: assistants + title: Assistants + beta: true + description: | + Build assistants that can call models and use tools to perform tasks. + + [Get started with the Assistants API](/docs/assistants) + navigationGroup: assistants + sections: + - type: endpoint + key: createAssistant + path: createAssistant + - type: endpoint + key: listAssistants + path: listAssistants + - type: endpoint + key: getAssistant + path: getAssistant + - type: endpoint + key: modifyAssistant + path: modifyAssistant + - type: endpoint + key: deleteAssistant + path: deleteAssistant + - type: object + key: AssistantObject + path: object + - id: threads + title: Threads + beta: true + description: | + Create threads that assistants can interact with. + + Related guide: [Assistants](/docs/assistants/overview) + navigationGroup: assistants + sections: + - type: endpoint + key: createThread + path: createThread + - type: endpoint + key: getThread + path: getThread + - type: endpoint + key: modifyThread + path: modifyThread + - type: endpoint + key: deleteThread + path: deleteThread + - type: object + key: ThreadObject + path: object + - id: messages + title: Messages + beta: true + description: | + Create messages within threads + + Related guide: [Assistants](/docs/assistants/overview) + navigationGroup: assistants + sections: + - type: endpoint + key: createMessage + path: createMessage + - type: endpoint + key: listMessages + path: listMessages + - type: endpoint + key: getMessage + path: getMessage + - type: endpoint + key: modifyMessage + path: modifyMessage + - type: endpoint + key: deleteMessage + path: deleteMessage + - type: object + key: MessageObject + path: object + - id: runs + title: Runs + beta: true + description: | + Represents an execution run on a thread. + + Related guide: [Assistants](/docs/assistants/overview) + navigationGroup: assistants + sections: + - type: endpoint + key: createRun + path: createRun + - type: endpoint + key: createThreadAndRun + path: createThreadAndRun + - type: endpoint + key: listRuns + path: listRuns + - type: endpoint + key: getRun + path: getRun + - type: endpoint + key: modifyRun + path: modifyRun + - type: endpoint + key: submitToolOuputsToRun + path: submitToolOutputs + - type: endpoint + key: cancelRun + path: cancelRun + - type: object + key: RunObject + path: object + - id: run-steps + title: Run Steps + beta: true + description: | + Represents the steps (model and tool calls) taken during the run. + + Related guide: [Assistants](/docs/assistants/overview) + navigationGroup: assistants + sections: + - type: endpoint + key: listRunSteps + path: listRunSteps + - type: endpoint + key: getRunStep + path: getRunStep + - type: object + key: RunStepObject + path: step-object + - id: vector-stores + title: Vector Stores + beta: true + description: | + Vector stores are used to store files for use by the `file_search` tool. + + Related guide: [File Search](/docs/assistants/tools/file-search) + navigationGroup: assistants + sections: + - type: endpoint + key: createVectorStore + path: create + - type: endpoint + key: listVectorStores + path: list + - type: endpoint + key: getVectorStore + path: retrieve + - type: endpoint + key: modifyVectorStore + path: modify + - type: endpoint + key: deleteVectorStore + path: delete + - type: object + key: VectorStoreObject + path: object + - id: vector-stores-files + title: Vector Store Files + beta: true + description: | + Vector store files represent files inside a vector store. + + Related guide: [File Search](/docs/assistants/tools/file-search) + navigationGroup: assistants + sections: + - type: endpoint + key: createVectorStoreFile + path: createFile + - type: endpoint + key: listVectorStoreFiles + path: listFiles + - type: endpoint + key: getVectorStoreFile + path: getFile + - type: endpoint + key: deleteVectorStoreFile + path: deleteFile + - type: object + key: VectorStoreFileObject + path: file-object + - id: vector-stores-file-batches + title: Vector Store File Batches + beta: true + description: | + Vector store file batches represent operations to add multiple files to a vector store. + + Related guide: [File Search](/docs/assistants/tools/file-search) + navigationGroup: assistants + sections: + - type: endpoint + key: createVectorStoreFileBatch + path: createBatch + - type: endpoint + key: getVectorStoreFileBatch + path: getBatch + - type: endpoint + key: cancelVectorStoreFileBatch + path: cancelBatch + - type: endpoint + key: listFilesInVectorStoreBatch + path: listBatchFiles + - type: object + key: VectorStoreFileBatchObject + path: batch-object + - id: assistants-streaming + title: Streaming + beta: true + description: | + Stream the result of executing a Run or resuming a Run after submitting tool outputs. + + You can stream events from the [Create Thread and Run](/docs/api-reference/runs/createThreadAndRun), + [Create Run](/docs/api-reference/runs/createRun), and [Submit Tool Outputs](/docs/api-reference/runs/submitToolOutputs) + endpoints by passing `"stream": true`. The response will be a [Server-Sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events) stream. + + Our Node and Python SDKs provide helpful utilities to make streaming easy. Reference the + [Assistants API quickstart](/docs/assistants/overview) to learn more. + navigationGroup: assistants + sections: + - type: object + key: MessageDeltaObject + path: message-delta-object + - type: object + key: RunStepDeltaObject + path: run-step-delta-object + - type: object + key: AssistantStreamEvent + path: events + - id: completions + title: Completions + legacy: true + navigationGroup: legacy + description: | + Given a prompt, the model will return one or more predicted completions along with the probabilities of alternative tokens at each position. Most developer should use our [Chat Completions API](/docs/guides/text-generation/text-generation-models) to leverage our best and newest models. + sections: + - type: endpoint + key: createCompletion + path: create + - type: object + key: CreateCompletionResponse + path: object From dcc5ff84c6d634bb5b3e963d3df1c71006bff256 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 29 Jul 2024 10:30:06 -0700 Subject: [PATCH 349/425] Bump istio proxy memory for gateway (#580) --- charts/model-engine/templates/gateway_deployment.yaml | 2 ++ model-engine/model_engine_server/api/app.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/charts/model-engine/templates/gateway_deployment.yaml b/charts/model-engine/templates/gateway_deployment.yaml index e6466544..937071b2 100644 --- a/charts/model-engine/templates/gateway_deployment.yaml +++ b/charts/model-engine/templates/gateway_deployment.yaml @@ -26,6 +26,8 @@ spec: "service": {{ include "modelEngine.fullname" . | quote }}, "source": "python" }] + sidecar.istio.io/proxyMemoryLimit: "5Gi" + sidecar.istio.io/proxyMemory: "1Gi" labels: {{- include "modelEngine.selectorLabels.gateway" . | nindent 8 }} {{- include "modelEngine.labels" . | nindent 8 }} diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index c45eedaf..851f0183 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -46,7 +46,6 @@ class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): try: - logger.info(f"Received request at {request.url.path}") LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) LoggerTagManager.set(LoggerTagKey.REQUEST_SIZE, request.headers.get("content-length")) # we intentionally exclude healthcheck routes from the concurrency limiter From 6e35c71cf82622fe2ad6e745728a65a1ff6f3984 Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 29 Jul 2024 15:10:23 -0700 Subject: [PATCH 350/425] Make configs backwards-compatible (#581) --- model-engine/model_engine_server/common/config.py | 7 ++++++- model-engine/model_engine_server/core/config.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 16a4dd00..201918d6 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -1,6 +1,7 @@ # Keep in line with service_config_{*}.yaml # This file loads sensitive data that shouldn't make it to inference docker images # Do not include this file in our inference/endpoint code +import inspect import os from dataclasses import dataclass from pathlib import Path @@ -76,11 +77,15 @@ class HostedModelInferenceServiceConfig: str ] = None # Not an env var because the redis cache info is already here + @classmethod + def from_json(cls, json): + return cls(**{k: v for k, v in json.items() if k in inspect.signature(cls).parameters}) + @classmethod def from_yaml(cls, yaml_path): with open(yaml_path, "r") as f: raw_data = yaml.safe_load(f) - return HostedModelInferenceServiceConfig(**raw_data) + return HostedModelInferenceServiceConfig.from_json(raw_data) @property def cache_redis_url(self) -> str: diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index 301b4f80..e68b61dd 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -4,6 +4,7 @@ If this is not set, the default configuration file is used from model_engine_server.core/configs/default.yaml. """ +import inspect import os from contextlib import contextmanager from copy import deepcopy @@ -47,11 +48,15 @@ class InfraConfig: firehose_role_arn: Optional[str] = None firehose_stream_name: Optional[str] = None + @classmethod + def from_json(cls, json): + return cls(**{k: v for k, v in json.items() if k in inspect.signature(cls).parameters}) + @classmethod def from_yaml(cls, yaml_path) -> "InfraConfig": with open(yaml_path, "r") as f: raw_data = yaml.safe_load(f) - return InfraConfig(**raw_data) + return InfraConfig.from_json(raw_data) def read_default_config(): From d03363826d5f531b1ef22fe71a830c2137bfa056 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Mon, 29 Jul 2024 15:58:11 -0700 Subject: [PATCH 351/425] Reduce connection pool size (#582) --- model-engine/model_engine_server/db/base.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index d6beefe9..b5164366 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -127,32 +127,32 @@ def refresh_sessions(): echo=False, future=True, pool_pre_ping=True, - pool_size=20, - max_overflow=30, + pool_size=10, + max_overflow=10, ) pg_engine_read_only = create_engine( get_engine_url(read_only=True, sync=True), echo=False, future=True, pool_pre_ping=True, - pool_size=20, - max_overflow=30, + pool_size=10, + max_overflow=10, ) pg_engine_async = create_async_engine( get_engine_url(read_only=False, sync=False), echo=False, future=True, pool_pre_ping=True, - pool_size=20, - max_overflow=30, + pool_size=10, + max_overflow=10, ) pg_engine_read_only_async = create_async_engine( get_engine_url(read_only=True, sync=False), echo=False, future=True, pool_pre_ping=True, - pool_size=20, - max_overflow=30, + pool_size=10, + max_overflow=10, ) pg_engine_async_null_pool = create_async_engine( get_engine_url(read_only=False, sync=False), From 353c472dcf3c87869c76d73ce0fba84c508438fe Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 30 Jul 2024 10:10:38 -0700 Subject: [PATCH 352/425] Up storage limit (#575) --- model-engine/model_engine_server/common/resource_limits.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py index c65e6cd6..5a760845 100644 --- a/model-engine/model_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -37,7 +37,7 @@ H100_INSTANCE_LIMITS = dict(cpus=191, memory="2000Gi") H100_1G_20GB_INSTANCE_LIMITS = dict(cpus=47, memory="500Gi") H100_3G_40GB_INSTANCE_LIMITS = dict(cpus=95, memory="1000Gi") -STORAGE_LIMIT = "500G" # TODO: figure out an actual limit. +STORAGE_LIMIT = "640Gi" # TODO: figure out an actual limit. REQUESTS_BY_GPU_TYPE = { None: CPU_INSTANCE_LIMITS, GpuType.NVIDIA_TESLA_T4: T4_INSTANCE_LIMITS, From e4f08544d3ef9a155a7d3d92549771c4f921b3b7 Mon Sep 17 00:00:00 2001 From: Tiffany Zhao <142925794+tiffzhao5@users.noreply.github.com> Date: Tue, 30 Jul 2024 13:15:56 -0700 Subject: [PATCH 353/425] Use session for sts boto3 client for logging hook (#583) * use session * fix unit tests --- .../firehose_streaming_storage_gateway.py | 3 ++- ...test_firehose_streaming_storage_gateway.py | 24 ++++++++----------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py index 801178af..9ec93f44 100644 --- a/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py +++ b/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py @@ -29,7 +29,8 @@ def __init__(self): """ def _get_firehose_client(self): - sts_client = boto3.client("sts", region_name=infra_config().default_region) + sts_session = boto3.Session(region_name=infra_config().default_region) + sts_client = sts_session.client("sts") assumed_role_object = sts_client.assume_role( RoleArn=infra_config().firehose_role_arn, RoleSessionName="AssumeMlLoggingRoleSession", diff --git a/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py b/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py index 3ae72a6e..d4902b29 100644 --- a/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py @@ -44,13 +44,6 @@ def mock_firehose_client(*args, **kwargs): return mock_client -def mock_session(*args, **kwargs): - mock_session_obj = mock.Mock() - mock_firehose = mock_firehose_client() - mock_session_obj.client.return_value = mock_firehose - return mock_session_obj - - def mock_firehose_client_with_exception(*args, **kwargs): mock_client = mock.Mock() mock_client.put_record.return_value = { @@ -61,13 +54,16 @@ def mock_firehose_client_with_exception(*args, **kwargs): return mock_client -def mock_session_with_exception(*args, **kwargs): - mock_session_obj = mock.Mock() - mock_firehose = mock_firehose_client_with_exception() +mock_sts_session = mock.Mock() +mock_sts_session.client.return_value = mock_sts_client() + + +mock_firehose_session = mock.Mock() +mock_firehose_session.client.return_value = mock_firehose_client() - mock_session_obj.client.return_value = mock_firehose - return mock_session_obj +mock_session_with_exception = mock.Mock() +mock_session_with_exception.client.return_value = mock_firehose_client_with_exception() def test_firehose_streaming_storage_gateway_put_record(streaming_storage_gateway, fake_record): @@ -76,7 +72,7 @@ def test_firehose_streaming_storage_gateway_put_record(streaming_storage_gateway mock_sts_client, ), mock.patch( "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.Session", - mock_session, + side_effect=[mock_sts_session, mock_firehose_session], ): assert streaming_storage_gateway.put_record(stream_name, fake_record) is return_value @@ -89,7 +85,7 @@ def test_firehose_streaming_storage_gateway_put_record_with_exception( mock_sts_client, ), mock.patch( "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.Session", - mock_session_with_exception, + side_effect=[mock_sts_session, mock_session_with_exception], ): with pytest.raises(StreamPutException): streaming_storage_gateway.put_record(stream_name, fake_record) From a6e2eda0abab6a4093efa8d380ce1777129db7b0 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 31 Jul 2024 23:03:28 -0700 Subject: [PATCH 354/425] Add env label (#584) --- charts/model-engine/templates/_helpers.tpl | 1 + 1 file changed, 1 insertion(+) diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index e0766b26..50383770 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -50,6 +50,7 @@ team: infra app.kubernetes.io/version: {{ .Values.tag }} tags.datadoghq.com/version: {{ .Values.tag }} tags.datadoghq.com/env: {{ .Values.context }} +env: {{ .Values.context }} {{- if .Values.azure }} azure.workload.identity/use: "true" {{- end }} From 3174f506bfe5741885b472ae2ba6e1175abf56d0 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 1 Aug 2024 14:15:00 -0700 Subject: [PATCH 355/425] Various Db configuration improvements (#585) * Add connection pool logs for db engine * Extract db settings into configs * Update default vars * Fix mypy typechecking * fix null pool init * Create DBManager to better encapsulate connection pooling + refreshing * Clean up old db pool * Refactors * Refactors --- charts/model-engine/values_sample.yaml | 6 + .../model_engine_server/api/dependencies.py | 2 +- .../model_engine_server/common/config.py | 12 +- .../model_engine_server/core/config.py | 14 +- .../core/configs/default.yaml | 5 + model-engine/model_engine_server/db/base.py | 357 ++++++++++-------- .../entrypoints/init_database.py | 2 +- .../gateways/resources/k8s_resource_types.py | 22 +- model-engine/mypy.ini | 3 + 9 files changed, 258 insertions(+), 165 deletions(-) diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index c9365466..430abea6 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -181,6 +181,12 @@ config: # redis_aws_secret_name: sample-prod/redis-credentials # s3_bucket [required] is the S3 bucket you wish to connect s3_bucket: "llm-engine" + # DB engine configs (This is SQLAlchemy heavy) + db_engine_pool_size: 10 + db_engine_max_overflow: 10 + db_engine_echo: false + db_engine_echo_pool: false + db_engine_disconnect_strategy: "pessimistic" launch: # endpoint_namespace [required] is K8s namespace the endpoints will be created in endpoint_namespace: llm-engine diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index cb664583..dc54c188 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -370,7 +370,7 @@ def get_default_external_interfaces() -> ExternalInterfaces: def get_default_external_interfaces_read_only() -> ExternalInterfaces: - session = async_scoped_session( # type: ignore + session = async_scoped_session( get_session_read_only_async(), scopefunc=asyncio.current_task # type: ignore ) return _get_external_interfaces(read_only=True, session=session) diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 201918d6..4531cd2a 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -139,4 +139,14 @@ def read_default_config(): return HostedModelInferenceServiceConfig.from_yaml(SERVICE_CONFIG_PATH) -hmi_config = read_default_config() +_hmi_config: Optional[HostedModelInferenceServiceConfig] = None + + +def get_hmi_config() -> HostedModelInferenceServiceConfig: + global _hmi_config + if _hmi_config is None: + _hmi_config = read_default_config() + return _hmi_config + + +hmi_config = get_hmi_config() diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index e68b61dd..b4c05042 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -31,7 +31,7 @@ @dataclass -class InfraConfig: +class _InfraConfig: cloud_provider: str env: str k8s_cluster_name: str @@ -48,6 +48,18 @@ class InfraConfig: firehose_role_arn: Optional[str] = None firehose_stream_name: Optional[str] = None + +@dataclass +class DBEngineConfig: + db_engine_pool_size: int = 10 + db_engine_max_overflow: int = 10 + db_engine_echo: bool = False + db_engine_echo_pool: bool = False + db_engine_disconnect_strategy: str = "pessimistic" + + +@dataclass +class InfraConfig(DBEngineConfig, _InfraConfig): @classmethod def from_json(cls, json): return cls(**{k: v for k, v in json.items() if k in inspect.signature(cls).parameters}) diff --git a/model-engine/model_engine_server/core/configs/default.yaml b/model-engine/model_engine_server/core/configs/default.yaml index 3529c814..2e2e6ec0 100644 --- a/model-engine/model_engine_server/core/configs/default.yaml +++ b/model-engine/model_engine_server/core/configs/default.yaml @@ -9,3 +9,8 @@ redis_host: "redis-message-broker-master.default" s3_bucket: "test-bucket" profile_ml_worker: "default" profile_ml_inference_worker: "default" +db_engine_pool_size: 10 +db_engine_max_overflow: 10 +db_engine_echo: false +db_engine_echo_pool: false +db_engine_disconnect_strategy: "pessimistic" diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index b5164366..9a67eb65 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -1,26 +1,23 @@ -import asyncio import os import sys import time +from dataclasses import dataclass from typing import Iterator, Optional import sqlalchemy from azure.identity import DefaultAzureCredential from azure.keyvault.secrets import SecretClient from model_engine_server.core.aws.secrets import get_key_file -from model_engine_server.core.config import infra_config +from model_engine_server.core.config import InfraConfig, infra_config from model_engine_server.core.loggers import logger_name, make_logger -from sqlalchemy import create_engine -from sqlalchemy.ext.asyncio import async_scoped_session, async_sessionmaker, create_async_engine +from sqlalchemy import Engine, create_engine +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool logger = make_logger(logger_name()) -database_credential_expiration_timestamp = time.time() -EXPIRATION_BUFFER = 300 # 5 minutes - def get_key_file_name(environment: str) -> str: if infra_config().cloud_provider == "azure": @@ -28,13 +25,19 @@ def get_key_file_name(environment: str) -> str: return f"{environment}/ml_infra_pg".replace("training", "prod").replace("-new", "") +@dataclass +class DBConnection: + url: str + expiry_in_sec: Optional[int] = None + + def get_engine_url( env: Optional[str] = None, read_only: bool = True, sync: bool = True, - reset_expiration_timestamp: bool = False, -) -> str: +) -> DBConnection: """Gets the URL of the Postgresql engine depending on the environment.""" + expiry_in_sec: Optional[int] = None if os.getenv("ML_INFRA_DATABASE_URL"): # In CircleCI environment, we set up a test in another container and specify the URL. engine_url = os.getenv("ML_INFRA_DATABASE_URL") @@ -54,7 +57,7 @@ def get_engine_url( env = infra_config().env if key_file is None: key_file = get_key_file_name(env) # type: ignore - logger.info(f"Using key file {key_file}") + logger.debug(f"Using key file {key_file}") if infra_config().cloud_provider == "azure": client = SecretClient( @@ -67,12 +70,12 @@ def get_engine_url( "https://ossrdbms-aad.database.windows.net/.default" ) password = token.token - if reset_expiration_timestamp: - global database_credential_expiration_timestamp - database_credential_expiration_timestamp = token.expires_on logger.info(f"Connecting to db {db} as user {user}") + # TODO: https://docs.sqlalchemy.org/en/20/core/engines.html#generating-dynamic-authentication-tokens + # for recommendations on how to work with rotating auth credentials engine_url = f"postgresql://{user}:{password}@{db}?sslmode=require" + expiry_in_sec = token.expires_on else: db_secret_aws_profile = os.environ.get("DB_SECRET_AWS_PROFILE") creds = get_key_file(key_file, db_secret_aws_profile) @@ -93,163 +96,217 @@ def get_engine_url( engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://").replace( "sslmode", "ssl" ) - return engine_url - - -# Synchronous sessions (Session and SessionReadOnly) are fairly straightforward, and both -# can be used at any time. To use asynchronous sqlalchemy, use the SessionAsyncNullPool -# if you're running a synchronous program where concurrency of database connections is not -# super important (e.g. Celery workers that use long-standing connections, and Celery is currently -# synchronous). Use SessionAsync and SessionReadOnlyAsync in ASGI applications. - -_Session: Optional[sessionmaker] = None -_SessionReadOnly: Optional[sessionmaker] = None -_SessionAsync: Optional[async_scoped_session] = None -_SessionAsyncNullPool: Optional[async_scoped_session] = None -_SessionReadOnlyAsync: Optional[async_scoped_session] = None - - -def refresh_sessions(): - # Try pool_pre_ping=True, see - # https://docs.sqlalchemy.org/en/14/core/engines.html - # ?highlight=create_engine#sqlalchemy.create_engine.params.pool_pre_ping - # tl;dr is hopefully it stops the psycopg errors from happening - # Another probably less impactful (ie it shouldn't increase latency by as much, - # but also shouldn't get rid of as many errors e.g. 95% compared to 99.9%) - # option is to try pool_recycle = something kinda short e.g. a minute - # pool_pre_ping=True seems to not increase latency by very much - # (I profiled 2.7 ms -> 3.3 ms on GET model_bundles/) - # but hopefully should completely eliminate - # any of the postgres connection errors we've been seeing. - - pg_engine = create_engine( - get_engine_url(read_only=False, sync=True, reset_expiration_timestamp=True), - echo=False, - future=True, - pool_pre_ping=True, - pool_size=10, - max_overflow=10, - ) - pg_engine_read_only = create_engine( - get_engine_url(read_only=True, sync=True), - echo=False, - future=True, - pool_pre_ping=True, - pool_size=10, - max_overflow=10, - ) - pg_engine_async = create_async_engine( - get_engine_url(read_only=False, sync=False), - echo=False, - future=True, - pool_pre_ping=True, - pool_size=10, - max_overflow=10, - ) - pg_engine_read_only_async = create_async_engine( - get_engine_url(read_only=True, sync=False), - echo=False, - future=True, - pool_pre_ping=True, - pool_size=10, - max_overflow=10, - ) - pg_engine_async_null_pool = create_async_engine( - get_engine_url(read_only=False, sync=False), - echo=False, - future=True, - poolclass=NullPool, - pool_pre_ping=True, - ) - - global _Session - global _SessionReadOnly - global _SessionAsync - global _SessionAsyncNullPool - global _SessionReadOnlyAsync - - _Session = sessionmaker(autocommit=False, autoflush=False, bind=pg_engine) - _SessionReadOnly = sessionmaker(autocommit=False, autoflush=False, bind=pg_engine_read_only) - _SessionAsync = async_scoped_session( - session_factory=async_sessionmaker( - autocommit=False, - autoflush=False, - bind=pg_engine_async, - expire_on_commit=False, - ), - scopefunc=asyncio.current_task, - ) - _SessionAsyncNullPool = async_scoped_session( - session_factory=async_sessionmaker( - autocommit=False, - autoflush=False, - bind=pg_engine_async_null_pool, - expire_on_commit=False, - ), - scopefunc=asyncio.current_task, - ) - _SessionReadOnlyAsync = async_scoped_session( - async_sessionmaker( - autocommit=False, - autoflush=False, - bind=pg_engine_read_only_async, - expire_on_commit=False, - ), - scopefunc=asyncio.current_task, - ) - - -refresh_sessions() + return DBConnection(engine_url, expiry_in_sec) + + +@dataclass +class SyncDBSession: + engine: Engine + session: sessionmaker + + +@dataclass +class AsyncDBSession: + engine: AsyncEngine + session: async_sessionmaker + + +@dataclass +class DBSessions: + session_sync: SyncDBSession + session_sync_ro: SyncDBSession + session_async: AsyncDBSession + session_async_ro: AsyncDBSession + session_async_null_pool: AsyncDBSession + + +@dataclass +class DBEngineConfig: + pool_pre_ping: bool + pool_size: int + max_overflow: int + echo: bool + echo_pool: bool + + +class DBManager: + sessions: DBSessions + config: DBEngineConfig + + credential_expiration_timestamp: Optional[float] = None + credential_expiration_buffer_sec: int = 300 + + def _get_engine_url(self, read_only: bool, sync: bool) -> DBConnection: + return get_engine_url(read_only=read_only, sync=sync) + + def __init__(self, infra_config: InfraConfig): + self.pool_pre_ping = infra_config.db_engine_disconnect_strategy == "pessimistic" + self.pool_size = infra_config.db_engine_pool_size + self.max_overflow = infra_config.db_engine_max_overflow + self.echo = infra_config.db_engine_echo + self.echo_pool = infra_config.db_engine_echo_pool + self.sessions = self.refresh_sessions() + + def refresh_sessions(self) -> DBSessions: + db_connection = get_engine_url(read_only=False, sync=True) + # use sync engine as proxy for credential expiration + self.credential_expiration_timestamp = db_connection.expiry_in_sec + pg_engine = create_engine( + db_connection.url, + echo=self.echo, + echo_pool=self.echo_pool, + pool_pre_ping=self.pool_pre_ping, + pool_size=self.pool_size, + max_overflow=self.max_overflow, + future=True, + logging_name="sync", + ) + session_sync = SyncDBSession( + engine=pg_engine, + session=sessionmaker(autocommit=False, autoflush=False, bind=pg_engine), + ) + pg_engine_ro = create_engine( + url=get_engine_url(read_only=True, sync=True).url, + echo=self.echo, + echo_pool=self.echo_pool, + pool_pre_ping=self.pool_pre_ping, + pool_size=self.pool_size, + max_overflow=self.max_overflow, + future=True, + logging_name="sync_ro", + ) + session_sync_ro = SyncDBSession( + engine=pg_engine_ro, + session=sessionmaker(autocommit=False, autoflush=False, bind=pg_engine_ro), + ) -def get_session(): - global _Session - global database_credential_expiration_timestamp + pg_engine_async = create_async_engine( + url=get_engine_url(read_only=False, sync=False).url, + echo=self.echo, + echo_pool=self.echo_pool, + pool_pre_ping=self.pool_pre_ping, + pool_size=self.pool_size, + max_overflow=self.max_overflow, + future=True, + logging_name="async", + ) + session_async = AsyncDBSession( + engine=pg_engine_async, + session=async_sessionmaker( + autocommit=False, + autoflush=False, + bind=pg_engine_async, + expire_on_commit=False, + ), + ) - if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: - refresh_sessions() + pg_engine_async_ro = create_async_engine( + url=get_engine_url(read_only=True, sync=False).url, + echo=self.echo, + echo_pool=self.echo_pool, + pool_pre_ping=self.pool_pre_ping, + pool_size=self.pool_size, + max_overflow=self.max_overflow, + future=True, + logging_name="async_ro", + ) + session_async_ro = AsyncDBSession( + engine=pg_engine_async_ro, + session=async_sessionmaker( + autocommit=False, + autoflush=False, + bind=pg_engine_async_ro, + expire_on_commit=False, + ), + ) - return _Session + pg_engine_async_null_pool = create_async_engine( + url=get_engine_url(read_only=False, sync=False).url, + echo=self.echo, + echo_pool=self.echo_pool, + future=True, + poolclass=NullPool, + logging_name="async_null", + ) + session_async_null_pool = AsyncDBSession( + engine=pg_engine_async_null_pool, + session=async_sessionmaker( + autocommit=False, + autoflush=False, + bind=pg_engine_async_null_pool, + expire_on_commit=False, + ), + ) -def get_session_read_only(): - global _SessionReadOnly - global database_credential_expiration_timestamp + return DBSessions( + session_sync=session_sync, + session_sync_ro=session_sync_ro, + session_async=session_async, + session_async_ro=session_async_ro, + session_async_null_pool=session_async_null_pool, + ) - if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: - refresh_sessions() + def _is_credentials_expired(self): + return ( + self.credential_expiration_timestamp is not None + and time.time() + > self.credential_expiration_timestamp - self.credential_expiration_buffer_sec + ) - return _SessionReadOnly + def _maybe_refresh_sessions(self): + if self._is_credentials_expired(): + old_sessions = self.sessions + self.sessions = self.refresh_sessions() + old_sessions.session_sync.engine.dispose() + old_sessions.session_sync_ro.engine.dispose() + old_sessions.session_async.engine.dispose() + old_sessions.session_async_ro.engine.dispose() + old_sessions.session_async_null_pool.engine.dispose() + def get_session_sync(self) -> sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_sync.session -def get_session_async(): - global _SessionAsync - global database_credential_expiration_timestamp + def get_session_sync_ro(self) -> sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_sync_ro.session - if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: - refresh_sessions() + def get_session_async(self) -> async_sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_async.session - return _SessionAsync + def get_session_async_ro(self) -> async_sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_async_ro.session + def get_session_async_null_pool(self) -> async_sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_async_null_pool.session -def get_session_async_null_pool(): - global _SessionAsyncNullPool - global database_credential_expiration_timestamp - if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: - refresh_sessions() +db_manager = DBManager(infra_config()) - return _SessionAsyncNullPool +def get_session(): + return db_manager.get_session_sync() -def get_session_read_only_async(): - global _SessionReadOnlyAsync - global database_credential_expiration_timestamp - if time.time() > database_credential_expiration_timestamp - EXPIRATION_BUFFER: - refresh_sessions() +def get_session_read_only(): + return db_manager.get_session_sync_ro() + - return _SessionReadOnlyAsync +def get_session_async(): + return db_manager.get_session_async() + + +def get_session_async_null_pool(): + return db_manager.get_session_async_null_pool() + + +def get_session_read_only_async(): + return db_manager.get_session_async_ro() Base = declarative_base() diff --git a/model-engine/model_engine_server/entrypoints/init_database.py b/model-engine/model_engine_server/entrypoints/init_database.py index 5f80ef64..14f7ac77 100644 --- a/model-engine/model_engine_server/entrypoints/init_database.py +++ b/model-engine/model_engine_server/entrypoints/init_database.py @@ -41,7 +41,7 @@ def init_database_and_engine(database_url) -> Engine: # If we are at this point, we want to init the db. if url is None: print("No k8s secret for DB url found, trying AWS secret") - url = get_engine_url(read_only=False, sync=True) + url = get_engine_url(read_only=False, sync=True).url for attempt in Retrying( stop=stop_after_attempt(6), wait=wait_exponential(), diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index b83f404e..481ab0a6 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -589,7 +589,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -637,7 +637,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -687,7 +687,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -732,7 +732,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -779,7 +779,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -823,7 +823,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -869,7 +869,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -925,7 +925,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -983,7 +983,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -1035,7 +1035,7 @@ def get_endpoint_resource_arguments_from_request( PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DD_TRACE_ENABLED=dd_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -1151,7 +1151,7 @@ def get_endpoint_resource_arguments_from_request( MAX_WORKERS=build_endpoint_request.max_workers, # CONCURRENCY=build_endpoint_request.concurrency, REDIS_HOST_PORT=hmi_config.cache_redis_host_port, - REDIS_DB_INDEX=hmi_config.cache_redis_db_index, + REDIS_DB_INDEX=str(hmi_config.cache_redis_db_index), SERVICEBUS_NAMESPACE=os.getenv("SERVICEBUS_NAMESPACE"), AUTHENTICATION_REF="azure-workload-identity", ) diff --git a/model-engine/mypy.ini b/model-engine/mypy.ini index cfd7e38d..f5c39968 100644 --- a/model-engine/mypy.ini +++ b/model-engine/mypy.ini @@ -17,6 +17,9 @@ ignore_errors = True [mypy-model_engine_server.db.*] ignore_errors = True +[mypy-model_engine_server.db.base] +ignore_errors = False + [mypy-model_engine_server.infra.repositories.*] ignore_errors = True From 44bbba1e756f53e8b3f76d1131e2ee986d8534b1 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 2 Aug 2024 11:10:19 -0700 Subject: [PATCH 356/425] Enable passing in headers through the client (#586) * Allow passing in headers to client * Allow passing in headers to client * bump version * Fix header type --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/api_engine.py | 63 +++++++++++++++++++------- clients/python/llmengine/completion.py | 8 ++++ clients/python/llmengine/model.py | 28 ++++++++++-- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 6 files changed, 80 insertions(+), 25 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index b2ea471a..3dcbc6b7 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b34" +__version__ = "0.0.0beta35" import os from typing import Sequence diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index a1b955be..a2b07f63 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -60,13 +60,15 @@ def validate_api_key(cls): ) @classmethod - def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: + def _get( + cls, resource_name: str, timeout: int, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: base_path = get_base_path() api_key = get_api_key() response = requests.get( urljoin(base_path, resource_name), timeout=timeout, - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, auth=(api_key, ""), ) if response.status_code != 200: @@ -76,7 +78,11 @@ def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: @classmethod def put( - cls, resource_name: str, data: Optional[Dict[str, Any]], timeout: int + cls, + resource_name: str, + data: Optional[Dict[str, Any]], + timeout: int, + headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: base_path = get_base_path() api_key = get_api_key() @@ -84,7 +90,7 @@ def put( urljoin(base_path, resource_name), json=data, timeout=timeout, - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, auth=(api_key, ""), ) if response.status_code != 200: @@ -93,13 +99,15 @@ def put( return payload @classmethod - def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: + def _delete( + cls, resource_name: str, timeout: int, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: base_path = get_base_path() api_key = get_api_key() response = requests.delete( urljoin(base_path, resource_name), timeout=timeout, - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, auth=(api_key, ""), ) if response.status_code != 200: @@ -108,15 +116,20 @@ def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: return payload @classmethod - def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + def post_sync( + cls, + resource_name: str, + data: Dict[str, Any], + timeout: int, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: base_path = get_base_path() api_key = get_api_key() response = requests.post( urljoin(base_path, resource_name), json=data, timeout=timeout, - headers={"x-api-key": api_key}, - auth=(api_key, ""), + headers={"x-api-key": api_key, **(headers or {})}, ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -125,7 +138,11 @@ def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Di @classmethod def post_stream( - cls, resource_name: str, data: Dict[str, Any], timeout: int + cls, + resource_name: str, + data: Dict[str, Any], + timeout: int, + headers: Optional[Dict[str, str]] = None, ) -> Iterator[Dict[str, Any]]: base_path = get_base_path() api_key = get_api_key() @@ -133,7 +150,7 @@ def post_stream( urljoin(base_path, resource_name), json=data, timeout=timeout, - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, auth=(api_key, ""), stream=True, ) @@ -158,7 +175,11 @@ def post_stream( @classmethod def post_file( - cls, resource_name: str, files: Dict[str, BufferedReader], timeout: int + cls, + resource_name: str, + files: Dict[str, BufferedReader], + timeout: int, + headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: base_path = get_base_path() api_key = get_api_key() @@ -166,7 +187,7 @@ def post_file( urljoin(base_path, resource_name), files=files, timeout=timeout, - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, auth=(api_key, ""), ) if response.status_code != 200: @@ -176,13 +197,17 @@ def post_file( @classmethod async def apost_sync( - cls, resource_name: str, data: Dict[str, Any], timeout: int + cls, + resource_name: str, + data: Dict[str, Any], + timeout: int, + headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: base_path = get_base_path() api_key = get_api_key() async with ClientSession( timeout=ClientTimeout(timeout), - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, auth=BasicAuth(api_key, ""), ) as session: async with session.post(urljoin(base_path, resource_name), json=data) as resp: @@ -193,13 +218,17 @@ async def apost_sync( @classmethod async def apost_stream( - cls, resource_name: str, data: Dict[str, Any], timeout: int + cls, + resource_name: str, + data: Dict[str, Any], + timeout: int, + headers: Optional[Dict[str, str]] = None, ) -> AsyncIterable[Dict[str, Any]]: base_path = get_base_path() api_key = get_api_key() async with ClientSession( timeout=ClientTimeout(timeout), - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, auth=BasicAuth(api_key, ""), ) as session: async with session.post(urljoin(base_path, resource_name), json=data) as resp: diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 4cbbaf75..076b9031 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -50,6 +50,7 @@ async def acreate( guided_grammar: Optional[str] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, + request_headers: Optional[Dict[str, str]] = None, ) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]: """ Creates a completion for the provided prompt and parameters asynchronously (with `asyncio`). @@ -203,6 +204,7 @@ async def _acreate_stream( resource_name=f"v1/llm/completions-stream?model_endpoint_name={model}", data=data, timeout=timeout, + headers=request_headers, ) async for chunk in response: yield CompletionStreamResponse.parse_obj(chunk) @@ -234,6 +236,7 @@ async def _acreate_sync(**kwargs) -> CompletionSyncResponse: resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}", data=data, timeout=timeout, + headers=request_headers, ) return CompletionSyncResponse.parse_obj(response) @@ -274,6 +277,7 @@ def create( guided_grammar: Optional[str] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, + request_headers: Optional[Dict[str, str]] = None, ) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]: """ Creates a completion for the provided prompt and parameters synchronously. @@ -419,6 +423,7 @@ def _create_stream(**kwargs): resource_name=f"v1/llm/completions-stream?model_endpoint_name={model}", data=data_stream, timeout=timeout, + headers=request_headers, ) for chunk in response_stream: yield CompletionStreamResponse.parse_obj(chunk) @@ -461,6 +466,7 @@ def _create_stream(**kwargs): resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}", data=data, timeout=timeout, + headers=request_headers, ) return CompletionSyncResponse.parse_obj(response) @@ -474,6 +480,7 @@ def batch_create( data_parallelism: int = 1, max_runtime_sec: int = 24 * 3600, tool_config: Optional[ToolConfig] = None, + request_headers: Optional[Dict[str, str]] = None, ) -> CreateBatchCompletionsResponse: """ Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint. @@ -586,5 +593,6 @@ def batch_create( resource_name="v1/llm/batch-completions", data=data, timeout=HTTP_TIMEOUT, + headers=request_headers, ) return CreateBatchCompletionsResponse.parse_obj(response) diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index 4d5a6bb1..cca90657 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -57,6 +57,7 @@ def create( default_callback_url: Optional[str] = None, public_inference: Optional[bool] = True, labels: Optional[Dict[str, str]] = None, + request_headers: Optional[Dict[str, str]] = None, ) -> CreateLLMEndpointResponse: """ Create an LLM model. Note: This API is only available for self-hosted users. @@ -313,6 +314,7 @@ def create( resource_name="v1/llm/model-endpoints", data=request.dict(), timeout=DEFAULT_TIMEOUT, + headers=request_headers, ) return CreateLLMEndpointResponse.parse_obj(response) @@ -320,6 +322,7 @@ def create( def get( cls, model: str, + request_headers: Optional[Dict[str, str]] = None, ) -> GetLLMEndpointResponse: """ Get information about an LLM model. @@ -363,11 +366,16 @@ def get( } ``` """ - response = cls._get(f"v1/llm/model-endpoints/{model}", timeout=DEFAULT_TIMEOUT) + response = cls._get( + f"v1/llm/model-endpoints/{model}", timeout=DEFAULT_TIMEOUT, headers=request_headers + ) return GetLLMEndpointResponse.parse_obj(response) @classmethod - def list(cls) -> ListLLMEndpointsResponse: + def list( + cls, + request_headers: Optional[Dict[str, str]] = None, + ) -> ListLLMEndpointsResponse: """ List LLM models available to call inference on. @@ -440,7 +448,9 @@ def list(cls) -> ListLLMEndpointsResponse: } ``` """ - response = cls._get("v1/llm/model-endpoints", timeout=DEFAULT_TIMEOUT) + response = cls._get( + "v1/llm/model-endpoints", timeout=DEFAULT_TIMEOUT, headers=request_headers + ) return ListLLMEndpointsResponse.parse_obj(response) @classmethod @@ -470,6 +480,7 @@ def update( default_callback_url: Optional[str] = None, public_inference: Optional[bool] = None, labels: Optional[Dict[str, str]] = None, + request_headers: Optional[Dict[str, str]] = None, ) -> UpdateLLMEndpointResponse: """ Update an LLM model. Note: This API is only available for self-hosted users. @@ -618,11 +629,16 @@ def update( resource_name=f"v1/llm/model-endpoints/{name}", data=request.dict(), timeout=DEFAULT_TIMEOUT, + headers=request_headers, ) return UpdateLLMEndpointResponse.parse_obj(response) @classmethod - def delete(cls, model_endpoint_name: str) -> DeleteLLMEndpointResponse: + def delete( + cls, + model_endpoint_name: str, + request_headers: Optional[Dict[str, str]] = None, + ) -> DeleteLLMEndpointResponse: """ Deletes an LLM model. @@ -655,7 +671,9 @@ def delete(cls, model_endpoint_name: str) -> DeleteLLMEndpointResponse: ``` """ response = cls._delete( - f"v1/llm/model-endpoints/{model_endpoint_name}", timeout=DEFAULT_TIMEOUT + f"v1/llm/model-endpoints/{model_endpoint_name}", + timeout=DEFAULT_TIMEOUT, + headers=request_headers, ) return DeleteLLMEndpointResponse.parse_obj(response) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 910fa162..5eeccff5 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta34" +version = "0.0.0.beta35" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index c8d30e11..6636af4b 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta34", + version="0.0.0.beta35", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) From 0d39f29f4c55ec47dd191680d9678db646a05f86 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 2 Aug 2024 13:21:12 -0700 Subject: [PATCH 357/425] Re-add auth header (#588) --- clients/python/llmengine/api_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index a2b07f63..05d298cd 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -129,6 +129,7 @@ def post_sync( urljoin(base_path, resource_name), json=data, timeout=timeout, + auth=(api_key, ""), headers={"x-api-key": api_key, **(headers or {})}, ) if response.status_code != 200: From 7e7f3bffc5f1b731acfc4b1b2a5875c1712aa74f Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 5 Aug 2024 09:58:50 -0700 Subject: [PATCH 358/425] Make storage required for endpoint creation requests (#587) --- integration_tests/rest_api_utils.py | 2 ++ .../model_engine_server/common/dtos/model_endpoints.py | 2 +- .../domain/services/model_endpoint_service.py | 2 +- .../gateways/live_model_endpoint_infra_gateway.py | 2 +- .../infra/gateways/model_endpoint_infra_gateway.py | 2 +- .../infra/services/live_batch_job_service.py | 4 +++- .../infra/services/live_model_endpoint_service.py | 2 +- model-engine/tests/unit/api/conftest.py | 10 +++++----- model-engine/tests/unit/conftest.py | 4 ++-- 9 files changed, 17 insertions(+), 13 deletions(-) diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index db087992..6f6a9407 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -116,6 +116,7 @@ def my_model(**keyword_args): "endpoint_type": "async", "cpus": "0.5", "memory": "500Mi", + "storage": "1Gi", "min_workers": 1, "max_workers": 1, "gpus": 0, @@ -136,6 +137,7 @@ def my_model(**keyword_args): "cpus": "1", "gpus": 0, "memory": "1Gi", + "storage": "2Gi", "optimize_costs": False, "min_workers": 1, "max_workers": 1, diff --git a/model-engine/model_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py index e8620890..a173cfe0 100644 --- a/model-engine/model_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -57,7 +57,7 @@ class CreateModelEndpointV1Request(BaseModel): gpus: int = Field(..., ge=0) memory: StorageSpecificationType gpu_type: Optional[GpuType] = None - storage: Optional[StorageSpecificationType] = None + storage: StorageSpecificationType optimize_costs: Optional[bool] = None min_workers: int = Field(..., ge=0) max_workers: int = Field(..., ge=0) diff --git a/model-engine/model_engine_server/domain/services/model_endpoint_service.py b/model-engine/model_engine_server/domain/services/model_endpoint_service.py index 4c3471b4..aed90ddd 100644 --- a/model-engine/model_engine_server/domain/services/model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/model_endpoint_service.py @@ -75,7 +75,7 @@ async def create_model_endpoint( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, optimize_costs: bool, min_workers: int, max_workers: int, diff --git a/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py index 4b73a386..b92726e2 100644 --- a/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py @@ -61,7 +61,7 @@ def create_model_endpoint_infra( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, optimize_costs: bool, aws_role: str, results_s3_bucket: str, diff --git a/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py index 7d349657..044bc038 100644 --- a/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py @@ -29,7 +29,7 @@ def create_model_endpoint_infra( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, optimize_costs: bool, aws_role: str, results_s3_bucket: str, diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_service.py index f9e6f904..78cea2d1 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_service.py @@ -21,6 +21,7 @@ DEFAULT_ENDPOINT_CPUS_BATCH_JOB = 3 DEFAULT_ENDPOINT_MEMORY_BATCH_JOB = "12Gi" +DEFAULT_ENDPOINT_STORAGE_BATCH_JOB = "16Gi" # to match launch-python-client endpoint default DEFAULT_ENDPOINT_GPUS_BATCH_JOB = 1 DEFAULT_ENDPOINT_GPU_TYPE_BATCH_JOB = GpuType.NVIDIA_TESLA_T4 DEFAULT_ENDPOINT_MAX_WORKERS_BATCH_JOB = 50 @@ -76,6 +77,7 @@ async def create_batch_job( else DEFAULT_ENDPOINT_GPUS_BATCH_JOB ) memory = resource_requests.memory or DEFAULT_ENDPOINT_MEMORY_BATCH_JOB + storage = resource_requests.storage or DEFAULT_ENDPOINT_STORAGE_BATCH_JOB gpu_type = None if gpus == 0 and resource_requests.gpu_type is not None: raise EndpointResourceInvalidRequestException( @@ -101,7 +103,7 @@ async def create_batch_job( gpus=gpus, # type: ignore memory=memory, # type: ignore gpu_type=gpu_type, # type: ignore - storage=resource_requests.storage, + storage=storage, optimize_costs=False, min_workers=0, max_workers=max_workers, # type: ignore diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index 475fbca8..0750fd84 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -145,7 +145,7 @@ async def create_model_endpoint( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, optimize_costs: bool, min_workers: int, max_workers: int, diff --git a/model-engine/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py index 2ca38500..d713b25a 100644 --- a/model-engine/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -609,7 +609,7 @@ def create_model_endpoint_request_async( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", - "storage": None, + "storage": "2G", "min_workers": 0, "max_workers": 5, "per_worker": 3, @@ -635,7 +635,7 @@ def create_model_endpoint_request_sync( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-ampere-a10", - "storage": None, + "storage": "2G", "min_workers": 1, "max_workers": 5, "per_worker": 3, @@ -661,7 +661,7 @@ def create_model_endpoint_request_streaming( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-ampere-a10", - "storage": None, + "storage": "2G", "min_workers": 1, "max_workers": 5, "per_worker": 1, @@ -687,7 +687,7 @@ def create_model_endpoint_request_streaming_invalid_bundle( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-ampere-a10", - "storage": None, + "storage": "2G", "min_workers": 1, "max_workers": 5, "per_worker": 1, @@ -713,7 +713,7 @@ def create_model_endpoint_request_sync_invalid_streaming_bundle( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-ampere-a10", - "storage": None, + "storage": "2G", "min_workers": 1, "max_workers": 5, "per_worker": 1, diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 0e8b59d9..3e02e5bc 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -1023,7 +1023,7 @@ def create_model_endpoint_infra( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, optimize_costs: bool, aws_role: str, results_s3_bucket: str, @@ -1689,7 +1689,7 @@ async def create_model_endpoint( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, optimize_costs: bool, min_workers: int, max_workers: int, From 554d30d39101083a98b85877f110e53be169c268 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:32:33 -0700 Subject: [PATCH 359/425] More Batch Inference Options (#590) * add an optional override for the config file data * temp change repo for testing, also use vllm 0.5.4 because might as well * . * revert vllm to 0.5.3.post1, otherwise there's some stuff that doesn't quite work * revert testing change * fix unit tests * try fixing tests * try fixing tests pt 2 --- .../inference/batch_inference/dto.py | 2 +- .../inference/batch_inference/vllm_batch.py | 24 +++++++-- .../tests/unit/inference/test_vllm_batch.py | 54 +++++++++++-------- 3 files changed, 53 insertions(+), 27 deletions(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/dto.py b/model-engine/model_engine_server/inference/batch_inference/dto.py index c682e14f..63f02efe 100644 --- a/model-engine/model_engine_server/inference/batch_inference/dto.py +++ b/model-engine/model_engine_server/inference/batch_inference/dto.py @@ -128,7 +128,7 @@ class CreateBatchCompletionsRequest(BaseModel): When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. """ - data_parallelism: Optional[int] = Field(default=1, ge=1, le=64) + data_parallelism: int = Field(default=1, ge=1, le=64) """ Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference. """ diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 7881e182..3dda4a34 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -1,3 +1,4 @@ +import argparse import asyncio import json import multiprocessing @@ -311,10 +312,16 @@ def tool_func(text: str, past_context: Optional[str]): return results -async def batch_inference(): +async def batch_inference(config_file_data: Optional[str]): job_index = int(os.getenv("JOB_COMPLETION_INDEX", 0)) - request = CreateBatchCompletionsEngineRequest.parse_file(CONFIG_FILE) + if config_file_data is None: + if CONFIG_FILE is None or not os.path.exists(CONFIG_FILE): + raise FileNotFoundError(f"Config file {CONFIG_FILE} not found") + with open(CONFIG_FILE, "r") as f: + config_file_data = f.read() + + request = CreateBatchCompletionsEngineRequest.model_validate_json(config_file_data) if request.model_cfg.checkpoint_path is not None: download_model(request.model_cfg.checkpoint_path, MODEL_WEIGHTS_FOLDER) @@ -322,7 +329,7 @@ async def batch_inference(): content = request.content if content is None: with smart_open.open(request.input_data_path, "r") as f: - content = CreateBatchCompletionsRequestContent.parse_raw(f.read()) + content = CreateBatchCompletionsRequestContent.model_validate_json(f.read()) model = MODEL_WEIGHTS_FOLDER if request.model_cfg.checkpoint_path else request.model_cfg.model is_finetuned = request.model_cfg.checkpoint_path is not None @@ -506,5 +513,14 @@ def check_unknown_startup_memory_usage(): # pragma: no cover if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-file-data", + "--config_file_data", + type=str, + default=None, + help="Optional override for the config file data, as a json string", + ) + args = parser.parse_args() check_unknown_startup_memory_usage() - asyncio.run(batch_inference()) + asyncio.run(batch_inference(args.config_file_data)) diff --git a/model-engine/tests/unit/inference/test_vllm_batch.py b/model-engine/tests/unit/inference/test_vllm_batch.py index 45223b20..c097f858 100644 --- a/model-engine/tests/unit/inference/test_vllm_batch.py +++ b/model-engine/tests/unit/inference/test_vllm_batch.py @@ -21,7 +21,9 @@ new_callable=mock_open, read_data="Mocked content", ) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") async def test_batch_inference( + mock_builtins_open_func, mock_open_func, mock_popen, mock_get_s3_client, @@ -38,10 +40,10 @@ async def test_batch_inference( # Mock the necessary objects and data mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client - mock_create_batch_completions_engine_request.parse_file.return_value = ( + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( create_batch_completions_engine_request ) - mock_create_batch_completions_request_content.parse_raw.return_value = ( + mock_create_batch_completions_request_content.model_validate_json.return_value = ( create_batch_completions_request_content ) @@ -49,10 +51,10 @@ async def test_batch_inference( mock_generate_with_vllm.return_value = [mock_completion_output] # Call the function - await batch_inference() + await batch_inference("this config data gets ignored because we mock model_validate_json") # Assertions - mock_create_batch_completions_engine_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -79,7 +81,9 @@ async def test_batch_inference( new_callable=mock_open, read_data="Mocked content", ) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") async def test_batch_inference_failed_to_download_model_but_proceed( + mock_builtins_open_func, mock_open_func, mock_popen, mock_get_s3_client, @@ -97,10 +101,10 @@ async def test_batch_inference_failed_to_download_model_but_proceed( mock_process.returncode = 1 # Failed to download model mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client - mock_create_batch_completions_engine_request.parse_file.return_value = ( + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( create_batch_completions_engine_request ) - mock_create_batch_completions_request_content.parse_raw.return_value = ( + mock_create_batch_completions_request_content.model_validate_json.return_value = ( create_batch_completions_request_content ) @@ -108,10 +112,10 @@ async def test_batch_inference_failed_to_download_model_but_proceed( mock_generate_with_vllm.return_value = [mock_completion_output] # Call the function - await batch_inference() + await batch_inference("this config data gets ignored because we mock model_validate_json") # Assertions - mock_create_batch_completions_engine_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -138,9 +142,11 @@ async def test_batch_inference_failed_to_download_model_but_proceed( new_callable=mock_open, read_data="Mocked content", ) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") @patch("model_engine_server.inference.batch_inference.vllm_batch.os.getenv") async def test_batch_inference_two_workers( mock_getenv, + mock_builtins_open_func, mock_open_func, mock_popen, mock_get_s3_client, @@ -158,10 +164,10 @@ async def test_batch_inference_two_workers( mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client create_batch_completions_engine_request.data_parallelism = 2 - mock_create_batch_completions_engine_request.parse_file.return_value = ( + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( create_batch_completions_engine_request ) - mock_create_batch_completions_request_content.parse_raw.return_value = ( + mock_create_batch_completions_request_content.model_validate_json.return_value = ( create_batch_completions_request_content ) @@ -177,10 +183,10 @@ def side_effect(key, default): mock_getenv.side_effect = side_effect # Batch completion worker 1 - await batch_inference() + await batch_inference("this config data gets ignored because we mock model_validate_json") # Assertions - mock_create_batch_completions_engine_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -191,7 +197,7 @@ def side_effect(key, default): ) # Batch completion worker 0 - await batch_inference() + await batch_inference("this config data gets ignored because we mock model_validate_json") mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -224,9 +230,11 @@ def side_effect(key, default): new_callable=mock_open, read_data="Mocked content", ) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") @patch("model_engine_server.inference.batch_inference.vllm_batch.os.getenv") async def test_batch_inference_delete_chunks( mock_getenv, + mock_builtins_open_func, mock_open_func, mock_popen, mock_get_s3_client, @@ -245,10 +253,10 @@ async def test_batch_inference_delete_chunks( mock_get_s3_client.return_value = mock_s3_client create_batch_completions_engine_request.data_parallelism = 2 create_batch_completions_engine_request.output_data_path = "s3://bucket/key" - mock_create_batch_completions_engine_request.parse_file.return_value = ( + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( create_batch_completions_engine_request ) - mock_create_batch_completions_request_content.parse_raw.return_value = ( + mock_create_batch_completions_request_content.model_validate_json.return_value = ( create_batch_completions_request_content ) @@ -264,10 +272,10 @@ def side_effect(key, default): mock_getenv.side_effect = side_effect # Batch completion worker 1 - await batch_inference() + await batch_inference("this config data gets ignored because we mock model_validate_json") # Assertions - mock_create_batch_completions_engine_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -278,7 +286,7 @@ def side_effect(key, default): ) # Batch completion worker 0 - await batch_inference() + await batch_inference("this config data gets ignored because we mock model_validate_json") mock_open_func.assert_has_calls( [ call("input_data_path", "r"), @@ -341,7 +349,9 @@ def test_file_exists_no_such_key(): new_callable=mock_open, read_data="Mocked content", ) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") async def test_batch_inference_tool_completion( + mock_builtins_open_func, mock_open_func, mock_run, mock_popen, @@ -362,10 +372,10 @@ async def test_batch_inference_tool_completion( mock_run.return_value = mock_run_output mock_popen.return_value = mock_process mock_get_s3_client.return_value = mock_s3_client - mock_create_batch_completions_engine_request.parse_file.return_value = ( + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( create_batch_completions_tool_completion_request ) - mock_create_batch_completions_request_content.parse_raw.return_value = ( + mock_create_batch_completions_request_content.model_validate_json.return_value = ( create_batch_completions_tool_completion_request_content ) @@ -376,10 +386,10 @@ async def test_batch_inference_tool_completion( ] # Call the function - await batch_inference() + await batch_inference("this config data gets ignored because we mock model_validate_json") # Assertions - mock_create_batch_completions_engine_request.parse_file.assert_called_once() + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() mock_open_func.assert_has_calls( [ call("input_data_path", "r"), From 5c815e3cbb6bb413c69ec3aa2c586c44423bed0b Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 12 Aug 2024 12:10:28 -0700 Subject: [PATCH 360/425] Allow support for vllm batch with checkpoints (#591) * Allow support for vllm batch with checkpoints * Disable flaky test --- model-engine/model_engine_server/common/service_requests.py | 4 +++- .../domain/use_cases/llm_model_endpoint_use_cases.py | 5 +++-- model-engine/tests/unit/inference/test_http_forwarder.py | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/common/service_requests.py b/model-engine/model_engine_server/common/service_requests.py index 96aeb6f0..ea77f2b9 100644 --- a/model-engine/model_engine_server/common/service_requests.py +++ b/model-engine/model_engine_server/common/service_requests.py @@ -37,7 +37,9 @@ def make_sync_request_with_retries( wait=wait_exponential(multiplier=1, min=1, max=timeout_seconds), ): with attempt: - logger.info(f"Retry number {attempt.retry_state.attempt_number}") + logger.debug( + f"Retry number {attempt.retry_state.attempt_number}" + ) # pragma: no cover resp = requests.post( request_url, json=payload_json, diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 7bd5a878..24bb2351 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -370,10 +370,11 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None: def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str: checkpoint_path = None - if SUPPORTED_MODELS_INFO[model_name].s3_repo: - checkpoint_path = get_models_s3_uri(SUPPORTED_MODELS_INFO[model_name].s3_repo, "") + models_info = SUPPORTED_MODELS_INFO.get(model_name, None) if checkpoint_path_override: checkpoint_path = checkpoint_path_override + elif models_info and models_info.s3_repo: + checkpoint_path = get_models_s3_uri(models_info.s3_repo, "") # pragma: no cover if not checkpoint_path: raise InvalidRequestException(f"No checkpoint path found for model {model_name}") diff --git a/model-engine/tests/unit/inference/test_http_forwarder.py b/model-engine/tests/unit/inference/test_http_forwarder.py index 837812e3..fff38834 100644 --- a/model-engine/tests/unit/inference/test_http_forwarder.py +++ b/model-engine/tests/unit/inference/test_http_forwarder.py @@ -144,6 +144,7 @@ def test_get_concurrency_limiter(): @mock.patch("requests.post", mocked_post) @mock.patch("requests.get", mocked_get) +@pytest.mark.skip(reason="This test is flaky") def test_http_service_429(mock_request, post_inference_hooks_handler): mock_forwarder = Forwarder( "ignored", From 37a0bd9391933d8bd0dc7ee87da3d6b79cb94218 Mon Sep 17 00:00:00 2001 From: Tiffany Zhao <142925794+tiffzhao5@users.noreply.github.com> Date: Mon, 12 Aug 2024 13:58:54 -0700 Subject: [PATCH 361/425] MLI-2510 Validate json logs to test hypothesis on no records in Snowflake (#592) * add try catch * use regex * format * format --- .../inference/post_inference_hooks.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py index 5d45b5cb..7f0e2f4d 100644 --- a/model-engine/model_engine_server/inference/post_inference_hooks.py +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -133,6 +133,25 @@ def handle( "BUNDLE_ID": self._bundle_id, "LABELS": self._labels, } + try: # pragma: no cover + json_string = json.dumps(data_record) # pragma: no cover + # Check for unexpected double quotes or escape characters + import re # pragma: no cover + + pattern = r'\\[ntrbfv\'"]|["\']' # pragma: no cover + matches = re.findall(pattern, repr(json_string)) # pragma: no cover + if matches: # pragma: no cover + logger.info( # pragma: no cover + "The JSON string contains double quotes or escape characters.", + extra={"json_string": json_string, "matches": matches}, + ) + else: + logger.info("The JSON string is valid.") # pragma: no cover + except (TypeError, ValueError) as e: # pragma: no cover + logger.warning( + f"Error: The data_record object is not a valid JSON object. {e}" + ) # pragma: no cover + stream_name = infra_config().firehose_stream_name if stream_name is None: logger.warning("No firehose stream name specified. Logging hook will not be executed.") From 3d9a770ca3d7e4eae41779d4201ee67408433395 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 13 Aug 2024 16:14:03 -0700 Subject: [PATCH 362/425] [Batch Completions V2] DTO models + Batch completions service (#593) * Add dto for openai compatible completion/chatcompletions * LLM Batch completion service * PR review * PR review --- .../model_engine_server/api/llms_v1.py | 25 +- .../model_engine_server/common/dtos/llms.py | 593 ------------------ .../common/dtos/llms/__init__.py | 8 + .../common/dtos/llms/batch_completion.py | 305 +++++++++ .../common/dtos/llms/chat_completion.py | 134 ++++ .../common/dtos/llms/completion.py | 270 ++++++++ .../common/dtos/llms/model_endpoints.py | 203 ++++++ .../common/resource_limits.py | 2 +- .../common/types/__init__.py | 2 + .../common/{types.py => types/endpoint.py} | 0 .../services/llm_batch_completions_service.py | 71 +++ .../use_cases/llm_model_endpoint_use_cases.py | 164 ++++- .../live_llm_batch_completions_service.py | 81 +++ model-engine/tests/unit/api/test_llms.py | 5 + model-engine/tests/unit/domain/conftest.py | 24 +- .../tests/unit/domain/test_llm_use_cases.py | 40 +- 16 files changed, 1272 insertions(+), 655 deletions(-) delete mode 100644 model-engine/model_engine_server/common/dtos/llms.py create mode 100644 model-engine/model_engine_server/common/dtos/llms/__init__.py create mode 100644 model-engine/model_engine_server/common/dtos/llms/batch_completion.py create mode 100644 model-engine/model_engine_server/common/dtos/llms/chat_completion.py create mode 100644 model-engine/model_engine_server/common/dtos/llms/completion.py create mode 100644 model-engine/model_engine_server/common/dtos/llms/model_endpoints.py create mode 100644 model-engine/model_engine_server/common/types/__init__.py rename model-engine/model_engine_server/common/{types.py => types/endpoint.py} (100%) create mode 100644 model-engine/model_engine_server/domain/services/llm_batch_completions_service.py create mode 100644 model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index db4da712..53194fb3 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -20,8 +20,8 @@ CompletionStreamV1Response, CompletionSyncV1Request, CompletionSyncV1Response, - CreateBatchCompletionsRequest, - CreateBatchCompletionsResponse, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1Response, CreateFineTuneRequest, CreateFineTuneResponse, CreateLLMModelEndpointV1Request, @@ -226,7 +226,8 @@ async def list_model_endpoints( @llm_router_v1.get( - "/model-endpoints/{model_endpoint_name}", response_model=GetLLMModelEndpointV1Response + "/model-endpoints/{model_endpoint_name}", + response_model=GetLLMModelEndpointV1Response, ) async def get_model_endpoint( model_endpoint_name: str, @@ -255,7 +256,8 @@ async def get_model_endpoint( @llm_router_v1.put( - "/model-endpoints/{model_endpoint_name}", response_model=UpdateLLMModelEndpointV1Response + "/model-endpoints/{model_endpoint_name}", + response_model=UpdateLLMModelEndpointV1Response, ) async def update_model_endpoint( model_endpoint_name: str, @@ -343,7 +345,7 @@ async def create_completion_sync_task( background_tasks.add_task( external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, TokenUsage( - num_prompt_tokens=response.output.num_prompt_tokens if response.output else None, + num_prompt_tokens=(response.output.num_prompt_tokens if response.output else None), num_completion_tokens=( response.output.num_completion_tokens if response.output else None ), @@ -426,7 +428,8 @@ async def create_completion_stream_task( raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: raise HTTPException( - status_code=500, detail="Internal error occurred. Our team has been notified." + status_code=500, + detail="Internal error occurred. Our team has been notified.", ) from exc async def event_generator(): @@ -440,7 +443,9 @@ async def event_generator(): background_tasks.add_task( external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, TokenUsage( - num_prompt_tokens=message.output.num_prompt_tokens if message.output else None, + num_prompt_tokens=( + message.output.num_prompt_tokens if message.output else None + ), num_completion_tokens=( message.output.num_completion_tokens if message.output else None ), @@ -626,12 +631,12 @@ async def delete_llm_model_endpoint( ) from exc -@llm_router_v1.post("/batch-completions", response_model=CreateBatchCompletionsResponse) +@llm_router_v1.post("/batch-completions", response_model=CreateBatchCompletionsV1Response) async def create_batch_completions( - request: CreateBatchCompletionsRequest, + request: CreateBatchCompletionsV1Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), -) -> CreateBatchCompletionsResponse: +) -> CreateBatchCompletionsV1Response: logger.info(f"POST /batch-completions with {request} for {auth}") try: use_case = CreateBatchCompletionsUseCase( diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py deleted file mode 100644 index 232b8ba6..00000000 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ /dev/null @@ -1,593 +0,0 @@ -""" -DTOs for LLM APIs. - -Make sure to keep this in sync with inference/batch_inference/dto.py. -""" - -from typing import Any, Dict, List, Optional - -from model_engine_server.common.dtos.core import HttpUrlStr -from model_engine_server.common.dtos.model_endpoints import ( - CpuSpecificationType, - GetModelEndpointV1Response, - GpuType, - ModelEndpointType, - StorageSpecificationType, -) -from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field -from model_engine_server.domain.entities import ( - BatchJobStatus, - CallbackAuth, - FineTuneHparamValueType, - LLMFineTuneEvent, - LLMInferenceFramework, - LLMSource, - ModelEndpointStatus, - Quantization, -) - - -class CreateLLMModelEndpointV1Request(BaseModel): - name: str - - # LLM specific fields - model_name: str - source: LLMSource = LLMSource.HUGGING_FACE - inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM - inference_framework_image_tag: str = "latest" - num_shards: int = 1 - """ - Number of shards to distribute the model onto GPUs. - """ - - quantize: Optional[Quantization] = None - """ - Whether to quantize the model. - """ - - checkpoint_path: Optional[str] = None - """ - Path to the checkpoint to load the model from. - """ - - # General endpoint fields - metadata: Dict[str, Any] # TODO: JSON type - post_inference_hooks: Optional[List[str]] = None - endpoint_type: ModelEndpointType = ModelEndpointType.SYNC - cpus: Optional[CpuSpecificationType] = None - gpus: Optional[int] = None - memory: Optional[StorageSpecificationType] = None - gpu_type: Optional[GpuType] = None - storage: Optional[StorageSpecificationType] = None - optimize_costs: Optional[bool] = None - min_workers: int - max_workers: int - per_worker: int - labels: Dict[str, str] - prewarm: Optional[bool] = None - high_priority: Optional[bool] = None - billing_tags: Optional[Dict[str, Any]] = None - default_callback_url: Optional[HttpUrlStr] = None - default_callback_auth: Optional[CallbackAuth] = None - public_inference: Optional[bool] = True # LLM endpoints are public by default. - - -class CreateLLMModelEndpointV1Response(BaseModel): - endpoint_creation_task_id: str - - -class GetLLMModelEndpointV1Response(BaseModel): - id: str - """ - The autogenerated ID of the Launch endpoint. - """ - - name: str - model_name: str - source: LLMSource - status: ModelEndpointStatus - inference_framework: LLMInferenceFramework - inference_framework_image_tag: Optional[str] = None - num_shards: Optional[int] = None - quantize: Optional[Quantization] = None - checkpoint_path: Optional[str] = None - spec: Optional[GetModelEndpointV1Response] = None - - -class ListLLMModelEndpointsV1Response(BaseModel): - model_endpoints: List[GetLLMModelEndpointV1Response] - - -class UpdateLLMModelEndpointV1Request(BaseModel): - # LLM specific fields - model_name: Optional[str] = None - source: Optional[LLMSource] = None - inference_framework_image_tag: Optional[str] = None - num_shards: Optional[int] = None - """ - Number of shards to distribute the model onto GPUs. - """ - - quantize: Optional[Quantization] = None - """ - Whether to quantize the model. - """ - - checkpoint_path: Optional[str] = None - """ - Path to the checkpoint to load the model from. - """ - - # General endpoint fields - metadata: Optional[Dict[str, Any]] = None - post_inference_hooks: Optional[List[str]] = None - cpus: Optional[CpuSpecificationType] = None - gpus: Optional[int] = None - memory: Optional[StorageSpecificationType] = None - gpu_type: Optional[GpuType] = None - storage: Optional[StorageSpecificationType] = None - optimize_costs: Optional[bool] = None - min_workers: Optional[int] = None - max_workers: Optional[int] = None - per_worker: Optional[int] = None - labels: Optional[Dict[str, str]] = None - prewarm: Optional[bool] = None - high_priority: Optional[bool] = None - billing_tags: Optional[Dict[str, Any]] = None - default_callback_url: Optional[HttpUrlStr] = None - default_callback_auth: Optional[CallbackAuth] = None - public_inference: Optional[bool] = None - - -class UpdateLLMModelEndpointV1Response(BaseModel): - endpoint_creation_task_id: str - - -# Delete uses the default Launch endpoint APIs. - - -class CompletionSyncV1Request(BaseModel): - """ - Request object for a synchronous prompt completion task. - """ - - prompt: str - max_new_tokens: int - temperature: float = Field(ge=0.0, le=1.0) - """ - Temperature of the sampling. Setting to 0 equals to greedy sampling. - """ - stop_sequences: Optional[List[str]] = None - """ - List of sequences to stop the completion at. - """ - return_token_log_probs: Optional[bool] = False - """ - Whether to return the log probabilities of the tokens. - """ - presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - """ - Only supported in vllm, lightllm - Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty - """ - frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - """ - Only supported in vllm, lightllm - Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty - """ - top_k: Optional[int] = Field(default=None, ge=-1) - """ - Controls the number of top tokens to consider. -1 means consider all tokens. - """ - top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) - """ - Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. - """ - include_stop_str_in_output: Optional[bool] = None - """ - Whether to include the stop strings in output text. - """ - guided_json: Optional[Dict[str, Any]] = None - """ - JSON schema for guided decoding. Only supported in vllm. - """ - guided_regex: Optional[str] = None - """ - Regex for guided decoding. Only supported in vllm. - """ - guided_choice: Optional[List[str]] = None - """ - Choices for guided decoding. Only supported in vllm. - """ - guided_grammar: Optional[str] = None - """ - Context-free grammar for guided decoding. Only supported in vllm. - """ - skip_special_tokens: Optional[bool] = True - """ - Whether to skip special tokens in the output. Only supported in vllm. - """ - - -class TokenOutput(BaseModel): - token: str - log_prob: float - - -class CompletionOutput(BaseModel): - text: str - num_prompt_tokens: int - num_completion_tokens: int - tokens: Optional[List[TokenOutput]] = None - - -class CompletionSyncV1Response(BaseModel): - """ - Response object for a synchronous prompt completion task. - """ - - request_id: Optional[str] = None - output: Optional[CompletionOutput] = None - - -class CompletionStreamV1Request(BaseModel): - """ - Request object for a stream prompt completion task. - """ - - prompt: str - max_new_tokens: int - temperature: float = Field(ge=0.0, le=1.0) - """ - Temperature of the sampling. Setting to 0 equals to greedy sampling. - """ - stop_sequences: Optional[List[str]] = None - """ - List of sequences to stop the completion at. - """ - return_token_log_probs: Optional[bool] = False - """ - Whether to return the log probabilities of the tokens. - """ - presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - """ - Only supported in vllm, lightllm - Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty - """ - frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - """ - Only supported in vllm, lightllm - Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty - """ - top_k: Optional[int] = Field(default=None, ge=-1) - """ - Controls the number of top tokens to consider. -1 means consider all tokens. - """ - top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) - """ - Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. - """ - include_stop_str_in_output: Optional[bool] = None - """ - Whether to include the stop strings in output text. - """ - guided_json: Optional[Dict[str, Any]] = None - """ - JSON schema for guided decoding. Only supported in vllm. - """ - guided_regex: Optional[str] = None - """ - Regex for guided decoding. Only supported in vllm. - """ - guided_choice: Optional[List[str]] = None - """ - Choices for guided decoding. Only supported in vllm. - """ - guided_grammar: Optional[str] = None - """ - Context-free grammar for guided decoding. Only supported in vllm. - """ - skip_special_tokens: Optional[bool] = True - """ - Whether to skip special tokens in the output. Only supported in vllm. - """ - - -class CompletionStreamOutput(BaseModel): - text: str - finished: bool - num_prompt_tokens: Optional[int] = None - num_completion_tokens: Optional[int] = None - token: Optional[TokenOutput] = None - - -class StreamErrorContent(BaseModel): - error: str - """Error message.""" - timestamp: str - """Timestamp of the error.""" - - -class StreamError(BaseModel): - """ - Error object for a stream prompt completion task. - """ - - status_code: int - """The HTTP status code of the error.""" - content: StreamErrorContent - """The error content.""" - - -class CompletionStreamV1Response(BaseModel): - """ - Response object for a stream prompt completion task. - """ - - request_id: Optional[str] = None - output: Optional[CompletionStreamOutput] = None - error: Optional[StreamError] = None - """Error of the response (if any).""" - - -class TokenUsage(BaseModel): - """ - Token usage for a prompt completion task. - """ - - num_prompt_tokens: Optional[int] = 0 - num_completion_tokens: Optional[int] = 0 - total_duration: Optional[float] = None - """Includes time spent waiting for the model to be ready.""" - - time_to_first_token: Optional[float] = None # Only for streaming requests - - @property - def num_total_tokens(self) -> int: - return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) - - @property - def total_tokens_per_second(self) -> float: - return ( - self.num_total_tokens / self.total_duration - if self.total_duration and self.total_duration > 0 - else 0.0 - ) - - @property - def inter_token_latency(self) -> Optional[float]: # Only for streaming requests - # Note: we calculate a single inter-token latency for the entire request. - # Calculating latency between each token seems a bit heavyweight, although we can do this if we wanted - if ( - self.time_to_first_token is None - or self.num_completion_tokens is None - or self.total_duration is None - ): - return None - if self.num_completion_tokens < 2: - return None - return (self.total_duration - self.time_to_first_token) / (self.num_completion_tokens - 1) - - -class CreateFineTuneRequest(BaseModel): - model: str - training_file: str - validation_file: Optional[str] = None - # fine_tuning_method: str # TODO enum + uncomment when we support multiple methods - hyperparameters: Dict[str, FineTuneHparamValueType] # validated somewhere else - suffix: Optional[str] = None - wandb_config: Optional[Dict[str, Any]] = None - """ - Config to pass to wandb for init. See https://docs.wandb.ai/ref/python/init - Must include `api_key` field which is the wandb API key. - """ - - -class CreateFineTuneResponse(BaseModel): - id: str - - -class GetFineTuneResponse(BaseModel): - id: str = Field(..., description="Unique ID of the fine tune") - fine_tuned_model: Optional[str] = Field( - default=None, - description="Name of the resulting fine-tuned model. This can be plugged into the " - "Completion API ones the fine-tune is complete", - ) - status: BatchJobStatus = Field(..., description="Status of the requested fine tune.") - - -class ListFineTunesResponse(BaseModel): - jobs: List[GetFineTuneResponse] - - -class CancelFineTuneResponse(BaseModel): - success: bool - - -class GetFineTuneEventsResponse(BaseModel): - # LLMFineTuneEvent is entity layer technically, but it's really simple - events: List[LLMFineTuneEvent] - - -class ModelDownloadRequest(BaseModel): - model_name: str = Field(..., description="Name of the fine tuned model") - download_format: Optional[str] = Field( - default="hugging_face", - description="Format that you want the downloaded urls to be compatible with. Currently only supports hugging_face", - ) - - -class ModelDownloadResponse(BaseModel): - urls: Dict[str, str] = Field( - ..., description="Dictionary of (file_name, url) pairs to download the model from." - ) - - -class DeleteLLMEndpointResponse(BaseModel): - deleted: bool - - -class CreateBatchCompletionsRequestContent(BaseModel): - prompts: List[str] - max_new_tokens: int - temperature: float = Field(ge=0.0, le=1.0) - """ - Temperature of the sampling. Setting to 0 equals to greedy sampling. - """ - stop_sequences: Optional[List[str]] = None - """ - List of sequences to stop the completion at. - """ - return_token_log_probs: Optional[bool] = False - """ - Whether to return the log probabilities of the tokens. - """ - presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - """ - Only supported in vllm, lightllm - Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty - """ - frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - """ - Only supported in vllm, lightllm - Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty - """ - top_k: Optional[int] = Field(default=None, ge=-1) - """ - Controls the number of top tokens to consider. -1 means consider all tokens. - """ - top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) - """ - Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. - """ - skip_special_tokens: Optional[bool] = True - """ - Whether to skip special tokens in the output. - """ - - -class CreateBatchCompletionsModelConfig(BaseModel): - model: str - checkpoint_path: Optional[str] = None - """ - Path to the checkpoint to load the model from. - """ - labels: Dict[str, str] - """ - Labels to attach to the batch inference job. - """ - num_shards: Optional[int] = 1 - """ - Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. - System may decide to use a different number than the given value. - """ - quantize: Optional[Quantization] = None - """ - Whether to quantize the model. - """ - seed: Optional[int] = None - """ - Random seed for the model. - """ - - -class ToolConfig(BaseModel): - """ - Configuration for tool use. - NOTE: this config is highly experimental and signature will change significantly in future iterations. - """ - - name: str - """ - Name of the tool to use for the batch inference. - """ - max_iterations: Optional[int] = 10 - """ - Maximum number of iterations to run the tool. - """ - execution_timeout_seconds: Optional[int] = 60 - """ - Maximum runtime of the tool in seconds. - """ - should_retry_on_error: Optional[bool] = True - """ - Whether to retry the tool on error. - """ - - -class CreateBatchCompletionsRequest(BaseModel): - """ - Request object for batch completions. - """ - - model_config = ConfigDict(protected_namespaces=()) - - input_data_path: Optional[str] = None - output_data_path: str - """ - Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. - """ - content: Optional[CreateBatchCompletionsRequestContent] = None - """ - Either `input_data_path` or `content` needs to be provided. - When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. - """ - model_cfg: CreateBatchCompletionsModelConfig = Field(alias="model_config") - """ - Model configuration for the batch inference. Hardware configurations are inferred. - - We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which - reserves model_config as a keyword. - """ - - data_parallelism: Optional[int] = Field(default=1, ge=1, le=64) - """ - Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference. - """ - max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600) - """ - Maximum runtime of the batch inference in seconds. Default to one day. - """ - tool_config: Optional[ToolConfig] = None - """ - Configuration for tool use. - NOTE: this config is highly experimental and signature will change significantly in future iterations. - """ - - -class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest): - """ - Internal model for representing request to the llm engine. This contains additional fields that we want - hidden from the DTO exposed to the client. - """ - - max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0) - """ - Maximum GPU memory utilization for the batch inference. Default to 90%. - """ - - @staticmethod - def from_api(request: CreateBatchCompletionsRequest) -> "CreateBatchCompletionsEngineRequest": - return CreateBatchCompletionsEngineRequest( - input_data_path=request.input_data_path, - output_data_path=request.output_data_path, - content=request.content, - model_config=request.model_cfg, - model_cfg=request.model_cfg, - data_parallelism=request.data_parallelism, - max_runtime_sec=request.max_runtime_sec, - tool_config=request.tool_config, - ) - - -class CreateBatchCompletionsResponse(BaseModel): - job_id: str - - -class GetBatchCompletionsResponse(BaseModel): - progress: float - """ - Progress of the batch inference in percentage from 0 to 100. - """ - finished: bool diff --git a/model-engine/model_engine_server/common/dtos/llms/__init__.py b/model-engine/model_engine_server/common/dtos/llms/__init__.py new file mode 100644 index 00000000..ae7bef45 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/__init__.py @@ -0,0 +1,8 @@ +""" +DTOs for LLM APIs. +""" + +from .batch_completion import * # noqa: F403 +from .chat_completion import * # noqa: F403 +from .completion import * # noqa: F403 +from .model_endpoints import * # noqa: F403 diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py new file mode 100644 index 00000000..4fa59281 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -0,0 +1,305 @@ +# Make sure to keep this in sync with inference/batch_inference/dto.py. +from enum import Enum +from typing import Dict, List, Optional, Union + +from model_engine_server.common.dtos.llms.chat_completion import ( + ChatCompletionV2Request, + ChatCompletionV2Response, +) +from model_engine_server.common.dtos.llms.completion import ( + CompletionV2Request, + CompletionV2Response, +) +from model_engine_server.common.pydantic_types import BaseModel, Field +from typing_extensions import TypeAlias + + +# Common DTOs for batch completions +class ToolConfig(BaseModel): + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + name: str + """ + Name of the tool to use for the batch inference. + """ + max_iterations: Optional[int] = 10 + """ + Maximum number of iterations to run the tool. + """ + execution_timeout_seconds: Optional[int] = 60 + """ + Maximum runtime of the tool in seconds. + """ + should_retry_on_error: Optional[bool] = True + """ + Whether to retry the tool on error. + """ + + +class BatchCompletionsModelConfig(BaseModel): + model: str = Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ) + + checkpoint_path: Optional[str] = Field( + default=None, description="Path to the checkpoint to load the model from." + ) + + num_shards: Optional[int] = Field( + default=1, + ge=1, + description=""" +Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. +System may decide to use a different number than the given value. +""", + ) + + seed: Optional[int] = Field(default=None, description="Random seed for the model.") + + +class BatchCompletionsRequestBase(BaseModel): + input_data_path: Optional[str] = Field( + default=None, + description="Path to the input file. The input file should be a JSON file of type List[CreateBatchCompletionsRequestContent].", + ) + output_data_path: str = Field( + description="Path to the output file. The output file will be a JSON file of type List[CompletionOutput]." + ) + + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) + + data_parallelism: Optional[int] = Field( + default=1, + ge=1, + le=64, + description="Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference.", + ) + + max_runtime_sec: Optional[int] = Field( + default=24 * 3600, + ge=1, + le=2 * 24 * 3600, + description="Maximum runtime of the batch inference in seconds. Default to one day.", + ) + + priority: Optional[int] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + + tool_config: Optional[ToolConfig] = Field( + default=None, + description=""" +Configuration for tool use. +NOTE: this config is highly experimental and signature will change significantly in future iterations.""", + ) + + +# V1 DTOs for batch completions +class CreateBatchCompletionsV1ModelConfig(BatchCompletionsModelConfig): + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) + + +class CreateBatchCompletionsV1RequestContent(BaseModel): + prompts: List[str] + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. + """ + + +class CreateBatchCompletionsV1Request(BatchCompletionsRequestBase): + """ + Request object for batch completions. + """ + + content: Optional[CreateBatchCompletionsV1RequestContent] = None + """ + Either `input_data_path` or `content` needs to be provided. + When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. + """ + model_cfg: CreateBatchCompletionsV1ModelConfig = Field(alias="model_config") + """ + Model configuration for the batch inference. Hardware configurations are inferred. + + We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + reserves model_config as a keyword. + """ + + +class CreateBatchCompletionsV1Response(BaseModel): + job_id: str + + +# V2 DTOs for batch completions +CompletionRequest: TypeAlias = Union[CompletionV2Request, ChatCompletionV2Request] +CompletionResponse: TypeAlias = Union[CompletionV2Response, ChatCompletionV2Response] +CreateBatchCompletionsV2RequestContent: TypeAlias = List[CompletionRequest] +CreateBatchCompletionsV2ModelConfig: TypeAlias = BatchCompletionsModelConfig + + +class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): + """ + Request object for batch completions. + """ + + content: Optional[CreateBatchCompletionsV2RequestContent] = Field( + default=None, + description=""" +Either `input_data_path` or `content` needs to be provided. +When input_data_path is provided, the input file should be a JSON file of type List[CreateBatchCompletionsRequestContent]. +""", + ) + + # We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + # reserves model_config as a keyword. + model_cfg: BatchCompletionsModelConfig = Field( + alias="model_config", + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + +class BatchCompletionsJobStatus(Enum): + Queued = "queued" + Running = "running" + Completed = "completed" + Failed = "failed" + Cancelled = "cancelled" + + +class BatchCompletionsJob(BaseModel): + job_id: str + input_data_path: Optional[str] = Field( + default=None, + description="Path to the input file. The input file should be a JSON file of type List[CreateBatchCompletionsRequestContent].", + ) + output_data_path: str = Field( + description="Path to the output file. The output file will be a JSON file of type List[CompletionOutput]." + ) + + # We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + # reserves model_config as a keyword. + model_cfg: BatchCompletionsModelConfig = Field( + alias="model_config", + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + priority: Optional[int] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + status: BatchCompletionsJobStatus + created_at: str + expires_at: str + completed_at: Optional[str] + metadata: Optional[Dict[str, str]] + + +CreateBatchCompletionsV2Response: TypeAlias = BatchCompletionsJob + + +class ListBatchCompletionV2Response(BaseModel): + jobs: List[BatchCompletionsJob] + + +class GetBatchCompletionV2Response(BaseModel): + job: BatchCompletionsJob + + +BatchCompletionContent = Union[ + CreateBatchCompletionsV1RequestContent, CreateBatchCompletionsV2RequestContent +] + + +class CreateBatchCompletionsEngineRequest(BatchCompletionsRequestBase): + """ + Internal model for representing request to the inference framework. This contains additional fields that we want + hidden from the DTO exposed to the client. + """ + + content: Optional[BatchCompletionContent] = Field( + default=None, + description="Content is a union of the content from v1 and v2 requests.", + ) + + model_cfg: BatchCompletionsModelConfig = Field( + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + max_gpu_memory_utilization: Optional[float] = Field( + default=0.9, + le=1.0, + description="Maximum GPU memory utilization for the batch inference. Default to 90%.", + ) + + @staticmethod + def from_api_v1( + request: CreateBatchCompletionsV1Request, + ) -> "CreateBatchCompletionsEngineRequest": + return CreateBatchCompletionsEngineRequest( + input_data_path=request.input_data_path, + output_data_path=request.output_data_path, + content=request.content, + model_config=request.model_cfg, + model_cfg=request.model_cfg, + data_parallelism=request.data_parallelism, + max_runtime_sec=request.max_runtime_sec, + tool_config=request.tool_config, + labels=request.model_cfg.labels, + priority=request.priority, + ) + + @staticmethod + def from_api_v2( + request: CreateBatchCompletionsV2Request, + ) -> "CreateBatchCompletionsEngineRequest": + return CreateBatchCompletionsEngineRequest( + input_data_path=request.input_data_path, + output_data_path=request.output_data_path, + content=request.content, + model_config=request.model_cfg, + model_cfg=request.model_cfg, + data_parallelism=request.data_parallelism, + max_runtime_sec=request.max_runtime_sec, + labels=request.labels, + priority=request.priority, + ) diff --git a/model-engine/model_engine_server/common/dtos/llms/chat_completion.py b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py new file mode 100644 index 00000000..c573b526 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py @@ -0,0 +1,134 @@ +from typing import Any, Dict, List, Optional + +from model_engine_server.common.types.gen.openai import ( + CreateChatCompletionRequest, + CreateChatCompletionResponse, +) +from pydantic import Field +from typing_extensions import Annotated + +# Fields that are a part of OpenAI spec but are not supported by model engine +UNSUPPORTED_FIELDS = ["service_tier"] + + +class VLLMAdditionalFields: + chat_template: Annotated[ + Optional[str], + Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one." + ), + ), + ] + chat_template_kwargs: Annotated[ + Optional[Dict[str, Any]], + Field( + default=None, + description=( + "Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ), + ] + + guided_json: Annotated[ + Optional[Dict[str, Any]], + Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ), + ] + + guided_regex: Annotated[ + Optional[str], + Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ), + ] + guided_choice: Annotated[ + Optional[List[str]], + Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ), + ] + + guided_grammar: Annotated[ + Optional[str], + Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ), + ] + + guided_decoding_backend: Annotated[ + Optional[str], + Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ), + ] + + guided_whitespace_pattern: Annotated[ + Optional[str], + Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ), + ] + + skip_special_tokens: Annotated[ + Optional[bool], + Field( + True, + description="Whether to skip special tokens in the output. Only supported in vllm.", + ), + ] + + +class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMAdditionalFields): + model: Annotated[ + str, + Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ), + ] + + stream: Annotated[ + Optional[bool], + Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + + top_k: Annotated[ + Optional[int], + Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ), + ] + + include_stop_str_in_output: Annotated[ + Optional[bool], + Field(None, description="Whether to include the stop strings in output text."), + ] + + +class ChatCompletionV2Response(CreateChatCompletionResponse): + pass diff --git a/model-engine/model_engine_server/common/dtos/llms/completion.py b/model-engine/model_engine_server/common/dtos/llms/completion.py new file mode 100644 index 00000000..25dc0caa --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/completion.py @@ -0,0 +1,270 @@ +from typing import Any, Dict, List, Optional + +from model_engine_server.common.pydantic_types import BaseModel, Field +from model_engine_server.common.types.gen.openai import ( + CreateCompletionRequest, + CreateCompletionResponse, +) +from typing_extensions import Annotated + +# Fields that are a part of OpenAI spec but are not supported by model engine +UNSUPPORTED_FIELDS = ["service_tier"] + + +class CompletionSyncV1Request(BaseModel): + """ + Request object for a synchronous prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ + + +class TokenOutput(BaseModel): + token: str + log_prob: float + + +class CompletionOutput(BaseModel): + text: str + num_prompt_tokens: int + num_completion_tokens: int + tokens: Optional[List[TokenOutput]] = None + + +class CompletionSyncV1Response(BaseModel): + """ + Response object for a synchronous prompt completion task. + """ + + request_id: Optional[str] = None + output: Optional[CompletionOutput] = None + + +class CompletionStreamV1Request(BaseModel): + """ + Request object for a stream prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ + + +class CompletionStreamOutput(BaseModel): + text: str + finished: bool + num_prompt_tokens: Optional[int] = None + num_completion_tokens: Optional[int] = None + token: Optional[TokenOutput] = None + + +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" + + +class CompletionStreamV1Response(BaseModel): + """ + Response object for a stream prompt completion task. + """ + + request_id: Optional[str] = None + output: Optional[CompletionStreamOutput] = None + error: Optional[StreamError] = None + """Error of the response (if any).""" + + +class TokenUsage(BaseModel): + """ + Token usage for a prompt completion task. + """ + + num_prompt_tokens: Optional[int] = 0 + num_completion_tokens: Optional[int] = 0 + total_duration: Optional[float] = None + """Includes time spent waiting for the model to be ready.""" + + time_to_first_token: Optional[float] = None # Only for streaming requests + + @property + def num_total_tokens(self) -> int: + return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) + + @property + def total_tokens_per_second(self) -> float: + return ( + self.num_total_tokens / self.total_duration + if self.total_duration and self.total_duration > 0 + else 0.0 + ) + + @property + def inter_token_latency(self) -> Optional[float]: # Only for streaming requests + # Note: we calculate a single inter-token latency for the entire request. + # Calculating latency between each token seems a bit heavyweight, although we can do this if we wanted + if ( + self.time_to_first_token is None + or self.num_completion_tokens is None + or self.total_duration is None + ): + return None + if self.num_completion_tokens < 2: + return None + return (self.total_duration - self.time_to_first_token) / (self.num_completion_tokens - 1) + + +class CompletionV2Request(CreateCompletionRequest): + model: Annotated[ + str, + Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ), + ] + + stream: Annotated[ + Optional[bool], + Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + + top_k: Annotated[ + Optional[int], + Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ), + ] + + include_stop_str_in_output: Annotated[ + Optional[bool], + Field(None, description="Whether to include the stop strings in output text."), + ] + + +class CompletionV2Response(CreateCompletionResponse): + pass diff --git a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py new file mode 100644 index 00000000..5b870532 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py @@ -0,0 +1,203 @@ +""" +DTOs for LLM APIs. + +""" + +from typing import Any, Dict, List, Optional + +from model_engine_server.common.dtos.core import HttpUrlStr +from model_engine_server.common.dtos.model_endpoints import ( + CpuSpecificationType, + GetModelEndpointV1Response, + GpuType, + ModelEndpointType, + StorageSpecificationType, +) +from model_engine_server.common.pydantic_types import BaseModel, Field +from model_engine_server.domain.entities import ( + BatchJobStatus, + CallbackAuth, + FineTuneHparamValueType, + LLMFineTuneEvent, + LLMInferenceFramework, + LLMSource, + ModelEndpointStatus, + Quantization, +) + + +class CreateLLMModelEndpointV1Request(BaseModel): + name: str + + # LLM specific fields + model_name: str + source: LLMSource = LLMSource.HUGGING_FACE + inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM + inference_framework_image_tag: str = "latest" + num_shards: int = 1 + """ + Number of shards to distribute the model onto GPUs. + """ + + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + + # General endpoint fields + metadata: Dict[str, Any] # TODO: JSON type + post_inference_hooks: Optional[List[str]] = None + endpoint_type: ModelEndpointType = ModelEndpointType.SYNC + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None + min_workers: int + max_workers: int + per_worker: int + labels: Dict[str, str] + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = True # LLM endpoints are public by default. + + +class CreateLLMModelEndpointV1Response(BaseModel): + endpoint_creation_task_id: str + + +class GetLLMModelEndpointV1Response(BaseModel): + id: str + """ + The autogenerated ID of the Launch endpoint. + """ + + name: str + model_name: str + source: LLMSource + status: ModelEndpointStatus + inference_framework: LLMInferenceFramework + inference_framework_image_tag: Optional[str] = None + num_shards: Optional[int] = None + quantize: Optional[Quantization] = None + checkpoint_path: Optional[str] = None + spec: Optional[GetModelEndpointV1Response] = None + + +class ListLLMModelEndpointsV1Response(BaseModel): + model_endpoints: List[GetLLMModelEndpointV1Response] + + +class UpdateLLMModelEndpointV1Request(BaseModel): + # LLM specific fields + model_name: Optional[str] = None + source: Optional[LLMSource] = None + inference_framework_image_tag: Optional[str] = None + num_shards: Optional[int] = None + """ + Number of shards to distribute the model onto GPUs. + """ + + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + + # General endpoint fields + metadata: Optional[Dict[str, Any]] = None + post_inference_hooks: Optional[List[str]] = None + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None + min_workers: Optional[int] = None + max_workers: Optional[int] = None + per_worker: Optional[int] = None + labels: Optional[Dict[str, str]] = None + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = None + + +class UpdateLLMModelEndpointV1Response(BaseModel): + endpoint_creation_task_id: str + + +class CreateFineTuneRequest(BaseModel): + model: str + training_file: str + validation_file: Optional[str] = None + # fine_tuning_method: str # TODO enum + uncomment when we support multiple methods + hyperparameters: Dict[str, FineTuneHparamValueType] # validated somewhere else + suffix: Optional[str] = None + wandb_config: Optional[Dict[str, Any]] = None + """ + Config to pass to wandb for init. See https://docs.wandb.ai/ref/python/init + Must include `api_key` field which is the wandb API key. + """ + + +class CreateFineTuneResponse(BaseModel): + id: str + + +class GetFineTuneResponse(BaseModel): + id: str = Field(..., description="Unique ID of the fine tune") + fine_tuned_model: Optional[str] = Field( + default=None, + description="Name of the resulting fine-tuned model. This can be plugged into the " + "Completion API ones the fine-tune is complete", + ) + status: BatchJobStatus = Field(..., description="Status of the requested fine tune.") + + +class ListFineTunesResponse(BaseModel): + jobs: List[GetFineTuneResponse] + + +class CancelFineTuneResponse(BaseModel): + success: bool + + +class GetFineTuneEventsResponse(BaseModel): + # LLMFineTuneEvent is entity layer technically, but it's really simple + events: List[LLMFineTuneEvent] + + +class ModelDownloadRequest(BaseModel): + model_name: str = Field(..., description="Name of the fine tuned model") + download_format: Optional[str] = Field( + default="hugging_face", + description="Format that you want the downloaded urls to be compatible with. Currently only supports hugging_face", + ) + + +class ModelDownloadResponse(BaseModel): + urls: Dict[str, str] = Field( + ..., + description="Dictionary of (file_name, url) pairs to download the model from.", + ) + + +# Delete uses the default Launch endpoint APIs. +class DeleteLLMEndpointResponse(BaseModel): + deleted: bool diff --git a/model-engine/model_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py index 5a760845..04a07edc 100644 --- a/model-engine/model_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -58,7 +58,7 @@ def validate_resource_requests( - bundle: Union[ModelBundle, DockerImageBatchJobBundle], + bundle: Optional[Union[ModelBundle, DockerImageBatchJobBundle]], cpus: Optional[CpuSpecificationType], memory: Optional[StorageSpecificationType], storage: Optional[StorageSpecificationType], diff --git a/model-engine/model_engine_server/common/types/__init__.py b/model-engine/model_engine_server/common/types/__init__.py new file mode 100644 index 00000000..5cdfe557 --- /dev/null +++ b/model-engine/model_engine_server/common/types/__init__.py @@ -0,0 +1,2 @@ +from .endpoint import * # noqa: F403 +from .gen import * # noqa: F403 diff --git a/model-engine/model_engine_server/common/types.py b/model-engine/model_engine_server/common/types/endpoint.py similarity index 100% rename from model-engine/model_engine_server/common/types.py rename to model-engine/model_engine_server/common/types/endpoint.py diff --git a/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py b/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py new file mode 100644 index 00000000..31c79bcf --- /dev/null +++ b/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional + +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.llms import CreateBatchCompletionsEngineRequest +from model_engine_server.common.dtos.llms.batch_completion import BatchCompletionsJob +from model_engine_server.core.auth.authentication_repository import User + + +class LLMBatchCompletionsService(ABC): + """ + Base class for LLM batch completions services. + """ + + @abstractmethod + async def create_batch_job( + self, + *, + user: User, + image_repo: str, + image_tag: str, + job_request: CreateBatchCompletionsEngineRequest, + resource_requests: CreateDockerImageBatchJobResourceRequests, + max_runtime_sec: int = 24 * 60 * 60, + labels: Dict[str, str] = {}, + priority: Optional[int] = 0, + num_workers: Optional[int] = 1, + ) -> BatchCompletionsJob: + """ + Create a batch completion job. + + Args: + owner: The user who requested the batch job + image_repo: The docker repo where the image is stored + image_tag: The tag of the batch completions image + job_config: The user-specified input to the batch job. Exposed as a file mounted at mount_location to the batch job + labels: Labels to apply to the batch job. + resource_requests: The resource requests for the batch job. + max_runtime_sec: The timeout of the batch job in seconds. + num_workers: The number of workers to run in the job. + + Returns: + The ID of the batch job. + """ + pass + + @abstractmethod + async def get_batch_job(self, batch_job_id: str) -> Optional[BatchCompletionsJob]: + """ + Get a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + The batch job, or None if it does not exist. + """ + pass + + @abstractmethod + async def cancel_batch_job(self, batch_job_id: str) -> bool: + """ + Update a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + Whether the batch job was updated successfully. + """ + pass diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 24bb2351..54a88383 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -24,8 +24,10 @@ CompletionSyncV1Request, CompletionSyncV1Response, CreateBatchCompletionsEngineRequest, - CreateBatchCompletionsRequest, - CreateBatchCompletionsResponse, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1Response, + CreateBatchCompletionsV2Request, + CreateBatchCompletionsV2Response, CreateLLMModelEndpointV1Request, CreateLLMModelEndpointV1Response, DeleteLLMEndpointResponse, @@ -90,6 +92,9 @@ TokenizerRepository, ) from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService +from model_engine_server.domain.services.llm_batch_completions_service import ( + LLMBatchCompletionsService, +) from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway from model_engine_server.infra.repositories.live_tokenizer_repository import ( SUPPORTED_MODELS_INFO, @@ -272,6 +277,16 @@ def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRep return len(tokenizer.encode(input)) +async def _get_latest_batch_tag(inference_framework: LLMInferenceFramework) -> str: + config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME) + batch_key = f"{inference_framework}_batch" + if batch_key not in config_map: + raise LatestImageTagNotFoundException( + f"Could not find latest batch job tag for inference framework {inference_framework}." + ) + return config_map[batch_key] + + async def _get_latest_tag(inference_framework: LLMInferenceFramework) -> str: config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME) if inference_framework not in config_map: @@ -1196,7 +1211,10 @@ def __init__( self.docker_repository = docker_repository async def execute( - self, user: User, model_endpoint_name: str, request: UpdateLLMModelEndpointV1Request + self, + user: User, + model_endpoint_name: str, + request: UpdateLLMModelEndpointV1Request, ) -> UpdateLLMModelEndpointV1Response: if request.labels is not None: validate_labels(request.labels) @@ -1253,7 +1271,9 @@ async def execute( validate_model_name(model_name, inference_framework) validate_num_shards( - num_shards, inference_framework, request.gpus or infra_state.resource_state.gpus + num_shards, + inference_framework, + request.gpus or infra_state.resource_state.gpus, ) validate_quantization(quantize, inference_framework) @@ -1420,7 +1440,10 @@ def validate_and_update_completion_params( if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: request.top_k = None if request.top_k == -1 else request.top_k request.top_p = None if request.top_p == 1.0 else request.top_p - if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: + if inference_framework in [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: request.top_k = -1 if request.top_k is None else request.top_k request.top_p = 1.0 if request.top_p is None else request.top_p else: @@ -1430,7 +1453,10 @@ def validate_and_update_completion_params( ) # presence_penalty, frequency_penalty - if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]: + if inference_framework in [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: request.presence_penalty = ( 0.0 if request.presence_penalty is None else request.presence_penalty ) @@ -1558,7 +1584,10 @@ def model_output_to_completion_output( tokens = None if with_token_probs: tokens = [ - TokenOutput(token=model_output["tokens"][index], log_prob=list(t.values())[0]) + TokenOutput( + token=model_output["tokens"][index], + log_prob=list(t.values())[0], + ) for index, t in enumerate(model_output["log_probs"]) ] return CompletionOutput( @@ -1711,7 +1740,8 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: @@ -1759,7 +1789,8 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1777,7 +1808,10 @@ async def execute( return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( - output, model_endpoint, request.prompt, request.return_token_log_probs + output, + model_endpoint, + request.prompt, + request.return_token_log_probs, ), ) elif endpoint_content.inference_framework == LLMInferenceFramework.VLLM: @@ -1814,7 +1848,8 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1831,7 +1866,10 @@ async def execute( return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( - output, model_endpoint, request.prompt, request.return_token_log_probs + output, + model_endpoint, + request.prompt, + request.return_token_log_probs, ), ) elif endpoint_content.inference_framework == LLMInferenceFramework.LIGHTLLM: @@ -1860,7 +1898,8 @@ async def execute( timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, ) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1877,7 +1916,10 @@ async def execute( return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( - output, model_endpoint, request.prompt, request.return_token_log_probs + output, + model_endpoint, + request.prompt, + request.return_token_log_probs, ), ) elif endpoint_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: @@ -1917,7 +1959,10 @@ async def execute( return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( - output, model_endpoint, request.prompt, request.return_token_log_probs + output, + model_endpoint, + request.prompt, + request.return_token_log_probs, ), ) else: @@ -2237,7 +2282,7 @@ async def _response_chunk_generator( output=CompletionStreamOutput( text=result["result"]["token"]["text"], finished=finished, - num_prompt_tokens=num_prompt_tokens if finished else None, + num_prompt_tokens=(num_prompt_tokens if finished else None), num_completion_tokens=num_completion_tokens, token=token, ), @@ -2492,21 +2537,19 @@ def __init__( async def create_batch_job_bundle( self, user: User, - request: CreateBatchCompletionsRequest, + request: CreateBatchCompletionsEngineRequest, hardware: CreateDockerImageBatchJobResourceRequests, ) -> DockerImageBatchJobBundle: + assert hardware.gpu_type is not None + bundle_name = ( f"{request.model_cfg.model}_{datetime.datetime.utcnow().strftime('%y%m%d-%H%M%S')}" ) - image_tag = self.docker_repository.get_latest_image_tag( - hmi_config.batch_inference_vllm_repository - ) + image_tag = await _get_latest_batch_tag(LLMInferenceFramework.VLLM) config_file_path = "/opt/config.json" - assert hardware.gpu_type is not None - batch_bundle = ( await self.docker_image_batch_job_bundle_repo.create_docker_image_batch_job_bundle( name=bundle_name, @@ -2534,8 +2577,8 @@ async def create_batch_job_bundle( return batch_bundle async def execute( - self, user: User, request: CreateBatchCompletionsRequest - ) -> CreateBatchCompletionsResponse: + self, user: User, request: CreateBatchCompletionsV1Request + ) -> CreateBatchCompletionsV1Response: if ( request.data_parallelism is not None and request.data_parallelism > 1 ): # pragma: no cover @@ -2552,14 +2595,10 @@ async def execute( request.model_cfg.checkpoint_path, is_batch_job=True, ) - # Reconcile gpus count with num_shards from request assert hardware.gpus is not None - if request.model_cfg.num_shards: - hardware.gpus = max(hardware.gpus, request.model_cfg.num_shards) - engine_request = CreateBatchCompletionsEngineRequest.from_api(request) + engine_request = CreateBatchCompletionsEngineRequest.from_api_v1(request) engine_request.model_cfg.num_shards = hardware.gpus - if engine_request.tool_config and engine_request.tool_config.name != "code_evaluator": raise ObjectHasInvalidValueException( "Only code_evaluator tool is supported for batch completions." @@ -2597,9 +2636,72 @@ async def execute( repo=batch_bundle.image_repository, tag=batch_bundle.image_tag, resource_requests=hardware, - labels=engine_request.model_cfg.labels, + labels=engine_request.labels, mount_location=batch_bundle.mount_location, override_job_max_runtime_s=engine_request.max_runtime_sec, num_workers=engine_request.data_parallelism, ) - return CreateBatchCompletionsResponse(job_id=job_id) + return CreateBatchCompletionsV1Response(job_id=job_id) + + +class CreateBatchCompletionsV2UseCase: + def __init__( + self, + llm_batch_completions_service: LLMBatchCompletionsService, + llm_artifact_gateway: LLMArtifactGateway, + ): + self.llm_batch_completions_service = llm_batch_completions_service + self.llm_artifact_gateway = llm_artifact_gateway + + async def execute( + self, user: User, request: CreateBatchCompletionsV2Request + ) -> CreateBatchCompletionsV2Response: + request.model_cfg.checkpoint_path = get_checkpoint_path( + request.model_cfg.model, request.model_cfg.checkpoint_path + ) + hardware = await _infer_hardware( + self.llm_artifact_gateway, + request.model_cfg.model, + request.model_cfg.checkpoint_path, + is_batch_job=True, + ) + + engine_request = CreateBatchCompletionsEngineRequest.from_api_v2(request) + engine_request.model_cfg.num_shards = hardware.gpus + + validate_resource_requests( + bundle=None, + cpus=hardware.cpus, + memory=hardware.memory, + storage=hardware.storage, + gpus=hardware.gpus, + gpu_type=hardware.gpu_type, + ) + + if engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1: + raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") + + # Right now we only support VLLM for batch inference. Refactor this if we support more inference frameworks. + image_repo = hmi_config.batch_inference_vllm_repository + image_tag = await _get_latest_batch_tag(LLMInferenceFramework.VLLM) + + additional_engine_args = infer_addition_engine_args_from_model_name( + engine_request.model_cfg.model + ) + + if additional_engine_args.gpu_memory_utilization is not None: + engine_request.max_gpu_memory_utilization = ( + additional_engine_args.gpu_memory_utilization + ) + + return await self.llm_batch_completions_service.create_batch_job( + user=user, + job_request=engine_request, + image_repo=image_repo, + image_tag=image_tag, + resource_requests=hardware, + labels=engine_request.labels, + max_runtime_sec=engine_request.max_runtime_sec, + priority=engine_request.priority, + num_workers=engine_request.data_parallelism, + ) diff --git a/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py b/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py new file mode 100644 index 00000000..40c155f5 --- /dev/null +++ b/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py @@ -0,0 +1,81 @@ +from datetime import datetime, timedelta +from typing import Dict, Optional + +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.llms import ( + BatchCompletionsJob, + BatchCompletionsJobStatus, + CreateBatchCompletionsEngineRequest, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.gateways.docker_image_batch_job_gateway import ( + DockerImageBatchJobGateway, +) +from model_engine_server.domain.services.llm_batch_completions_service import ( + LLMBatchCompletionsService, +) + + +class LiveLLMBatchCompletionsService(LLMBatchCompletionsService): + def __init__( + self, + docker_image_batch_job_gateway: DockerImageBatchJobGateway, + ): + self.docker_image_batch_job_gateway = docker_image_batch_job_gateway + + async def create_batch_job( + self, + *, + user: User, + image_repo: str, + image_tag: str, + job_request: CreateBatchCompletionsEngineRequest, + resource_requests: CreateDockerImageBatchJobResourceRequests, + max_runtime_sec: int = 24 * 60 * 60, + labels: Dict[str, str] = {}, + priority: Optional[int] = 0, + num_workers: Optional[int] = 1, + ): + config_file_path = "/opt/config.json" + env = {"CONFIG_FILE": config_file_path} + command = [ + "dumb-init", + "--", + "/bin/bash", + "-c", + "ddtrace-run python vllm_batch.py", + ] + + job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( + created_by=user.user_id, + owner=user.team_id, + job_config=job_request.model_dump(by_alias=True), + env=env, + command=command, + repo=image_repo, + tag=image_tag, + mount_location=config_file_path, + resource_requests=resource_requests, + labels=labels, + override_job_max_runtime_s=max_runtime_sec, + num_workers=num_workers, + ) + return BatchCompletionsJob( + job_id=job_id, + input_data_path=job_request.input_data_path, + output_data_path=job_request.output_data_path, + model_config=job_request.model_cfg, + priority=job_request.priority, + status=BatchCompletionsJobStatus.Queued, + created_at=datetime.now().isoformat(), + expires_at=(datetime.now() + timedelta(seconds=max_runtime_sec)).isoformat(), + completed_at=None, + metadata={"labels": job_request.labels}, + ) + + async def get_batch_job(self, batch_job_id: str) -> Optional[BatchCompletionsJob]: + raise NotImplementedError("Not implemented") + + async def cancel_batch_job(self, batch_job_id: str) -> bool: + # TODO: implement + raise NotImplementedError("Not implemented") diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 1c65cea6..9e8fbc95 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -6,6 +6,7 @@ from model_engine_server.common.dtos.llms import GetLLMModelEndpointV1Response from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus from model_engine_server.domain.entities import ModelEndpoint +from tests.unit.domain.test_llm_use_cases import mocked__get_latest_batch_tag from ..conftest import mocked__get_recommended_hardware_config_map @@ -263,6 +264,10 @@ def test_completion_stream_misc_server_error_returns_500( "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", mocked__get_recommended_hardware_config_map(), ) +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_batch_tag", + mocked__get_latest_batch_tag(), +) def test_create_batch_completions_success( create_batch_completions_request: Dict[str, Any], test_api_key: str, diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index aaad807c..23abfd9d 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -6,9 +6,9 @@ from model_engine_server.common.dtos.llms import ( CompletionStreamV1Request, CompletionSyncV1Request, - CreateBatchCompletionsModelConfig, - CreateBatchCompletionsRequest, - CreateBatchCompletionsRequestContent, + CreateBatchCompletionsV1ModelConfig, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1RequestContent, CreateLLMModelEndpointV1Request, UpdateLLMModelEndpointV1Request, ) @@ -162,7 +162,7 @@ def update_model_endpoint_request( @pytest.fixture -def create_docker_image_batch_job_bundle_request() -> CreateDockerImageBatchJobBundleV1Request: +def create_docker_image_batch_job_bundle_request() -> (CreateDockerImageBatchJobBundleV1Request): return CreateDockerImageBatchJobBundleV1Request( name="name", image_repository="repo", @@ -394,7 +394,7 @@ def create_llm_model_endpoint_text_generation_inference_request_async() -> ( @pytest.fixture -def create_llm_model_endpoint_trt_llm_request_streaming() -> CreateLLMModelEndpointV1Request: +def create_llm_model_endpoint_trt_llm_request_streaming() -> (CreateLLMModelEndpointV1Request): return CreateLLMModelEndpointV1Request( name="test_llm_endpoint_name_trt_llm_streaming", model_name="llama-2-7b", @@ -421,7 +421,7 @@ def create_llm_model_endpoint_trt_llm_request_streaming() -> CreateLLMModelEndpo @pytest.fixture -def create_llm_model_endpoint_trt_llm_request_async() -> CreateLLMModelEndpointV1Request: +def create_llm_model_endpoint_trt_llm_request_async() -> (CreateLLMModelEndpointV1Request): return CreateLLMModelEndpointV1Request( name="test_llm_endpoint_name_tgi_async", model_name="llama-2-7b", @@ -449,7 +449,7 @@ def create_llm_model_endpoint_trt_llm_request_async() -> CreateLLMModelEndpointV @pytest.fixture -def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndpointV1Request: +def create_llm_model_endpoint_request_invalid_model_name() -> (CreateLLMModelEndpointV1Request): return CreateLLMModelEndpointV1Request( name="test_llm_endpoint_name_1", model_name="nonexist", @@ -475,7 +475,7 @@ def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndp @pytest.fixture -def create_llm_model_endpoint_request_invalid_quantization() -> CreateLLMModelEndpointV1Request: +def create_llm_model_endpoint_request_invalid_quantization() -> (CreateLLMModelEndpointV1Request): return CreateLLMModelEndpointV1Request( name="test_llm_endpoint_name_1", model_name="nonexist", @@ -521,16 +521,16 @@ def completion_stream_request() -> CompletionStreamV1Request: @pytest.fixture -def create_batch_completions_request() -> CreateBatchCompletionsRequest: - return CreateBatchCompletionsRequest( +def create_batch_completions_v1_request() -> CreateBatchCompletionsV1Request: + return CreateBatchCompletionsV1Request( input_data_path="test_input_data_path", output_data_path="test_output_data_path", - content=CreateBatchCompletionsRequestContent( + content=CreateBatchCompletionsV1RequestContent( prompts=["What is machine learning?"], max_new_tokens=10, temperature=0.5, ), - model_config=CreateBatchCompletionsModelConfig( + model_config=CreateBatchCompletionsV1ModelConfig( model="mpt-7b", checkpoint_path="s3://test_checkpoint_path", labels={}, diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index d36165d0..cc3744ed 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -7,7 +7,7 @@ CompletionOutput, CompletionStreamV1Request, CompletionSyncV1Request, - CreateBatchCompletionsRequest, + CreateBatchCompletionsV1Request, CreateFineTuneRequest, CreateLLMModelEndpointV1Request, CreateLLMModelEndpointV1Response, @@ -60,6 +60,13 @@ from ..conftest import mocked__get_recommended_hardware_config_map +def mocked__get_latest_batch_tag(): + async def async_mock(*args, **kwargs): # noqa + return "fake_docker_repository_latest_image_tag" + + return mock.AsyncMock(side_effect=async_mock) + + def mocked__get_latest_tag(): async def async_mock(*args, **kwargs): # noqa return "fake_docker_repository_latest_image_tag" @@ -201,7 +208,12 @@ async def test_create_model_endpoint_use_case_success( @pytest.mark.parametrize( "inference_framework, model_name, checkpoint_path, expected_error", [ - (LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "mpt-7b", None, InvalidRequestException), + ( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + "mpt-7b", + None, + InvalidRequestException, + ), ( LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "mpt-7b-instruct", @@ -1943,7 +1955,12 @@ async def test_validate_checkpoint_files_no_safetensors(): @pytest.mark.asyncio async def test_validate_checkpoint_files_safetensors_with_other_files(): - fake_model_files = ["model-fake.bin", "model-fake2.safetensors", "model.json", "optimizer.pt"] + fake_model_files = [ + "model-fake.bin", + "model-fake2.safetensors", + "model.json", + "optimizer.pt", + ] validate_checkpoint_files(fake_model_files) # No exception should be raised @@ -2069,7 +2086,10 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 hardware = await _infer_hardware( - fake_llm_artifact_gateway, "deepseek-coder-v2-lite-instruct", "", is_batch_job=True + fake_llm_artifact_gateway, + "deepseek-coder-v2-lite-instruct", + "", + is_batch_job=True, ) assert hardware.cpus == 160 assert hardware.gpus == 8 @@ -2602,13 +2622,17 @@ async def test_fill_hardware_info(fake_llm_artifact_gateway): "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", mocked__get_recommended_hardware_config_map(), ) -async def test_create_batch_completions( +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_batch_tag", + mocked__get_latest_batch_tag(), +) +async def test_create_batch_completions_v1( fake_docker_image_batch_job_gateway, fake_docker_repository_image_always_exists, fake_docker_image_batch_job_bundle_repository, fake_llm_artifact_gateway, test_api_key: str, - create_batch_completions_request: CreateBatchCompletionsRequest, + create_batch_completions_v1_request: CreateBatchCompletionsV1Request, ): use_case = CreateBatchCompletionsUseCase( docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway, @@ -2618,10 +2642,10 @@ async def test_create_batch_completions( ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - result = await use_case.execute(user, create_batch_completions_request) + result = await use_case.execute(user, create_batch_completions_v1_request) job = await fake_docker_image_batch_job_gateway.get_docker_image_batch_job(result.job_id) - assert job.num_workers == create_batch_completions_request.data_parallelism + assert job.num_workers == create_batch_completions_v1_request.data_parallelism bundle = list(fake_docker_image_batch_job_bundle_repository.db.values())[0] assert bundle.command == [ From b58cf41ede95f9729cb4288580009f3d922e5afd Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 14 Aug 2024 10:58:00 -0700 Subject: [PATCH 363/425] Add Qwen2 72b instruct (#594) * Add Qwen2 72b instruct * update * fommatting --- .../use_cases/llm_model_endpoint_use_cases.py | 1 + .../repositories/live_tokenizer_repository.py | 4 +++ .../tests/unit/domain/test_llm_use_cases.py | 32 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 54a88383..a27e9813 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -231,6 +231,7 @@ "deepseek-coder-v2-instruct", "deepseek-coder-v2-lite", "deepseek-coder-v2-lite-instruct", + "qwen2-72b-instruct", ] ), LLMInferenceFramework.LIGHTLLM: set( diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 9f55217c..94ad089c 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -103,6 +103,10 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "deepseek-coder-v2-lite-instruct": ModelInfo( "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", None ), + "qwen2-72b-instruct": ModelInfo( + "Qwen/Qwen2-72B-Instruct", + None, + ), } diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index cc3744ed..8d6bb930 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -2571,6 +2571,38 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + fake_llm_artifact_gateway.model_config = { + "architectures": ["Qwen2ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 29568, + "max_position_embeddings": 32768, + "max_window_layers": 80, + "model_type": "qwen2", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 131072, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.1", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 152064, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "qwen2-72b-instruct", "") + assert hardware.cpus == 80 + assert hardware.gpus == 4 + assert hardware.memory == "320Gi" + assert hardware.storage == "320Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + with pytest.raises(ObjectHasInvalidValueException): await _infer_hardware(fake_llm_artifact_gateway, "unsupported_model", "") From 065fb9de05a2a4baab9d836b5381fb806386f665 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 14 Aug 2024 21:45:51 -0700 Subject: [PATCH 364/425] Get Gemma2 working (#595) * Install flash-infer == 0.0.8 * Lots of changes... * cleanup --- .../common/dtos/llms/batch_completion.py | 21 +- .../use_cases/llm_model_endpoint_use_cases.py | 100 +++++----- .../inference/batch_inference/Dockerfile_vllm | 23 --- .../inference/batch_inference/README.md | 3 + .../batch_inference/build_and_upload_image.sh | 21 -- .../inference/batch_inference/dto.py | 26 ++- .../generate_tool_sample_data.py | 0 .../examples/sample_config.json | 14 ++ .../examples/sample_config_gemma.json | 14 ++ .../examples/sample_config_mixtral.json | 13 ++ .../examples/sample_config_tool.json | 16 ++ .../{ => examples}/sample_data.json | 6 +- .../{ => examples}/sample_data_tool.json | 0 .../batch_inference/sample_config.json | 11 -- .../batch_inference/sample_config_tool.json | 14 -- .../inference/batch_inference/vllm_batch.py | 20 +- .../inference/vllm/Dockerfile | 17 -- .../inference/vllm/Dockerfile.vllm | 43 +++++ .../inference/vllm/build_and_upload_image.sh | 33 +++- .../inference/vllm/requirements.txt | 2 +- .../inference/vllm/vllm_server.py | 179 ++++++------------ 21 files changed, 299 insertions(+), 277 deletions(-) delete mode 100644 model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm create mode 100644 model-engine/model_engine_server/inference/batch_inference/README.md delete mode 100755 model-engine/model_engine_server/inference/batch_inference/build_and_upload_image.sh rename model-engine/model_engine_server/inference/batch_inference/{ => examples}/generate_tool_sample_data.py (100%) create mode 100644 model-engine/model_engine_server/inference/batch_inference/examples/sample_config.json create mode 100644 model-engine/model_engine_server/inference/batch_inference/examples/sample_config_gemma.json create mode 100644 model-engine/model_engine_server/inference/batch_inference/examples/sample_config_mixtral.json create mode 100644 model-engine/model_engine_server/inference/batch_inference/examples/sample_config_tool.json rename model-engine/model_engine_server/inference/batch_inference/{ => examples}/sample_data.json (57%) rename model-engine/model_engine_server/inference/batch_inference/{ => examples}/sample_data_tool.json (100%) delete mode 100644 model-engine/model_engine_server/inference/batch_inference/sample_config.json delete mode 100644 model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json delete mode 100644 model-engine/model_engine_server/inference/vllm/Dockerfile create mode 100644 model-engine/model_engine_server/inference/vllm/Dockerfile.vllm diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py index 4fa59281..da45589e 100644 --- a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -250,7 +250,20 @@ class GetBatchCompletionV2Response(BaseModel): ] -class CreateBatchCompletionsEngineRequest(BatchCompletionsRequestBase): +class VLLMEngineAdditionalArgs(BaseModel): + max_gpu_memory_utilization: Optional[float] = Field( + default=0.9, + le=1.0, + description="Maximum GPU memory utilization for the batch inference. Default to 90%.", + ) + + attention_backend: Optional[str] = Field( + default=None, + description="Attention backend to use for vLLM. Default to None.", + ) + + +class CreateBatchCompletionsEngineRequest(BatchCompletionsRequestBase, VLLMEngineAdditionalArgs): """ Internal model for representing request to the inference framework. This contains additional fields that we want hidden from the DTO exposed to the client. @@ -265,12 +278,6 @@ class CreateBatchCompletionsEngineRequest(BatchCompletionsRequestBase): description="""Model configuration for the batch inference. Hardware configurations are inferred.""", ) - max_gpu_memory_utilization: Optional[float] = Field( - default=0.9, - le=1.0, - description="Maximum GPU memory utilization for the batch inference. Default to 90%.", - ) - @staticmethod def from_api_v1( request: CreateBatchCompletionsV1Request, diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index a27e9813..659215c4 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -9,7 +9,7 @@ import math import os import re -from dataclasses import asdict, dataclass +from dataclasses import asdict from functools import lru_cache from typing import Any, AsyncIterable, Dict, List, Optional, Union @@ -39,6 +39,7 @@ UpdateLLMModelEndpointV1Request, UpdateLLMModelEndpointV1Response, ) +from model_engine_server.common.dtos.llms.batch_completion import VLLMEngineAdditionalArgs from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus @@ -798,6 +799,9 @@ async def create_vllm_bundle( if "llama-3-70b" in model_name: subcommands[-1] = subcommands[-1] + " --gpu-memory-utilization 0.95 --enforce-eager" + if "gemma-2" in model_name: + subcommands[-1] = subcommands[-1] + " --attention-backend FLASHINFER" + command = [ "/bin/bash", "-c", @@ -2423,27 +2427,8 @@ async def _fill_hardware_info( request.num_shards = hardware_info.gpus -@lru_cache() -async def _infer_hardware( - llm_artifact_gateway: LLMArtifactGateway, - model_name: str, - checkpoint_path: str, - is_batch_job: bool = False, -) -> CreateDockerImageBatchJobResourceRequests: - config = llm_artifact_gateway.get_model_config(checkpoint_path) - - dtype_size = 2 - kv_multiplier = 20 if is_batch_job else 2 - - min_kv_cache_size = ( - kv_multiplier - * dtype_size - * config["num_hidden_layers"] - * config["hidden_size"] - * config["max_position_embeddings"] - // (config["num_attention_heads"] // config["num_key_value_heads"]) - ) - +def get_model_param_count_b(model_name: str) -> int: + """Get the number of parameters in the model in billions""" if "mixtral-8x7b" in model_name: model_param_count_b = 47 elif "mixtral-8x22b" in model_name: @@ -2465,7 +2450,31 @@ async def _infer_hardware( f"Unable to infer number of parameters for {model_name}." ) model_param_count_b = int(numbers[-1]) + return model_param_count_b + +@lru_cache() +async def _infer_hardware( + llm_artifact_gateway: LLMArtifactGateway, + model_name: str, + checkpoint_path: str, + is_batch_job: bool = False, +) -> CreateDockerImageBatchJobResourceRequests: + config = llm_artifact_gateway.get_model_config(checkpoint_path) + + dtype_size = 2 + kv_multiplier = 20 if is_batch_job else 2 + + min_kv_cache_size = ( + kv_multiplier + * dtype_size + * config["num_hidden_layers"] + * config["hidden_size"] + * config["max_position_embeddings"] + // (config["num_attention_heads"] // config["num_key_value_heads"]) + ) + + model_param_count_b = get_model_param_count_b(model_name) model_weights_size = dtype_size * model_param_count_b * 1_000_000_000 min_memory_gb = math.ceil((min_kv_cache_size + model_weights_size) / 1_000_000_000 / 0.9) @@ -2501,25 +2510,25 @@ async def _infer_hardware( ) -@dataclass -class VLLMEngineArgs: - gpu_memory_utilization: Optional[float] = None - - -def infer_addition_engine_args_from_model_name(model_name: str) -> VLLMEngineArgs: - numbers = re.findall(r"\d+", model_name) - if len(numbers) == 0: - raise ObjectHasInvalidValueException( - f"Model {model_name} is not supported for batch completions." - ) - - b_params = int(numbers[-1]) - if b_params >= 70: +def infer_addition_engine_args_from_model_name( + model_name: str, +) -> VLLMEngineAdditionalArgs: + # Increase max gpu utilization for larger models + model_param_count_b = get_model_param_count_b(model_name) + if model_param_count_b >= 70: gpu_memory_utilization = 0.95 else: gpu_memory_utilization = 0.9 - return VLLMEngineArgs(gpu_memory_utilization=gpu_memory_utilization) + # Gemma 2 requires flashinfer attention backend + attention_backend = None + if model_name.startswith("gemma-2"): + attention_backend = "FLASHINFER" + + return VLLMEngineAdditionalArgs( + max_gpu_memory_utilization=gpu_memory_utilization, + attention_backend=attention_backend, + ) class CreateBatchCompletionsUseCase: @@ -2609,10 +2618,10 @@ async def execute( engine_request.model_cfg.model ) - if additional_engine_args.gpu_memory_utilization is not None: - engine_request.max_gpu_memory_utilization = ( - additional_engine_args.gpu_memory_utilization - ) + engine_request.max_gpu_memory_utilization = ( + additional_engine_args.max_gpu_memory_utilization + ) + engine_request.attention_backend = additional_engine_args.attention_backend batch_bundle = await self.create_batch_job_bundle(user, engine_request, hardware) @@ -2689,11 +2698,10 @@ async def execute( additional_engine_args = infer_addition_engine_args_from_model_name( engine_request.model_cfg.model ) - - if additional_engine_args.gpu_memory_utilization is not None: - engine_request.max_gpu_memory_utilization = ( - additional_engine_args.gpu_memory_utilization - ) + engine_request.max_gpu_memory_utilization = ( + additional_engine_args.max_gpu_memory_utilization + ) + engine_request.attention_backend = additional_engine_args.attention_backend return await self.llm_batch_completions_service.create_batch_job( user=user, diff --git a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm b/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm deleted file mode 100644 index 3b08756c..00000000 --- a/model-engine/model_engine_server/inference/batch_inference/Dockerfile_vllm +++ /dev/null @@ -1,23 +0,0 @@ -FROM nvcr.io/nvidia/pytorch:23.09-py3 - -RUN apt-get update && \ - apt-get install -y dumb-init psmisc && \ - apt-get autoremove -y && \ - rm -rf /var/lib/apt/lists/* && \ - apt-get clean - -RUN pip uninstall torch -y -RUN pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cu121 - -RUN pip uninstall xformers -y -RUN pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu121 - -RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz -RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz - -COPY model-engine/model_engine_server/inference/batch_inference/requirements.txt /workspace/requirements.txt -RUN pip install -r requirements.txt - -COPY model-engine /workspace/model-engine -RUN pip install -e /workspace/model-engine -COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py diff --git a/model-engine/model_engine_server/inference/batch_inference/README.md b/model-engine/model_engine_server/inference/batch_inference/README.md new file mode 100644 index 00000000..0c380633 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/README.md @@ -0,0 +1,3 @@ +# Notes + +We will merge this with inference/vllm. In the meantime, you can build the batch image via inference/vllm/build_and_publish_image.sh \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/build_and_upload_image.sh b/model-engine/model_engine_server/inference/batch_inference/build_and_upload_image.sh deleted file mode 100755 index 2bd519ed..00000000 --- a/model-engine/model_engine_server/inference/batch_inference/build_and_upload_image.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# Build and push batch inference vLLM docker image to AWS ECR. - -set -eo pipefail - -if [ -z "$1" ]; then - echo "Must supply AWS account ID" - exit 1; -fi - -if [ -z "$2" ]; then - echo "Must supply the image tag" - exit 1; -fi - -IMAGE_TAG=$2 -ACCOUNT=$1 -aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com -DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:$IMAGE_TAG -f Dockerfile_vllm ../../../../ -docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:$IMAGE_TAG diff --git a/model-engine/model_engine_server/inference/batch_inference/dto.py b/model-engine/model_engine_server/inference/batch_inference/dto.py index 63f02efe..0d3e2c05 100644 --- a/model-engine/model_engine_server/inference/batch_inference/dto.py +++ b/model-engine/model_engine_server/inference/batch_inference/dto.py @@ -1,4 +1,4 @@ -# This is a copy of model_engine_server.common.dtos.llm +# This is a copy of model_engine_server.common.dtos.llms.batch_completion.py # This is done to decouple the pydantic requirements since vllm requires pydantic >2 # while model engine is on 1.x from enum import Enum @@ -143,13 +143,26 @@ class CreateBatchCompletionsRequest(BaseModel): """ -class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest): +class VLLMEngineAdditionalArgs(BaseModel): + max_gpu_memory_utilization: Optional[float] = Field( + default=0.9, + le=1.0, + description="Maximum GPU memory utilization for the model. Default to 90%.", + ) + + attention_backend: Optional[str] = Field( + default=None, + description="Attention backend to use for vLLM. Default to None.", + ) + + +class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest, VLLMEngineAdditionalArgs): """ - Internal model for representing request to the llm engine. This contains additional fields that we want + Internal model for representing request to the inference framework. This contains additional fields that we want hidden from the DTO exposed to the client. """ - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, protected_namespaces=()) model_cfg: CreateBatchCompletionsModelConfig = Field(alias="model_config") """ @@ -160,8 +173,3 @@ class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest): We alias `model_config` for deserialization for backwards compatibility. """ - - max_gpu_memory_utilization: Optional[float] = Field(default=0.9, le=1.0) - """ - Maximum GPU memory utilization for the batch inference. Default to 90%. - """ diff --git a/model-engine/model_engine_server/inference/batch_inference/generate_tool_sample_data.py b/model-engine/model_engine_server/inference/batch_inference/examples/generate_tool_sample_data.py similarity index 100% rename from model-engine/model_engine_server/inference/batch_inference/generate_tool_sample_data.py rename to model-engine/model_engine_server/inference/batch_inference/examples/generate_tool_sample_data.py diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_config.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config.json new file mode 100644 index 00000000..8944e66c --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config.json @@ -0,0 +1,14 @@ +{ + "input_data_path": "./examples/sample_data.json", + "output_data_path": "./examples/sample_output.json", + "model_config": { + "model": "mixtral-8x7b-instruct-v0.1", + "checkpoint_path": "my_path", + "num_shards": 2, + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_gemma.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_gemma.json new file mode 100644 index 00000000..e988c2f9 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_gemma.json @@ -0,0 +1,14 @@ +{ + "input_data_path": "./examples/sample_data.json", + "output_data_path": "./examples/sample_output.json", + "model_config": { + "model": "gemma-2-2b-it", + "checkpoint_path": "my_path", + "num_shards": 1, + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_mixtral.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_mixtral.json new file mode 100644 index 00000000..2c5fcc97 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_mixtral.json @@ -0,0 +1,13 @@ +{ + "input_data_path": "./examples/sample_data.json", + "output_data_path": "./examples/sample_output.json", + "model_config": { + "model": "mixtral-8x7b-instruct-v0.1", + "checkpoint_path": "my_path", + "num_shards": 2, + "labels": { + "team": "my_team" + } + }, + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_tool.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_tool.json new file mode 100644 index 00000000..457131d7 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_tool.json @@ -0,0 +1,16 @@ +{ + "input_data_path": "./sample_data_tool.json", + "output_data_path": "./sample_output_tool.json", + "model_config": { + "model": "gemma-2-2b-it", + "checkpoint_path": "/workspace/model_files/gemma-2-2b-it", + "num_shards": 1, + "labels": { + "team": "my_team" + } + }, + "data_parallelism": 2, + "tool_config": { + "name": "code_evaluator" + } +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_data.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_data.json similarity index 57% rename from model-engine/model_engine_server/inference/batch_inference/sample_data.json rename to model-engine/model_engine_server/inference/batch_inference/examples/sample_data.json index 87eb3169..d8fa3a68 100644 --- a/model-engine/model_engine_server/inference/batch_inference/sample_data.json +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_data.json @@ -1,7 +1,7 @@ { - "prompts":[ - "deep learning is", - "san francisco is" + "prompts": [ + "san francisco is", + "deep learning is" ], "max_new_tokens": 100, "temperature": 0.0, diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_data_tool.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_data_tool.json similarity index 100% rename from model-engine/model_engine_server/inference/batch_inference/sample_data_tool.json rename to model-engine/model_engine_server/inference/batch_inference/examples/sample_data_tool.json diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_config.json b/model-engine/model_engine_server/inference/batch_inference/sample_config.json deleted file mode 100644 index d047d7f8..00000000 --- a/model-engine/model_engine_server/inference/batch_inference/sample_config.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "input_data_path":"./sample_data.json", - "output_data_path":"./sample_output.json", - "model_config":{ - "model":"llama-2-7b", - "checkpoint_path":"my_path", - "num_shards": 1, - "labels": {"team": "my_team"} - }, - "data_parallelism":2 -} diff --git a/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json b/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json deleted file mode 100644 index 3f21befe..00000000 --- a/model-engine/model_engine_server/inference/batch_inference/sample_config_tool.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "input_data_path":"./sample_data_tool.json", - "output_data_path":"./sample_output_tool.json", - "model_config":{ - "model":"mistral-7b", - "checkpoint_path":"my_path", - "num_shards": 1, - "labels": {"team": "my_team"} - }, - "data_parallelism":2, - "tool_config": { - "name": "code_evaluator" - } -} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 3dda4a34..74f8782b 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -63,9 +63,13 @@ def download_model(checkpoint_path, final_weights_folder): # Need to override these env vars so s5cmd uses AWS_PROFILE env["AWS_ROLE_ARN"] = "" env["AWS_WEB_IDENTITY_TOKEN_FILE"] = "" - # nosemgrep process = subprocess.Popen( - s5cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env + s5cmd, + shell=True, # nosemgrep + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=env, ) for line in process.stdout: print(line, flush=True) @@ -323,6 +327,9 @@ async def batch_inference(config_file_data: Optional[str]): request = CreateBatchCompletionsEngineRequest.model_validate_json(config_file_data) + if request.attention_backend is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = request.attention_backend + if request.model_cfg.checkpoint_path is not None: download_model(request.model_cfg.checkpoint_path, MODEL_WEIGHTS_FOLDER) @@ -432,7 +439,7 @@ async def generate_with_vllm( frequency_penalty=frequency_penalty or 0.0, top_k=top_k or -1, top_p=top_p or 1.0, - skip_special_tokens=skip_special_tokens if skip_special_tokens is not None else True, + skip_special_tokens=(skip_special_tokens if skip_special_tokens is not None else True), ) results_generator = await engine.add_request( request_id, prompt, sampling_params, time.monotonic(), None @@ -503,9 +510,11 @@ def check_unknown_startup_memory_usage(): # pragma: no cover f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." ) try: - # nosemgrep output = subprocess.run( - ["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True + ["fuser -v /dev/nvidia*"], + shell=True, # nosemgrep + capture_output=True, + text=True, ).stdout print(f"Processes using GPU: {output}") except Exception as e: @@ -522,5 +531,6 @@ def check_unknown_startup_memory_usage(): # pragma: no cover help="Optional override for the config file data, as a json string", ) args = parser.parse_args() + check_unknown_startup_memory_usage() asyncio.run(batch_inference(args.config_file_data)) diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile b/model-engine/model_engine_server/inference/vllm/Dockerfile deleted file mode 100644 index 75b9e1f5..00000000 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile +++ /dev/null @@ -1,17 +0,0 @@ -FROM nvcr.io/nvidia/pytorch:23.09-py3 - -RUN apt-get update \ - && apt-get install -y \ - gdb \ - psmisc \ - && apt-get autoremove -y \ - && rm -rf /var/lib/apt/lists/* - -RUN pip uninstall torch -y -COPY requirements.txt /workspace/requirements.txt -RUN pip install -r requirements.txt - -RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz -RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz - -COPY vllm_server.py /workspace/vllm_server.py diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm new file mode 100644 index 00000000..4109939f --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm @@ -0,0 +1,43 @@ +# syntax=docker/dockerfile:1 +ARG VLLM_VERSION=0.5.3.post1 +ARG VLLM_BASE_IMAGE=vllm/vllm-openai:v${VLLM_VERSION} +FROM ${VLLM_BASE_IMAGE} AS base + +RUN apt-get update \ + && apt-get install -y wget gdb psmisc dumb-init \ + && apt-get autoremove -y \ + && rm -rf /var/lib/apt/lists/* \ + apt-get clean + +# Need to fix flashinfer at 0.0.8 to support gemma models +# See https://github.com/vllm-project/vllm/issues/7060#issuecomment-2266248014 +# vLLM 0.5.3 depends on torch 2.3.1 +RUN pip uninstall flashinfer -y +RUN pip install flashinfer==0.0.8 --index-url https://flashinfer.ai/whl/cu121/torch2.3 + +WORKDIR /workspace + +RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz +RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz + +# symlink python to python3 +RUN ln -s /usr/bin/python3 /usr/bin/python + +FROM base AS vllm + +COPY model-engine/model_engine_server/inference/vllm/vllm_server.py /workspace/vllm_server.py + +# Need to override entrypoint from parent image +ENTRYPOINT ["/bin/env"] + +FROM base AS vllm_batch + +COPY model-engine/model_engine_server/inference/batch_inference/requirements.txt /workspace/requirements.txt +RUN pip install -r requirements.txt + +COPY model-engine /workspace/model-engine +RUN pip install -e /workspace/model-engine +COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py + +# Need to override entrypoint from parent image +ENTRYPOINT ["/bin/env"] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh index 750da0e0..8b6175b6 100755 --- a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh +++ b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh @@ -1,8 +1,14 @@ #!/bin/bash +set -eo pipefail + # Build and push vLLM docker image to AWS ECR. +# +# Usage: VLLM_VERSION=0.5.3.post1 ./build_and_upload_image.sh vllm|vllm_batch -set -eo pipefail +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +PROJECT_DIR=$SCRIPT_DIR/../../../.. +DOCKERFILE=$PROJECT_DIR/model_engine_server/inference/vllm/Dockerfile.vllm if [ -z "$1" ]; then echo "Must supply AWS account ID" @@ -14,8 +20,27 @@ if [ -z "$2" ]; then exit 1; fi -IMAGE_TAG=$2 +if [ -z "$3" ]; then + echo "Must supply the build target (either vllm or vllm_batch)" + exit 1; +fi + ACCOUNT=$1 +IMAGE_TAG=$2 +BUILD_TARGET=$3 +VLLM_VERSION=${VLLM_VERSION:-"0.5.3.post1"} + +# if build target = vllm use vllm otherwise use vllm_batch +if [ "$BUILD_TARGET" == "vllm" ]; then + IMAGE=$ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/vllm:$IMAGE_TAG +else + IMAGE=$ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:$IMAGE_TAG +fi + aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com -DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/vllm:$IMAGE_TAG . -docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/vllm:$IMAGE_TAG +DOCKER_BUILDKIT=1 docker build \ + --build-arg VLLM_VERSION=${VLLM_VERSION} \ + -f Dockerfile.vllm \ + --target ${BUILD_TARGET} \ + -t $IMAGE ${PROJECT_DIR} +docker push $IMAGE diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 8fa7cb6e..c6984a80 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,2 +1,2 @@ -vllm==0.5.3.post1 +vllm>=0.5.4 pydantic>=2.0 diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index a30d4a25..5d19111a 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -1,7 +1,6 @@ import asyncio import code import json -import logging import os import signal import subprocess @@ -9,19 +8,14 @@ from logging import Logger from typing import AsyncGenerator, Dict, List, Optional -import uvicorn -from fastapi import BackgroundTasks, FastAPI, HTTPException, Request -from fastapi.responses import JSONResponse, Response, StreamingResponse -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncEngineDeadError, AsyncLLMEngine +from fastapi import APIRouter, BackgroundTasks, HTTPException, Request +from fastapi.responses import Response, StreamingResponse +from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.protocol import AsyncEngineClient +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.openai.api_server import build_async_engine_client, init_app from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.entrypoints.openai.protocol import ChatCompletionRequest as OpenAIChatCompletionRequest -from vllm.entrypoints.openai.protocol import ChatCompletionResponse as OpenAIChatCompletionResponse from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest -from vllm.entrypoints.openai.protocol import ErrorResponse -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.outputs import CompletionOutput from vllm.sampling_params import SamplingParams @@ -29,61 +23,18 @@ from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION -logging.basicConfig( - format="%(asctime)s | %(levelname)s: %(message)s", - datefmt="%b/%d %H:%M:%S", - level=logging.INFO, -) - logger = Logger("vllm_server") +async_engine_client: AsyncEngineClient + TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds -app = FastAPI() - -openai_serving_chat: OpenAIServingChat -openai_serving_completion: OpenAIServingCompletion -openai_serving_embedding: OpenAIServingEmbedding - - -@app.get("/healthz") -@app.get("/health") -async def healthcheck(): - await openai_serving_chat.engine.check_health() - return Response(status_code=200) - - -@app.get("/v1/models") -async def show_available_models(): - models = await openai_serving_chat.show_available_models() - return JSONResponse(content=models.model_dump()) - - -@app.post("/v1/chat/completions") -async def create_chat_completion(request: OpenAIChatCompletionRequest, raw_request: Request): - generator = await openai_serving_chat.create_chat_completion(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, media_type="text/event-stream") - else: - assert isinstance(generator, OpenAIChatCompletionResponse) - return JSONResponse(content=generator.model_dump()) - -@app.post("/v1/completions") -async def create_completion(request: OpenAICompletionRequest, raw_request: Request): - generator = await openai_serving_completion.create_completion(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, media_type="text/event-stream") - else: - return JSONResponse(content=generator.model_dump()) +router = APIRouter() -@app.post("/predict") -@app.post("/stream") +@router.post("/predict") +@router.post("/stream") async def generate(request: Request) -> Response: """Generate completion for the request. @@ -94,7 +45,7 @@ async def generate(request: Request) -> Response: """ # check health before accepting request and fail fast if engine isn't healthy try: - await engine.check_health() + await async_engine_client.check_health() request_dict = await request.json() prompt = request_dict.pop("prompt") @@ -119,12 +70,17 @@ async def generate(request: Request) -> Response: ) except Exception: raise HTTPException( - status_code=400, detail="Bad request: failed to parse guided decoding parameters." + status_code=400, + detail="Bad request: failed to parse guided decoding parameters.", ) - guided_decoding_backend = engine.engine.decoding_config.guided_decoding_backend + guided_decoding_backend = ( + await async_engine_client.get_decoding_config() + ).guided_decoding_backend guided_decode_logit_processor = await get_guided_decoding_logits_processor( - guided_decoding_backend, partial_openai_request, await engine.get_tokenizer() + guided_decoding_backend, + partial_openai_request, + await async_engine_client.get_tokenizer(lora_request=None), ) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: @@ -133,10 +89,10 @@ async def generate(request: Request) -> Response: request_id = random_uuid() - results_generator = engine.generate(prompt, sampling_params, request_id) + results_generator = async_engine_client.generate(prompt, sampling_params, request_id) async def abort_request() -> None: - await engine.abort(request_id) + await async_engine_client.abort(request_id) if stream: # Streaming case @@ -148,9 +104,9 @@ async def stream_results() -> AsyncGenerator[str, None]: "text": request_output.outputs[-1].text[len(last_output_text) :], "count_prompt_tokens": len(request_output.prompt_token_ids), "count_output_tokens": len(request_output.outputs[0].token_ids), - "log_probs": log_probs[-1] - if log_probs and sampling_params.logprobs - else None, + "log_probs": ( + log_probs[-1] if log_probs and sampling_params.logprobs else None + ), "finished": request_output.finished, } last_output_text = request_output.outputs[-1].text @@ -171,7 +127,7 @@ async def stream_results() -> AsyncGenerator[str, None]: last_output_text = request_output.outputs[-1].text if await request.is_disconnected(): # Abort the request if the client disconnects. - await engine.abort(request_id) + await async_engine_client.abort(request_id) return Response(status_code=499) final_output = request_output @@ -220,7 +176,10 @@ def check_unknown_startup_memory_usage(): try: # nosemgrep output = subprocess.run( - ["fuser -v /dev/nvidia*"], shell=True, capture_output=True, text=True + ["fuser -v /dev/nvidia*"], + shell=False, + capture_output=True, + text=True, ).stdout logger.info(f"Processes using GPU: {output}") except Exception as e: @@ -240,7 +199,9 @@ def debug(sig, frame): i.interact(message) -def format_logprobs(request_output: CompletionOutput) -> Optional[List[Dict[int, float]]]: +def format_logprobs( + request_output: CompletionOutput, +) -> Optional[List[Dict[int, float]]]: """Given a request output, format the logprobs if they exist.""" output_logprobs = request_output.outputs[0].logprobs if output_logprobs is None: @@ -254,55 +215,41 @@ def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: def parse_args(parser: FlexibleArgumentParser): parser = make_arg_parser(parser) + parser.add_argument("--attention-backend", type=str, help="The attention backend to use") return parser.parse_args() +async def run_server(args, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + global async_engine_client + async with build_async_engine_client(args) as async_engine_client: + app = await init_app(async_engine_client, args) + app.include_router(router) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + await shutdown_task + + if __name__ == "__main__": check_unknown_startup_memory_usage() parser = FlexibleArgumentParser() - # host, port, and AsyncEngineArgs are already given by make_arg_parser() in parse_args() - # host == None -> IPv4 / IPv6 dualstack args = parse_args(parser) - - logger.info("vLLM version %s", VLLM_VERSION) - logger.info("args: %s", args) - - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] - - signal.signal(signal.SIGUSR1, debug) - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args(engine_args) - - model_config = asyncio.run(engine.get_model_config()) - - openai_serving_chat = OpenAIServingChat( - engine, - model_config, - served_model_names, - args.response_role, - lora_modules=args.lora_modules, - chat_template=args.chat_template, - prompt_adapters=args.prompt_adapters, - request_logger=None, - ) - openai_serving_completion = OpenAIServingCompletion( - engine, - model_config, - served_model_names, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, - request_logger=None, - ) - - uvicorn.run( - app, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ) + if args.attention_backend is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = args.attention_backend + asyncio.run(run_server(args)) From 6de5b7c9275dfa41929b806e5a8a084863ed5667 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 15 Aug 2024 10:11:54 -0700 Subject: [PATCH 365/425] Allow setting max context length for batch jobs (#598) --- .../common/dtos/llms/batch_completion.py | 6 ++++++ .../use_cases/llm_model_endpoint_use_cases.py | 13 +++++++++++-- .../inference/batch_inference/dto.py | 6 ++++++ .../inference/batch_inference/vllm_batch.py | 1 + .../tests/unit/domain/test_llm_use_cases.py | 13 +++++++++++++ 5 files changed, 37 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py index da45589e..248056cb 100644 --- a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -58,6 +58,12 @@ class BatchCompletionsModelConfig(BaseModel): """, ) + max_context_length: Optional[int] = Field( + default=None, + ge=1, + description="Maximum context length to use for the model. Defaults to the max allowed by the model", + ) + seed: Optional[int] = Field(default=None, description="Random seed for the model.") diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 659215c4..eb108194 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -2459,18 +2459,25 @@ async def _infer_hardware( model_name: str, checkpoint_path: str, is_batch_job: bool = False, + max_context_length: Optional[int] = None, ) -> CreateDockerImageBatchJobResourceRequests: config = llm_artifact_gateway.get_model_config(checkpoint_path) dtype_size = 2 kv_multiplier = 20 if is_batch_job else 2 + max_position_embeddings = ( + min(max_context_length, config["max_position_embeddings"]) + if max_context_length + else config["max_position_embeddings"] + ) + min_kv_cache_size = ( kv_multiplier * dtype_size * config["num_hidden_layers"] * config["hidden_size"] - * config["max_position_embeddings"] + * max_position_embeddings // (config["num_attention_heads"] // config["num_key_value_heads"]) ) @@ -2480,7 +2487,7 @@ async def _infer_hardware( min_memory_gb = math.ceil((min_kv_cache_size + model_weights_size) / 1_000_000_000 / 0.9) logger.info( - f"Memory calculation result: {min_memory_gb=} for {model_name}, min_kv_cache_size: {min_kv_cache_size}, model_weights_size: {model_weights_size}, is_batch_job: {is_batch_job}" + f"Memory calculation result: {min_memory_gb=} for {model_name} context_size: {max_position_embeddings}, min_kv_cache_size: {min_kv_cache_size}, model_weights_size: {model_weights_size}, is_batch_job: {is_batch_job}" ) config_map = await _get_recommended_hardware_config_map() @@ -2604,6 +2611,7 @@ async def execute( request.model_cfg.model, request.model_cfg.checkpoint_path, is_batch_job=True, + max_context_length=request.model_cfg.max_context_length, ) assert hardware.gpus is not None @@ -2674,6 +2682,7 @@ async def execute( request.model_cfg.model, request.model_cfg.checkpoint_path, is_batch_job=True, + max_context_length=request.model_cfg.max_context_length, ) engine_request = CreateBatchCompletionsEngineRequest.from_api_v2(request) diff --git a/model-engine/model_engine_server/inference/batch_inference/dto.py b/model-engine/model_engine_server/inference/batch_inference/dto.py index 0d3e2c05..73f30537 100644 --- a/model-engine/model_engine_server/inference/batch_inference/dto.py +++ b/model-engine/model_engine_server/inference/batch_inference/dto.py @@ -142,6 +142,12 @@ class CreateBatchCompletionsRequest(BaseModel): NOTE: this config is highly experimental and signature will change significantly in future iterations. """ + max_context_length: Optional[int] = Field( + default=None, + ge=1, + description="Maximum context length to use for the model. Defaults to the max allowed by the model", + ) + class VLLMEngineAdditionalArgs(BaseModel): max_gpu_memory_utilization: Optional[float] = Field( diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index 74f8782b..b31e7331 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -160,6 +160,7 @@ def get_vllm_engine(model: str, request: CreateBatchCompletionsEngineRequest): seed=request.model_cfg.seed or 0, disable_log_requests=True, gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9, + max_model_len=request.max_context_length, ) llm = AsyncLLMEngine.from_engine_args(engine_args) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 8d6bb930..7dc149fd 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -2097,6 +2097,19 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.storage == "640Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + hardware = await _infer_hardware( + fake_llm_artifact_gateway, + "deepseek-coder-v2-lite-instruct", + "", + is_batch_job=True, + max_context_length=4096, + ) + assert hardware.cpus == 20 + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" + assert hardware.storage == "96Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + # Phi 3 mini from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json fake_llm_artifact_gateway.model_config = { "architectures": ["Phi3ForCausalLM"], From b84018f5828c0572f8b6c42f129aef07746375f6 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 15 Aug 2024 11:06:44 -0700 Subject: [PATCH 366/425] Fix dto for batch completion (#599) --- .../model_engine_server/inference/batch_inference/dto.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/dto.py b/model-engine/model_engine_server/inference/batch_inference/dto.py index 73f30537..f46f62a1 100644 --- a/model-engine/model_engine_server/inference/batch_inference/dto.py +++ b/model-engine/model_engine_server/inference/batch_inference/dto.py @@ -69,10 +69,6 @@ class CreateBatchCompletionsModelConfig(BaseModel): """ Path to the checkpoint to load the model from. """ - labels: Dict[str, str] - """ - Labels to attach to the batch inference job. - """ num_shards: Optional[int] = 1 """ Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. @@ -122,6 +118,9 @@ class CreateBatchCompletionsRequest(BaseModel): """ Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. """ + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) content: Optional[CreateBatchCompletionsRequestContent] = None """ Either `input_data_path` or `content` needs to be provided. From 092d9f4b036b0f6c45426b6d03b4407aca53279f Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 15 Aug 2024 12:53:47 -0700 Subject: [PATCH 367/425] Update client with new max_context_length (#600) * Update client with new max_context_length * Bump version --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/data_types.py | 9 ++++++++- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 3dcbc6b7..e8c40fd3 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0beta35" +__version__ = "0.0.0beta36" import os from typing import Sequence diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 7f775e5a..2de743c6 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -621,7 +621,8 @@ class ModelDownloadResponse(BaseModel): """ urls: Dict[str, str] = Field( - ..., description="Dictionary of (file_name, url) pairs to download the model from." + ..., + description="Dictionary of (file_name, url) pairs to download the model from.", ) @@ -732,6 +733,12 @@ class CreateBatchCompletionsModelConfig(BaseModel): Random seed for the model. """ + max_context_length: Optional[int] = Field( + default=None, + ge=1, + description="Maximum context length to use for the model. Defaults to the max allowed by the model", + ) + class ToolConfig(BaseModel): """ diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 5eeccff5..30e680d2 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta35" +version = "0.0.0.beta36" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 6636af4b..d486a8d6 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta35", + version="0.0.0.beta36", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) From 439e001fe8011c0292d55ef4a009e6f3bcb0c621 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 21 Aug 2024 14:17:57 -0700 Subject: [PATCH 368/425] Batch completions V2 job (#602) * Batch V2 - v1 payload working * Fix types * Add more examples * Update sdk * fix package versions for docs * more updates * Batch completions V2 API * Fix test * Bump client min python * Add v2 router to app * Add v2 router to app * Some fixes * Some fixes * Update client --- .../templates/inference_framework_config.yaml | 2 + clients/python/llmengine/__init__.py | 34 +- clients/python/llmengine/completion.py | 106 + .../python/llmengine/data_types/__init__.py | 23 + .../llmengine/data_types/batch_completion.py | 291 + .../llmengine/data_types/chat_completion.py | 132 + .../python/llmengine/data_types/completion.py | 323 + .../llmengine/data_types/gen/__init__.py | 0 .../python/llmengine/data_types/gen/openai.py | 6314 +++++++++ .../llmengine/data_types/pydantic_types.py | 15 + .../{data_types.py => data_types/rest.py} | 304 +- clients/python/mypy.ini | 3 + clients/python/poetry.lock | 1897 +-- clients/python/pyproject.toml | 7 +- clients/python/setup.py | 4 +- model-engine/model_engine_server/api/app.py | 7 +- .../model_engine_server/api/dependencies.py | 12 + .../model_engine_server/api/v2/__init__.py | 10 + .../api/v2/batch_completion.py | 150 + .../model_engine_server/api/v2/common.py | 33 + .../common/dtos/llms/batch_completion.py | 55 +- .../common/dtos/llms/completion.py | 60 +- .../common/types/gen/openai.py | 2602 ++-- .../services/llm_batch_completions_service.py | 25 +- .../use_cases/llm_model_endpoint_use_cases.py | 91 +- .../model_engine_server/inference/utils.py | 117 + .../inference/vllm/Dockerfile.vllm | 12 + .../inference/vllm/README.md | 61 + .../vllm/examples/v2/sample_config_gemma.json | 15 + .../examples/v2/sample_data_chat_gemma.json | 1 + .../inference/vllm/gen_sample_data.py | 36 + .../inference/vllm/requirements-batch.txt | 6 + .../inference/vllm/requirements-dev.txt | 1 + .../inference/vllm/requirements.txt | 1 - .../inference/vllm/vllm_batch.py | 320 + .../repositories/live_tokenizer_repository.py | 8 +- .../live_llm_batch_completions_service.py | 114 +- model-engine/mypy.ini | 3 - model-engine/tests/unit/conftest.py | 7 + requirements-docs.txt | 3 +- scripts/generate-openai-types.sh | 1 + scripts/openai-spec.yaml | 11736 ++++++++++------ 42 files changed, 18186 insertions(+), 6756 deletions(-) create mode 100644 clients/python/llmengine/data_types/__init__.py create mode 100644 clients/python/llmengine/data_types/batch_completion.py create mode 100644 clients/python/llmengine/data_types/chat_completion.py create mode 100644 clients/python/llmengine/data_types/completion.py create mode 100644 clients/python/llmengine/data_types/gen/__init__.py create mode 100644 clients/python/llmengine/data_types/gen/openai.py create mode 100644 clients/python/llmengine/data_types/pydantic_types.py rename clients/python/llmengine/{data_types.py => data_types/rest.py} (61%) create mode 100644 model-engine/model_engine_server/api/v2/__init__.py create mode 100644 model-engine/model_engine_server/api/v2/batch_completion.py create mode 100644 model-engine/model_engine_server/api/v2/common.py create mode 100644 model-engine/model_engine_server/inference/utils.py create mode 100644 model-engine/model_engine_server/inference/vllm/README.md create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/sample_config_gemma.json create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/sample_data_chat_gemma.json create mode 100644 model-engine/model_engine_server/inference/vllm/gen_sample_data.py create mode 100644 model-engine/model_engine_server/inference/vllm/requirements-batch.txt create mode 100644 model-engine/model_engine_server/inference/vllm/requirements-dev.txt create mode 100644 model-engine/model_engine_server/inference/vllm/vllm_batch.py diff --git a/charts/model-engine/templates/inference_framework_config.yaml b/charts/model-engine/templates/inference_framework_config.yaml index d97d1920..45759d77 100644 --- a/charts/model-engine/templates/inference_framework_config.yaml +++ b/charts/model-engine/templates/inference_framework_config.yaml @@ -12,5 +12,7 @@ data: deepspeed: "latest" text_generation_inference: "latest" vllm: "latest" + vllm_batch: "latest" + vllm_batch_v2: "latest" lightllm: "latest" tensorrt_llm: "latest" diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index e8c40fd3..6e201069 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0beta36" +__version__ = "0.0.0beta37" import os from typing import Sequence @@ -20,19 +20,34 @@ import requests from llmengine.completion import Completion from llmengine.data_types import ( + BatchCompletionsJob, + BatchCompletionsJobStatus, + BatchCompletionsModelConfig, CancelFineTuneResponse, + ChatCompletionV2Request, + ChatCompletionV2Response, CompletionOutput, CompletionStreamOutput, CompletionStreamResponse, + CompletionStreamV1Request, + CompletionStreamV1Response, CompletionSyncResponse, + CompletionSyncV1Request, + CompletionSyncV1Response, CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequest, CreateBatchCompletionsRequestContent, CreateBatchCompletionsResponse, + CreateBatchCompletionsV2ModelConfig, + CreateBatchCompletionsV2Request, + CreateBatchCompletionsV2RequestContent, + CreateBatchCompletionsV2Response, CreateFineTuneRequest, CreateFineTuneResponse, DeleteFileResponse, DeleteLLMEndpointResponse, + FilteredChatCompletionV2Request, + FilteredCompletionV2Request, GetFileContentResponse, GetFileResponse, GetFineTuneResponse, @@ -43,13 +58,26 @@ ModelDownloadRequest, ModelDownloadResponse, UploadFileResponse, + VLLMAdditionalFields, ) from llmengine.file import File from llmengine.fine_tuning import FineTune from llmengine.model import Model __all__: Sequence[str] = ( + "BatchCompletionsJob", + "CreateBatchCompletionsV2Response", + "FilteredCompletionV2Request", + "FilteredChatCompletionV2Request", + "BatchCompletionsJobStatus", + "CompletionSyncV1Request", + "CompletionSyncV1Response", + "CompletionStreamV1Request", + "CompletionStreamV1Response", "CancelFineTuneResponse", + "ChatCompletionV2Request", + "ChatCompletionV2Response", + "VLLMAdditionalFields", "Completion", "CompletionOutput", "CompletionStreamOutput", @@ -59,6 +87,10 @@ "CreateBatchCompletionsRequest", "CreateBatchCompletionsRequestContent", "CreateBatchCompletionsResponse", + "CreateBatchCompletionsV2Request", + "CreateBatchCompletionsV2RequestContent", + "CreateBatchCompletionsV2ModelConfig", + "BatchCompletionsModelConfig", "CreateFineTuneRequest", "CreateFineTuneResponse", "DeleteFileResponse", diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 076b9031..bb0d5ffa 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -2,6 +2,7 @@ from llmengine.api_engine import APIEngine from llmengine.data_types import ( + BatchCompletionsModelConfig, CompletionStreamResponse, CompletionStreamV1Request, CompletionSyncResponse, @@ -10,6 +11,9 @@ CreateBatchCompletionsRequest, CreateBatchCompletionsRequestContent, CreateBatchCompletionsResponse, + CreateBatchCompletionsV2Request, + CreateBatchCompletionsV2RequestContent, + CreateBatchCompletionsV2Response, ToolConfig, ) @@ -596,3 +600,105 @@ def batch_create( headers=request_headers, ) return CreateBatchCompletionsResponse.parse_obj(response) + + @classmethod + def batch_create_v2( + cls, + *, + output_data_path: str, + model_config: BatchCompletionsModelConfig, + content: Optional[List[CreateBatchCompletionsV2RequestContent]] = None, + input_data_path: Optional[str] = None, + data_parallelism: int = 1, + max_runtime_sec: int = 24 * 3600, + labels: Dict[str, str] = {}, + tool_config: Optional[ToolConfig] = None, + request_headers: Optional[Dict[str, str]] = None, + ) -> CreateBatchCompletionsV2Response: + """ + Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint. + + Prompts can be passed in from an input file, or as a part of the request. + + Args: + output_data_path (str): + The path to the output file. The output file will be a JSON file containing the completions. + + model_config (BatchCompletionsModelConfig): + The model configuration to use for the batch completion. + + content (Optional[List[CreateBatchCompletionsV2RequestContent]]): + The content to use for the batch completion. Either one of `content` or `input_data_path` must be provided. + + input_data_path (Optional[str]): + The path to the input file. The input file should be a JSON file with data of type `BatchCompletionsRequestContent`. Either one of `content` or `input_data_path` must be provided. + + data_parallelism (int): + The number of parallel jobs to run. Data will be evenly distributed to the jobs. Defaults to 1. + + max_runtime_sec (int): + The maximum runtime of the batch completion in seconds. Defaults to 24 hours. + + tool_config (Optional[ToolConfig]): + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + Currently only Python code evaluator is supported. + Python code context starts with "\`\`\`python\\n" and ends with "\\n>>>\\n", data before "\\n\`\`\`\\n" and content end will be replaced by the Python execution results. + Please format prompts accordingly and provide examples so LLMs could properly generate Python code. + + Returns: + response (CreateBatchCompletionsV2Response): The response containing the job id. + + === "Batch completions with prompts in the request" + ```python + from llmengine import ( + Completion, + ) + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, FilteredChatCompletionV2Request, + + model_config = CreateBatchCompletionsModelConfig( + model="gemma-2-2b-it", + checkpoint_path="s3://path-to-checkpoint", + ) + + content = { + "messages": [ + { + "role": "user", + "content": "What is a good place for travel in the US?", + }, + {"role": "assistant", "content": "California."}, + {"role": "user", "content": "What can I do in California?"}, + ], + "logprobs": True, + } + + response = Completion.batch_create_v2( + output_data_path="testoutput", + model_config=model_config, + content=[FilteredChatCompletionV2Request(**content)], + labels={"team": "my-team", "product": "my-product"}, + ) + + print(response.json()) + ``` + + """ + data = CreateBatchCompletionsV2Request( + model_config=model_config, + content=content, + input_data_path=input_data_path, + output_data_path=output_data_path, + data_parallelism=data_parallelism, + labels=labels, + max_runtime_sec=max_runtime_sec, + tool_config=tool_config, + ).model_dump(exclude_none=True, by_alias=True) + response = cls.post_sync( + resource_name="v2/batch-completions", + data=data, + timeout=HTTP_TIMEOUT, + headers=request_headers, + ) + return CreateBatchCompletionsV2Response.parse_obj(response) diff --git a/clients/python/llmengine/data_types/__init__.py b/clients/python/llmengine/data_types/__init__.py new file mode 100644 index 00000000..ff34f72b --- /dev/null +++ b/clients/python/llmengine/data_types/__init__.py @@ -0,0 +1,23 @@ +""" +DTOs for LLM APIs. +""" + +from typing_extensions import TypeAlias + +from .batch_completion import * # noqa: F403 +from .chat_completion import * # noqa: F403 +from .completion import * # noqa: F403 +from .rest import * # noqa: F403 + +# Alias for backwards compatibility +CreateBatchCompletionsRequestContent: TypeAlias = ( + CreateBatchCompletionsV1RequestContent # noqa: F405 +) +CreateBatchCompletionsRequest: TypeAlias = CreateBatchCompletionsV1Request # noqa: F405 +CreateBatchCompletionsResponse: TypeAlias = CreateBatchCompletionsV1Response # noqa: F405 +CreateBatchCompletionsModelConfig: TypeAlias = CreateBatchCompletionsV1ModelConfig # noqa: F405 + +CompletionSyncRequest: TypeAlias = CompletionSyncV1Request # noqa: F405 +CompletionSyncResponse: TypeAlias = CompletionSyncV1Response # noqa: F405 +CompletionStreamRequest: TypeAlias = CompletionStreamV1Request # noqa: F405 +CompletionStreamResponse: TypeAlias = CompletionStreamV1Response # noqa: F405 diff --git a/clients/python/llmengine/data_types/batch_completion.py b/clients/python/llmengine/data_types/batch_completion.py new file mode 100644 index 00000000..cfb31248 --- /dev/null +++ b/clients/python/llmengine/data_types/batch_completion.py @@ -0,0 +1,291 @@ +from enum import Enum +from typing import Dict, List, Optional, Union + +from typing_extensions import TypeAlias + +from .chat_completion import ChatCompletionV2Request, ChatCompletionV2Response +from .completion import CompletionOutput, CompletionV2Request, CompletionV2Response +from .pydantic_types import BaseModel, Field + + +# Common DTOs for batch completions +class ToolConfig(BaseModel): + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + name: str + """ + Name of the tool to use for the batch inference. + """ + max_iterations: Optional[int] = 10 + """ + Maximum number of iterations to run the tool. + """ + execution_timeout_seconds: Optional[int] = 60 + """ + Maximum runtime of the tool in seconds. + """ + should_retry_on_error: Optional[bool] = True + """ + Whether to retry the tool on error. + """ + + +class BatchCompletionsModelConfig(BaseModel): + model: str = Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ) + + checkpoint_path: Optional[str] = Field( + default=None, description="Path to the checkpoint to load the model from." + ) + + num_shards: Optional[int] = Field( + default=1, + ge=1, + description=""" +Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. +System may decide to use a different number than the given value. +""", + ) + + max_context_length: Optional[int] = Field( + default=None, + ge=1, + description="Maximum context length to use for the model. Defaults to the max allowed by the model", + ) + + seed: Optional[int] = Field(default=None, description="Random seed for the model.") + + response_role: Optional[str] = Field( + default=None, + description="Role of the response in the conversation. Only supported in chat completions.", + ) + + +class BatchCompletionsRequestBase(BaseModel): + input_data_path: Optional[str] = Field( + default=None, + description="Path to the input file. The input file should be a JSON file of type List[CreateBatchCompletionsRequestContent].", + ) + output_data_path: str = Field( + description="Path to the output file. The output file will be a JSON file of type List[CompletionOutput]." + ) + + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) + + data_parallelism: Optional[int] = Field( + default=1, + ge=1, + le=64, + description="Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference.", + ) + + max_runtime_sec: Optional[int] = Field( + default=24 * 3600, + ge=1, + le=2 * 24 * 3600, + description="Maximum runtime of the batch inference in seconds. Default to one day.", + ) + + priority: Optional[str] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + + tool_config: Optional[ToolConfig] = Field( + default=None, + description=""" +Configuration for tool use. +NOTE: this config is highly experimental and signature will change significantly in future iterations.""", + ) + + +# V1 DTOs for batch completions +CompletionV1Output = CompletionOutput + + +class CreateBatchCompletionsV1ModelConfig(BatchCompletionsModelConfig): + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) + + +class CreateBatchCompletionsV1RequestContent(BaseModel): + prompts: List[str] + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. + """ + + +class CreateBatchCompletionsV1Request(BatchCompletionsRequestBase): + """ + Request object for batch completions. + """ + + content: Optional[CreateBatchCompletionsV1RequestContent] = None + """ + Either `input_data_path` or `content` needs to be provided. + When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. + """ + model_cfg: CreateBatchCompletionsV1ModelConfig = Field(alias="model_config") + """ + Model configuration for the batch inference. Hardware configurations are inferred. + + We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + reserves model_config as a keyword. + """ + + +class CreateBatchCompletionsV1Response(BaseModel): + job_id: str + + +class FilteredCompletionV2Request(CompletionV2Request): + model: Optional[str] = None # type: ignore[assignment] + stream: Optional[bool] = False + + +class FilteredChatCompletionV2Request(ChatCompletionV2Request): + model: Optional[str] = None # type: ignore[assignment] + stream: Optional[bool] = False + + +# V2 DTOs for batch completions +CompletionRequest: TypeAlias = Union[FilteredCompletionV2Request, FilteredChatCompletionV2Request] +CompletionResponse: TypeAlias = Union[CompletionV2Response, ChatCompletionV2Response] +CreateBatchCompletionsV2RequestContent: TypeAlias = Union[ + List[FilteredCompletionV2Request], List[FilteredChatCompletionV2Request] +] +CreateBatchCompletionsV2ModelConfig: TypeAlias = BatchCompletionsModelConfig + + +class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): + """ + Request object for batch completions. + """ + + content: Optional[CreateBatchCompletionsV2RequestContent] = Field( + default=None, + description=""" +Either `input_data_path` or `content` needs to be provided. +When input_data_path is provided, the input file should be a JSON file of type List[CreateBatchCompletionsRequestContent]. +""", + ) + + # We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + # reserves model_config as a keyword. + model_cfg: BatchCompletionsModelConfig = Field( + alias="model_config", + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + +class BatchCompletionsJobStatus(str, Enum): + Queued = "queued" + Running = "running" + Completed = "completed" + Failed = "failed" + Cancelled = "cancelled" + Unknown = "unknown" + + +class BatchCompletionsJob(BaseModel): + job_id: str + input_data_path: Optional[str] = Field( + default=None, + description="Path to the input file. The input file should be a JSON file of type List[CreateBatchCompletionsRequestContent].", + ) + output_data_path: str = Field( + description="Path to the output file. The output file will be a JSON file of type List[CompletionOutput]." + ) + + # We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + # reserves model_config as a keyword. + model_cfg: BatchCompletionsModelConfig = Field( + alias="model_config", + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + priority: Optional[int] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + status: BatchCompletionsJobStatus + created_at: str + expires_at: str + completed_at: Optional[str] + metadata: Optional[Dict[str, str]] + + +CreateBatchCompletionsV2Response: TypeAlias = BatchCompletionsJob + + +class UpdateBatchCompletionsV2Request(BaseModel): + job_id: str = Field(description="ID of the batch completions job") + priority: Optional[int] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + + +class UpdateBatchCompletionsV2Response(BatchCompletionsJob): + success: bool = Field(description="Whether the update was successful") + + +class CancelBatchCompletionsV2Request(BaseModel): + job_id: str = Field(description="ID of the batch completions job") + + +class CancelBatchCompletionsV2Response(BaseModel): + success: bool = Field(description="Whether the cancellation was successful") + + +class ListBatchCompletionV2Response(BaseModel): + jobs: List[BatchCompletionsJob] + + +class GetBatchCompletionV2Response(BaseModel): + job: BatchCompletionsJob + + +BatchCompletionContent = Union[ + CreateBatchCompletionsV1RequestContent, CreateBatchCompletionsV2RequestContent +] diff --git a/clients/python/llmengine/data_types/chat_completion.py b/clients/python/llmengine/data_types/chat_completion.py new file mode 100644 index 00000000..ab2c94a0 --- /dev/null +++ b/clients/python/llmengine/data_types/chat_completion.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, List, Optional + +from pydantic import Field +from typing_extensions import Annotated + +from .gen.openai import CreateChatCompletionRequest, CreateChatCompletionResponse + +# Fields that are a part of OpenAI spec but are not supported by model engine +UNSUPPORTED_FIELDS = ["service_tier"] + + +class VLLMAdditionalFields: + chat_template: Annotated[ + Optional[str], + Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one." + ), + ), + ] + chat_template_kwargs: Annotated[ + Optional[Dict[str, Any]], + Field( + default=None, + description=( + "Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ), + ] + + guided_json: Annotated[ + Optional[Dict[str, Any]], + Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ), + ] + + guided_regex: Annotated[ + Optional[str], + Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ), + ] + guided_choice: Annotated[ + Optional[List[str]], + Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ), + ] + + guided_grammar: Annotated[ + Optional[str], + Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ), + ] + + guided_decoding_backend: Annotated[ + Optional[str], + Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ), + ] + + guided_whitespace_pattern: Annotated[ + Optional[str], + Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ), + ] + + skip_special_tokens: Annotated[ + Optional[bool], + Field( + True, + description="Whether to skip special tokens in the output. Only supported in vllm.", + ), + ] + + +class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMAdditionalFields): + model: Annotated[ + str, + Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ), + ] + + stream: Annotated[ + Optional[bool], + Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + + top_k: Annotated[ + Optional[int], + Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ), + ] + + include_stop_str_in_output: Annotated[ + Optional[bool], + Field(None, description="Whether to include the stop strings in output text."), + ] + + +class ChatCompletionV2Response(CreateChatCompletionResponse): + pass diff --git a/clients/python/llmengine/data_types/completion.py b/clients/python/llmengine/data_types/completion.py new file mode 100644 index 00000000..24978263 --- /dev/null +++ b/clients/python/llmengine/data_types/completion.py @@ -0,0 +1,323 @@ +from typing import Any, Dict, List, Optional + +from typing_extensions import Annotated + +from .gen.openai import CreateCompletionRequest, CreateCompletionResponse +from .pydantic_types import BaseModel, Field + +# Fields that are a part of OpenAI spec but are not supported by model engine +UNSUPPORTED_FIELDS = ["service_tier"] + + +class CompletionSyncV1Request(BaseModel): + """ + Request object for a synchronous prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ + + +class TokenOutput(BaseModel): + """ + Detailed token information. + """ + + token: str + """ + The token text. + """ + + log_prob: float + """ + The log probability of the token. + """ + + +class CompletionOutput(BaseModel): + """ + Represents the output of a completion request to a model. + """ + + text: str + """The text of the completion.""" + + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. + # If we send request to api.spellbook.scale.com, we don't get this back. + num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + + num_completion_tokens: int + """Number of tokens in the completion.""" + + tokens: Optional[List[TokenOutput]] = None + """Detailed token information.""" + + +class CompletionSyncV1Response(BaseModel): + """ + Response object for a synchronous prompt completion. + """ + + request_id: str + """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs + associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as + follows: + + * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. + * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability + provider.""" + + output: CompletionOutput + """Completion output.""" + + +class CompletionStreamV1Request(BaseModel): + """ + Request object for a stream prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ + + +class CompletionStreamOutput(BaseModel): + text: str + """The text of the completion.""" + + finished: bool + """Whether the completion is finished.""" + + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. + num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + + num_completion_tokens: Optional[int] = None + """Number of tokens in the completion.""" + + token: Optional[TokenOutput] = None + """Detailed token information.""" + + +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" + + +class CompletionStreamV1Response(BaseModel): + """Error of the response (if any).""" + + """ + Response object for a stream prompt completion task. + """ + + request_id: str + """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs + associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as + follows: + + * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. + * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability + provider.""" + + output: Optional[CompletionStreamOutput] = None + """Completion output.""" + + error: Optional[StreamError] = None + """Error of the response (if any).""" + + +class TokenUsage(BaseModel): + """ + Token usage for a prompt completion task. + """ + + num_prompt_tokens: Optional[int] = 0 + num_completion_tokens: Optional[int] = 0 + total_duration: Optional[float] = None + """Includes time spent waiting for the model to be ready.""" + + time_to_first_token: Optional[float] = None # Only for streaming requests + + @property + def num_total_tokens(self) -> int: + return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) + + @property + def total_tokens_per_second(self) -> float: + return ( + self.num_total_tokens / self.total_duration + if self.total_duration and self.total_duration > 0 + else 0.0 + ) + + @property + def inter_token_latency(self) -> Optional[float]: # Only for streaming requests + # Note: we calculate a single inter-token latency for the entire request. + # Calculating latency between each token seems a bit heavyweight, although we can do this if we wanted + if ( + self.time_to_first_token is None + or self.num_completion_tokens is None + or self.total_duration is None + ): + return None + if self.num_completion_tokens < 2: + return None + return (self.total_duration - self.time_to_first_token) / (self.num_completion_tokens - 1) + + +class CompletionV2Request(CreateCompletionRequest): + model: Annotated[ + str, + Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ), + ] + + stream: Annotated[ + Optional[bool], + Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + + top_k: Annotated[ + Optional[int], + Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ), + ] + + include_stop_str_in_output: Annotated[ + Optional[bool], + Field(None, description="Whether to include the stop strings in output text."), + ] + + +class CompletionV2Response(CreateCompletionResponse): + pass diff --git a/clients/python/llmengine/data_types/gen/__init__.py b/clients/python/llmengine/data_types/gen/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/clients/python/llmengine/data_types/gen/openai.py b/clients/python/llmengine/data_types/gen/openai.py new file mode 100644 index 00000000..b8222667 --- /dev/null +++ b/clients/python/llmengine/data_types/gen/openai.py @@ -0,0 +1,6314 @@ +# mypy: ignore-errors +# generated by datamodel-codegen: +# filename: openai-spec.yaml +# timestamp: 2024-08-20T08:20:04+00:00 + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union + +from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel +from typing_extensions import Annotated, Literal + + +class Error(BaseModel): + code: str + message: str + param: str + type: str + + +class ErrorResponse(BaseModel): + error: Error + + +class DeleteModelResponse(BaseModel): + id: str + deleted: bool + object: str + + +class Prompt(RootModel[Optional[List[int]]]): + root: Annotated[ + Optional[List[int]], + Field( + "<|endoftext|>", + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + examples=["[1212, 318, 257, 1332, 13]"], + min_length=1, + ), + ] = "<|endoftext|>" + + +class Prompt1Item(RootModel[List[int]]): + root: Annotated[List[int], Field(min_length=1)] + + +class Prompt1(RootModel[Optional[List[Prompt1Item]]]): + root: Annotated[ + Optional[List[Prompt1Item]], + Field( + "<|endoftext|>", + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + examples=["[[1212, 318, 257, 1332, 13]]"], + min_length=1, + ), + ] = "<|endoftext|>" + + +class Stop(RootModel[Optional[List[str]]]): + root: Annotated[ + Optional[List[str]], + Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + max_length=4, + min_length=1, + ), + ] = None + + +class Logprobs(BaseModel): + text_offset: Optional[List[int]] = None + token_logprobs: Optional[List[float]] = None + tokens: Optional[List[str]] = None + top_logprobs: Optional[List[Dict[str, float]]] = None + + +class Choice(BaseModel): + finish_reason: Annotated[ + Literal["stop", "length", "content_filter"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\nor `content_filter` if content was omitted due to a flag from our content filters.\n" + ), + ] + index: int + logprobs: Logprobs + text: str + + +class ChatCompletionRequestMessageContentPartText(BaseModel): + type: Annotated[Literal["text"], Field(description="The type of the content part.")] + text: Annotated[str, Field(description="The text content.")] + + +class ImageUrl(BaseModel): + url: Annotated[ + AnyUrl, + Field(description="Either a URL of the image or the base64 encoded image data."), + ] + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + "auto", + description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding).", + ), + ] + + +class ChatCompletionRequestMessageContentPartImage(BaseModel): + type: Annotated[Literal["image_url"], Field(description="The type of the content part.")] + image_url: ImageUrl + + +class ChatCompletionRequestMessageContentPartRefusal(BaseModel): + type: Annotated[Literal["refusal"], Field(description="The type of the content part.")] + refusal: Annotated[str, Field(description="The refusal message generated by the model.")] + + +class ChatCompletionRequestSystemMessageContentPart( + RootModel[ChatCompletionRequestMessageContentPartText] +): + root: ChatCompletionRequestMessageContentPartText + + +class ChatCompletionRequestUserMessageContentPart( + RootModel[ + Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + ] +): + root: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + + +class ChatCompletionRequestAssistantMessageContentPart( + RootModel[ + Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartRefusal, + ] + ] +): + root: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartRefusal, + ] + + +class ChatCompletionRequestToolMessageContentPart( + RootModel[ChatCompletionRequestMessageContentPartText] +): + root: ChatCompletionRequestMessageContentPartText + + +class Content(RootModel[List[ChatCompletionRequestSystemMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestSystemMessageContentPart], + Field( + description="An array of content parts with a defined type. For system messages, only type `text` is supported.", + min_length=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestSystemMessage(BaseModel): + content: Annotated[ + Union[str, Content], Field(description="The contents of the system message.") + ] + role: Annotated[ + Literal["system"], + Field(description="The role of the messages author, in this case `system`."), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ), + ] + + +class Content1(RootModel[List[ChatCompletionRequestUserMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestUserMessageContentPart], + Field( + description="An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model.", + min_length=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestUserMessage(BaseModel): + content: Annotated[ + Union[str, Content1], Field(description="The contents of the user message.\n") + ] + role: Annotated[ + Literal["user"], + Field(description="The role of the messages author, in this case `user`."), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ), + ] + + +class Content2(RootModel[Optional[List[ChatCompletionRequestAssistantMessageContentPart]]]): + root: Annotated[ + Optional[List[ChatCompletionRequestAssistantMessageContentPart]], + Field( + None, + description="An array of content parts with a defined type. Can be one or more of type `text`, or exactly one of type `refusal`.", + min_length=1, + title="Array of content parts", + ), + ] = None + + +class FunctionCall(BaseModel): + arguments: Annotated[ + str, + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] + name: Annotated[str, Field(description="The name of the function to call.")] + + +class Content3(RootModel[List[ChatCompletionRequestToolMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestToolMessageContentPart], + Field( + description="An array of content parts with a defined type. For tool messages, only type `text` is supported.", + min_length=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestToolMessage(BaseModel): + role: Annotated[ + Literal["tool"], + Field(description="The role of the messages author, in this case `tool`."), + ] + content: Annotated[Union[str, Content3], Field(description="The contents of the tool message.")] + tool_call_id: Annotated[str, Field(description="Tool call that this message is responding to.")] + + +class ChatCompletionRequestFunctionMessage(BaseModel): + role: Annotated[ + Literal["function"], + Field(description="The role of the messages author, in this case `function`."), + ] + content: Annotated[str, Field(description="The contents of the function message.")] + name: Annotated[str, Field(description="The name of the function to call.")] + + +class FunctionParameters(BaseModel): + pass + model_config = ConfigDict( + extra="allow", + ) + + +class ChatCompletionFunctions(BaseModel): + description: Annotated[ + Optional[str], + Field( + None, + description="A description of what the function does, used by the model to choose when and how to call the function.", + ), + ] + name: Annotated[ + str, + Field( + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + parameters: Optional[FunctionParameters] = None + + +class ChatCompletionFunctionCallOption(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class FunctionObject(BaseModel): + description: Annotated[ + Optional[str], + Field( + None, + description="A description of what the function does, used by the model to choose when and how to call the function.", + ), + ] + name: Annotated[ + str, + Field( + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + parameters: Optional[FunctionParameters] = None + strict: Annotated[ + Optional[bool], + Field( + False, + description="Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling).", + ), + ] + + +class ResponseFormatText(BaseModel): + type: Annotated[ + Literal["text"], + Field(description="The type of response format being defined: `text`"), + ] + + +class ResponseFormatJsonObject(BaseModel): + type: Annotated[ + Literal["json_object"], + Field(description="The type of response format being defined: `json_object`"), + ] + + +class ResponseFormatJsonSchemaSchema(BaseModel): + pass + model_config = ConfigDict( + extra="allow", + ) + + +class JsonSchema(BaseModel): + description: Annotated[ + Optional[str], + Field( + None, + description="A description of what the response format is for, used by the model to determine how to respond in the format.", + ), + ] + name: Annotated[ + str, + Field( + description="The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + schema_: Annotated[Optional[ResponseFormatJsonSchemaSchema], Field(None, alias="schema")] + strict: Annotated[ + Optional[bool], + Field( + False, + description="Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs).", + ), + ] + + +class ResponseFormatJsonSchema(BaseModel): + type: Annotated[ + Literal["json_schema"], + Field(description="The type of response format being defined: `json_schema`"), + ] + json_schema: JsonSchema + + +class Function(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class ChatCompletionNamedToolChoice(BaseModel): + type: Annotated[ + Literal["function"], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: Function + + +class ParallelToolCalls(RootModel[bool]): + root: Annotated[ + bool, + Field( + description="Whether to enable [parallel function calling](/docs/guides/function-calling/parallel-function-calling) during tool use." + ), + ] + + +class Function1(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + arguments: Annotated[ + str, + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] + + +class ChatCompletionMessageToolCall(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call.")] + type: Annotated[ + Literal["function"], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: Annotated[Function1, Field(description="The function that the model called.")] + + +class Function2(BaseModel): + name: Annotated[Optional[str], Field(None, description="The name of the function to call.")] + arguments: Annotated[ + Optional[str], + Field( + None, + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ), + ] + + +class ChatCompletionMessageToolCallChunk(BaseModel): + index: int + id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] + type: Annotated[ + Optional[Literal["function"]], + Field( + None, + description="The type of the tool. Currently, only `function` is supported.", + ), + ] + function: Optional[Function2] = None + + +class ChatCompletionRole(RootModel[Literal["system", "user", "assistant", "tool", "function"]]): + root: Annotated[ + Literal["system", "user", "assistant", "tool", "function"], + Field(description="The role of the author of a message"), + ] + + +class ChatCompletionStreamOptions(BaseModel): + include_usage: Annotated[ + Optional[bool], + Field( + None, + description="If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.\n", + ), + ] + + +class FunctionCall2(BaseModel): + arguments: Annotated[ + Optional[str], + Field( + None, + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + ), + ] + name: Annotated[Optional[str], Field(None, description="The name of the function to call.")] + + +class ChatCompletionStreamResponseDelta(BaseModel): + content: Annotated[Optional[str], Field(None, description="The contents of the chunk message.")] + function_call: Annotated[ + Optional[FunctionCall2], + Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ), + ] + tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = None + role: Annotated[ + Optional[Literal["system", "user", "assistant", "tool"]], + Field(None, description="The role of the author of this message."), + ] + refusal: Annotated[ + Optional[str], + Field(None, description="The refusal message generated by the model."), + ] + + +class Stop1(RootModel[List[str]]): + root: Annotated[ + List[str], + Field( + description="Up to 4 sequences where the API will stop generating further tokens.\n", + max_length=4, + min_length=1, + ), + ] + + +class TopLogprob(BaseModel): + token: Annotated[str, Field(description="The token.")] + logprob: Annotated[ + float, + Field( + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely." + ), + ] + bytes: Annotated[ + List[int], + Field( + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." + ), + ] + + +class ChatCompletionTokenLogprob(BaseModel): + token: Annotated[str, Field(description="The token.")] + logprob: Annotated[ + float, + Field( + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely." + ), + ] + bytes: Annotated[ + List[int], + Field( + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." + ), + ] + top_logprobs: Annotated[ + List[TopLogprob], + Field( + description="List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned." + ), + ] + + +class Logprobs2(BaseModel): + content: Annotated[ + List[ChatCompletionTokenLogprob], + Field(description="A list of message content tokens with log probability information."), + ] + refusal: Annotated[ + List[ChatCompletionTokenLogprob], + Field(description="A list of message refusal tokens with log probability information."), + ] + + +class Choice3(BaseModel): + delta: ChatCompletionStreamResponseDelta + logprobs: Annotated[ + Optional[Logprobs2], + Field(None, description="Log probability information for the choice."), + ] + finish_reason: Annotated[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + + +class Usage(BaseModel): + completion_tokens: Annotated[ + int, Field(description="Number of tokens in the generated completion.") + ] + prompt_tokens: Annotated[int, Field(description="Number of tokens in the prompt.")] + total_tokens: Annotated[ + int, + Field(description="Total number of tokens used in the request (prompt + completion)."), + ] + + +class CreateChatCompletionStreamResponse(BaseModel): + id: Annotated[ + str, + Field( + description="A unique identifier for the chat completion. Each chunk has the same ID." + ), + ] + choices: Annotated[ + List[Choice3], + Field( + description='A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the\nlast chunk if you set `stream_options: {"include_usage": true}`.\n' + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp." + ), + ] + model: Annotated[str, Field(description="The model to generate the completion.")] + service_tier: Annotated[ + Optional[Literal["scale", "default"]], + Field( + None, + description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", + examples=["scale"], + ), + ] + system_fingerprint: Annotated[ + Optional[str], + Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ), + ] + object: Annotated[ + Literal["chat.completion.chunk"], + Field(description="The object type, which is always `chat.completion.chunk`."), + ] + usage: Annotated[ + Optional[Usage], + Field( + None, + description='An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.\nWhen present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.\n', + ), + ] + + +class CreateChatCompletionImageResponse(BaseModel): + pass + + +class CreateImageRequest(BaseModel): + prompt: Annotated[ + str, + Field( + description="A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`.", + examples=["A cute baby sea otter"], + ), + ] + model: Annotated[ + Optional[Union[str, Literal["dall-e-2", "dall-e-3"]]], + Field( + "dall-e-2", + description="The model to use for image generation.", + examples=["dall-e-3"], + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + examples=[1], + ge=1, + le=10, + ), + ] + quality: Annotated[ + Optional[Literal["standard", "hd"]], + Field( + "standard", + description="The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`.", + examples=["standard"], + ), + ] + response_format: Annotated[ + Optional[Literal["url", "b64_json"]], + Field( + "url", + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + examples=["url"], + ), + ] + size: Annotated[ + Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]], + Field( + "1024x1024", + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models.", + examples=["1024x1024"], + ), + ] + style: Annotated[ + Optional[Literal["vivid", "natural"]], + Field( + "vivid", + description="The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`.", + examples=["vivid"], + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class Image(BaseModel): + b64_json: Annotated[ + Optional[str], + Field( + None, + description="The base64-encoded JSON of the generated image, if `response_format` is `b64_json`.", + ), + ] + url: Annotated[ + Optional[str], + Field( + None, + description="The URL of the generated image, if `response_format` is `url` (default).", + ), + ] + revised_prompt: Annotated[ + Optional[str], + Field( + None, + description="The prompt that was used to generate the image, if there was any revision to the prompt.", + ), + ] + + +class CreateImageEditRequest(BaseModel): + image: Annotated[ + bytes, + Field( + description="The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask." + ), + ] + prompt: Annotated[ + str, + Field( + description="A text description of the desired image(s). The maximum length is 1000 characters.", + examples=["A cute baby sea otter wearing a beret"], + ), + ] + mask: Annotated[ + Optional[bytes], + Field( + None, + description="An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`.", + ), + ] + model: Annotated[ + Optional[Union[str, Literal["dall-e-2"]]], + Field( + "dall-e-2", + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + examples=["dall-e-2"], + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="The number of images to generate. Must be between 1 and 10.", + examples=[1], + ge=1, + le=10, + ), + ] + size: Annotated[ + Optional[Literal["256x256", "512x512", "1024x1024"]], + Field( + "1024x1024", + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + examples=["1024x1024"], + ), + ] + response_format: Annotated[ + Optional[Literal["url", "b64_json"]], + Field( + "url", + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + examples=["url"], + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class CreateImageVariationRequest(BaseModel): + image: Annotated[ + bytes, + Field( + description="The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square." + ), + ] + model: Annotated[ + Optional[Union[str, Literal["dall-e-2"]]], + Field( + "dall-e-2", + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + examples=["dall-e-2"], + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + examples=[1], + ge=1, + le=10, + ), + ] + response_format: Annotated[ + Optional[Literal["url", "b64_json"]], + Field( + "url", + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + examples=["url"], + ), + ] + size: Annotated[ + Optional[Literal["256x256", "512x512", "1024x1024"]], + Field( + "1024x1024", + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + examples=["1024x1024"], + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class CreateModerationRequest(BaseModel): + input: Annotated[Union[str, List[str]], Field(description="The input text to classify")] + model: Annotated[ + Optional[Union[str, Literal["text-moderation-latest", "text-moderation-stable"]]], + Field( + "text-moderation-latest", + description="Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`.\n\nThe default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`.\n", + examples=["text-moderation-stable"], + ), + ] + + +class Categories(BaseModel): + hate: Annotated[ + bool, + Field( + description="Content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. Hateful content aimed at non-protected groups (e.g., chess players) is harassment." + ), + ] + hate_threatening: Annotated[ + bool, + Field( + alias="hate/threatening", + description="Hateful content that also includes violence or serious harm towards the targeted group based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste.", + ), + ] + harassment: Annotated[ + bool, + Field( + description="Content that expresses, incites, or promotes harassing language towards any target." + ), + ] + harassment_threatening: Annotated[ + bool, + Field( + alias="harassment/threatening", + description="Harassment content that also includes violence or serious harm towards any target.", + ), + ] + self_harm: Annotated[ + bool, + Field( + alias="self-harm", + description="Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders.", + ), + ] + self_harm_intent: Annotated[ + bool, + Field( + alias="self-harm/intent", + description="Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders.", + ), + ] + self_harm_instructions: Annotated[ + bool, + Field( + alias="self-harm/instructions", + description="Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts.", + ), + ] + sexual: Annotated[ + bool, + Field( + description="Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness)." + ), + ] + sexual_minors: Annotated[ + bool, + Field( + alias="sexual/minors", + description="Sexual content that includes an individual who is under 18 years old.", + ), + ] + violence: Annotated[ + bool, + Field(description="Content that depicts death, violence, or physical injury."), + ] + violence_graphic: Annotated[ + bool, + Field( + alias="violence/graphic", + description="Content that depicts death, violence, or physical injury in graphic detail.", + ), + ] + + +class CategoryScores(BaseModel): + hate: Annotated[float, Field(description="The score for the category 'hate'.")] + hate_threatening: Annotated[ + float, + Field( + alias="hate/threatening", + description="The score for the category 'hate/threatening'.", + ), + ] + harassment: Annotated[float, Field(description="The score for the category 'harassment'.")] + harassment_threatening: Annotated[ + float, + Field( + alias="harassment/threatening", + description="The score for the category 'harassment/threatening'.", + ), + ] + self_harm: Annotated[ + float, + Field(alias="self-harm", description="The score for the category 'self-harm'."), + ] + self_harm_intent: Annotated[ + float, + Field( + alias="self-harm/intent", + description="The score for the category 'self-harm/intent'.", + ), + ] + self_harm_instructions: Annotated[ + float, + Field( + alias="self-harm/instructions", + description="The score for the category 'self-harm/instructions'.", + ), + ] + sexual: Annotated[float, Field(description="The score for the category 'sexual'.")] + sexual_minors: Annotated[ + float, + Field( + alias="sexual/minors", + description="The score for the category 'sexual/minors'.", + ), + ] + violence: Annotated[float, Field(description="The score for the category 'violence'.")] + violence_graphic: Annotated[ + float, + Field( + alias="violence/graphic", + description="The score for the category 'violence/graphic'.", + ), + ] + + +class Result(BaseModel): + flagged: Annotated[bool, Field(description="Whether any of the below categories are flagged.")] + categories: Annotated[ + Categories, + Field(description="A list of the categories, and whether they are flagged or not."), + ] + category_scores: Annotated[ + CategoryScores, + Field( + description="A list of the categories along with their scores as predicted by model." + ), + ] + + +class CreateModerationResponse(BaseModel): + id: Annotated[str, Field(description="The unique identifier for the moderation request.")] + model: Annotated[str, Field(description="The model used to generate the moderation results.")] + results: Annotated[List[Result], Field(description="A list of moderation objects.")] + + +class CreateFileRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file: Annotated[bytes, Field(description="The File object (not file name) to be uploaded.\n")] + purpose: Annotated[ + Literal["assistants", "batch", "fine-tune", "vision"], + Field( + description='The intended purpose of the uploaded file.\n\nUse "assistants" for [Assistants](/docs/api-reference/assistants) and [Message](/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](/docs/guides/batch), and "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning).\n' + ), + ] + + +class DeleteFileResponse(BaseModel): + id: str + object: Literal["file"] + deleted: bool + + +class CreateUploadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + filename: Annotated[str, Field(description="The name of the file to upload.\n")] + purpose: Annotated[ + Literal["assistants", "batch", "fine-tune", "vision"], + Field( + description="The intended purpose of the uploaded file.\n\nSee the [documentation on File purposes](/docs/api-reference/files/create#files-create-purpose).\n" + ), + ] + bytes: Annotated[int, Field(description="The number of bytes in the file you are uploading.\n")] + mime_type: Annotated[ + str, + Field( + description="The MIME type of the file.\n\nThis must fall within the supported MIME types for your file purpose. See the supported MIME types for assistants and vision.\n" + ), + ] + + +class AddUploadPartRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + data: Annotated[bytes, Field(description="The chunk of bytes for this Part.\n")] + + +class CompleteUploadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + part_ids: Annotated[List[str], Field(description="The ordered list of Part IDs.\n")] + md5: Annotated[ + Optional[str], + Field( + None, + description="The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect.\n", + ), + ] + + +class CancelUploadRequest(BaseModel): + pass + model_config = ConfigDict( + extra="forbid", + ) + + +class BatchSize(RootModel[int]): + root: Annotated[ + int, + Field( + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", + ge=1, + le=256, + ), + ] + + +class LearningRateMultiplier(RootModel[float]): + root: Annotated[ + float, + Field( + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", + gt=0.0, + ), + ] + + +class NEpochs(RootModel[int]): + root: Annotated[ + int, + Field( + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", + ge=1, + le=50, + ), + ] + + +class Hyperparameters(BaseModel): + batch_size: Annotated[ + Optional[Union[Literal["auto"], BatchSize]], + Field( + "auto", + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", + ), + ] + learning_rate_multiplier: Annotated[ + Optional[Union[Literal["auto"], LearningRateMultiplier]], + Field( + "auto", + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", + ), + ] + n_epochs: Annotated[ + Optional[Union[Literal["auto"], NEpochs]], + Field( + "auto", + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", + ), + ] + + +class Wandb(BaseModel): + project: Annotated[ + str, + Field( + description="The name of the project that the new run will be created under.\n", + examples=["my-wandb-project"], + ), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="A display name to set for the run. If not set, we will use the Job ID as the name.\n", + ), + ] + entity: Annotated[ + Optional[str], + Field( + None, + description="The entity to use for the run. This allows you to set the team or username of the WandB user that you would\nlike associated with the run. If not set, the default entity for the registered WandB API key is used.\n", + ), + ] + tags: Annotated[ + Optional[List[str]], + Field( + None, + description='A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some\ndefault tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".\n', + ), + ] + + +class Integration(BaseModel): + type: Annotated[ + Literal["wandb"], + Field( + description='The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported.\n' + ), + ] + wandb: Annotated[ + Wandb, + Field( + description="The settings for your integration with Weights and Biases. This payload specifies the project that\nmetrics will be sent to. Optionally, you can set an explicit display name for your run, add tags\nto your run, and set a default entity (team, username, etc) to be associated with your run.\n" + ), + ] + + +class CreateFineTuningJobRequest(BaseModel): + model: Annotated[ + Union[str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo", "gpt-4o-mini"]], + Field( + description="The name of the model to fine-tune. You can select one of the\n[supported models](/docs/guides/fine-tuning/which-models-can-be-fine-tuned).\n", + examples=["gpt-4o-mini"], + ), + ] + training_file: Annotated[ + str, + Field( + description="The ID of an uploaded file that contains training data.\n\nSee [upload file](/docs/api-reference/files/create) for how to upload a file.\n\nYour dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`.\n\nThe contents of the file should differ depending on if the model uses the [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) format.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + examples=["file-abc123"], + ), + ] + hyperparameters: Annotated[ + Optional[Hyperparameters], + Field(None, description="The hyperparameters used for the fine-tuning job."), + ] + suffix: Annotated[ + Optional[str], + Field( + None, + description='A string of up to 18 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`.\n', + max_length=40, + min_length=1, + ), + ] + validation_file: Annotated[ + Optional[str], + Field( + None, + description="The ID of an uploaded file that contains validation data.\n\nIf you provide this file, the data is used to generate validation\nmetrics periodically during fine-tuning. These metrics can be viewed in\nthe fine-tuning results file.\nThe same data should not be present in both train and validation files.\n\nYour dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + examples=["file-abc123"], + ), + ] + integrations: Annotated[ + Optional[List[Integration]], + Field( + None, + description="A list of integrations to enable for your fine-tuning job.", + ), + ] + seed: Annotated[ + Optional[int], + Field( + None, + description="The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases.\nIf a seed is not specified, one will be generated for you.\n", + examples=[42], + ge=0, + le=2147483647, + ), + ] + + +class Input(RootModel[List[str]]): + root: Annotated[ + List[str], + Field( + description="The array of strings that will be turned into an embedding.", + examples=["The quick brown fox jumped over the lazy dog"], + max_length=2048, + min_length=1, + title="array", + ), + ] + + +class Input1(RootModel[List[int]]): + root: Annotated[ + List[int], + Field( + description="The array of integers that will be turned into an embedding.", + examples=["[1212, 318, 257, 1332, 13]"], + max_length=2048, + min_length=1, + title="array", + ), + ] + + +class Input2Item(RootModel[List[int]]): + root: Annotated[List[int], Field(min_length=1)] + + +class Input2(RootModel[List[Input2Item]]): + root: Annotated[ + List[Input2Item], + Field( + description="The array of arrays containing integers that will be turned into an embedding.", + examples=["[[1212, 318, 257, 1332, 13]]"], + max_length=2048, + min_length=1, + title="array", + ), + ] + + +class CreateEmbeddingRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + input: Annotated[ + Union[str, Input, Input1, Input2], + Field( + description="Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + examples=["The quick brown fox jumped over the lazy dog"], + ), + ] + model: Annotated[ + Union[ + str, + Literal[ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ], + ], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + examples=["text-embedding-3-small"], + ), + ] + encoding_format: Annotated[ + Optional[Literal["float", "base64"]], + Field( + "float", + description="The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).", + examples=["float"], + ), + ] + dimensions: Annotated[ + Optional[int], + Field( + None, + description="The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.\n", + ge=1, + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class Usage1(BaseModel): + prompt_tokens: Annotated[int, Field(description="The number of tokens used by the prompt.")] + total_tokens: Annotated[ + int, Field(description="The total number of tokens used by the request.") + ] + + +class CreateTranscriptionRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file: Annotated[ + bytes, + Field( + description="The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n" + ), + ] + model: Annotated[ + Union[str, Literal["whisper-1"]], + Field( + description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", + examples=["whisper-1"], + ), + ] + language: Annotated[ + Optional[str], + Field( + None, + description="The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.\n", + ), + ] + prompt: Annotated[ + Optional[str], + Field( + None, + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language.\n", + ), + ] + response_format: Annotated[ + Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]], + Field( + "json", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 0, + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + ), + ] + timestamp_granularities__: Annotated[ + Optional[List[Literal["word", "segment"]]], + Field( + ["segment"], + alias="timestamp_granularities[]", + description="The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency.\n", + ), + ] + + +class CreateTranscriptionResponseJson(BaseModel): + text: Annotated[str, Field(description="The transcribed text.")] + + +class TranscriptionSegment(BaseModel): + id: Annotated[int, Field(description="Unique identifier of the segment.")] + seek: Annotated[int, Field(description="Seek offset of the segment.")] + start: Annotated[float, Field(description="Start time of the segment in seconds.")] + end: Annotated[float, Field(description="End time of the segment in seconds.")] + text: Annotated[str, Field(description="Text content of the segment.")] + tokens: Annotated[List[int], Field(description="Array of token IDs for the text content.")] + temperature: Annotated[ + float, + Field(description="Temperature parameter used for generating the segment."), + ] + avg_logprob: Annotated[ + float, + Field( + description="Average logprob of the segment. If the value is lower than -1, consider the logprobs failed." + ), + ] + compression_ratio: Annotated[ + float, + Field( + description="Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed." + ), + ] + no_speech_prob: Annotated[ + float, + Field( + description="Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this segment silent." + ), + ] + + +class TranscriptionWord(BaseModel): + word: Annotated[str, Field(description="The text content of the word.")] + start: Annotated[float, Field(description="Start time of the word in seconds.")] + end: Annotated[float, Field(description="End time of the word in seconds.")] + + +class CreateTranscriptionResponseVerboseJson(BaseModel): + language: Annotated[str, Field(description="The language of the input audio.")] + duration: Annotated[str, Field(description="The duration of the input audio.")] + text: Annotated[str, Field(description="The transcribed text.")] + words: Annotated[ + Optional[List[TranscriptionWord]], + Field(None, description="Extracted words and their corresponding timestamps."), + ] + segments: Annotated[ + Optional[List[TranscriptionSegment]], + Field( + None, + description="Segments of the transcribed text and their corresponding details.", + ), + ] + + +class CreateTranslationRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file: Annotated[ + bytes, + Field( + description="The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n" + ), + ] + model: Annotated[ + Union[str, Literal["whisper-1"]], + Field( + description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", + examples=["whisper-1"], + ), + ] + prompt: Annotated[ + Optional[str], + Field( + None, + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English.\n", + ), + ] + response_format: Annotated[ + Optional[str], + Field( + "json", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 0, + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + ), + ] + + +class CreateTranslationResponseJson(BaseModel): + text: str + + +class CreateTranslationResponseVerboseJson(BaseModel): + language: Annotated[ + str, + Field(description="The language of the output translation (always `english`)."), + ] + duration: Annotated[str, Field(description="The duration of the input audio.")] + text: Annotated[str, Field(description="The translated text.")] + segments: Annotated[ + Optional[List[TranscriptionSegment]], + Field( + None, + description="Segments of the translated text and their corresponding details.", + ), + ] + + +class CreateSpeechRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + model: Annotated[ + Union[str, Literal["tts-1", "tts-1-hd"]], + Field( + description="One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd`\n" + ), + ] + input: Annotated[ + str, + Field( + description="The text to generate audio for. The maximum length is 4096 characters.", + max_length=4096, + ), + ] + voice: Annotated[ + Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], + Field( + description="The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are available in the [Text to speech guide](/docs/guides/text-to-speech/voice-options)." + ), + ] + response_format: Annotated[ + Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]], + Field( + "mp3", + description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`.", + ), + ] + speed: Annotated[ + Optional[float], + Field( + 1.0, + description="The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.", + ge=0.25, + le=4.0, + ), + ] + + +class Model(BaseModel): + id: Annotated[ + str, + Field(description="The model identifier, which can be referenced in the API endpoints."), + ] + created: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) when the model was created."), + ] + object: Annotated[ + Literal["model"], Field(description='The object type, which is always "model".') + ] + owned_by: Annotated[str, Field(description="The organization that owns the model.")] + + +class OpenAIFile(BaseModel): + id: Annotated[ + str, + Field(description="The file identifier, which can be referenced in the API endpoints."), + ] + bytes: Annotated[int, Field(description="The size of the file, in bytes.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the file was created."), + ] + filename: Annotated[str, Field(description="The name of the file.")] + object: Annotated[ + Literal["file"], Field(description="The object type, which is always `file`.") + ] + purpose: Annotated[ + Literal[ + "assistants", + "assistants_output", + "batch", + "batch_output", + "fine-tune", + "fine-tune-results", + "vision", + ], + Field( + description="The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`." + ), + ] + status: Annotated[ + Literal["uploaded", "processed", "error"], + Field( + description="Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`." + ), + ] + status_details: Annotated[ + Optional[str], + Field( + None, + description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`.", + ), + ] + + +class Upload(BaseModel): + id: Annotated[ + str, + Field( + description="The Upload unique identifier, which can be referenced in API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Upload was created."), + ] + filename: Annotated[str, Field(description="The name of the file to be uploaded.")] + bytes: Annotated[int, Field(description="The intended number of bytes to be uploaded.")] + purpose: Annotated[ + str, + Field( + description="The intended purpose of the file. [Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values." + ), + ] + status: Annotated[ + Literal["pending", "completed", "cancelled", "expired"], + Field(description="The status of the Upload."), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Upload was created."), + ] + object: Annotated[ + Optional[Literal["upload"]], + Field(None, description='The object type, which is always "upload".'), + ] + file: Annotated[ + Optional[OpenAIFile], + Field(None, description="The ready File object after the Upload is completed."), + ] + + +class UploadPart(BaseModel): + id: Annotated[ + str, + Field( + description="The upload Part unique identifier, which can be referenced in API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Part was created."), + ] + upload_id: Annotated[ + str, + Field(description="The ID of the Upload object that this Part was added to."), + ] + object: Annotated[ + Literal["upload.part"], + Field(description="The object type, which is always `upload.part`."), + ] + + +class Embedding(BaseModel): + index: Annotated[ + int, Field(description="The index of the embedding in the list of embeddings.") + ] + embedding: Annotated[ + List[float], + Field( + description="The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).\n" + ), + ] + object: Annotated[ + Literal["embedding"], + Field(description='The object type, which is always "embedding".'), + ] + + +class Error1(BaseModel): + code: Annotated[str, Field(description="A machine-readable error code.")] + message: Annotated[str, Field(description="A human-readable error message.")] + param: Annotated[ + str, + Field( + description="The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific." + ), + ] + + +class NEpochs1(RootModel[int]): + root: Annotated[ + int, + Field( + description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.', + ge=1, + le=50, + ), + ] + + +class Hyperparameters1(BaseModel): + n_epochs: Annotated[ + Union[Literal["auto"], NEpochs1], + Field( + description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.' + ), + ] + + +class FineTuningIntegration(BaseModel): + type: Annotated[ + Literal["wandb"], + Field(description="The type of the integration being enabled for the fine-tuning job"), + ] + wandb: Annotated[ + Wandb, + Field( + description="The settings for your integration with Weights and Biases. This payload specifies the project that\nmetrics will be sent to. Optionally, you can set an explicit display name for your run, add tags\nto your run, and set a default entity (team, username, etc) to be associated with your run.\n" + ), + ] + + +class FineTuningJobEvent(BaseModel): + id: str + created_at: int + level: Literal["info", "warn", "error"] + message: str + object: Literal["fine_tuning.job.event"] + + +class Metrics(BaseModel): + step: Optional[float] = None + train_loss: Optional[float] = None + train_mean_token_accuracy: Optional[float] = None + valid_loss: Optional[float] = None + valid_mean_token_accuracy: Optional[float] = None + full_valid_loss: Optional[float] = None + full_valid_mean_token_accuracy: Optional[float] = None + + +class FineTuningJobCheckpoint(BaseModel): + id: Annotated[ + str, + Field( + description="The checkpoint identifier, which can be referenced in the API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the checkpoint was created."), + ] + fine_tuned_model_checkpoint: Annotated[ + str, + Field(description="The name of the fine-tuned checkpoint model that is created."), + ] + step_number: Annotated[ + int, Field(description="The step number that the checkpoint was created at.") + ] + metrics: Annotated[ + Metrics, + Field(description="Metrics at the step number during the fine-tuning job."), + ] + fine_tuning_job_id: Annotated[ + str, + Field(description="The name of the fine-tuning job that this checkpoint was created from."), + ] + object: Annotated[ + Literal["fine_tuning.job.checkpoint"], + Field(description='The object type, which is always "fine_tuning.job.checkpoint".'), + ] + + +class FinetuneCompletionRequestInput(BaseModel): + prompt: Annotated[ + Optional[str], + Field(None, description="The input prompt for this training example."), + ] + completion: Annotated[ + Optional[str], + Field(None, description="The desired completion for this training example."), + ] + + +class CompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, Field(description="Number of tokens in the generated completion.") + ] + prompt_tokens: Annotated[int, Field(description="Number of tokens in the prompt.")] + total_tokens: Annotated[ + int, + Field(description="Total number of tokens used in the request (prompt + completion)."), + ] + + +class RunCompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, + Field(description="Number of completion tokens used over the course of the run."), + ] + prompt_tokens: Annotated[ + int, + Field(description="Number of prompt tokens used over the course of the run."), + ] + total_tokens: Annotated[ + int, Field(description="Total number of tokens used (prompt + completion).") + ] + + +class RunStepCompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, + Field(description="Number of completion tokens used over the course of the run step."), + ] + prompt_tokens: Annotated[ + int, + Field(description="Number of prompt tokens used over the course of the run step."), + ] + total_tokens: Annotated[ + int, Field(description="Total number of tokens used (prompt + completion).") + ] + + +class AssistantsApiResponseFormatOption( + RootModel[ + Union[ + Literal["auto"], + ResponseFormatText, + ResponseFormatJsonObject, + ResponseFormatJsonSchema, + ] + ] +): + root: Annotated[ + Union[ + Literal["auto"], + ResponseFormatText, + ResponseFormatJsonObject, + ResponseFormatJsonSchema, + ], + Field( + description='Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' + ), + ] + + +class CodeInterpreter(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + [], + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] + + +class FileSearch(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ToolResources(BaseModel): + code_interpreter: Optional[CodeInterpreter] = None + file_search: Optional[FileSearch] = None + + +class CodeInterpreter1(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + [], + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] + + +class ChunkingStrategy(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class Static(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + max_chunk_size_tokens: Annotated[ + int, + Field( + description="The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`.", + ge=100, + le=4096, + ), + ] + chunk_overlap_tokens: Annotated[ + int, + Field( + description="The number of tokens that overlap between chunks. The default value is `400`.\n\nNote that the overlap must not exceed half of `max_chunk_size_tokens`.\n" + ), + ] + + +class ChunkingStrategy1(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy, ChunkingStrategy1]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class FileSearch1(BaseModel): + vector_store_ids: Annotated[ + List[str], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + Optional[List[VectorStore]], + Field( + None, + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ChunkingStrategy2(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class ChunkingStrategy3(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore1(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy2, ChunkingStrategy3]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class FileSearch2(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + List[VectorStore1], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ToolResources1(BaseModel): + code_interpreter: Optional[CodeInterpreter1] = None + file_search: Optional[Union[FileSearch1, FileSearch2]] = None + + +class CodeInterpreter2(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + [], + description="Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] + + +class FileSearch3(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ToolResources2(BaseModel): + code_interpreter: Optional[CodeInterpreter2] = None + file_search: Optional[FileSearch3] = None + + +class DeleteAssistantResponse(BaseModel): + id: str + deleted: bool + object: Literal["assistant.deleted"] + + +class AssistantToolsCode(BaseModel): + type: Annotated[ + Literal["code_interpreter"], + Field(description="The type of tool being defined: `code_interpreter`"), + ] + + +class FileSearch4(BaseModel): + max_num_results: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of results the file search tool should output. The default is 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between 1 and 50 inclusive.\n\nNote that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information.\n", + ge=1, + le=50, + ), + ] + + +class AssistantToolsFileSearch(BaseModel): + type: Annotated[ + Literal["file_search"], + Field(description="The type of tool being defined: `file_search`"), + ] + file_search: Annotated[ + Optional[FileSearch4], + Field(None, description="Overrides for the file search tool."), + ] + + +class AssistantToolsFileSearchTypeOnly(BaseModel): + type: Annotated[ + Literal["file_search"], + Field(description="The type of tool being defined: `file_search`"), + ] + + +class AssistantToolsFunction(BaseModel): + type: Annotated[ + Literal["function"], + Field(description="The type of tool being defined: `function`"), + ] + function: FunctionObject + + +class TruncationObject(BaseModel): + type: Annotated[ + Literal["auto", "last_messages"], + Field( + description="The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`." + ), + ] + last_messages: Annotated[ + Optional[int], + Field( + None, + description="The number of most recent messages from the thread when constructing the context for the run.", + ge=1, + ), + ] + + +class Function3(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class AssistantsNamedToolChoice(BaseModel): + type: Annotated[ + Literal["function", "code_interpreter", "file_search"], + Field( + description="The type of the tool. If type is `function`, the function name must be set" + ), + ] + function: Optional[Function3] = None + + +class LastError(BaseModel): + code: Annotated[ + Literal["server_error", "rate_limit_exceeded", "invalid_prompt"], + Field(description="One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class IncompleteDetails(BaseModel): + reason: Annotated[ + Optional[Literal["max_completion_tokens", "max_prompt_tokens"]], + Field( + None, + description="The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run.", + ), + ] + + +class ModifyRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class ToolOutput(BaseModel): + tool_call_id: Annotated[ + Optional[str], + Field( + None, + description="The ID of the tool call in the `required_action` object within the run object the output is being submitted for.", + ), + ] + output: Annotated[ + Optional[str], + Field( + None, + description="The output of the tool call to be submitted to continue the run.", + ), + ] + + +class SubmitToolOutputsRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + tool_outputs: Annotated[ + List[ToolOutput], + Field(description="A list of tools for which the outputs are being submitted."), + ] + stream: Annotated[ + Optional[bool], + Field( + None, + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + ), + ] + + +class Function4(BaseModel): + name: Annotated[str, Field(description="The name of the function.")] + arguments: Annotated[ + str, + Field(description="The arguments that the model expects you to pass to the function."), + ] + + +class RunToolCallObject(BaseModel): + id: Annotated[ + str, + Field( + description="The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](/docs/api-reference/runs/submitToolOutputs) endpoint." + ), + ] + type: Annotated[ + Literal["function"], + Field( + description="The type of tool call the output is required for. For now, this is always `function`." + ), + ] + function: Annotated[Function4, Field(description="The function definition.")] + + +class CodeInterpreter3(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + [], + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] + + +class FileSearch5(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ToolResources3(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch5] = None + + +class FileSearch6(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + + +class ToolResources4(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch6] = None + + +class ThreadObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["thread"], + Field(description="The object type, which is always `thread`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the thread was created."), + ] + tool_resources: Annotated[ + ToolResources4, + Field( + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class ChunkingStrategy4(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class ChunkingStrategy5(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore2(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy4, ChunkingStrategy5]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class FileSearch7(BaseModel): + vector_store_ids: Annotated[ + List[str], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + Optional[List[VectorStore2]], + Field( + None, + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + + +class ChunkingStrategy6(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class ChunkingStrategy7(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore3(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy6, ChunkingStrategy7]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class FileSearch8(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + List[VectorStore3], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + + +class ToolResources5(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[Union[FileSearch7, FileSearch8]] = None + + +class FileSearch9(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + + +class ToolResources6(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch9] = None + + +class ModifyThreadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + tool_resources: Annotated[ + Optional[ToolResources6], + Field( + None, + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class DeleteThreadResponse(BaseModel): + id: str + deleted: bool + object: Literal["thread.deleted"] + + +class ListThreadsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[ThreadObject] + first_id: Annotated[str, Field(examples=["asst_abc123"])] + last_id: Annotated[str, Field(examples=["asst_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class IncompleteDetails1(BaseModel): + reason: Annotated[ + Literal["content_filter", "max_tokens", "run_cancelled", "run_expired", "run_failed"], + Field(description="The reason the message is incomplete."), + ] + + +class Attachment(BaseModel): + file_id: Annotated[ + Optional[str], + Field(None, description="The ID of the file to attach to the message."), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearchTypeOnly]]], + Field(None, description="The tools to add this file to."), + ] + + +class ModifyMessageRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class DeleteMessageResponse(BaseModel): + id: str + deleted: bool + object: Literal["thread.message.deleted"] + + +class ImageFile(BaseModel): + file_id: Annotated[ + str, + Field( + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.' + ), + ] + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + "auto", + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + ), + ] + + +class MessageContentImageFileObject(BaseModel): + type: Annotated[Literal["image_file"], Field(description="Always `image_file`.")] + image_file: ImageFile + + +class ImageFile1(BaseModel): + file_id: Annotated[ + Optional[str], + Field( + None, + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.', + ), + ] + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + "auto", + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + ), + ] + + +class MessageDeltaContentImageFileObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Literal["image_file"], Field(description="Always `image_file`.")] + image_file: Optional[ImageFile1] = None + + +class ImageUrl1(BaseModel): + url: Annotated[ + AnyUrl, + Field( + description="The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + ), + ] + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + "auto", + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`", + ), + ] + + +class MessageContentImageUrlObject(BaseModel): + type: Annotated[Literal["image_url"], Field(description="The type of the content part.")] + image_url: ImageUrl1 + + +class ImageUrl2(BaseModel): + url: Annotated[ + Optional[str], + Field( + None, + description="The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp.", + ), + ] + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + "auto", + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + ), + ] + + +class MessageDeltaContentImageUrlObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Literal["image_url"], Field(description="Always `image_url`.")] + image_url: Optional[ImageUrl2] = None + + +class MessageContentRefusalObject(BaseModel): + type: Annotated[Literal["refusal"], Field(description="Always `refusal`.")] + refusal: str + + +class MessageRequestContentTextObject(BaseModel): + type: Annotated[Literal["text"], Field(description="Always `text`.")] + text: Annotated[str, Field(description="Text content to be sent to the model")] + + +class FileCitation(BaseModel): + file_id: Annotated[str, Field(description="The ID of the specific File the citation is from.")] + + +class MessageContentTextAnnotationsFileCitationObject(BaseModel): + type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] + text: Annotated[ + str, + Field(description="The text in the message content that needs to be replaced."), + ] + file_citation: FileCitation + start_index: Annotated[int, Field(ge=0)] + end_index: Annotated[int, Field(ge=0)] + + +class FilePath(BaseModel): + file_id: Annotated[str, Field(description="The ID of the file that was generated.")] + + +class MessageContentTextAnnotationsFilePathObject(BaseModel): + type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] + text: Annotated[ + str, + Field(description="The text in the message content that needs to be replaced."), + ] + file_path: FilePath + start_index: Annotated[int, Field(ge=0)] + end_index: Annotated[int, Field(ge=0)] + + +class MessageDeltaContentRefusalObject(BaseModel): + index: Annotated[int, Field(description="The index of the refusal part in the message.")] + type: Annotated[Literal["refusal"], Field(description="Always `refusal`.")] + refusal: Optional[str] = None + + +class FileCitation1(BaseModel): + file_id: Annotated[ + Optional[str], + Field(None, description="The ID of the specific File the citation is from."), + ] + quote: Annotated[Optional[str], Field(None, description="The specific quote in the file.")] + + +class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): + index: Annotated[ + int, Field(description="The index of the annotation in the text content part.") + ] + type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] + text: Annotated[ + Optional[str], + Field( + None, + description="The text in the message content that needs to be replaced.", + ), + ] + file_citation: Optional[FileCitation1] = None + start_index: Annotated[Optional[int], Field(None, ge=0)] + end_index: Annotated[Optional[int], Field(None, ge=0)] + + +class FilePath1(BaseModel): + file_id: Annotated[ + Optional[str], Field(None, description="The ID of the file that was generated.") + ] + + +class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): + index: Annotated[ + int, Field(description="The index of the annotation in the text content part.") + ] + type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] + text: Annotated[ + Optional[str], + Field( + None, + description="The text in the message content that needs to be replaced.", + ), + ] + file_path: Optional[FilePath1] = None + start_index: Annotated[Optional[int], Field(None, ge=0)] + end_index: Annotated[Optional[int], Field(None, ge=0)] + + +class LastError1(BaseModel): + code: Annotated[ + Literal["server_error", "rate_limit_exceeded"], + Field(description="One of `server_error` or `rate_limit_exceeded`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class MessageCreation(BaseModel): + message_id: Annotated[ + str, + Field(description="The ID of the message that was created by this run step."), + ] + + +class RunStepDetailsMessageCreationObject(BaseModel): + type: Annotated[Literal["message_creation"], Field(description="Always `message_creation`.")] + message_creation: MessageCreation + + +class MessageCreation1(BaseModel): + message_id: Annotated[ + Optional[str], + Field(None, description="The ID of the message that was created by this run step."), + ] + + +class RunStepDeltaStepDetailsMessageCreationObject(BaseModel): + type: Annotated[Literal["message_creation"], Field(description="Always `message_creation`.")] + message_creation: Optional[MessageCreation1] = None + + +class RunStepDetailsToolCallsCodeOutputLogsObject(BaseModel): + type: Annotated[Literal["logs"], Field(description="Always `logs`.")] + logs: Annotated[str, Field(description="The text output from the Code Interpreter tool call.")] + + +class RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject(BaseModel): + index: Annotated[int, Field(description="The index of the output in the outputs array.")] + type: Annotated[Literal["logs"], Field(description="Always `logs`.")] + logs: Annotated[ + Optional[str], + Field(None, description="The text output from the Code Interpreter tool call."), + ] + + +class Image1(BaseModel): + file_id: Annotated[ + str, Field(description="The [file](/docs/api-reference/files) ID of the image.") + ] + + +class RunStepDetailsToolCallsCodeOutputImageObject(BaseModel): + type: Annotated[Literal["image"], Field(description="Always `image`.")] + image: Image1 + + +class Image2(BaseModel): + file_id: Annotated[ + Optional[str], + Field(None, description="The [file](/docs/api-reference/files) ID of the image."), + ] + + +class RunStepDeltaStepDetailsToolCallsCodeOutputImageObject(BaseModel): + index: Annotated[int, Field(description="The index of the output in the outputs array.")] + type: Annotated[Literal["image"], Field(description="Always `image`.")] + image: Optional[Image2] = None + + +class RunStepDetailsToolCallsFileSearchObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call object.")] + type: Annotated[ + Literal["file_search"], + Field( + description="The type of tool call. This is always going to be `file_search` for this type of tool call." + ), + ] + file_search: Annotated[ + Dict[str, Any], + Field(description="For now, this is always going to be an empty object."), + ] + + +class RunStepDeltaStepDetailsToolCallsFileSearchObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] + type: Annotated[ + Literal["file_search"], + Field( + description="The type of tool call. This is always going to be `file_search` for this type of tool call." + ), + ] + file_search: Annotated[ + Dict[str, Any], + Field(description="For now, this is always going to be an empty object."), + ] + + +class Function5(BaseModel): + name: Annotated[str, Field(description="The name of the function.")] + arguments: Annotated[str, Field(description="The arguments passed to the function.")] + output: Annotated[ + str, + Field( + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." + ), + ] + + +class RunStepDetailsToolCallsFunctionObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call object.")] + type: Annotated[ + Literal["function"], + Field( + description="The type of tool call. This is always going to be `function` for this type of tool call." + ), + ] + function: Annotated[ + Function5, Field(description="The definition of the function that was called.") + ] + + +class Function6(BaseModel): + name: Annotated[Optional[str], Field(None, description="The name of the function.")] + arguments: Annotated[ + Optional[str], Field(None, description="The arguments passed to the function.") + ] + output: Annotated[ + Optional[str], + Field( + None, + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet.", + ), + ] + + +class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] + type: Annotated[ + Literal["function"], + Field( + description="The type of tool call. This is always going to be `function` for this type of tool call." + ), + ] + function: Annotated[ + Optional[Function6], + Field(None, description="The definition of the function that was called."), + ] + + +class VectorStoreExpirationAfter(BaseModel): + anchor: Annotated[ + Literal["last_active_at"], + Field( + description="Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`." + ), + ] + days: Annotated[ + int, + Field( + description="The number of days after the anchor time that the vector store will expire.", + ge=1, + le=365, + ), + ] + + +class FileCounts(BaseModel): + in_progress: Annotated[ + int, + Field(description="The number of files that are currently being processed."), + ] + completed: Annotated[ + int, + Field(description="The number of files that have been successfully processed."), + ] + failed: Annotated[int, Field(description="The number of files that have failed to process.")] + cancelled: Annotated[int, Field(description="The number of files that were cancelled.")] + total: Annotated[int, Field(description="The total number of files.")] + + +class VectorStoreObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["vector_store"], + Field(description="The object type, which is always `vector_store`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the vector store was created."), + ] + name: Annotated[str, Field(description="The name of the vector store.")] + usage_bytes: Annotated[ + int, + Field(description="The total number of bytes used by the files in the vector store."), + ] + file_counts: FileCounts + status: Annotated[ + Literal["expired", "in_progress", "completed"], + Field( + description="The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use." + ), + ] + expires_after: Optional[VectorStoreExpirationAfter] = None + expires_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the vector store will expire.", + ), + ] + last_active_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store was last active." + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class UpdateVectorStoreRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + name: Annotated[Optional[str], Field(None, description="The name of the vector store.")] + expires_after: Optional[VectorStoreExpirationAfter] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class ListVectorStoresResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[VectorStoreObject] + first_id: Annotated[str, Field(examples=["vs_abc123"])] + last_id: Annotated[str, Field(examples=["vs_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class DeleteVectorStoreResponse(BaseModel): + id: str + deleted: bool + object: Literal["vector_store.deleted"] + + +class LastError2(BaseModel): + code: Annotated[ + Literal["server_error", "unsupported_file", "invalid_file"], + Field(description="One of `server_error` or `rate_limit_exceeded`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class OtherChunkingStrategyResponseParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["other"], Field(description="Always `other`.")] + + +class StaticChunkingStrategy(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + max_chunk_size_tokens: Annotated[ + int, + Field( + description="The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`.", + ge=100, + le=4096, + ), + ] + chunk_overlap_tokens: Annotated[ + int, + Field( + description="The number of tokens that overlap between chunks. The default value is `400`.\n\nNote that the overlap must not exceed half of `max_chunk_size_tokens`.\n" + ), + ] + + +class AutoChunkingStrategyRequestParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class StaticChunkingStrategyRequestParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: StaticChunkingStrategy + + +class ChunkingStrategyRequestParam( + RootModel[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]] +): + root: Annotated[ + Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] + + +class CreateVectorStoreFileRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file_id: Annotated[ + str, + Field( + description="A [File](/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files." + ), + ] + chunking_strategy: Optional[ChunkingStrategyRequestParam] = None + + +class DeleteVectorStoreFileResponse(BaseModel): + id: str + deleted: bool + object: Literal["vector_store.file.deleted"] + + +class FileCounts1(BaseModel): + in_progress: Annotated[ + int, + Field(description="The number of files that are currently being processed."), + ] + completed: Annotated[int, Field(description="The number of files that have been processed.")] + failed: Annotated[int, Field(description="The number of files that have failed to process.")] + cancelled: Annotated[int, Field(description="The number of files that where cancelled.")] + total: Annotated[int, Field(description="The total number of files.")] + + +class VectorStoreFileBatchObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["vector_store.files_batch"], + Field(description="The object type, which is always `vector_store.file_batch`."), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store files batch was created." + ), + ] + vector_store_id: Annotated[ + str, + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to." + ), + ] + status: Annotated[ + Literal["in_progress", "completed", "cancelled", "failed"], + Field( + description="The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`." + ), + ] + file_counts: FileCounts1 + + +class CreateVectorStoreFileBatchRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file_ids: Annotated[ + List[str], + Field( + description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", + max_length=500, + min_length=1, + ), + ] + chunking_strategy: Optional[ChunkingStrategyRequestParam] = None + + +class ThreadStreamEvent1(BaseModel): + event: Literal["thread.created"] + data: ThreadObject + + +class ThreadStreamEvent(RootModel[ThreadStreamEvent1]): + root: ThreadStreamEvent1 + + +class ErrorEvent(BaseModel): + event: Literal["error"] + data: Error + + +class DoneEvent(BaseModel): + event: Literal["done"] + data: Literal["[DONE]"] + + +class Datum(BaseModel): + code: Annotated[ + Optional[str], + Field(None, description="An error code identifying the error type."), + ] + message: Annotated[ + Optional[str], + Field( + None, + description="A human-readable message providing more details about the error.", + ), + ] + param: Annotated[ + Optional[str], + Field( + None, + description="The name of the parameter that caused the error, if applicable.", + ), + ] + line: Annotated[ + Optional[int], + Field( + None, + description="The line number of the input file where the error occurred, if applicable.", + ), + ] + + +class Errors(BaseModel): + object: Annotated[ + Optional[str], + Field(None, description="The object type, which is always `list`."), + ] + data: Optional[List[Datum]] = None + + +class RequestCounts(BaseModel): + total: Annotated[int, Field(description="Total number of requests in the batch.")] + completed: Annotated[ + int, + Field(description="Number of requests that have been completed successfully."), + ] + failed: Annotated[int, Field(description="Number of requests that have failed.")] + + +class Batch(BaseModel): + id: str + object: Annotated[ + Literal["batch"], Field(description="The object type, which is always `batch`.") + ] + endpoint: Annotated[str, Field(description="The OpenAI API endpoint used by the batch.")] + errors: Optional[Errors] = None + input_file_id: Annotated[str, Field(description="The ID of the input file for the batch.")] + completion_window: Annotated[ + str, + Field(description="The time frame within which the batch should be processed."), + ] + status: Annotated[ + Literal[ + "validating", + "failed", + "in_progress", + "finalizing", + "completed", + "expired", + "cancelling", + "cancelled", + ], + Field(description="The current status of the batch."), + ] + output_file_id: Annotated[ + Optional[str], + Field( + None, + description="The ID of the file containing the outputs of successfully executed requests.", + ), + ] + error_file_id: Annotated[ + Optional[str], + Field( + None, + description="The ID of the file containing the outputs of requests with errors.", + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the batch was created."), + ] + in_progress_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch started processing.", + ), + ] + expires_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch will expire.", + ), + ] + finalizing_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch started finalizing.", + ), + ] + completed_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch was completed.", + ), + ] + failed_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch failed.", + ), + ] + expired_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch expired.", + ), + ] + cancelling_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch started cancelling.", + ), + ] + cancelled_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the batch was cancelled.", + ), + ] + request_counts: Annotated[ + Optional[RequestCounts], + Field( + None, + description="The request counts for different statuses within the batch.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class BatchRequestInput(BaseModel): + custom_id: Annotated[ + Optional[str], + Field( + None, + description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch.", + ), + ] + method: Annotated[ + Optional[Literal["POST"]], + Field( + None, + description="The HTTP method to be used for the request. Currently only `POST` is supported.", + ), + ] + url: Annotated[ + Optional[str], + Field( + None, + description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported.", + ), + ] + + +class Response(BaseModel): + status_code: Annotated[ + Optional[int], Field(None, description="The HTTP status code of the response") + ] + request_id: Annotated[ + Optional[str], + Field( + None, + description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support.", + ), + ] + body: Annotated[ + Optional[Dict[str, Any]], + Field(None, description="The JSON body of the response"), + ] + + +class Error2(BaseModel): + code: Annotated[Optional[str], Field(None, description="A machine-readable error code.")] + message: Annotated[Optional[str], Field(None, description="A human-readable error message.")] + + +class BatchRequestOutput(BaseModel): + id: Optional[str] = None + custom_id: Annotated[ + Optional[str], + Field( + None, + description="A developer-provided per-request id that will be used to match outputs to inputs.", + ), + ] + response: Optional[Response] = None + error: Annotated[ + Optional[Error2], + Field( + None, + description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure.", + ), + ] + + +class ListBatchesResponse(BaseModel): + data: List[Batch] + first_id: Annotated[Optional[str], Field(None, examples=["batch_abc123"])] + last_id: Annotated[Optional[str], Field(None, examples=["batch_abc456"])] + has_more: bool + object: Literal["list"] + + +class AuditLogActorServiceAccount(BaseModel): + id: Annotated[Optional[str], Field(None, description="The service account id.")] + + +class AuditLogActorUser(BaseModel): + id: Annotated[Optional[str], Field(None, description="The user id.")] + email: Annotated[Optional[str], Field(None, description="The user email.")] + + +class AuditLogActorApiKey(BaseModel): + id: Annotated[Optional[str], Field(None, description="The tracking id of the API key.")] + type: Annotated[ + Optional[Literal["user", "service_account"]], + Field( + None, + description="The type of API key. Can be either `user` or `service_account`.", + ), + ] + user: Optional[AuditLogActorUser] = None + service_account: Optional[AuditLogActorServiceAccount] = None + + +class AuditLogActorSession(BaseModel): + user: Optional[AuditLogActorUser] = None + ip_address: Annotated[ + Optional[str], + Field(None, description="The IP address from which the action was performed."), + ] + + +class AuditLogActor(BaseModel): + type: Annotated[ + Optional[Literal["session", "api_key"]], + Field(None, description="The type of actor. Is either `session` or `api_key`."), + ] + session: Optional[AuditLogActorSession] = None + api_key: Optional[AuditLogActorApiKey] = None + + +class AuditLogEventType( + RootModel[ + Literal[ + "api_key.created", + "api_key.updated", + "api_key.deleted", + "invite.sent", + "invite.accepted", + "invite.deleted", + "login.succeeded", + "login.failed", + "logout.succeeded", + "logout.failed", + "organization.updated", + "project.created", + "project.updated", + "project.archived", + "service_account.created", + "service_account.updated", + "service_account.deleted", + "user.added", + "user.updated", + "user.deleted", + ] + ] +): + root: Annotated[ + Literal[ + "api_key.created", + "api_key.updated", + "api_key.deleted", + "invite.sent", + "invite.accepted", + "invite.deleted", + "login.succeeded", + "login.failed", + "logout.succeeded", + "logout.failed", + "organization.updated", + "project.created", + "project.updated", + "project.archived", + "service_account.created", + "service_account.updated", + "service_account.deleted", + "user.added", + "user.updated", + "user.deleted", + ], + Field(description="The event type."), + ] + + +class Project(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + name: Annotated[Optional[str], Field(None, description="The project title.")] + + +class Data(BaseModel): + scopes: Annotated[ + Optional[List[str]], + Field( + None, + description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`', + ), + ] + + +class ApiKeyCreated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + data: Annotated[ + Optional[Data], + Field(None, description="The payload used to create the API key."), + ] + + +class ChangesRequested(BaseModel): + scopes: Annotated[ + Optional[List[str]], + Field( + None, + description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`', + ), + ] + + +class ApiKeyUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + changes_requested: Annotated[ + Optional[ChangesRequested], + Field(None, description="The payload used to update the API key."), + ] + + +class ApiKeyDeleted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + + +class Data1(BaseModel): + email: Annotated[ + Optional[str], Field(None, description="The email invited to the organization.") + ] + role: Annotated[ + Optional[str], + Field( + None, + description="The role the email was invited to be. Is either `owner` or `member`.", + ), + ] + + +class InviteSent(BaseModel): + id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + data: Annotated[ + Optional[Data1], + Field(None, description="The payload used to create the invite."), + ] + + +class InviteAccepted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + + +class InviteDeleted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + + +class LoginFailed(BaseModel): + error_code: Annotated[Optional[str], Field(None, description="The error code of the failure.")] + error_message: Annotated[ + Optional[str], Field(None, description="The error message of the failure.") + ] + + +class LogoutFailed(BaseModel): + error_code: Annotated[Optional[str], Field(None, description="The error code of the failure.")] + error_message: Annotated[ + Optional[str], Field(None, description="The error message of the failure.") + ] + + +class Settings(BaseModel): + threads_ui_visibility: Annotated[ + Optional[str], + Field( + None, + description="Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`.", + ), + ] + usage_dashboard_visibility: Annotated[ + Optional[str], + Field( + None, + description="Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`.", + ), + ] + + +class ChangesRequested1(BaseModel): + title: Annotated[Optional[str], Field(None, description="The organization title.")] + description: Annotated[Optional[str], Field(None, description="The organization description.")] + name: Annotated[Optional[str], Field(None, description="The organization name.")] + settings: Optional[Settings] = None + + +class OrganizationUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The organization ID.")] + changes_requested: Annotated[ + Optional[ChangesRequested1], + Field(None, description="The payload used to update the organization settings."), + ] + + +class Data2(BaseModel): + name: Annotated[Optional[str], Field(None, description="The project name.")] + title: Annotated[ + Optional[str], + Field(None, description="The title of the project as seen on the dashboard."), + ] + + +class ProjectCreated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + data: Annotated[ + Optional[Data2], + Field(None, description="The payload used to create the project."), + ] + + +class ChangesRequested2(BaseModel): + title: Annotated[ + Optional[str], + Field(None, description="The title of the project as seen on the dashboard."), + ] + + +class ProjectUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + changes_requested: Annotated[ + Optional[ChangesRequested2], + Field(None, description="The payload used to update the project."), + ] + + +class ProjectArchived(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + + +class Data3(BaseModel): + role: Annotated[ + Optional[str], + Field( + None, + description="The role of the service account. Is either `owner` or `member`.", + ), + ] + + +class ServiceAccountCreated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The service account ID.")] + data: Annotated[ + Optional[Data3], + Field(None, description="The payload used to create the service account."), + ] + + +class ChangesRequested3(BaseModel): + role: Annotated[ + Optional[str], + Field( + None, + description="The role of the service account. Is either `owner` or `member`.", + ), + ] + + +class ServiceAccountUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The service account ID.")] + changes_requested: Annotated[ + Optional[ChangesRequested3], + Field(None, description="The payload used to updated the service account."), + ] + + +class ServiceAccountDeleted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The service account ID.")] + + +class Data4(BaseModel): + role: Annotated[ + Optional[str], + Field(None, description="The role of the user. Is either `owner` or `member`."), + ] + + +class UserAdded(BaseModel): + id: Annotated[Optional[str], Field(None, description="The user ID.")] + data: Annotated[ + Optional[Data4], + Field(None, description="The payload used to add the user to the project."), + ] + + +class ChangesRequested4(BaseModel): + role: Annotated[ + Optional[str], + Field(None, description="The role of the user. Is either `owner` or `member`."), + ] + + +class UserUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + changes_requested: Annotated[ + Optional[ChangesRequested4], + Field(None, description="The payload used to update the user."), + ] + + +class UserDeleted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The user ID.")] + + +class AuditLog(BaseModel): + id: Annotated[str, Field(description="The ID of this log.")] + type: AuditLogEventType + effective_at: Annotated[int, Field(description="The Unix timestamp (in seconds) of the event.")] + project: Annotated[ + Optional[Project], + Field( + None, + description="The project that the action was scoped to. Absent for actions not scoped to projects.", + ), + ] + actor: AuditLogActor + api_key_created: Annotated[ + Optional[ApiKeyCreated], + Field( + None, + alias="api_key.created", + description="The details for events with this `type`.", + ), + ] + api_key_updated: Annotated[ + Optional[ApiKeyUpdated], + Field( + None, + alias="api_key.updated", + description="The details for events with this `type`.", + ), + ] + api_key_deleted: Annotated[ + Optional[ApiKeyDeleted], + Field( + None, + alias="api_key.deleted", + description="The details for events with this `type`.", + ), + ] + invite_sent: Annotated[ + Optional[InviteSent], + Field( + None, + alias="invite.sent", + description="The details for events with this `type`.", + ), + ] + invite_accepted: Annotated[ + Optional[InviteAccepted], + Field( + None, + alias="invite.accepted", + description="The details for events with this `type`.", + ), + ] + invite_deleted: Annotated[ + Optional[InviteDeleted], + Field( + None, + alias="invite.deleted", + description="The details for events with this `type`.", + ), + ] + login_failed: Annotated[ + Optional[LoginFailed], + Field( + None, + alias="login.failed", + description="The details for events with this `type`.", + ), + ] + logout_failed: Annotated[ + Optional[LogoutFailed], + Field( + None, + alias="logout.failed", + description="The details for events with this `type`.", + ), + ] + organization_updated: Annotated[ + Optional[OrganizationUpdated], + Field( + None, + alias="organization.updated", + description="The details for events with this `type`.", + ), + ] + project_created: Annotated[ + Optional[ProjectCreated], + Field( + None, + alias="project.created", + description="The details for events with this `type`.", + ), + ] + project_updated: Annotated[ + Optional[ProjectUpdated], + Field( + None, + alias="project.updated", + description="The details for events with this `type`.", + ), + ] + project_archived: Annotated[ + Optional[ProjectArchived], + Field( + None, + alias="project.archived", + description="The details for events with this `type`.", + ), + ] + service_account_created: Annotated[ + Optional[ServiceAccountCreated], + Field( + None, + alias="service_account.created", + description="The details for events with this `type`.", + ), + ] + service_account_updated: Annotated[ + Optional[ServiceAccountUpdated], + Field( + None, + alias="service_account.updated", + description="The details for events with this `type`.", + ), + ] + service_account_deleted: Annotated[ + Optional[ServiceAccountDeleted], + Field( + None, + alias="service_account.deleted", + description="The details for events with this `type`.", + ), + ] + user_added: Annotated[ + Optional[UserAdded], + Field( + None, + alias="user.added", + description="The details for events with this `type`.", + ), + ] + user_updated: Annotated[ + Optional[UserUpdated], + Field( + None, + alias="user.updated", + description="The details for events with this `type`.", + ), + ] + user_deleted: Annotated[ + Optional[UserDeleted], + Field( + None, + alias="user.deleted", + description="The details for events with this `type`.", + ), + ] + + +class ListAuditLogsResponse(BaseModel): + object: Literal["list"] + data: List[AuditLog] + first_id: Annotated[str, Field(examples=["audit_log-defb456h8dks"])] + last_id: Annotated[str, Field(examples=["audit_log-hnbkd8s93s"])] + has_more: bool + + +class Invite(BaseModel): + object: Annotated[ + Literal["organization.invite"], + Field(description="The object type, which is always `organization.invite`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + email: Annotated[ + str, + Field(description="The email address of the individual to whom the invite was sent"), + ] + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + status: Annotated[ + Literal["accepted", "expired", "pending"], + Field(description="`accepted`,`expired`, or `pending`"), + ] + invited_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the invite was sent."), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the invite expires."), + ] + accepted_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) of when the invite was accepted.", + ), + ] + + +class InviteListResponse(BaseModel): + object: Annotated[Literal["list"], Field(description="The object type, which is always `list`")] + data: List[Invite] + first_id: Annotated[ + Optional[str], + Field(None, description="The first `invite_id` in the retrieved `list`"), + ] + last_id: Annotated[ + Optional[str], + Field(None, description="The last `invite_id` in the retrieved `list`"), + ] + has_more: Annotated[ + Optional[bool], + Field( + None, + description="The `has_more` property is used for pagination to indicate there are additional results.", + ), + ] + + +class InviteRequest(BaseModel): + email: Annotated[str, Field(description="Send an email to this address")] + role: Annotated[Literal["reader", "owner"], Field(description="`owner` or `reader`")] + + +class InviteDeleteResponse(BaseModel): + object: Annotated[ + Literal["organization.invite.deleted"], + Field(description="The object type, which is always `organization.invite.deleted`"), + ] + id: str + deleted: bool + + +class User(BaseModel): + object: Annotated[ + Literal["organization.user"], + Field(description="The object type, which is always `organization.user`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the user")] + email: Annotated[str, Field(description="The email address of the user")] + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + added_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the user was added."), + ] + + +class UserListResponse(BaseModel): + object: Literal["list"] + data: List[User] + first_id: str + last_id: str + has_more: bool + + +class UserRoleUpdateRequest(BaseModel): + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + + +class UserDeleteResponse(BaseModel): + object: Literal["organization.user.deleted"] + id: str + deleted: bool + + +class Project1(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + object: Annotated[ + Literal["organization.project"], + Field(description="The object type, which is always `organization.project`"), + ] + name: Annotated[str, Field(description="The name of the project. This appears in reporting.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the project was created."), + ] + archived_at: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) of when the project was archived or `null`.", + ), + ] + status: Annotated[Literal["active", "archived"], Field(description="`active` or `archived`")] + + +class ProjectListResponse(BaseModel): + object: Literal["list"] + data: List[Project1] + first_id: str + last_id: str + has_more: bool + + +class ProjectCreateRequest(BaseModel): + name: Annotated[ + str, + Field(description="The friendly name of the project, this name appears in reports."), + ] + + +class ProjectUpdateRequest(BaseModel): + name: Annotated[ + str, + Field(description="The updated name of the project, this name appears in reports."), + ] + + +class DefaultProjectErrorResponse(BaseModel): + code: int + message: str + + +class ProjectUser(BaseModel): + object: Annotated[ + Literal["organization.project.user"], + Field(description="The object type, which is always `organization.project.user`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the user")] + email: Annotated[str, Field(description="The email address of the user")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + added_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the project was added."), + ] + + +class ProjectUserListResponse(BaseModel): + object: str + data: List[ProjectUser] + first_id: str + last_id: str + has_more: bool + + +class ProjectUserCreateRequest(BaseModel): + user_id: Annotated[str, Field(description="The ID of the user.")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + + +class ProjectUserUpdateRequest(BaseModel): + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + + +class ProjectUserDeleteResponse(BaseModel): + object: Literal["organization.project.user.deleted"] + id: str + deleted: bool + + +class ProjectServiceAccount(BaseModel): + object: Annotated[ + Literal["organization.project.service_account"], + Field( + description="The object type, which is always `organization.project.service_account`" + ), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the service account")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the service account was created" + ), + ] + + +class ProjectServiceAccountListResponse(BaseModel): + object: Literal["list"] + data: List[ProjectServiceAccount] + first_id: str + last_id: str + has_more: bool + + +class ProjectServiceAccountCreateRequest(BaseModel): + name: Annotated[str, Field(description="The name of the service account being created.")] + + +class ProjectServiceAccountApiKey(BaseModel): + object: Annotated[ + Literal["organization.project.service_account.api_key"], + Field( + description="The object type, which is always `organization.project.service_account.api_key`" + ), + ] + value: str + name: str + created_at: int + id: str + + +class ProjectServiceAccountDeleteResponse(BaseModel): + object: Literal["organization.project.service_account.deleted"] + id: str + deleted: bool + + +class Owner(BaseModel): + type: Annotated[ + Optional[Literal["user", "service_account"]], + Field(None, description="`user` or `service_account`"), + ] + user: Optional[ProjectUser] = None + service_account: Optional[ProjectServiceAccount] = None + + +class ProjectApiKey(BaseModel): + object: Annotated[ + Literal["organization.project.api_key"], + Field(description="The object type, which is always `organization.project.api_key`"), + ] + redacted_value: Annotated[str, Field(description="The redacted value of the API key")] + name: Annotated[str, Field(description="The name of the API key")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the API key was created"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + owner: Owner + + +class ProjectApiKeyListResponse(BaseModel): + object: Literal["list"] + data: List[ProjectApiKey] + first_id: str + last_id: str + has_more: bool + + +class ProjectApiKeyDeleteResponse(BaseModel): + object: Literal["organization.project.api_key.deleted"] + id: str + deleted: bool + + +class ListModelsResponse(BaseModel): + object: Literal["list"] + data: List[Model] + + +class CreateCompletionRequest(BaseModel): + model: Annotated[ + Union[str, Literal["gpt-3.5-turbo-instruct", "davinci-002", "babbage-002"]], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] + prompt: Annotated[ + Union[str, List[str], Prompt, Prompt1], + Field( + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n" + ), + ] + best_of: Annotated[ + Optional[int], + Field( + 1, + description='Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.\n\nWhen used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n', + ge=0, + le=20, + ), + ] + echo: Annotated[ + Optional[bool], + Field(False, description="Echo back the prompt in addition to the completion\n"), + ] + frequency_penalty: Annotated[ + Optional[float], + Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] + logit_bias: Annotated[ + Optional[Dict[str, int]], + Field( + None, + description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n', + ), + ] + logprobs: Annotated[ + Optional[int], + Field( + None, + description="Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.\n\nThe maximum value for `logprobs` is 5.\n", + ge=0, + le=5, + ), + ] + max_tokens: Annotated[ + Optional[int], + Field( + 16, + description="The maximum number of [tokens](/tokenizer) that can be generated in the completion.\n\nThe token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + examples=[16], + ge=0, + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="How many completions to generate for each prompt.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n", + examples=[1], + ge=1, + le=128, + ), + ] + presence_penalty: Annotated[ + Optional[float], + Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] + seed: Annotated[ + Optional[int], + Field( + None, + description="If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\n\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ge=-9223372036854775808, + le=9223372036854775807, + ), + ] + stop: Annotated[ + Optional[Union[str, Stop]], + Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + ), + ] + stream: Annotated[ + Optional[bool], + Field( + False, + description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + stream_options: Optional[ChatCompletionStreamOptions] = None + suffix: Annotated[ + Optional[str], + Field( + None, + description="The suffix that comes after a completion of inserted text.\n\nThis parameter is only supported for `gpt-3.5-turbo-instruct`.\n", + examples=["test."], + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + + +class CreateCompletionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the completion.")] + choices: Annotated[ + List[Choice], + Field( + description="The list of completion choices the model generated for the input prompt." + ), + ] + created: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the completion was created."), + ] + model: Annotated[str, Field(description="The model used for completion.")] + system_fingerprint: Annotated[ + Optional[str], + Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ), + ] + object: Annotated[ + Literal["text_completion"], + Field(description='The object type, which is always "text_completion"'), + ] + usage: Optional[CompletionUsage] = None + + +class ChatCompletionTool(BaseModel): + type: Annotated[ + Literal["function"], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: FunctionObject + + +class ChatCompletionToolChoiceOption( + RootModel[Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoice]] +): + root: Annotated[ + Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoice], + Field( + description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tool and instead generates a message.\n`auto` means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools.\nSpecifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n\n`none` is the default when no tools are present. `auto` is the default if tools are present.\n' + ), + ] + + +class ChatCompletionMessageToolCalls(RootModel[List[ChatCompletionMessageToolCall]]): + root: Annotated[ + List[ChatCompletionMessageToolCall], + Field(description="The tool calls generated by the model, such as function calls."), + ] + + +class ChatCompletionResponseMessage(BaseModel): + content: Annotated[str, Field(description="The contents of the message.")] + refusal: Annotated[str, Field(description="The refusal message generated by the model.")] + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + role: Annotated[ + Literal["assistant"], + Field(description="The role of the author of this message."), + ] + function_call: Annotated[ + Optional[FunctionCall], + Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ), + ] + + +class Choice1(BaseModel): + finish_reason: Annotated[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + message: ChatCompletionResponseMessage + logprobs: Annotated[Logprobs2, Field(description="Log probability information for the choice.")] + + +class CreateChatCompletionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the chat completion.")] + choices: Annotated[ + List[Choice1], + Field( + description="A list of chat completion choices. Can be more than one if `n` is greater than 1." + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created." + ), + ] + model: Annotated[str, Field(description="The model used for the chat completion.")] + service_tier: Annotated[ + Optional[Literal["scale", "default"]], + Field( + None, + description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", + examples=["scale"], + ), + ] + system_fingerprint: Annotated[ + Optional[str], + Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ), + ] + object: Annotated[ + Literal["chat.completion"], + Field(description="The object type, which is always `chat.completion`."), + ] + usage: Optional[CompletionUsage] = None + + +class Choice2(BaseModel): + finish_reason: Annotated[ + Literal["stop", "length", "function_call", "content_filter"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + message: ChatCompletionResponseMessage + + +class CreateChatCompletionFunctionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the chat completion.")] + choices: Annotated[ + List[Choice2], + Field( + description="A list of chat completion choices. Can be more than one if `n` is greater than 1." + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created." + ), + ] + model: Annotated[str, Field(description="The model used for the chat completion.")] + system_fingerprint: Annotated[ + Optional[str], + Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + ), + ] + object: Annotated[ + Literal["chat.completion"], + Field(description="The object type, which is always `chat.completion`."), + ] + usage: Optional[CompletionUsage] = None + + +class ImagesResponse(BaseModel): + created: int + data: List[Image] + + +class ListFilesResponse(BaseModel): + data: List[OpenAIFile] + object: Literal["list"] + + +class ListFineTuningJobEventsResponse(BaseModel): + data: List[FineTuningJobEvent] + object: Literal["list"] + + +class ListFineTuningJobCheckpointsResponse(BaseModel): + data: List[FineTuningJobCheckpoint] + object: Literal["list"] + first_id: Optional[str] = None + last_id: Optional[str] = None + has_more: bool + + +class CreateEmbeddingResponse(BaseModel): + data: Annotated[ + List[Embedding], + Field(description="The list of embeddings generated by the model."), + ] + model: Annotated[ + str, Field(description="The name of the model used to generate the embedding.") + ] + object: Annotated[ + Literal["list"], Field(description='The object type, which is always "list".') + ] + usage: Annotated[Usage1, Field(description="The usage information for the request.")] + + +class FineTuningJob(BaseModel): + id: Annotated[ + str, + Field(description="The object identifier, which can be referenced in the API endpoints."), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job was created." + ), + ] + error: Annotated[ + Error1, + Field( + description="For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure." + ), + ] + fine_tuned_model: Annotated[ + str, + Field( + description="The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running." + ), + ] + finished_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running." + ), + ] + hyperparameters: Annotated[ + Hyperparameters1, + Field( + description="The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/fine-tuning) for more details." + ), + ] + model: Annotated[str, Field(description="The base model that is being fine-tuned.")] + object: Annotated[ + Literal["fine_tuning.job"], + Field(description='The object type, which is always "fine_tuning.job".'), + ] + organization_id: Annotated[ + str, Field(description="The organization that owns the fine-tuning job.") + ] + result_files: Annotated[ + List[str], + Field( + description="The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + status: Annotated[ + Literal["validating_files", "queued", "running", "succeeded", "failed", "cancelled"], + Field( + description="The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`." + ), + ] + trained_tokens: Annotated[ + int, + Field( + description="The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running." + ), + ] + training_file: Annotated[ + str, + Field( + description="The file ID used for training. You can retrieve the training data with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + validation_file: Annotated[ + str, + Field( + description="The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + integrations: Annotated[ + Optional[List[FineTuningIntegration]], + Field( + None, + description="A list of integrations to enable for this fine-tuning job.", + max_length=5, + ), + ] + seed: Annotated[int, Field(description="The seed used for the fine-tuning job.")] + estimated_finish: Annotated[ + Optional[int], + Field( + None, + description="The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.", + ), + ] + + +class AssistantObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["assistant"], + Field(description="The object type, which is always `assistant`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the assistant was created."), + ] + name: Annotated[ + str, + Field( + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] + description: Annotated[ + str, + Field( + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] + model: Annotated[ + str, + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] + instructions: Annotated[ + str, + Field( + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_length=128, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources], + Field( + None, + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class CreateAssistantRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + model: Annotated[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + examples=["gpt-4o"], + ), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] + description: Annotated[ + Optional[str], + Field( + None, + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] + instructions: Annotated[ + Optional[str], + Field( + None, + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + [], + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_length=128, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources1], + Field( + None, + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class ModifyAssistantRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + model: Annotated[ + Optional[str], + Field( + None, + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + ), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] + description: Annotated[ + Optional[str], + Field( + None, + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] + instructions: Annotated[ + Optional[str], + Field( + None, + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + [], + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_length=128, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources2], + Field( + None, + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class ListAssistantsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[AssistantObject] + first_id: Annotated[str, Field(examples=["asst_abc123"])] + last_id: Annotated[str, Field(examples=["asst_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class AssistantsApiToolChoiceOption( + RootModel[Union[Literal["none", "auto", "required"], AssistantsNamedToolChoice]] +): + root: Annotated[ + Union[Literal["none", "auto", "required"], AssistantsNamedToolChoice], + Field( + description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tools and instead generates a message.\n`auto` is the default value and means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools before responding to the user.\nSpecifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n' + ), + ] + + +class SubmitToolOutputs(BaseModel): + tool_calls: Annotated[ + List[RunToolCallObject], Field(description="A list of the relevant tool calls.") + ] + + +class RequiredAction(BaseModel): + type: Annotated[ + Literal["submit_tool_outputs"], + Field(description="For now, this is always `submit_tool_outputs`."), + ] + submit_tool_outputs: Annotated[ + SubmitToolOutputs, + Field(description="Details on the tool outputs needed for this run to continue."), + ] + + +class RunObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["thread.run"], + Field(description="The object type, which is always `thread.run`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was created."), + ] + thread_id: Annotated[ + str, + Field( + description="The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run." + ), + ] + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run." + ), + ] + status: Annotated[ + Literal[ + "queued", + "in_progress", + "requires_action", + "cancelling", + "cancelled", + "failed", + "completed", + "incomplete", + "expired", + ], + Field( + description="The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`." + ), + ] + required_action: Annotated[ + RequiredAction, + Field( + description="Details on the action required to continue the run. Will be `null` if no action is required." + ), + ] + last_error: Annotated[ + LastError, + Field( + description="The last error associated with this run. Will be `null` if there are no errors." + ), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run will expire."), + ] + started_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was started."), + ] + cancelled_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was cancelled."), + ] + failed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run failed."), + ] + completed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was completed."), + ] + incomplete_details: Annotated[ + IncompleteDetails, + Field( + description="Details on why the run is incomplete. Will be `null` if the run is not incomplete." + ), + ] + model: Annotated[ + str, + Field( + description="The model that the [assistant](/docs/api-reference/assistants) used for this run." + ), + ] + instructions: Annotated[ + str, + Field( + description="The instructions that the [assistant](/docs/api-reference/assistants) used for this run." + ), + ] + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.", + max_length=20, + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + usage: RunCompletionUsage + temperature: Annotated[ + Optional[float], + Field( + None, + description="The sampling temperature used for this run. If not set, defaults to 1.", + ), + ] + top_p: Annotated[ + Optional[float], + Field( + None, + description="The nucleus sampling value used for this run. If not set, defaults to 1.", + ), + ] + max_prompt_tokens: Annotated[ + int, + Field( + description="The maximum number of prompt tokens specified to have been used over the course of the run.\n", + ge=256, + ), + ] + max_completion_tokens: Annotated[ + int, + Field( + description="The maximum number of completion tokens specified to have been used over the course of the run.\n", + ge=256, + ), + ] + truncation_strategy: TruncationObject + tool_choice: AssistantsApiToolChoiceOption + parallel_tool_calls: ParallelToolCalls + response_format: AssistantsApiResponseFormatOption + + +class ListRunsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[RunObject] + first_id: Annotated[str, Field(examples=["run_abc123"])] + last_id: Annotated[str, Field(examples=["run_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class Content4( + RootModel[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageRequestContentTextObject, + ] + ] + ] +): + root: Annotated[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageRequestContentTextObject, + ] + ], + Field( + description="An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](/docs/models/overview).", + min_length=1, + title="Array of content parts", + ), + ] + + +class CreateMessageRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + role: Annotated[ + Literal["user", "assistant"], + Field( + description="The role of the entity that is creating the message. Allowed values include:\n- `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages.\n- `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation.\n" + ), + ] + content: Union[str, Content4] + attachments: Annotated[ + Optional[List[Attachment]], + Field( + None, + description="A list of files attached to the message, and the tools they should be added to.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class Text(BaseModel): + value: Annotated[str, Field(description="The data that makes up the text.")] + annotations: List[ + Union[ + MessageContentTextAnnotationsFileCitationObject, + MessageContentTextAnnotationsFilePathObject, + ] + ] + + +class MessageContentTextObject(BaseModel): + type: Annotated[Literal["text"], Field(description="Always `text`.")] + text: Text + + +class Text1(BaseModel): + value: Annotated[Optional[str], Field(None, description="The data that makes up the text.")] + annotations: Optional[ + List[ + Union[ + MessageDeltaContentTextAnnotationsFileCitationObject, + MessageDeltaContentTextAnnotationsFilePathObject, + ] + ] + ] = None + + +class MessageDeltaContentTextObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Literal["text"], Field(description="Always `text`.")] + text: Optional[Text1] = None + + +class CodeInterpreter7(BaseModel): + input: Annotated[str, Field(description="The input to the Code Interpreter tool call.")] + outputs: Annotated[ + List[ + Union[ + RunStepDetailsToolCallsCodeOutputLogsObject, + RunStepDetailsToolCallsCodeOutputImageObject, + ] + ], + Field( + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type." + ), + ] + + +class RunStepDetailsToolCallsCodeObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call.")] + type: Annotated[ + Literal["code_interpreter"], + Field( + description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." + ), + ] + code_interpreter: Annotated[ + CodeInterpreter7, + Field(description="The Code Interpreter tool call definition."), + ] + + +class CodeInterpreter8(BaseModel): + input: Annotated[ + Optional[str], + Field(None, description="The input to the Code Interpreter tool call."), + ] + outputs: Annotated[ + Optional[ + List[ + Union[ + RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject, + RunStepDeltaStepDetailsToolCallsCodeOutputImageObject, + ] + ] + ], + Field( + None, + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type.", + ), + ] + + +class RunStepDeltaStepDetailsToolCallsCodeObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] + type: Annotated[ + Literal["code_interpreter"], + Field( + description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." + ), + ] + code_interpreter: Annotated[ + Optional[CodeInterpreter8], + Field(None, description="The Code Interpreter tool call definition."), + ] + + +class CreateVectorStoreRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file_ids: Annotated[ + Optional[List[str]], + Field( + None, + description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", + max_length=500, + ), + ] + name: Annotated[Optional[str], Field(None, description="The name of the vector store.")] + expires_after: Optional[VectorStoreExpirationAfter] = None + chunking_strategy: Annotated[ + Optional[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]], + Field( + None, + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class StaticChunkingStrategyResponseParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: StaticChunkingStrategy + + +class RunStreamEvent1(BaseModel): + event: Literal["thread.run.created"] + data: RunObject + + +class RunStreamEvent2(BaseModel): + event: Literal["thread.run.queued"] + data: RunObject + + +class RunStreamEvent3(BaseModel): + event: Literal["thread.run.in_progress"] + data: RunObject + + +class RunStreamEvent4(BaseModel): + event: Literal["thread.run.requires_action"] + data: RunObject + + +class RunStreamEvent5(BaseModel): + event: Literal["thread.run.completed"] + data: RunObject + + +class RunStreamEvent6(BaseModel): + event: Literal["thread.run.incomplete"] + data: RunObject + + +class RunStreamEvent7(BaseModel): + event: Literal["thread.run.failed"] + data: RunObject + + +class RunStreamEvent8(BaseModel): + event: Literal["thread.run.cancelling"] + data: RunObject + + +class RunStreamEvent9(BaseModel): + event: Literal["thread.run.cancelled"] + data: RunObject + + +class RunStreamEvent10(BaseModel): + event: Literal["thread.run.expired"] + data: RunObject + + +class RunStreamEvent( + RootModel[ + Union[ + RunStreamEvent1, + RunStreamEvent2, + RunStreamEvent3, + RunStreamEvent4, + RunStreamEvent5, + RunStreamEvent6, + RunStreamEvent7, + RunStreamEvent8, + RunStreamEvent9, + RunStreamEvent10, + ] + ] +): + root: Union[ + RunStreamEvent1, + RunStreamEvent2, + RunStreamEvent3, + RunStreamEvent4, + RunStreamEvent5, + RunStreamEvent6, + RunStreamEvent7, + RunStreamEvent8, + RunStreamEvent9, + RunStreamEvent10, + ] + + +class ProjectServiceAccountCreateResponse(BaseModel): + object: Literal["organization.project.service_account"] + id: str + name: str + role: Annotated[ + Literal["member"], + Field(description="Service accounts can only have one role of type `member`"), + ] + created_at: int + api_key: ProjectServiceAccountApiKey + + +class ChatCompletionRequestAssistantMessage(BaseModel): + content: Annotated[ + Optional[Union[str, Content2]], + Field( + None, + description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n", + ), + ] + refusal: Annotated[ + Optional[str], Field(None, description="The refusal message by the assistant.") + ] + role: Annotated[ + Literal["assistant"], + Field(description="The role of the messages author, in this case `assistant`."), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ), + ] + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + function_call: Annotated[ + Optional[FunctionCall], + Field( + None, + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + ), + ] + + +class FineTuneChatCompletionRequestAssistantMessage(ChatCompletionRequestAssistantMessage): + weight: Annotated[ + Optional[Literal[0, 1]], + Field( + None, + description="Controls whether the assistant message is trained against (0 or 1)", + ), + ] + role: Annotated[ + Literal["assistant"], + Field(description="The role of the messages author, in this case `assistant`."), + ] + + +class ListPaginatedFineTuningJobsResponse(BaseModel): + data: List[FineTuningJob] + has_more: bool + object: Literal["list"] + + +class FinetuneChatRequestInput(BaseModel): + messages: Annotated[ + Optional[ + List[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + FineTuneChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + ] + ], + Field(None, min_length=1), + ] + tools: Annotated[ + Optional[List[ChatCompletionTool]], + Field(None, description="A list of tools the model may generate JSON inputs for."), + ] + parallel_tool_calls: Optional[ParallelToolCalls] = None + functions: Annotated[ + Optional[List[ChatCompletionFunctions]], + Field( + None, + description="A list of functions the model may generate JSON inputs for.", + max_length=128, + min_length=1, + ), + ] + + +class CreateRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run." + ), + ] + model: Annotated[ + Optional[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ] + ], + Field( + None, + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + examples=["gpt-4o"], + ), + ] + instructions: Annotated[ + Optional[str], + Field( + None, + description="Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis.", + ), + ] + additional_instructions: Annotated[ + Optional[str], + Field( + None, + description="Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions.", + ), + ] + additional_messages: Annotated[ + Optional[List[CreateMessageRequest]], + Field( + None, + description="Adds additional messages to the thread before creating the run.", + ), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + None, + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_length=20, + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + stream: Annotated[ + Optional[bool], + Field( + None, + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + ), + ] + max_prompt_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] + max_completion_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] + truncation_strategy: Optional[TruncationObject] = None + tool_choice: Optional[AssistantsApiToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class CreateThreadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + messages: Annotated[ + Optional[List[CreateMessageRequest]], + Field( + None, + description="A list of [messages](/docs/api-reference/messages) to start the thread with.", + ), + ] + tool_resources: Annotated[ + Optional[ToolResources5], + Field( + None, + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class MessageObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["thread.message"], + Field(description="The object type, which is always `thread.message`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the message was created."), + ] + thread_id: Annotated[ + str, + Field( + description="The [thread](/docs/api-reference/threads) ID that this message belongs to." + ), + ] + status: Annotated[ + Literal["in_progress", "incomplete", "completed"], + Field( + description="The status of the message, which can be either `in_progress`, `incomplete`, or `completed`." + ), + ] + incomplete_details: Annotated[ + IncompleteDetails1, + Field(description="On an incomplete message, details about why the message is incomplete."), + ] + completed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the message was completed."), + ] + incomplete_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the message was marked as incomplete." + ), + ] + role: Annotated[ + Literal["user", "assistant"], + Field(description="The entity that produced the message. One of `user` or `assistant`."), + ] + content: Annotated[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageContentTextObject, + MessageContentRefusalObject, + ] + ], + Field(description="The content of the message in array of text and/or images."), + ] + assistant_id: Annotated[ + str, + Field( + description="If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message." + ), + ] + run_id: Annotated[ + str, + Field( + description="The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints." + ), + ] + attachments: Annotated[ + List[Attachment], + Field( + description="A list of files attached to the message, and the tools they were added to." + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class Delta(BaseModel): + role: Annotated[ + Optional[Literal["user", "assistant"]], + Field( + None, + description="The entity that produced the message. One of `user` or `assistant`.", + ), + ] + content: Annotated[ + Optional[ + List[ + Union[ + MessageDeltaContentImageFileObject, + MessageDeltaContentTextObject, + MessageDeltaContentRefusalObject, + MessageDeltaContentImageUrlObject, + ] + ] + ], + Field( + None, + description="The content of the message in array of text and/or images.", + ), + ] + + +class MessageDeltaObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the message, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Literal["thread.message.delta"], + Field(description="The object type, which is always `thread.message.delta`."), + ] + delta: Annotated[ + Delta, + Field(description="The delta containing the fields that have changed on the Message."), + ] + + +class ListMessagesResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[MessageObject] + first_id: Annotated[str, Field(examples=["msg_abc123"])] + last_id: Annotated[str, Field(examples=["msg_abc123"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class RunStepDetailsToolCallsObject(BaseModel): + type: Annotated[Literal["tool_calls"], Field(description="Always `tool_calls`.")] + tool_calls: Annotated[ + List[ + Union[ + RunStepDetailsToolCallsCodeObject, + RunStepDetailsToolCallsFileSearchObject, + RunStepDetailsToolCallsFunctionObject, + ] + ], + Field( + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n" + ), + ] + + +class RunStepDeltaStepDetailsToolCallsObject(BaseModel): + type: Annotated[Literal["tool_calls"], Field(description="Always `tool_calls`.")] + tool_calls: Annotated[ + Optional[ + List[ + Union[ + RunStepDeltaStepDetailsToolCallsCodeObject, + RunStepDeltaStepDetailsToolCallsFileSearchObject, + RunStepDeltaStepDetailsToolCallsFunctionObject, + ] + ] + ], + Field( + None, + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n", + ), + ] + + +class VectorStoreFileObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["vector_store.file"], + Field(description="The object type, which is always `vector_store.file`."), + ] + usage_bytes: Annotated[ + int, + Field( + description="The total vector store usage in bytes. Note that this may be different from the original file size." + ), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store file was created." + ), + ] + vector_store_id: Annotated[ + str, + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to." + ), + ] + status: Annotated[ + Literal["in_progress", "completed", "cancelled", "failed"], + Field( + description="The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use." + ), + ] + last_error: Annotated[ + LastError2, + Field( + description="The last error associated with this vector store file. Will be `null` if there are no errors." + ), + ] + chunking_strategy: Annotated[ + Optional[Union[StaticChunkingStrategyResponseParam, OtherChunkingStrategyResponseParam]], + Field(None, description="The strategy used to chunk the file."), + ] + + +class ListVectorStoreFilesResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[VectorStoreFileObject] + first_id: Annotated[str, Field(examples=["file-abc123"])] + last_id: Annotated[str, Field(examples=["file-abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class MessageStreamEvent1(BaseModel): + event: Literal["thread.message.created"] + data: MessageObject + + +class MessageStreamEvent2(BaseModel): + event: Literal["thread.message.in_progress"] + data: MessageObject + + +class MessageStreamEvent3(BaseModel): + event: Literal["thread.message.delta"] + data: MessageDeltaObject + + +class MessageStreamEvent4(BaseModel): + event: Literal["thread.message.completed"] + data: MessageObject + + +class MessageStreamEvent5(BaseModel): + event: Literal["thread.message.incomplete"] + data: MessageObject + + +class MessageStreamEvent( + RootModel[ + Union[ + MessageStreamEvent1, + MessageStreamEvent2, + MessageStreamEvent3, + MessageStreamEvent4, + MessageStreamEvent5, + ] + ] +): + root: Union[ + MessageStreamEvent1, + MessageStreamEvent2, + MessageStreamEvent3, + MessageStreamEvent4, + MessageStreamEvent5, + ] + + +class ChatCompletionRequestMessage( + RootModel[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + ] +): + root: Annotated[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ], + Field(discriminator="role"), + ] + + +class CreateChatCompletionRequest(BaseModel): + messages: Annotated[ + List[ChatCompletionRequestMessage], + Field( + description="A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).", + min_length=1, + ), + ] + model: Annotated[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "chatgpt-4o-latest", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ], + Field( + description="ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.", + examples=["gpt-4o"], + ), + ] + frequency_penalty: Annotated[ + Optional[float], + Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] + logit_bias: Annotated[ + Optional[Dict[str, int]], + Field( + None, + description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n", + ), + ] + logprobs: Annotated[ + Optional[bool], + Field( + False, + description="Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.", + ), + ] + top_logprobs: Annotated[ + Optional[int], + Field( + None, + description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.", + ge=0, + le=20, + ), + ] + max_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + ), + ] + n: Annotated[ + Optional[int], + Field( + 1, + description="How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs.", + examples=[1], + ge=1, + le=128, + ), + ] + presence_penalty: Annotated[ + Optional[float], + Field( + 0, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] + response_format: Annotated[ + Optional[Union[ResponseFormatText, ResponseFormatJsonObject, ResponseFormatJsonSchema]], + Field( + None, + description='An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n', + ), + ] + seed: Annotated[ + Optional[int], + Field( + None, + description="This feature is in Beta.\nIf specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ge=-9223372036854775808, + le=9223372036854775807, + ), + ] + service_tier: Annotated[ + Optional[Literal["auto", "default"]], + Field( + None, + description="Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service:\n - If set to 'auto', the system will utilize scale tier credits until they are exhausted.\n - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee.\n - When not set, the default behavior is 'auto'.\n\n When this parameter is set, the response body will include the `service_tier` utilized.\n", + ), + ] + stop: Annotated[ + Optional[Union[str, Stop1]], + Field( + None, + description="Up to 4 sequences where the API will stop generating further tokens.\n", + ), + ] + stream: Annotated[ + Optional[bool], + Field( + False, + description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + stream_options: Optional[ChatCompletionStreamOptions] = None + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + tools: Annotated[ + Optional[List[ChatCompletionTool]], + Field( + None, + description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.\n", + ), + ] + tool_choice: Optional[ChatCompletionToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + user: Annotated[ + Optional[str], + Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] + function_call: Annotated[ + Optional[Union[Literal["none", "auto"], ChatCompletionFunctionCallOption]], + Field( + None, + description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n', + ), + ] + functions: Annotated[ + Optional[List[ChatCompletionFunctions]], + Field( + None, + description="Deprecated in favor of `tools`.\n\nA list of functions the model may generate JSON inputs for.\n", + max_length=128, + min_length=1, + ), + ] + + +class CreateThreadAndRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run." + ), + ] + thread: Annotated[ + Optional[CreateThreadRequest], + Field( + None, + description="If no thread is provided, an empty thread will be created.", + ), + ] + model: Annotated[ + Optional[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ] + ], + Field( + None, + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + examples=["gpt-4o"], + ), + ] + instructions: Annotated[ + Optional[str], + Field( + None, + description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis.", + ), + ] + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + None, + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_length=20, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources3], + Field( + None, + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + temperature: Annotated[ + Optional[float], + Field( + 1, + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] + top_p: Annotated[ + Optional[float], + Field( + 1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] + stream: Annotated[ + Optional[bool], + Field( + None, + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + ), + ] + max_prompt_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] + max_completion_tokens: Annotated[ + Optional[int], + Field( + None, + description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] + truncation_strategy: Optional[TruncationObject] = None + tool_choice: Optional[AssistantsApiToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class RunStepObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the run step, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Literal["thread.run.step"], + Field(description="The object type, which is always `thread.run.step`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step was created."), + ] + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) associated with the run step." + ), + ] + thread_id: Annotated[ + str, + Field(description="The ID of the [thread](/docs/api-reference/threads) that was run."), + ] + run_id: Annotated[ + str, + Field( + description="The ID of the [run](/docs/api-reference/runs) that this run step is a part of." + ), + ] + type: Annotated[ + Literal["message_creation", "tool_calls"], + Field( + description="The type of run step, which can be either `message_creation` or `tool_calls`." + ), + ] + status: Annotated[ + Literal["in_progress", "cancelled", "failed", "completed", "expired"], + Field( + description="The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`." + ), + ] + step_details: Annotated[ + Union[RunStepDetailsMessageCreationObject, RunStepDetailsToolCallsObject], + Field(description="The details of the run step."), + ] + last_error: Annotated[ + LastError1, + Field( + description="The last error associated with this run step. Will be `null` if there are no errors." + ), + ] + expired_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired." + ), + ] + cancelled_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step was cancelled."), + ] + failed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step failed."), + ] + completed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step completed."), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + usage: RunStepCompletionUsage + + +class Delta1(BaseModel): + step_details: Annotated[ + Optional[ + Union[ + RunStepDeltaStepDetailsMessageCreationObject, + RunStepDeltaStepDetailsToolCallsObject, + ] + ], + Field(None, description="The details of the run step."), + ] + + +class RunStepDeltaObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the run step, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Literal["thread.run.step.delta"], + Field(description="The object type, which is always `thread.run.step.delta`."), + ] + delta: Annotated[ + Delta1, + Field(description="The delta containing the fields that have changed on the run step."), + ] + + +class ListRunStepsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[RunStepObject] + first_id: Annotated[str, Field(examples=["step_abc123"])] + last_id: Annotated[str, Field(examples=["step_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class RunStepStreamEvent1(BaseModel): + event: Literal["thread.run.step.created"] + data: RunStepObject + + +class RunStepStreamEvent2(BaseModel): + event: Literal["thread.run.step.in_progress"] + data: RunStepObject + + +class RunStepStreamEvent3(BaseModel): + event: Literal["thread.run.step.delta"] + data: RunStepDeltaObject + + +class RunStepStreamEvent4(BaseModel): + event: Literal["thread.run.step.completed"] + data: RunStepObject + + +class RunStepStreamEvent5(BaseModel): + event: Literal["thread.run.step.failed"] + data: RunStepObject + + +class RunStepStreamEvent6(BaseModel): + event: Literal["thread.run.step.cancelled"] + data: RunStepObject + + +class RunStepStreamEvent7(BaseModel): + event: Literal["thread.run.step.expired"] + data: RunStepObject + + +class RunStepStreamEvent( + RootModel[ + Union[ + RunStepStreamEvent1, + RunStepStreamEvent2, + RunStepStreamEvent3, + RunStepStreamEvent4, + RunStepStreamEvent5, + RunStepStreamEvent6, + RunStepStreamEvent7, + ] + ] +): + root: Union[ + RunStepStreamEvent1, + RunStepStreamEvent2, + RunStepStreamEvent3, + RunStepStreamEvent4, + RunStepStreamEvent5, + RunStepStreamEvent6, + RunStepStreamEvent7, + ] + + +class AssistantStreamEvent( + RootModel[ + Union[ + ThreadStreamEvent, + RunStreamEvent, + RunStepStreamEvent, + MessageStreamEvent, + ErrorEvent, + DoneEvent, + ] + ] +): + root: Annotated[ + Union[ + ThreadStreamEvent, + RunStreamEvent, + RunStepStreamEvent, + MessageStreamEvent, + ErrorEvent, + DoneEvent, + ], + Field( + description='Represents an event emitted when streaming a Run.\n\nEach event in a server-sent events stream has an `event` and `data` property:\n\n```\nevent: thread.created\ndata: {"id": "thread_123", "object": "thread", ...}\n```\n\nWe emit events whenever a new object is created, transitions to a new state, or is being\nstreamed in parts (deltas). For example, we emit `thread.run.created` when a new run\nis created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses\nto create a message during a run, we emit a `thread.message.created event`, a\n`thread.message.in_progress` event, many `thread.message.delta` events, and finally a\n`thread.message.completed` event.\n\nWe may add additional events over time, so we recommend handling unknown events gracefully\nin your code. See the [Assistants API quickstart](/docs/assistants/overview) to learn how to\nintegrate the Assistants API with streaming.\n' + ), + ] diff --git a/clients/python/llmengine/data_types/pydantic_types.py b/clients/python/llmengine/data_types/pydantic_types.py new file mode 100644 index 00000000..64d89c3d --- /dev/null +++ b/clients/python/llmengine/data_types/pydantic_types.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel as PydanticBaseModel +from pydantic import ( # noqa: F401 + ConfigDict, + Field, + HttpUrl, + RootModel, + ValidationError, + model_validator, +) + + +class BaseModel(PydanticBaseModel): + """Common pydantic configurations for model engine""" + + model_config = ConfigDict(protected_namespaces=()) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types/rest.py similarity index 61% rename from clients/python/llmengine/data_types.py rename to clients/python/llmengine/data_types/rest.py index 2de743c6..e7f80189 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types/rest.py @@ -6,16 +6,10 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union -from pydantic.version import VERSION as PYDANTIC_VERSION - -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") -if PYDANTIC_V2: - from pydantic.v1 import BaseModel, Field, HttpUrl -else: - from pydantic import BaseModel, Field, HttpUrl # type: ignore +from .pydantic_types import BaseModel, Field, HttpUrl, RootModel CpuSpecificationType = Union[str, int, float] -StorageSpecificationType = Union[str, int, float] # TODO(phil): we can make this more specific. +StorageSpecificationType = Union[str, int, float] class LLMInferenceFramework(str, Enum): @@ -73,8 +67,8 @@ class CallbackmTLSAuth(BaseModel): key: str -class CallbackAuth(BaseModel): - __root__: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") +class CallbackAuth(RootModel): + root: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") class ModelEndpointDeploymentState(BaseModel): @@ -314,161 +308,6 @@ class DeleteLLMEndpointResponse(BaseModel): """ -class CompletionSyncV1Request(BaseModel): - """ - Request object for a synchronous prompt completion task. - """ - - prompt: str = Field(..., min_length=1) - max_new_tokens: int = Field(..., gt=0) - temperature: float = Field(..., ge=0.0) - stop_sequences: Optional[List[str]] = Field(default=None) - return_token_log_probs: Optional[bool] = Field(default=False) - presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - top_k: Optional[int] = Field(default=None, ge=-1) - top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) - include_stop_str_in_output: Optional[bool] = Field(default=None) - guided_json: Optional[Dict[str, Any]] = Field(default=None) - guided_regex: Optional[str] = Field(default=None) - guided_choice: Optional[List[str]] = Field(default=None) - guided_grammar: Optional[str] = Field(default=None) - skip_special_tokens: Optional[bool] = Field(default=True) - - -class TokenOutput(BaseModel): - """ - Detailed token information. - """ - - token: str - """ - The token text. - """ - - log_prob: float - """ - The log probability of the token. - """ - - -class CompletionOutput(BaseModel): - """ - Represents the output of a completion request to a model. - """ - - text: str - """The text of the completion.""" - - # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. - # If we send request to api.spellbook.scale.com, we don't get this back. - num_prompt_tokens: Optional[int] = None - """Number of tokens in the prompt.""" - - num_completion_tokens: int - """Number of tokens in the completion.""" - - tokens: Optional[List[TokenOutput]] = None - """Detailed token information.""" - - -class CompletionSyncResponse(BaseModel): - """ - Response object for a synchronous prompt completion. - """ - - request_id: str - """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs - associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as - follows: - - * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. - * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability - provider.""" - - output: CompletionOutput - """Completion output.""" - - -class CompletionStreamV1Request(BaseModel): - """ - Request object for a streaming prompt completion. - """ - - prompt: str = Field(..., min_length=1) - max_new_tokens: int = Field(..., gt=0) - temperature: float = Field(..., ge=0.0) - stop_sequences: Optional[List[str]] = Field(default=None) - return_token_log_probs: Optional[bool] = Field(default=False) - presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - top_k: Optional[int] = Field(default=None, ge=-1) - top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) - include_stop_str_in_output: Optional[bool] = Field(default=None) - guided_json: Optional[Dict[str, Any]] = Field(default=None) - guided_regex: Optional[str] = Field(default=None) - guided_choice: Optional[List[str]] = Field(default=None) - guided_grammar: Optional[str] = Field(default=None) - skip_special_tokens: Optional[bool] = Field(default=True) - - -class CompletionStreamOutput(BaseModel): - text: str - """The text of the completion.""" - - finished: bool - """Whether the completion is finished.""" - - # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. - num_prompt_tokens: Optional[int] = None - """Number of tokens in the prompt.""" - - num_completion_tokens: Optional[int] = None - """Number of tokens in the completion.""" - - token: Optional[TokenOutput] = None - """Detailed token information.""" - - -class StreamErrorContent(BaseModel): - error: str - """Error message.""" - timestamp: str - """Timestamp of the error.""" - - -class StreamError(BaseModel): - """ - Error object for a stream prompt completion task. - """ - - status_code: int - """The HTTP status code of the error.""" - content: StreamErrorContent - """The error content.""" - - -class CompletionStreamResponse(BaseModel): - """ - Response object for a stream prompt completion task. - """ - - request_id: str - """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs - associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as - follows: - - * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. - * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability - provider.""" - - output: Optional[CompletionStreamOutput] = None - """Completion output.""" - - error: Optional[StreamError] = None - """Error of the response (if any).""" - - class CreateFineTuneRequest(BaseModel): """ Request object for creating a FineTune. @@ -668,138 +507,3 @@ class GetFileContentResponse(BaseModel): content: str = Field(..., description="File content.") """File content.""" - - -class CreateBatchCompletionsRequestContent(BaseModel): - prompts: List[str] - max_new_tokens: int - temperature: float = Field(ge=0.0, le=1.0) - """ - Temperature of the sampling. Setting to 0 equals to greedy sampling. - """ - stop_sequences: Optional[List[str]] = None - """ - List of sequences to stop the completion at. - """ - return_token_log_probs: Optional[bool] = False - """ - Whether to return the log probabilities of the tokens. - """ - presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - """ - Only supported in vllm, lightllm - Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty - """ - frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) - """ - Only supported in vllm, lightllm - Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty - """ - top_k: Optional[int] = Field(default=None, ge=-1) - """ - Controls the number of top tokens to consider. -1 means consider all tokens. - """ - top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) - """ - Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. - """ - skip_special_tokens: Optional[bool] = True - """ - Whether to skip special tokens in the output. - """ - - -class CreateBatchCompletionsModelConfig(BaseModel): - model: str - checkpoint_path: Optional[str] = None - """ - Path to the checkpoint to load the model from. - """ - labels: Dict[str, str] - """ - Labels to attach to the batch inference job. - """ - num_shards: Optional[int] = 1 - """ - Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. - System may decide to use a different number than the given value. - """ - quantize: Optional[Quantization] = None - """ - Whether to quantize the model. - """ - seed: Optional[int] = None - """ - Random seed for the model. - """ - - max_context_length: Optional[int] = Field( - default=None, - ge=1, - description="Maximum context length to use for the model. Defaults to the max allowed by the model", - ) - - -class ToolConfig(BaseModel): - """ - Configuration for tool use. - NOTE: this config is highly experimental and signature will change significantly in future iterations. - """ - - name: str - """ - Name of the tool to use for the batch inference. - """ - max_iterations: Optional[int] = 10 - """ - Maximum number of iterations to run the tool. - """ - execution_timeout_seconds: Optional[int] = 60 - """ - Maximum runtime of the tool in seconds. - """ - should_retry_on_error: Optional[bool] = True - """ - Whether to retry the tool on error. - """ - - -class CreateBatchCompletionsRequest(BaseModel): - """ - Request object for batch completions. - """ - - input_data_path: Optional[str] - output_data_path: str - """ - Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. - """ - content: Optional[CreateBatchCompletionsRequestContent] = None - """ - Either `input_data_path` or `content` needs to be provided. - When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. - """ - model_config: CreateBatchCompletionsModelConfig - """ - Model configuration for the batch inference. Hardware configurations are inferred. - """ - data_parallelism: Optional[int] = Field(default=1, ge=1, le=64) - """ - Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference. - """ - max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600) - """ - Maximum runtime of the batch inference in seconds. Default to one day. - """ - tool_config: Optional[ToolConfig] = None - """ - Configuration for tool use. - NOTE: this config is highly experimental and signature will change significantly in future iterations. - """ - - -class CreateBatchCompletionsResponse(BaseModel): - job_id: str - """ - The ID of the batch completions job. - """ diff --git a/clients/python/mypy.ini b/clients/python/mypy.ini index f35ae689..53164a06 100644 --- a/clients/python/mypy.ini +++ b/clients/python/mypy.ini @@ -6,3 +6,6 @@ namespace_packages = True explicit_package_bases = True strict_optional = True plugins = pydantic.mypy + +[mypy-llmengine.data_types.gen.*] +ignore_errors = True \ No newline at end of file diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock index 8d98a933..2b23ca33 100644 --- a/clients/python/poetry.lock +++ b/clients/python/poetry.lock @@ -1,114 +1,127 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. + +[[package]] +name = "aiohappyeyeballs" +version = "2.4.0" +description = "Happy Eyeballs for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohappyeyeballs-2.4.0-py3-none-any.whl", hash = "sha256:7ce92076e249169a13c2f49320d1967425eaf1f407522d707d59cac7628d62bd"}, + {file = "aiohappyeyeballs-2.4.0.tar.gz", hash = "sha256:55a1714f084e63d49639800f95716da97a1f173d46a16dfcfda0016abb93b6b2"}, +] [[package]] name = "aiohttp" -version = "3.8.5" +version = "3.10.5" description = "Async http client/server framework (asyncio)" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a94159871304770da4dd371f4291b20cac04e8c94f11bdea1c3478e557fbe0d8"}, - {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:13bf85afc99ce6f9ee3567b04501f18f9f8dbbb2ea11ed1a2e079670403a7c84"}, - {file = "aiohttp-3.8.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ce2ac5708501afc4847221a521f7e4b245abf5178cf5ddae9d5b3856ddb2f3a"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96943e5dcc37a6529d18766597c491798b7eb7a61d48878611298afc1fca946c"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ad5c3c4590bb3cc28b4382f031f3783f25ec223557124c68754a2231d989e2b"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c413c633d0512df4dc7fd2373ec06cc6a815b7b6d6c2f208ada7e9e93a5061d"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df72ac063b97837a80d80dec8d54c241af059cc9bb42c4de68bd5b61ceb37caa"}, - {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c48c5c0271149cfe467c0ff8eb941279fd6e3f65c9a388c984e0e6cf57538e14"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:368a42363c4d70ab52c2c6420a57f190ed3dfaca6a1b19afda8165ee16416a82"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7607ec3ce4993464368505888af5beb446845a014bc676d349efec0e05085905"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0d21c684808288a98914e5aaf2a7c6a3179d4df11d249799c32d1808e79503b5"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:312fcfbacc7880a8da0ae8b6abc6cc7d752e9caa0051a53d217a650b25e9a691"}, - {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad093e823df03bb3fd37e7dec9d4670c34f9e24aeace76808fc20a507cace825"}, - {file = "aiohttp-3.8.5-cp310-cp310-win32.whl", hash = "sha256:33279701c04351a2914e1100b62b2a7fdb9a25995c4a104259f9a5ead7ed4802"}, - {file = "aiohttp-3.8.5-cp310-cp310-win_amd64.whl", hash = "sha256:6e4a280e4b975a2e7745573e3fc9c9ba0d1194a3738ce1cbaa80626cc9b4f4df"}, - {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae871a964e1987a943d83d6709d20ec6103ca1eaf52f7e0d36ee1b5bebb8b9b9"}, - {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:461908b2578955045efde733719d62f2b649c404189a09a632d245b445c9c975"}, - {file = "aiohttp-3.8.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:72a860c215e26192379f57cae5ab12b168b75db8271f111019509a1196dfc780"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc14be025665dba6202b6a71cfcdb53210cc498e50068bc088076624471f8bb9"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8af740fc2711ad85f1a5c034a435782fbd5b5f8314c9a3ef071424a8158d7f6b"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:841cd8233cbd2111a0ef0a522ce016357c5e3aff8a8ce92bcfa14cef890d698f"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed1c46fb119f1b59304b5ec89f834f07124cd23ae5b74288e364477641060ff"}, - {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84f8ae3e09a34f35c18fa57f015cc394bd1389bce02503fb30c394d04ee6b938"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62360cb771707cb70a6fd114b9871d20d7dd2163a0feafe43fd115cfe4fe845e"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:23fb25a9f0a1ca1f24c0a371523546366bb642397c94ab45ad3aedf2941cec6a"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0ba0d15164eae3d878260d4c4df859bbdc6466e9e6689c344a13334f988bb53"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5d20003b635fc6ae3f96d7260281dfaf1894fc3aa24d1888a9b2628e97c241e5"}, - {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0175d745d9e85c40dcc51c8f88c74bfbaef9e7afeeeb9d03c37977270303064c"}, - {file = "aiohttp-3.8.5-cp311-cp311-win32.whl", hash = "sha256:2e1b1e51b0774408f091d268648e3d57f7260c1682e7d3a63cb00d22d71bb945"}, - {file = "aiohttp-3.8.5-cp311-cp311-win_amd64.whl", hash = "sha256:043d2299f6dfdc92f0ac5e995dfc56668e1587cea7f9aa9d8a78a1b6554e5755"}, - {file = "aiohttp-3.8.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cae533195e8122584ec87531d6df000ad07737eaa3c81209e85c928854d2195c"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f21e83f355643c345177a5d1d8079f9f28b5133bcd154193b799d380331d5d3"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a75ef35f2df54ad55dbf4b73fe1da96f370e51b10c91f08b19603c64004acc"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e2e9839e14dd5308ee773c97115f1e0a1cb1d75cbeeee9f33824fa5144c7634"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44e65da1de4403d0576473e2344828ef9c4c6244d65cf4b75549bb46d40b8dd"}, - {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d847e4cde6ecc19125ccbc9bfac4a7ab37c234dd88fbb3c5c524e8e14da543"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:c7a815258e5895d8900aec4454f38dca9aed71085f227537208057853f9d13f2"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:8b929b9bd7cd7c3939f8bcfffa92fae7480bd1aa425279d51a89327d600c704d"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:5db3a5b833764280ed7618393832e0853e40f3d3e9aa128ac0ba0f8278d08649"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:a0215ce6041d501f3155dc219712bc41252d0ab76474615b9700d63d4d9292af"}, - {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:fd1ed388ea7fbed22c4968dd64bab0198de60750a25fe8c0c9d4bef5abe13824"}, - {file = "aiohttp-3.8.5-cp36-cp36m-win32.whl", hash = "sha256:6e6783bcc45f397fdebc118d772103d751b54cddf5b60fbcc958382d7dd64f3e"}, - {file = "aiohttp-3.8.5-cp36-cp36m-win_amd64.whl", hash = "sha256:b5411d82cddd212644cf9360879eb5080f0d5f7d809d03262c50dad02f01421a"}, - {file = "aiohttp-3.8.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:01d4c0c874aa4ddfb8098e85d10b5e875a70adc63db91f1ae65a4b04d3344cda"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5980a746d547a6ba173fd5ee85ce9077e72d118758db05d229044b469d9029a"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a482e6da906d5e6e653be079b29bc173a48e381600161c9932d89dfae5942ef"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80bd372b8d0715c66c974cf57fe363621a02f359f1ec81cba97366948c7fc873"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1161b345c0a444ebcf46bf0a740ba5dcf50612fd3d0528883fdc0eff578006a"}, - {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd56db019015b6acfaaf92e1ac40eb8434847d9bf88b4be4efe5bfd260aee692"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:153c2549f6c004d2754cc60603d4668899c9895b8a89397444a9c4efa282aaf4"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4a01951fabc4ce26ab791da5f3f24dca6d9a6f24121746eb19756416ff2d881b"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bfb9162dcf01f615462b995a516ba03e769de0789de1cadc0f916265c257e5d8"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7dde0009408969a43b04c16cbbe252c4f5ef4574ac226bc8815cd7342d2028b6"}, - {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4149d34c32f9638f38f544b3977a4c24052042affa895352d3636fa8bffd030a"}, - {file = "aiohttp-3.8.5-cp37-cp37m-win32.whl", hash = "sha256:68c5a82c8779bdfc6367c967a4a1b2aa52cd3595388bf5961a62158ee8a59e22"}, - {file = "aiohttp-3.8.5-cp37-cp37m-win_amd64.whl", hash = "sha256:2cf57fb50be5f52bda004b8893e63b48530ed9f0d6c96c84620dc92fe3cd9b9d"}, - {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:eca4bf3734c541dc4f374ad6010a68ff6c6748f00451707f39857f429ca36ced"}, - {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1274477e4c71ce8cfe6c1ec2f806d57c015ebf84d83373676036e256bc55d690"}, - {file = "aiohttp-3.8.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28c543e54710d6158fc6f439296c7865b29e0b616629767e685a7185fab4a6b9"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910bec0c49637d213f5d9877105d26e0c4a4de2f8b1b29405ff37e9fc0ad52b8"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5443910d662db951b2e58eb70b0fbe6b6e2ae613477129a5805d0b66c54b6cb7"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e460be6978fc24e3df83193dc0cc4de46c9909ed92dd47d349a452ef49325b7"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1558def481d84f03b45888473fc5a1f35747b5f334ef4e7a571bc0dfcb11f8"}, - {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34dd0c107799dcbbf7d48b53be761a013c0adf5571bf50c4ecad5643fe9cfcd0"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aa1990247f02a54185dc0dff92a6904521172a22664c863a03ff64c42f9b5410"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0e584a10f204a617d71d359fe383406305a4b595b333721fa50b867b4a0a1548"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a3cf433f127efa43fee6b90ea4c6edf6c4a17109d1d037d1a52abec84d8f2e42"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c11f5b099adafb18e65c2c997d57108b5bbeaa9eeee64a84302c0978b1ec948b"}, - {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:84de26ddf621d7ac4c975dbea4c945860e08cccde492269db4e1538a6a6f3c35"}, - {file = "aiohttp-3.8.5-cp38-cp38-win32.whl", hash = "sha256:ab88bafedc57dd0aab55fa728ea10c1911f7e4d8b43e1d838a1739f33712921c"}, - {file = "aiohttp-3.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:5798a9aad1879f626589f3df0f8b79b3608a92e9beab10e5fda02c8a2c60db2e"}, - {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a6ce61195c6a19c785df04e71a4537e29eaa2c50fe745b732aa937c0c77169f3"}, - {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:773dd01706d4db536335fcfae6ea2440a70ceb03dd3e7378f3e815b03c97ab51"}, - {file = "aiohttp-3.8.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f83a552443a526ea38d064588613aca983d0ee0038801bc93c0c916428310c28"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f7372f7341fcc16f57b2caded43e81ddd18df53320b6f9f042acad41f8e049a"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea353162f249c8097ea63c2169dd1aa55de1e8fecbe63412a9bc50816e87b761"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d47ae48db0b2dcf70bc8a3bc72b3de86e2a590fc299fdbbb15af320d2659de"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d827176898a2b0b09694fbd1088c7a31836d1a505c243811c87ae53a3f6273c1"}, - {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3562b06567c06439d8b447037bb655ef69786c590b1de86c7ab81efe1c9c15d8"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4e874cbf8caf8959d2adf572a78bba17cb0e9d7e51bb83d86a3697b686a0ab4d"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6809a00deaf3810e38c628e9a33271892f815b853605a936e2e9e5129762356c"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:33776e945d89b29251b33a7e7d006ce86447b2cfd66db5e5ded4e5cd0340585c"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eaeed7abfb5d64c539e2db173f63631455f1196c37d9d8d873fc316470dfbacd"}, - {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e91d635961bec2d8f19dfeb41a539eb94bd073f075ca6dae6c8dc0ee89ad6f91"}, - {file = "aiohttp-3.8.5-cp39-cp39-win32.whl", hash = "sha256:00ad4b6f185ec67f3e6562e8a1d2b69660be43070bd0ef6fcec5211154c7df67"}, - {file = "aiohttp-3.8.5-cp39-cp39-win_amd64.whl", hash = "sha256:c0a9034379a37ae42dea7ac1e048352d96286626251862e448933c0f59cbd79c"}, - {file = "aiohttp-3.8.5.tar.gz", hash = "sha256:b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:18a01eba2574fb9edd5f6e5fb25f66e6ce061da5dab5db75e13fe1558142e0a3"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:94fac7c6e77ccb1ca91e9eb4cb0ac0270b9fb9b289738654120ba8cebb1189c6"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2f1f1c75c395991ce9c94d3e4aa96e5c59c8356a15b1c9231e783865e2772699"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7acae3cf1a2a2361ec4c8e787eaaa86a94171d2417aae53c0cca6ca3118ff6"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:94c4381ffba9cc508b37d2e536b418d5ea9cfdc2848b9a7fea6aebad4ec6aac1"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c31ad0c0c507894e3eaa843415841995bf8de4d6b2d24c6e33099f4bc9fc0d4f"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0912b8a8fadeb32ff67a3ed44249448c20148397c1ed905d5dac185b4ca547bb"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d93400c18596b7dc4794d48a63fb361b01a0d8eb39f28800dc900c8fbdaca91"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d00f3c5e0d764a5c9aa5a62d99728c56d455310bcc288a79cab10157b3af426f"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d742c36ed44f2798c8d3f4bc511f479b9ceef2b93f348671184139e7d708042c"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:814375093edae5f1cb31e3407997cf3eacefb9010f96df10d64829362ae2df69"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8224f98be68a84b19f48e0bdc14224b5a71339aff3a27df69989fa47d01296f3"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d9a487ef090aea982d748b1b0d74fe7c3950b109df967630a20584f9a99c0683"}, + {file = "aiohttp-3.10.5-cp310-cp310-win32.whl", hash = "sha256:d9ef084e3dc690ad50137cc05831c52b6ca428096e6deb3c43e95827f531d5ef"}, + {file = "aiohttp-3.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:66bf9234e08fe561dccd62083bf67400bdbf1c67ba9efdc3dac03650e97c6088"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8c6a4e5e40156d72a40241a25cc226051c0a8d816610097a8e8f517aeacd59a2"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c634a3207a5445be65536d38c13791904fda0748b9eabf908d3fe86a52941cf"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4aff049b5e629ef9b3e9e617fa6e2dfeda1bf87e01bcfecaf3949af9e210105e"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1942244f00baaacaa8155eca94dbd9e8cc7017deb69b75ef67c78e89fdad3c77"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e04a1f2a65ad2f93aa20f9ff9f1b672bf912413e5547f60749fa2ef8a644e061"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f2bfc0032a00405d4af2ba27f3c429e851d04fad1e5ceee4080a1c570476697"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:424ae21498790e12eb759040bbb504e5e280cab64693d14775c54269fd1d2bb7"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:975218eee0e6d24eb336d0328c768ebc5d617609affaca5dbbd6dd1984f16ed0"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4120d7fefa1e2d8fb6f650b11489710091788de554e2b6f8347c7a20ceb003f5"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b90078989ef3fc45cf9221d3859acd1108af7560c52397ff4ace8ad7052a132e"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ba5a8b74c2a8af7d862399cdedce1533642fa727def0b8c3e3e02fcb52dca1b1"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:02594361128f780eecc2a29939d9dfc870e17b45178a867bf61a11b2a4367277"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8fb4fc029e135859f533025bc82047334e24b0d489e75513144f25408ecaf058"}, + {file = "aiohttp-3.10.5-cp311-cp311-win32.whl", hash = "sha256:e1ca1ef5ba129718a8fc827b0867f6aa4e893c56eb00003b7367f8a733a9b072"}, + {file = "aiohttp-3.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:349ef8a73a7c5665cca65c88ab24abe75447e28aa3bc4c93ea5093474dfdf0ff"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:305be5ff2081fa1d283a76113b8df7a14c10d75602a38d9f012935df20731487"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3a1c32a19ee6bbde02f1cb189e13a71b321256cc1d431196a9f824050b160d5a"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:61645818edd40cc6f455b851277a21bf420ce347baa0b86eaa41d51ef58ba23d"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c225286f2b13bab5987425558baa5cbdb2bc925b2998038fa028245ef421e75"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ba01ebc6175e1e6b7275c907a3a36be48a2d487549b656aa90c8a910d9f3178"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8eaf44ccbc4e35762683078b72bf293f476561d8b68ec8a64f98cf32811c323e"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c43eb1ab7cbf411b8e387dc169acb31f0ca0d8c09ba63f9eac67829585b44f"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de7a5299827253023c55ea549444e058c0eb496931fa05d693b95140a947cb73"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4790f0e15f00058f7599dab2b206d3049d7ac464dc2e5eae0e93fa18aee9e7bf"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:44b324a6b8376a23e6ba25d368726ee3bc281e6ab306db80b5819999c737d820"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0d277cfb304118079e7044aad0b76685d30ecb86f83a0711fc5fb257ffe832ca"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:54d9ddea424cd19d3ff6128601a4a4d23d54a421f9b4c0fff740505813739a91"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4f1c9866ccf48a6df2b06823e6ae80573529f2af3a0992ec4fe75b1a510df8a6"}, + {file = "aiohttp-3.10.5-cp312-cp312-win32.whl", hash = "sha256:dc4826823121783dccc0871e3f405417ac116055bf184ac04c36f98b75aacd12"}, + {file = "aiohttp-3.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:22c0a23a3b3138a6bf76fc553789cb1a703836da86b0f306b6f0dc1617398abc"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7f6b639c36734eaa80a6c152a238242bedcee9b953f23bb887e9102976343092"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f29930bc2921cef955ba39a3ff87d2c4398a0394ae217f41cb02d5c26c8b1b77"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f489a2c9e6455d87eabf907ac0b7d230a9786be43fbe884ad184ddf9e9c1e385"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:123dd5b16b75b2962d0fff566effb7a065e33cd4538c1692fb31c3bda2bfb972"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b98e698dc34966e5976e10bbca6d26d6724e6bdea853c7c10162a3235aba6e16"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3b9162bab7e42f21243effc822652dc5bb5e8ff42a4eb62fe7782bcbcdfacf6"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1923a5c44061bffd5eebeef58cecf68096e35003907d8201a4d0d6f6e387ccaa"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55f011da0a843c3d3df2c2cf4e537b8070a419f891c930245f05d329c4b0689"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:afe16a84498441d05e9189a15900640a2d2b5e76cf4efe8cbb088ab4f112ee57"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8112fb501b1e0567a1251a2fd0747baae60a4ab325a871e975b7bb67e59221f"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e72589da4c90337837fdfe2026ae1952c0f4a6e793adbbfbdd40efed7c63599"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4d46c7b4173415d8e583045fbc4daa48b40e31b19ce595b8d92cf639396c15d5"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:33e6bc4bab477c772a541f76cd91e11ccb6d2efa2b8d7d7883591dfb523e5987"}, + {file = "aiohttp-3.10.5-cp313-cp313-win32.whl", hash = "sha256:c58c6837a2c2a7cf3133983e64173aec11f9c2cd8e87ec2fdc16ce727bcf1a04"}, + {file = "aiohttp-3.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:38172a70005252b6893088c0f5e8a47d173df7cc2b2bd88650957eb84fcf5022"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f6f18898ace4bcd2d41a122916475344a87f1dfdec626ecde9ee802a711bc569"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5ede29d91a40ba22ac1b922ef510aab871652f6c88ef60b9dcdf773c6d32ad7a"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:673f988370f5954df96cc31fd99c7312a3af0a97f09e407399f61583f30da9bc"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58718e181c56a3c02d25b09d4115eb02aafe1a732ce5714ab70326d9776457c3"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b38b1570242fbab8d86a84128fb5b5234a2f70c2e32f3070143a6d94bc854cf"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:074d1bff0163e107e97bd48cad9f928fa5a3eb4b9d33366137ffce08a63e37fe"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd31f176429cecbc1ba499d4aba31aaccfea488f418d60376b911269d3b883c5"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7384d0b87d4635ec38db9263e6a3f1eb609e2e06087f0aa7f63b76833737b471"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8989f46f3d7ef79585e98fa991e6ded55d2f48ae56d2c9fa5e491a6e4effb589"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c83f7a107abb89a227d6c454c613e7606c12a42b9a4ca9c5d7dad25d47c776ae"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cde98f323d6bf161041e7627a5fd763f9fd829bcfcd089804a5fdce7bb6e1b7d"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:676f94c5480d8eefd97c0c7e3953315e4d8c2b71f3b49539beb2aa676c58272f"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2d21ac12dc943c68135ff858c3a989f2194a709e6e10b4c8977d7fcd67dfd511"}, + {file = "aiohttp-3.10.5-cp38-cp38-win32.whl", hash = "sha256:17e997105bd1a260850272bfb50e2a328e029c941c2708170d9d978d5a30ad9a"}, + {file = "aiohttp-3.10.5-cp38-cp38-win_amd64.whl", hash = "sha256:1c19de68896747a2aa6257ae4cf6ef59d73917a36a35ee9d0a6f48cff0f94db8"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7e2fe37ac654032db1f3499fe56e77190282534810e2a8e833141a021faaab0e"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5bf3ead3cb66ab990ee2561373b009db5bc0e857549b6c9ba84b20bc462e172"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b2c16a919d936ca87a3c5f0e43af12a89a3ce7ccbce59a2d6784caba945b68b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad146dae5977c4dd435eb31373b3fe9b0b1bf26858c6fc452bf6af394067e10b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c5c6fa16412b35999320f5c9690c0f554392dc222c04e559217e0f9ae244b92"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:95c4dc6f61d610bc0ee1edc6f29d993f10febfe5b76bb470b486d90bbece6b22"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da452c2c322e9ce0cfef392e469a26d63d42860f829026a63374fde6b5c5876f"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:898715cf566ec2869d5cb4d5fb4be408964704c46c96b4be267442d265390f32"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:391cc3a9c1527e424c6865e087897e766a917f15dddb360174a70467572ac6ce"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:380f926b51b92d02a34119d072f178d80bbda334d1a7e10fa22d467a66e494db"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce91db90dbf37bb6fa0997f26574107e1b9d5ff939315247b7e615baa8ec313b"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9093a81e18c45227eebe4c16124ebf3e0d893830c6aca7cc310bfca8fe59d857"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ee40b40aa753d844162dcc80d0fe256b87cba48ca0054f64e68000453caead11"}, + {file = "aiohttp-3.10.5-cp39-cp39-win32.whl", hash = "sha256:03f2645adbe17f274444953bdea69f8327e9d278d961d85657cb0d06864814c1"}, + {file = "aiohttp-3.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:d17920f18e6ee090bdd3d0bfffd769d9f2cb4c8ffde3eb203777a3895c128862"}, + {file = "aiohttp-3.10.5.tar.gz", hash = "sha256:f071854b47d39591ce9a17981c46790acb30518e2f83dfca8db2dfa091178691"}, ] [package.dependencies] +aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" -async-timeout = ">=4.0.0a3,<5.0" -asynctest = {version = "0.13.0", markers = "python_version < \"3.8\""} +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" -charset-normalizer = ">=2.0,<4.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" -typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} yarl = ">=1.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns", "cchardet"] +speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] [[package]] name = "aiosignal" @@ -125,28 +138,50 @@ files = [ frozenlist = ">=1.1.0" [[package]] -name = "async-timeout" -version = "4.0.2" -description = "Timeout context manager for asyncio programs" +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"}, - {file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"}, + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] [package.dependencies] -typing-extensions = {version = ">=3.6.5", markers = "python_version < \"3.8\""} +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} [[package]] -name = "asynctest" -version = "0.13.0" -description = "Enhance the standard unittest package with features for testing asyncio libraries" +name = "anyio" +version = "4.4.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false -python-versions = ">=3.5" +python-versions = ">=3.8" files = [ - {file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"}, - {file = "asynctest-0.13.0.tar.gz", hash = "sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac"}, + {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, + {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] + +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] [[package]] @@ -161,118 +196,131 @@ files = [ [[package]] name = "attrs" -version = "23.1.0" +version = "24.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, - {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, + {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, + {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, ] -[package.dependencies] -importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} - [package.extras] -cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] -dev = ["attrs[docs,tests]", "pre-commit"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] -tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [[package]] name = "certifi" -version = "2023.7.22" +version = "2024.7.4" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, - {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, + {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, + {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, ] [[package]] name = "charset-normalizer" -version = "3.2.0" +version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" files = [ - {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"}, - {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, + {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"}, + {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] [[package]] @@ -299,71 +347,83 @@ files = [ [[package]] name = "coverage" -version = "7.2.7" +version = "7.6.1" description = "Code coverage measurement for Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "coverage-7.2.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d39b5b4f2a66ccae8b7263ac3c8170994b65266797fb96cbbfd3fb5b23921db8"}, - {file = "coverage-7.2.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d040ef7c9859bb11dfeb056ff5b3872436e3b5e401817d87a31e1750b9ae2fb"}, - {file = "coverage-7.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba90a9563ba44a72fda2e85302c3abc71c5589cea608ca16c22b9804262aaeb6"}, - {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d9405291c6928619403db1d10bd07888888ec1abcbd9748fdaa971d7d661b2"}, - {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31563e97dae5598556600466ad9beea39fb04e0229e61c12eaa206e0aa202063"}, - {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ebba1cd308ef115925421d3e6a586e655ca5a77b5bf41e02eb0e4562a111f2d1"}, - {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb017fd1b2603ef59e374ba2063f593abe0fc45f2ad9abdde5b4d83bd922a353"}, - {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62a5c7dad11015c66fbb9d881bc4caa5b12f16292f857842d9d1871595f4495"}, - {file = "coverage-7.2.7-cp310-cp310-win32.whl", hash = "sha256:ee57190f24fba796e36bb6d3aa8a8783c643d8fa9760c89f7a98ab5455fbf818"}, - {file = "coverage-7.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:f75f7168ab25dd93110c8a8117a22450c19976afbc44234cbf71481094c1b850"}, - {file = "coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f"}, - {file = "coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe"}, - {file = "coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3"}, - {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f"}, - {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb"}, - {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833"}, - {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97"}, - {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a"}, - {file = "coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a"}, - {file = "coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562"}, - {file = "coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4"}, - {file = "coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4"}, - {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01"}, - {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6"}, - {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d"}, - {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de"}, - {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d"}, - {file = "coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511"}, - {file = "coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3"}, - {file = "coverage-7.2.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:58c2ccc2f00ecb51253cbe5d8d7122a34590fac9646a960d1430d5b15321d95f"}, - {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d22656368f0e6189e24722214ed8d66b8022db19d182927b9a248a2a8a2f67eb"}, - {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a895fcc7b15c3fc72beb43cdcbdf0ddb7d2ebc959edac9cef390b0d14f39f8a9"}, - {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84606b74eb7de6ff581a7915e2dab7a28a0517fbe1c9239eb227e1354064dcd"}, - {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0a5f9e1dbd7fbe30196578ca36f3fba75376fb99888c395c5880b355e2875f8a"}, - {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:419bfd2caae268623dd469eff96d510a920c90928b60f2073d79f8fe2bbc5959"}, - {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2aee274c46590717f38ae5e4650988d1af340fe06167546cc32fe2f58ed05b02"}, - {file = "coverage-7.2.7-cp37-cp37m-win32.whl", hash = "sha256:61b9a528fb348373c433e8966535074b802c7a5d7f23c4f421e6c6e2f1697a6f"}, - {file = "coverage-7.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:b1c546aca0ca4d028901d825015dc8e4d56aac4b541877690eb76490f1dc8ed0"}, - {file = "coverage-7.2.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54b896376ab563bd38453cecb813c295cf347cf5906e8b41d340b0321a5433e5"}, - {file = "coverage-7.2.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d376df58cc111dc8e21e3b6e24606b5bb5dee6024f46a5abca99124b2229ef5"}, - {file = "coverage-7.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e330fc79bd7207e46c7d7fd2bb4af2963f5f635703925543a70b99574b0fea9"}, - {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e9d683426464e4a252bf70c3498756055016f99ddaec3774bf368e76bbe02b6"}, - {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d13c64ee2d33eccf7437961b6ea7ad8673e2be040b4f7fd4fd4d4d28d9ccb1e"}, - {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b7aa5f8a41217360e600da646004f878250a0d6738bcdc11a0a39928d7dc2050"}, - {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fa03bce9bfbeeef9f3b160a8bed39a221d82308b4152b27d82d8daa7041fee5"}, - {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:245167dd26180ab4c91d5e1496a30be4cd721a5cf2abf52974f965f10f11419f"}, - {file = "coverage-7.2.7-cp38-cp38-win32.whl", hash = "sha256:d2c2db7fd82e9b72937969bceac4d6ca89660db0a0967614ce2481e81a0b771e"}, - {file = "coverage-7.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:2e07b54284e381531c87f785f613b833569c14ecacdcb85d56b25c4622c16c3c"}, - {file = "coverage-7.2.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:537891ae8ce59ef63d0123f7ac9e2ae0fc8b72c7ccbe5296fec45fd68967b6c9"}, - {file = "coverage-7.2.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06fb182e69f33f6cd1d39a6c597294cff3143554b64b9825d1dc69d18cc2fff2"}, - {file = "coverage-7.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201e7389591af40950a6480bd9edfa8ed04346ff80002cec1a66cac4549c1ad7"}, - {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6951407391b639504e3b3be51b7ba5f3528adbf1a8ac3302b687ecababf929e"}, - {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f48351d66575f535669306aa7d6d6f71bc43372473b54a832222803eb956fd1"}, - {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b29019c76039dc3c0fd815c41392a044ce555d9bcdd38b0fb60fb4cd8e475ba9"}, - {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:81c13a1fc7468c40f13420732805a4c38a105d89848b7c10af65a90beff25250"}, - {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:975d70ab7e3c80a3fe86001d8751f6778905ec723f5b110aed1e450da9d4b7f2"}, - {file = "coverage-7.2.7-cp39-cp39-win32.whl", hash = "sha256:7ee7d9d4822c8acc74a5e26c50604dff824710bc8de424904c0982e25c39c6cb"}, - {file = "coverage-7.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:eb393e5ebc85245347950143969b241d08b52b88a3dc39479822e073a1a8eb27"}, - {file = "coverage-7.2.7-pp37.pp38.pp39-none-any.whl", hash = "sha256:b7b4c971f05e6ae490fef852c218b0e79d4e52f79ef0c8475566584a8fb3e01d"}, - {file = "coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61c0abb4c85b095a784ef23fdd4aede7a2628478e7baba7c5e3deba61070a02"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd21f6ae3f08b41004dfb433fa895d858f3f5979e7762d052b12aef444e29afc"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f59d57baca39b32db42b83b2a7ba6f47ad9c394ec2076b084c3f029b7afca23"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a1ac0ae2b8bd743b88ed0502544847c3053d7171a3cff9228af618a068ed9c34"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e6a08c0be454c3b3beb105c0596ebdc2371fab6bb90c0c0297f4e58fd7e1012c"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f5796e664fe802da4f57a168c85359a8fbf3eab5e55cd4e4569fbacecc903959"}, + {file = "coverage-7.6.1-cp310-cp310-win32.whl", hash = "sha256:7bb65125fcbef8d989fa1dd0e8a060999497629ca5b0efbca209588a73356232"}, + {file = "coverage-7.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:3115a95daa9bdba70aea750db7b96b37259a81a709223c8448fa97727d546fe0"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7dea0889685db8550f839fa202744652e87c60015029ce3f60e006f8c4462c93"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed37bd3c3b063412f7620464a9ac1314d33100329f39799255fb8d3027da50d3"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d85f5e9a5f8b73e2350097c3756ef7e785f55bd71205defa0bfdaf96c31616ff"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bc572be474cafb617672c43fe989d6e48d3c83af02ce8de73fff1c6bb3c198d"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c0420b573964c760df9e9e86d1a9a622d0d27f417e1a949a8a66dd7bcee7bc6"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f4aa8219db826ce6be7099d559f8ec311549bfc4046f7f9fe9b5cea5c581c56"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:fc5a77d0c516700ebad189b587de289a20a78324bc54baee03dd486f0855d234"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b48f312cca9621272ae49008c7f613337c53fadca647d6384cc129d2996d1133"}, + {file = "coverage-7.6.1-cp311-cp311-win32.whl", hash = "sha256:1125ca0e5fd475cbbba3bb67ae20bd2c23a98fac4e32412883f9bcbaa81c314c"}, + {file = "coverage-7.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:8ae539519c4c040c5ffd0632784e21b2f03fc1340752af711f33e5be83a9d6c6"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"}, + {file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"}, + {file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a4acd025ecc06185ba2b801f2de85546e0b8ac787cf9d3b06e7e2a69f925b106"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a6d3adcf24b624a7b778533480e32434a39ad8fa30c315208f6d3e5542aeb6e9"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0c212c49b6c10e6951362f7c6df3329f04c2b1c28499563d4035d964ab8e08c"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e81d7a3e58882450ec4186ca59a3f20a5d4440f25b1cff6f0902ad890e6748a"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b260de9790fd81e69401c2dc8b17da47c8038176a79092a89cb2b7d945d060"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a78d169acd38300060b28d600344a803628c3fd585c912cacc9ea8790fe96862"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2c09f4ce52cb99dd7505cd0fc8e0e37c77b87f46bc9c1eb03fe3bc9991085388"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6878ef48d4227aace338d88c48738a4258213cd7b74fd9a3d4d7582bb1d8a155"}, + {file = "coverage-7.6.1-cp313-cp313-win32.whl", hash = "sha256:44df346d5215a8c0e360307d46ffaabe0f5d3502c8a1cefd700b34baf31d411a"}, + {file = "coverage-7.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:8284cf8c0dd272a247bc154eb6c95548722dce90d098c17a883ed36e67cdb129"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d3296782ca4eab572a1a4eca686d8bfb00226300dcefdf43faa25b5242ab8a3e"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:502753043567491d3ff6d08629270127e0c31d4184c4c8d98f92c26f65019962"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a89ecca80709d4076b95f89f308544ec8f7b4727e8a547913a35f16717856cb"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a318d68e92e80af8b00fa99609796fdbcdfef3629c77c6283566c6f02c6d6704"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13b0a73a0896988f053e4fbb7de6d93388e6dd292b0d87ee51d106f2c11b465b"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4421712dbfc5562150f7554f13dde997a2e932a6b5f352edcce948a815efee6f"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:166811d20dfea725e2e4baa71fffd6c968a958577848d2131f39b60043400223"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:225667980479a17db1048cb2bf8bfb39b8e5be8f164b8f6628b64f78a72cf9d3"}, + {file = "coverage-7.6.1-cp313-cp313t-win32.whl", hash = "sha256:170d444ab405852903b7d04ea9ae9b98f98ab6d7e63e1115e82620807519797f"}, + {file = "coverage-7.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b9f222de8cded79c49bf184bdbc06630d4c58eec9459b939b4a690c82ed05657"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6db04803b6c7291985a761004e9060b2bca08da6d04f26a7f2294b8623a0c1a0"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f1adfc8ac319e1a348af294106bc6a8458a0f1633cc62a1446aebc30c5fa186a"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a95324a9de9650a729239daea117df21f4b9868ce32e63f8b650ebe6cef5595b"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b43c03669dc4618ec25270b06ecd3ee4fa94c7f9b3c14bae6571ca00ef98b0d3"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8929543a7192c13d177b770008bc4e8119f2e1f881d563fc6b6305d2d0ebe9de"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:a09ece4a69cf399510c8ab25e0950d9cf2b42f7b3cb0374f95d2e2ff594478a6"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9054a0754de38d9dbd01a46621636689124d666bad1936d76c0341f7d71bf569"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0dbde0f4aa9a16fa4d754356a8f2e36296ff4d83994b2c9d8398aa32f222f989"}, + {file = "coverage-7.6.1-cp38-cp38-win32.whl", hash = "sha256:da511e6ad4f7323ee5702e6633085fb76c2f893aaf8ce4c51a0ba4fc07580ea7"}, + {file = "coverage-7.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:3f1156e3e8f2872197af3840d8ad307a9dd18e615dc64d9ee41696f287c57ad8"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abd5fd0db5f4dc9289408aaf34908072f805ff7792632250dcb36dc591d24255"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:547f45fa1a93154bd82050a7f3cddbc1a7a4dd2a9bf5cb7d06f4ae29fe94eaf8"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645786266c8f18a931b65bfcefdbf6952dd0dea98feee39bd188607a9d307ed2"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e0b2df163b8ed01d515807af24f63de04bebcecbd6c3bfeff88385789fdf75a"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:609b06f178fe8e9f89ef676532760ec0b4deea15e9969bf754b37f7c40326dbc"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:702855feff378050ae4f741045e19a32d57d19f3e0676d589df0575008ea5004"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2bdb062ea438f22d99cba0d7829c2ef0af1d768d1e4a4f528087224c90b132cb"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9c56863d44bd1c4fe2abb8a4d6f5371d197f1ac0ebdee542f07f35895fc07f36"}, + {file = "coverage-7.6.1-cp39-cp39-win32.whl", hash = "sha256:6e2cd258d7d927d09493c8df1ce9174ad01b381d4729a9d8d4e38670ca24774c"}, + {file = "coverage-7.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:06a737c882bd26d0d6ee7269b20b12f14a8704807a01056c80bb881a4b2ce6ca"}, + {file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"}, + {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, ] [package.dependencies] @@ -383,134 +443,199 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.2" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, +] + +[package.extras] +test = ["pytest (>=6)"] + [[package]] name = "filelock" -version = "3.12.2" +version = "3.15.4" description = "A platform independent file lock." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, - {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, ] [package.extras] -docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] +typing = ["typing-extensions (>=4.8)"] [[package]] name = "frozenlist" -version = "1.3.3" +version = "1.4.1" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, +] + +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false python-versions = ">=3.7" files = [ - {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff8bf625fe85e119553b5383ba0fb6aa3d0ec2ae980295aaefa552374926b3f4"}, - {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dfbac4c2dfcc082fcf8d942d1e49b6aa0766c19d3358bd86e2000bf0fa4a9cf0"}, - {file = "frozenlist-1.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b1c63e8d377d039ac769cd0926558bb7068a1f7abb0f003e3717ee003ad85530"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fdfc24dcfce5b48109867c13b4cb15e4660e7bd7661741a391f821f23dfdca7"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c926450857408e42f0bbc295e84395722ce74bae69a3b2aa2a65fe22cb14b99"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1841e200fdafc3d51f974d9d377c079a0694a8f06de2e67b48150328d66d5483"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f470c92737afa7d4c3aacc001e335062d582053d4dbe73cda126f2d7031068dd"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:783263a4eaad7c49983fe4b2e7b53fa9770c136c270d2d4bbb6d2192bf4d9caf"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:924620eef691990dfb56dc4709f280f40baee568c794b5c1885800c3ecc69816"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae4dc05c465a08a866b7a1baf360747078b362e6a6dbeb0c57f234db0ef88ae0"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:bed331fe18f58d844d39ceb398b77d6ac0b010d571cba8267c2e7165806b00ce"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:02c9ac843e3390826a265e331105efeab489ffaf4dd86384595ee8ce6d35ae7f"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9545a33965d0d377b0bc823dcabf26980e77f1b6a7caa368a365a9497fb09420"}, - {file = "frozenlist-1.3.3-cp310-cp310-win32.whl", hash = "sha256:d5cd3ab21acbdb414bb6c31958d7b06b85eeb40f66463c264a9b343a4e238642"}, - {file = "frozenlist-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:b756072364347cb6aa5b60f9bc18e94b2f79632de3b0190253ad770c5df17db1"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b4395e2f8d83fbe0c627b2b696acce67868793d7d9750e90e39592b3626691b7"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14143ae966a6229350021384870458e4777d1eae4c28d1a7aa47f24d030e6678"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5d8860749e813a6f65bad8285a0520607c9500caa23fea6ee407e63debcdbef6"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23d16d9f477bb55b6154654e0e74557040575d9d19fe78a161bd33d7d76808e8"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb82dbba47a8318e75f679690190c10a5e1f447fbf9df41cbc4c3afd726d88cb"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9309869032abb23d196cb4e4db574232abe8b8be1339026f489eeb34a4acfd91"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a97b4fe50b5890d36300820abd305694cb865ddb7885049587a5678215782a6b"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c188512b43542b1e91cadc3c6c915a82a5eb95929134faf7fd109f14f9892ce4"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:303e04d422e9b911a09ad499b0368dc551e8c3cd15293c99160c7f1f07b59a48"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0771aed7f596c7d73444c847a1c16288937ef988dc04fb9f7be4b2aa91db609d"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:66080ec69883597e4d026f2f71a231a1ee9887835902dbe6b6467d5a89216cf6"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:41fe21dc74ad3a779c3d73a2786bdf622ea81234bdd4faf90b8b03cad0c2c0b4"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f20380df709d91525e4bee04746ba612a4df0972c1b8f8e1e8af997e678c7b81"}, - {file = "frozenlist-1.3.3-cp311-cp311-win32.whl", hash = "sha256:f30f1928162e189091cf4d9da2eac617bfe78ef907a761614ff577ef4edfb3c8"}, - {file = "frozenlist-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a6394d7dadd3cfe3f4b3b186e54d5d8504d44f2d58dcc89d693698e8b7132b32"}, - {file = "frozenlist-1.3.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8df3de3a9ab8325f94f646609a66cbeeede263910c5c0de0101079ad541af332"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0693c609e9742c66ba4870bcee1ad5ff35462d5ffec18710b4ac89337ff16e27"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd4210baef299717db0a600d7a3cac81d46ef0e007f88c9335db79f8979c0d3d"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:394c9c242113bfb4b9aa36e2b80a05ffa163a30691c7b5a29eba82e937895d5e"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6327eb8e419f7d9c38f333cde41b9ae348bec26d840927332f17e887a8dcb70d"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e24900aa13212e75e5b366cb9065e78bbf3893d4baab6052d1aca10d46d944c"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3843f84a6c465a36559161e6c59dce2f2ac10943040c2fd021cfb70d58c4ad56"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:84610c1502b2461255b4c9b7d5e9c48052601a8957cd0aea6ec7a7a1e1fb9420"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:c21b9aa40e08e4f63a2f92ff3748e6b6c84d717d033c7b3438dd3123ee18f70e"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:efce6ae830831ab6a22b9b4091d411698145cb9b8fc869e1397ccf4b4b6455cb"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:40de71985e9042ca00b7953c4f41eabc3dc514a2d1ff534027f091bc74416401"}, - {file = "frozenlist-1.3.3-cp37-cp37m-win32.whl", hash = "sha256:180c00c66bde6146a860cbb81b54ee0df350d2daf13ca85b275123bbf85de18a"}, - {file = "frozenlist-1.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9bbbcedd75acdfecf2159663b87f1bb5cfc80e7cd99f7ddd9d66eb98b14a8411"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:034a5c08d36649591be1cbb10e09da9f531034acfe29275fc5454a3b101ce41a"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba64dc2b3b7b158c6660d49cdb1d872d1d0bf4e42043ad8d5006099479a194e5"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:47df36a9fe24054b950bbc2db630d508cca3aa27ed0566c0baf661225e52c18e"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:008a054b75d77c995ea26629ab3a0c0d7281341f2fa7e1e85fa6153ae29ae99c"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:841ea19b43d438a80b4de62ac6ab21cfe6827bb8a9dc62b896acc88eaf9cecba"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e235688f42b36be2b6b06fc37ac2126a73b75fb8d6bc66dd632aa35286238703"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca713d4af15bae6e5d79b15c10c8522859a9a89d3b361a50b817c98c2fb402a2"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ac5995f2b408017b0be26d4a1d7c61bce106ff3d9e3324374d66b5964325448"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4ae8135b11652b08a8baf07631d3ebfe65a4c87909dbef5fa0cdde440444ee4"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4ea42116ceb6bb16dbb7d526e242cb6747b08b7710d9782aa3d6732bd8d27649"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:810860bb4bdce7557bc0febb84bbd88198b9dbc2022d8eebe5b3590b2ad6c842"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:ee78feb9d293c323b59a6f2dd441b63339a30edf35abcb51187d2fc26e696d13"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0af2e7c87d35b38732e810befb9d797a99279cbb85374d42ea61c1e9d23094b3"}, - {file = "frozenlist-1.3.3-cp38-cp38-win32.whl", hash = "sha256:899c5e1928eec13fd6f6d8dc51be23f0d09c5281e40d9cf4273d188d9feeaf9b"}, - {file = "frozenlist-1.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:7f44e24fa70f6fbc74aeec3e971f60a14dde85da364aa87f15d1be94ae75aeef"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2b07ae0c1edaa0a36339ec6cce700f51b14a3fc6545fdd32930d2c83917332cf"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ebb86518203e12e96af765ee89034a1dbb0c3c65052d1b0c19bbbd6af8a145e1"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5cf820485f1b4c91e0417ea0afd41ce5cf5965011b3c22c400f6d144296ccbc0"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c11e43016b9024240212d2a65043b70ed8dfd3b52678a1271972702d990ac6d"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fa3c6e3305aa1146b59a09b32b2e04074945ffcfb2f0931836d103a2c38f936"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:352bd4c8c72d508778cf05ab491f6ef36149f4d0cb3c56b1b4302852255d05d5"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:65a5e4d3aa679610ac6e3569e865425b23b372277f89b5ef06cf2cdaf1ebf22b"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e2c1185858d7e10ff045c496bbf90ae752c28b365fef2c09cf0fa309291669"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f163d2fd041c630fed01bc48d28c3ed4a3b003c00acd396900e11ee5316b56bb"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:05cdb16d09a0832eedf770cb7bd1fe57d8cf4eaf5aced29c4e41e3f20b30a784"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:8bae29d60768bfa8fb92244b74502b18fae55a80eac13c88eb0b496d4268fd2d"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eedab4c310c0299961ac285591acd53dc6723a1ebd90a57207c71f6e0c2153ab"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3bbdf44855ed8f0fbcd102ef05ec3012d6a4fd7c7562403f76ce6a52aeffb2b1"}, - {file = "frozenlist-1.3.3-cp39-cp39-win32.whl", hash = "sha256:efa568b885bca461f7c7b9e032655c0c143d305bf01c30caf6db2854a4532b38"}, - {file = "frozenlist-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:cfe33efc9cb900a4c46f91a5ceba26d6df370ffddd9ca386eb1d4f0ad97b9ea9"}, - {file = "frozenlist-1.3.3.tar.gz", hash = "sha256:58bcc55721e8a90b88332d6cd441261ebb22342e238296bb330968952fbb3a6a"}, + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] [[package]] -name = "idna" -version = "3.7" -description = "Internationalized Domain Names in Applications (IDNA)" +name = "httpcore" +version = "1.0.5" +description = "A minimal low-level HTTP client." optional = false -python-versions = ">=3.5" +python-versions = ">=3.8" files = [ - {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, - {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, ] +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] + [[package]] -name = "importlib-metadata" -version = "6.7.0" -description = "Read metadata from Python packages" +name = "httpx" +version = "0.27.0" +description = "The next generation HTTP client." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "importlib_metadata-6.7.0-py3-none-any.whl", hash = "sha256:cb52082e659e97afc5dac71e79de97d8681de3aa07ff18578330904a9d18e5b5"}, - {file = "importlib_metadata-6.7.0.tar.gz", hash = "sha256:1aaf550d4f73e5d6783e7acb77aec43d49da8017410afae93822cc9cca98c4d4"}, + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, ] [package.dependencies] -typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} -zipp = ">=0.5" +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" [package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + +[[package]] +name = "idna" +version = "3.7" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.5" +files = [ + {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, + {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, +] [[package]] name = "iniconfig" @@ -523,134 +648,220 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "jiter" +version = "0.5.0" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jiter-0.5.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b599f4e89b3def9a94091e6ee52e1d7ad7bc33e238ebb9c4c63f211d74822c3f"}, + {file = "jiter-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a063f71c4b06225543dddadbe09d203dc0c95ba352d8b85f1221173480a71d5"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:acc0d5b8b3dd12e91dd184b87273f864b363dfabc90ef29a1092d269f18c7e28"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c22541f0b672f4d741382a97c65609332a783501551445ab2df137ada01e019e"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63314832e302cc10d8dfbda0333a384bf4bcfce80d65fe99b0f3c0da8945a91a"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a25fbd8a5a58061e433d6fae6d5298777c0814a8bcefa1e5ecfff20c594bd749"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:503b2c27d87dfff5ab717a8200fbbcf4714516c9d85558048b1fc14d2de7d8dc"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6d1f3d27cce923713933a844872d213d244e09b53ec99b7a7fdf73d543529d6d"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c95980207b3998f2c3b3098f357994d3fd7661121f30669ca7cb945f09510a87"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:afa66939d834b0ce063f57d9895e8036ffc41c4bd90e4a99631e5f261d9b518e"}, + {file = "jiter-0.5.0-cp310-none-win32.whl", hash = "sha256:f16ca8f10e62f25fd81d5310e852df6649af17824146ca74647a018424ddeccf"}, + {file = "jiter-0.5.0-cp310-none-win_amd64.whl", hash = "sha256:b2950e4798e82dd9176935ef6a55cf6a448b5c71515a556da3f6b811a7844f1e"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d4c8e1ed0ef31ad29cae5ea16b9e41529eb50a7fba70600008e9f8de6376d553"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c6f16e21276074a12d8421692515b3fd6d2ea9c94fd0734c39a12960a20e85f3"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5280e68e7740c8c128d3ae5ab63335ce6d1fb6603d3b809637b11713487af9e6"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:583c57fc30cc1fec360e66323aadd7fc3edeec01289bfafc35d3b9dcb29495e4"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26351cc14507bdf466b5f99aba3df3143a59da75799bf64a53a3ad3155ecded9"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4829df14d656b3fb87e50ae8b48253a8851c707da9f30d45aacab2aa2ba2d614"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a42a4bdcf7307b86cb863b2fb9bb55029b422d8f86276a50487982d99eed7c6e"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04d461ad0aebf696f8da13c99bc1b3e06f66ecf6cfd56254cc402f6385231c06"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e6375923c5f19888c9226582a124b77b622f8fd0018b843c45eeb19d9701c403"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2cec323a853c24fd0472517113768c92ae0be8f8c384ef4441d3632da8baa646"}, + {file = "jiter-0.5.0-cp311-none-win32.whl", hash = "sha256:aa1db0967130b5cab63dfe4d6ff547c88b2a394c3410db64744d491df7f069bb"}, + {file = "jiter-0.5.0-cp311-none-win_amd64.whl", hash = "sha256:aa9d2b85b2ed7dc7697597dcfaac66e63c1b3028652f751c81c65a9f220899ae"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9f664e7351604f91dcdd557603c57fc0d551bc65cc0a732fdacbf73ad335049a"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:044f2f1148b5248ad2c8c3afb43430dccf676c5a5834d2f5089a4e6c5bbd64df"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:702e3520384c88b6e270c55c772d4bd6d7b150608dcc94dea87ceba1b6391248"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:528d742dcde73fad9d63e8242c036ab4a84389a56e04efd854062b660f559544"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8cf80e5fe6ab582c82f0c3331df27a7e1565e2dcf06265afd5173d809cdbf9ba"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:44dfc9ddfb9b51a5626568ef4e55ada462b7328996294fe4d36de02fce42721f"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c451f7922992751a936b96c5f5b9bb9312243d9b754c34b33d0cb72c84669f4e"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:308fce789a2f093dca1ff91ac391f11a9f99c35369117ad5a5c6c4903e1b3e3a"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7f5ad4a7c6b0d90776fdefa294f662e8a86871e601309643de30bf94bb93a64e"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ea189db75f8eca08807d02ae27929e890c7d47599ce3d0a6a5d41f2419ecf338"}, + {file = "jiter-0.5.0-cp312-none-win32.whl", hash = "sha256:e3bbe3910c724b877846186c25fe3c802e105a2c1fc2b57d6688b9f8772026e4"}, + {file = "jiter-0.5.0-cp312-none-win_amd64.whl", hash = "sha256:a586832f70c3f1481732919215f36d41c59ca080fa27a65cf23d9490e75b2ef5"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f04bc2fc50dc77be9d10f73fcc4e39346402ffe21726ff41028f36e179b587e6"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6f433a4169ad22fcb550b11179bb2b4fd405de9b982601914ef448390b2954f3"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad4a6398c85d3a20067e6c69890ca01f68659da94d74c800298581724e426c7e"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6baa88334e7af3f4d7a5c66c3a63808e5efbc3698a1c57626541ddd22f8e4fbf"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ece0a115c05efca597c6d938f88c9357c843f8c245dbbb53361a1c01afd7148"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:335942557162ad372cc367ffaf93217117401bf930483b4b3ebdb1223dbddfa7"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649b0ee97a6e6da174bffcb3c8c051a5935d7d4f2f52ea1583b5b3e7822fbf14"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4be354c5de82157886ca7f5925dbda369b77344b4b4adf2723079715f823989"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5206144578831a6de278a38896864ded4ed96af66e1e63ec5dd7f4a1fce38a3a"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8120c60f8121ac3d6f072b97ef0e71770cc72b3c23084c72c4189428b1b1d3b6"}, + {file = "jiter-0.5.0-cp38-none-win32.whl", hash = "sha256:6f1223f88b6d76b519cb033a4d3687ca157c272ec5d6015c322fc5b3074d8a5e"}, + {file = "jiter-0.5.0-cp38-none-win_amd64.whl", hash = "sha256:c59614b225d9f434ea8fc0d0bec51ef5fa8c83679afedc0433905994fb36d631"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0af3838cfb7e6afee3f00dc66fa24695199e20ba87df26e942820345b0afc566"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:550b11d669600dbc342364fd4adbe987f14d0bbedaf06feb1b983383dcc4b961"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:489875bf1a0ffb3cb38a727b01e6673f0f2e395b2aad3c9387f94187cb214bbf"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b250ca2594f5599ca82ba7e68785a669b352156260c5362ea1b4e04a0f3e2389"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ea18e01f785c6667ca15407cd6dabbe029d77474d53595a189bdc813347218e"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:462a52be85b53cd9bffd94e2d788a09984274fe6cebb893d6287e1c296d50653"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92cc68b48d50fa472c79c93965e19bd48f40f207cb557a8346daa020d6ba973b"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1c834133e59a8521bc87ebcad773608c6fa6ab5c7a022df24a45030826cf10bc"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab3a71ff31cf2d45cb216dc37af522d335211f3a972d2fe14ea99073de6cb104"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cccd3af9c48ac500c95e1bcbc498020c87e1781ff0345dd371462d67b76643eb"}, + {file = "jiter-0.5.0-cp39-none-win32.whl", hash = "sha256:368084d8d5c4fc40ff7c3cc513c4f73e02c85f6009217922d0823a48ee7adf61"}, + {file = "jiter-0.5.0-cp39-none-win_amd64.whl", hash = "sha256:ce03f7b4129eb72f1687fa11300fbf677b02990618428934662406d2a76742a1"}, + {file = "jiter-0.5.0.tar.gz", hash = "sha256:1d916ba875bcab5c5f7d927df998c4cb694d27dceddf3392e58beaf10563368a"}, +] + [[package]] name = "multidict" -version = "6.0.4" +version = "6.0.5" description = "multidict implementation" optional = false python-versions = ">=3.7" files = [ - {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"}, - {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"}, - {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"}, - {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"}, - {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"}, - {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"}, - {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"}, - {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"}, - {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"}, - {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"}, - {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"}, - {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"}, - {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"}, - {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"}, - {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, + {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, + {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, + {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, + {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, + {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, + {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, + {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, + {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, + {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, + {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, + {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, + {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, + {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, + {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, + {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, ] [[package]] name = "mypy" -version = "1.4.1" +version = "1.11.1" description = "Optional static typing for Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "mypy-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:566e72b0cd6598503e48ea610e0052d1b8168e60a46e0bfd34b3acf2d57f96a8"}, - {file = "mypy-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ca637024ca67ab24a7fd6f65d280572c3794665eaf5edcc7e90a866544076878"}, - {file = "mypy-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dde1d180cd84f0624c5dcaaa89c89775550a675aff96b5848de78fb11adabcd"}, - {file = "mypy-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c4d8e89aa7de683e2056a581ce63c46a0c41e31bd2b6d34144e2c80f5ea53dc"}, - {file = "mypy-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:bfdca17c36ae01a21274a3c387a63aa1aafe72bff976522886869ef131b937f1"}, - {file = "mypy-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7549fbf655e5825d787bbc9ecf6028731973f78088fbca3a1f4145c39ef09462"}, - {file = "mypy-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98324ec3ecf12296e6422939e54763faedbfcc502ea4a4c38502082711867258"}, - {file = "mypy-1.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:141dedfdbfe8a04142881ff30ce6e6653c9685b354876b12e4fe6c78598b45e2"}, - {file = "mypy-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8207b7105829eca6f3d774f64a904190bb2231de91b8b186d21ffd98005f14a7"}, - {file = "mypy-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:16f0db5b641ba159eff72cff08edc3875f2b62b2fa2bc24f68c1e7a4e8232d01"}, - {file = "mypy-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:470c969bb3f9a9efcedbadcd19a74ffb34a25f8e6b0e02dae7c0e71f8372f97b"}, - {file = "mypy-1.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5952d2d18b79f7dc25e62e014fe5a23eb1a3d2bc66318df8988a01b1a037c5b"}, - {file = "mypy-1.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:190b6bab0302cec4e9e6767d3eb66085aef2a1cc98fe04936d8a42ed2ba77bb7"}, - {file = "mypy-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9d40652cc4fe33871ad3338581dca3297ff5f2213d0df345bcfbde5162abf0c9"}, - {file = "mypy-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:01fd2e9f85622d981fd9063bfaef1aed6e336eaacca00892cd2d82801ab7c042"}, - {file = "mypy-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2460a58faeea905aeb1b9b36f5065f2dc9a9c6e4c992a6499a2360c6c74ceca3"}, - {file = "mypy-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2746d69a8196698146a3dbe29104f9eb6a2a4d8a27878d92169a6c0b74435b6"}, - {file = "mypy-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ae704dcfaa180ff7c4cfbad23e74321a2b774f92ca77fd94ce1049175a21c97f"}, - {file = "mypy-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:43d24f6437925ce50139a310a64b2ab048cb2d3694c84c71c3f2a1626d8101dc"}, - {file = "mypy-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c482e1246726616088532b5e964e39765b6d1520791348e6c9dc3af25b233828"}, - {file = "mypy-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:43b592511672017f5b1a483527fd2684347fdffc041c9ef53428c8dc530f79a3"}, - {file = "mypy-1.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34a9239d5b3502c17f07fd7c0b2ae6b7dd7d7f6af35fbb5072c6208e76295816"}, - {file = "mypy-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5703097c4936bbb9e9bce41478c8d08edd2865e177dc4c52be759f81ee4dd26c"}, - {file = "mypy-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e02d700ec8d9b1859790c0475df4e4092c7bf3272a4fd2c9f33d87fac4427b8f"}, - {file = "mypy-1.4.1-py3-none-any.whl", hash = "sha256:45d32cec14e7b97af848bddd97d85ea4f0db4d5a149ed9676caa4eb2f7402bb4"}, - {file = "mypy-1.4.1.tar.gz", hash = "sha256:9bbcd9ab8ea1f2e1c8031c21445b511442cc45c89951e49bbf852cbb70755b1b"}, + {file = "mypy-1.11.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a32fc80b63de4b5b3e65f4be82b4cfa362a46702672aa6a0f443b4689af7008c"}, + {file = "mypy-1.11.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c1952f5ea8a5a959b05ed5f16452fddadbaae48b5d39235ab4c3fc444d5fd411"}, + {file = "mypy-1.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1e30dc3bfa4e157e53c1d17a0dad20f89dc433393e7702b813c10e200843b03"}, + {file = "mypy-1.11.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2c63350af88f43a66d3dfeeeb8d77af34a4f07d760b9eb3a8697f0386c7590b4"}, + {file = "mypy-1.11.1-cp310-cp310-win_amd64.whl", hash = "sha256:a831671bad47186603872a3abc19634f3011d7f83b083762c942442d51c58d58"}, + {file = "mypy-1.11.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7b6343d338390bb946d449677726edf60102a1c96079b4f002dedff375953fc5"}, + {file = "mypy-1.11.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4fe9f4e5e521b458d8feb52547f4bade7ef8c93238dfb5bbc790d9ff2d770ca"}, + {file = "mypy-1.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:886c9dbecc87b9516eff294541bf7f3655722bf22bb898ee06985cd7269898de"}, + {file = "mypy-1.11.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fca4a60e1dd9fd0193ae0067eaeeb962f2d79e0d9f0f66223a0682f26ffcc809"}, + {file = "mypy-1.11.1-cp311-cp311-win_amd64.whl", hash = "sha256:0bd53faf56de9643336aeea1c925012837432b5faf1701ccca7fde70166ccf72"}, + {file = "mypy-1.11.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f39918a50f74dc5969807dcfaecafa804fa7f90c9d60506835036cc1bc891dc8"}, + {file = "mypy-1.11.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0bc71d1fb27a428139dd78621953effe0d208aed9857cb08d002280b0422003a"}, + {file = "mypy-1.11.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b868d3bcff720dd7217c383474008ddabaf048fad8d78ed948bb4b624870a417"}, + {file = "mypy-1.11.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a707ec1527ffcdd1c784d0924bf5cb15cd7f22683b919668a04d2b9c34549d2e"}, + {file = "mypy-1.11.1-cp312-cp312-win_amd64.whl", hash = "sha256:64f4a90e3ea07f590c5bcf9029035cf0efeae5ba8be511a8caada1a4893f5525"}, + {file = "mypy-1.11.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:749fd3213916f1751fff995fccf20c6195cae941dc968f3aaadf9bb4e430e5a2"}, + {file = "mypy-1.11.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b639dce63a0b19085213ec5fdd8cffd1d81988f47a2dec7100e93564f3e8fb3b"}, + {file = "mypy-1.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4c956b49c5d865394d62941b109728c5c596a415e9c5b2be663dd26a1ff07bc0"}, + {file = "mypy-1.11.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45df906e8b6804ef4b666af29a87ad9f5921aad091c79cc38e12198e220beabd"}, + {file = "mypy-1.11.1-cp38-cp38-win_amd64.whl", hash = "sha256:d44be7551689d9d47b7abc27c71257adfdb53f03880841a5db15ddb22dc63edb"}, + {file = "mypy-1.11.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2684d3f693073ab89d76da8e3921883019ea8a3ec20fa5d8ecca6a2db4c54bbe"}, + {file = "mypy-1.11.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:79c07eb282cb457473add5052b63925e5cc97dfab9812ee65a7c7ab5e3cb551c"}, + {file = "mypy-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11965c2f571ded6239977b14deebd3f4c3abd9a92398712d6da3a772974fad69"}, + {file = "mypy-1.11.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a2b43895a0f8154df6519706d9bca8280cda52d3d9d1514b2d9c3e26792a0b74"}, + {file = "mypy-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:1a81cf05975fd61aec5ae16501a091cfb9f605dc3e3c878c0da32f250b74760b"}, + {file = "mypy-1.11.1-py3-none-any.whl", hash = "sha256:0624bdb940255d2dd24e829d99a13cfeb72e4e9031f9492148f410ed30bcab54"}, + {file = "mypy-1.11.1.tar.gz", hash = "sha256:f404a0b069709f18bbdb702eb3dcfe51910602995de00bd39cea3050b5772d08"}, ] [package.dependencies] mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typed-ast = {version = ">=1.4.0,<2", markers = "python_version < \"3.8\""} -typing-extensions = ">=4.1.0" +typing-extensions = ">=4.6.0" [package.extras] dmypy = ["psutil (>=4.0)"] install-types = ["pip"] -python2 = ["typed-ast (>=1.4.0,<2)"] +mypyc = ["setuptools (>=50)"] reports = ["lxml"] [[package]] @@ -664,31 +875,52 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "openai" +version = "1.41.1" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.41.1-py3-none-any.whl", hash = "sha256:56fb04105263f79559aff3ceea2e1dd16f8c5385e8238cb66cf0e6888fa8bfcf"}, + {file = "openai-1.41.1.tar.gz", hash = "sha256:e38e376efd91e0d4db071e2a6517b6b4cac1c2a6fd63efdc5ec6be10c5967c1b"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.11,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "packaging" -version = "23.1" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, - {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] name = "pluggy" -version = "1.2.0" +version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"}, - {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, ] -[package.dependencies] -importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} - [package.extras] dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] @@ -706,55 +938,126 @@ files = [ [[package]] name = "pydantic" -version = "1.10.11" -description = "Data validation and settings management using python type hints" +version = "2.8.2" +description = "Data validation using Python type hints" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pydantic-1.10.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ff44c5e89315b15ff1f7fdaf9853770b810936d6b01a7bcecaa227d2f8fe444f"}, - {file = "pydantic-1.10.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a6c098d4ab5e2d5b3984d3cb2527e2d6099d3de85630c8934efcfdc348a9760e"}, - {file = "pydantic-1.10.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16928fdc9cb273c6af00d9d5045434c39afba5f42325fb990add2c241402d151"}, - {file = "pydantic-1.10.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0588788a9a85f3e5e9ebca14211a496409cb3deca5b6971ff37c556d581854e7"}, - {file = "pydantic-1.10.11-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e9baf78b31da2dc3d3f346ef18e58ec5f12f5aaa17ac517e2ffd026a92a87588"}, - {file = "pydantic-1.10.11-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:373c0840f5c2b5b1ccadd9286782852b901055998136287828731868027a724f"}, - {file = "pydantic-1.10.11-cp310-cp310-win_amd64.whl", hash = "sha256:c3339a46bbe6013ef7bdd2844679bfe500347ac5742cd4019a88312aa58a9847"}, - {file = "pydantic-1.10.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:08a6c32e1c3809fbc49debb96bf833164f3438b3696abf0fbeceb417d123e6eb"}, - {file = "pydantic-1.10.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a451ccab49971af043ec4e0d207cbc8cbe53dbf148ef9f19599024076fe9c25b"}, - {file = "pydantic-1.10.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b02d24f7b2b365fed586ed73582c20f353a4c50e4be9ba2c57ab96f8091ddae"}, - {file = "pydantic-1.10.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f34739a89260dfa420aa3cbd069fbcc794b25bbe5c0a214f8fb29e363484b66"}, - {file = "pydantic-1.10.11-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:e297897eb4bebde985f72a46a7552a7556a3dd11e7f76acda0c1093e3dbcf216"}, - {file = "pydantic-1.10.11-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d185819a7a059550ecb85d5134e7d40f2565f3dd94cfd870132c5f91a89cf58c"}, - {file = "pydantic-1.10.11-cp311-cp311-win_amd64.whl", hash = "sha256:4400015f15c9b464c9db2d5d951b6a780102cfa5870f2c036d37c23b56f7fc1b"}, - {file = "pydantic-1.10.11-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2417de68290434461a266271fc57274a138510dca19982336639484c73a07af6"}, - {file = "pydantic-1.10.11-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:331c031ba1554b974c98679bd0780d89670d6fd6f53f5d70b10bdc9addee1713"}, - {file = "pydantic-1.10.11-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8268a735a14c308923e8958363e3a3404f6834bb98c11f5ab43251a4e410170c"}, - {file = "pydantic-1.10.11-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:44e51ba599c3ef227e168424e220cd3e544288c57829520dc90ea9cb190c3248"}, - {file = "pydantic-1.10.11-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d7781f1d13b19700b7949c5a639c764a077cbbdd4322ed505b449d3ca8edcb36"}, - {file = "pydantic-1.10.11-cp37-cp37m-win_amd64.whl", hash = "sha256:7522a7666157aa22b812ce14c827574ddccc94f361237ca6ea8bb0d5c38f1629"}, - {file = "pydantic-1.10.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bc64eab9b19cd794a380179ac0e6752335e9555d214cfcb755820333c0784cb3"}, - {file = "pydantic-1.10.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8dc77064471780262b6a68fe67e013298d130414d5aaf9b562c33987dbd2cf4f"}, - {file = "pydantic-1.10.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe429898f2c9dd209bd0632a606bddc06f8bce081bbd03d1c775a45886e2c1cb"}, - {file = "pydantic-1.10.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:192c608ad002a748e4a0bed2ddbcd98f9b56df50a7c24d9a931a8c5dd053bd3d"}, - {file = "pydantic-1.10.11-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ef55392ec4bb5721f4ded1096241e4b7151ba6d50a50a80a2526c854f42e6a2f"}, - {file = "pydantic-1.10.11-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:41e0bb6efe86281623abbeeb0be64eab740c865388ee934cd3e6a358784aca6e"}, - {file = "pydantic-1.10.11-cp38-cp38-win_amd64.whl", hash = "sha256:265a60da42f9f27e0b1014eab8acd3e53bd0bad5c5b4884e98a55f8f596b2c19"}, - {file = "pydantic-1.10.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:469adf96c8e2c2bbfa655fc7735a2a82f4c543d9fee97bd113a7fb509bf5e622"}, - {file = "pydantic-1.10.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e6cbfbd010b14c8a905a7b10f9fe090068d1744d46f9e0c021db28daeb8b6de1"}, - {file = "pydantic-1.10.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abade85268cc92dff86d6effcd917893130f0ff516f3d637f50dadc22ae93999"}, - {file = "pydantic-1.10.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9738b0f2e6c70f44ee0de53f2089d6002b10c33264abee07bdb5c7f03038303"}, - {file = "pydantic-1.10.11-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:787cf23e5a0cde753f2eabac1b2e73ae3844eb873fd1f5bdbff3048d8dbb7604"}, - {file = "pydantic-1.10.11-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:174899023337b9fc685ac8adaa7b047050616136ccd30e9070627c1aaab53a13"}, - {file = "pydantic-1.10.11-cp39-cp39-win_amd64.whl", hash = "sha256:1954f8778489a04b245a1e7b8b22a9d3ea8ef49337285693cf6959e4b757535e"}, - {file = "pydantic-1.10.11-py3-none-any.whl", hash = "sha256:008c5e266c8aada206d0627a011504e14268a62091450210eda7c07fabe6963e"}, - {file = "pydantic-1.10.11.tar.gz", hash = "sha256:f66d479cf7eb331372c470614be6511eae96f1f120344c25f3f9bb59fb1b5528"}, + {file = "pydantic-2.8.2-py3-none-any.whl", hash = "sha256:73ee9fddd406dc318b885c7a2eab8a6472b68b8fb5ba8150949fc3db939f23c8"}, + {file = "pydantic-2.8.2.tar.gz", hash = "sha256:6f62c13d067b0755ad1c21a34bdd06c0c12625a22b0fc09c6b149816604f7c2a"}, ] [package.dependencies] -typing-extensions = ">=4.2.0" +annotated-types = ">=0.4.0" +pydantic-core = "2.20.1" +typing-extensions = [ + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, +] [package.extras] -dotenv = ["python-dotenv (>=0.10.4)"] -email = ["email-validator (>=1.0.3)"] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.20.1" +description = "Core functionality for Pydantic validation and serialization" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3acae97ffd19bf091c72df4d726d552c473f3576409b2a7ca36b2f535ffff4a3"}, + {file = "pydantic_core-2.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41f4c96227a67a013e7de5ff8f20fb496ce573893b7f4f2707d065907bffdbd6"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f239eb799a2081495ea659d8d4a43a8f42cd1fe9ff2e7e436295c38a10c286a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53e431da3fc53360db73eedf6f7124d1076e1b4ee4276b36fb25514544ceb4a3"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1f62b2413c3a0e846c3b838b2ecd6c7a19ec6793b2a522745b0869e37ab5bc1"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d41e6daee2813ecceea8eda38062d69e280b39df793f5a942fa515b8ed67953"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d482efec8b7dc6bfaedc0f166b2ce349df0011f5d2f1f25537ced4cfc34fd98"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e93e1a4b4b33daed65d781a57a522ff153dcf748dee70b40c7258c5861e1768a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7c4ea22b6739b162c9ecaaa41d718dfad48a244909fe7ef4b54c0b530effc5a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4f2790949cf385d985a31984907fecb3896999329103df4e4983a4a41e13e840"}, + {file = "pydantic_core-2.20.1-cp310-none-win32.whl", hash = "sha256:5e999ba8dd90e93d57410c5e67ebb67ffcaadcea0ad973240fdfd3a135506250"}, + {file = "pydantic_core-2.20.1-cp310-none-win_amd64.whl", hash = "sha256:512ecfbefef6dac7bc5eaaf46177b2de58cdf7acac8793fe033b24ece0b9566c"}, + {file = "pydantic_core-2.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d2a8fa9d6d6f891f3deec72f5cc668e6f66b188ab14bb1ab52422fe8e644f312"}, + {file = "pydantic_core-2.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:175873691124f3d0da55aeea1d90660a6ea7a3cfea137c38afa0a5ffabe37b88"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37eee5b638f0e0dcd18d21f59b679686bbd18917b87db0193ae36f9c23c355fc"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25e9185e2d06c16ee438ed39bf62935ec436474a6ac4f9358524220f1b236e43"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:150906b40ff188a3260cbee25380e7494ee85048584998c1e66df0c7a11c17a6"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ad4aeb3e9a97286573c03df758fc7627aecdd02f1da04516a86dc159bf70121"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3f3ed29cd9f978c604708511a1f9c2fdcb6c38b9aae36a51905b8811ee5cbf1"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b0dae11d8f5ded51699c74d9548dcc5938e0804cc8298ec0aa0da95c21fff57b"}, + {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faa6b09ee09433b87992fb5a2859efd1c264ddc37280d2dd5db502126d0e7f27"}, + {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9dc1b507c12eb0481d071f3c1808f0529ad41dc415d0ca11f7ebfc666e66a18b"}, + {file = "pydantic_core-2.20.1-cp311-none-win32.whl", hash = "sha256:fa2fddcb7107e0d1808086ca306dcade7df60a13a6c347a7acf1ec139aa6789a"}, + {file = "pydantic_core-2.20.1-cp311-none-win_amd64.whl", hash = "sha256:40a783fb7ee353c50bd3853e626f15677ea527ae556429453685ae32280c19c2"}, + {file = "pydantic_core-2.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:595ba5be69b35777474fa07f80fc260ea71255656191adb22a8c53aba4479231"}, + {file = "pydantic_core-2.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a4f55095ad087474999ee28d3398bae183a66be4823f753cd7d67dd0153427c9"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9aa05d09ecf4c75157197f27cdc9cfaeb7c5f15021c6373932bf3e124af029f"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e97fdf088d4b31ff4ba35db26d9cc472ac7ef4a2ff2badeabf8d727b3377fc52"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc633a9fe1eb87e250b5c57d389cf28998e4292336926b0b6cdaee353f89a237"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d573faf8eb7e6b1cbbcb4f5b247c60ca8be39fe2c674495df0eb4318303137fe"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26dc97754b57d2fd00ac2b24dfa341abffc380b823211994c4efac7f13b9e90e"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:33499e85e739a4b60c9dac710c20a08dc73cb3240c9a0e22325e671b27b70d24"}, + {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bebb4d6715c814597f85297c332297c6ce81e29436125ca59d1159b07f423eb1"}, + {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:516d9227919612425c8ef1c9b869bbbee249bc91912c8aaffb66116c0b447ebd"}, + {file = "pydantic_core-2.20.1-cp312-none-win32.whl", hash = "sha256:469f29f9093c9d834432034d33f5fe45699e664f12a13bf38c04967ce233d688"}, + {file = "pydantic_core-2.20.1-cp312-none-win_amd64.whl", hash = "sha256:035ede2e16da7281041f0e626459bcae33ed998cca6a0a007a5ebb73414ac72d"}, + {file = "pydantic_core-2.20.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:0827505a5c87e8aa285dc31e9ec7f4a17c81a813d45f70b1d9164e03a813a686"}, + {file = "pydantic_core-2.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:19c0fa39fa154e7e0b7f82f88ef85faa2a4c23cc65aae2f5aea625e3c13c735a"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa223cd1e36b642092c326d694d8bf59b71ddddc94cdb752bbbb1c5c91d833b"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c336a6d235522a62fef872c6295a42ecb0c4e1d0f1a3e500fe949415761b8a19"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7eb6a0587eded33aeefea9f916899d42b1799b7b14b8f8ff2753c0ac1741edac"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70c8daf4faca8da5a6d655f9af86faf6ec2e1768f4b8b9d0226c02f3d6209703"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9fa4c9bf273ca41f940bceb86922a7667cd5bf90e95dbb157cbb8441008482c"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:11b71d67b4725e7e2a9f6e9c0ac1239bbc0c48cce3dc59f98635efc57d6dac83"}, + {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:270755f15174fb983890c49881e93f8f1b80f0b5e3a3cc1394a255706cabd203"}, + {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c81131869240e3e568916ef4c307f8b99583efaa60a8112ef27a366eefba8ef0"}, + {file = "pydantic_core-2.20.1-cp313-none-win32.whl", hash = "sha256:b91ced227c41aa29c672814f50dbb05ec93536abf8f43cd14ec9521ea09afe4e"}, + {file = "pydantic_core-2.20.1-cp313-none-win_amd64.whl", hash = "sha256:65db0f2eefcaad1a3950f498aabb4875c8890438bc80b19362cf633b87a8ab20"}, + {file = "pydantic_core-2.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4745f4ac52cc6686390c40eaa01d48b18997cb130833154801a442323cc78f91"}, + {file = "pydantic_core-2.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8ad4c766d3f33ba8fd692f9aa297c9058970530a32c728a2c4bfd2616d3358b"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41e81317dd6a0127cabce83c0c9c3fbecceae981c8391e6f1dec88a77c8a569a"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04024d270cf63f586ad41fff13fde4311c4fc13ea74676962c876d9577bcc78f"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eaad4ff2de1c3823fddf82f41121bdf453d922e9a238642b1dedb33c4e4f98ad"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26ab812fa0c845df815e506be30337e2df27e88399b985d0bb4e3ecfe72df31c"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ebac750d9d5f2706654c638c041635c385596caf68f81342011ddfa1e5598"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2aafc5a503855ea5885559eae883978c9b6d8c8993d67766ee73d82e841300dd"}, + {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4868f6bd7c9d98904b748a2653031fc9c2f85b6237009d475b1008bfaeb0a5aa"}, + {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa2f457b4af386254372dfa78a2eda2563680d982422641a85f271c859df1987"}, + {file = "pydantic_core-2.20.1-cp38-none-win32.whl", hash = "sha256:225b67a1f6d602de0ce7f6c1c3ae89a4aa25d3de9be857999e9124f15dab486a"}, + {file = "pydantic_core-2.20.1-cp38-none-win_amd64.whl", hash = "sha256:6b507132dcfc0dea440cce23ee2182c0ce7aba7054576efc65634f080dbe9434"}, + {file = "pydantic_core-2.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b03f7941783b4c4a26051846dea594628b38f6940a2fdc0df00b221aed39314c"}, + {file = "pydantic_core-2.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1eedfeb6089ed3fad42e81a67755846ad4dcc14d73698c120a82e4ccf0f1f9f6"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:635fee4e041ab9c479e31edda27fcf966ea9614fff1317e280d99eb3e5ab6fe2"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:77bf3ac639c1ff567ae3b47f8d4cc3dc20f9966a2a6dd2311dcc055d3d04fb8a"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ed1b0132f24beeec5a78b67d9388656d03e6a7c837394f99257e2d55b461611"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c6514f963b023aeee506678a1cf821fe31159b925c4b76fe2afa94cc70b3222b"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d4204d8ca33146e761c79f83cc861df20e7ae9f6487ca290a97702daf56006"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d036c7187b9422ae5b262badb87a20a49eb6c5238b2004e96d4da1231badef1"}, + {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9ebfef07dbe1d93efb94b4700f2d278494e9162565a54f124c404a5656d7ff09"}, + {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6b9d9bb600328a1ce523ab4f454859e9d439150abb0906c5a1983c146580ebab"}, + {file = "pydantic_core-2.20.1-cp39-none-win32.whl", hash = "sha256:784c1214cb6dd1e3b15dd8b91b9a53852aed16671cc3fbe4786f4f1db07089e2"}, + {file = "pydantic_core-2.20.1-cp39-none-win_amd64.whl", hash = "sha256:d2fe69c5434391727efa54b47a1e7986bb0186e72a41b203df8f5b0a19a4f669"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a45f84b09ac9c3d35dfcf6a27fd0634d30d183205230a0ebe8373a0e8cfa0906"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d02a72df14dfdbaf228424573a07af10637bd490f0901cee872c4f434a735b94"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2b27e6af28f07e2f195552b37d7d66b150adbaa39a6d327766ffd695799780f"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084659fac3c83fd674596612aeff6041a18402f1e1bc19ca39e417d554468482"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:242b8feb3c493ab78be289c034a1f659e8826e2233786e36f2893a950a719bb6"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:38cf1c40a921d05c5edc61a785c0ddb4bed67827069f535d794ce6bcded919fc"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e0bbdd76ce9aa5d4209d65f2b27fc6e5ef1312ae6c5333c26db3f5ade53a1e99"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:254ec27fdb5b1ee60684f91683be95e5133c994cc54e86a0b0963afa25c8f8a6"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:407653af5617f0757261ae249d3fba09504d7a71ab36ac057c938572d1bc9331"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:c693e916709c2465b02ca0ad7b387c4f8423d1db7b4649c551f27a529181c5ad"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b5ff4911aea936a47d9376fd3ab17e970cc543d1b68921886e7f64bd28308d1"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177f55a886d74f1808763976ac4efd29b7ed15c69f4d838bbd74d9d09cf6fa86"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:964faa8a861d2664f0c7ab0c181af0bea66098b1919439815ca8803ef136fc4e"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4dd484681c15e6b9a977c785a345d3e378d72678fd5f1f3c0509608da24f2ac0"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f6d6cff3538391e8486a431569b77921adfcdef14eb18fbf19b7c0a5294d4e6a"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a6d511cc297ff0883bc3708b465ff82d7560193169a8b93260f74ecb0a5e08a7"}, + {file = "pydantic_core-2.20.1.tar.gz", hash = "sha256:26ca695eeee5f9f1aeeb211ffc12f10bcb6f71e2989988fda61dabd65db878d4"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pytest" @@ -771,7 +1074,6 @@ files = [ atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} -importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} iniconfig = "*" packaging = "*" pluggy = ">=0.12,<2.0" @@ -794,7 +1096,6 @@ files = [ [package.dependencies] pytest = ">=6.1.0" -typing-extensions = {version = ">=4.0", markers = "python_version < \"3.8\""} [package.extras] testing = ["coverage (==6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (==0.931)"] @@ -832,13 +1133,12 @@ files = [ attrs = ">=19.0" filelock = ">=3.0" mypy = [ - {version = ">=0.500", markers = "python_version < \"3.8\""}, - {version = ">=0.700", markers = "python_version >= \"3.8\" and python_version < \"3.9\""}, {version = ">=0.780", markers = "python_version >= \"3.9\""}, + {version = ">=0.700", markers = "python_version >= \"3.8\" and python_version < \"3.9\""}, ] pytest = [ - {version = ">=4.6", markers = "python_version >= \"3.6\" and python_version < \"3.10\""}, {version = ">=6.2", markers = "python_version >= \"3.10\""}, + {version = ">=4.6", markers = "python_version >= \"3.6\" and python_version < \"3.10\""}, ] [[package]] @@ -863,159 +1163,163 @@ regex = "*" [[package]] name = "pyyaml" -version = "6.0" +version = "6.0.2" description = "YAML parser and emitter for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, - {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, - {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, - {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b"}, - {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, - {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, - {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, - {file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"}, - {file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"}, - {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"}, - {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"}, - {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"}, - {file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"}, - {file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"}, - {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, - {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, - {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, - {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4"}, - {file = "PyYAML-6.0-cp36-cp36m-win32.whl", hash = "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293"}, - {file = "PyYAML-6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57"}, - {file = "PyYAML-6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c"}, - {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0"}, - {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4"}, - {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9"}, - {file = "PyYAML-6.0-cp37-cp37m-win32.whl", hash = "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737"}, - {file = "PyYAML-6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d"}, - {file = "PyYAML-6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b"}, - {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba"}, - {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34"}, - {file = "PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287"}, - {file = "PyYAML-6.0-cp38-cp38-win32.whl", hash = "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78"}, - {file = "PyYAML-6.0-cp38-cp38-win_amd64.whl", hash = "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07"}, - {file = "PyYAML-6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b"}, - {file = "PyYAML-6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174"}, - {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803"}, - {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3"}, - {file = "PyYAML-6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0"}, - {file = "PyYAML-6.0-cp39-cp39-win32.whl", hash = "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb"}, - {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, - {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] [[package]] name = "regex" -version = "2023.6.3" +version = "2024.7.24" description = "Alternative regular expression module, to replace re." optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "regex-2023.6.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:824bf3ac11001849aec3fa1d69abcb67aac3e150a933963fb12bda5151fe1bfd"}, - {file = "regex-2023.6.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:05ed27acdf4465c95826962528f9e8d41dbf9b1aa8531a387dee6ed215a3e9ef"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b49c764f88a79160fa64f9a7b425620e87c9f46095ef9c9920542ab2495c8bc"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e3f1316c2293e5469f8f09dc2d76efb6c3982d3da91ba95061a7e69489a14ef"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:43e1dd9d12df9004246bacb79a0e5886b3b6071b32e41f83b0acbf293f820ee8"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4959e8bcbfda5146477d21c3a8ad81b185cd252f3d0d6e4724a5ef11c012fb06"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:af4dd387354dc83a3bff67127a124c21116feb0d2ef536805c454721c5d7993d"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2239d95d8e243658b8dbb36b12bd10c33ad6e6933a54d36ff053713f129aa536"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:890e5a11c97cf0d0c550eb661b937a1e45431ffa79803b942a057c4fb12a2da2"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a8105e9af3b029f243ab11ad47c19b566482c150c754e4c717900a798806b222"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:25be746a8ec7bc7b082783216de8e9473803706723b3f6bef34b3d0ed03d57e2"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:3676f1dd082be28b1266c93f618ee07741b704ab7b68501a173ce7d8d0d0ca18"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:10cb847aeb1728412c666ab2e2000ba6f174f25b2bdc7292e7dd71b16db07568"}, - {file = "regex-2023.6.3-cp310-cp310-win32.whl", hash = "sha256:dbbbfce33cd98f97f6bffb17801b0576e653f4fdb1d399b2ea89638bc8d08ae1"}, - {file = "regex-2023.6.3-cp310-cp310-win_amd64.whl", hash = "sha256:c5f8037000eb21e4823aa485149f2299eb589f8d1fe4b448036d230c3f4e68e0"}, - {file = "regex-2023.6.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c123f662be8ec5ab4ea72ea300359023a5d1df095b7ead76fedcd8babbedf969"}, - {file = "regex-2023.6.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9edcbad1f8a407e450fbac88d89e04e0b99a08473f666a3f3de0fd292badb6aa"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcba6dae7de533c876255317c11f3abe4907ba7d9aa15d13e3d9710d4315ec0e"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29cdd471ebf9e0f2fb3cac165efedc3c58db841d83a518b082077e612d3ee5df"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:12b74fbbf6cbbf9dbce20eb9b5879469e97aeeaa874145517563cca4029db65c"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c29ca1bd61b16b67be247be87390ef1d1ef702800f91fbd1991f5c4421ebae8"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77f09bc4b55d4bf7cc5eba785d87001d6757b7c9eec237fe2af57aba1a071d9"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ea353ecb6ab5f7e7d2f4372b1e779796ebd7b37352d290096978fea83c4dba0c"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:10590510780b7541969287512d1b43f19f965c2ece6c9b1c00fc367b29d8dce7"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e2fbd6236aae3b7f9d514312cdb58e6494ee1c76a9948adde6eba33eb1c4264f"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:6b2675068c8b56f6bfd5a2bda55b8accbb96c02fd563704732fd1c95e2083461"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:74419d2b50ecb98360cfaa2974da8689cb3b45b9deff0dcf489c0d333bcc1477"}, - {file = "regex-2023.6.3-cp311-cp311-win32.whl", hash = "sha256:fb5ec16523dc573a4b277663a2b5a364e2099902d3944c9419a40ebd56a118f9"}, - {file = "regex-2023.6.3-cp311-cp311-win_amd64.whl", hash = "sha256:09e4a1a6acc39294a36b7338819b10baceb227f7f7dbbea0506d419b5a1dd8af"}, - {file = "regex-2023.6.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:0654bca0cdf28a5956c83839162692725159f4cda8d63e0911a2c0dc76166525"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:463b6a3ceb5ca952e66550a4532cef94c9a0c80dc156c4cc343041951aec1697"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87b2a5bb5e78ee0ad1de71c664d6eb536dc3947a46a69182a90f4410f5e3f7dd"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6343c6928282c1f6a9db41f5fd551662310e8774c0e5ebccb767002fcf663ca9"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6192d5af2ccd2a38877bfef086d35e6659566a335b1492786ff254c168b1693"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74390d18c75054947e4194019077e243c06fbb62e541d8817a0fa822ea310c14"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:742e19a90d9bb2f4a6cf2862b8b06dea5e09b96c9f2df1779e53432d7275331f"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:8abbc5d54ea0ee80e37fef009e3cec5dafd722ed3c829126253d3e22f3846f1e"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:c2b867c17a7a7ae44c43ebbeb1b5ff406b3e8d5b3e14662683e5e66e6cc868d3"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:d831c2f8ff278179705ca59f7e8524069c1a989e716a1874d6d1aab6119d91d1"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:ee2d1a9a253b1729bb2de27d41f696ae893507c7db224436abe83ee25356f5c1"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:61474f0b41fe1a80e8dfa70f70ea1e047387b7cd01c85ec88fa44f5d7561d787"}, - {file = "regex-2023.6.3-cp36-cp36m-win32.whl", hash = "sha256:0b71e63226e393b534105fcbdd8740410dc6b0854c2bfa39bbda6b0d40e59a54"}, - {file = "regex-2023.6.3-cp36-cp36m-win_amd64.whl", hash = "sha256:bbb02fd4462f37060122e5acacec78e49c0fbb303c30dd49c7f493cf21fc5b27"}, - {file = "regex-2023.6.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b862c2b9d5ae38a68b92e215b93f98d4c5e9454fa36aae4450f61dd33ff48487"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:976d7a304b59ede34ca2921305b57356694f9e6879db323fd90a80f865d355a3"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:83320a09188e0e6c39088355d423aa9d056ad57a0b6c6381b300ec1a04ec3d16"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9427a399501818a7564f8c90eced1e9e20709ece36be701f394ada99890ea4b3"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178bbc1b2ec40eaca599d13c092079bf529679bf0371c602edaa555e10b41c3"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:837328d14cde912af625d5f303ec29f7e28cdab588674897baafaf505341f2fc"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d44dc13229905ae96dd2ae2dd7cebf824ee92bc52e8cf03dcead37d926da019"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d54af539295392611e7efbe94e827311eb8b29668e2b3f4cadcfe6f46df9c777"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:7117d10690c38a622e54c432dfbbd3cbd92f09401d622902c32f6d377e2300ee"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bb60b503ec8a6e4e3e03a681072fa3a5adcbfa5479fa2d898ae2b4a8e24c4591"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:65ba8603753cec91c71de423a943ba506363b0e5c3fdb913ef8f9caa14b2c7e0"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:271f0bdba3c70b58e6f500b205d10a36fb4b58bd06ac61381b68de66442efddb"}, - {file = "regex-2023.6.3-cp37-cp37m-win32.whl", hash = "sha256:9beb322958aaca059f34975b0df135181f2e5d7a13b84d3e0e45434749cb20f7"}, - {file = "regex-2023.6.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fea75c3710d4f31389eed3c02f62d0b66a9da282521075061ce875eb5300cf23"}, - {file = "regex-2023.6.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8f56fcb7ff7bf7404becdfc60b1e81a6d0561807051fd2f1860b0d0348156a07"}, - {file = "regex-2023.6.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d2da3abc88711bce7557412310dfa50327d5769a31d1c894b58eb256459dc289"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a99b50300df5add73d307cf66abea093304a07eb017bce94f01e795090dea87c"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5708089ed5b40a7b2dc561e0c8baa9535b77771b64a8330b684823cfd5116036"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:687ea9d78a4b1cf82f8479cab23678aff723108df3edeac098e5b2498879f4a7"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3850beab9f527f06ccc94b446c864059c57651b3f911fddb8d9d3ec1d1b25d"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8915cc96abeb8983cea1df3c939e3c6e1ac778340c17732eb63bb96247b91d2"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:841d6e0e5663d4c7b4c8099c9997be748677d46cbf43f9f471150e560791f7ff"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9edce5281f965cf135e19840f4d93d55b3835122aa76ccacfd389e880ba4cf82"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b956231ebdc45f5b7a2e1f90f66a12be9610ce775fe1b1d50414aac1e9206c06"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:36efeba71c6539d23c4643be88295ce8c82c88bbd7c65e8a24081d2ca123da3f"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:cf67ca618b4fd34aee78740bea954d7c69fdda419eb208c2c0c7060bb822d747"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b4598b1897837067a57b08147a68ac026c1e73b31ef6e36deeeb1fa60b2933c9"}, - {file = "regex-2023.6.3-cp38-cp38-win32.whl", hash = "sha256:f415f802fbcafed5dcc694c13b1292f07fe0befdb94aa8a52905bd115ff41e88"}, - {file = "regex-2023.6.3-cp38-cp38-win_amd64.whl", hash = "sha256:d4f03bb71d482f979bda92e1427f3ec9b220e62a7dd337af0aa6b47bf4498f72"}, - {file = "regex-2023.6.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ccf91346b7bd20c790310c4147eee6ed495a54ddb6737162a36ce9dbef3e4751"}, - {file = "regex-2023.6.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b28f5024a3a041009eb4c333863d7894d191215b39576535c6734cd88b0fcb68"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0bb18053dfcfed432cc3ac632b5e5e5c5b7e55fb3f8090e867bfd9b054dbcbf"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a5bfb3004f2144a084a16ce19ca56b8ac46e6fd0651f54269fc9e230edb5e4a"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c6b48d0fa50d8f4df3daf451be7f9689c2bde1a52b1225c5926e3f54b6a9ed1"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:051da80e6eeb6e239e394ae60704d2b566aa6a7aed6f2890a7967307267a5dc6"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a4c3b7fa4cdaa69268748665a1a6ff70c014d39bb69c50fda64b396c9116cf77"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:457b6cce21bee41ac292d6753d5e94dcbc5c9e3e3a834da285b0bde7aa4a11e9"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:aad51907d74fc183033ad796dd4c2e080d1adcc4fd3c0fd4fd499f30c03011cd"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0385e73da22363778ef2324950e08b689abdf0b108a7d8decb403ad7f5191938"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c6a57b742133830eec44d9b2290daf5cbe0a2f1d6acee1b3c7b1c7b2f3606df7"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3e5219bf9e75993d73ab3d25985c857c77e614525fac9ae02b1bebd92f7cecac"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e5087a3c59eef624a4591ef9eaa6e9a8d8a94c779dade95d27c0bc24650261cd"}, - {file = "regex-2023.6.3-cp39-cp39-win32.whl", hash = "sha256:20326216cc2afe69b6e98528160b225d72f85ab080cbdf0b11528cbbaba2248f"}, - {file = "regex-2023.6.3-cp39-cp39-win_amd64.whl", hash = "sha256:bdff5eab10e59cf26bc479f565e25ed71a7d041d1ded04ccf9aee1d9f208487a"}, - {file = "regex-2023.6.3.tar.gz", hash = "sha256:72d1a25bf36d2050ceb35b517afe13864865268dfb45910e2e17a84be6cbfeb0"}, + {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b0d3f567fafa0633aee87f08b9276c7062da9616931382993c03808bb68ce"}, + {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3426de3b91d1bc73249042742f45c2148803c111d1175b283270177fdf669024"}, + {file = "regex-2024.7.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f273674b445bcb6e4409bf8d1be67bc4b58e8b46fd0d560055d515b8830063cd"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23acc72f0f4e1a9e6e9843d6328177ae3074b4182167e34119ec7233dfeccf53"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65fd3d2e228cae024c411c5ccdffae4c315271eee4a8b839291f84f796b34eca"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c414cbda77dbf13c3bc88b073a1a9f375c7b0cb5e115e15d4b73ec3a2fbc6f59"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf7a89eef64b5455835f5ed30254ec19bf41f7541cd94f266ab7cbd463f00c41"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19c65b00d42804e3fbea9708f0937d157e53429a39b7c61253ff15670ff62cb5"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7a5486ca56c8869070a966321d5ab416ff0f83f30e0e2da1ab48815c8d165d46"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6f51f9556785e5a203713f5efd9c085b4a45aecd2a42573e2b5041881b588d1f"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a4997716674d36a82eab3e86f8fa77080a5d8d96a389a61ea1d0e3a94a582cf7"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c0abb5e4e8ce71a61d9446040c1e86d4e6d23f9097275c5bd49ed978755ff0fe"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:18300a1d78cf1290fa583cd8b7cde26ecb73e9f5916690cf9d42de569c89b1ce"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:416c0e4f56308f34cdb18c3f59849479dde5b19febdcd6e6fa4d04b6c31c9faa"}, + {file = "regex-2024.7.24-cp310-cp310-win32.whl", hash = "sha256:fb168b5924bef397b5ba13aabd8cf5df7d3d93f10218d7b925e360d436863f66"}, + {file = "regex-2024.7.24-cp310-cp310-win_amd64.whl", hash = "sha256:6b9fc7e9cc983e75e2518496ba1afc524227c163e43d706688a6bb9eca41617e"}, + {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:382281306e3adaaa7b8b9ebbb3ffb43358a7bbf585fa93821300a418bb975281"}, + {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fdd1384619f406ad9037fe6b6eaa3de2749e2e12084abc80169e8e075377d3b"}, + {file = "regex-2024.7.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3d974d24edb231446f708c455fd08f94c41c1ff4f04bcf06e5f36df5ef50b95a"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec4419a3fe6cf8a4795752596dfe0adb4aea40d3683a132bae9c30b81e8d73"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb563dd3aea54c797adf513eeec819c4213d7dbfc311874eb4fd28d10f2ff0f2"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45104baae8b9f67569f0f1dca5e1f1ed77a54ae1cd8b0b07aba89272710db61e"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:994448ee01864501912abf2bad9203bffc34158e80fe8bfb5b031f4f8e16da51"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fac296f99283ac232d8125be932c5cd7644084a30748fda013028c815ba3364"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7e37e809b9303ec3a179085415cb5f418ecf65ec98cdfe34f6a078b46ef823ee"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:01b689e887f612610c869421241e075c02f2e3d1ae93a037cb14f88ab6a8934c"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f6442f0f0ff81775eaa5b05af8a0ffa1dda36e9cf6ec1e0d3d245e8564b684ce"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:871e3ab2838fbcb4e0865a6e01233975df3a15e6fce93b6f99d75cacbd9862d1"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c918b7a1e26b4ab40409820ddccc5d49871a82329640f5005f73572d5eaa9b5e"}, + {file = "regex-2024.7.24-cp311-cp311-win32.whl", hash = "sha256:2dfbb8baf8ba2c2b9aa2807f44ed272f0913eeeba002478c4577b8d29cde215c"}, + {file = "regex-2024.7.24-cp311-cp311-win_amd64.whl", hash = "sha256:538d30cd96ed7d1416d3956f94d54e426a8daf7c14527f6e0d6d425fcb4cca52"}, + {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:fe4ebef608553aff8deb845c7f4f1d0740ff76fa672c011cc0bacb2a00fbde86"}, + {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:74007a5b25b7a678459f06559504f1eec2f0f17bca218c9d56f6a0a12bfffdad"}, + {file = "regex-2024.7.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7df9ea48641da022c2a3c9c641650cd09f0cd15e8908bf931ad538f5ca7919c9"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1141a1dcc32904c47f6846b040275c6e5de0bf73f17d7a409035d55b76f289"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80c811cfcb5c331237d9bad3bea2c391114588cf4131707e84d9493064d267f9"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7214477bf9bd195894cf24005b1e7b496f46833337b5dedb7b2a6e33f66d962c"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d55588cba7553f0b6ec33130bc3e114b355570b45785cebdc9daed8c637dd440"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558a57cfc32adcf19d3f791f62b5ff564922942e389e3cfdb538a23d65a6b610"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a512eed9dfd4117110b1881ba9a59b31433caed0c4101b361f768e7bcbaf93c5"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:86b17ba823ea76256b1885652e3a141a99a5c4422f4a869189db328321b73799"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5eefee9bfe23f6df09ffb6dfb23809f4d74a78acef004aa904dc7c88b9944b05"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:731fcd76bbdbf225e2eb85b7c38da9633ad3073822f5ab32379381e8c3c12e94"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eaef80eac3b4cfbdd6de53c6e108b4c534c21ae055d1dbea2de6b3b8ff3def38"}, + {file = "regex-2024.7.24-cp312-cp312-win32.whl", hash = "sha256:185e029368d6f89f36e526764cf12bf8d6f0e3a2a7737da625a76f594bdfcbfc"}, + {file = "regex-2024.7.24-cp312-cp312-win_amd64.whl", hash = "sha256:2f1baff13cc2521bea83ab2528e7a80cbe0ebb2c6f0bfad15be7da3aed443908"}, + {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:66b4c0731a5c81921e938dcf1a88e978264e26e6ac4ec96a4d21ae0354581ae0"}, + {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:88ecc3afd7e776967fa16c80f974cb79399ee8dc6c96423321d6f7d4b881c92b"}, + {file = "regex-2024.7.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64bd50cf16bcc54b274e20235bf8edbb64184a30e1e53873ff8d444e7ac656b2"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb462f0e346fcf41a901a126b50f8781e9a474d3927930f3490f38a6e73b6950"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a82465ebbc9b1c5c50738536fdfa7cab639a261a99b469c9d4c7dcbb2b3f1e57"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68a8f8c046c6466ac61a36b65bb2395c74451df2ffb8458492ef49900efed293"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac8e84fff5d27420f3c1e879ce9929108e873667ec87e0c8eeb413a5311adfe"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba2537ef2163db9e6ccdbeb6f6424282ae4dea43177402152c67ef869cf3978b"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:43affe33137fcd679bdae93fb25924979517e011f9dea99163f80b82eadc7e53"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:c9bb87fdf2ab2370f21e4d5636e5317775e5d51ff32ebff2cf389f71b9b13750"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:945352286a541406f99b2655c973852da7911b3f4264e010218bbc1cc73168f2"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:8bc593dcce679206b60a538c302d03c29b18e3d862609317cb560e18b66d10cf"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:3f3b6ca8eae6d6c75a6cff525c8530c60e909a71a15e1b731723233331de4169"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c51edc3541e11fbe83f0c4d9412ef6c79f664a3745fab261457e84465ec9d5a8"}, + {file = "regex-2024.7.24-cp38-cp38-win32.whl", hash = "sha256:d0a07763776188b4db4c9c7fb1b8c494049f84659bb387b71c73bbc07f189e96"}, + {file = "regex-2024.7.24-cp38-cp38-win_amd64.whl", hash = "sha256:8fd5afd101dcf86a270d254364e0e8dddedebe6bd1ab9d5f732f274fa00499a5"}, + {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0ffe3f9d430cd37d8fa5632ff6fb36d5b24818c5c986893063b4e5bdb84cdf24"}, + {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:25419b70ba00a16abc90ee5fce061228206173231f004437730b67ac77323f0d"}, + {file = "regex-2024.7.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:33e2614a7ce627f0cdf2ad104797d1f68342d967de3695678c0cb84f530709f8"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d33a0021893ede5969876052796165bab6006559ab845fd7b515a30abdd990dc"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04ce29e2c5fedf296b1a1b0acc1724ba93a36fb14031f3abfb7abda2806c1535"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b16582783f44fbca6fcf46f61347340c787d7530d88b4d590a397a47583f31dd"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:836d3cc225b3e8a943d0b02633fb2f28a66e281290302a79df0e1eaa984ff7c1"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:438d9f0f4bc64e8dea78274caa5af971ceff0f8771e1a2333620969936ba10be"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:973335b1624859cb0e52f96062a28aa18f3a5fc77a96e4a3d6d76e29811a0e6e"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c5e69fd3eb0b409432b537fe3c6f44ac089c458ab6b78dcec14478422879ec5f"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:fbf8c2f00904eaf63ff37718eb13acf8e178cb940520e47b2f05027f5bb34ce3"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2757ace61bc4061b69af19e4689fa4416e1a04840f33b441034202b5cd02d4"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:44fc61b99035fd9b3b9453f1713234e5a7c92a04f3577252b45feefe1b327759"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:84c312cdf839e8b579f504afcd7b65f35d60b6285d892b19adea16355e8343c9"}, + {file = "regex-2024.7.24-cp39-cp39-win32.whl", hash = "sha256:ca5b2028c2f7af4e13fb9fc29b28d0ce767c38c7facdf64f6c2cd040413055f1"}, + {file = "regex-2024.7.24-cp39-cp39-win_amd64.whl", hash = "sha256:7c479f5ae937ec9985ecaf42e2e10631551d909f203e31308c12d703922742f9"}, + {file = "regex-2024.7.24.tar.gz", hash = "sha256:9cfd009eed1a46b27c14039ad5bbc5e71b6367c5b2e6d5f5da0ea91600817506"}, ] [[package]] name = "requests" -version = "2.31.0" +version = "2.32.3" description = "Python HTTP for Humans." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [package.dependencies] @@ -1028,6 +1332,17 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "toml" version = "0.10.2" @@ -1051,187 +1366,157 @@ files = [ ] [[package]] -name = "typed-ast" -version = "1.5.5" -description = "a fork of Python 2 and 3 ast modules with type comment support" +name = "tqdm" +version = "4.66.5" +description = "Fast, Extensible Progress Meter" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "typed_ast-1.5.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4bc1efe0ce3ffb74784e06460f01a223ac1f6ab31c6bc0376a21184bf5aabe3b"}, - {file = "typed_ast-1.5.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5f7a8c46a8b333f71abd61d7ab9255440d4a588f34a21f126bbfc95f6049e686"}, - {file = "typed_ast-1.5.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:597fc66b4162f959ee6a96b978c0435bd63791e31e4f410622d19f1686d5e769"}, - {file = "typed_ast-1.5.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d41b7a686ce653e06c2609075d397ebd5b969d821b9797d029fccd71fdec8e04"}, - {file = "typed_ast-1.5.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5fe83a9a44c4ce67c796a1b466c270c1272e176603d5e06f6afbc101a572859d"}, - {file = "typed_ast-1.5.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d5c0c112a74c0e5db2c75882a0adf3133adedcdbfd8cf7c9d6ed77365ab90a1d"}, - {file = "typed_ast-1.5.5-cp310-cp310-win_amd64.whl", hash = "sha256:e1a976ed4cc2d71bb073e1b2a250892a6e968ff02aa14c1f40eba4f365ffec02"}, - {file = "typed_ast-1.5.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c631da9710271cb67b08bd3f3813b7af7f4c69c319b75475436fcab8c3d21bee"}, - {file = "typed_ast-1.5.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b445c2abfecab89a932b20bd8261488d574591173d07827c1eda32c457358b18"}, - {file = "typed_ast-1.5.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc95ffaaab2be3b25eb938779e43f513e0e538a84dd14a5d844b8f2932593d88"}, - {file = "typed_ast-1.5.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61443214d9b4c660dcf4b5307f15c12cb30bdfe9588ce6158f4a005baeb167b2"}, - {file = "typed_ast-1.5.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6eb936d107e4d474940469e8ec5b380c9b329b5f08b78282d46baeebd3692dc9"}, - {file = "typed_ast-1.5.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e48bf27022897577d8479eaed64701ecaf0467182448bd95759883300ca818c8"}, - {file = "typed_ast-1.5.5-cp311-cp311-win_amd64.whl", hash = "sha256:83509f9324011c9a39faaef0922c6f720f9623afe3fe220b6d0b15638247206b"}, - {file = "typed_ast-1.5.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:44f214394fc1af23ca6d4e9e744804d890045d1643dd7e8229951e0ef39429b5"}, - {file = "typed_ast-1.5.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:118c1ce46ce58fda78503eae14b7664163aa735b620b64b5b725453696f2a35c"}, - {file = "typed_ast-1.5.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be4919b808efa61101456e87f2d4c75b228f4e52618621c77f1ddcaae15904fa"}, - {file = "typed_ast-1.5.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:fc2b8c4e1bc5cd96c1a823a885e6b158f8451cf6f5530e1829390b4d27d0807f"}, - {file = "typed_ast-1.5.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:16f7313e0a08c7de57f2998c85e2a69a642e97cb32f87eb65fbfe88381a5e44d"}, - {file = "typed_ast-1.5.5-cp36-cp36m-win_amd64.whl", hash = "sha256:2b946ef8c04f77230489f75b4b5a4a6f24c078be4aed241cfabe9cbf4156e7e5"}, - {file = "typed_ast-1.5.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2188bc33d85951ea4ddad55d2b35598b2709d122c11c75cffd529fbc9965508e"}, - {file = "typed_ast-1.5.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0635900d16ae133cab3b26c607586131269f88266954eb04ec31535c9a12ef1e"}, - {file = "typed_ast-1.5.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:57bfc3cf35a0f2fdf0a88a3044aafaec1d2f24d8ae8cd87c4f58d615fb5b6311"}, - {file = "typed_ast-1.5.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:fe58ef6a764de7b4b36edfc8592641f56e69b7163bba9f9c8089838ee596bfb2"}, - {file = "typed_ast-1.5.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d09d930c2d1d621f717bb217bf1fe2584616febb5138d9b3e8cdd26506c3f6d4"}, - {file = "typed_ast-1.5.5-cp37-cp37m-win_amd64.whl", hash = "sha256:d40c10326893ecab8a80a53039164a224984339b2c32a6baf55ecbd5b1df6431"}, - {file = "typed_ast-1.5.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fd946abf3c31fb50eee07451a6aedbfff912fcd13cf357363f5b4e834cc5e71a"}, - {file = "typed_ast-1.5.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ed4a1a42df8a3dfb6b40c3d2de109e935949f2f66b19703eafade03173f8f437"}, - {file = "typed_ast-1.5.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:045f9930a1550d9352464e5149710d56a2aed23a2ffe78946478f7b5416f1ede"}, - {file = "typed_ast-1.5.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381eed9c95484ceef5ced626355fdc0765ab51d8553fec08661dce654a935db4"}, - {file = "typed_ast-1.5.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bfd39a41c0ef6f31684daff53befddae608f9daf6957140228a08e51f312d7e6"}, - {file = "typed_ast-1.5.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8c524eb3024edcc04e288db9541fe1f438f82d281e591c548903d5b77ad1ddd4"}, - {file = "typed_ast-1.5.5-cp38-cp38-win_amd64.whl", hash = "sha256:7f58fabdde8dcbe764cef5e1a7fcb440f2463c1bbbec1cf2a86ca7bc1f95184b"}, - {file = "typed_ast-1.5.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:042eb665ff6bf020dd2243307d11ed626306b82812aba21836096d229fdc6a10"}, - {file = "typed_ast-1.5.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:622e4a006472b05cf6ef7f9f2636edc51bda670b7bbffa18d26b255269d3d814"}, - {file = "typed_ast-1.5.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1efebbbf4604ad1283e963e8915daa240cb4bf5067053cf2f0baadc4d4fb51b8"}, - {file = "typed_ast-1.5.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0aefdd66f1784c58f65b502b6cf8b121544680456d1cebbd300c2c813899274"}, - {file = "typed_ast-1.5.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:48074261a842acf825af1968cd912f6f21357316080ebaca5f19abbb11690c8a"}, - {file = "typed_ast-1.5.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:429ae404f69dc94b9361bb62291885894b7c6fb4640d561179548c849f8492ba"}, - {file = "typed_ast-1.5.5-cp39-cp39-win_amd64.whl", hash = "sha256:335f22ccb244da2b5c296e6f96b06ee9bed46526db0de38d2f0e5a6597b81155"}, - {file = "typed_ast-1.5.5.tar.gz", hash = "sha256:94282f7a354f36ef5dbce0ef3467ebf6a258e370ab33d5b40c249fa996e590dd"}, + {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, + {file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"}, ] +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "typing-extensions" -version = "4.7.1" -description = "Backported and Experimental Type Hints for Python 3.7+" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, - {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] name = "urllib3" -version = "2.0.3" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "urllib3-2.0.3-py3-none-any.whl", hash = "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1"}, - {file = "urllib3-2.0.3.tar.gz", hash = "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] [[package]] name = "yarl" -version = "1.9.2" +version = "1.9.4" description = "Yet another URL library" optional = false python-versions = ">=3.7" files = [ - {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"}, - {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"}, - {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"}, - {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"}, - {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"}, - {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"}, - {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"}, - {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"}, - {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"}, - {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"}, - {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"}, - {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"}, - {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"}, - {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"}, - {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, + {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, + {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, + {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, + {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, + {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, + {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, + {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, + {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, + {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, + {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, + {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, + {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, + {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, + {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, + {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, ] [package.dependencies] idna = ">=2.0" multidict = ">=4.0" -typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} - -[[package]] -name = "zipp" -version = "3.15.0" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.7" -files = [ - {file = "zipp-3.15.0-py3-none-any.whl", hash = "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"}, - {file = "zipp-3.15.0.tar.gz", hash = "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [metadata] lock-version = "2.0" -python-versions = "^3.7" -content-hash = "e172656b142f767ce252f458226edc093bec9cee800a0a608340742d11bfa911" +python-versions = "^3.8" +content-hash = "4572f90730a8c15e31847b5238b491116780f50b02fd1b08e45d6353baac1bf8" diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 30e680d2..2a9a8a1c 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta36" +version = "0.0.0.beta37" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] @@ -12,10 +12,11 @@ packages = [{include = "llmengine"}] [tool.poetry.dependencies] -python = "^3.7" -pydantic = ">=1.10" +python = "^3.8" +pydantic = ">=2.0" aiohttp = "^3.8" requests = "^2.31.0" +openai = "^1.30.0" [tool.poetry.dev-dependencies] pytest = "^6.2.5" diff --git a/clients/python/setup.py b/clients/python/setup.py index d486a8d6..4aa4832a 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -2,8 +2,8 @@ setup( name="scale-llm-engine", - python_requires=">=3.7", - version="0.0.0.beta36", + python_requires=">=3.8", + version="0.0.0.beta37", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index 851f0183..164768c8 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -21,6 +21,7 @@ from model_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1 from model_engine_server.api.tasks_v1 import inference_task_router_v1 from model_engine_server.api.triggers_v1 import trigger_router_v1 +from model_engine_server.api.v2 import llm_router_v2 from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.core.loggers import ( LoggerTagKey, @@ -83,7 +84,10 @@ async def dispatch(self, request: Request, call_next): app = FastAPI( - title="launch", version="1.0.0", redoc_url="/api", middleware=[Middleware(CustomMiddleware)] + title="launch", + version="1.0.0", + redoc_url="/api", + middleware=[Middleware(CustomMiddleware)], ) app.include_router(batch_job_router_v1) @@ -96,6 +100,7 @@ async def dispatch(self, request: Request, call_next): app.include_router(llm_router_v1) app.include_router(file_router_v1) app.include_router(trigger_router_v1) +app.include_router(llm_router_v2) # TODO: Remove this once we have a better way to serve internal docs diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index dc54c188..0fc61abb 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -45,6 +45,9 @@ LLMModelEndpointService, ModelEndpointService, ) +from model_engine_server.domain.services.llm_batch_completions_service import ( + LLMBatchCompletionsService, +) from model_engine_server.inference.domain.gateways.streaming_storage_gateway import ( StreamingStorageGateway, ) @@ -118,6 +121,9 @@ LiveBatchJobService, LiveModelEndpointService, ) +from model_engine_server.infra.services.live_llm_batch_completions_service import ( + LiveLLMBatchCompletionsService, +) from model_engine_server.infra.services.live_llm_model_endpoint_service import ( LiveLLMModelEndpointService, ) @@ -142,6 +148,7 @@ class ExternalInterfaces: model_endpoint_service: ModelEndpointService batch_job_service: BatchJobService llm_model_endpoint_service: LLMModelEndpointService + llm_batch_completions_service: LLMBatchCompletionsService llm_fine_tuning_service: LLMFineTuningService llm_fine_tune_events_repository: LLMFineTuneEventsRepository @@ -319,6 +326,10 @@ def _get_external_interfaces( llm_fine_tune_repository=llm_fine_tune_repository, ) + llm_batch_completions_service = LiveLLMBatchCompletionsService( + docker_image_batch_job_gateway=docker_image_batch_job_gateway + ) + file_storage_gateway = ( ABSFileStorageGateway() if infra_config().cloud_provider == "azure" @@ -342,6 +353,7 @@ def _get_external_interfaces( model_bundle_repository=model_bundle_repository, model_endpoint_service=model_endpoint_service, llm_model_endpoint_service=llm_model_endpoint_service, + llm_batch_completions_service=llm_batch_completions_service, batch_job_service=batch_job_service, resource_gateway=resource_gateway, endpoint_creation_task_queue_gateway=infra_task_queue_gateway, diff --git a/model-engine/model_engine_server/api/v2/__init__.py b/model-engine/model_engine_server/api/v2/__init__.py new file mode 100644 index 00000000..dbcb4d67 --- /dev/null +++ b/model-engine/model_engine_server/api/v2/__init__.py @@ -0,0 +1,10 @@ +from typing import Sequence + +from fastapi import APIRouter + +from .batch_completion import batch_completions_router_v2 + +llm_router_v2 = APIRouter(prefix="/v2") +llm_router_v2.include_router(batch_completions_router_v2) + +__all__: Sequence[str] = ("llm_router_v2",) diff --git a/model-engine/model_engine_server/api/v2/batch_completion.py b/model-engine/model_engine_server/api/v2/batch_completion.py new file mode 100644 index 00000000..9412f945 --- /dev/null +++ b/model-engine/model_engine_server/api/v2/batch_completion.py @@ -0,0 +1,150 @@ +from fastapi import APIRouter, Depends, HTTPException +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.common.dtos.llms.batch_completion import ( + CancelBatchCompletionsV2Response, + CreateBatchCompletionsV2Request, + CreateBatchCompletionsV2Response, + GetBatchCompletionV2Response, + UpdateBatchCompletionsV2Request, + UpdateBatchCompletionsV2Response, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( + ObjectNotAuthorizedException, + ObjectNotFoundException, +) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + CancelBatchCompletionV2UseCase, + CreateBatchCompletionsV2UseCase, + GetBatchCompletionV2UseCase, + UpdateBatchCompletionV2UseCase, +) + +from .common import get_metric_metadata, record_route_call + +logger = make_logger(logger_name()) + + +batch_completions_router_v2 = APIRouter( + prefix="/batch-completions", dependencies=[Depends(record_route_call)] +) + + +@batch_completions_router_v2.post("/", response_model=CreateBatchCompletionsV2Response) +async def batch_completions( + request: CreateBatchCompletionsV2Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> CreateBatchCompletionsV2Response: + logger.info(f"POST /v2/batch-completions {request} for {auth}") + try: + use_case = CreateBatchCompletionsV2UseCase( + llm_batch_completions_service=external_interfaces.llm_batch_completions_service, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + ) + + return await use_case.execute(request, user=auth) + except ObjectNotFoundException as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + + except Exception as exc: + logger.exception(f"Error processing request {request} for {auth}") + raise HTTPException( + status_code=500, + detail="Internal server error", + ) from exc + + +@batch_completions_router_v2.get( + "/{batch_completion_id}", + response_model=GetBatchCompletionV2Response, +) +async def get_batch_completion( + batch_completion_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +) -> GetBatchCompletionV2Response: + logger.info(f"GET /v2/batch-completions/{batch_completion_id} for {auth}") + try: + use_case = GetBatchCompletionV2UseCase( + llm_batch_completions_service=external_interfaces.llm_batch_completions_service, + ) + return await use_case.execute(batch_completion_id=batch_completion_id, user=auth) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + + +@batch_completions_router_v2.post( + "/{batch_completion_id}", + response_model=UpdateBatchCompletionsV2Response, +) +async def update_batch_completion( + batch_completion_id: str, + request: UpdateBatchCompletionsV2Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> UpdateBatchCompletionsV2Response: + logger.info(f"POST /v2/batch-completions/{batch_completion_id} {request} for {auth}") + try: + use_case = UpdateBatchCompletionV2UseCase( + llm_batch_completions_service=external_interfaces.llm_batch_completions_service, + ) + return await use_case.execute( + batch_completion_id=batch_completion_id, request=request, user=auth + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except Exception as exc: + logger.exception(f"Error processing request {request} for {auth}", exc_info=exc) + raise HTTPException( + status_code=500, + detail="Internal server error", + ) from exc + + +@batch_completions_router_v2.post( + "/{batch_completion_id}/actions/cancel", + response_model=CancelBatchCompletionsV2Response, +) +async def cancel_batch_completion( + batch_completion_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> CancelBatchCompletionsV2Response: + logger.info(f"POST /v2/batch-completions/{batch_completion_id}/actions/cancel for {auth}") + try: + use_case = CancelBatchCompletionV2UseCase( + llm_batch_completions_service=external_interfaces.llm_batch_completions_service, + ) + return await use_case.execute(batch_completion_id=batch_completion_id, user=auth) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except Exception as exc: + logger.exception( + f"Error canceling batch completions {batch_completion_id} for {auth}", + exc_info=exc, + ) + raise HTTPException( + status_code=500, + detail="Internal server error", + ) from exc diff --git a/model-engine/model_engine_server/api/v2/common.py b/model-engine/model_engine_server/api/v2/common.py new file mode 100644 index 00000000..2099c3b6 --- /dev/null +++ b/model-engine/model_engine_server/api/v2/common.py @@ -0,0 +1,33 @@ +from fastapi import Depends, Request +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata + + +def format_request_route(request: Request) -> str: + url_path = request.url.path + for path_param in request.path_params: + url_path = url_path.replace(request.path_params[path_param], f":{path_param}") + return f"{request.method}_{url_path}".lower() + + +async def get_metric_metadata( + request: Request, + auth: User = Depends(verify_authentication), +) -> MetricMetadata: + print("body") + print(request.body) + model_name = request.query_params.get("model", None) + return MetricMetadata(user=auth, model_name=model_name) + + +async def record_route_call( + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + route: str = Depends(format_request_route), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +): + external_interfaces.monitoring_metrics_gateway.emit_route_call_metric(route, metric_metadata) diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py index 248056cb..6b7ebfb5 100644 --- a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -7,10 +7,11 @@ ChatCompletionV2Response, ) from model_engine_server.common.dtos.llms.completion import ( + CompletionOutput, CompletionV2Request, CompletionV2Response, ) -from model_engine_server.common.pydantic_types import BaseModel, Field +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field from typing_extensions import TypeAlias @@ -66,6 +67,11 @@ class BatchCompletionsModelConfig(BaseModel): seed: Optional[int] = Field(default=None, description="Random seed for the model.") + response_role: Optional[str] = Field( + default=None, + description="Role of the response in the conversation. Only supported in chat completions.", + ) + class BatchCompletionsRequestBase(BaseModel): input_data_path: Optional[str] = Field( @@ -94,7 +100,7 @@ class BatchCompletionsRequestBase(BaseModel): description="Maximum runtime of the batch inference in seconds. Default to one day.", ) - priority: Optional[int] = Field( + priority: Optional[str] = Field( default=None, description="Priority of the batch inference job. Default to None.", ) @@ -108,6 +114,9 @@ class BatchCompletionsRequestBase(BaseModel): # V1 DTOs for batch completions +CompletionV1Output = CompletionOutput + + class CreateBatchCompletionsV1ModelConfig(BatchCompletionsModelConfig): labels: Dict[str, str] = Field( default={}, description="Labels to attach to the batch inference job." @@ -176,10 +185,22 @@ class CreateBatchCompletionsV1Response(BaseModel): job_id: str +class FilteredCompletionV2Request(CompletionV2Request): + model: Optional[str] = None # type: ignore[assignment] + stream: Optional[bool] = False + + +class FilteredChatCompletionV2Request(ChatCompletionV2Request): + model: Optional[str] = None # type: ignore[assignment] + stream: Optional[bool] = False + + # V2 DTOs for batch completions -CompletionRequest: TypeAlias = Union[CompletionV2Request, ChatCompletionV2Request] +CompletionRequest: TypeAlias = Union[FilteredCompletionV2Request, FilteredChatCompletionV2Request] CompletionResponse: TypeAlias = Union[CompletionV2Response, ChatCompletionV2Response] -CreateBatchCompletionsV2RequestContent: TypeAlias = List[CompletionRequest] +CreateBatchCompletionsV2RequestContent: TypeAlias = Union[ + List[FilteredCompletionV2Request], List[FilteredChatCompletionV2Request] +] CreateBatchCompletionsV2ModelConfig: TypeAlias = BatchCompletionsModelConfig @@ -204,12 +225,13 @@ class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): ) -class BatchCompletionsJobStatus(Enum): +class BatchCompletionsJobStatus(str, Enum): Queued = "queued" Running = "running" Completed = "completed" Failed = "failed" Cancelled = "cancelled" + Unknown = "unknown" class BatchCompletionsJob(BaseModel): @@ -243,6 +265,26 @@ class BatchCompletionsJob(BaseModel): CreateBatchCompletionsV2Response: TypeAlias = BatchCompletionsJob +class UpdateBatchCompletionsV2Request(BaseModel): + job_id: str = Field(description="ID of the batch completions job") + priority: Optional[int] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + + +class UpdateBatchCompletionsV2Response(BatchCompletionsJob): + success: bool = Field(description="Whether the update was successful") + + +class CancelBatchCompletionsV2Request(BaseModel): + job_id: str = Field(description="ID of the batch completions job") + + +class CancelBatchCompletionsV2Response(BaseModel): + success: bool = Field(description="Whether the cancellation was successful") + + class ListBatchCompletionV2Response(BaseModel): jobs: List[BatchCompletionsJob] @@ -275,12 +317,15 @@ class CreateBatchCompletionsEngineRequest(BatchCompletionsRequestBase, VLLMEngin hidden from the DTO exposed to the client. """ + model_config = ConfigDict(populate_by_name=True, protected_namespaces=()) + content: Optional[BatchCompletionContent] = Field( default=None, description="Content is a union of the content from v1 and v2 requests.", ) model_cfg: BatchCompletionsModelConfig = Field( + alias="model_config", description="""Model configuration for the batch inference. Hardware configurations are inferred.""", ) diff --git a/model-engine/model_engine_server/common/dtos/llms/completion.py b/model-engine/model_engine_server/common/dtos/llms/completion.py index 25dc0caa..a680a2b7 100644 --- a/model-engine/model_engine_server/common/dtos/llms/completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/completion.py @@ -75,24 +75,56 @@ class CompletionSyncV1Request(BaseModel): class TokenOutput(BaseModel): + """ + Detailed token information. + """ + token: str + """ + The token text. + """ + log_prob: float + """ + The log probability of the token. + """ class CompletionOutput(BaseModel): + """ + Represents the output of a completion request to a model. + """ + text: str - num_prompt_tokens: int + """The text of the completion.""" + + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. + num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + num_completion_tokens: int + """Number of tokens in the completion.""" + tokens: Optional[List[TokenOutput]] = None + """Detailed token information.""" class CompletionSyncV1Response(BaseModel): """ - Response object for a synchronous prompt completion task. + Response object for a synchronous prompt completion. """ request_id: Optional[str] = None + """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs + associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as + follows: + + * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. + * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability + provider.""" + output: Optional[CompletionOutput] = None + """Completion output.""" class CompletionStreamV1Request(BaseModel): @@ -160,10 +192,20 @@ class CompletionStreamV1Request(BaseModel): class CompletionStreamOutput(BaseModel): text: str + """The text of the completion.""" + finished: bool + """Whether the completion is finished.""" + + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + num_completion_tokens: Optional[int] = None + """Number of tokens in the completion.""" + token: Optional[TokenOutput] = None + """Detailed token information.""" class StreamErrorContent(BaseModel): @@ -185,12 +227,24 @@ class StreamError(BaseModel): class CompletionStreamV1Response(BaseModel): + """Error of the response (if any).""" + """ Response object for a stream prompt completion task. """ - request_id: Optional[str] = None + request_id: Optional[str] + """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs + associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as + follows: + + * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. + * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability + provider.""" + output: Optional[CompletionStreamOutput] = None + """Completion output.""" + error: Optional[StreamError] = None """Error of the response (if any).""" diff --git a/model-engine/model_engine_server/common/types/gen/openai.py b/model-engine/model_engine_server/common/types/gen/openai.py index 2c58f13b..9ac7a40f 100644 --- a/model-engine/model_engine_server/common/types/gen/openai.py +++ b/model-engine/model_engine_server/common/types/gen/openai.py @@ -1,14 +1,13 @@ # generated by datamodel-codegen: # filename: openai-spec.yaml -# timestamp: 2024-07-26T21:34:42+00:00 +# timestamp: 2024-08-20T08:20:04+00:00 from __future__ import annotations -from enum import Enum from typing import Any, Dict, List, Optional, Union from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel -from typing_extensions import Annotated +from typing_extensions import Annotated, Literal class Error(BaseModel): @@ -22,22 +21,12 @@ class ErrorResponse(BaseModel): error: Error -class Object(Enum): - list = "list" - - class DeleteModelResponse(BaseModel): id: str deleted: bool object: str -class Model1(Enum): - gpt_3_5_turbo_instruct = "gpt-3.5-turbo-instruct" - davinci_002 = "davinci-002" - babbage_002 = "babbage-002" - - class Prompt(RootModel[Optional[List[int]]]): root: Annotated[ Optional[List[int]], @@ -78,12 +67,6 @@ class Stop(RootModel[Optional[List[str]]]): ] = None -class FinishReason(Enum): - stop = "stop" - length = "length" - content_filter = "content_filter" - - class Logprobs(BaseModel): text_offset: Optional[List[int]] = None token_logprobs: Optional[List[float]] = None @@ -93,7 +76,7 @@ class Logprobs(BaseModel): class Choice(BaseModel): finish_reason: Annotated[ - FinishReason, + Literal["stop", "length", "content_filter"], Field( description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\nor `content_filter` if content was omitted due to a flag from our content filters.\n" ), @@ -103,18 +86,9 @@ class Choice(BaseModel): text: str -class Object1(Enum): - text_completion = "text_completion" - - -class Type(Enum): - image_url = "image_url" - - -class Detail(Enum): - auto = "auto" - low = "low" - high = "high" +class ChatCompletionRequestMessageContentPartText(BaseModel): + type: Annotated[Literal["text"], Field(description="The type of the content part.")] + text: Annotated[str, Field(description="The text content.")] class ImageUrl(BaseModel): @@ -123,7 +97,7 @@ class ImageUrl(BaseModel): Field(description="Either a URL of the image or the base64 encoded image data."), ] detail: Annotated[ - Optional[Detail], + Optional[Literal["auto", "low", "high"]], Field( "auto", description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding).", @@ -132,27 +106,72 @@ class ImageUrl(BaseModel): class ChatCompletionRequestMessageContentPartImage(BaseModel): - type: Annotated[Type, Field(description="The type of the content part.")] + type: Annotated[Literal["image_url"], Field(description="The type of the content part.")] image_url: ImageUrl -class Type1(Enum): - text = "text" +class ChatCompletionRequestMessageContentPartRefusal(BaseModel): + type: Annotated[Literal["refusal"], Field(description="The type of the content part.")] + refusal: Annotated[str, Field(description="The refusal message generated by the model.")] -class ChatCompletionRequestMessageContentPartText(BaseModel): - type: Annotated[Type1, Field(description="The type of the content part.")] - text: Annotated[str, Field(description="The text content.")] +class ChatCompletionRequestSystemMessageContentPart( + RootModel[ChatCompletionRequestMessageContentPartText] +): + root: ChatCompletionRequestMessageContentPartText + + +class ChatCompletionRequestUserMessageContentPart( + RootModel[ + Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + ] +): + root: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + + +class ChatCompletionRequestAssistantMessageContentPart( + RootModel[ + Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartRefusal, + ] + ] +): + root: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartRefusal, + ] + + +class ChatCompletionRequestToolMessageContentPart( + RootModel[ChatCompletionRequestMessageContentPartText] +): + root: ChatCompletionRequestMessageContentPartText -class Role(Enum): - system = "system" +class Content(RootModel[List[ChatCompletionRequestSystemMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestSystemMessageContentPart], + Field( + description="An array of content parts with a defined type. For system messages, only type `text` is supported.", + min_length=1, + title="Array of content parts", + ), + ] class ChatCompletionRequestSystemMessage(BaseModel): - content: Annotated[str, Field(description="The contents of the system message.")] + content: Annotated[ + Union[str, Content], Field(description="The contents of the system message.") + ] role: Annotated[ - Role, + Literal["system"], Field(description="The role of the messages author, in this case `system`."), ] name: Annotated[ @@ -164,12 +183,44 @@ class ChatCompletionRequestSystemMessage(BaseModel): ] -class Role1(Enum): - user = "user" +class Content1(RootModel[List[ChatCompletionRequestUserMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestUserMessageContentPart], + Field( + description="An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model.", + min_length=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestUserMessage(BaseModel): + content: Annotated[ + Union[str, Content1], Field(description="The contents of the user message.\n") + ] + role: Annotated[ + Literal["user"], + Field(description="The role of the messages author, in this case `user`."), + ] + name: Annotated[ + Optional[str], + Field( + None, + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + ), + ] -class Role2(Enum): - assistant = "assistant" +class Content2(RootModel[Optional[List[ChatCompletionRequestAssistantMessageContentPart]]]): + root: Annotated[ + Optional[List[ChatCompletionRequestAssistantMessageContentPart]], + Field( + None, + description="An array of content parts with a defined type. Can be one or more of type `text`, or exactly one of type `refusal`.", + min_length=1, + title="Array of content parts", + ), + ] = None class FunctionCall(BaseModel): @@ -182,31 +233,29 @@ class FunctionCall(BaseModel): name: Annotated[str, Field(description="The name of the function to call.")] -class Weight(Enum): - integer_0 = 0 - integer_1 = 1 - - -class Role3(Enum): - tool = "tool" +class Content3(RootModel[List[ChatCompletionRequestToolMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestToolMessageContentPart], + Field( + description="An array of content parts with a defined type. For tool messages, only type `text` is supported.", + min_length=1, + title="Array of content parts", + ), + ] class ChatCompletionRequestToolMessage(BaseModel): role: Annotated[ - Role3, + Literal["tool"], Field(description="The role of the messages author, in this case `tool`."), ] - content: Annotated[str, Field(description="The contents of the tool message.")] + content: Annotated[Union[str, Content3], Field(description="The contents of the tool message.")] tool_call_id: Annotated[str, Field(description="Tool call that this message is responding to.")] -class Role4(Enum): - function = "function" - - class ChatCompletionRequestFunctionMessage(BaseModel): role: Annotated[ - Role4, + Literal["function"], Field(description="The role of the messages author, in this case `function`."), ] content: Annotated[str, Field(description="The contents of the function message.")] @@ -241,10 +290,6 @@ class ChatCompletionFunctionCallOption(BaseModel): name: Annotated[str, Field(description="The name of the function to call.")] -class Type2(Enum): - function = "function" - - class FunctionObject(BaseModel): description: Annotated[ Optional[str], @@ -260,12 +305,66 @@ class FunctionObject(BaseModel): ), ] parameters: Optional[FunctionParameters] = None + strict: Annotated[ + Optional[bool], + Field( + False, + description="Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling).", + ), + ] + + +class ResponseFormatText(BaseModel): + type: Annotated[ + Literal["text"], + Field(description="The type of response format being defined: `text`"), + ] + + +class ResponseFormatJsonObject(BaseModel): + type: Annotated[ + Literal["json_object"], + Field(description="The type of response format being defined: `json_object`"), + ] + + +class ResponseFormatJsonSchemaSchema(BaseModel): + pass + model_config = ConfigDict( + extra="allow", + ) + + +class JsonSchema(BaseModel): + description: Annotated[ + Optional[str], + Field( + None, + description="A description of what the response format is for, used by the model to determine how to respond in the format.", + ), + ] + name: Annotated[ + str, + Field( + description="The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + schema_: Annotated[Optional[ResponseFormatJsonSchemaSchema], Field(None, alias="schema")] + strict: Annotated[ + Optional[bool], + Field( + False, + description="Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs).", + ), + ] -class ChatCompletionToolChoiceOption1(Enum): - none = "none" - auto = "auto" - required = "required" +class ResponseFormatJsonSchema(BaseModel): + type: Annotated[ + Literal["json_schema"], + Field(description="The type of response format being defined: `json_schema`"), + ] + json_schema: JsonSchema class Function(BaseModel): @@ -274,7 +373,7 @@ class Function(BaseModel): class ChatCompletionNamedToolChoice(BaseModel): type: Annotated[ - Type2, + Literal["function"], Field(description="The type of the tool. Currently, only `function` is supported."), ] function: Function @@ -302,7 +401,7 @@ class Function1(BaseModel): class ChatCompletionMessageToolCall(BaseModel): id: Annotated[str, Field(description="The ID of the tool call.")] type: Annotated[ - Type2, + Literal["function"], Field(description="The type of the tool. Currently, only `function` is supported."), ] function: Annotated[Function1, Field(description="The function that the model called.")] @@ -323,7 +422,7 @@ class ChatCompletionMessageToolCallChunk(BaseModel): index: int id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] type: Annotated[ - Optional[Type2], + Optional[Literal["function"]], Field( None, description="The type of the tool. Currently, only `function` is supported.", @@ -332,12 +431,11 @@ class ChatCompletionMessageToolCallChunk(BaseModel): function: Optional[Function2] = None -class ChatCompletionRole(Enum): - system = "system" - user = "user" - assistant = "assistant" - tool = "tool" - function = "function" +class ChatCompletionRole(RootModel[Literal["system", "user", "assistant", "tool", "function"]]): + root: Annotated[ + Literal["system", "user", "assistant", "tool", "function"], + Field(description="The role of the author of a message"), + ] class ChatCompletionStreamOptions(BaseModel): @@ -350,10 +448,6 @@ class ChatCompletionStreamOptions(BaseModel): ] -class Role5(Enum): - assistant = "assistant" - - class FunctionCall2(BaseModel): arguments: Annotated[ Optional[str], @@ -365,13 +459,6 @@ class FunctionCall2(BaseModel): name: Annotated[Optional[str], Field(None, description="The name of the function to call.")] -class Role6(Enum): - system = "system" - user = "user" - assistant = "assistant" - tool = "tool" - - class ChatCompletionStreamResponseDelta(BaseModel): content: Annotated[Optional[str], Field(None, description="The contents of the chunk message.")] function_call: Annotated[ @@ -383,58 +470,15 @@ class ChatCompletionStreamResponseDelta(BaseModel): ] tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = None role: Annotated[ - Optional[Role6], + Optional[Literal["system", "user", "assistant", "tool"]], Field(None, description="The role of the author of this message."), ] - - -class Model2(Enum): - gpt_4o = "gpt-4o" - gpt_4o_2024_05_13 = "gpt-4o-2024-05-13" - gpt_4o_mini = "gpt-4o-mini" - gpt_4o_mini_2024_07_18 = "gpt-4o-mini-2024-07-18" - gpt_4_turbo = "gpt-4-turbo" - gpt_4_turbo_2024_04_09 = "gpt-4-turbo-2024-04-09" - gpt_4_0125_preview = "gpt-4-0125-preview" - gpt_4_turbo_preview = "gpt-4-turbo-preview" - gpt_4_1106_preview = "gpt-4-1106-preview" - gpt_4_vision_preview = "gpt-4-vision-preview" - gpt_4 = "gpt-4" - gpt_4_0314 = "gpt-4-0314" - gpt_4_0613 = "gpt-4-0613" - gpt_4_32k = "gpt-4-32k" - gpt_4_32k_0314 = "gpt-4-32k-0314" - gpt_4_32k_0613 = "gpt-4-32k-0613" - gpt_3_5_turbo = "gpt-3.5-turbo" - gpt_3_5_turbo_16k = "gpt-3.5-turbo-16k" - gpt_3_5_turbo_0301 = "gpt-3.5-turbo-0301" - gpt_3_5_turbo_0613 = "gpt-3.5-turbo-0613" - gpt_3_5_turbo_1106 = "gpt-3.5-turbo-1106" - gpt_3_5_turbo_0125 = "gpt-3.5-turbo-0125" - gpt_3_5_turbo_16k_0613 = "gpt-3.5-turbo-16k-0613" - - -class Type6(Enum): - text = "text" - json_object = "json_object" - - -class ResponseFormat(BaseModel): - type: Annotated[ - Optional[Type6], - Field( - "text", - description="Must be one of `text` or `json_object`.", - examples=["json_object"], - ), + refusal: Annotated[ + Optional[str], + Field(None, description="The refusal message generated by the model."), ] -class ServiceTier(Enum): - auto = "auto" - default = "default" - - class Stop1(RootModel[List[str]]): root: Annotated[ List[str], @@ -446,35 +490,6 @@ class Stop1(RootModel[List[str]]): ] -class FunctionCall3(Enum): - none = "none" - auto = "auto" - - -class FinishReason1(Enum): - stop = "stop" - length = "length" - tool_calls = "tool_calls" - content_filter = "content_filter" - function_call = "function_call" - - -class ServiceTier1(Enum): - scale = "scale" - default = "default" - - -class Object2(Enum): - chat_completion = "chat.completion" - - -class FinishReason2(Enum): - stop = "stop" - length = "length" - function_call = "function_call" - content_filter = "content_filter" - - class TopLogprob(BaseModel): token: Annotated[str, Field(description="The token.")] logprob: Annotated[ @@ -513,23 +528,15 @@ class ChatCompletionTokenLogprob(BaseModel): ] -class Object4(Enum): - list = "list" - - class Logprobs2(BaseModel): content: Annotated[ List[ChatCompletionTokenLogprob], Field(description="A list of message content tokens with log probability information."), ] - - -class FinishReason3(Enum): - stop = "stop" - length = "length" - tool_calls = "tool_calls" - content_filter = "content_filter" - function_call = "function_call" + refusal: Annotated[ + List[ChatCompletionTokenLogprob], + Field(description="A list of message refusal tokens with log probability information."), + ] class Choice3(BaseModel): @@ -539,7 +546,7 @@ class Choice3(BaseModel): Field(None, description="Log probability information for the choice."), ] finish_reason: Annotated[ - FinishReason3, + Literal["stop", "length", "tool_calls", "content_filter", "function_call"], Field( description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" ), @@ -547,10 +554,6 @@ class Choice3(BaseModel): index: Annotated[int, Field(description="The index of the choice in the list of choices.")] -class Object5(Enum): - chat_completion_chunk = "chat.completion.chunk" - - class Usage(BaseModel): completion_tokens: Annotated[ int, Field(description="Number of tokens in the generated completion.") @@ -583,7 +586,7 @@ class CreateChatCompletionStreamResponse(BaseModel): ] model: Annotated[str, Field(description="The model to generate the completion.")] service_tier: Annotated[ - Optional[ServiceTier1], + Optional[Literal["scale", "default"]], Field( None, description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", @@ -598,7 +601,7 @@ class CreateChatCompletionStreamResponse(BaseModel): ), ] object: Annotated[ - Object5, + Literal["chat.completion.chunk"], Field(description="The object type, which is always `chat.completion.chunk`."), ] usage: Annotated[ @@ -614,34 +617,6 @@ class CreateChatCompletionImageResponse(BaseModel): pass -class Model3(Enum): - dall_e_2 = "dall-e-2" - dall_e_3 = "dall-e-3" - - -class Quality(Enum): - standard = "standard" - hd = "hd" - - -class ResponseFormat1(Enum): - url = "url" - b64_json = "b64_json" - - -class Size(Enum): - field_256x256 = "256x256" - field_512x512 = "512x512" - field_1024x1024 = "1024x1024" - field_1792x1024 = "1792x1024" - field_1024x1792 = "1024x1792" - - -class Style(Enum): - vivid = "vivid" - natural = "natural" - - class CreateImageRequest(BaseModel): prompt: Annotated[ str, @@ -651,7 +626,7 @@ class CreateImageRequest(BaseModel): ), ] model: Annotated[ - Optional[Union[str, Model3]], + Optional[Union[str, Literal["dall-e-2", "dall-e-3"]]], Field( "dall-e-2", description="The model to use for image generation.", @@ -669,7 +644,7 @@ class CreateImageRequest(BaseModel): ), ] quality: Annotated[ - Optional[Quality], + Optional[Literal["standard", "hd"]], Field( "standard", description="The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`.", @@ -677,7 +652,7 @@ class CreateImageRequest(BaseModel): ), ] response_format: Annotated[ - Optional[ResponseFormat1], + Optional[Literal["url", "b64_json"]], Field( "url", description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", @@ -685,7 +660,7 @@ class CreateImageRequest(BaseModel): ), ] size: Annotated[ - Optional[Size], + Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]], Field( "1024x1024", description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models.", @@ -693,7 +668,7 @@ class CreateImageRequest(BaseModel): ), ] style: Annotated[ - Optional[Style], + Optional[Literal["vivid", "natural"]], Field( "vivid", description="The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`.", @@ -734,16 +709,6 @@ class Image(BaseModel): ] -class Model4(Enum): - dall_e_2 = "dall-e-2" - - -class Size1(Enum): - field_256x256 = "256x256" - field_512x512 = "512x512" - field_1024x1024 = "1024x1024" - - class CreateImageEditRequest(BaseModel): image: Annotated[ bytes, @@ -766,7 +731,7 @@ class CreateImageEditRequest(BaseModel): ), ] model: Annotated[ - Optional[Union[str, Model4]], + Optional[Union[str, Literal["dall-e-2"]]], Field( "dall-e-2", description="The model to use for image generation. Only `dall-e-2` is supported at this time.", @@ -784,7 +749,7 @@ class CreateImageEditRequest(BaseModel): ), ] size: Annotated[ - Optional[Size1], + Optional[Literal["256x256", "512x512", "1024x1024"]], Field( "1024x1024", description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", @@ -792,7 +757,7 @@ class CreateImageEditRequest(BaseModel): ), ] response_format: Annotated[ - Optional[ResponseFormat1], + Optional[Literal["url", "b64_json"]], Field( "url", description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", @@ -817,7 +782,7 @@ class CreateImageVariationRequest(BaseModel): ), ] model: Annotated[ - Optional[Union[str, Model4]], + Optional[Union[str, Literal["dall-e-2"]]], Field( "dall-e-2", description="The model to use for image generation. Only `dall-e-2` is supported at this time.", @@ -835,7 +800,7 @@ class CreateImageVariationRequest(BaseModel): ), ] response_format: Annotated[ - Optional[ResponseFormat1], + Optional[Literal["url", "b64_json"]], Field( "url", description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", @@ -843,7 +808,7 @@ class CreateImageVariationRequest(BaseModel): ), ] size: Annotated[ - Optional[Size1], + Optional[Literal["256x256", "512x512", "1024x1024"]], Field( "1024x1024", description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", @@ -860,15 +825,10 @@ class CreateImageVariationRequest(BaseModel): ] -class Model6(Enum): - text_moderation_latest = "text-moderation-latest" - text_moderation_stable = "text-moderation-stable" - - class CreateModerationRequest(BaseModel): input: Annotated[Union[str, List[str]], Field(description="The input text to classify")] model: Annotated[ - Optional[Union[str, Model6]], + Optional[Union[str, Literal["text-moderation-latest", "text-moderation-stable"]]], Field( "text-moderation-latest", description="Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`.\n\nThe default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`.\n", @@ -1024,37 +984,22 @@ class CreateModerationResponse(BaseModel): results: Annotated[List[Result], Field(description="A list of moderation objects.")] -class Object6(Enum): - list = "list" - - -class Purpose(Enum): - assistants = "assistants" - batch = "batch" - fine_tune = "fine-tune" - vision = "vision" - - class CreateFileRequest(BaseModel): model_config = ConfigDict( extra="forbid", ) file: Annotated[bytes, Field(description="The File object (not file name) to be uploaded.\n")] purpose: Annotated[ - Purpose, + Literal["assistants", "batch", "fine-tune", "vision"], Field( description='The intended purpose of the uploaded file.\n\nUse "assistants" for [Assistants](/docs/api-reference/assistants) and [Message](/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](/docs/guides/batch), and "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning).\n' ), ] -class Object7(Enum): - file = "file" - - class DeleteFileResponse(BaseModel): id: str - object: Object7 + object: Literal["file"] deleted: bool @@ -1064,7 +1009,7 @@ class CreateUploadRequest(BaseModel): ) filename: Annotated[str, Field(description="The name of the file to upload.\n")] purpose: Annotated[ - Purpose, + Literal["assistants", "batch", "fine-tune", "vision"], Field( description="The intended purpose of the uploaded file.\n\nSee the [documentation on File purposes](/docs/api-reference/files/create#files-create-purpose).\n" ), @@ -1106,17 +1051,7 @@ class CancelUploadRequest(BaseModel): ) -class Model7(Enum): - babbage_002 = "babbage-002" - davinci_002 = "davinci-002" - gpt_3_5_turbo = "gpt-3.5-turbo" - - -class BatchSize(Enum): - auto = "auto" - - -class BatchSize1(RootModel[int]): +class BatchSize(RootModel[int]): root: Annotated[ int, Field( @@ -1127,11 +1062,7 @@ class BatchSize1(RootModel[int]): ] -class LearningRateMultiplier(Enum): - auto = "auto" - - -class LearningRateMultiplier1(RootModel[float]): +class LearningRateMultiplier(RootModel[float]): root: Annotated[ float, Field( @@ -1141,11 +1072,7 @@ class LearningRateMultiplier1(RootModel[float]): ] -class NEpochs(Enum): - auto = "auto" - - -class NEpochs1(RootModel[int]): +class NEpochs(RootModel[int]): root: Annotated[ int, Field( @@ -1158,21 +1085,21 @@ class NEpochs1(RootModel[int]): class Hyperparameters(BaseModel): batch_size: Annotated[ - Optional[Union[BatchSize, BatchSize1]], + Optional[Union[Literal["auto"], BatchSize]], Field( "auto", description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", ), ] learning_rate_multiplier: Annotated[ - Optional[Union[LearningRateMultiplier, LearningRateMultiplier1]], + Optional[Union[Literal["auto"], LearningRateMultiplier]], Field( "auto", description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", ), ] n_epochs: Annotated[ - Optional[Union[NEpochs, NEpochs1]], + Optional[Union[Literal["auto"], NEpochs]], Field( "auto", description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", @@ -1180,10 +1107,6 @@ class Hyperparameters(BaseModel): ] -class Type7(Enum): - wandb = "wandb" - - class Wandb(BaseModel): project: Annotated[ str, @@ -1217,7 +1140,7 @@ class Wandb(BaseModel): class Integration(BaseModel): type: Annotated[ - Type7, + Literal["wandb"], Field( description='The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported.\n' ), @@ -1232,10 +1155,10 @@ class Integration(BaseModel): class CreateFineTuningJobRequest(BaseModel): model: Annotated[ - Union[str, Model7], + Union[str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo", "gpt-4o-mini"]], Field( - description="The name of the model to fine-tune. You can select one of the\n[supported models](/docs/guides/fine-tuning/what-models-can-be-fine-tuned).\n", - examples=["gpt-3.5-turbo"], + description="The name of the model to fine-tune. You can select one of the\n[supported models](/docs/guides/fine-tuning/which-models-can-be-fine-tuned).\n", + examples=["gpt-4o-mini"], ), ] training_file: Annotated[ @@ -1253,7 +1176,7 @@ class CreateFineTuningJobRequest(BaseModel): Optional[str], Field( None, - description='A string of up to 18 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-3.5-turbo:openai:custom-model-name:7p4lURel`.\n', + description='A string of up to 18 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`.\n', max_length=40, min_length=1, ), @@ -1285,10 +1208,6 @@ class CreateFineTuningJobRequest(BaseModel): ] -class Object8(Enum): - list = "list" - - class Input(RootModel[List[str]]): root: Annotated[ List[str], @@ -1332,17 +1251,6 @@ class Input2(RootModel[List[Input2Item]]): ] -class Model8(Enum): - text_embedding_ada_002 = "text-embedding-ada-002" - text_embedding_3_small = "text-embedding-3-small" - text_embedding_3_large = "text-embedding-3-large" - - -class EncodingFormat(Enum): - float = "float" - base64 = "base64" - - class CreateEmbeddingRequest(BaseModel): model_config = ConfigDict( extra="forbid", @@ -1355,14 +1263,21 @@ class CreateEmbeddingRequest(BaseModel): ), ] model: Annotated[ - Union[str, Model8], + Union[ + str, + Literal[ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ], + ], Field( description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", examples=["text-embedding-3-small"], ), ] encoding_format: Annotated[ - Optional[EncodingFormat], + Optional[Literal["float", "base64"]], Field( "float", description="The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).", @@ -1394,23 +1309,6 @@ class Usage1(BaseModel): ] -class Model9(Enum): - whisper_1 = "whisper-1" - - -class ResponseFormat4(Enum): - json = "json" - text = "text" - srt = "srt" - verbose_json = "verbose_json" - vtt = "vtt" - - -class TimestampGranularity(Enum): - word = "word" - segment = "segment" - - class CreateTranscriptionRequest(BaseModel): model_config = ConfigDict( extra="forbid", @@ -1422,7 +1320,7 @@ class CreateTranscriptionRequest(BaseModel): ), ] model: Annotated[ - Union[str, Model9], + Union[str, Literal["whisper-1"]], Field( description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", examples=["whisper-1"], @@ -1443,7 +1341,7 @@ class CreateTranscriptionRequest(BaseModel): ), ] response_format: Annotated[ - Optional[ResponseFormat4], + Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]], Field( "json", description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", @@ -1457,7 +1355,7 @@ class CreateTranscriptionRequest(BaseModel): ), ] timestamp_granularities__: Annotated[ - Optional[List[TimestampGranularity]], + Optional[List[Literal["word", "segment"]]], Field( ["segment"], alias="timestamp_granularities[]", @@ -1535,7 +1433,7 @@ class CreateTranslationRequest(BaseModel): ), ] model: Annotated[ - Union[str, Model9], + Union[str, Literal["whisper-1"]], Field( description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", examples=["whisper-1"], @@ -1584,35 +1482,12 @@ class CreateTranslationResponseVerboseJson(BaseModel): ] -class Model11(Enum): - tts_1 = "tts-1" - tts_1_hd = "tts-1-hd" - - -class Voice(Enum): - alloy = "alloy" - echo = "echo" - fable = "fable" - onyx = "onyx" - nova = "nova" - shimmer = "shimmer" - - -class ResponseFormat5(Enum): - mp3 = "mp3" - opus = "opus" - aac = "aac" - flac = "flac" - wav = "wav" - pcm = "pcm" - - class CreateSpeechRequest(BaseModel): model_config = ConfigDict( extra="forbid", ) model: Annotated[ - Union[str, Model11], + Union[str, Literal["tts-1", "tts-1-hd"]], Field( description="One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd`\n" ), @@ -1625,13 +1500,13 @@ class CreateSpeechRequest(BaseModel): ), ] voice: Annotated[ - Voice, + Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], Field( description="The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are available in the [Text to speech guide](/docs/guides/text-to-speech/voice-options)." ), ] response_format: Annotated[ - Optional[ResponseFormat5], + Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]], Field( "mp3", description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`.", @@ -1648,10 +1523,6 @@ class CreateSpeechRequest(BaseModel): ] -class Object11(Enum): - model = "model" - - class Model(BaseModel): id: Annotated[ str, @@ -1661,30 +1532,12 @@ class Model(BaseModel): int, Field(description="The Unix timestamp (in seconds) when the model was created."), ] - object: Annotated[Object11, Field(description='The object type, which is always "model".')] + object: Annotated[ + Literal["model"], Field(description='The object type, which is always "model".') + ] owned_by: Annotated[str, Field(description="The organization that owns the model.")] -class Object12(Enum): - file = "file" - - -class Purpose2(Enum): - assistants = "assistants" - assistants_output = "assistants_output" - batch = "batch" - batch_output = "batch_output" - fine_tune = "fine-tune" - fine_tune_results = "fine-tune-results" - vision = "vision" - - -class Status(Enum): - uploaded = "uploaded" - processed = "processed" - error = "error" - - class OpenAIFile(BaseModel): id: Annotated[ str, @@ -1696,15 +1549,25 @@ class OpenAIFile(BaseModel): Field(description="The Unix timestamp (in seconds) for when the file was created."), ] filename: Annotated[str, Field(description="The name of the file.")] - object: Annotated[Object12, Field(description="The object type, which is always `file`.")] + object: Annotated[ + Literal["file"], Field(description="The object type, which is always `file`.") + ] purpose: Annotated[ - Purpose2, + Literal[ + "assistants", + "assistants_output", + "batch", + "batch_output", + "fine-tune", + "fine-tune-results", + "vision", + ], Field( description="The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`." ), ] status: Annotated[ - Status, + Literal["uploaded", "processed", "error"], Field( description="Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`." ), @@ -1718,17 +1581,6 @@ class OpenAIFile(BaseModel): ] -class Status1(Enum): - pending = "pending" - completed = "completed" - cancelled = "cancelled" - expired = "expired" - - -class Object13(Enum): - upload = "upload" - - class Upload(BaseModel): id: Annotated[ str, @@ -1748,13 +1600,16 @@ class Upload(BaseModel): description="The intended purpose of the file. [Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values." ), ] - status: Annotated[Status1, Field(description="The status of the Upload.")] + status: Annotated[ + Literal["pending", "completed", "cancelled", "expired"], + Field(description="The status of the Upload."), + ] expires_at: Annotated[ int, Field(description="The Unix timestamp (in seconds) for when the Upload was created."), ] object: Annotated[ - Optional[Object13], + Optional[Literal["upload"]], Field(None, description='The object type, which is always "upload".'), ] file: Annotated[ @@ -1763,10 +1618,6 @@ class Upload(BaseModel): ] -class Object14(Enum): - upload_part = "upload.part" - - class UploadPart(BaseModel): id: Annotated[ str, @@ -1783,14 +1634,11 @@ class UploadPart(BaseModel): Field(description="The ID of the Upload object that this Part was added to."), ] object: Annotated[ - Object14, Field(description="The object type, which is always `upload.part`.") + Literal["upload.part"], + Field(description="The object type, which is always `upload.part`."), ] -class Object15(Enum): - embedding = "embedding" - - class Embedding(BaseModel): index: Annotated[ int, Field(description="The index of the embedding in the list of embeddings.") @@ -1801,7 +1649,10 @@ class Embedding(BaseModel): description="The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).\n" ), ] - object: Annotated[Object15, Field(description='The object type, which is always "embedding".')] + object: Annotated[ + Literal["embedding"], + Field(description='The object type, which is always "embedding".'), + ] class Error1(BaseModel): @@ -1815,11 +1666,7 @@ class Error1(BaseModel): ] -class NEpochs2(Enum): - auto = "auto" - - -class NEpochs3(RootModel[int]): +class NEpochs1(RootModel[int]): root: Annotated[ int, Field( @@ -1832,29 +1679,16 @@ class NEpochs3(RootModel[int]): class Hyperparameters1(BaseModel): n_epochs: Annotated[ - Union[NEpochs2, NEpochs3], + Union[Literal["auto"], NEpochs1], Field( description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.' ), ] -class Object16(Enum): - fine_tuning_job = "fine_tuning.job" - - -class Status2(Enum): - validating_files = "validating_files" - queued = "queued" - running = "running" - succeeded = "succeeded" - failed = "failed" - cancelled = "cancelled" - - class FineTuningIntegration(BaseModel): type: Annotated[ - Type7, + Literal["wandb"], Field(description="The type of the integration being enabled for the fine-tuning job"), ] wandb: Annotated[ @@ -1865,22 +1699,12 @@ class FineTuningIntegration(BaseModel): ] -class Level(Enum): - info = "info" - warn = "warn" - error = "error" - - -class Object17(Enum): - fine_tuning_job_event = "fine_tuning.job.event" - - class FineTuningJobEvent(BaseModel): id: str created_at: int - level: Level + level: Literal["info", "warn", "error"] message: str - object: Object17 + object: Literal["fine_tuning.job.event"] class Metrics(BaseModel): @@ -1893,10 +1717,6 @@ class Metrics(BaseModel): full_valid_mean_token_accuracy: Optional[float] = None -class Object18(Enum): - fine_tuning_job_checkpoint = "fine_tuning.job.checkpoint" - - class FineTuningJobCheckpoint(BaseModel): id: Annotated[ str, @@ -1924,7 +1744,7 @@ class FineTuningJobCheckpoint(BaseModel): Field(description="The name of the fine-tuning job that this checkpoint was created from."), ] object: Annotated[ - Object18, + Literal["fine_tuning.job.checkpoint"], Field(description='The object type, which is always "fine_tuning.job.checkpoint".'), ] @@ -1979,31 +1799,29 @@ class RunStepCompletionUsage(BaseModel): ] -class AssistantsApiResponseFormatOption1(Enum): - none = "none" - auto = "auto" - - -class Type9(Enum): - text = "text" - json_object = "json_object" - - -class AssistantsApiResponseFormat(BaseModel): - type: Annotated[ - Optional[Type9], +class AssistantsApiResponseFormatOption( + RootModel[ + Union[ + Literal["auto"], + ResponseFormatText, + ResponseFormatJsonObject, + ResponseFormatJsonSchema, + ] + ] +): + root: Annotated[ + Union[ + Literal["auto"], + ResponseFormatText, + ResponseFormatJsonObject, + ResponseFormatJsonSchema, + ], Field( - "text", - description="Must be one of `text` or `json_object`.", - examples=["json_object"], + description='Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' ), ] -class Object19(Enum): - assistant = "assistant" - - class CodeInterpreter(BaseModel): file_ids: Annotated[ Optional[List[str]], @@ -2031,31 +1849,6 @@ class ToolResources(BaseModel): file_search: Optional[FileSearch] = None -class Model12(Enum): - gpt_4o = "gpt-4o" - gpt_4o_2024_05_13 = "gpt-4o-2024-05-13" - gpt_4o_mini = "gpt-4o-mini" - gpt_4o_mini_2024_07_18 = "gpt-4o-mini-2024-07-18" - gpt_4_turbo = "gpt-4-turbo" - gpt_4_turbo_2024_04_09 = "gpt-4-turbo-2024-04-09" - gpt_4_0125_preview = "gpt-4-0125-preview" - gpt_4_turbo_preview = "gpt-4-turbo-preview" - gpt_4_1106_preview = "gpt-4-1106-preview" - gpt_4_vision_preview = "gpt-4-vision-preview" - gpt_4 = "gpt-4" - gpt_4_0314 = "gpt-4-0314" - gpt_4_0613 = "gpt-4-0613" - gpt_4_32k = "gpt-4-32k" - gpt_4_32k_0314 = "gpt-4-32k-0314" - gpt_4_32k_0613 = "gpt-4-32k-0613" - gpt_3_5_turbo = "gpt-3.5-turbo" - gpt_3_5_turbo_16k = "gpt-3.5-turbo-16k" - gpt_3_5_turbo_0613 = "gpt-3.5-turbo-0613" - gpt_3_5_turbo_1106 = "gpt-3.5-turbo-1106" - gpt_3_5_turbo_0125 = "gpt-3.5-turbo-0125" - gpt_3_5_turbo_16k_0613 = "gpt-3.5-turbo-16k-0613" - - class CodeInterpreter1(BaseModel): file_ids: Annotated[ Optional[List[str]], @@ -2067,19 +1860,11 @@ class CodeInterpreter1(BaseModel): ] -class Type10(Enum): - auto = "auto" - - class ChunkingStrategy(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type10, Field(description="Always `auto`.")] - - -class Type11(Enum): - static = "static" + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class Static(BaseModel): @@ -2106,7 +1891,7 @@ class ChunkingStrategy1(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type11, Field(description="Always `static`.")] + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: Static @@ -2153,26 +1938,18 @@ class FileSearch1(BaseModel): ] -class Type12(Enum): - auto = "auto" - - class ChunkingStrategy2(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type12, Field(description="Always `auto`.")] - - -class Type13(Enum): - static = "static" + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class ChunkingStrategy3(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type13, Field(description="Always `static`.")] + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: Static @@ -2251,26 +2028,17 @@ class ToolResources2(BaseModel): file_search: Optional[FileSearch3] = None -class Object20(Enum): - assistant_deleted = "assistant.deleted" - - class DeleteAssistantResponse(BaseModel): id: str deleted: bool - object: Object20 - - -class Type14(Enum): - code_interpreter = "code_interpreter" + object: Literal["assistant.deleted"] class AssistantToolsCode(BaseModel): - type: Annotated[Type14, Field(description="The type of tool being defined: `code_interpreter`")] - - -class Type15(Enum): - file_search = "file_search" + type: Annotated[ + Literal["code_interpreter"], + Field(description="The type of tool being defined: `code_interpreter`"), + ] class FileSearch4(BaseModel): @@ -2278,7 +2046,7 @@ class FileSearch4(BaseModel): Optional[int], Field( None, - description="The maximum number of results the file search tool should output. The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. This number should be between 1 and 50 inclusive.\n\nNote that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information.\n", + description="The maximum number of results the file search tool should output. The default is 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between 1 and 50 inclusive.\n\nNote that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information.\n", ge=1, le=50, ), @@ -2286,7 +2054,10 @@ class FileSearch4(BaseModel): class AssistantToolsFileSearch(BaseModel): - type: Annotated[Type15, Field(description="The type of tool being defined: `file_search`")] + type: Annotated[ + Literal["file_search"], + Field(description="The type of tool being defined: `file_search`"), + ] file_search: Annotated[ Optional[FileSearch4], Field(None, description="Overrides for the file search tool."), @@ -2294,26 +2065,23 @@ class AssistantToolsFileSearch(BaseModel): class AssistantToolsFileSearchTypeOnly(BaseModel): - type: Annotated[Type15, Field(description="The type of tool being defined: `file_search`")] - - -class Type17(Enum): - function = "function" + type: Annotated[ + Literal["file_search"], + Field(description="The type of tool being defined: `file_search`"), + ] class AssistantToolsFunction(BaseModel): - type: Annotated[Type17, Field(description="The type of tool being defined: `function`")] + type: Annotated[ + Literal["function"], + Field(description="The type of tool being defined: `function`"), + ] function: FunctionObject -class Type18(Enum): - auto = "auto" - last_messages = "last_messages" - - class TruncationObject(BaseModel): type: Annotated[ - Type18, + Literal["auto", "last_messages"], Field( description="The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`." ), @@ -2328,25 +2096,13 @@ class TruncationObject(BaseModel): ] -class AssistantsApiToolChoiceOption1(Enum): - none = "none" - auto = "auto" - required = "required" - - -class Type19(Enum): - function = "function" - code_interpreter = "code_interpreter" - file_search = "file_search" - - class Function3(BaseModel): name: Annotated[str, Field(description="The name of the function to call.")] class AssistantsNamedToolChoice(BaseModel): type: Annotated[ - Type19, + Literal["function", "code_interpreter", "file_search"], Field( description="The type of the tool. If type is `function`, the function name must be set" ), @@ -2354,48 +2110,17 @@ class AssistantsNamedToolChoice(BaseModel): function: Optional[Function3] = None -class Object21(Enum): - thread_run = "thread.run" - - -class Status3(Enum): - queued = "queued" - in_progress = "in_progress" - requires_action = "requires_action" - cancelling = "cancelling" - cancelled = "cancelled" - failed = "failed" - completed = "completed" - incomplete = "incomplete" - expired = "expired" - - -class Type20(Enum): - submit_tool_outputs = "submit_tool_outputs" - - -class Code(Enum): - server_error = "server_error" - rate_limit_exceeded = "rate_limit_exceeded" - invalid_prompt = "invalid_prompt" - - class LastError(BaseModel): code: Annotated[ - Code, + Literal["server_error", "rate_limit_exceeded", "invalid_prompt"], Field(description="One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`."), ] message: Annotated[str, Field(description="A human-readable description of the error.")] -class Reason(Enum): - max_completion_tokens = "max_completion_tokens" - max_prompt_tokens = "max_prompt_tokens" - - class IncompleteDetails(BaseModel): reason: Annotated[ - Optional[Reason], + Optional[Literal["max_completion_tokens", "max_prompt_tokens"]], Field( None, description="The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run.", @@ -2450,10 +2175,6 @@ class SubmitToolOutputsRunRequest(BaseModel): ] -class Type21(Enum): - function = "function" - - class Function4(BaseModel): name: Annotated[str, Field(description="The name of the function.")] arguments: Annotated[ @@ -2470,7 +2191,7 @@ class RunToolCallObject(BaseModel): ), ] type: Annotated[ - Type21, + Literal["function"], Field( description="The type of tool call the output is required for. For now, this is always `function`." ), @@ -2505,10 +2226,6 @@ class ToolResources3(BaseModel): file_search: Optional[FileSearch5] = None -class Object22(Enum): - thread = "thread" - - class FileSearch6(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], @@ -2530,7 +2247,10 @@ class ThreadObject(BaseModel): str, Field(description="The identifier, which can be referenced in API endpoints."), ] - object: Annotated[Object22, Field(description="The object type, which is always `thread`.")] + object: Annotated[ + Literal["thread"], + Field(description="The object type, which is always `thread`."), + ] created_at: Annotated[ int, Field(description="The Unix timestamp (in seconds) for when the thread was created."), @@ -2549,26 +2269,18 @@ class ThreadObject(BaseModel): ] -class Type22(Enum): - auto = "auto" - - class ChunkingStrategy4(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type22, Field(description="Always `auto`.")] - - -class Type23(Enum): - static = "static" + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class ChunkingStrategy5(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type23, Field(description="Always `static`.")] + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: Static @@ -2615,26 +2327,18 @@ class FileSearch7(BaseModel): ] -class Type24(Enum): - auto = "auto" - - class ChunkingStrategy6(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type24, Field(description="Always `auto`.")] - - -class Type25(Enum): - static = "static" + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class ChunkingStrategy7(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type25, Field(description="Always `static`.")] + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: Static @@ -2722,14 +2426,10 @@ class ModifyThreadRequest(BaseModel): ] -class Object23(Enum): - thread_deleted = "thread.deleted" - - class DeleteThreadResponse(BaseModel): id: str deleted: bool - object: Object23 + object: Literal["thread.deleted"] class ListThreadsResponse(BaseModel): @@ -2740,31 +2440,11 @@ class ListThreadsResponse(BaseModel): has_more: Annotated[bool, Field(examples=[False])] -class Object24(Enum): - thread_message = "thread.message" - - -class Status4(Enum): - in_progress = "in_progress" - incomplete = "incomplete" - completed = "completed" - - -class Reason1(Enum): - content_filter = "content_filter" - max_tokens = "max_tokens" - run_cancelled = "run_cancelled" - run_expired = "run_expired" - run_failed = "run_failed" - - class IncompleteDetails1(BaseModel): - reason: Annotated[Reason1, Field(description="The reason the message is incomplete.")] - - -class Role7(Enum): - user = "user" - assistant = "assistant" + reason: Annotated[ + Literal["content_filter", "max_tokens", "run_cancelled", "run_expired", "run_failed"], + Field(description="The reason the message is incomplete."), + ] class Attachment(BaseModel): @@ -2778,10 +2458,6 @@ class Attachment(BaseModel): ] -class Object25(Enum): - thread_message_delta = "thread.message.delta" - - class ModifyMessageRequest(BaseModel): model_config = ConfigDict( extra="forbid", @@ -2795,18 +2471,10 @@ class ModifyMessageRequest(BaseModel): ] -class Object26(Enum): - thread_message_deleted = "thread.message.deleted" - - class DeleteMessageResponse(BaseModel): id: str deleted: bool - object: Object26 - - -class Type26(Enum): - image_file = "image_file" + object: Literal["thread.message.deleted"] class ImageFile(BaseModel): @@ -2817,7 +2485,7 @@ class ImageFile(BaseModel): ), ] detail: Annotated[ - Optional[Detail], + Optional[Literal["auto", "low", "high"]], Field( "auto", description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", @@ -2826,7 +2494,7 @@ class ImageFile(BaseModel): class MessageContentImageFileObject(BaseModel): - type: Annotated[Type26, Field(description="Always `image_file`.")] + type: Annotated[Literal["image_file"], Field(description="Always `image_file`.")] image_file: ImageFile @@ -2839,7 +2507,7 @@ class ImageFile1(BaseModel): ), ] detail: Annotated[ - Optional[Detail], + Optional[Literal["auto", "low", "high"]], Field( "auto", description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", @@ -2849,14 +2517,10 @@ class ImageFile1(BaseModel): class MessageDeltaContentImageFileObject(BaseModel): index: Annotated[int, Field(description="The index of the content part in the message.")] - type: Annotated[Type26, Field(description="Always `image_file`.")] + type: Annotated[Literal["image_file"], Field(description="Always `image_file`.")] image_file: Optional[ImageFile1] = None -class Type28(Enum): - image_url = "image_url" - - class ImageUrl1(BaseModel): url: Annotated[ AnyUrl, @@ -2865,7 +2529,7 @@ class ImageUrl1(BaseModel): ), ] detail: Annotated[ - Optional[Detail], + Optional[Literal["auto", "low", "high"]], Field( "auto", description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`", @@ -2874,7 +2538,7 @@ class ImageUrl1(BaseModel): class MessageContentImageUrlObject(BaseModel): - type: Annotated[Type28, Field(description="The type of the content part.")] + type: Annotated[Literal["image_url"], Field(description="The type of the content part.")] image_url: ImageUrl1 @@ -2887,7 +2551,7 @@ class ImageUrl2(BaseModel): ), ] detail: Annotated[ - Optional[Detail], + Optional[Literal["auto", "low", "high"]], Field( "auto", description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`.", @@ -2897,29 +2561,26 @@ class ImageUrl2(BaseModel): class MessageDeltaContentImageUrlObject(BaseModel): index: Annotated[int, Field(description="The index of the content part in the message.")] - type: Annotated[Type28, Field(description="Always `image_url`.")] + type: Annotated[Literal["image_url"], Field(description="Always `image_url`.")] image_url: Optional[ImageUrl2] = None -class Type30(Enum): - text = "text" +class MessageContentRefusalObject(BaseModel): + type: Annotated[Literal["refusal"], Field(description="Always `refusal`.")] + refusal: str class MessageRequestContentTextObject(BaseModel): - type: Annotated[Type30, Field(description="Always `text`.")] + type: Annotated[Literal["text"], Field(description="Always `text`.")] text: Annotated[str, Field(description="Text content to be sent to the model")] -class Type32(Enum): - file_citation = "file_citation" - - class FileCitation(BaseModel): file_id: Annotated[str, Field(description="The ID of the specific File the citation is from.")] class MessageContentTextAnnotationsFileCitationObject(BaseModel): - type: Annotated[Type32, Field(description="Always `file_citation`.")] + type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] text: Annotated[ str, Field(description="The text in the message content that needs to be replaced."), @@ -2929,16 +2590,12 @@ class MessageContentTextAnnotationsFileCitationObject(BaseModel): end_index: Annotated[int, Field(ge=0)] -class Type33(Enum): - file_path = "file_path" - - class FilePath(BaseModel): file_id: Annotated[str, Field(description="The ID of the file that was generated.")] class MessageContentTextAnnotationsFilePathObject(BaseModel): - type: Annotated[Type33, Field(description="Always `file_path`.")] + type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] text: Annotated[ str, Field(description="The text in the message content that needs to be replaced."), @@ -2948,12 +2605,10 @@ class MessageContentTextAnnotationsFilePathObject(BaseModel): end_index: Annotated[int, Field(ge=0)] -class Type34(Enum): - text = "text" - - -class Type35(Enum): - file_citation = "file_citation" +class MessageDeltaContentRefusalObject(BaseModel): + index: Annotated[int, Field(description="The index of the refusal part in the message.")] + type: Annotated[Literal["refusal"], Field(description="Always `refusal`.")] + refusal: Optional[str] = None class FileCitation1(BaseModel): @@ -2968,7 +2623,7 @@ class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): index: Annotated[ int, Field(description="The index of the annotation in the text content part.") ] - type: Annotated[Type35, Field(description="Always `file_citation`.")] + type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] text: Annotated[ Optional[str], Field( @@ -2981,10 +2636,6 @@ class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): end_index: Annotated[Optional[int], Field(None, ge=0)] -class Type36(Enum): - file_path = "file_path" - - class FilePath1(BaseModel): file_id: Annotated[ Optional[str], Field(None, description="The ID of the file that was generated.") @@ -2995,7 +2646,7 @@ class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): index: Annotated[ int, Field(description="The index of the annotation in the text content part.") ] - type: Annotated[Type36, Field(description="Always `file_path`.")] + type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] text: Annotated[ Optional[str], Field( @@ -3008,41 +2659,14 @@ class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): end_index: Annotated[Optional[int], Field(None, ge=0)] -class Object27(Enum): - thread_run_step = "thread.run.step" - - -class Type37(Enum): - message_creation = "message_creation" - tool_calls = "tool_calls" - - -class Status5(Enum): - in_progress = "in_progress" - cancelled = "cancelled" - failed = "failed" - completed = "completed" - expired = "expired" - - -class Code1(Enum): - server_error = "server_error" - rate_limit_exceeded = "rate_limit_exceeded" - - class LastError1(BaseModel): - code: Annotated[Code1, Field(description="One of `server_error` or `rate_limit_exceeded`.")] + code: Annotated[ + Literal["server_error", "rate_limit_exceeded"], + Field(description="One of `server_error` or `rate_limit_exceeded`."), + ] message: Annotated[str, Field(description="A human-readable description of the error.")] -class Object28(Enum): - thread_run_step_delta = "thread.run.step.delta" - - -class Type38(Enum): - message_creation = "message_creation" - - class MessageCreation(BaseModel): message_id: Annotated[ str, @@ -3051,7 +2675,7 @@ class MessageCreation(BaseModel): class RunStepDetailsMessageCreationObject(BaseModel): - type: Annotated[Type38, Field(description="Always `message_creation`.")] + type: Annotated[Literal["message_creation"], Field(description="Always `message_creation`.")] message_creation: MessageCreation @@ -3063,40 +2687,24 @@ class MessageCreation1(BaseModel): class RunStepDeltaStepDetailsMessageCreationObject(BaseModel): - type: Annotated[Type38, Field(description="Always `message_creation`.")] + type: Annotated[Literal["message_creation"], Field(description="Always `message_creation`.")] message_creation: Optional[MessageCreation1] = None -class Type40(Enum): - tool_calls = "tool_calls" - - -class Type42(Enum): - code_interpreter = "code_interpreter" - - -class Type44(Enum): - logs = "logs" - - class RunStepDetailsToolCallsCodeOutputLogsObject(BaseModel): - type: Annotated[Type44, Field(description="Always `logs`.")] + type: Annotated[Literal["logs"], Field(description="Always `logs`.")] logs: Annotated[str, Field(description="The text output from the Code Interpreter tool call.")] class RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject(BaseModel): index: Annotated[int, Field(description="The index of the output in the outputs array.")] - type: Annotated[Type44, Field(description="Always `logs`.")] + type: Annotated[Literal["logs"], Field(description="Always `logs`.")] logs: Annotated[ Optional[str], Field(None, description="The text output from the Code Interpreter tool call."), ] -class Type46(Enum): - image = "image" - - class Image1(BaseModel): file_id: Annotated[ str, Field(description="The [file](/docs/api-reference/files) ID of the image.") @@ -3104,7 +2712,7 @@ class Image1(BaseModel): class RunStepDetailsToolCallsCodeOutputImageObject(BaseModel): - type: Annotated[Type46, Field(description="Always `image`.")] + type: Annotated[Literal["image"], Field(description="Always `image`.")] image: Image1 @@ -3117,18 +2725,14 @@ class Image2(BaseModel): class RunStepDeltaStepDetailsToolCallsCodeOutputImageObject(BaseModel): index: Annotated[int, Field(description="The index of the output in the outputs array.")] - type: Annotated[Type46, Field(description="Always `image`.")] + type: Annotated[Literal["image"], Field(description="Always `image`.")] image: Optional[Image2] = None -class Type48(Enum): - file_search = "file_search" - - class RunStepDetailsToolCallsFileSearchObject(BaseModel): id: Annotated[str, Field(description="The ID of the tool call object.")] type: Annotated[ - Type48, + Literal["file_search"], Field( description="The type of tool call. This is always going to be `file_search` for this type of tool call." ), @@ -3143,7 +2747,7 @@ class RunStepDeltaStepDetailsToolCallsFileSearchObject(BaseModel): index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] type: Annotated[ - Type48, + Literal["file_search"], Field( description="The type of tool call. This is always going to be `file_search` for this type of tool call." ), @@ -3154,10 +2758,6 @@ class RunStepDeltaStepDetailsToolCallsFileSearchObject(BaseModel): ] -class Type50(Enum): - function = "function" - - class Function5(BaseModel): name: Annotated[str, Field(description="The name of the function.")] arguments: Annotated[str, Field(description="The arguments passed to the function.")] @@ -3172,7 +2772,7 @@ class Function5(BaseModel): class RunStepDetailsToolCallsFunctionObject(BaseModel): id: Annotated[str, Field(description="The ID of the tool call object.")] type: Annotated[ - Type50, + Literal["function"], Field( description="The type of tool call. This is always going to be `function` for this type of tool call." ), @@ -3200,7 +2800,7 @@ class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] type: Annotated[ - Type50, + Literal["function"], Field( description="The type of tool call. This is always going to be `function` for this type of tool call." ), @@ -3211,13 +2811,9 @@ class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): ] -class Anchor(Enum): - last_active_at = "last_active_at" - - class VectorStoreExpirationAfter(BaseModel): anchor: Annotated[ - Anchor, + Literal["last_active_at"], Field( description="Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`." ), @@ -3232,10 +2828,6 @@ class VectorStoreExpirationAfter(BaseModel): ] -class Object29(Enum): - vector_store = "vector_store" - - class FileCounts(BaseModel): in_progress: Annotated[ int, @@ -3250,19 +2842,14 @@ class FileCounts(BaseModel): total: Annotated[int, Field(description="The total number of files.")] -class Status6(Enum): - expired = "expired" - in_progress = "in_progress" - completed = "completed" - - class VectorStoreObject(BaseModel): id: Annotated[ str, Field(description="The identifier, which can be referenced in API endpoints."), ] object: Annotated[ - Object29, Field(description="The object type, which is always `vector_store`.") + Literal["vector_store"], + Field(description="The object type, which is always `vector_store`."), ] created_at: Annotated[ int, @@ -3275,7 +2862,7 @@ class VectorStoreObject(BaseModel): ] file_counts: FileCounts status: Annotated[ - Status6, + Literal["expired", "in_progress", "completed"], Field( description="The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use." ), @@ -3325,52 +2912,25 @@ class ListVectorStoresResponse(BaseModel): has_more: Annotated[bool, Field(examples=[False])] -class Object30(Enum): - vector_store_deleted = "vector_store.deleted" - - class DeleteVectorStoreResponse(BaseModel): id: str deleted: bool - object: Object30 - - -class Object31(Enum): - vector_store_file = "vector_store.file" - - -class Status7(Enum): - in_progress = "in_progress" - completed = "completed" - cancelled = "cancelled" - failed = "failed" - - -class Code2(Enum): - internal_error = "internal_error" - file_not_found = "file_not_found" - parsing_error = "parsing_error" - unhandled_mime_type = "unhandled_mime_type" + object: Literal["vector_store.deleted"] class LastError2(BaseModel): - code: Annotated[Code2, Field(description="One of `server_error` or `rate_limit_exceeded`.")] + code: Annotated[ + Literal["server_error", "unsupported_file", "invalid_file"], + Field(description="One of `server_error` or `rate_limit_exceeded`."), + ] message: Annotated[str, Field(description="A human-readable description of the error.")] -class Type52(Enum): - other = "other" - - class OtherChunkingStrategyResponseParam(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type52, Field(description="Always `other`.")] - - -class Type53(Enum): - static = "static" + type: Annotated[Literal["other"], Field(description="Always `other`.")] class StaticChunkingStrategy(BaseModel): @@ -3393,26 +2953,18 @@ class StaticChunkingStrategy(BaseModel): ] -class Type54(Enum): - auto = "auto" - - class AutoChunkingStrategyRequestParam(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type54, Field(description="Always `auto`.")] - - -class Type55(Enum): - static = "static" + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class StaticChunkingStrategyRequestParam(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type55, Field(description="Always `static`.")] + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: StaticChunkingStrategy @@ -3440,18 +2992,10 @@ class CreateVectorStoreFileRequest(BaseModel): chunking_strategy: Optional[ChunkingStrategyRequestParam] = None -class Object32(Enum): - vector_store_file_deleted = "vector_store.file.deleted" - - class DeleteVectorStoreFileResponse(BaseModel): id: str deleted: bool - object: Object32 - - -class Object33(Enum): - vector_store_files_batch = "vector_store.files_batch" + object: Literal["vector_store.file.deleted"] class FileCounts1(BaseModel): @@ -3471,7 +3015,7 @@ class VectorStoreFileBatchObject(BaseModel): Field(description="The identifier, which can be referenced in API endpoints."), ] object: Annotated[ - Object33, + Literal["vector_store.files_batch"], Field(description="The object type, which is always `vector_store.file_batch`."), ] created_at: Annotated[ @@ -3487,7 +3031,7 @@ class VectorStoreFileBatchObject(BaseModel): ), ] status: Annotated[ - Status7, + Literal["in_progress", "completed", "cancelled", "failed"], Field( description="The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`." ), @@ -3510,12 +3054,8 @@ class CreateVectorStoreFileBatchRequest(BaseModel): chunking_strategy: Optional[ChunkingStrategyRequestParam] = None -class Event(Enum): - thread_created = "thread.created" - - class ThreadStreamEvent1(BaseModel): - event: Event + event: Literal["thread.created"] data: ThreadObject @@ -3523,118 +3063,14 @@ class ThreadStreamEvent(RootModel[ThreadStreamEvent1]): root: ThreadStreamEvent1 -class Event1(Enum): - thread_run_created = "thread.run.created" - - -class Event2(Enum): - thread_run_queued = "thread.run.queued" - - -class Event3(Enum): - thread_run_in_progress = "thread.run.in_progress" - - -class Event4(Enum): - thread_run_requires_action = "thread.run.requires_action" - - -class Event5(Enum): - thread_run_completed = "thread.run.completed" - - -class Event6(Enum): - thread_run_incomplete = "thread.run.incomplete" - - -class Event7(Enum): - thread_run_failed = "thread.run.failed" - - -class Event8(Enum): - thread_run_cancelling = "thread.run.cancelling" - - -class Event9(Enum): - thread_run_cancelled = "thread.run.cancelled" - - -class Event10(Enum): - thread_run_expired = "thread.run.expired" - - -class Event11(Enum): - thread_run_step_created = "thread.run.step.created" - - -class Event12(Enum): - thread_run_step_in_progress = "thread.run.step.in_progress" - - -class Event13(Enum): - thread_run_step_delta = "thread.run.step.delta" - - -class Event14(Enum): - thread_run_step_completed = "thread.run.step.completed" - - -class Event15(Enum): - thread_run_step_failed = "thread.run.step.failed" - - -class Event16(Enum): - thread_run_step_cancelled = "thread.run.step.cancelled" - - -class Event17(Enum): - thread_run_step_expired = "thread.run.step.expired" - - -class Event18(Enum): - thread_message_created = "thread.message.created" - - -class Event19(Enum): - thread_message_in_progress = "thread.message.in_progress" - - -class Event20(Enum): - thread_message_delta = "thread.message.delta" - - -class Event21(Enum): - thread_message_completed = "thread.message.completed" - - -class Event22(Enum): - thread_message_incomplete = "thread.message.incomplete" - - -class Event23(Enum): - error = "error" - - class ErrorEvent(BaseModel): - event: Event23 + event: Literal["error"] data: Error -class Event24(Enum): - done = "done" - - -class Data(Enum): - field_DONE_ = "[DONE]" - - class DoneEvent(BaseModel): - event: Event24 - data: Data - - -class Object34(Enum): - batch = "batch" + event: Literal["done"] + data: Literal["[DONE]"] class Datum(BaseModel): @@ -3673,17 +3109,6 @@ class Errors(BaseModel): data: Optional[List[Datum]] = None -class Status9(Enum): - validating = "validating" - failed = "failed" - in_progress = "in_progress" - finalizing = "finalizing" - completed = "completed" - expired = "expired" - cancelling = "cancelling" - cancelled = "cancelled" - - class RequestCounts(BaseModel): total: Annotated[int, Field(description="Total number of requests in the batch.")] completed: Annotated[ @@ -3695,7 +3120,9 @@ class RequestCounts(BaseModel): class Batch(BaseModel): id: str - object: Annotated[Object34, Field(description="The object type, which is always `batch`.")] + object: Annotated[ + Literal["batch"], Field(description="The object type, which is always `batch`.") + ] endpoint: Annotated[str, Field(description="The OpenAI API endpoint used by the batch.")] errors: Optional[Errors] = None input_file_id: Annotated[str, Field(description="The ID of the input file for the batch.")] @@ -3703,7 +3130,19 @@ class Batch(BaseModel): str, Field(description="The time frame within which the batch should be processed."), ] - status: Annotated[Status9, Field(description="The current status of the batch.")] + status: Annotated[ + Literal[ + "validating", + "failed", + "in_progress", + "finalizing", + "completed", + "expired", + "cancelling", + "cancelled", + ], + Field(description="The current status of the batch."), + ] output_file_id: Annotated[ Optional[str], Field( @@ -3782,107 +3221,876 @@ class Batch(BaseModel): Optional[RequestCounts], Field( None, - description="The request counts for different statuses within the batch.", + description="The request counts for different statuses within the batch.", + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + None, + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + ), + ] + + +class BatchRequestInput(BaseModel): + custom_id: Annotated[ + Optional[str], + Field( + None, + description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch.", + ), + ] + method: Annotated[ + Optional[Literal["POST"]], + Field( + None, + description="The HTTP method to be used for the request. Currently only `POST` is supported.", + ), + ] + url: Annotated[ + Optional[str], + Field( + None, + description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported.", + ), + ] + + +class Response(BaseModel): + status_code: Annotated[ + Optional[int], Field(None, description="The HTTP status code of the response") + ] + request_id: Annotated[ + Optional[str], + Field( + None, + description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support.", + ), + ] + body: Annotated[ + Optional[Dict[str, Any]], + Field(None, description="The JSON body of the response"), + ] + + +class Error2(BaseModel): + code: Annotated[Optional[str], Field(None, description="A machine-readable error code.")] + message: Annotated[Optional[str], Field(None, description="A human-readable error message.")] + + +class BatchRequestOutput(BaseModel): + id: Optional[str] = None + custom_id: Annotated[ + Optional[str], + Field( + None, + description="A developer-provided per-request id that will be used to match outputs to inputs.", + ), + ] + response: Optional[Response] = None + error: Annotated[ + Optional[Error2], + Field( + None, + description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure.", + ), + ] + + +class ListBatchesResponse(BaseModel): + data: List[Batch] + first_id: Annotated[Optional[str], Field(None, examples=["batch_abc123"])] + last_id: Annotated[Optional[str], Field(None, examples=["batch_abc456"])] + has_more: bool + object: Literal["list"] + + +class AuditLogActorServiceAccount(BaseModel): + id: Annotated[Optional[str], Field(None, description="The service account id.")] + + +class AuditLogActorUser(BaseModel): + id: Annotated[Optional[str], Field(None, description="The user id.")] + email: Annotated[Optional[str], Field(None, description="The user email.")] + + +class AuditLogActorApiKey(BaseModel): + id: Annotated[Optional[str], Field(None, description="The tracking id of the API key.")] + type: Annotated[ + Optional[Literal["user", "service_account"]], + Field( + None, + description="The type of API key. Can be either `user` or `service_account`.", + ), + ] + user: Optional[AuditLogActorUser] = None + service_account: Optional[AuditLogActorServiceAccount] = None + + +class AuditLogActorSession(BaseModel): + user: Optional[AuditLogActorUser] = None + ip_address: Annotated[ + Optional[str], + Field(None, description="The IP address from which the action was performed."), + ] + + +class AuditLogActor(BaseModel): + type: Annotated[ + Optional[Literal["session", "api_key"]], + Field(None, description="The type of actor. Is either `session` or `api_key`."), + ] + session: Optional[AuditLogActorSession] = None + api_key: Optional[AuditLogActorApiKey] = None + + +class AuditLogEventType( + RootModel[ + Literal[ + "api_key.created", + "api_key.updated", + "api_key.deleted", + "invite.sent", + "invite.accepted", + "invite.deleted", + "login.succeeded", + "login.failed", + "logout.succeeded", + "logout.failed", + "organization.updated", + "project.created", + "project.updated", + "project.archived", + "service_account.created", + "service_account.updated", + "service_account.deleted", + "user.added", + "user.updated", + "user.deleted", + ] + ] +): + root: Annotated[ + Literal[ + "api_key.created", + "api_key.updated", + "api_key.deleted", + "invite.sent", + "invite.accepted", + "invite.deleted", + "login.succeeded", + "login.failed", + "logout.succeeded", + "logout.failed", + "organization.updated", + "project.created", + "project.updated", + "project.archived", + "service_account.created", + "service_account.updated", + "service_account.deleted", + "user.added", + "user.updated", + "user.deleted", + ], + Field(description="The event type."), + ] + + +class Project(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + name: Annotated[Optional[str], Field(None, description="The project title.")] + + +class Data(BaseModel): + scopes: Annotated[ + Optional[List[str]], + Field( + None, + description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`', + ), + ] + + +class ApiKeyCreated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + data: Annotated[ + Optional[Data], + Field(None, description="The payload used to create the API key."), + ] + + +class ChangesRequested(BaseModel): + scopes: Annotated[ + Optional[List[str]], + Field( + None, + description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`', + ), + ] + + +class ApiKeyUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + changes_requested: Annotated[ + Optional[ChangesRequested], + Field(None, description="The payload used to update the API key."), + ] + + +class ApiKeyDeleted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + + +class Data1(BaseModel): + email: Annotated[ + Optional[str], Field(None, description="The email invited to the organization.") + ] + role: Annotated[ + Optional[str], + Field( + None, + description="The role the email was invited to be. Is either `owner` or `member`.", + ), + ] + + +class InviteSent(BaseModel): + id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + data: Annotated[ + Optional[Data1], + Field(None, description="The payload used to create the invite."), + ] + + +class InviteAccepted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + + +class InviteDeleted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + + +class LoginFailed(BaseModel): + error_code: Annotated[Optional[str], Field(None, description="The error code of the failure.")] + error_message: Annotated[ + Optional[str], Field(None, description="The error message of the failure.") + ] + + +class LogoutFailed(BaseModel): + error_code: Annotated[Optional[str], Field(None, description="The error code of the failure.")] + error_message: Annotated[ + Optional[str], Field(None, description="The error message of the failure.") + ] + + +class Settings(BaseModel): + threads_ui_visibility: Annotated[ + Optional[str], + Field( + None, + description="Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`.", + ), + ] + usage_dashboard_visibility: Annotated[ + Optional[str], + Field( + None, + description="Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`.", + ), + ] + + +class ChangesRequested1(BaseModel): + title: Annotated[Optional[str], Field(None, description="The organization title.")] + description: Annotated[Optional[str], Field(None, description="The organization description.")] + name: Annotated[Optional[str], Field(None, description="The organization name.")] + settings: Optional[Settings] = None + + +class OrganizationUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The organization ID.")] + changes_requested: Annotated[ + Optional[ChangesRequested1], + Field(None, description="The payload used to update the organization settings."), + ] + + +class Data2(BaseModel): + name: Annotated[Optional[str], Field(None, description="The project name.")] + title: Annotated[ + Optional[str], + Field(None, description="The title of the project as seen on the dashboard."), + ] + + +class ProjectCreated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + data: Annotated[ + Optional[Data2], + Field(None, description="The payload used to create the project."), + ] + + +class ChangesRequested2(BaseModel): + title: Annotated[ + Optional[str], + Field(None, description="The title of the project as seen on the dashboard."), + ] + + +class ProjectUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + changes_requested: Annotated[ + Optional[ChangesRequested2], + Field(None, description="The payload used to update the project."), + ] + + +class ProjectArchived(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + + +class Data3(BaseModel): + role: Annotated[ + Optional[str], + Field( + None, + description="The role of the service account. Is either `owner` or `member`.", + ), + ] + + +class ServiceAccountCreated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The service account ID.")] + data: Annotated[ + Optional[Data3], + Field(None, description="The payload used to create the service account."), + ] + + +class ChangesRequested3(BaseModel): + role: Annotated[ + Optional[str], + Field( + None, + description="The role of the service account. Is either `owner` or `member`.", + ), + ] + + +class ServiceAccountUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The service account ID.")] + changes_requested: Annotated[ + Optional[ChangesRequested3], + Field(None, description="The payload used to updated the service account."), + ] + + +class ServiceAccountDeleted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The service account ID.")] + + +class Data4(BaseModel): + role: Annotated[ + Optional[str], + Field(None, description="The role of the user. Is either `owner` or `member`."), + ] + + +class UserAdded(BaseModel): + id: Annotated[Optional[str], Field(None, description="The user ID.")] + data: Annotated[ + Optional[Data4], + Field(None, description="The payload used to add the user to the project."), + ] + + +class ChangesRequested4(BaseModel): + role: Annotated[ + Optional[str], + Field(None, description="The role of the user. Is either `owner` or `member`."), + ] + + +class UserUpdated(BaseModel): + id: Annotated[Optional[str], Field(None, description="The project ID.")] + changes_requested: Annotated[ + Optional[ChangesRequested4], + Field(None, description="The payload used to update the user."), + ] + + +class UserDeleted(BaseModel): + id: Annotated[Optional[str], Field(None, description="The user ID.")] + + +class AuditLog(BaseModel): + id: Annotated[str, Field(description="The ID of this log.")] + type: AuditLogEventType + effective_at: Annotated[int, Field(description="The Unix timestamp (in seconds) of the event.")] + project: Annotated[ + Optional[Project], + Field( + None, + description="The project that the action was scoped to. Absent for actions not scoped to projects.", + ), + ] + actor: AuditLogActor + api_key_created: Annotated[ + Optional[ApiKeyCreated], + Field( + None, + alias="api_key.created", + description="The details for events with this `type`.", + ), + ] + api_key_updated: Annotated[ + Optional[ApiKeyUpdated], + Field( + None, + alias="api_key.updated", + description="The details for events with this `type`.", + ), + ] + api_key_deleted: Annotated[ + Optional[ApiKeyDeleted], + Field( + None, + alias="api_key.deleted", + description="The details for events with this `type`.", + ), + ] + invite_sent: Annotated[ + Optional[InviteSent], + Field( + None, + alias="invite.sent", + description="The details for events with this `type`.", + ), + ] + invite_accepted: Annotated[ + Optional[InviteAccepted], + Field( + None, + alias="invite.accepted", + description="The details for events with this `type`.", + ), + ] + invite_deleted: Annotated[ + Optional[InviteDeleted], + Field( + None, + alias="invite.deleted", + description="The details for events with this `type`.", + ), + ] + login_failed: Annotated[ + Optional[LoginFailed], + Field( + None, + alias="login.failed", + description="The details for events with this `type`.", + ), + ] + logout_failed: Annotated[ + Optional[LogoutFailed], + Field( + None, + alias="logout.failed", + description="The details for events with this `type`.", + ), + ] + organization_updated: Annotated[ + Optional[OrganizationUpdated], + Field( + None, + alias="organization.updated", + description="The details for events with this `type`.", + ), + ] + project_created: Annotated[ + Optional[ProjectCreated], + Field( + None, + alias="project.created", + description="The details for events with this `type`.", + ), + ] + project_updated: Annotated[ + Optional[ProjectUpdated], + Field( + None, + alias="project.updated", + description="The details for events with this `type`.", + ), + ] + project_archived: Annotated[ + Optional[ProjectArchived], + Field( + None, + alias="project.archived", + description="The details for events with this `type`.", + ), + ] + service_account_created: Annotated[ + Optional[ServiceAccountCreated], + Field( + None, + alias="service_account.created", + description="The details for events with this `type`.", + ), + ] + service_account_updated: Annotated[ + Optional[ServiceAccountUpdated], + Field( + None, + alias="service_account.updated", + description="The details for events with this `type`.", + ), + ] + service_account_deleted: Annotated[ + Optional[ServiceAccountDeleted], + Field( + None, + alias="service_account.deleted", + description="The details for events with this `type`.", + ), + ] + user_added: Annotated[ + Optional[UserAdded], + Field( + None, + alias="user.added", + description="The details for events with this `type`.", + ), + ] + user_updated: Annotated[ + Optional[UserUpdated], + Field( + None, + alias="user.updated", + description="The details for events with this `type`.", ), ] - metadata: Annotated[ - Optional[Dict[str, Any]], + user_deleted: Annotated[ + Optional[UserDeleted], Field( None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + alias="user.deleted", + description="The details for events with this `type`.", ), ] -class Method(Enum): - POST = "POST" +class ListAuditLogsResponse(BaseModel): + object: Literal["list"] + data: List[AuditLog] + first_id: Annotated[str, Field(examples=["audit_log-defb456h8dks"])] + last_id: Annotated[str, Field(examples=["audit_log-hnbkd8s93s"])] + has_more: bool -class BatchRequestInput(BaseModel): - custom_id: Annotated[ - Optional[str], - Field( - None, - description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch.", - ), +class Invite(BaseModel): + object: Annotated[ + Literal["organization.invite"], + Field(description="The object type, which is always `organization.invite`"), ] - method: Annotated[ - Optional[Method], + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + email: Annotated[ + str, + Field(description="The email address of the individual to whom the invite was sent"), + ] + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + status: Annotated[ + Literal["accepted", "expired", "pending"], + Field(description="`accepted`,`expired`, or `pending`"), + ] + invited_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the invite was sent."), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the invite expires."), + ] + accepted_at: Annotated[ + Optional[int], Field( None, - description="The HTTP method to be used for the request. Currently only `POST` is supported.", + description="The Unix timestamp (in seconds) of when the invite was accepted.", ), ] - url: Annotated[ + + +class InviteListResponse(BaseModel): + object: Annotated[Literal["list"], Field(description="The object type, which is always `list`")] + data: List[Invite] + first_id: Annotated[ Optional[str], + Field(None, description="The first `invite_id` in the retrieved `list`"), + ] + last_id: Annotated[ + Optional[str], + Field(None, description="The last `invite_id` in the retrieved `list`"), + ] + has_more: Annotated[ + Optional[bool], Field( None, - description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported.", + description="The `has_more` property is used for pagination to indicate there are additional results.", ), ] -class Response(BaseModel): - status_code: Annotated[ - Optional[int], Field(None, description="The HTTP status code of the response") +class InviteRequest(BaseModel): + email: Annotated[str, Field(description="Send an email to this address")] + role: Annotated[Literal["reader", "owner"], Field(description="`owner` or `reader`")] + + +class InviteDeleteResponse(BaseModel): + object: Annotated[ + Literal["organization.invite.deleted"], + Field(description="The object type, which is always `organization.invite.deleted`"), ] - request_id: Annotated[ - Optional[str], + id: str + deleted: bool + + +class User(BaseModel): + object: Annotated[ + Literal["organization.user"], + Field(description="The object type, which is always `organization.user`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the user")] + email: Annotated[str, Field(description="The email address of the user")] + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + added_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the user was added."), + ] + + +class UserListResponse(BaseModel): + object: Literal["list"] + data: List[User] + first_id: str + last_id: str + has_more: bool + + +class UserRoleUpdateRequest(BaseModel): + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + + +class UserDeleteResponse(BaseModel): + object: Literal["organization.user.deleted"] + id: str + deleted: bool + + +class Project1(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + object: Annotated[ + Literal["organization.project"], + Field(description="The object type, which is always `organization.project`"), + ] + name: Annotated[str, Field(description="The name of the project. This appears in reporting.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the project was created."), + ] + archived_at: Annotated[ + Optional[int], Field( None, - description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support.", + description="The Unix timestamp (in seconds) of when the project was archived or `null`.", ), ] - body: Annotated[ - Optional[Dict[str, Any]], - Field(None, description="The JSON body of the response"), + status: Annotated[Literal["active", "archived"], Field(description="`active` or `archived`")] + + +class ProjectListResponse(BaseModel): + object: Literal["list"] + data: List[Project1] + first_id: str + last_id: str + has_more: bool + + +class ProjectCreateRequest(BaseModel): + name: Annotated[ + str, + Field(description="The friendly name of the project, this name appears in reports."), ] -class Error2(BaseModel): - code: Annotated[Optional[str], Field(None, description="A machine-readable error code.")] - message: Annotated[Optional[str], Field(None, description="A human-readable error message.")] +class ProjectUpdateRequest(BaseModel): + name: Annotated[ + str, + Field(description="The updated name of the project, this name appears in reports."), + ] -class BatchRequestOutput(BaseModel): - id: Optional[str] = None - custom_id: Annotated[ - Optional[str], +class DefaultProjectErrorResponse(BaseModel): + code: int + message: str + + +class ProjectUser(BaseModel): + object: Annotated[ + Literal["organization.project.user"], + Field(description="The object type, which is always `organization.project.user`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the user")] + email: Annotated[str, Field(description="The email address of the user")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + added_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the project was added."), + ] + + +class ProjectUserListResponse(BaseModel): + object: str + data: List[ProjectUser] + first_id: str + last_id: str + has_more: bool + + +class ProjectUserCreateRequest(BaseModel): + user_id: Annotated[str, Field(description="The ID of the user.")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + + +class ProjectUserUpdateRequest(BaseModel): + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + + +class ProjectUserDeleteResponse(BaseModel): + object: Literal["organization.project.user.deleted"] + id: str + deleted: bool + + +class ProjectServiceAccount(BaseModel): + object: Annotated[ + Literal["organization.project.service_account"], Field( - None, - description="A developer-provided per-request id that will be used to match outputs to inputs.", + description="The object type, which is always `organization.project.service_account`" ), ] - response: Optional[Response] = None - error: Annotated[ - Optional[Error2], + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the service account")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + created_at: Annotated[ + int, Field( - None, - description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure.", + description="The Unix timestamp (in seconds) of when the service account was created" ), ] -class Object35(Enum): - list = "list" +class ProjectServiceAccountListResponse(BaseModel): + object: Literal["list"] + data: List[ProjectServiceAccount] + first_id: str + last_id: str + has_more: bool -class ListBatchesResponse(BaseModel): - data: List[Batch] - first_id: Annotated[Optional[str], Field(None, examples=["batch_abc123"])] - last_id: Annotated[Optional[str], Field(None, examples=["batch_abc456"])] +class ProjectServiceAccountCreateRequest(BaseModel): + name: Annotated[str, Field(description="The name of the service account being created.")] + + +class ProjectServiceAccountApiKey(BaseModel): + object: Annotated[ + Literal["organization.project.service_account.api_key"], + Field( + description="The object type, which is always `organization.project.service_account.api_key`" + ), + ] + value: str + name: str + created_at: int + id: str + + +class ProjectServiceAccountDeleteResponse(BaseModel): + object: Literal["organization.project.service_account.deleted"] + id: str + deleted: bool + + +class Owner(BaseModel): + type: Annotated[ + Optional[Literal["user", "service_account"]], + Field(None, description="`user` or `service_account`"), + ] + user: Optional[ProjectUser] = None + service_account: Optional[ProjectServiceAccount] = None + + +class ProjectApiKey(BaseModel): + object: Annotated[ + Literal["organization.project.api_key"], + Field(description="The object type, which is always `organization.project.api_key`"), + ] + redacted_value: Annotated[str, Field(description="The redacted value of the API key")] + name: Annotated[str, Field(description="The name of the API key")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the API key was created"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + owner: Owner + + +class ProjectApiKeyListResponse(BaseModel): + object: Literal["list"] + data: List[ProjectApiKey] + first_id: str + last_id: str has_more: bool - object: Object35 + + +class ProjectApiKeyDeleteResponse(BaseModel): + object: Literal["organization.project.api_key.deleted"] + id: str + deleted: bool class ListModelsResponse(BaseModel): - object: Object + object: Literal["list"] data: List[Model] class CreateCompletionRequest(BaseModel): model: Annotated[ - Union[str, Model1], + Union[str, Literal["gpt-3.5-turbo-instruct", "davinci-002", "babbage-002"]], Field( description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" ), @@ -4042,66 +4250,25 @@ class CreateCompletionResponse(BaseModel): ), ] object: Annotated[ - Object1, Field(description='The object type, which is always "text_completion"') + Literal["text_completion"], + Field(description='The object type, which is always "text_completion"'), ] usage: Optional[CompletionUsage] = None -class ChatCompletionRequestMessageContentPart( - RootModel[ - Union[ - ChatCompletionRequestMessageContentPartText, - ChatCompletionRequestMessageContentPartImage, - ] - ] -): - root: Union[ - ChatCompletionRequestMessageContentPartText, - ChatCompletionRequestMessageContentPartImage, - ] - - -class Content(RootModel[List[ChatCompletionRequestMessageContentPart]]): - root: Annotated[ - List[ChatCompletionRequestMessageContentPart], - Field( - description="An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model.", - min_length=1, - title="Array of content parts", - ), - ] - - -class ChatCompletionRequestUserMessage(BaseModel): - content: Annotated[ - Union[str, Content], Field(description="The contents of the user message.\n") - ] - role: Annotated[ - Role1, - Field(description="The role of the messages author, in this case `user`."), - ] - name: Annotated[ - Optional[str], - Field( - None, - description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", - ), - ] - - class ChatCompletionTool(BaseModel): type: Annotated[ - Type2, + Literal["function"], Field(description="The type of the tool. Currently, only `function` is supported."), ] function: FunctionObject class ChatCompletionToolChoiceOption( - RootModel[Union[ChatCompletionToolChoiceOption1, ChatCompletionNamedToolChoice]] + RootModel[Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoice]] ): root: Annotated[ - Union[ChatCompletionToolChoiceOption1, ChatCompletionNamedToolChoice], + Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoice], Field( description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tool and instead generates a message.\n`auto` means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools.\nSpecifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n\n`none` is the default when no tools are present. `auto` is the default if tools are present.\n' ), @@ -4117,8 +4284,12 @@ class ChatCompletionMessageToolCalls(RootModel[List[ChatCompletionMessageToolCal class ChatCompletionResponseMessage(BaseModel): content: Annotated[str, Field(description="The contents of the message.")] + refusal: Annotated[str, Field(description="The refusal message generated by the model.")] tool_calls: Optional[ChatCompletionMessageToolCalls] = None - role: Annotated[Role5, Field(description="The role of the author of this message.")] + role: Annotated[ + Literal["assistant"], + Field(description="The role of the author of this message."), + ] function_call: Annotated[ Optional[FunctionCall], Field( @@ -4130,7 +4301,7 @@ class ChatCompletionResponseMessage(BaseModel): class Choice1(BaseModel): finish_reason: Annotated[ - FinishReason1, + Literal["stop", "length", "tool_calls", "content_filter", "function_call"], Field( description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" ), @@ -4156,7 +4327,7 @@ class CreateChatCompletionResponse(BaseModel): ] model: Annotated[str, Field(description="The model used for the chat completion.")] service_tier: Annotated[ - Optional[ServiceTier1], + Optional[Literal["scale", "default"]], Field( None, description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", @@ -4171,7 +4342,7 @@ class CreateChatCompletionResponse(BaseModel): ), ] object: Annotated[ - Object2, + Literal["chat.completion"], Field(description="The object type, which is always `chat.completion`."), ] usage: Optional[CompletionUsage] = None @@ -4179,7 +4350,7 @@ class CreateChatCompletionResponse(BaseModel): class Choice2(BaseModel): finish_reason: Annotated[ - FinishReason2, + Literal["stop", "length", "function_call", "content_filter"], Field( description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function.\n" ), @@ -4211,7 +4382,7 @@ class CreateChatCompletionFunctionResponse(BaseModel): ), ] object: Annotated[ - Object2, + Literal["chat.completion"], Field(description="The object type, which is always `chat.completion`."), ] usage: Optional[CompletionUsage] = None @@ -4224,17 +4395,17 @@ class ImagesResponse(BaseModel): class ListFilesResponse(BaseModel): data: List[OpenAIFile] - object: Object6 + object: Literal["list"] class ListFineTuningJobEventsResponse(BaseModel): data: List[FineTuningJobEvent] - object: Object8 + object: Literal["list"] class ListFineTuningJobCheckpointsResponse(BaseModel): data: List[FineTuningJobCheckpoint] - object: Object8 + object: Literal["list"] first_id: Optional[str] = None last_id: Optional[str] = None has_more: bool @@ -4248,7 +4419,9 @@ class CreateEmbeddingResponse(BaseModel): model: Annotated[ str, Field(description="The name of the model used to generate the embedding.") ] - object: Annotated[Object8, Field(description='The object type, which is always "list".')] + object: Annotated[ + Literal["list"], Field(description='The object type, which is always "list".') + ] usage: Annotated[Usage1, Field(description="The usage information for the request.")] @@ -4289,7 +4462,7 @@ class FineTuningJob(BaseModel): ] model: Annotated[str, Field(description="The base model that is being fine-tuned.")] object: Annotated[ - Object16, + Literal["fine_tuning.job"], Field(description='The object type, which is always "fine_tuning.job".'), ] organization_id: Annotated[ @@ -4302,7 +4475,7 @@ class FineTuningJob(BaseModel): ), ] status: Annotated[ - Status2, + Literal["validating_files", "queued", "running", "succeeded", "failed", "cancelled"], Field( description="The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`." ), @@ -4343,23 +4516,15 @@ class FineTuningJob(BaseModel): ] -class AssistantsApiResponseFormatOption( - RootModel[Union[AssistantsApiResponseFormatOption1, AssistantsApiResponseFormat]] -): - root: Annotated[ - Union[AssistantsApiResponseFormatOption1, AssistantsApiResponseFormat], - Field( - description='Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' - ), - ] - - class AssistantObject(BaseModel): id: Annotated[ str, Field(description="The identifier, which can be referenced in API endpoints."), ] - object: Annotated[Object19, Field(description="The object type, which is always `assistant`.")] + object: Annotated[ + Literal["assistant"], + Field(description="The object type, which is always `assistant`."), + ] created_at: Annotated[ int, Field(description="The Unix timestamp (in seconds) for when the assistant was created."), @@ -4439,10 +4604,38 @@ class CreateAssistantRequest(BaseModel): extra="forbid", ) model: Annotated[ - Union[str, Model12], + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ], Field( description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", - examples=["gpt-4-turbo"], + examples=["gpt-4o"], ), ] name: Annotated[ @@ -4603,10 +4796,10 @@ class ListAssistantsResponse(BaseModel): class AssistantsApiToolChoiceOption( - RootModel[Union[AssistantsApiToolChoiceOption1, AssistantsNamedToolChoice]] + RootModel[Union[Literal["none", "auto", "required"], AssistantsNamedToolChoice]] ): root: Annotated[ - Union[AssistantsApiToolChoiceOption1, AssistantsNamedToolChoice], + Union[Literal["none", "auto", "required"], AssistantsNamedToolChoice], Field( description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tools and instead generates a message.\n`auto` is the default value and means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools before responding to the user.\nSpecifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n' ), @@ -4620,7 +4813,10 @@ class SubmitToolOutputs(BaseModel): class RequiredAction(BaseModel): - type: Annotated[Type20, Field(description="For now, this is always `submit_tool_outputs`.")] + type: Annotated[ + Literal["submit_tool_outputs"], + Field(description="For now, this is always `submit_tool_outputs`."), + ] submit_tool_outputs: Annotated[ SubmitToolOutputs, Field(description="Details on the tool outputs needed for this run to continue."), @@ -4632,7 +4828,10 @@ class RunObject(BaseModel): str, Field(description="The identifier, which can be referenced in API endpoints."), ] - object: Annotated[Object21, Field(description="The object type, which is always `thread.run`.")] + object: Annotated[ + Literal["thread.run"], + Field(description="The object type, which is always `thread.run`."), + ] created_at: Annotated[ int, Field(description="The Unix timestamp (in seconds) for when the run was created."), @@ -4650,7 +4849,17 @@ class RunObject(BaseModel): ), ] status: Annotated[ - Status3, + Literal[ + "queued", + "in_progress", + "requires_action", + "cancelling", + "cancelled", + "failed", + "completed", + "incomplete", + "expired", + ], Field( description="The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`." ), @@ -4761,7 +4970,7 @@ class ListRunsResponse(BaseModel): has_more: Annotated[bool, Field(examples=[False])] -class Content1( +class Content4( RootModel[ List[ Union[ @@ -4793,12 +5002,12 @@ class CreateMessageRequest(BaseModel): extra="forbid", ) role: Annotated[ - Role7, + Literal["user", "assistant"], Field( description="The role of the entity that is creating the message. Allowed values include:\n- `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages.\n- `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation.\n" ), ] - content: Union[str, Content1] + content: Union[str, Content4] attachments: Annotated[ Optional[List[Attachment]], Field( @@ -4826,7 +5035,7 @@ class Text(BaseModel): class MessageContentTextObject(BaseModel): - type: Annotated[Type30, Field(description="Always `text`.")] + type: Annotated[Literal["text"], Field(description="Always `text`.")] text: Text @@ -4844,7 +5053,7 @@ class Text1(BaseModel): class MessageDeltaContentTextObject(BaseModel): index: Annotated[int, Field(description="The index of the content part in the message.")] - type: Annotated[Type34, Field(description="Always `text`.")] + type: Annotated[Literal["text"], Field(description="Always `text`.")] text: Optional[Text1] = None @@ -4866,7 +5075,7 @@ class CodeInterpreter7(BaseModel): class RunStepDetailsToolCallsCodeObject(BaseModel): id: Annotated[str, Field(description="The ID of the tool call.")] type: Annotated[ - Type42, + Literal["code_interpreter"], Field( description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." ), @@ -4902,7 +5111,7 @@ class RunStepDeltaStepDetailsToolCallsCodeObject(BaseModel): index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] type: Annotated[ - Type42, + Literal["code_interpreter"], Field( description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." ), @@ -4947,57 +5156,57 @@ class StaticChunkingStrategyResponseParam(BaseModel): model_config = ConfigDict( extra="forbid", ) - type: Annotated[Type53, Field(description="Always `static`.")] + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: StaticChunkingStrategy class RunStreamEvent1(BaseModel): - event: Event1 + event: Literal["thread.run.created"] data: RunObject class RunStreamEvent2(BaseModel): - event: Event2 + event: Literal["thread.run.queued"] data: RunObject class RunStreamEvent3(BaseModel): - event: Event3 + event: Literal["thread.run.in_progress"] data: RunObject class RunStreamEvent4(BaseModel): - event: Event4 + event: Literal["thread.run.requires_action"] data: RunObject class RunStreamEvent5(BaseModel): - event: Event5 + event: Literal["thread.run.completed"] data: RunObject class RunStreamEvent6(BaseModel): - event: Event6 + event: Literal["thread.run.incomplete"] data: RunObject class RunStreamEvent7(BaseModel): - event: Event7 + event: Literal["thread.run.failed"] data: RunObject class RunStreamEvent8(BaseModel): - event: Event8 + event: Literal["thread.run.cancelling"] data: RunObject class RunStreamEvent9(BaseModel): - event: Event9 + event: Literal["thread.run.cancelled"] data: RunObject class RunStreamEvent10(BaseModel): - event: Event10 + event: Literal["thread.run.expired"] data: RunObject @@ -5031,16 +5240,31 @@ class RunStreamEvent( ] +class ProjectServiceAccountCreateResponse(BaseModel): + object: Literal["organization.project.service_account"] + id: str + name: str + role: Annotated[ + Literal["member"], + Field(description="Service accounts can only have one role of type `member`"), + ] + created_at: int + api_key: ProjectServiceAccountApiKey + + class ChatCompletionRequestAssistantMessage(BaseModel): content: Annotated[ - Optional[str], + Optional[Union[str, Content2]], Field( None, description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n", ), ] + refusal: Annotated[ + Optional[str], Field(None, description="The refusal message by the assistant.") + ] role: Annotated[ - Role2, + Literal["assistant"], Field(description="The role of the messages author, in this case `assistant`."), ] name: Annotated[ @@ -5062,14 +5286,14 @@ class ChatCompletionRequestAssistantMessage(BaseModel): class FineTuneChatCompletionRequestAssistantMessage(ChatCompletionRequestAssistantMessage): weight: Annotated[ - Optional[Weight], + Optional[Literal[0, 1]], Field( None, description="Controls whether the assistant message is trained against (0 or 1)", ), ] role: Annotated[ - Role2, + Literal["assistant"], Field(description="The role of the messages author, in this case `assistant`."), ] @@ -5077,7 +5301,7 @@ class FineTuneChatCompletionRequestAssistantMessage(ChatCompletionRequestAssista class ListPaginatedFineTuningJobsResponse(BaseModel): data: List[FineTuningJob] has_more: bool - object: Object4 + object: Literal["list"] class FinetuneChatRequestInput(BaseModel): @@ -5122,11 +5346,41 @@ class CreateRunRequest(BaseModel): ), ] model: Annotated[ - Optional[Union[str, Model12]], + Optional[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ] + ], Field( None, description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", - examples=["gpt-4-turbo"], + examples=["gpt-4o"], ), ] instructions: Annotated[ @@ -5247,7 +5501,7 @@ class MessageObject(BaseModel): Field(description="The identifier, which can be referenced in API endpoints."), ] object: Annotated[ - Object24, + Literal["thread.message"], Field(description="The object type, which is always `thread.message`."), ] created_at: Annotated[ @@ -5261,7 +5515,7 @@ class MessageObject(BaseModel): ), ] status: Annotated[ - Status4, + Literal["in_progress", "incomplete", "completed"], Field( description="The status of the message, which can be either `in_progress`, `incomplete`, or `completed`." ), @@ -5281,7 +5535,7 @@ class MessageObject(BaseModel): ), ] role: Annotated[ - Role7, + Literal["user", "assistant"], Field(description="The entity that produced the message. One of `user` or `assistant`."), ] content: Annotated[ @@ -5290,6 +5544,7 @@ class MessageObject(BaseModel): MessageContentImageFileObject, MessageContentImageUrlObject, MessageContentTextObject, + MessageContentRefusalObject, ] ], Field(description="The content of the message in array of text and/or images."), @@ -5322,7 +5577,7 @@ class MessageObject(BaseModel): class Delta(BaseModel): role: Annotated[ - Optional[Role7], + Optional[Literal["user", "assistant"]], Field( None, description="The entity that produced the message. One of `user` or `assistant`.", @@ -5334,6 +5589,7 @@ class Delta(BaseModel): Union[ MessageDeltaContentImageFileObject, MessageDeltaContentTextObject, + MessageDeltaContentRefusalObject, MessageDeltaContentImageUrlObject, ] ] @@ -5353,7 +5609,7 @@ class MessageDeltaObject(BaseModel): ), ] object: Annotated[ - Object25, + Literal["thread.message.delta"], Field(description="The object type, which is always `thread.message.delta`."), ] delta: Annotated[ @@ -5371,7 +5627,7 @@ class ListMessagesResponse(BaseModel): class RunStepDetailsToolCallsObject(BaseModel): - type: Annotated[Type40, Field(description="Always `tool_calls`.")] + type: Annotated[Literal["tool_calls"], Field(description="Always `tool_calls`.")] tool_calls: Annotated[ List[ Union[ @@ -5387,7 +5643,7 @@ class RunStepDetailsToolCallsObject(BaseModel): class RunStepDeltaStepDetailsToolCallsObject(BaseModel): - type: Annotated[Type40, Field(description="Always `tool_calls`.")] + type: Annotated[Literal["tool_calls"], Field(description="Always `tool_calls`.")] tool_calls: Annotated[ Optional[ List[ @@ -5411,7 +5667,7 @@ class VectorStoreFileObject(BaseModel): Field(description="The identifier, which can be referenced in API endpoints."), ] object: Annotated[ - Object31, + Literal["vector_store.file"], Field(description="The object type, which is always `vector_store.file`."), ] usage_bytes: Annotated[ @@ -5433,7 +5689,7 @@ class VectorStoreFileObject(BaseModel): ), ] status: Annotated[ - Status7, + Literal["in_progress", "completed", "cancelled", "failed"], Field( description="The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use." ), @@ -5459,27 +5715,27 @@ class ListVectorStoreFilesResponse(BaseModel): class MessageStreamEvent1(BaseModel): - event: Event18 + event: Literal["thread.message.created"] data: MessageObject class MessageStreamEvent2(BaseModel): - event: Event19 + event: Literal["thread.message.in_progress"] data: MessageObject class MessageStreamEvent3(BaseModel): - event: Event20 + event: Literal["thread.message.delta"] data: MessageDeltaObject class MessageStreamEvent4(BaseModel): - event: Event21 + event: Literal["thread.message.completed"] data: MessageObject class MessageStreamEvent5(BaseModel): - event: Event22 + event: Literal["thread.message.incomplete"] data: MessageObject @@ -5514,12 +5770,15 @@ class ChatCompletionRequestMessage( ] ] ): - root: Union[ - ChatCompletionRequestSystemMessage, - ChatCompletionRequestUserMessage, - ChatCompletionRequestAssistantMessage, - ChatCompletionRequestToolMessage, - ChatCompletionRequestFunctionMessage, + root: Annotated[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ], + Field(discriminator="role"), ] @@ -5532,10 +5791,39 @@ class CreateChatCompletionRequest(BaseModel): ), ] model: Annotated[ - Union[str, Model2], + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "chatgpt-4o-latest", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ], Field( description="ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.", - examples=["gpt-4-turbo"], + examples=["gpt-4o"], ), ] frequency_penalty: Annotated[ @@ -5597,10 +5885,10 @@ class CreateChatCompletionRequest(BaseModel): ), ] response_format: Annotated[ - Optional[ResponseFormat], + Optional[Union[ResponseFormatText, ResponseFormatJsonObject, ResponseFormatJsonSchema]], Field( None, - description='An object specifying the format that the model must output. Compatible with [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n', + description='An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n', ), ] seed: Annotated[ @@ -5613,7 +5901,7 @@ class CreateChatCompletionRequest(BaseModel): ), ] service_tier: Annotated[ - Optional[ServiceTier], + Optional[Literal["auto", "default"]], Field( None, description="Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service:\n - If set to 'auto', the system will utilize scale tier credits until they are exhausted.\n - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee.\n - When not set, the default behavior is 'auto'.\n\n When this parameter is set, the response body will include the `service_tier` utilized.\n", @@ -5672,7 +5960,7 @@ class CreateChatCompletionRequest(BaseModel): ), ] function_call: Annotated[ - Optional[Union[FunctionCall3, ChatCompletionFunctionCallOption]], + Optional[Union[Literal["none", "auto"], ChatCompletionFunctionCallOption]], Field( None, description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n', @@ -5707,11 +5995,41 @@ class CreateThreadAndRunRequest(BaseModel): ), ] model: Annotated[ - Optional[Union[str, Model12]], + Optional[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ] + ], Field( None, description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", - examples=["gpt-4-turbo"], + examples=["gpt-4o"], ), ] instructions: Annotated[ @@ -5800,7 +6118,7 @@ class RunStepObject(BaseModel): ), ] object: Annotated[ - Object27, + Literal["thread.run.step"], Field(description="The object type, which is always `thread.run.step`."), ] created_at: Annotated[ @@ -5824,13 +6142,13 @@ class RunStepObject(BaseModel): ), ] type: Annotated[ - Type37, + Literal["message_creation", "tool_calls"], Field( description="The type of run step, which can be either `message_creation` or `tool_calls`." ), ] status: Annotated[ - Status5, + Literal["in_progress", "cancelled", "failed", "completed", "expired"], Field( description="The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`." ), @@ -5892,7 +6210,7 @@ class RunStepDeltaObject(BaseModel): ), ] object: Annotated[ - Object28, + Literal["thread.run.step.delta"], Field(description="The object type, which is always `thread.run.step.delta`."), ] delta: Annotated[ @@ -5910,37 +6228,37 @@ class ListRunStepsResponse(BaseModel): class RunStepStreamEvent1(BaseModel): - event: Event11 + event: Literal["thread.run.step.created"] data: RunStepObject class RunStepStreamEvent2(BaseModel): - event: Event12 + event: Literal["thread.run.step.in_progress"] data: RunStepObject class RunStepStreamEvent3(BaseModel): - event: Event13 + event: Literal["thread.run.step.delta"] data: RunStepDeltaObject class RunStepStreamEvent4(BaseModel): - event: Event14 + event: Literal["thread.run.step.completed"] data: RunStepObject class RunStepStreamEvent5(BaseModel): - event: Event15 + event: Literal["thread.run.step.failed"] data: RunStepObject class RunStepStreamEvent6(BaseModel): - event: Event16 + event: Literal["thread.run.step.cancelled"] data: RunStepObject class RunStepStreamEvent7(BaseModel): - event: Event17 + event: Literal["thread.run.step.expired"] data: RunStepObject diff --git a/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py b/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py index 31c79bcf..ffc0eed9 100644 --- a/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py +++ b/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py @@ -3,7 +3,10 @@ from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests from model_engine_server.common.dtos.llms import CreateBatchCompletionsEngineRequest -from model_engine_server.common.dtos.llms.batch_completion import BatchCompletionsJob +from model_engine_server.common.dtos.llms.batch_completion import ( + BatchCompletionsJob, + UpdateBatchCompletionsV2Request, +) from model_engine_server.core.auth.authentication_repository import User @@ -23,7 +26,6 @@ async def create_batch_job( resource_requests: CreateDockerImageBatchJobResourceRequests, max_runtime_sec: int = 24 * 60 * 60, labels: Dict[str, str] = {}, - priority: Optional[int] = 0, num_workers: Optional[int] = 1, ) -> BatchCompletionsJob: """ @@ -45,7 +47,7 @@ async def create_batch_job( pass @abstractmethod - async def get_batch_job(self, batch_job_id: str) -> Optional[BatchCompletionsJob]: + async def get_batch_job(self, batch_job_id: str, user: User) -> Optional[BatchCompletionsJob]: """ Get a batch job. @@ -58,7 +60,22 @@ async def get_batch_job(self, batch_job_id: str) -> Optional[BatchCompletionsJob pass @abstractmethod - async def cancel_batch_job(self, batch_job_id: str) -> bool: + async def update_batch_job( + self, batch_job_id: str, request: UpdateBatchCompletionsV2Request, user: User + ) -> Optional[BatchCompletionsJob]: + """ + Get a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + The batch job, or None if it does not exist. + """ + pass + + @abstractmethod + async def cancel_batch_job(self, batch_job_id: str, user: User) -> bool: """ Update a batch job. diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index eb108194..2f6c9e37 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -39,7 +39,13 @@ UpdateLLMModelEndpointV1Request, UpdateLLMModelEndpointV1Response, ) -from model_engine_server.common.dtos.llms.batch_completion import VLLMEngineAdditionalArgs +from model_engine_server.common.dtos.llms.batch_completion import ( + CancelBatchCompletionsV2Response, + GetBatchCompletionV2Response, + UpdateBatchCompletionsV2Request, + UpdateBatchCompletionsV2Response, + VLLMEngineAdditionalArgs, +) from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus @@ -265,10 +271,10 @@ SERVICE_NAME = "model-engine" SERVICE_IDENTIFIER = os.getenv("SERVICE_IDENTIFIER") -if SERVICE_IDENTIFIER: - SERVICE_NAME += f"-{SERVICE_IDENTIFIER}" LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = f"{SERVICE_NAME}-inference-framework-latest-config" RECOMMENDED_HARDWARE_CONFIG_MAP_NAME = f"{SERVICE_NAME}-recommended-hardware-config" +if SERVICE_IDENTIFIER: + SERVICE_NAME += f"-{SERVICE_IDENTIFIER}" def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int: @@ -279,12 +285,23 @@ def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRep return len(tokenizer.encode(input)) +async def _get_latest_batch_v2_tag(inference_framework: LLMInferenceFramework) -> str: + config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME) + print(config_map) + batch_key = f"{inference_framework}_batch_v2" + if batch_key not in config_map: + raise LatestImageTagNotFoundException( + f"Could not find latest batch job tag for inference framework {inference_framework}. key: {batch_key}" + ) + return config_map[batch_key] + + async def _get_latest_batch_tag(inference_framework: LLMInferenceFramework) -> str: config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME) batch_key = f"{inference_framework}_batch" if batch_key not in config_map: raise LatestImageTagNotFoundException( - f"Could not find latest batch job tag for inference framework {inference_framework}." + f"Could not find latest batch job tag for inference framework {inference_framework}. key: {batch_key}" ) return config_map[batch_key] @@ -2672,7 +2689,7 @@ def __init__( self.llm_artifact_gateway = llm_artifact_gateway async def execute( - self, user: User, request: CreateBatchCompletionsV2Request + self, request: CreateBatchCompletionsV2Request, user: User ) -> CreateBatchCompletionsV2Response: request.model_cfg.checkpoint_path = get_checkpoint_path( request.model_cfg.model, request.model_cfg.checkpoint_path @@ -2702,7 +2719,7 @@ async def execute( # Right now we only support VLLM for batch inference. Refactor this if we support more inference frameworks. image_repo = hmi_config.batch_inference_vllm_repository - image_tag = await _get_latest_batch_tag(LLMInferenceFramework.VLLM) + image_tag = await _get_latest_batch_v2_tag(LLMInferenceFramework.VLLM) additional_engine_args = infer_addition_engine_args_from_model_name( engine_request.model_cfg.model @@ -2720,6 +2737,66 @@ async def execute( resource_requests=hardware, labels=engine_request.labels, max_runtime_sec=engine_request.max_runtime_sec, - priority=engine_request.priority, num_workers=engine_request.data_parallelism, ) + + +class GetBatchCompletionV2UseCase: + def __init__(self, llm_batch_completions_service: LLMBatchCompletionsService): + self.llm_batch_completions_service = llm_batch_completions_service + + async def execute( + self, + batch_completion_id: str, + user: User, + ) -> GetBatchCompletionV2Response: + job = await self.llm_batch_completions_service.get_batch_job( + batch_completion_id, + user=user, + ) + + if not job: + raise ObjectNotFoundException(f"Batch completion {batch_completion_id} not found.") + + return GetBatchCompletionV2Response(job=job) + + +class UpdateBatchCompletionV2UseCase: + def __init__(self, llm_batch_completions_service: LLMBatchCompletionsService): + self.llm_batch_completions_service = llm_batch_completions_service + + async def execute( + self, + batch_completion_id: str, + request: UpdateBatchCompletionsV2Request, + user: User, + ) -> UpdateBatchCompletionsV2Response: + result = await self.llm_batch_completions_service.update_batch_job( + batch_completion_id, + user=user, + request=request, + ) + if not result: + raise ObjectNotFoundException(f"Batch completion {batch_completion_id} not found.") + + return UpdateBatchCompletionsV2Response( + **result.model_dump(by_alias=True, exclude_none=True), + success=True, + ) + + +class CancelBatchCompletionV2UseCase: + def __init__(self, llm_batch_completions_service: LLMBatchCompletionsService): + self.llm_batch_completions_service = llm_batch_completions_service + + async def execute( + self, + batch_completion_id: str, + user: User, + ) -> CancelBatchCompletionsV2Response: + return CancelBatchCompletionsV2Response( + success=await self.llm_batch_completions_service.cancel_batch_job( + batch_completion_id, + user=user, + ) + ) diff --git a/model-engine/model_engine_server/inference/utils.py b/model-engine/model_engine_server/inference/utils.py new file mode 100644 index 00000000..e6137415 --- /dev/null +++ b/model-engine/model_engine_server/inference/utils.py @@ -0,0 +1,117 @@ +import asyncio +import subprocess +import sys +import uuid +from typing import Any, AsyncIterator, Coroutine, Tuple, Union + +from typing_extensions import TypeVar + + +def get_cpu_cores_in_container() -> int: + import multiprocessing + + cpu_count = multiprocessing.cpu_count() + try: + with open("/sys/fs/cgroup/cpu/cpu.cfs_quota_us") as fp: + cfs_quota_us = int(fp.read()) + with open("/sys/fs/cgroup/cpu/cpu.cfs_period_us") as fp: + cfs_period_us = int(fp.read()) + if cfs_quota_us != -1: + cpu_count = cfs_quota_us // cfs_period_us + except FileNotFoundError: + pass + return cpu_count + + +def get_gpu_free_memory(): # pragma: no cover + """Get GPU free memory using nvidia-smi.""" + try: + output = subprocess.run( + ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"], + capture_output=True, + text=True, + ).stdout + gpu_memory = [int(x) for x in output.strip().split("\n")] + return gpu_memory + except Exception as e: + print(f"Error getting GPU memory: {e}") + return None + + +def check_unknown_startup_memory_usage(): # pragma: no cover + """Check for unknown memory usage at startup.""" + gpu_free_memory = get_gpu_free_memory() + if gpu_free_memory is not None: + print(f"GPU free memory at startup in MB: {gpu_free_memory}") + min_mem = min(gpu_free_memory) + max_mem = max(gpu_free_memory) + if max_mem - min_mem > 10: + print( + f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." + ) + try: + output = subprocess.run( + ["fuser -v /dev/nvidia*"], + shell=True, # nosemgrep + capture_output=True, + text=True, + ).stdout + print(f"Processes using GPU: {output}") + except Exception as e: + print(f"Error getting processes using GPU: {e}") + + +def random_uuid() -> str: + return str(uuid.uuid4()) + + +T = TypeVar("T") + + +class ProducerFinished: + pass + + +def await_coroutines(*coroutines: Coroutine[Any, Any, T]) -> AsyncIterator[Tuple[int, T]]: + """Await multiple coroutines concurrently. + + Returns an async iterator that yields the results of the coroutines as they complete. + """ + queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished, Exception]] = asyncio.Queue() + + async def producer(i: int, coroutine: Coroutine[Any, Any, T]): + try: + result = await coroutine + await queue.put((i, result)) + except Exception as e: + await queue.put(e) + # Signal to the consumer that we've finished + await queue.put(ProducerFinished()) + + _tasks = [asyncio.create_task(producer(i, coroutine)) for i, coroutine in enumerate(coroutines)] + + async def consumer(): + remaining = len(coroutines) + try: + while remaining or not queue.empty(): + item = await queue.get() + + if isinstance(item, ProducerFinished): + # Signal that a producer finished- not a real item + remaining -= 1 + continue + + if isinstance(item, Exception): + raise item + yield item + except (Exception, asyncio.CancelledError) as e: + for task in _tasks: + if sys.version_info >= (3, 9): + # msg parameter only supported in Python 3.9+ + task.cancel(e) + else: + task.cancel() + raise e + await asyncio.gather(*_tasks) + + return consumer() diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm index 4109939f..bb4bf801 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm @@ -39,5 +39,17 @@ COPY model-engine /workspace/model-engine RUN pip install -e /workspace/model-engine COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py +# Need to override entrypoint from parent image +ENTRYPOINT ["/bin/env"] + +FROM base AS vllm_batch_v2 + +COPY model-engine/model_engine_server/inference/vllm/requirements-batch.txt /workspace/requirements.txt +RUN pip install -r requirements.txt + +COPY model-engine /workspace/model-engine +RUN pip install -e /workspace/model-engine +COPY model-engine/model_engine_server/inference/vllm/vllm_batch.py /workspace/vllm_batch.py + # Need to override entrypoint from parent image ENTRYPOINT ["/bin/env"] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/README.md b/model-engine/model_engine_server/inference/vllm/README.md new file mode 100644 index 00000000..29b44d60 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/README.md @@ -0,0 +1,61 @@ +# VLLM + +## Building container + +There are three build targets for vLLM. +1. vLLM endpoint +2. vLLM batch job v1 +3. vLLM batch job v2 + +```bash +VLLM_VERSION=0.5.4 bash build_and_upload_image.sh $ACCOUNT_ID $IMAGE_TAG {BUILD_TARGET=vllm|vllm_batch|vllm_batch_v2} +``` + +## Running locally + +### Endpoint + +1. Download model weights to `model_files` +2. Run docker locally +```bash +IMAGE=${ACCOUNT_ID}.dkr.ecr.us-west-2.amazonaws.com/vllm:${IMAGE_TAG} +docker kill vllm; docker rm vllm; +docker run \ + --runtime nvidia \ + --shm-size=16gb \ + --gpus '"device=0"' \ + -v $MODEL_PATH:/workspace/model_files:ro \ + -p 5005:5005 \ + --name vllm \ + ${IMAGE} \ + python -m vllm_server --model model_files --tensor-parallel-size 1 --port 5005 --disable-log-requests +``` + +3. Send curl requests +```bash +curl -X POST localhost:5005/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"messages":[{"role": "user", "content": "Hey, whats the temperature in Paris right now?"}],"model":"model_files","max_tokens":100,"temperature":0.2,"guided_regex":"Sean.*"}' +``` + +### Batch job v2 +```bash +IMAGE_BATCH=${ACCOUNT_ID}.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:${IMAGE_TAG} + +export MODEL=gemma-2-2b-it && export MODEL_PATH=/data/model_files/$MODEL +docker kill vllm_batch; docker rm vllm_batch; +docker run \ + --runtime nvidia \ + --shm-size=16gb \ + --gpus '"device=6,7"' \ + -v $MODEL_PATH:/workspace/model_files:ro \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/examples:/workspace/examples \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_batch.py:/workspace/vllm_batch.py \ + -p 5005:5005 \ + -e CONFIG_FILE=/workspace/examples/sample_config_gemma.json \ + -e MODEL_WEIGHTS_FOLDER=/workspace/model_files \ + --name vllm_batch \ + ${IMAGE_BATCH} \ + python vllm_batch.py + +``` \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/sample_config_gemma.json b/model-engine/model_engine_server/inference/vllm/examples/v2/sample_config_gemma.json new file mode 100644 index 00000000..2b1c020d --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/sample_config_gemma.json @@ -0,0 +1,15 @@ +{ + "input_data_path": "./examples/v2/sample_data_chat_gemma.json", + "output_data_path": "./examples/v2/sample_output.json", + "model_config": { + "model": "gemma-2-2b-it", + "checkpoint_path": "my_path", + "num_shards": 1, + "response_role": "assistant", + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/sample_data_chat_gemma.json b/model-engine/model_engine_server/inference/vllm/examples/v2/sample_data_chat_gemma.json new file mode 100644 index 00000000..39722117 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/sample_data_chat_gemma.json @@ -0,0 +1 @@ +[{"messages": [{"role": "user", "content": "What is a good place for travel in the US?"}, {"role": "assistant", "content": "California."}, {"role": "user", "content": "What can I do in California?"}], "logprobs": true}] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/gen_sample_data.py b/model-engine/model_engine_server/inference/vllm/gen_sample_data.py new file mode 100644 index 00000000..2b2e9367 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/gen_sample_data.py @@ -0,0 +1,36 @@ +import json + +EXAMPLES_DIR = "examples/v2" + +messages = [ + { + "role": "user", + "content": "What is a good place for travel in the US?", + }, + { + "role": "assistant", + "content": "California.", + }, + { + "role": "user", + "content": "What can I do in California?", + }, +] + +if __name__ == "__main__": + + completion_type = "chat" + model = "gemma" + target_file = f"{EXAMPLES_DIR}/sample_data_{completion_type}_{model}.json" + + # request = CompletionCreateParamsNonStreaming( + # messages=messages, + # logprobs=True, + # max_tokens=300, + # ) + request = { + "messages": messages, + "logprobs": True, + "max_tokens": 300, + } + json.dump([request], open(target_file, "w")) diff --git a/model-engine/model_engine_server/inference/vllm/requirements-batch.txt b/model-engine/model_engine_server/inference/vllm/requirements-batch.txt new file mode 100644 index 00000000..f2865af3 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/requirements-batch.txt @@ -0,0 +1,6 @@ +pydantic>=2.8 +boto3==1.34.15 +smart-open==6.4.0 +ddtrace==2.11.0 +datadog==0.49.1 +dataclasses-json~=0.6.7 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/requirements-dev.txt b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt new file mode 100644 index 00000000..34cee62b --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt @@ -0,0 +1 @@ +vllm>=0.5.4 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index c6984a80..3381d938 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,2 +1 @@ -vllm>=0.5.4 pydantic>=2.0 diff --git a/model-engine/model_engine_server/inference/vllm/vllm_batch.py b/model-engine/model_engine_server/inference/vllm/vllm_batch.py new file mode 100644 index 00000000..99e50328 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/vllm_batch.py @@ -0,0 +1,320 @@ +import argparse +import asyncio +import json +import os +import subprocess +from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Dict, List, Optional, Union + +import smart_open +from model_engine_server.common.dtos.llms import ( + BatchCompletionContent, + BatchCompletionsModelConfig, + CompletionResponse, + CompletionV1Output, + CreateBatchCompletionsEngineRequest, + CreateBatchCompletionsV1RequestContent, + TokenOutput, +) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) +from model_engine_server.inference.utils import ( + await_coroutines, + check_unknown_startup_memory_usage, + get_cpu_cores_in_container, + random_uuid, +) +from pydantic import TypeAdapter +from tqdm import tqdm +from typing_extensions import TypeAlias, assert_never +from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams +from vllm.engine.protocol import AsyncEngineClient +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest, ErrorResponse +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.utils import merge_async_iterators + +CONFIG_FILE = os.getenv("CONFIG_FILE") +AWS_REGION = os.getenv("AWS_REGION", "us-west-2") +MODEL_WEIGHTS_FOLDER = os.getenv("MODEL_WEIGHTS_FOLDER", "./model_weights") +os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") + + +openai_serving_chat: OpenAIServingChat +openai_serving_completion: OpenAIServingCompletion + +CPU_COUNT = get_cpu_cores_in_container() + +_BatchCompletionContent: TypeAlias = Union[ + CreateBatchCompletionsV1RequestContent, + List[CompletionRequest], + List[ChatCompletionRequest], +] + + +async def download_model(checkpoint_path: str, target_dir: str) -> None: + s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}" + env = os.environ.copy() + env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") + # Need to override these env vars so s5cmd uses AWS_PROFILE + env["AWS_ROLE_ARN"] = "" + env["AWS_WEB_IDENTITY_TOKEN_FILE"] = "" + process = subprocess.Popen( + s5cmd, + shell=True, # nosemgrep + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=env, + ) + if process.stdout: + for line in process.stdout: + print(line, flush=True) + + process.wait() + + if process.returncode != 0 and process.stderr: + stderr_lines = [] + for line in iter(process.stderr.readline, ""): + stderr_lines.append(line.strip()) + + print(f"Error downloading model weights: {stderr_lines}", flush=True) + + +async def generate_v1_completions( + engine: AsyncEngineClient, + content: CreateBatchCompletionsV1RequestContent, +) -> List[Optional[CompletionV1Output]]: + prompts = content.prompts + bar = tqdm(total=len(prompts), desc="Processed prompts") + sampling_params = SamplingParams( + max_tokens=content.max_new_tokens, + temperature=content.temperature, + stop=content.stop_sequences, + logprobs=1 if content.return_token_log_probs else None, + presence_penalty=content.presence_penalty or 0.0, + frequency_penalty=content.frequency_penalty or 0.0, + top_k=content.top_k or -1, + top_p=content.top_p or 1.0, + skip_special_tokens=( + content.skip_special_tokens if content.skip_special_tokens is not None else True + ), + ) + + results_generators: List[AsyncIterator[RequestOutput]] = [] + for prompt in prompts: + request_id = random_uuid() + results_generator = engine.generate( + prompt, + sampling_params=sampling_params, + request_id=request_id, + ) + results_generators.append(results_generator) + + return_token_log_probs = True + + generator = merge_async_iterators(*results_generators) + outputs: List[Optional[CompletionV1Output]] = [None] * len(prompts) + tokens: List[List[TokenOutput]] = [list()] * len(prompts) + async for i, res in generator: + # There should only be one output + output = res.outputs[-1] + + if return_token_log_probs and output.logprobs is not None: + # Sometime the logprobs are not present in the output + logprobs = output.logprobs[-1] + for token_id in logprobs.keys(): + tokens[i].append( + TokenOutput( + token=logprobs[token_id].decoded_token, + log_prob=logprobs[token_id].logprob, + ) + ) + + if res.finished: + outputs[i] = CompletionV1Output( + text=output.text, + num_prompt_tokens=len(res.prompt_token_ids), + num_completion_tokens=len(output.token_ids), + tokens=[ + token.model_dump() for token in tokens[i] + ], # Not sure why, but pydantic doesn't like when I pass it TokenOutput directly but works when I encode it as a dict... + ) + bar.update(1) + + return outputs + + +async def generate_v2_completions( + engine: AsyncEngineClient, + requests: Union[List[CompletionRequest], List[ChatCompletionRequest]], +) -> List[Union[CompletionResponse, ErrorResponse, None]]: + bar = tqdm(total=len(requests), desc="Processed requests") + results_generators: List[ + Coroutine[ + Any, + Any, + Union[ErrorResponse, AsyncGenerator[str, None], CompletionResponse], + ] + ] = [] + for request in requests: + if isinstance(request, CompletionRequest): + results_generators.append(openai_serving_completion.create_completion(request)) + elif isinstance(request, ChatCompletionRequest): + results_generators.append(openai_serving_chat.create_chat_completion(request)) + else: + assert_never(request) + + results_generator = await_coroutines(*results_generators) + outputs: List[Optional[CompletionResponse]] = [None] * len(requests) + + async for i, res in results_generator: + if isinstance(res, AsyncGenerator): + continue + outputs[i] = res + bar.update(1) + return outputs + + +async def generate_completions( + engine: AsyncEngineClient, request: _BatchCompletionContent +) -> Union[List[Optional[CompletionV1Output]], List[Optional[CompletionResponse]]]: + if isinstance(request, CreateBatchCompletionsV1RequestContent): + return await generate_v1_completions(engine, request) + elif isinstance(request, List): + return await generate_v2_completions(engine, request) + else: + assert_never(request) + + +async def init_engine( + model: str, + request: CreateBatchCompletionsEngineRequest, +) -> AsyncEngineClient: + global openai_serving_chat + global openai_serving_completion + + if request.attention_backend is not None: + os.environ["ATTENTION_BACKEND"] = request.attention_backend + + engine_args = AsyncEngineArgs( + model=model, + tensor_parallel_size=request.model_cfg.num_shards, + seed=request.model_cfg.seed or 0, + disable_log_requests=True, + gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9, + max_model_len=request.model_cfg.max_context_length, + ) + + async_engine_client = AsyncLLMEngine.from_engine_args(engine_args) + model_config = await async_engine_client.get_model_config() + served_model_names = [model] + + openai_serving_chat = OpenAIServingChat( + async_engine_client, + model_config, + served_model_names, + response_role=request.model_cfg.response_role or "assistant", + lora_modules=None, + prompt_adapters=None, + request_logger=None, + chat_template=None, + ) + + openai_serving_completion = OpenAIServingCompletion( + async_engine_client, + model_config, + served_model_names, + lora_modules=None, + prompt_adapters=None, + request_logger=None, + ) + + return async_engine_client + + +def overwrite_request(request: Dict[str, Any], model: str) -> Dict[str, Any]: + request["model"] = model + request["stream"] = False + return request + + +def load_batch_content( + request: CreateBatchCompletionsEngineRequest, +) -> _BatchCompletionContent: + content = request.content + if content is None: + with smart_open.open(request.input_data_path, "r") as f: + data = json.load(f) + content = TypeAdapter(BatchCompletionContent).validate_python(data) + + # Recast the content to vLLMs schema + if isinstance(content, List) and len(content) > 0: + model = get_model_name(request.model_cfg) + return TypeAdapter( + Union[List[CompletionRequest], List[ChatCompletionRequest]] + ).validate_python( + [overwrite_request(req.model_dump(exclude_none=True), model) for req in content] + ) + + return content + + +def get_model_name(model_config: BatchCompletionsModelConfig) -> str: + return MODEL_WEIGHTS_FOLDER if model_config.checkpoint_path else model_config.model + + +async def handle_batch_job(request: CreateBatchCompletionsEngineRequest) -> None: + metrics_gateway = DatadogInferenceMonitoringMetricsGateway() + + model = get_model_name(request.model_cfg) + + if request.model_cfg.checkpoint_path: + await download_model( + checkpoint_path=request.model_cfg.checkpoint_path, + target_dir=MODEL_WEIGHTS_FOLDER, + ) + + content = load_batch_content(request) + engine = await init_engine( + model, + request=request, + ) + + outputs = await generate_completions(engine, content) + with smart_open.open(request.output_data_path, "w") as f: + f.write(json.dumps([output.model_dump() if output else None for output in outputs])) + + metrics_gateway.emit_batch_completions_metric( + model, + use_tool=False, + num_prompt_tokens=0, + num_completion_tokens=0, + is_finetuned=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-file-data", + "--config_file_data", + type=str, + default=None, + help="Optional override for the config file data, as a json string", + ) + + args = parser.parse_args() + + check_unknown_startup_memory_usage() + + config_file_data = args.config_file_data + if config_file_data is None: + if CONFIG_FILE is None or not os.path.exists(CONFIG_FILE): + raise FileNotFoundError(f"Config file {CONFIG_FILE} not found") + with open(CONFIG_FILE, "r") as f: + config_file_data = f.read() + + request = CreateBatchCompletionsEngineRequest.model_validate_json(config_file_data) + + asyncio.run(handle_batch_job(request)) diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 94ad089c..54e6436c 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -87,10 +87,10 @@ def get_default_supported_models_info() -> Dict[str, ModelInfo]: "vicuna-13b": ModelInfo("eachadea/vicuna-13b-1.1", None), "zephyr-7b-alpha": ModelInfo("HuggingFaceH4/zephyr-7b-alpha", None), "zephyr-7b-beta": ModelInfo("HuggingFaceH4/zephyr-7b-beta", None), - "gemma-2b": ModelInfo("google/gemma-2b", None), - "gemma-2b-instruct": ModelInfo("google/gemma-2b-it", None), - "gemma-7b": ModelInfo("google/gemma-7b", None), - "gemma-7b-instruct": ModelInfo("google/gemma-7b-it", None), + "gemma-2-2b": ModelInfo("google/gemma-2-2b", None), + "gemma-2-2b-instruct": ModelInfo("google/gemma-2-2b-it", None), + "gemma-2-7b": ModelInfo("google/gemma-2-7b", None), + "gemma-2-7b-instruct": ModelInfo("google/gemma-2-7b-it", None), "phi-3-mini-4k-instruct": ModelInfo("microsoft/phi-3-mini-4k-instruct", None), "phi-3-mini-128k-instruct": ModelInfo("microsoft/phi-3-mini-128k-instruct", None), "phi-3-small-8k-instruct": ModelInfo("microsoft/phi-3-small-8k-instruct", None), diff --git a/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py b/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py index 40c155f5..ad792365 100644 --- a/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py +++ b/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from datetime import datetime, timedelta from typing import Dict, Optional @@ -7,7 +8,12 @@ BatchCompletionsJobStatus, CreateBatchCompletionsEngineRequest, ) +from model_engine_server.common.dtos.llms.batch_completion import ( + UpdateBatchCompletionsV2Request, + UpdateBatchCompletionsV2Response, +) from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.entities.batch_job_entity import BatchJobStatus from model_engine_server.domain.gateways.docker_image_batch_job_gateway import ( DockerImageBatchJobGateway, ) @@ -16,6 +22,40 @@ ) +def to_dto(status: BatchJobStatus) -> BatchCompletionsJobStatus: + if status == BatchJobStatus.PENDING: + return BatchCompletionsJobStatus.Queued + if status == BatchJobStatus.RUNNING: + return BatchCompletionsJobStatus.Running + if status == BatchJobStatus.FAILURE: + return BatchCompletionsJobStatus.Failed + if status == BatchJobStatus.SUCCESS: + return BatchCompletionsJobStatus.Completed + if status == BatchJobStatus.CANCELLED: + return BatchCompletionsJobStatus.Cancelled + if status == BatchJobStatus.TIMEOUT: + return BatchCompletionsJobStatus.Failed + + return BatchCompletionsJobStatus.Unknown + + +@dataclass +class CustomJobMetadata: + """ + This is a workaround to the current DockerImageBatchJobGateway implementation + which doesn't store additional metadata we need for batch completions v2 + """ + + input_data_path: Optional[str] + output_data_path: str + expires_at: str + priority: Optional[str] + labels: Dict[str, str] + + +NULL_TOKEN = "null" + + class LiveLLMBatchCompletionsService(LLMBatchCompletionsService): def __init__( self, @@ -23,6 +63,30 @@ def __init__( ): self.docker_image_batch_job_gateway = docker_image_batch_job_gateway + def encode_metadata(self, metadata: CustomJobMetadata) -> Dict[str, str]: + return { + "__INT_input_data_path": metadata.input_data_path or NULL_TOKEN, + "__INT_output_data_path": metadata.output_data_path, + "__INT_expires_at": metadata.expires_at, + "__INT_priority": metadata.priority or NULL_TOKEN, + **{f"__LABEL_{key}": value for key, value in metadata.labels.items()}, + } + + def decode_metadata(self, metadata: Dict[str, str]) -> CustomJobMetadata: + labels = { + key.replace("__LABEL_", ""): value + for key, value in metadata.items() + if key.startswith("__LABEL") + } + + return CustomJobMetadata( + input_data_path=metadata.get("__INT_input_data_path", "unknown"), + output_data_path=metadata.get("__INT_output_data_path", "unknown"), + expires_at=metadata.get("__INT_expires_at", "unknown"), + priority=metadata.get("__INT_priority", "unknown"), + labels=labels, + ) + async def create_batch_job( self, *, @@ -33,7 +97,6 @@ async def create_batch_job( resource_requests: CreateDockerImageBatchJobResourceRequests, max_runtime_sec: int = 24 * 60 * 60, labels: Dict[str, str] = {}, - priority: Optional[int] = 0, num_workers: Optional[int] = 1, ): config_file_path = "/opt/config.json" @@ -46,6 +109,7 @@ async def create_batch_job( "ddtrace-run python vllm_batch.py", ] + expires_at = datetime.now() + timedelta(seconds=max_runtime_sec) job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( created_by=user.user_id, owner=user.team_id, @@ -59,6 +123,15 @@ async def create_batch_job( labels=labels, override_job_max_runtime_s=max_runtime_sec, num_workers=num_workers, + annotations=self.encode_metadata( + CustomJobMetadata( + input_data_path=job_request.input_data_path, + output_data_path=job_request.output_data_path, + expires_at=expires_at.isoformat(), + priority=job_request.priority, + labels=job_request.labels, + ) + ), ) return BatchCompletionsJob( job_id=job_id, @@ -68,14 +141,41 @@ async def create_batch_job( priority=job_request.priority, status=BatchCompletionsJobStatus.Queued, created_at=datetime.now().isoformat(), - expires_at=(datetime.now() + timedelta(seconds=max_runtime_sec)).isoformat(), + expires_at=expires_at.isoformat(), completed_at=None, metadata={"labels": job_request.labels}, ) - async def get_batch_job(self, batch_job_id: str) -> Optional[BatchCompletionsJob]: - raise NotImplementedError("Not implemented") + async def get_batch_job(self, batch_job_id: str, user: User) -> Optional[BatchCompletionsJob]: + job = await self.docker_image_batch_job_gateway.get_docker_image_batch_job( + batch_job_id=batch_job_id + ) + + if job is None: + return None + + custom_metadata = self.decode_metadata(job.annotations or {}) + model_config = "[Cannot retrieve] -- please check the job logs" - async def cancel_batch_job(self, batch_job_id: str) -> bool: - # TODO: implement - raise NotImplementedError("Not implemented") + return BatchCompletionsJob( + job_id=batch_job_id, + input_data_path=custom_metadata.input_data_path, + output_data_path=custom_metadata.output_data_path, + model_config=model_config, + priority=custom_metadata.priority, + status=to_dto(job.status), + created_at=job.created_at, + expires_at=custom_metadata.expires_at, + completed_at=job.completed_at, + metadata={"labels": custom_metadata.labels}, + ) + + async def update_batch_job( + self, batch_job_id: str, request: UpdateBatchCompletionsV2Request, user: User + ) -> UpdateBatchCompletionsV2Response: + raise NotImplementedError("Not supported") + + async def cancel_batch_job(self, batch_job_id: str, user: User) -> bool: + return await self.docker_image_batch_job_gateway.update_docker_image_batch_job( + batch_job_id=batch_job_id, cancel=True + ) diff --git a/model-engine/mypy.ini b/model-engine/mypy.ini index f5c39968..fc499b32 100644 --- a/model-engine/mypy.ini +++ b/model-engine/mypy.ini @@ -8,9 +8,6 @@ strict_optional = True plugins = pydantic.mypy exclude = clients|.*/triton_model_repo/.* -[mypy-model_engine_server.cli.*] -ignore_errors = True - [mypy-model_engine_server.core.*] ignore_errors = True diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 3e02e5bc..97beffb8 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -153,6 +153,9 @@ ) from model_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService from model_engine_server.infra.services.image_cache_service import ImageCacheService +from model_engine_server.infra.services.live_llm_batch_completions_service import ( + LiveLLMBatchCompletionsService, +) from model_engine_server.infra.services.live_llm_model_endpoint_service import ( LiveLLMModelEndpointService, ) @@ -2234,6 +2237,9 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: model_endpoint_record_repository=fake_model_endpoint_record_repository, model_endpoint_service=fake_model_endpoint_service, ) + fake_llm_batch_completions_service = LiveLLMBatchCompletionsService( + docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway + ) fake_llm_fine_tuning_service = FakeLLMFineTuningService( fake_llm_fine_tuning_service_contents ) @@ -2249,6 +2255,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: model_bundle_repository=fake_model_bundle_repository, model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + llm_batch_completions_service=fake_llm_batch_completions_service, batch_job_service=fake_batch_job_service, resource_gateway=FakeEndpointResourceGateway(), endpoint_creation_task_queue_gateway=FakeTaskQueueGateway(), diff --git a/requirements-docs.txt b/requirements-docs.txt index fdc1a843..01a02a52 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -5,8 +5,9 @@ mkdocs-material-extensions==1.1.1 mkdocs-render-swagger-plugin~=0.0.4 mkdocs-simple-hooks~=0.1.5 mkdocs-video~=1.5.0 -mkdocstrings[python]~=0.20.0 +mkdocstrings[python]~=0.24.0 pydantic==2.8.2 +griffe<1.0 neoteroi-mkdocs~=1.0.0 tabulate~=0.9.0 scale-llm-engine \ No newline at end of file diff --git a/scripts/generate-openai-types.sh b/scripts/generate-openai-types.sh index cbe3b323..b9b2717b 100755 --- a/scripts/generate-openai-types.sh +++ b/scripts/generate-openai-types.sh @@ -12,5 +12,6 @@ datamodel-codegen \ --input-file-type openapi \ --output ${DEST_DIR}/openai.py \ --output-model-type pydantic_v2.BaseModel \ + --enum-field-as-literal all \ --field-constraints \ --use-annotated \ No newline at end of file diff --git a/scripts/openai-spec.yaml b/scripts/openai-spec.yaml index aebd38a5..8a50ee2f 100644 --- a/scripts/openai-spec.yaml +++ b/scripts/openai-spec.yaml @@ -1,8 +1,9 @@ +# https://github.com/openai/openai-openapi/blob/423e672461b3d17f9829711e4a858e777252f077/openapi.yaml openapi: 3.0.0 info: title: OpenAI API description: The OpenAI REST API. Please see https://platform.openai.com/docs/api-reference for more details. - version: "2.1.0" + version: "2.3.0" termsOfService: https://openai.com/policies/terms-of-use contact: name: OpenAI Support @@ -37,6 +38,8 @@ tags: description: List and describe the various models available in the API. - name: Moderations description: Given a input text, outputs if the model classifies it as potentially harmful. + - name: Audit Logs + description: List user actions and configuration changes within this organization. paths: # Note: When adding an endpoint, make sure you also add it in the `groups` section, in the end of this file, # under the appropriate group @@ -143,7 +146,7 @@ paths: -H "Content-Type: application/json" \ -H "Authorization: Bearer $OPENAI_API_KEY" \ -d '{ - "model": "gpt-4-turbo", + "model": "gpt-4o", "messages": [ { "role": "user", @@ -169,7 +172,7 @@ paths: client = OpenAI() response = client.chat.completions.create( - model="gpt-4-turbo", + model="gpt-4o", messages=[ { "role": "user", @@ -193,7 +196,7 @@ paths: async function main() { const response = await openai.chat.completions.create({ - model: "gpt-4-turbo", + model: "gpt-4o", messages: [ { role: "user", @@ -305,7 +308,7 @@ paths: -H "Content-Type: application/json" \ -H "Authorization: Bearer $OPENAI_API_KEY" \ -d '{ - "model": "gpt-4-turbo", + "model": "gpt-4o", "messages": [ { "role": "user", @@ -399,7 +402,7 @@ paths: ]; const response = await openai.chat.completions.create({ - model: "gpt-4-turbo", + model: "gpt-4o", messages: messages, tools: tools, tool_choice: "auto", @@ -1973,7 +1976,7 @@ paths: -H "Authorization: Bearer $OPENAI_API_KEY" \ -d '{ "training_file": "file-BK7bzQj3FfZFXr7DbL6xJwfo", - "model": "gpt-3.5-turbo" + "model": "gpt-4o-mini" }' python: | from openai import OpenAI @@ -1981,7 +1984,7 @@ paths: client.fine_tuning.jobs.create( training_file="file-abc123", - model="gpt-3.5-turbo" + model="gpt-4o-mini" ) node.js: | import OpenAI from "openai"; @@ -2001,8 +2004,8 @@ paths: { "object": "fine_tuning.job", "id": "ftjob-abc123", - "model": "gpt-3.5-turbo-0125", - "created_at": 1614807352, + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, "fine_tuned_model": null, "organization_id": "org-123", "result_files": [], @@ -2018,7 +2021,7 @@ paths: -H "Authorization: Bearer $OPENAI_API_KEY" \ -d '{ "training_file": "file-abc123", - "model": "gpt-3.5-turbo", + "model": "gpt-4o-mini", "hyperparameters": { "n_epochs": 2 } @@ -2029,7 +2032,7 @@ paths: client.fine_tuning.jobs.create( training_file="file-abc123", - model="gpt-3.5-turbo", + model="gpt-4o-mini", hyperparameters={ "n_epochs":2 } @@ -2042,7 +2045,7 @@ paths: async function main() { const fineTune = await openai.fineTuning.jobs.create({ training_file: "file-abc123", - model: "gpt-3.5-turbo", + model: "gpt-4o-mini", hyperparameters: { n_epochs: 2 } }); @@ -2054,8 +2057,8 @@ paths: { "object": "fine_tuning.job", "id": "ftjob-abc123", - "model": "gpt-3.5-turbo-0125", - "created_at": 1614807352, + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, "fine_tuned_model": null, "organization_id": "org-123", "result_files": [], @@ -2073,7 +2076,7 @@ paths: -d '{ "training_file": "file-abc123", "validation_file": "file-abc123", - "model": "gpt-3.5-turbo" + "model": "gpt-4o-mini" }' python: | from openai import OpenAI @@ -2082,7 +2085,7 @@ paths: client.fine_tuning.jobs.create( training_file="file-abc123", validation_file="file-def456", - model="gpt-3.5-turbo" + model="gpt-4o-mini" ) node.js: | import OpenAI from "openai"; @@ -2103,8 +2106,8 @@ paths: { "object": "fine_tuning.job", "id": "ftjob-abc123", - "model": "gpt-3.5-turbo-0125", - "created_at": 1614807352, + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, "fine_tuned_model": null, "organization_id": "org-123", "result_files": [], @@ -2121,7 +2124,7 @@ paths: -d '{ "training_file": "file-abc123", "validation_file": "file-abc123", - "model": "gpt-3.5-turbo", + "model": "gpt-4o-mini", "integrations": [ { "type": "wandb", @@ -2139,8 +2142,8 @@ paths: { "object": "fine_tuning.job", "id": "ftjob-abc123", - "model": "gpt-3.5-turbo-0125", - "created_at": 1614807352, + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, "fine_tuned_model": null, "organization_id": "org-123", "result_files": [], @@ -2380,7 +2383,7 @@ paths: { "object": "fine_tuning.job.event", "id": "ft-event-ddTJfwuMVpfLXseO0Am0Gqjm", - "created_at": 1692407401, + "created_at": 1721764800, "level": "info", "message": "Fine tuning job successfully completed", "data": null, @@ -2389,9 +2392,9 @@ paths: { "object": "fine_tuning.job.event", "id": "ft-event-tyiGuB72evQncpH87xe505Sv", - "created_at": 1692407400, + "created_at": 1721764800, "level": "info", - "message": "New fine-tuned model created: ft:gpt-3.5-turbo:openai::7p4lURel", + "message": "New fine-tuned model created: ft:gpt-4o-mini:openai::7p4lURel", "data": null, "type": "message" } @@ -2450,8 +2453,8 @@ paths: { "object": "fine_tuning.job", "id": "ftjob-abc123", - "model": "gpt-3.5-turbo-0125", - "created_at": 1689376978, + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, "fine_tuned_model": null, "organization_id": "org-123", "result_files": [], @@ -2514,8 +2517,8 @@ paths: { "object": "fine_tuning.job.checkpoint", "id": "ftckpt_zc4Q7MP6XxulcVzj4MZdwsAB", - "created_at": 1519129973, - "fine_tuned_model_checkpoint": "ft:gpt-3.5-turbo-0125:my-org:custom-suffix:96olL566:ckpt-step-2000", + "created_at": 1721764867, + "fine_tuned_model_checkpoint": "ft:gpt-4o-mini-2024-07-18:my-org:custom-suffix:96olL566:ckpt-step-2000", "metrics": { "full_valid_loss": 0.134, "full_valid_mean_token_accuracy": 0.874 @@ -2526,8 +2529,8 @@ paths: { "object": "fine_tuning.job.checkpoint", "id": "ftckpt_enQCFmOTGj3syEpYVhBRLTSy", - "created_at": 1519129833, - "fine_tuned_model_checkpoint": "ft:gpt-3.5-turbo-0125:my-org:custom-suffix:7q8mpxmy:ckpt-step-1000", + "created_at": 1721764800, + "fine_tuned_model_checkpoint": "ft:gpt-4o-mini-2024-07-18:my-org:custom-suffix:7q8mpxmy:ckpt-step-1000", "metrics": { "full_valid_loss": 0.167, "full_valid_mean_token_accuracy": 0.781 @@ -2619,7 +2622,7 @@ paths: schema: type: string # ideally this will be an actual ID, so this will always work from browser - example: gpt-3.5-turbo + example: gpt-4o-mini description: The ID of the model to use for this request responses: "200": @@ -2672,7 +2675,7 @@ paths: required: true schema: type: string - example: ft:gpt-3.5-turbo:acemeco:suffix:abc123 + example: ft:gpt-4o-mini:acemeco:suffix:abc123 description: The model to delete responses: "200": @@ -2688,28 +2691,28 @@ paths: examples: request: curl: | - curl https://api.openai.com/v1/models/ft:gpt-3.5-turbo:acemeco:suffix:abc123 \ + curl https://api.openai.com/v1/models/ft:gpt-4o-mini:acemeco:suffix:abc123 \ -X DELETE \ -H "Authorization: Bearer $OPENAI_API_KEY" python: | from openai import OpenAI client = OpenAI() - client.models.delete("ft:gpt-3.5-turbo:acemeco:suffix:abc123") + client.models.delete("ft:gpt-4o-mini:acemeco:suffix:abc123") node.js: |- import OpenAI from "openai"; const openai = new OpenAI(); async function main() { - const model = await openai.models.del("ft:gpt-3.5-turbo:acemeco:suffix:abc123"); + const model = await openai.models.del("ft:gpt-4o-mini:acemeco:suffix:abc123"); console.log(model); } main(); response: | { - "id": "ft:gpt-3.5-turbo:acemeco:suffix:abc123", + "id": "ft:gpt-4o-mini:acemeco:suffix:abc123", "object": "model", "deleted": true } @@ -2888,7 +2891,7 @@ paths: "created_at": 1698982736, "name": "Coding Tutor", "description": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": "You are a helpful assistant designed to make me better at coding!", "tools": [], "tool_resources": {}, @@ -2903,7 +2906,7 @@ paths: "created_at": 1698982718, "name": "My Assistant", "description": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": "You are a helpful assistant designed to make me better at coding!", "tools": [], "tool_resources": {}, @@ -2918,7 +2921,7 @@ paths: "created_at": 1698982643, "name": null, "description": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": null, "tools": [], "tool_resources": {}, @@ -2967,7 +2970,7 @@ paths: "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", "name": "Math Tutor", "tools": [{"type": "code_interpreter"}], - "model": "gpt-4-turbo" + "model": "gpt-4o" }' python: | @@ -2978,7 +2981,7 @@ paths: instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.", name="Math Tutor", tools=[{"type": "code_interpreter"}], - model="gpt-4-turbo", + model="gpt-4o", ) print(my_assistant) node.js: |- @@ -2992,7 +2995,7 @@ paths: "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", name: "Math Tutor", tools: [{ type: "code_interpreter" }], - model: "gpt-4-turbo", + model: "gpt-4o", }); console.log(myAssistant); @@ -3006,7 +3009,7 @@ paths: "created_at": 1698984975, "name": "Math Tutor", "description": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", "tools": [ { @@ -3029,7 +3032,7 @@ paths: "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies.", "tools": [{"type": "file_search"}], "tool_resources": {"file_search": {"vector_store_ids": ["vs_123"]}}, - "model": "gpt-4-turbo" + "model": "gpt-4o" }' python: | from openai import OpenAI @@ -3040,7 +3043,7 @@ paths: name="HR Helper", tools=[{"type": "file_search"}], tool_resources={"file_search": {"vector_store_ids": ["vs_123"]}}, - model="gpt-4-turbo" + model="gpt-4o" ) print(my_assistant) node.js: |- @@ -3059,7 +3062,7 @@ paths: vector_store_ids: ["vs_123"] } }, - model: "gpt-4-turbo" + model: "gpt-4o" }); console.log(myAssistant); @@ -3073,7 +3076,7 @@ paths: "created_at": 1699009403, "name": "HR Helper", "description": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies.", "tools": [ { @@ -3150,7 +3153,7 @@ paths: "created_at": 1699009709, "name": "HR Helper", "description": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies.", "tools": [ { @@ -3202,7 +3205,7 @@ paths: -d '{ "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", "tools": [{"type": "file_search"}], - "model": "gpt-4-turbo" + "model": "gpt-4o" }' python: | from openai import OpenAI @@ -3213,7 +3216,7 @@ paths: instructions="You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", name="HR Helper", tools=[{"type": "file_search"}], - model="gpt-4-turbo" + model="gpt-4o" ) print(my_updated_assistant) @@ -3230,7 +3233,7 @@ paths: "You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", name: "HR Helper", tools: [{ type: "file_search" }], - model: "gpt-4-turbo" + model: "gpt-4o" } ); @@ -3245,7 +3248,7 @@ paths: "created_at": 1699009709, "name": "HR Helper", "description": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", "tools": [ { @@ -4203,7 +4206,7 @@ paths: "completed_at": null, "required_action": null, "last_error": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": "You are a helpful assistant.", "tools": [], "tool_resources": {}, @@ -4282,13 +4285,13 @@ paths: data: {"id":"thread_123","object":"thread","created_at":1710348075,"metadata":{}} event: thread.run.created - data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} event: thread.run.queued - data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} event: thread.run.in_progress - data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} event: thread.run.step.created data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} @@ -4320,7 +4323,7 @@ paths: data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710348077,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31}} event: thread.run.completed - {"id":"run_123","object":"thread.run","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1713226836,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1713226837,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":345,"completion_tokens":11,"total_tokens":356},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + {"id":"run_123","object":"thread.run","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1713226836,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1713226837,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":345,"completion_tokens":11,"total_tokens":356},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} event: done data: [DONE] @@ -4451,13 +4454,13 @@ paths: data: {"id":"thread_123","object":"thread","created_at":1710351818,"metadata":{}} event: thread.run.created - data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.queued - data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.in_progress - data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710351818,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710351818,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.step.created data: {"id":"step_001","object":"thread.run.step","created_at":1710351819,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"tool_calls","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710352418,"failed_at":null,"last_error":null,"step_details":{"type":"tool_calls","tool_calls":[]},"usage":null} @@ -4483,7 +4486,7 @@ paths: data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"type":"function","function":{"arguments":"\"}"}}]}}} event: thread.run.requires_action - data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"requires_action","started_at":1710351818,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":{"type":"submit_tool_outputs","submit_tool_outputs":{"tool_calls":[{"id":"call_XXNp8YGaFrjrSjgqxtC8JJ1B","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\":\"San Francisco, CA\",\"unit\":\"fahrenheit\"}"}}]}},"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":345,"completion_tokens":11,"total_tokens":356},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"requires_action","started_at":1710351818,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":{"type":"submit_tool_outputs","submit_tool_outputs":{"tool_calls":[{"id":"call_XXNp8YGaFrjrSjgqxtC8JJ1B","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\":\"San Francisco, CA\",\"unit\":\"fahrenheit\"}"}}]}},"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":345,"completion_tokens":11,"total_tokens":356},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: done data: [DONE] @@ -4584,7 +4587,7 @@ paths: "failed_at": null, "completed_at": 1699075073, "last_error": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": null, "incomplete_details": null, "tools": [ @@ -4631,7 +4634,7 @@ paths: "failed_at": null, "completed_at": 1699063291, "last_error": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": null, "incomplete_details": null, "tools": [ @@ -4750,7 +4753,7 @@ paths: "failed_at": null, "completed_at": 1699063291, "last_error": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": null, "incomplete_details": null, "tools": [ @@ -4814,13 +4817,13 @@ paths: main(); response: | event: thread.run.created - data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.queued - data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.in_progress - data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710330641,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710330641,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.step.created data: {"id":"step_001","object":"thread.run.step","created_at":1710330641,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710331240,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} @@ -4852,7 +4855,7 @@ paths: data: {"id":"step_001","object":"thread.run.step","created_at":1710330641,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710330642,"expires_at":1710331240,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31}} event: thread.run.completed - data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710330641,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710330642,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710330641,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710330642,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: done data: [DONE] @@ -4969,13 +4972,13 @@ paths: main(); response: | event: thread.run.created - data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.queued - data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.in_progress - data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710348075,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710348075,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.step.created data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} @@ -5007,7 +5010,7 @@ paths: data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710348077,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31}} event: thread.run.completed - data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710348075,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710348077,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710348075,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710348077,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: done data: [DONE] @@ -5088,7 +5091,7 @@ paths: "failed_at": null, "completed_at": 1699075073, "last_error": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": null, "incomplete_details": null, "tools": [ @@ -5207,7 +5210,7 @@ paths: "failed_at": null, "completed_at": 1699075073, "last_error": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": null, "incomplete_details": null, "tools": [ @@ -5351,7 +5354,7 @@ paths: "failed_at": null, "completed_at": null, "last_error": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": null, "tools": [ { @@ -5455,10 +5458,10 @@ paths: data: {"id":"step_001","object":"thread.run.step","created_at":1710352449,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"tool_calls","status":"completed","cancelled_at":null,"completed_at":1710352475,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"tool_calls","tool_calls":[{"id":"call_iWr0kQ2EaYMaxNdl0v3KYkx7","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\":\"San Francisco, CA\",\"unit\":\"fahrenheit\"}","output":"70 degrees and sunny."}}]},"usage":{"prompt_tokens":291,"completion_tokens":24,"total_tokens":315}} event: thread.run.queued - data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":1710352448,"expires_at":1710353047,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":1710352448,"expires_at":1710353047,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.in_progress - data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710352475,"expires_at":1710353047,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710352475,"expires_at":1710353047,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: thread.run.step.created data: {"id":"step_002","object":"thread.run.step","created_at":1710352476,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_002"}},"usage":null} @@ -5496,7 +5499,7 @@ paths: data: {"id":"step_002","object":"thread.run.step","created_at":1710352476,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710352477,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_002"}},"usage":{"prompt_tokens":329,"completion_tokens":18,"total_tokens":347}} event: thread.run.completed - data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710352475,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710352477,"required_action":null,"last_error":null,"model":"gpt-4-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710352475,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710352477,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} event: done data: [DONE] @@ -5578,7 +5581,7 @@ paths: "failed_at": null, "completed_at": null, "last_error": null, - "model": "gpt-4-turbo", + "model": "gpt-4o", "instructions": "You summarize books.", "tools": [ { @@ -7227,6 +7230,1432 @@ paths: } } + # Organization + # Audit Logs List + /organization/audit_logs: + get: + summary: List user actions and configuration changes within this organization. + operationId: list-audit-logs + tags: + - Audit Logs + parameters: + - name: effective_at + in: query + description: Return only events whose `effective_at` (Unix seconds) is in this range. + required: false + schema: + type: object + properties: + gt: + type: integer + description: Return only events whose `effective_at` (Unix seconds) is greater than this value. + gte: + type: integer + description: Return only events whose `effective_at` (Unix seconds) is greater than or equal to this value. + lt: + type: integer + description: Return only events whose `effective_at` (Unix seconds) is less than this value. + lte: + type: integer + description: Return only events whose `effective_at` (Unix seconds) is less than or equal to this value. + - name: project_ids[] + in: query + description: Return only events for these projects. + required: false + schema: + type: array + items: + type: string + - name: event_types[] + in: query + description: Return only events with a `type` in one of these values. For example, `project.created`. For all options, see the documentation for the [audit log object](/docs/api-reference/audit-logs/object). + required: false + schema: + type: array + items: + $ref: "#/components/schemas/AuditLogEventType" + - name: actor_ids[] + in: query + description: Return only events performed by these actors. Can be a user ID, a service account ID, or an api key tracking ID. + required: false + schema: + type: array + items: + type: string + - name: actor_emails[] + in: query + description: Return only events performed by users with these emails. + required: false + schema: + type: array + items: + type: string + - name: resource_ids[] + in: query + description: Return only events performed on these targets. For example, a project ID updated. + required: false + schema: + type: array + items: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + responses: + "200": + description: Audit logs listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ListAuditLogsResponse" + x-oaiMeta: + name: List audit logs + group: audit-logs + returns: A list of paginated [Audit Log](/docs/api-reference/audit-logs/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/audit_logs \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + response: | + { + "object": "list", + "data": [ + { + "id": "audit_log-xxx_yyyymmdd", + "type": "project.archived", + "effective_at": 1722461446, + "actor": { + "type": "api_key", + "api_key": { + "type": "user", + "user": { + "id": "user-xxx", + "email": "user@example.com" + } + } + }, + "project.archived": { + "id": "proj_abc" + }, + }, + { + "id": "audit_log-yyy__20240101", + "type": "api_key.updated", + "effective_at": 1720804190, + "actor": { + "type": "session", + "session": { + "user": { + "id": "user-xxx", + "email": "user@example.com" + }, + "ip_address": "127.0.0.1", + "user_agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + }, + "api_key.updated": { + "id": "key_xxxx", + "data": { + "scopes": ["resource_2.operation_2"] + } + }, + } + ], + "first_id": "audit_log-xxx__20240101", + "last_id": "audit_log_yyy__20240101", + "has_more": true + } + /organization/invites: + get: + summary: Returns a list of invites in the organization. + operationId: list-invites + tags: + - Invites + parameters: + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Invites listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/InviteListResponse" + x-oaiMeta: + name: List invites + group: administration + returns: A list of [Invite](/docs/api-reference/invite/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/invites?after=invite-abc&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.invite", + "id": "invite-abc", + "email": "user@example.com", + "role": "owner", + "status": "accepted", + "invited_at": 1711471533, + "expires_at": 1711471533, + "accepted_at": 1711471533 + } + ], + "first_id": "invite-abc", + "last_id": "invite-abc", + "has_more": false + } + + post: + summary: Create an invite for a user to the organization. The invite must be accepted by the user before they have access to the organization. + operationId: inviteUser + tags: + - Invites + requestBody: + description: The invite request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/InviteRequest" + responses: + "200": + description: User invited successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Invite" + x-oaiMeta: + name: Create invite + group: administration + returns: The created [Invite](/docs/api-reference/invite/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/invites \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "email": "user@example.com", + "role": "owner" + }' + response: + content: | + { + "object": "organization.invite", + "id": "invite-abc", + "email": "user@example.com", + "role": "owner", + "invited_at": 1711471533, + "expires_at": 1711471533, + "accepted_at": null + } + + /organization/invites/{invite_id}: + get: + summary: Retrieves an invite. + operationId: retrieve-invite + tags: + - Invites + parameters: + - in: path + name: invite_id + required: true + schema: + type: string + description: The ID of the invite to retrieve. + responses: + "200": + description: Invite retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Invite" + x-oaiMeta: + name: Retrieve invite + group: administration + returns: The [Invite](/docs/api-reference/invite/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/invites/invite-abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.invite", + "id": "invite-abc", + "email": "user@example.com", + "role": "owner", + "status": "accepted", + "invited_at": 1711471533, + "expires_at": 1711471533, + "accepted_at": 1711471533 + } + delete: + summary: Delete an invite. If the invite has already been accepted, it cannot be deleted. + operationId: delete-invite + tags: + - Invites + parameters: + - in: path + name: invite_id + required: true + schema: + type: string + description: The ID of the invite to delete. + responses: + "200": + description: Invite deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/InviteDeleteResponse" + x-oaiMeta: + name: Delete invite + group: administration + returns: Confirmation that the invite has been deleted + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/invites/invite-abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.invite.deleted", + "id": "invite-abc", + "deleted": true + } + + /organization/users: + get: + summary: Lists all of the users in the organization. + operationId: list-users + tags: + - Users + parameters: + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Users listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/UserListResponse" + x-oaiMeta: + name: List users + group: administration + returns: A list of [User](/docs/api-reference/users/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/users?after=user_abc&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + ], + "first_id": "user-abc", + "last_id": "user-xyz", + "has_more": false + } + + /organization/users/{user_id}: + get: + summary: Retrieves a user by their identifier. + operationId: retrieve-user + tags: + - Users + parameters: + - name: user_id + in: path + description: The ID of the user. + required: true + schema: + type: string + responses: + "200": + description: User retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/User" + x-oaiMeta: + name: Retrieve user + group: administration + returns: The [User](/docs/api-reference/users/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + post: + summary: Modifies a user's role in the organization. + operationId: modify-user + tags: + - Users + requestBody: + description: The new user role to modify. This must be one of `owner` or `member`. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UserRoleUpdateRequest" + responses: + "200": + description: User role updated successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/User" + x-oaiMeta: + name: Modify user + group: administration + returns: The updated [User](/docs/api-reference/users/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "role": "owner" + }' + response: + content: | + { + "object": "organization.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + delete: + summary: Deletes a user from the organization. + operationId: delete-user + tags: + - Users + parameters: + - name: user_id + in: path + description: The ID of the user. + required: true + schema: + type: string + responses: + "200": + description: User deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/UserDeleteResponse" + x-oaiMeta: + name: Delete user + group: administration + returns: Confirmation of the deleted user + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.user.deleted", + "id": "user_abc", + "deleted": true + } + /organization/projects: + get: + summary: Returns a list of projects. + operationId: list-projects + tags: + - Projects + parameters: + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + - name: include_archived + in: query + schema: + type: boolean + default: false + description: If `true` returns all projects including those that have been `archived`. Archived projects are not included by default. + responses: + "200": + description: Projects listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectListResponse" + x-oaiMeta: + name: List projects + group: administration + returns: A list of [Project](/docs/api-reference/projects/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects?after=proj_abc&limit=20&include_archived=false \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project example", + "created_at": 1711471533, + "archived_at": null, + "status": "active" + } + ], + "first_id": "proj-abc", + "last_id": "proj-xyz", + "has_more": false + } + + post: + summary: Create a new project in the organization. Projects can be created and archived, but cannot be deleted. + operationId: create-project + tags: + - Projects + requestBody: + description: The project create request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectCreateRequest" + responses: + "200": + description: Project created successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Project" + x-oaiMeta: + name: Create project + group: administration + returns: The created [Project](/docs/api-reference/projects/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Project ABC" + }' + response: + content: | + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project ABC", + "created_at": 1711471533, + "archived_at": null, + "status": "active" + } + + /organization/projects/{project_id}: + get: + summary: Retrieves a project. + operationId: retrieve-project + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + responses: + "200": + description: Project retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Project" + x-oaiMeta: + name: Retrieve project + group: administration + description: Retrieve a project. + returns: The [Project](/docs/api-reference/projects/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project example", + "created_at": 1711471533, + "archived_at": null, + "status": "active" + } + + post: + summary: Modifies a project in the organization. + operationId: modify-project + tags: + - Projects + requestBody: + description: The project update request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUpdateRequest" + responses: + "200": + description: Project updated successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Project" + "400": + description: Error response when updating the default project. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Modify project + group: administration + returns: The updated [Project](/docs/api-reference/projects/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Project DEF" + }' + + /organization/projects/{project_id}/archive: + post: + summary: Archives a project in the organization. Archived projects cannot be used or updated. + operationId: archive-project + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + responses: + "200": + description: Project archived successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Project" + x-oaiMeta: + name: Archive project + group: administration + returns: The archived [Project](/docs/api-reference/projects/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc/archive \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project DEF", + "created_at": 1711471533, + "archived_at": 1711471533, + "status": "archived" + } + + /organization/projects/{project_id}/users: + get: + summary: Returns a list of users in the project. + operationId: list-project-users + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Project users listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUserListResponse" + "400": + description: Error response when project is archived. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: List project users + group: administration + returns: A list of [ProjectUser](/docs/api-reference/project-users/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/users?after=user_abc&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + ], + "first_id": "user-abc", + "last_id": "user-xyz", + "has_more": false + } + error_response: + content: | + { + "code": 400, + "message": "Project {name} is archived" + } + + post: + summary: Adds a user to the project. Users must already be members of the organization to be added to a project. + operationId: create-project-user + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + tags: + - Projects + requestBody: + description: The project user create request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUserCreateRequest" + responses: + "200": + description: User added to project successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUser" + "400": + description: Error response for various conditions. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Create project user + group: administration + returns: The created [ProjectUser](/docs/api-reference/project-users/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc/users \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "user_id": "user_abc", + "role": "member" + }' + response: + content: | + { + "object": "organization.project.user", + "id": "user_abc", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + error_response: + content: | + { + "code": 400, + "message": "Project {name} is archived" + } + + /organization/projects/{project_id}/users/{user_id}: + get: + summary: Retrieves a user in the project. + operationId: retrieve-project-user + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: user_id + in: path + description: The ID of the user. + required: true + schema: + type: string + responses: + "200": + description: Project user retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUser" + x-oaiMeta: + name: Retrieve project user + group: administration + returns: The [ProjectUser](/docs/api-reference/project-users/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + post: + summary: Modifies a user's role in the project. + operationId: modify-project-user + tags: + - Projects + requestBody: + description: The project user update request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUserUpdateRequest" + responses: + "200": + description: Project user's role updated successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUser" + "400": + description: Error response for various conditions. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Modify project user + group: administration + returns: The updated [ProjectUser](/docs/api-reference/project-users/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "role": "owner" + }' + response: + content: | + { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + delete: + summary: Deletes a user from the project. + operationId: delete-project-user + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: user_id + in: path + description: The ID of the user. + required: true + schema: + type: string + responses: + "200": + description: Project user deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUserDeleteResponse" + "400": + description: Error response for various conditions. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Delete project user + group: administration + returns: Confirmation that project has been deleted or an error in case of an archived project, which has no users + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/projects/proj_abc/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.user.deleted", + "id": "user_abc", + "deleted": true + } + + /organization/projects/{project_id}/service_accounts: + get: + summary: Returns a list of service accounts in the project. + operationId: list-project-service-accounts + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Project service accounts listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccountListResponse" + "400": + description: Error response when project is archived. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: List project service accounts + group: administration + returns: A list of [ProjectServiceAccount](/docs/api-reference/project-service-accounts/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/service_accounts?after=custom_id&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.project.service_account", + "id": "svc_acct_abc", + "name": "Service Account", + "role": "owner", + "created_at": 1711471533 + } + ], + "first_id": "svc_acct_abc", + "last_id": "svc_acct_xyz", + "has_more": false + } + + post: + summary: Creates a new service account in the project. This also returns an unredacted API key for the service account. + operationId: create-project-service-account + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + requestBody: + description: The project service account create request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccountCreateRequest" + responses: + "200": + description: Project service account created successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccountCreateResponse" + "400": + description: Error response when project is archived. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Create project service account + group: administration + returns: The created [ProjectServiceAccount](/docs/api-reference/project-service-accounts/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc/service_accounts \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Production App" + }' + response: + content: | + { + "object": "organization.project.service_account", + "id": "svc_acct_abc", + "name": "Production App", + "role": "member", + "created_at": 1711471533, + "api_key": { + "object": "organization.project.service_account.api_key", + "value": "sk-abcdefghijklmnop123", + "name": "Secret Key", + "created_at": 1711471533, + "id": "key_abc" + } + } + + /organization/projects/{project_id}/service_accounts/{service_account_id}: + get: + summary: Retrieves a service account in the project. + operationId: retrieve-project-service-account + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: service_account_id + in: path + description: The ID of the service account. + required: true + schema: + type: string + responses: + "200": + description: Project service account retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccount" + x-oaiMeta: + name: Retrieve project service account + group: administration + returns: The [ProjectServiceAccount](/docs/api-reference/project-service-accounts/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/service_accounts/svc_acct_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.service_account", + "id": "svc_acct_abc", + "name": "Service Account", + "role": "owner", + "created_at": 1711471533 + } + + delete: + summary: Deletes a service account from the project. + operationId: delete-project-service-account + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: service_account_id + in: path + description: The ID of the service account. + required: true + schema: + type: string + responses: + "200": + description: Project service account deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccountDeleteResponse" + x-oaiMeta: + name: Delete project service account + group: administration + returns: Confirmation of service account being deleted, or an error in case of an archived project, which has no service accounts + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/projects/proj_abc/service_accounts/svc_acct_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.service_account.deleted", + "id": "svc_acct_abc", + "deleted": true + } + + /organization/projects/{project_id}/api_keys: + get: + summary: Returns a list of API keys in the project. + operationId: list-project-api-keys + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Project API keys listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectApiKeyListResponse" + + x-oaiMeta: + name: List project API keys + group: administration + returns: A list of [ProjectApiKey](/docs/api-reference/project-api-keys/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/api_keys?after=key_abc&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.project.api_key", + "redacted_value": "sk-abc...def", + "name": "My API Key", + "created_at": 1711471533, + "id": "key_abc", + "owner": { + "type": "user", + "user": { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + } + } + ], + "first_id": "key_abc", + "last_id": "key_xyz", + "has_more": false + } + error_response: + content: | + { + "code": 400, + "message": "Project {name} is archived" + } + + /organization/projects/{project_id}/api_keys/{key_id}: + get: + summary: Retrieves an API key in the project. + operationId: retrieve-project-api-key + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: key_id + in: path + description: The ID of the API key. + required: true + schema: + type: string + responses: + "200": + description: Project API key retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectApiKey" + x-oaiMeta: + name: Retrieve project API key + group: administration + returns: The [ProjectApiKey](/docs/api-reference/project-api-keys/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/api_keys/key_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.api_key", + "redacted_value": "sk-abc...def", + "name": "My API Key", + "created_at": 1711471533, + "id": "key_abc", + "owner": { + "type": "user", + "user": { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + } + } + + delete: + summary: Deletes an API key from the project. + operationId: delete-project-api-key + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: key_id + in: path + description: The ID of the API key. + required: true + schema: + type: string + responses: + "200": + description: Project API key deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectApiKeyDeleteResponse" + "400": + description: Error response for various conditions. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Delete project API key + group: administration + returns: Confirmation of the key's deletion or an error if the key belonged to a service account + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/projects/proj_abc/api_keys/key_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.api_key.deleted", + "id": "key_abc", + "deleted": true + } + error_response: + content: | + { + "code": 400, + "message": "API keys cannot be deleted for service accounts, please delete the service account" + } + components: securitySchemes: ApiKeyAuth: @@ -7585,11 +9014,20 @@ components: } } - ChatCompletionRequestMessageContentPart: - oneOf: - - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" - - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartImage" - x-oaiExpandable: true + ChatCompletionRequestMessageContentPartText: + type: object + title: Text content part + properties: + type: + type: string + enum: ["text"] + description: The type of the content part. + text: + type: string + description: The text content. + required: + - type + - text ChatCompletionRequestMessageContentPartImage: type: object @@ -7617,28 +9055,58 @@ components: - type - image_url - ChatCompletionRequestMessageContentPartText: + ChatCompletionRequestMessageContentPartRefusal: type: object - title: Text content part + title: Refusal content part properties: type: type: string - enum: ["text"] + enum: ["refusal"] description: The type of the content part. - text: + refusal: type: string - description: The text content. + description: The refusal message generated by the model. required: - type - - text + - refusal ChatCompletionRequestMessage: oneOf: - - $ref: "#/components/schemas/ChatCompletionRequestSystemMessage" - - $ref: "#/components/schemas/ChatCompletionRequestUserMessage" - - $ref: "#/components/schemas/ChatCompletionRequestAssistantMessage" - - $ref: "#/components/schemas/ChatCompletionRequestToolMessage" - - $ref: "#/components/schemas/ChatCompletionRequestFunctionMessage" + - $ref: "#/components/schemas/ChatCompletionRequestSystemMessage" + - $ref: "#/components/schemas/ChatCompletionRequestUserMessage" + - $ref: "#/components/schemas/ChatCompletionRequestAssistantMessage" + - $ref: "#/components/schemas/ChatCompletionRequestToolMessage" + - $ref: "#/components/schemas/ChatCompletionRequestFunctionMessage" + x-oaiExpandable: true + discriminator: + propertyName: role + mapping: + system: "#/components/schemas/ChatCompletionRequestSystemMessage" + user: "#/components/schemas/ChatCompletionRequestUserMessage" + assistant: "#/components/schemas/ChatCompletionRequestAssistantMessage" + tool: "#/components/schemas/ChatCompletionRequestToolMessage" + function: "#/components/schemas/ChatCompletionRequestFunctionMessage" + + ChatCompletionRequestSystemMessageContentPart: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" + x-oaiExpandable: true + + ChatCompletionRequestUserMessageContentPart: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartImage" + x-oaiExpandable: true + + ChatCompletionRequestAssistantMessageContentPart: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartRefusal" + x-oaiExpandable: true + + ChatCompletionRequestToolMessageContentPart: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" x-oaiExpandable: true ChatCompletionRequestSystemMessage: @@ -7647,7 +9115,16 @@ components: properties: content: description: The contents of the system message. - type: string + oneOf: + - type: string + description: The contents of the system message. + title: Text content + - type: array + description: An array of content parts with a defined type. For system messages, only type `text` is supported. + title: Array of content parts + items: + $ref: "#/components/schemas/ChatCompletionRequestSystemMessageContentPart" + minItems: 1 role: type: string enum: ["system"] @@ -7674,7 +9151,7 @@ components: description: An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model. title: Array of content parts items: - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPart" + $ref: "#/components/schemas/ChatCompletionRequestUserMessageContentPart" minItems: 1 x-oaiExpandable: true role: @@ -7694,9 +9171,22 @@ components: properties: content: nullable: true - type: string + oneOf: + - type: string + description: The contents of the assistant message. + title: Text content + - type: array + description: An array of content parts with a defined type. Can be one or more of type `text`, or exactly one of type `refusal`. + title: Array of content parts + items: + $ref: "#/components/schemas/ChatCompletionRequestAssistantMessageContentPart" + minItems: 1 description: | The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified. + refusal: + nullable: true + type: string + description: The refusal message by the assistant. role: type: string enum: ["assistant"] @@ -7747,7 +9237,16 @@ components: enum: ["tool"] description: The role of the messages author, in this case `tool`. content: - type: string + oneOf: + - type: string + description: The contents of the tool message. + title: Text content + - type: array + description: An array of content parts with a defined type. For tool messages, only type `text` is supported. + title: Array of content parts + items: + $ref: "#/components/schemas/ChatCompletionRequestToolMessageContentPart" + minItems: 1 description: The contents of the tool message. tool_call_id: type: string @@ -7833,9 +9332,69 @@ components: description: The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. parameters: $ref: "#/components/schemas/FunctionParameters" + strict: + type: boolean + nullable: true + default: false + description: Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling). required: - name + ResponseFormatText: + type: object + properties: + type: + type: string + description: "The type of response format being defined: `text`" + enum: ["text"] + required: + - type + + ResponseFormatJsonObject: + type: object + properties: + type: + type: string + description: "The type of response format being defined: `json_object`" + enum: ["json_object"] + required: + - type + + ResponseFormatJsonSchemaSchema: + type: object + description: "The schema for the response format, described as a JSON Schema object." + additionalProperties: true + + ResponseFormatJsonSchema: + type: object + properties: + type: + type: string + description: "The type of response format being defined: `json_schema`" + enum: ["json_schema"] + json_schema: + type: object + properties: + description: + type: string + description: A description of what the response format is for, used by the model to determine how to respond in the format. + name: + type: string + description: The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + schema: + $ref: "#/components/schemas/ResponseFormatJsonSchemaSchema" + strict: + type: boolean + nullable: true + default: false + description: Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs). + required: + - type + - name + required: + - type + - json_schema + ChatCompletionToolChoiceOption: description: | Controls which (if any) tool is called by the model. @@ -7970,6 +9529,10 @@ components: type: string description: The contents of the message. nullable: true + refusal: + type: string + description: The refusal message generated by the model. + nullable: true tool_calls: $ref: "#/components/schemas/ChatCompletionMessageToolCalls" role: @@ -7993,6 +9556,7 @@ components: required: - role - content + - refusal ChatCompletionStreamResponseDelta: type: object @@ -8021,6 +9585,10 @@ components: type: string enum: ["system", "user", "assistant", "tool"] description: The role of the author of this message. + refusal: + type: string + description: The refusal message generated by the model. + nullable: true CreateChatCompletionRequest: type: object @@ -8033,7 +9601,7 @@ components: $ref: "#/components/schemas/ChatCompletionRequestMessage" model: description: ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API. - example: "gpt-4-turbo" + example: "gpt-4o" anyOf: - type: string - type: string @@ -8041,6 +9609,8 @@ components: [ "gpt-4o", "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "chatgpt-4o-latest", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4-turbo", @@ -8116,20 +9686,19 @@ components: nullable: true description: *completions_presence_penalty_description response_format: - type: object description: | - An object specifying the format that the model must output. Compatible with [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`. + An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`. + + Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs). Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON. **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length. - properties: - type: - type: string - enum: ["text", "json_object"] - example: "json_object" - default: "text" - description: Must be one of `text` or `json_object`. + oneOf: + - $ref: "#/components/schemas/ResponseFormatText" + - $ref: "#/components/schemas/ResponseFormatJsonObject" + - $ref: "#/components/schemas/ResponseFormatJsonSchema" + x-oaiExpandable: true seed: type: integer minimum: -9223372036854775808 @@ -8286,14 +9855,203 @@ components: items: $ref: "#/components/schemas/ChatCompletionTokenLogprob" nullable: true + refusal: + description: A list of message refusal tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + nullable: true required: - content + - refusal + + created: + type: integer + description: The Unix timestamp (in seconds) of when the chat completion was created. + model: + type: string + description: The model used for the chat completion. + service_tier: + description: The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. + type: string + enum: ["scale", "default"] + example: "scale" + nullable: true + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always `chat.completion`. + enum: [chat.completion] + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + x-oaiMeta: + name: The chat completion object + group: chat + example: *chat_completion_example + + CreateChatCompletionFunctionResponse: + type: object + description: Represents a chat completion response returned by model, based on the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. + choices: + type: array + description: A list of chat completion choices. Can be more than one if `n` is greater than 1. + items: + type: object + required: + - finish_reason + - index + - message + - logprobs + properties: + finish_reason: + type: string + description: + &chat_completion_function_finish_reason_description | + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function. + enum: ["stop", "length", "function_call", "content_filter"] + index: + type: integer + description: The index of the choice in the list of choices. + message: + $ref: "#/components/schemas/ChatCompletionResponseMessage" + created: + type: integer + description: The Unix timestamp (in seconds) of when the chat completion was created. + model: + type: string + description: The model used for the chat completion. + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always `chat.completion`. + enum: [chat.completion] + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + x-oaiMeta: + name: The chat completion object + group: chat + example: *chat_completion_function_example + + ChatCompletionTokenLogprob: + type: object + properties: + token: &chat_completion_response_logprobs_token + description: The token. + type: string + logprob: &chat_completion_response_logprobs_token_logprob + description: The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. + type: number + bytes: &chat_completion_response_logprobs_bytes + description: A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. + type: array + items: + type: integer + nullable: true + top_logprobs: + description: List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned. + type: array + items: + type: object + properties: + token: *chat_completion_response_logprobs_token + logprob: *chat_completion_response_logprobs_token_logprob + bytes: *chat_completion_response_logprobs_bytes + required: + - token + - logprob + - bytes + required: + - token + - logprob + - bytes + - top_logprobs + + ListPaginatedFineTuningJobsResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/FineTuningJob" + has_more: + type: boolean + object: + type: string + enum: [list] + required: + - object + - data + - has_more + + CreateChatCompletionStreamResponse: + type: object + description: Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. Each chunk has the same ID. + choices: + type: array + description: | + A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the + last chunk if you set `stream_options: {"include_usage": true}`. + items: + type: object + required: + - delta + - finish_reason + - index + properties: + delta: + $ref: "#/components/schemas/ChatCompletionStreamResponseDelta" + logprobs: *chat_completion_response_logprobs + finish_reason: + type: string + description: *chat_completion_finish_reason_description + enum: + [ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + ] + nullable: true + index: + type: integer + description: The index of the choice in the list of choices. created: type: integer - description: The Unix timestamp (in seconds) of when the chat completion was created. + description: The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp. model: type: string - description: The model used for the chat completion. + description: The model to generate the completion. service_tier: description: The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. type: string @@ -8304,14 +10062,30 @@ components: type: string description: | This fingerprint represents the backend configuration that the model runs with. - Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. object: type: string - description: The object type, which is always `chat.completion`. - enum: [chat.completion] + description: The object type, which is always `chat.completion.chunk`. + enum: [chat.completion.chunk] usage: - $ref: "#/components/schemas/CompletionUsage" + type: object + description: | + An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request. + When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request. + properties: + completion_tokens: + type: integer + description: Number of tokens in the generated completion. + prompt_tokens: + type: integer + description: Number of tokens in the prompt. + total_tokens: + type: integer + description: Total number of tokens used in the request (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens required: - choices - created @@ -8319,1937 +10093,2723 @@ components: - model - object x-oaiMeta: - name: The chat completion object + name: The chat completion chunk object group: chat - example: *chat_completion_example + example: *chat_completion_chunk_example - CreateChatCompletionFunctionResponse: + CreateChatCompletionImageResponse: + type: object + description: Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + x-oaiMeta: + name: The chat completion chunk object + group: chat + example: *chat_completion_image_example + + CreateImageRequest: type: object - description: Represents a chat completion response returned by model, based on the provided input. properties: - id: + prompt: + description: A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`. type: string - description: A unique identifier for the chat completion. - choices: - type: array - description: A list of chat completion choices. Can be more than one if `n` is greater than 1. - items: - type: object - required: - - finish_reason - - index - - message - - logprobs - properties: - finish_reason: - type: string - description: - &chat_completion_function_finish_reason_description | - The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function. - enum: ["stop", "length", "function_call", "content_filter"] - index: - type: integer - description: The index of the choice in the list of choices. - message: - $ref: "#/components/schemas/ChatCompletionResponseMessage" + example: "A cute baby sea otter" + model: + anyOf: + - type: string + - type: string + enum: ["dall-e-2", "dall-e-3"] + x-oaiTypeLabel: string + default: "dall-e-2" + example: "dall-e-3" + nullable: true + description: The model to use for image generation. + n: &images_n + type: integer + minimum: 1 + maximum: 10 + default: 1 + example: 1 + nullable: true + description: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported. + quality: + type: string + enum: ["standard", "hd"] + default: "standard" + example: "standard" + description: The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`. + response_format: &images_response_format + type: string + enum: ["url", "b64_json"] + default: "url" + example: "url" + nullable: true + description: The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated. + size: &images_size + type: string + enum: ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] + default: "1024x1024" + example: "1024x1024" + nullable: true + description: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models. + style: + type: string + enum: ["vivid", "natural"] + default: "vivid" + example: "vivid" + nullable: true + description: The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`. + user: *end_user_param_configuration + required: + - prompt + + ImagesResponse: + properties: created: type: integer - description: The Unix timestamp (in seconds) of when the chat completion was created. - model: + data: + type: array + items: + $ref: "#/components/schemas/Image" + required: + - created + - data + + Image: + type: object + description: Represents the url or the content of an image generated by the OpenAI API. + properties: + b64_json: type: string - description: The model used for the chat completion. - system_fingerprint: + description: The base64-encoded JSON of the generated image, if `response_format` is `b64_json`. + url: type: string - description: | - This fingerprint represents the backend configuration that the model runs with. + description: The URL of the generated image, if `response_format` is `url` (default). + revised_prompt: + type: string + description: The prompt that was used to generate the image, if there was any revision to the prompt. + x-oaiMeta: + name: The image object + example: | + { + "url": "...", + "revised_prompt": "..." + } - Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. - object: + CreateImageEditRequest: + type: object + properties: + image: + description: The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask. type: string - description: The object type, which is always `chat.completion`. - enum: [chat.completion] - usage: - $ref: "#/components/schemas/CompletionUsage" + format: binary + prompt: + description: A text description of the desired image(s). The maximum length is 1000 characters. + type: string + example: "A cute baby sea otter wearing a beret" + mask: + description: An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`. + type: string + format: binary + model: + anyOf: + - type: string + - type: string + enum: ["dall-e-2"] + x-oaiTypeLabel: string + default: "dall-e-2" + example: "dall-e-2" + nullable: true + description: The model to use for image generation. Only `dall-e-2` is supported at this time. + n: + type: integer + minimum: 1 + maximum: 10 + default: 1 + example: 1 + nullable: true + description: The number of images to generate. Must be between 1 and 10. + size: &dalle2_images_size + type: string + enum: ["256x256", "512x512", "1024x1024"] + default: "1024x1024" + example: "1024x1024" + nullable: true + description: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. + response_format: *images_response_format + user: *end_user_param_configuration + required: + - prompt + - image + + CreateImageVariationRequest: + type: object + properties: + image: + description: The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. + type: string + format: binary + model: + anyOf: + - type: string + - type: string + enum: ["dall-e-2"] + x-oaiTypeLabel: string + default: "dall-e-2" + example: "dall-e-2" + nullable: true + description: The model to use for image generation. Only `dall-e-2` is supported at this time. + n: *images_n + response_format: *images_response_format + size: *dalle2_images_size + user: *end_user_param_configuration + required: + - image + + CreateModerationRequest: + type: object + properties: + input: + description: The input text to classify + oneOf: + - type: string + default: "" + example: "I want to kill them." + - type: array + items: + type: string + default: "" + example: "I want to kill them." + model: + description: | + Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`. + + The default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`. + nullable: false + default: "text-moderation-latest" + example: "text-moderation-stable" + anyOf: + - type: string + - type: string + enum: ["text-moderation-latest", "text-moderation-stable"] + x-oaiTypeLabel: string required: - - choices - - created - - id - - model - - object - x-oaiMeta: - name: The chat completion object - group: chat - example: *chat_completion_function_example + - input - ChatCompletionTokenLogprob: + CreateModerationResponse: type: object + description: Represents if a given text input is potentially harmful. properties: - token: &chat_completion_response_logprobs_token - description: The token. + id: type: string - logprob: &chat_completion_response_logprobs_token_logprob - description: The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. - type: number - bytes: &chat_completion_response_logprobs_bytes - description: A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. - type: array - items: - type: integer - nullable: true - top_logprobs: - description: List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned. + description: The unique identifier for the moderation request. + model: + type: string + description: The model used to generate the moderation results. + results: type: array + description: A list of moderation objects. items: type: object properties: - token: *chat_completion_response_logprobs_token - logprob: *chat_completion_response_logprobs_token_logprob - bytes: *chat_completion_response_logprobs_bytes + flagged: + type: boolean + description: Whether any of the below categories are flagged. + categories: + type: object + description: A list of the categories, and whether they are flagged or not. + properties: + hate: + type: boolean + description: Content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. Hateful content aimed at non-protected groups (e.g., chess players) is harassment. + hate/threatening: + type: boolean + description: Hateful content that also includes violence or serious harm towards the targeted group based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. + harassment: + type: boolean + description: Content that expresses, incites, or promotes harassing language towards any target. + harassment/threatening: + type: boolean + description: Harassment content that also includes violence or serious harm towards any target. + self-harm: + type: boolean + description: Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders. + self-harm/intent: + type: boolean + description: Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders. + self-harm/instructions: + type: boolean + description: Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts. + sexual: + type: boolean + description: Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness). + sexual/minors: + type: boolean + description: Sexual content that includes an individual who is under 18 years old. + violence: + type: boolean + description: Content that depicts death, violence, or physical injury. + violence/graphic: + type: boolean + description: Content that depicts death, violence, or physical injury in graphic detail. + required: + - hate + - hate/threatening + - harassment + - harassment/threatening + - self-harm + - self-harm/intent + - self-harm/instructions + - sexual + - sexual/minors + - violence + - violence/graphic + category_scores: + type: object + description: A list of the categories along with their scores as predicted by model. + properties: + hate: + type: number + description: The score for the category 'hate'. + hate/threatening: + type: number + description: The score for the category 'hate/threatening'. + harassment: + type: number + description: The score for the category 'harassment'. + harassment/threatening: + type: number + description: The score for the category 'harassment/threatening'. + self-harm: + type: number + description: The score for the category 'self-harm'. + self-harm/intent: + type: number + description: The score for the category 'self-harm/intent'. + self-harm/instructions: + type: number + description: The score for the category 'self-harm/instructions'. + sexual: + type: number + description: The score for the category 'sexual'. + sexual/minors: + type: number + description: The score for the category 'sexual/minors'. + violence: + type: number + description: The score for the category 'violence'. + violence/graphic: + type: number + description: The score for the category 'violence/graphic'. + required: + - hate + - hate/threatening + - harassment + - harassment/threatening + - self-harm + - self-harm/intent + - self-harm/instructions + - sexual + - sexual/minors + - violence + - violence/graphic required: - - token - - logprob - - bytes + - flagged + - categories + - category_scores required: - - token - - logprob - - bytes - - top_logprobs + - id + - model + - results + x-oaiMeta: + name: The moderation object + example: *moderation_example - ListPaginatedFineTuningJobsResponse: + ListFilesResponse: type: object properties: data: type: array items: - $ref: "#/components/schemas/FineTuningJob" - has_more: - type: boolean + $ref: "#/components/schemas/OpenAIFile" object: type: string enum: [list] - required: - - object - - data - - has_more - - CreateChatCompletionStreamResponse: - type: object - description: Represents a streamed chunk of a chat completion response returned by model, based on the provided input. - properties: - id: - type: string - description: A unique identifier for the chat completion. Each chunk has the same ID. - choices: - type: array - description: | - A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the - last chunk if you set `stream_options: {"include_usage": true}`. - items: - type: object - required: - - delta - - finish_reason - - index - properties: - delta: - $ref: "#/components/schemas/ChatCompletionStreamResponseDelta" - logprobs: *chat_completion_response_logprobs - finish_reason: - type: string - description: *chat_completion_finish_reason_description - enum: - [ - "stop", - "length", - "tool_calls", - "content_filter", - "function_call", - ] - nullable: true - index: - type: integer - description: The index of the choice in the list of choices. - created: - type: integer - description: The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp. - model: - type: string - description: The model to generate the completion. - service_tier: - description: The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. - type: string - enum: ["scale", "default"] - example: "scale" - nullable: true - system_fingerprint: - type: string - description: | - This fingerprint represents the backend configuration that the model runs with. - Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. - object: - type: string - description: The object type, which is always `chat.completion.chunk`. - enum: [chat.completion.chunk] - usage: - type: object - description: | - An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request. - When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request. - properties: - completion_tokens: - type: integer - description: Number of tokens in the generated completion. - prompt_tokens: - type: integer - description: Number of tokens in the prompt. - total_tokens: - type: integer - description: Total number of tokens used in the request (prompt + completion). - required: - - prompt_tokens - - completion_tokens - - total_tokens - required: - - choices - - created - - id - - model + required: - object - x-oaiMeta: - name: The chat completion chunk object - group: chat - example: *chat_completion_chunk_example + - data - CreateChatCompletionImageResponse: + CreateFileRequest: type: object - description: Represents a streamed chunk of a chat completion response returned by model, based on the provided input. - x-oaiMeta: - name: The chat completion chunk object - group: chat - example: *chat_completion_image_example + additionalProperties: false + properties: + file: + description: | + The File object (not file name) to be uploaded. + type: string + format: binary + purpose: + description: | + The intended purpose of the uploaded file. - CreateImageRequest: + Use "assistants" for [Assistants](/docs/api-reference/assistants) and [Message](/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](/docs/guides/batch), and "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning). + type: string + enum: ["assistants", "batch", "fine-tune", "vision"] + required: + - file + - purpose + + DeleteFileResponse: type: object properties: - prompt: - description: A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`. + id: type: string - example: "A cute baby sea otter" - model: - anyOf: - - type: string - - type: string - enum: ["dall-e-2", "dall-e-3"] - x-oaiTypeLabel: string - default: "dall-e-2" - example: "dall-e-3" - nullable: true - description: The model to use for image generation. - n: &images_n - type: integer - minimum: 1 - maximum: 10 - default: 1 - example: 1 - nullable: true - description: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported. - quality: + object: type: string - enum: ["standard", "hd"] - default: "standard" - example: "standard" - description: The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`. - response_format: &images_response_format + enum: [file] + deleted: + type: boolean + required: + - id + - object + - deleted + + CreateUploadRequest: + type: object + additionalProperties: false + properties: + filename: + description: | + The name of the file to upload. type: string - enum: ["url", "b64_json"] - default: "url" - example: "url" - nullable: true - description: The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated. - size: &images_size + purpose: + description: | + The intended purpose of the uploaded file. + + See the [documentation on File purposes](/docs/api-reference/files/create#files-create-purpose). type: string - enum: ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] - default: "1024x1024" - example: "1024x1024" - nullable: true - description: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models. - style: + enum: ["assistants", "batch", "fine-tune", "vision"] + bytes: + description: | + The number of bytes in the file you are uploading. + type: integer + mime_type: + description: | + The MIME type of the file. + + This must fall within the supported MIME types for your file purpose. See the supported MIME types for assistants and vision. type: string - enum: ["vivid", "natural"] - default: "vivid" - example: "vivid" - nullable: true - description: The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`. - user: *end_user_param_configuration required: - - prompt + - filename + - purpose + - bytes + - mime_type - ImagesResponse: + AddUploadPartRequest: + type: object + additionalProperties: false properties: - created: - type: integer data: - type: array - items: - $ref: "#/components/schemas/Image" + description: | + The chunk of bytes for this Part. + type: string + format: binary required: - - created - data - Image: + CompleteUploadRequest: type: object - description: Represents the url or the content of an image generated by the OpenAI API. + additionalProperties: false properties: - b64_json: - type: string - description: The base64-encoded JSON of the generated image, if `response_format` is `b64_json`. - url: - type: string - description: The URL of the generated image, if `response_format` is `url` (default). - revised_prompt: + part_ids: + type: array + description: | + The ordered list of Part IDs. + items: + type: string + md5: + description: | + The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect. type: string - description: The prompt that was used to generate the image, if there was any revision to the prompt. - x-oaiMeta: - name: The image object - example: | - { - "url": "...", - "revised_prompt": "..." - } + required: + - part_ids - CreateImageEditRequest: + CancelUploadRequest: + type: object + additionalProperties: false + + CreateFineTuningJobRequest: type: object properties: - image: - description: The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask. - type: string - format: binary - prompt: - description: A text description of the desired image(s). The maximum length is 1000 characters. - type: string - example: "A cute baby sea otter wearing a beret" - mask: - description: An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`. - type: string - format: binary model: + description: | + The name of the model to fine-tune. You can select one of the + [supported models](/docs/guides/fine-tuning/which-models-can-be-fine-tuned). + example: "gpt-4o-mini" anyOf: - type: string - type: string - enum: ["dall-e-2"] + enum: + ["babbage-002", "davinci-002", "gpt-3.5-turbo", "gpt-4o-mini"] x-oaiTypeLabel: string - default: "dall-e-2" - example: "dall-e-2" - nullable: true - description: The model to use for image generation. Only `dall-e-2` is supported at this time. - n: - type: integer - minimum: 1 - maximum: 10 - default: 1 - example: 1 + training_file: + description: | + The ID of an uploaded file that contains training data. + + See [upload file](/docs/api-reference/files/create) for how to upload a file. + + Your dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`. + + The contents of the file should differ depending on if the model uses the [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) format. + + See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + type: string + example: "file-abc123" + hyperparameters: + type: object + description: The hyperparameters used for the fine-tuning job. + properties: + batch_size: + description: | + Number of examples in each batch. A larger batch size means that model parameters + are updated less frequently, but with lower variance. + oneOf: + - type: string + enum: [auto] + - type: integer + minimum: 1 + maximum: 256 + default: auto + learning_rate_multiplier: + description: | + Scaling factor for the learning rate. A smaller learning rate may be useful to avoid + overfitting. + oneOf: + - type: string + enum: [auto] + - type: number + minimum: 0 + exclusiveMinimum: true + default: auto + n_epochs: + description: | + The number of epochs to train the model for. An epoch refers to one full cycle + through the training dataset. + oneOf: + - type: string + enum: [auto] + - type: integer + minimum: 1 + maximum: 50 + default: auto + suffix: + description: | + A string of up to 18 characters that will be added to your fine-tuned model name. + + For example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`. + type: string + minLength: 1 + maxLength: 40 + default: null nullable: true - description: The number of images to generate. Must be between 1 and 10. - size: &dalle2_images_size + validation_file: + description: | + The ID of an uploaded file that contains validation data. + + If you provide this file, the data is used to generate validation + metrics periodically during fine-tuning. These metrics can be viewed in + the fine-tuning results file. + The same data should not be present in both train and validation files. + + Your dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`. + + See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. type: string - enum: ["256x256", "512x512", "1024x1024"] - default: "1024x1024" - example: "1024x1024" nullable: true - description: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. - response_format: *images_response_format - user: *end_user_param_configuration - required: - - prompt - - image + example: "file-abc123" + integrations: + type: array + description: A list of integrations to enable for your fine-tuning job. + nullable: true + items: + type: object + required: + - type + - wandb + properties: + type: + description: | + The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported. + oneOf: + - type: string + enum: [wandb] + wandb: + type: object + description: | + The settings for your integration with Weights and Biases. This payload specifies the project that + metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags + to your run, and set a default entity (team, username, etc) to be associated with your run. + required: + - project + properties: + project: + description: | + The name of the project that the new run will be created under. + type: string + example: "my-wandb-project" + name: + description: | + A display name to set for the run. If not set, we will use the Job ID as the name. + nullable: true + type: string + entity: + description: | + The entity to use for the run. This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered WandB API key is used. + nullable: true + type: string + tags: + description: | + A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some + default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + type: array + items: + type: string + example: "custom-tag" - CreateImageVariationRequest: - type: object - properties: - image: - description: The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. - type: string - format: binary - model: - anyOf: - - type: string - - type: string - enum: ["dall-e-2"] - x-oaiTypeLabel: string - default: "dall-e-2" - example: "dall-e-2" + seed: + description: | + The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases. + If a seed is not specified, one will be generated for you. + type: integer nullable: true - description: The model to use for image generation. Only `dall-e-2` is supported at this time. - n: *images_n - response_format: *images_response_format - size: *dalle2_images_size - user: *end_user_param_configuration + minimum: 0 + maximum: 2147483647 + example: 42 required: - - image + - model + - training_file - CreateModerationRequest: + ListFineTuningJobEventsResponse: type: object properties: - input: - description: The input text to classify - oneOf: - - type: string - default: "" - example: "I want to kill them." - - type: array - items: - type: string - default: "" - example: "I want to kill them." - model: - description: | - Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`. - - The default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`. - nullable: false - default: "text-moderation-latest" - example: "text-moderation-stable" - anyOf: - - type: string - - type: string - enum: ["text-moderation-latest", "text-moderation-stable"] - x-oaiTypeLabel: string + data: + type: array + items: + $ref: "#/components/schemas/FineTuningJobEvent" + object: + type: string + enum: [list] required: - - input + - object + - data - CreateModerationResponse: + ListFineTuningJobCheckpointsResponse: type: object - description: Represents if a given text input is potentially harmful. properties: - id: - type: string - description: The unique identifier for the moderation request. - model: - type: string - description: The model used to generate the moderation results. - results: + data: type: array - description: A list of moderation objects. items: - type: object - properties: - flagged: - type: boolean - description: Whether any of the below categories are flagged. - categories: - type: object - description: A list of the categories, and whether they are flagged or not. - properties: - hate: - type: boolean - description: Content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. Hateful content aimed at non-protected groups (e.g., chess players) is harassment. - hate/threatening: - type: boolean - description: Hateful content that also includes violence or serious harm towards the targeted group based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. - harassment: - type: boolean - description: Content that expresses, incites, or promotes harassing language towards any target. - harassment/threatening: - type: boolean - description: Harassment content that also includes violence or serious harm towards any target. - self-harm: - type: boolean - description: Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders. - self-harm/intent: - type: boolean - description: Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders. - self-harm/instructions: - type: boolean - description: Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts. - sexual: - type: boolean - description: Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness). - sexual/minors: - type: boolean - description: Sexual content that includes an individual who is under 18 years old. - violence: - type: boolean - description: Content that depicts death, violence, or physical injury. - violence/graphic: - type: boolean - description: Content that depicts death, violence, or physical injury in graphic detail. - required: - - hate - - hate/threatening - - harassment - - harassment/threatening - - self-harm - - self-harm/intent - - self-harm/instructions - - sexual - - sexual/minors - - violence - - violence/graphic - category_scores: - type: object - description: A list of the categories along with their scores as predicted by model. - properties: - hate: - type: number - description: The score for the category 'hate'. - hate/threatening: - type: number - description: The score for the category 'hate/threatening'. - harassment: - type: number - description: The score for the category 'harassment'. - harassment/threatening: - type: number - description: The score for the category 'harassment/threatening'. - self-harm: - type: number - description: The score for the category 'self-harm'. - self-harm/intent: - type: number - description: The score for the category 'self-harm/intent'. - self-harm/instructions: - type: number - description: The score for the category 'self-harm/instructions'. - sexual: - type: number - description: The score for the category 'sexual'. - sexual/minors: - type: number - description: The score for the category 'sexual/minors'. - violence: - type: number - description: The score for the category 'violence'. - violence/graphic: - type: number - description: The score for the category 'violence/graphic'. - required: - - hate - - hate/threatening - - harassment - - harassment/threatening - - self-harm - - self-harm/intent - - self-harm/instructions - - sexual - - sexual/minors - - violence - - violence/graphic - required: - - flagged - - categories - - category_scores + $ref: "#/components/schemas/FineTuningJobCheckpoint" + object: + type: string + enum: [list] + first_id: + type: string + nullable: true + last_id: + type: string + nullable: true + has_more: + type: boolean + required: + - object + - data + - has_more + + CreateEmbeddingRequest: + type: object + additionalProperties: false + properties: + input: + description: | + Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. + example: "The quick brown fox jumped over the lazy dog" + oneOf: + - type: string + title: string + description: The string that will be turned into an embedding. + default: "" + example: "This is a test." + - type: array + title: array + description: The array of strings that will be turned into an embedding. + minItems: 1 + maxItems: 2048 + items: + type: string + default: "" + example: "['This is a test.']" + - type: array + title: array + description: The array of integers that will be turned into an embedding. + minItems: 1 + maxItems: 2048 + items: + type: integer + example: "[1212, 318, 257, 1332, 13]" + - type: array + title: array + description: The array of arrays containing integers that will be turned into an embedding. + minItems: 1 + maxItems: 2048 + items: + type: array + minItems: 1 + items: + type: integer + example: "[[1212, 318, 257, 1332, 13]]" + x-oaiExpandable: true + model: + description: *model_description + example: "text-embedding-3-small" + anyOf: + - type: string + - type: string + enum: + [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ] + x-oaiTypeLabel: string + encoding_format: + description: "The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/)." + example: "float" + default: "float" + type: string + enum: ["float", "base64"] + dimensions: + description: | + The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models. + type: integer + minimum: 1 + user: *end_user_param_configuration required: - - id - model - - results - x-oaiMeta: - name: The moderation object - example: *moderation_example + - input - ListFilesResponse: + CreateEmbeddingResponse: type: object properties: data: type: array + description: The list of embeddings generated by the model. items: - $ref: "#/components/schemas/OpenAIFile" + $ref: "#/components/schemas/Embedding" + model: + type: string + description: The name of the model used to generate the embedding. object: type: string + description: The object type, which is always "list". enum: [list] + usage: + type: object + description: The usage information for the request. + properties: + prompt_tokens: + type: integer + description: The number of tokens used by the prompt. + total_tokens: + type: integer + description: The total number of tokens used by the request. + required: + - prompt_tokens + - total_tokens required: - object + - model - data + - usage - CreateFileRequest: + CreateTranscriptionRequest: type: object additionalProperties: false properties: file: description: | - The File object (not file name) to be uploaded. + The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. type: string + x-oaiTypeLabel: file format: binary - purpose: + model: description: | - The intended purpose of the uploaded file. - - Use "assistants" for [Assistants](/docs/api-reference/assistants) and [Message](/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](/docs/guides/batch), and "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning). + ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. + example: whisper-1 + anyOf: + - type: string + - type: string + enum: ["whisper-1"] + x-oaiTypeLabel: string + language: + description: | + The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency. type: string - enum: ["assistants", "batch", "fine-tune", "vision"] + prompt: + description: | + An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language. + type: string + response_format: + description: | + The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`. + type: string + enum: + - json + - text + - srt + - verbose_json + - vtt + default: json + temperature: + description: | + The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + type: number + default: 0 + timestamp_granularities[]: + description: | + The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency. + type: array + items: + type: string + enum: + - word + - segment + default: [segment] required: - file - - purpose + - model - DeleteFileResponse: + # Note: This does not currently support the non-default response format types. + CreateTranscriptionResponseJson: type: object + description: Represents a transcription response returned by model, based on the provided input. properties: - id: + text: type: string - object: + description: The transcribed text. + required: + - text + x-oaiMeta: + name: The transcription object (JSON) + group: audio + example: *basic_transcription_response_example + + TranscriptionSegment: + type: object + properties: + id: + type: integer + description: Unique identifier of the segment. + seek: + type: integer + description: Seek offset of the segment. + start: + type: number + format: float + description: Start time of the segment in seconds. + end: + type: number + format: float + description: End time of the segment in seconds. + text: type: string - enum: [file] - deleted: - type: boolean + description: Text content of the segment. + tokens: + type: array + items: + type: integer + description: Array of token IDs for the text content. + temperature: + type: number + format: float + description: Temperature parameter used for generating the segment. + avg_logprob: + type: number + format: float + description: Average logprob of the segment. If the value is lower than -1, consider the logprobs failed. + compression_ratio: + type: number + format: float + description: Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed. + no_speech_prob: + type: number + format: float + description: Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this segment silent. required: - id - - object - - deleted + - seek + - start + - end + - text + - tokens + - temperature + - avg_logprob + - compression_ratio + - no_speech_prob + + TranscriptionWord: + type: object + properties: + word: + type: string + description: The text content of the word. + start: + type: number + format: float + description: Start time of the word in seconds. + end: + type: number + format: float + description: End time of the word in seconds. + required: [word, start, end] + + CreateTranscriptionResponseVerboseJson: + type: object + description: Represents a verbose json transcription response returned by model, based on the provided input. + properties: + language: + type: string + description: The language of the input audio. + duration: + type: string + description: The duration of the input audio. + text: + type: string + description: The transcribed text. + words: + type: array + description: Extracted words and their corresponding timestamps. + items: + $ref: "#/components/schemas/TranscriptionWord" + segments: + type: array + description: Segments of the transcribed text and their corresponding details. + items: + $ref: "#/components/schemas/TranscriptionSegment" + required: [language, duration, text] + x-oaiMeta: + name: The transcription object (Verbose JSON) + group: audio + example: *verbose_transcription_response_example - CreateUploadRequest: + CreateTranslationRequest: type: object additionalProperties: false properties: - filename: + file: description: | - The name of the file to upload. + The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. type: string - purpose: + x-oaiTypeLabel: file + format: binary + model: description: | - The intended purpose of the uploaded file. - - See the [documentation on File purposes](/docs/api-reference/files/create#files-create-purpose). - type: string - enum: ["assistants", "batch", "fine-tune", "vision"] - bytes: + ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. + example: whisper-1 + anyOf: + - type: string + - type: string + enum: ["whisper-1"] + x-oaiTypeLabel: string + prompt: description: | - The number of bytes in the file you are uploading. - type: integer - mime_type: + An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English. + type: string + response_format: description: | - The MIME type of the file. - - This must fall within the supported MIME types for your file purpose. See the supported MIME types for assistants and vision. + The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`. type: string + default: json + temperature: + description: | + The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + type: number + default: 0 required: - - filename - - purpose - - bytes - - mime_type + - file + - model - AddUploadPartRequest: + # Note: This does not currently support the non-default response format types. + CreateTranslationResponseJson: type: object - additionalProperties: false properties: - data: - description: | - The chunk of bytes for this Part. + text: type: string - format: binary required: - - data + - text - CompleteUploadRequest: + CreateTranslationResponseVerboseJson: type: object - additionalProperties: false properties: - part_ids: + language: + type: string + description: The language of the output translation (always `english`). + duration: + type: string + description: The duration of the input audio. + text: + type: string + description: The translated text. + segments: type: array - description: | - The ordered list of Part IDs. + description: Segments of the translated text and their corresponding details. items: - type: string - md5: - description: | - The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect. - type: string - required: - - part_ids + $ref: "#/components/schemas/TranscriptionSegment" + required: [language, duration, text] - CancelUploadRequest: + CreateSpeechRequest: type: object additionalProperties: false - - CreateFineTuningJobRequest: - type: object properties: model: description: | - The name of the model to fine-tune. You can select one of the - [supported models](/docs/guides/fine-tuning/what-models-can-be-fine-tuned). - example: "gpt-3.5-turbo" + One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd` anyOf: - type: string - type: string - enum: ["babbage-002", "davinci-002", "gpt-3.5-turbo"] + enum: ["tts-1", "tts-1-hd"] x-oaiTypeLabel: string - training_file: - description: | - The ID of an uploaded file that contains training data. - - See [upload file](/docs/api-reference/files/create) for how to upload a file. - - Your dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`. - - The contents of the file should differ depending on if the model uses the [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) format. - - See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + input: type: string - example: "file-abc123" - hyperparameters: - type: object - description: The hyperparameters used for the fine-tuning job. - properties: - batch_size: - description: | - Number of examples in each batch. A larger batch size means that model parameters - are updated less frequently, but with lower variance. - oneOf: - - type: string - enum: [auto] - - type: integer - minimum: 1 - maximum: 256 - default: auto - learning_rate_multiplier: - description: | - Scaling factor for the learning rate. A smaller learning rate may be useful to avoid - overfitting. - oneOf: - - type: string - enum: [auto] - - type: number - minimum: 0 - exclusiveMinimum: true - default: auto - n_epochs: - description: | - The number of epochs to train the model for. An epoch refers to one full cycle - through the training dataset. - oneOf: - - type: string - enum: [auto] - - type: integer - minimum: 1 - maximum: 50 - default: auto - suffix: - description: | - A string of up to 18 characters that will be added to your fine-tuned model name. - - For example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-3.5-turbo:openai:custom-model-name:7p4lURel`. + description: The text to generate audio for. The maximum length is 4096 characters. + maxLength: 4096 + voice: + description: The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are available in the [Text to speech guide](/docs/guides/text-to-speech/voice-options). type: string - minLength: 1 - maxLength: 40 - default: null - nullable: true - validation_file: - description: | - The ID of an uploaded file that contains validation data. - - If you provide this file, the data is used to generate validation - metrics periodically during fine-tuning. These metrics can be viewed in - the fine-tuning results file. - The same data should not be present in both train and validation files. - - Your dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`. - - See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + enum: ["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + response_format: + description: "The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`." + default: "mp3" type: string - nullable: true - example: "file-abc123" - integrations: - type: array - description: A list of integrations to enable for your fine-tuning job. - nullable: true - items: - type: object - required: - - type - - wandb - properties: - type: - description: | - The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported. - oneOf: - - type: string - enum: [wandb] - wandb: - type: object - description: | - The settings for your integration with Weights and Biases. This payload specifies the project that - metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags - to your run, and set a default entity (team, username, etc) to be associated with your run. - required: - - project - properties: - project: - description: | - The name of the project that the new run will be created under. - type: string - example: "my-wandb-project" - name: - description: | - A display name to set for the run. If not set, we will use the Job ID as the name. - nullable: true - type: string - entity: - description: | - The entity to use for the run. This allows you to set the team or username of the WandB user that you would - like associated with the run. If not set, the default entity for the registered WandB API key is used. - nullable: true - type: string - tags: - description: | - A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some - default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". - type: array - items: - type: string - example: "custom-tag" - - seed: - description: | - The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases. - If a seed is not specified, one will be generated for you. - type: integer - nullable: true - minimum: 0 - maximum: 2147483647 - example: 42 + enum: ["mp3", "opus", "aac", "flac", "wav", "pcm"] + speed: + description: "The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default." + type: number + default: 1.0 + minimum: 0.25 + maximum: 4.0 required: - model - - training_file + - input + - voice - ListFineTuningJobEventsResponse: - type: object + Model: + title: Model + description: Describes an OpenAI model offering that can be used with the API. properties: - data: - type: array - items: - $ref: "#/components/schemas/FineTuningJobEvent" + id: + type: string + description: The model identifier, which can be referenced in the API endpoints. + created: + type: integer + description: The Unix timestamp (in seconds) when the model was created. object: type: string - enum: [list] + description: The object type, which is always "model". + enum: [model] + owned_by: + type: string + description: The organization that owns the model. required: + - id - object - - data + - created + - owned_by + x-oaiMeta: + name: The model object + example: *retrieve_model_response - ListFineTuningJobCheckpointsResponse: - type: object + OpenAIFile: + title: OpenAIFile + description: The `File` object represents a document that has been uploaded to OpenAI. properties: - data: - type: array - items: - $ref: "#/components/schemas/FineTuningJobCheckpoint" + id: + type: string + description: The file identifier, which can be referenced in the API endpoints. + bytes: + type: integer + description: The size of the file, in bytes. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the file was created. + filename: + type: string + description: The name of the file. object: type: string - enum: [list] - first_id: + description: The object type, which is always `file`. + enum: ["file"] + purpose: type: string - nullable: true - last_id: + description: The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`. + enum: + [ + "assistants", + "assistants_output", + "batch", + "batch_output", + "fine-tune", + "fine-tune-results", + "vision", + ] + status: type: string - nullable: true - has_more: - type: boolean + deprecated: true + description: Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`. + enum: ["uploaded", "processed", "error"] + status_details: + type: string + deprecated: true + description: Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`. required: + - id - object - - data - - has_more - - CreateEmbeddingRequest: + - bytes + - created_at + - filename + - purpose + - status + x-oaiMeta: + name: The file object + example: | + { + "id": "file-abc123", + "object": "file", + "bytes": 120000, + "created_at": 1677610602, + "filename": "salesOverview.pdf", + "purpose": "assistants", + } + Upload: type: object - additionalProperties: false + title: Upload + description: | + The Upload object can accept byte chunks in the form of Parts. properties: - input: - description: | - Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. - example: "The quick brown fox jumped over the lazy dog" - oneOf: - - type: string - title: string - description: The string that will be turned into an embedding. - default: "" - example: "This is a test." - - type: array - title: array - description: The array of strings that will be turned into an embedding. - minItems: 1 - maxItems: 2048 - items: - type: string - default: "" - example: "['This is a test.']" - - type: array - title: array - description: The array of integers that will be turned into an embedding. - minItems: 1 - maxItems: 2048 - items: - type: integer - example: "[1212, 318, 257, 1332, 13]" - - type: array - title: array - description: The array of arrays containing integers that will be turned into an embedding. - minItems: 1 - maxItems: 2048 - items: - type: array - minItems: 1 - items: - type: integer - example: "[[1212, 318, 257, 1332, 13]]" - x-oaiExpandable: true - model: - description: *model_description - example: "text-embedding-3-small" - anyOf: - - type: string - - type: string - enum: - [ - "text-embedding-ada-002", - "text-embedding-3-small", - "text-embedding-3-large", - ] - x-oaiTypeLabel: string - encoding_format: - description: "The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/)." - example: "float" - default: "float" + id: type: string - enum: ["float", "base64"] - dimensions: - description: | - The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models. + description: The Upload unique identifier, which can be referenced in API endpoints. + created_at: type: integer - minimum: 1 - user: *end_user_param_configuration + description: The Unix timestamp (in seconds) for when the Upload was created. + filename: + type: string + description: The name of the file to be uploaded. + bytes: + type: integer + description: The intended number of bytes to be uploaded. + purpose: + type: string + description: The intended purpose of the file. [Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values. + status: + type: string + description: The status of the Upload. + enum: ["pending", "completed", "cancelled", "expired"] + expires_at: + type: integer + description: The Unix timestamp (in seconds) for when the Upload was created. + object: + type: string + description: The object type, which is always "upload". + enum: [upload] + file: + $ref: "#/components/schemas/OpenAIFile" + nullable: true + description: The ready File object after the Upload is completed. required: - - model - - input - - CreateEmbeddingResponse: + - bytes + - created_at + - expires_at + - filename + - id + - purpose + - status + - step_number + x-oaiMeta: + name: The upload object + example: | + { + "id": "upload_abc123", + "object": "upload", + "bytes": 2147483648, + "created_at": 1719184911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + "status": "completed", + "expires_at": 1719127296, + "file": { + "id": "file-xyz321", + "object": "file", + "bytes": 2147483648, + "created_at": 1719186911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + } + } + UploadPart: + type: object + title: UploadPart + description: | + The upload Part represents a chunk of bytes we can add to an Upload object. + properties: + id: + type: string + description: The upload Part unique identifier, which can be referenced in API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the Part was created. + upload_id: + type: string + description: The ID of the Upload object that this Part was added to. + object: + type: string + description: The object type, which is always `upload.part`. + enum: ["upload.part"] + required: + - created_at + - id + - object + - upload_id + x-oaiMeta: + name: The upload part object + example: | + { + "id": "part_def456", + "object": "upload.part", + "created_at": 1719186911, + "upload_id": "upload_abc123" + } + Embedding: type: object + description: | + Represents an embedding vector returned by embedding endpoint. properties: - data: + index: + type: integer + description: The index of the embedding in the list of embeddings. + embedding: type: array - description: The list of embeddings generated by the model. + description: | + The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings). items: - $ref: "#/components/schemas/Embedding" - model: - type: string - description: The name of the model used to generate the embedding. + type: number object: type: string - description: The object type, which is always "list". - enum: [list] - usage: - type: object - description: The usage information for the request. - properties: - prompt_tokens: - type: integer - description: The number of tokens used by the prompt. - total_tokens: - type: integer - description: The total number of tokens used by the request. - required: - - prompt_tokens - - total_tokens + description: The object type, which is always "embedding". + enum: [embedding] required: + - index - object - - model - - data - - usage + - embedding + x-oaiMeta: + name: The embedding object + example: | + { + "object": "embedding", + "embedding": [ + 0.0023064255, + -0.009327292, + .... (1536 floats total for ada-002) + -0.0028842222, + ], + "index": 0 + } - CreateTranscriptionRequest: + FineTuningJob: type: object - additionalProperties: false + title: FineTuningJob + description: | + The `fine_tuning.job` object represents a fine-tuning job that has been created through the API. properties: - file: - description: | - The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + id: type: string - x-oaiTypeLabel: file - format: binary + description: The object identifier, which can be referenced in the API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the fine-tuning job was created. + error: + type: object + nullable: true + description: For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure. + properties: + code: + type: string + description: A machine-readable error code. + message: + type: string + description: A human-readable error message. + param: + type: string + description: The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific. + nullable: true + required: + - code + - message + - param + fine_tuned_model: + type: string + nullable: true + description: The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running. + finished_at: + type: integer + nullable: true + description: The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running. + hyperparameters: + type: object + description: The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + properties: + n_epochs: + oneOf: + - type: string + enum: [auto] + - type: integer + minimum: 1 + maximum: 50 + default: auto + description: + The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. + + "auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs. + required: + - n_epochs model: - description: | - ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. - example: whisper-1 - anyOf: - - type: string - - type: string - enum: ["whisper-1"] - x-oaiTypeLabel: string - language: - description: | - The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency. type: string - prompt: - description: | - An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language. + description: The base model that is being fine-tuned. + object: type: string - response_format: - description: | - The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`. + description: The object type, which is always "fine_tuning.job". + enum: [fine_tuning.job] + organization_id: type: string - enum: - - json - - text - - srt - - verbose_json - - vtt - default: json - temperature: - description: | - The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. - type: number - default: 0 - timestamp_granularities[]: - description: | - The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency. + description: The organization that owns the fine-tuning job. + result_files: type: array + description: The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the [Files API](/docs/api-reference/files/retrieve-contents). items: type: string - enum: - - word - - segment - default: [segment] + example: file-abc123 + status: + type: string + description: The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`. + enum: + [ + "validating_files", + "queued", + "running", + "succeeded", + "failed", + "cancelled", + ] + trained_tokens: + type: integer + nullable: true + description: The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running. + training_file: + type: string + description: The file ID used for training. You can retrieve the training data with the [Files API](/docs/api-reference/files/retrieve-contents). + validation_file: + type: string + nullable: true + description: The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents). + integrations: + type: array + nullable: true + description: A list of integrations to enable for this fine-tuning job. + maxItems: 5 + items: + oneOf: + - $ref: "#/components/schemas/FineTuningIntegration" + x-oaiExpandable: true + seed: + type: integer + description: The seed used for the fine-tuning job. + estimated_finish: + type: integer + nullable: true + description: The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running. required: - - file + - created_at + - error + - finished_at + - fine_tuned_model + - hyperparameters + - id - model + - object + - organization_id + - result_files + - status + - trained_tokens + - training_file + - validation_file + - seed + x-oaiMeta: + name: The fine-tuning job object + example: *fine_tuning_example + + FineTuningIntegration: + type: object + title: Fine-Tuning Job Integration + required: + - type + - wandb + properties: + type: + type: string + description: "The type of the integration being enabled for the fine-tuning job" + enum: ["wandb"] + wandb: + type: object + description: | + The settings for your integration with Weights and Biases. This payload specifies the project that + metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags + to your run, and set a default entity (team, username, etc) to be associated with your run. + required: + - project + properties: + project: + description: | + The name of the project that the new run will be created under. + type: string + example: "my-wandb-project" + name: + description: | + A display name to set for the run. If not set, we will use the Job ID as the name. + nullable: true + type: string + entity: + description: | + The entity to use for the run. This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered WandB API key is used. + nullable: true + type: string + tags: + description: | + A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some + default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + type: array + items: + type: string + example: "custom-tag" - # Note: This does not currently support the non-default response format types. - CreateTranscriptionResponseJson: + FineTuningJobEvent: type: object - description: Represents a transcription response returned by model, based on the provided input. + description: Fine-tuning job event object properties: - text: + id: type: string - description: The transcribed text. + created_at: + type: integer + level: + type: string + enum: ["info", "warn", "error"] + message: + type: string + object: + type: string + enum: [fine_tuning.job.event] required: - - text + - id + - object + - created_at + - level + - message x-oaiMeta: - name: The transcription object (JSON) - group: audio - example: *basic_transcription_response_example + name: The fine-tuning job event object + example: | + { + "object": "fine_tuning.job.event", + "id": "ftevent-abc123" + "created_at": 1677610602, + "level": "info", + "message": "Created fine-tuning job" + } - TranscriptionSegment: + FineTuningJobCheckpoint: type: object + title: FineTuningJobCheckpoint + description: | + The `fine_tuning.job.checkpoint` object represents a model checkpoint for a fine-tuning job that is ready to use. properties: id: + type: string + description: The checkpoint identifier, which can be referenced in the API endpoints. + created_at: type: integer - description: Unique identifier of the segment. - seek: + description: The Unix timestamp (in seconds) for when the checkpoint was created. + fine_tuned_model_checkpoint: + type: string + description: The name of the fine-tuned checkpoint model that is created. + step_number: type: integer - description: Seek offset of the segment. - start: - type: number - format: float - description: Start time of the segment in seconds. - end: - type: number - format: float - description: End time of the segment in seconds. - text: + description: The step number that the checkpoint was created at. + metrics: + type: object + description: Metrics at the step number during the fine-tuning job. + properties: + step: + type: number + train_loss: + type: number + train_mean_token_accuracy: + type: number + valid_loss: + type: number + valid_mean_token_accuracy: + type: number + full_valid_loss: + type: number + full_valid_mean_token_accuracy: + type: number + fine_tuning_job_id: type: string - description: Text content of the segment. - tokens: - type: array - items: - type: integer - description: Array of token IDs for the text content. - temperature: - type: number - format: float - description: Temperature parameter used for generating the segment. - avg_logprob: - type: number - format: float - description: Average logprob of the segment. If the value is lower than -1, consider the logprobs failed. - compression_ratio: - type: number - format: float - description: Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed. - no_speech_prob: - type: number - format: float - description: Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this segment silent. + description: The name of the fine-tuning job that this checkpoint was created from. + object: + type: string + description: The object type, which is always "fine_tuning.job.checkpoint". + enum: [fine_tuning.job.checkpoint] required: + - created_at + - fine_tuning_job_id + - fine_tuned_model_checkpoint - id - - seek - - start - - end - - text - - tokens - - temperature - - avg_logprob - - compression_ratio - - no_speech_prob - - TranscriptionWord: - type: object - properties: - word: - type: string - description: The text content of the word. - start: - type: number - format: float - description: Start time of the word in seconds. - end: - type: number - format: float - description: End time of the word in seconds. - required: [word, start, end] + - metrics + - object + - step_number + x-oaiMeta: + name: The fine-tuning job checkpoint object + example: | + { + "object": "fine_tuning.job.checkpoint", + "id": "ftckpt_qtZ5Gyk4BLq1SfLFWp3RtO3P", + "created_at": 1712211699, + "fine_tuned_model_checkpoint": "ft:gpt-4o-mini-2024-07-18:my-org:custom_suffix:9ABel2dg:ckpt-step-88", + "fine_tuning_job_id": "ftjob-fpbNQ3H1GrMehXRf8cO97xTN", + "metrics": { + "step": 88, + "train_loss": 0.478, + "train_mean_token_accuracy": 0.924, + "valid_loss": 10.112, + "valid_mean_token_accuracy": 0.145, + "full_valid_loss": 0.567, + "full_valid_mean_token_accuracy": 0.944 + }, + "step_number": 88 + } - CreateTranscriptionResponseVerboseJson: + FinetuneChatRequestInput: type: object - description: Represents a verbose json transcription response returned by model, based on the provided input. + description: The per-line training example of a fine-tuning input file for chat models properties: - language: - type: string - description: The language of the input audio. - duration: - type: string - description: The duration of the input audio. - text: - type: string - description: The transcribed text. - words: + messages: type: array - description: Extracted words and their corresponding timestamps. + minItems: 1 items: - $ref: "#/components/schemas/TranscriptionWord" - segments: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestSystemMessage" + - $ref: "#/components/schemas/ChatCompletionRequestUserMessage" + - $ref: "#/components/schemas/FineTuneChatCompletionRequestAssistantMessage" + - $ref: "#/components/schemas/ChatCompletionRequestToolMessage" + - $ref: "#/components/schemas/ChatCompletionRequestFunctionMessage" + x-oaiExpandable: true + tools: type: array - description: Segments of the transcribed text and their corresponding details. + description: A list of tools the model may generate JSON inputs for. items: - $ref: "#/components/schemas/TranscriptionSegment" - required: [language, duration, text] + $ref: "#/components/schemas/ChatCompletionTool" + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + functions: + deprecated: true + description: A list of functions the model may generate JSON inputs for. + type: array + minItems: 1 + maxItems: 128 + items: + $ref: "#/components/schemas/ChatCompletionFunctions" x-oaiMeta: - name: The transcription object (Verbose JSON) - group: audio - example: *verbose_transcription_response_example + name: Training format for chat models + example: | + { + "messages": [ + { "role": "user", "content": "What is the weather in San Francisco?" }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_id", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}" + } + } + ] + } + ], + "parallel_tool_calls": false, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and country, eg. San Francisco, USA" + }, + "format": { "type": "string", "enum": ["celsius", "fahrenheit"] } + }, + "required": ["location", "format"] + } + } + } + ] + } - CreateTranslationRequest: + FinetuneCompletionRequestInput: type: object - additionalProperties: false + description: The per-line training example of a fine-tuning input file for completions models properties: - file: - description: | - The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. - type: string - x-oaiTypeLabel: file - format: binary - model: - description: | - ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. - example: whisper-1 - anyOf: - - type: string - - type: string - enum: ["whisper-1"] - x-oaiTypeLabel: string prompt: - description: | - An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English. type: string - response_format: - description: | - The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`. + description: The input prompt for this training example. + completion: type: string - default: json - temperature: - description: | - The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. - type: number - default: 0 - required: - - file - - model + description: The desired completion for this training example. + x-oaiMeta: + name: Training format for completions models + example: | + { + "prompt": "What is the answer to 2+2", + "completion": "4" + } - # Note: This does not currently support the non-default response format types. - CreateTranslationResponseJson: + CompletionUsage: type: object + description: Usage statistics for the completion request. properties: - text: - type: string + completion_tokens: + type: integer + description: Number of tokens in the generated completion. + prompt_tokens: + type: integer + description: Number of tokens in the prompt. + total_tokens: + type: integer + description: Total number of tokens used in the request (prompt + completion). required: - - text + - prompt_tokens + - completion_tokens + - total_tokens - CreateTranslationResponseVerboseJson: + RunCompletionUsage: type: object + description: Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.). properties: - language: - type: string - description: The language of the output translation (always `english`). - duration: - type: string - description: The duration of the input audio. - text: - type: string - description: The translated text. - segments: - type: array - description: Segments of the translated text and their corresponding details. - items: - $ref: "#/components/schemas/TranscriptionSegment" - required: [language, duration, text] + completion_tokens: + type: integer + description: Number of completion tokens used over the course of the run. + prompt_tokens: + type: integer + description: Number of prompt tokens used over the course of the run. + total_tokens: + type: integer + description: Total number of tokens used (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + nullable: true - CreateSpeechRequest: + RunStepCompletionUsage: type: object - additionalProperties: false + description: Usage statistics related to the run step. This value will be `null` while the run step's status is `in_progress`. properties: - model: - description: | - One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd` - anyOf: - - type: string - - type: string - enum: ["tts-1", "tts-1-hd"] - x-oaiTypeLabel: string - input: - type: string - description: The text to generate audio for. The maximum length is 4096 characters. - maxLength: 4096 - voice: - description: The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are available in the [Text to speech guide](/docs/guides/text-to-speech/voice-options). - type: string - enum: ["alloy", "echo", "fable", "onyx", "nova", "shimmer"] - response_format: - description: "The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`." - default: "mp3" - type: string - enum: ["mp3", "opus", "aac", "flac", "wav", "pcm"] - speed: - description: "The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default." - type: number - default: 1.0 - minimum: 0.25 - maximum: 4.0 + completion_tokens: + type: integer + description: Number of completion tokens used over the course of the run step. + prompt_tokens: + type: integer + description: Number of prompt tokens used over the course of the run step. + total_tokens: + type: integer + description: Total number of tokens used (prompt + completion). required: - - model - - input - - voice + - prompt_tokens + - completion_tokens + - total_tokens + nullable: true - Model: - title: Model - description: Describes an OpenAI model offering that can be used with the API. + AssistantsApiResponseFormatOption: + description: | + Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. + + Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs). + + Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON. + + **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length. + oneOf: + - type: string + description: > + `auto` is the default value + enum: [auto] + - $ref: "#/components/schemas/ResponseFormatText" + - $ref: "#/components/schemas/ResponseFormatJsonObject" + - $ref: "#/components/schemas/ResponseFormatJsonSchema" + x-oaiExpandable: true + + AssistantObject: + type: object + title: Assistant + description: Represents an `assistant` that can call the model and use tools. properties: id: + description: The identifier, which can be referenced in API endpoints. type: string - description: The model identifier, which can be referenced in the API endpoints. - created: - type: integer - description: The Unix timestamp (in seconds) when the model was created. object: + description: The object type, which is always `assistant`. type: string - description: The object type, which is always "model". - enum: [model] - owned_by: + enum: [assistant] + created_at: + description: The Unix timestamp (in seconds) for when the assistant was created. + type: integer + name: + description: &assistant_name_param_description | + The name of the assistant. The maximum length is 256 characters. type: string - description: The organization that owns the model. - required: - - id - - object - - created - - owned_by - x-oaiMeta: - name: The model object - example: *retrieve_model_response + maxLength: 256 + nullable: true + description: + description: &assistant_description_param_description | + The description of the assistant. The maximum length is 512 characters. + type: string + maxLength: 512 + nullable: true + model: + description: *model_description + type: string + instructions: + description: &assistant_instructions_param_description | + The system instructions that the assistant uses. The maximum length is 256,000 characters. + type: string + maxLength: 256000 + nullable: true + tools: + description: &assistant_tools_param_description | + A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`. + default: [] + type: array + maxItems: 128 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: &metadata_description | + Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + description: &run_temperature_description | + What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: &run_top_p_description | + An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. - OpenAIFile: - title: OpenAIFile - description: The `File` object represents a document that has been uploaded to OpenAI. - properties: - id: - type: string - description: The file identifier, which can be referenced in the API endpoints. - bytes: - type: integer - description: The size of the file, in bytes. - created_at: - type: integer - description: The Unix timestamp (in seconds) for when the file was created. - filename: - type: string - description: The name of the file. - object: - type: string - description: The object type, which is always `file`. - enum: ["file"] - purpose: - type: string - description: The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`. - enum: - [ - "assistants", - "assistants_output", - "batch", - "batch_output", - "fine-tune", - "fine-tune-results", - "vision", - ] - status: - type: string - deprecated: true - description: Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`. - enum: ["uploaded", "processed", "error"] - status_details: - type: string - deprecated: true - description: Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`. + We generally recommend altering this or temperature but not both. + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true required: - id - object - - bytes - created_at - - filename - - purpose - - status + - name + - description + - model + - instructions + - tools + - metadata x-oaiMeta: - name: The file object - example: | - { - "id": "file-abc123", - "object": "file", - "bytes": 120000, - "created_at": 1677610602, - "filename": "salesOverview.pdf", - "purpose": "assistants", - } - Upload: + name: The assistant object + beta: true + example: *create_assistants_example + + CreateAssistantRequest: type: object - title: Upload - description: | - The Upload object can accept byte chunks in the form of Parts. + additionalProperties: false properties: - id: - type: string - description: The Upload unique identifier, which can be referenced in API endpoints. - created_at: - type: integer - description: The Unix timestamp (in seconds) for when the Upload was created. - filename: - type: string - description: The name of the file to be uploaded. - bytes: - type: integer - description: The intended number of bytes to be uploaded. - purpose: + model: + description: *model_description + example: "gpt-4o" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + name: + description: *assistant_name_param_description type: string - description: The intended purpose of the file. [Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values. - status: + nullable: true + maxLength: 256 + description: + description: *assistant_description_param_description type: string - description: The status of the Upload. - enum: ["pending", "completed", "cancelled", "expired"] - expires_at: - type: integer - description: The Unix timestamp (in seconds) for when the Upload was created. - object: + nullable: true + maxLength: 512 + instructions: + description: *assistant_instructions_param_description type: string - description: The object type, which is always "upload". - enum: [upload] - file: - $ref: "#/components/schemas/OpenAIFile" nullable: true - description: The ready File object after the Upload is completed. + maxLength: 256000 + tools: + description: *assistant_tools_param_description + default: [] + type: array + maxItems: 128 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + vector_stores: + type: array + description: | + A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store. + maxItems: 10000 + items: + type: string + chunking_strategy: + # Ideally we'd reuse the chunking strategy schema here, but it doesn't expand properly + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. + oneOf: + - type: object + title: Auto Chunking Strategy + description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + additionalProperties: false + properties: + type: + type: string + description: Always `auto`. + enum: ["auto"] + required: + - type + - type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + type: object + additionalProperties: false + properties: + max_chunk_size_tokens: + type: integer + minimum: 100 + maximum: 4096 + description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. + chunk_overlap_tokens: + type: integer + description: | + The number of tokens that overlap between chunks. The default value is `400`. + + Note that the overlap must not exceed half of `max_chunk_size_tokens`. + required: + - max_chunk_size_tokens + - chunk_overlap_tokens + required: + - type + - static + x-oaiExpandable: true + metadata: + type: object + description: | + Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + x-oaiTypeLabel: map + oneOf: + - required: [vector_store_ids] + - required: [vector_stores] + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + description: *run_temperature_description + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true required: - - bytes - - created_at - - expires_at - - filename - - id - - purpose - - status - - step_number - x-oaiMeta: - name: The upload object - example: | - { - "id": "upload_abc123", - "object": "upload", - "bytes": 2147483648, - "created_at": 1719184911, - "filename": "training_examples.jsonl", - "purpose": "fine-tune", - "status": "completed", - "expires_at": 1719127296, - "file": { - "id": "file-xyz321", - "object": "file", - "bytes": 2147483648, - "created_at": 1719186911, - "filename": "training_examples.jsonl", - "purpose": "fine-tune", - } - } - UploadPart: + - model + + ModifyAssistantRequest: type: object - title: UploadPart - description: | - The upload Part represents a chunk of bytes we can add to an Upload object. + additionalProperties: false properties: - id: + model: + description: *model_description + anyOf: + - type: string + name: + description: *assistant_name_param_description type: string - description: The upload Part unique identifier, which can be referenced in API endpoints. - created_at: - type: integer - description: The Unix timestamp (in seconds) for when the Part was created. - upload_id: + nullable: true + maxLength: 256 + description: + description: *assistant_description_param_description type: string - description: The ID of the Upload object that this Part was added to. - object: + nullable: true + maxLength: 512 + instructions: + description: *assistant_instructions_param_description type: string - description: The object type, which is always `upload.part`. - enum: ["upload.part"] - required: - - created_at - - id - - object - - upload_id - x-oaiMeta: - name: The upload part object - example: | - { - "id": "part_def456", - "object": "upload.part", - "created_at": 1719186911, - "upload_id": "upload_abc123" - } - Embedding: - type: object - description: | - Represents an embedding vector returned by embedding endpoint. - properties: - index: - type: integer - description: The index of the embedding in the list of embeddings. - embedding: + nullable: true + maxLength: 256000 + tools: + description: *assistant_tools_param_description + default: [] type: array - description: | - The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings). + maxItems: 128 items: - type: number - object: - type: string - description: The object type, which is always "embedding". - enum: [embedding] - required: - - index - - object - - embedding - x-oaiMeta: - name: The embedding object - example: | - { - "object": "embedding", - "embedding": [ - 0.0023064255, - -0.009327292, - .... (1536 floats total for ada-002) - -0.0028842222, - ], - "index": 0 - } - - FineTuningJob: - type: object - title: FineTuningJob - description: | - The `fine_tuning.job` object represents a fine-tuning job that has been created through the API. - properties: - id: - type: string - description: The object identifier, which can be referenced in the API endpoints. - created_at: - type: integer - description: The Unix timestamp (in seconds) for when the fine-tuning job was created. - error: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + tool_resources: type: object - nullable: true - description: For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure. + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. properties: - code: - type: string - description: A machine-readable error code. - message: - type: string - description: A human-readable error message. - param: - type: string - description: The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific. - nullable: true - required: - - code - - message - - param - fine_tuned_model: - type: string - nullable: true - description: The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running. - finished_at: - type: integer + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string nullable: true - description: The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running. - hyperparameters: + metadata: + description: *metadata_description type: object - description: The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. - properties: - n_epochs: - oneOf: - - type: string - enum: [auto] - - type: integer - minimum: 1 - maximum: 50 - default: auto - description: - The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. + x-oaiTypeLabel: map + nullable: true + temperature: + description: *run_temperature_description + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true - "auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs. - required: - - n_epochs - model: + DeleteAssistantResponse: + type: object + properties: + id: type: string - description: The base model that is being fine-tuned. + deleted: + type: boolean object: type: string - description: The object type, which is always "fine_tuning.job". - enum: [fine_tuning.job] - organization_id: + enum: [assistant.deleted] + required: + - id + - object + - deleted + + ListAssistantsResponse: + type: object + properties: + object: type: string - description: The organization that owns the fine-tuning job. - result_files: + example: "list" + data: type: array - description: The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the [Files API](/docs/api-reference/files/retrieve-contents). items: - type: string - example: file-abc123 - status: - type: string - description: The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`. - enum: - [ - "validating_files", - "queued", - "running", - "succeeded", - "failed", - "cancelled", - ] - trained_tokens: - type: integer - nullable: true - description: The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running. - training_file: + $ref: "#/components/schemas/AssistantObject" + first_id: type: string - description: The file ID used for training. You can retrieve the training data with the [Files API](/docs/api-reference/files/retrieve-contents). - validation_file: + example: "asst_abc123" + last_id: type: string - nullable: true - description: The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents). - integrations: - type: array - nullable: true - description: A list of integrations to enable for this fine-tuning job. - maxItems: 5 - items: - oneOf: - - $ref: "#/components/schemas/FineTuningIntegration" - x-oaiExpandable: true - seed: - type: integer - description: The seed used for the fine-tuning job. - estimated_finish: - type: integer - nullable: true - description: The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running. + example: "asst_abc456" + has_more: + type: boolean + example: false required: - - created_at - - error - - finished_at - - fine_tuned_model - - hyperparameters - - id - - model - object - - organization_id - - result_files - - status - - trained_tokens - - training_file - - validation_file - - seed + - data + - first_id + - last_id + - has_more x-oaiMeta: - name: The fine-tuning job object - example: *fine_tuning_example + name: List assistants response object + group: chat + example: *list_assistants_example - FineTuningIntegration: + AssistantToolsCode: type: object - title: Fine-Tuning Job Integration + title: Code interpreter tool + properties: + type: + type: string + description: "The type of tool being defined: `code_interpreter`" + enum: ["code_interpreter"] required: - type - - wandb + + AssistantToolsFileSearch: + type: object + title: FileSearch tool properties: type: type: string - description: "The type of the integration being enabled for the fine-tuning job" - enum: ["wandb"] - wandb: + description: "The type of tool being defined: `file_search`" + enum: ["file_search"] + file_search: type: object - description: | - The settings for your integration with Weights and Biases. This payload specifies the project that - metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags - to your run, and set a default entity (team, username, etc) to be associated with your run. - required: - - project + description: Overrides for the file search tool. properties: - project: - description: | - The name of the project that the new run will be created under. - type: string - example: "my-wandb-project" - name: - description: | - A display name to set for the run. If not set, we will use the Job ID as the name. - nullable: true - type: string - entity: - description: | - The entity to use for the run. This allows you to set the team or username of the WandB user that you would - like associated with the run. If not set, the default entity for the registered WandB API key is used. - nullable: true - type: string - tags: + max_num_results: + type: integer + minimum: 1 + maximum: 50 description: | - A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some - default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". - type: array - items: - type: string - example: "custom-tag" + The maximum number of results the file search tool should output. The default is 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between 1 and 50 inclusive. - FineTuningJobEvent: + Note that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information. + required: + - type + + AssistantToolsFileSearchTypeOnly: type: object - description: Fine-tuning job event object + title: FileSearch tool properties: - id: + type: type: string - created_at: - type: integer - level: + description: "The type of tool being defined: `file_search`" + enum: ["file_search"] + required: + - type + + AssistantToolsFunction: + type: object + title: Function tool + properties: + type: type: string - enum: ["info", "warn", "error"] - message: + description: "The type of tool being defined: `function`" + enum: ["function"] + function: + $ref: "#/components/schemas/FunctionObject" + required: + - type + - function + + TruncationObject: + type: object + title: Thread Truncation Controls + description: Controls for how a thread will be truncated prior to the run. Use this to control the intial context window of the run. + properties: + type: type: string - object: + description: The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`. + enum: ["auto", "last_messages"] + last_messages: + type: integer + description: The number of most recent messages from the thread when constructing the context for the run. + minimum: 1 + nullable: true + required: + - type + + AssistantsApiToolChoiceOption: + description: | + Controls which (if any) tool is called by the model. + `none` means the model will not call any tools and instead generates a message. + `auto` is the default value and means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools before responding to the user. + Specifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. + + oneOf: + - type: string + description: > + `none` means the model will not call any tools and instead generates a message. + `auto` means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools before responding to the user. + enum: [none, auto, required] + - $ref: "#/components/schemas/AssistantsNamedToolChoice" + x-oaiExpandable: true + + AssistantsNamedToolChoice: + type: object + description: Specifies a tool the model should use. Use to force the model to call a specific tool. + properties: + type: type: string - enum: [fine_tuning.job.event] + enum: ["function", "code_interpreter", "file_search"] + description: The type of the tool. If type is `function`, the function name must be set + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + required: + - name required: - - id - - object - - created_at - - level - - message - x-oaiMeta: - name: The fine-tuning job event object - example: | - { - "object": "fine_tuning.job.event", - "id": "ftevent-abc123" - "created_at": 1677610602, - "level": "info", - "message": "Created fine-tuning job" - } + - type - FineTuningJobCheckpoint: + RunObject: type: object - title: FineTuningJobCheckpoint - description: | - The `fine_tuning.job.checkpoint` object represents a model checkpoint for a fine-tuning job that is ready to use. + title: A run on a thread + description: Represents an execution run on a [thread](/docs/api-reference/threads). properties: id: + description: The identifier, which can be referenced in API endpoints. type: string - description: The checkpoint identifier, which can be referenced in the API endpoints. + object: + description: The object type, which is always `thread.run`. + type: string + enum: ["thread.run"] created_at: + description: The Unix timestamp (in seconds) for when the run was created. type: integer - description: The Unix timestamp (in seconds) for when the checkpoint was created. - fine_tuned_model_checkpoint: + thread_id: + description: The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run. type: string - description: The name of the fine-tuned checkpoint model that is created. - step_number: + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run. + type: string + status: + description: The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`. + type: string + enum: + [ + "queued", + "in_progress", + "requires_action", + "cancelling", + "cancelled", + "failed", + "completed", + "incomplete", + "expired", + ] + required_action: + type: object + description: Details on the action required to continue the run. Will be `null` if no action is required. + nullable: true + properties: + type: + description: For now, this is always `submit_tool_outputs`. + type: string + enum: ["submit_tool_outputs"] + submit_tool_outputs: + type: object + description: Details on the tool outputs needed for this run to continue. + properties: + tool_calls: + type: array + description: A list of the relevant tool calls. + items: + $ref: "#/components/schemas/RunToolCallObject" + required: + - tool_calls + required: + - type + - submit_tool_outputs + last_error: + type: object + description: The last error associated with this run. Will be `null` if there are no errors. + nullable: true + properties: + code: + type: string + description: One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`. + enum: ["server_error", "rate_limit_exceeded", "invalid_prompt"] + message: + type: string + description: A human-readable description of the error. + required: + - code + - message + expires_at: + description: The Unix timestamp (in seconds) for when the run will expire. type: integer - description: The step number that the checkpoint was created at. - metrics: + nullable: true + started_at: + description: The Unix timestamp (in seconds) for when the run was started. + type: integer + nullable: true + cancelled_at: + description: The Unix timestamp (in seconds) for when the run was cancelled. + type: integer + nullable: true + failed_at: + description: The Unix timestamp (in seconds) for when the run failed. + type: integer + nullable: true + completed_at: + description: The Unix timestamp (in seconds) for when the run was completed. + type: integer + nullable: true + incomplete_details: + description: Details on why the run is incomplete. Will be `null` if the run is not incomplete. type: object - description: Metrics at the step number during the fine-tuning job. + nullable: true properties: - step: - type: number - train_loss: - type: number - train_mean_token_accuracy: - type: number - valid_loss: - type: number - valid_mean_token_accuracy: - type: number - full_valid_loss: - type: number - full_valid_mean_token_accuracy: - type: number - fine_tuning_job_id: + reason: + description: The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run. + type: string + enum: ["max_completion_tokens", "max_prompt_tokens"] + model: + description: The model that the [assistant](/docs/api-reference/assistants) used for this run. type: string - description: The name of the fine-tuning job that this checkpoint was created from. - object: + instructions: + description: The instructions that the [assistant](/docs/api-reference/assistants) used for this run. type: string - description: The object type, which is always "fine_tuning.job.checkpoint". - enum: [fine_tuning.job.checkpoint] + tools: + description: The list of tools that the [assistant](/docs/api-reference/assistants) used for this run. + default: [] + type: array + maxItems: 20 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + usage: + $ref: "#/components/schemas/RunCompletionUsage" + temperature: + description: The sampling temperature used for this run. If not set, defaults to 1. + type: number + nullable: true + top_p: + description: The nucleus sampling value used for this run. If not set, defaults to 1. + type: number + nullable: true + max_prompt_tokens: + type: integer + nullable: true + description: | + The maximum number of prompt tokens specified to have been used over the course of the run. + minimum: 256 + max_completion_tokens: + type: integer + nullable: true + description: | + The maximum number of completion tokens specified to have been used over the course of the run. + minimum: 256 + truncation_strategy: + $ref: "#/components/schemas/TruncationObject" + nullable: true + tool_choice: + $ref: "#/components/schemas/AssistantsApiToolChoiceOption" + nullable: true + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true required: - - created_at - - fine_tuning_job_id - - fine_tuned_model_checkpoint - id - - metrics - object - - step_number + - created_at + - thread_id + - assistant_id + - status + - required_action + - last_error + - expires_at + - started_at + - cancelled_at + - failed_at + - completed_at + - model + - instructions + - tools + - metadata + - usage + - incomplete_details + - max_prompt_tokens + - max_completion_tokens + - truncation_strategy + - tool_choice + - parallel_tool_calls + - response_format x-oaiMeta: - name: The fine-tuning job checkpoint object + name: The run object + beta: true example: | { - "object": "fine_tuning.job.checkpoint", - "id": "ftckpt_qtZ5Gyk4BLq1SfLFWp3RtO3P", - "created_at": 1712211699, - "fine_tuned_model_checkpoint": "ft:gpt-3.5-turbo-0125:my-org:custom_suffix:9ABel2dg:ckpt-step-88", - "fine_tuning_job_id": "ftjob-fpbNQ3H1GrMehXRf8cO97xTN", - "metrics": { - "step": 88, - "train_loss": 0.478, - "train_mean_token_accuracy": 0.924, - "valid_loss": 10.112, - "valid_mean_token_accuracy": 0.145, - "full_valid_loss": 0.567, - "full_valid_mean_token_accuracy": 0.944 + "id": "run_abc123", + "object": "thread.run", + "created_at": 1698107661, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699073476, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699073498, + "last_error": null, + "model": "gpt-4o", + "instructions": null, + "tools": [{"type": "file_search"}, {"type": "code_interpreter"}], + "metadata": {}, + "incomplete_details": null, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 }, - "step_number": 88 + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true } - - FinetuneChatRequestInput: + CreateRunRequest: type: object - description: The per-line training example of a fine-tuning input file for chat models + additionalProperties: false properties: - messages: + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run. + type: string + model: + description: The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. + example: "gpt-4o" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + nullable: true + instructions: + description: Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis. + type: string + nullable: true + additional_instructions: + description: Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions. + type: string + nullable: true + additional_messages: + description: Adds additional messages to the thread before creating the run. type: array - minItems: 1 items: - oneOf: - - $ref: "#/components/schemas/ChatCompletionRequestSystemMessage" - - $ref: "#/components/schemas/ChatCompletionRequestUserMessage" - - $ref: "#/components/schemas/FineTuneChatCompletionRequestAssistantMessage" - - $ref: "#/components/schemas/ChatCompletionRequestToolMessage" - - $ref: "#/components/schemas/ChatCompletionRequestFunctionMessage" - x-oaiExpandable: true + $ref: "#/components/schemas/CreateMessageRequest" + nullable: true tools: + description: Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. + nullable: true type: array - description: A list of tools the model may generate JSON inputs for. + maxItems: 20 items: - $ref: "#/components/schemas/ChatCompletionTool" + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + description: *run_temperature_description + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + stream: + type: boolean + nullable: true + description: | + If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + max_prompt_tokens: + type: integer + nullable: true + description: | + The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + max_completion_tokens: + type: integer + nullable: true + description: | + The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + truncation_strategy: + $ref: "#/components/schemas/TruncationObject" + nullable: true + tool_choice: + $ref: "#/components/schemas/AssistantsApiToolChoiceOption" + nullable: true parallel_tool_calls: $ref: "#/components/schemas/ParallelToolCalls" - functions: - deprecated: true - description: A list of functions the model may generate JSON inputs for. - type: array - minItems: 1 - maxItems: 128 - items: - $ref: "#/components/schemas/ChatCompletionFunctions" - x-oaiMeta: - name: Training format for chat models - example: | - { - "messages": [ - { "role": "user", "content": "What is the weather in San Francisco?" }, - { - "role": "assistant", - "tool_calls": [ - { - "id": "call_id", - "type": "function", - "function": { - "name": "get_current_weather", - "arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}" - } - } - ] - } - ], - "parallel_tool_calls": false, - "tools": [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and country, eg. San Francisco, USA" - }, - "format": { "type": "string", "enum": ["celsius", "fahrenheit"] } - }, - "required": ["location", "format"] - } - } - } - ] - } - - FinetuneCompletionRequestInput: + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - thread_id + - assistant_id + ListRunsResponse: type: object - description: The per-line training example of a fine-tuning input file for completions models properties: - prompt: + object: type: string - description: The input prompt for this training example. - completion: + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/RunObject" + first_id: type: string - description: The desired completion for this training example. - x-oaiMeta: - name: Training format for completions models - example: | - { - "prompt": "What is the answer to 2+2", - "completion": "4" - } - - CompletionUsage: - type: object - description: Usage statistics for the completion request. - properties: - completion_tokens: - type: integer - description: Number of tokens in the generated completion. - prompt_tokens: - type: integer - description: Number of tokens in the prompt. - total_tokens: - type: integer - description: Total number of tokens used in the request (prompt + completion). + example: "run_abc123" + last_id: + type: string + example: "run_abc456" + has_more: + type: boolean + example: false required: - - prompt_tokens - - completion_tokens - - total_tokens - - RunCompletionUsage: + - object + - data + - first_id + - last_id + - has_more + ModifyRunRequest: type: object - description: Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.). + additionalProperties: false properties: - completion_tokens: - type: integer - description: Number of completion tokens used over the course of the run. - prompt_tokens: - type: integer - description: Number of prompt tokens used over the course of the run. - total_tokens: - type: integer - description: Total number of tokens used (prompt + completion). - required: - - prompt_tokens - - completion_tokens - - total_tokens - nullable: true - - RunStepCompletionUsage: + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + SubmitToolOutputsRunRequest: type: object - description: Usage statistics related to the run step. This value will be `null` while the run step's status is `in_progress`. + additionalProperties: false properties: - completion_tokens: - type: integer - description: Number of completion tokens used over the course of the run step. - prompt_tokens: - type: integer - description: Number of prompt tokens used over the course of the run step. - total_tokens: - type: integer - description: Total number of tokens used (prompt + completion). + tool_outputs: + description: A list of tools for which the outputs are being submitted. + type: array + items: + type: object + properties: + tool_call_id: + type: string + description: The ID of the tool call in the `required_action` object within the run object the output is being submitted for. + output: + type: string + description: The output of the tool call to be submitted to continue the run. + stream: + type: boolean + nullable: true + description: | + If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. required: - - prompt_tokens - - completion_tokens - - total_tokens - nullable: true - - AssistantsApiResponseFormatOption: - description: | - Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. - - Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON. - - **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length. - oneOf: - - type: string - description: > - `auto` is the default value - enum: [none, auto] - - $ref: "#/components/schemas/AssistantsApiResponseFormat" - x-oaiExpandable: true + - tool_outputs - AssistantsApiResponseFormat: + RunToolCallObject: type: object - description: | - An object describing the expected output of the model. If `json_object` only `function` type `tools` are allowed to be passed to the Run. If `text` the model can return text or any value needed. + description: Tool call objects properties: + id: + type: string + description: The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](/docs/api-reference/runs/submitToolOutputs) endpoint. type: type: string - enum: ["text", "json_object"] - example: "json_object" - default: "text" - description: Must be one of `text` or `json_object`. + description: The type of tool call the output is required for. For now, this is always `function`. + enum: ["function"] + function: + type: object + description: The function definition. + properties: + name: + type: string + description: The name of the function. + arguments: + type: string + description: The arguments that the model expects you to pass to the function. + required: + - name + - arguments + required: + - id + - type + - function - AssistantObject: + CreateThreadAndRunRequest: type: object - title: Assistant - description: Represents an `assistant` that can call the model and use tools. + additionalProperties: false properties: - id: - description: The identifier, which can be referenced in API endpoints. - type: string - object: - description: The object type, which is always `assistant`. - type: string - enum: [assistant] - created_at: - description: The Unix timestamp (in seconds) for when the assistant was created. - type: integer - name: - description: &assistant_name_param_description | - The name of the assistant. The maximum length is 256 characters. - type: string - maxLength: 256 - nullable: true - description: - description: &assistant_description_param_description | - The description of the assistant. The maximum length is 512 characters. + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run. type: string - maxLength: 512 - nullable: true + thread: + $ref: "#/components/schemas/CreateThreadRequest" + description: If no thread is provided, an empty thread will be created. model: - description: *model_description - type: string + description: The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. + example: "gpt-4o" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + nullable: true instructions: - description: &assistant_instructions_param_description | - The system instructions that the assistant uses. The maximum length is 256,000 characters. + description: Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis. type: string - maxLength: 256000 nullable: true tools: - description: &assistant_tools_param_description | - A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`. - default: [] + description: Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. + nullable: true type: array - maxItems: 128 + maxItems: 20 items: oneOf: - $ref: "#/components/schemas/AssistantToolsCode" - $ref: "#/components/schemas/AssistantToolsFileSearch" - $ref: "#/components/schemas/AssistantToolsFunction" - x-oaiExpandable: true tool_resources: type: object description: | @@ -10261,7 +12821,7 @@ components: file_ids: type: array description: | - A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool. + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. default: [] maxItems: 20 items: @@ -10278,20 +12838,18 @@ components: type: string nullable: true metadata: - description: &metadata_description | - Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + description: *metadata_description type: object x-oaiTypeLabel: map nullable: true temperature: - description: &run_temperature_description | - What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. type: number minimum: 0 maximum: 2 default: 1 example: 1 nullable: true + description: *run_temperature_description top_p: type: number minimum: 0 @@ -10299,94 +12857,116 @@ components: default: 1 example: 1 nullable: true - description: &run_top_p_description | - An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. - - We generally recommend altering this or temperature but not both. + description: *run_top_p_description + stream: + type: boolean + nullable: true + description: | + If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + max_prompt_tokens: + type: integer + nullable: true + description: | + The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + max_completion_tokens: + type: integer + nullable: true + description: | + The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + truncation_strategy: + $ref: "#/components/schemas/TruncationObject" + nullable: true + tool_choice: + $ref: "#/components/schemas/AssistantsApiToolChoiceOption" + nullable: true + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" response_format: $ref: "#/components/schemas/AssistantsApiResponseFormatOption" nullable: true required: - - id - - object - - created_at - - name - - description - - model - - instructions - - tools - - metadata - x-oaiMeta: - name: The assistant object - beta: true - example: *create_assistants_example + - thread_id + - assistant_id - CreateAssistantRequest: + ThreadObject: type: object - additionalProperties: false + title: Thread + description: Represents a thread that contains [messages](/docs/api-reference/messages). properties: - model: - description: *model_description - example: "gpt-4-turbo" - anyOf: - - type: string - - type: string - enum: - [ - "gpt-4o", - "gpt-4o-2024-05-13", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-0125-preview", - "gpt-4-turbo-preview", - "gpt-4-1106-preview", - "gpt-4-vision-preview", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-32k", - "gpt-4-32k-0314", - "gpt-4-32k-0613", - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-16k-0613", - ] - x-oaiTypeLabel: string - name: - description: *assistant_name_param_description + id: + description: The identifier, which can be referenced in API endpoints. type: string - nullable: true - maxLength: 256 - description: - description: *assistant_description_param_description + object: + description: The object type, which is always `thread`. type: string + enum: ["thread"] + created_at: + description: The Unix timestamp (in seconds) for when the thread was created. + type: integer + tool_resources: + type: object + description: | + A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: string nullable: true - maxLength: 512 - instructions: - description: *assistant_instructions_param_description - type: string + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map nullable: true - maxLength: 256000 - tools: - description: *assistant_tools_param_description - default: [] + required: + - id + - object + - created_at + - tool_resources + - metadata + x-oaiMeta: + name: The thread object + beta: true + example: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1698107661, + "metadata": {} + } + + CreateThreadRequest: + type: object + additionalProperties: false + properties: + messages: + description: A list of [messages](/docs/api-reference/messages) to start the thread with. type: array - maxItems: 128 items: - oneOf: - - $ref: "#/components/schemas/AssistantToolsCode" - - $ref: "#/components/schemas/AssistantToolsFileSearch" - - $ref: "#/components/schemas/AssistantToolsFunction" - x-oaiExpandable: true + $ref: "#/components/schemas/CreateMessageRequest" tool_resources: type: object description: | - A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. properties: code_interpreter: type: object @@ -10405,14 +12985,14 @@ components: vector_store_ids: type: array description: | - The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. maxItems: 1 items: type: string vector_stores: type: array description: | - A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant. + A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread. maxItems: 1 items: type: object @@ -10475,6 +13055,7 @@ components: description: | Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. x-oaiTypeLabel: map + x-oaiExpandable: true oneOf: - required: [vector_store_ids] - required: [vector_stores] @@ -10482,117 +13063,340 @@ components: metadata: description: *metadata_description type: object - x-oaiTypeLabel: map - nullable: true - temperature: - description: *run_temperature_description - type: number - minimum: 0 - maximum: 2 - default: 1 - example: 1 - nullable: true - top_p: - type: number - minimum: 0 - maximum: 1 - default: 1 - example: 1 - nullable: true - description: *run_top_p_description - response_format: - $ref: "#/components/schemas/AssistantsApiResponseFormatOption" - nullable: true + x-oaiTypeLabel: map + nullable: true + + ModifyThreadRequest: + type: object + additionalProperties: false + properties: + tool_resources: + type: object + description: | + A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + DeleteThreadResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [thread.deleted] + required: + - id + - object + - deleted + + ListThreadsResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/ThreadObject" + first_id: + type: string + example: "asst_abc123" + last_id: + type: string + example: "asst_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + MessageObject: + type: object + title: The message object + description: Represents a message within a [thread](/docs/api-reference/threads). + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.message`. + type: string + enum: ["thread.message"] + created_at: + description: The Unix timestamp (in seconds) for when the message was created. + type: integer + thread_id: + description: The [thread](/docs/api-reference/threads) ID that this message belongs to. + type: string + status: + description: The status of the message, which can be either `in_progress`, `incomplete`, or `completed`. + type: string + enum: ["in_progress", "incomplete", "completed"] + incomplete_details: + description: On an incomplete message, details about why the message is incomplete. + type: object + properties: + reason: + type: string + description: The reason the message is incomplete. + enum: + [ + "content_filter", + "max_tokens", + "run_cancelled", + "run_expired", + "run_failed", + ] + nullable: true + required: + - reason + completed_at: + description: The Unix timestamp (in seconds) for when the message was completed. + type: integer + nullable: true + incomplete_at: + description: The Unix timestamp (in seconds) for when the message was marked as incomplete. + type: integer + nullable: true + role: + description: The entity that produced the message. One of `user` or `assistant`. + type: string + enum: ["user", "assistant"] + content: + description: The content of the message in array of text and/or images. + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageContentImageFileObject" + - $ref: "#/components/schemas/MessageContentImageUrlObject" + - $ref: "#/components/schemas/MessageContentTextObject" + - $ref: "#/components/schemas/MessageContentRefusalObject" + x-oaiExpandable: true + assistant_id: + description: If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message. + type: string + nullable: true + run_id: + description: The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints. + type: string + nullable: true + attachments: + type: array + items: + type: object + properties: + file_id: + type: string + description: The ID of the file to attach to the message. + tools: + description: The tools to add this file to. + type: array + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearchTypeOnly" + x-oaiExpandable: true + description: A list of files attached to the message, and the tools they were added to. + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + required: + - id + - object + - created_at + - thread_id + - status + - incomplete_details + - completed_at + - incomplete_at + - role + - content + - assistant_id + - run_id + - attachments + - metadata + x-oaiMeta: + name: The message object + beta: true + example: | + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1698983503, + "thread_id": "thread_abc123", + "role": "assistant", + "content": [ + { + "type": "text", + "text": { + "value": "Hi! How can I help you today?", + "annotations": [] + } + } + ], + "assistant_id": "asst_abc123", + "run_id": "run_abc123", + "attachments": [], + "metadata": {} + } + + MessageDeltaObject: + type: object + title: Message delta object + description: | + Represents a message delta i.e. any changed fields on a message during streaming. + properties: + id: + description: The identifier of the message, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.message.delta`. + type: string + enum: ["thread.message.delta"] + delta: + description: The delta containing the fields that have changed on the Message. + type: object + properties: + role: + description: The entity that produced the message. One of `user` or `assistant`. + type: string + enum: ["user", "assistant"] + content: + description: The content of the message in array of text and/or images. + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageDeltaContentImageFileObject" + - $ref: "#/components/schemas/MessageDeltaContentTextObject" + - $ref: "#/components/schemas/MessageDeltaContentRefusalObject" + - $ref: "#/components/schemas/MessageDeltaContentImageUrlObject" + x-oaiExpandable: true required: - - model + - id + - object + - delta + x-oaiMeta: + name: The message delta object + beta: true + example: | + { + "id": "msg_123", + "object": "thread.message.delta", + "delta": { + "content": [ + { + "index": 0, + "type": "text", + "text": { "value": "Hello", "annotations": [] } + } + ] + } + } - ModifyAssistantRequest: + CreateMessageRequest: type: object additionalProperties: false + required: + - role + - content properties: - model: - description: *model_description - anyOf: - - type: string - name: - description: *assistant_name_param_description - type: string - nullable: true - maxLength: 256 - description: - description: *assistant_description_param_description - type: string - nullable: true - maxLength: 512 - instructions: - description: *assistant_instructions_param_description + role: type: string - nullable: true - maxLength: 256000 - tools: - description: *assistant_tools_param_description - default: [] + enum: ["user", "assistant"] + description: | + The role of the entity that is creating the message. Allowed values include: + - `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages. + - `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation. + content: + oneOf: + - type: string + description: The text contents of the message. + title: Text content + - type: array + description: An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](/docs/models/overview). + title: Array of content parts + items: + oneOf: + - $ref: "#/components/schemas/MessageContentImageFileObject" + - $ref: "#/components/schemas/MessageContentImageUrlObject" + - $ref: "#/components/schemas/MessageRequestContentTextObject" + x-oaiExpandable: true + minItems: 1 + x-oaiExpandable: true + attachments: type: array - maxItems: 128 items: - oneOf: - - $ref: "#/components/schemas/AssistantToolsCode" - - $ref: "#/components/schemas/AssistantToolsFileSearch" - - $ref: "#/components/schemas/AssistantToolsFunction" - x-oaiExpandable: true - tool_resources: - type: object - description: | - A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. - properties: - code_interpreter: - type: object - properties: - file_ids: - type: array - description: | - Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. - default: [] - maxItems: 20 - items: - type: string - file_search: - type: object - properties: - vector_store_ids: - type: array - description: | - Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. - maxItems: 1 - items: - type: string + type: object + properties: + file_id: + type: string + description: The ID of the file to attach to the message. + tools: + description: The tools to add this file to. + type: array + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearchTypeOnly" + x-oaiExpandable: true + description: A list of files attached to the message, and the tools they should be added to. + required: + - file_id + - tools nullable: true metadata: description: *metadata_description type: object x-oaiTypeLabel: map nullable: true - temperature: - description: *run_temperature_description - type: number - minimum: 0 - maximum: 2 - default: 1 - example: 1 - nullable: true - top_p: - type: number - minimum: 0 - maximum: 1 - default: 1 - example: 1 - nullable: true - description: *run_top_p_description - response_format: - $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + + ModifyMessageRequest: + type: object + additionalProperties: false + properties: + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map nullable: true - DeleteAssistantResponse: + DeleteMessageResponse: type: object properties: id: @@ -10601,14 +13405,13 @@ components: type: boolean object: type: string - enum: [assistant.deleted] + enum: [thread.message.deleted] required: - id - object - deleted - ListAssistantsResponse: - type: object + ListMessagesResponse: properties: object: type: string @@ -10616,933 +13419,520 @@ components: data: type: array items: - $ref: "#/components/schemas/AssistantObject" + $ref: "#/components/schemas/MessageObject" first_id: type: string - example: "asst_abc123" - last_id: - type: string - example: "asst_abc456" - has_more: - type: boolean - example: false - required: - - object - - data - - first_id - - last_id - - has_more - x-oaiMeta: - name: List assistants response object - group: chat - example: *list_assistants_example - - AssistantToolsCode: - type: object - title: Code interpreter tool - properties: - type: - type: string - description: "The type of tool being defined: `code_interpreter`" - enum: ["code_interpreter"] - required: - - type - - AssistantToolsFileSearch: - type: object - title: FileSearch tool - properties: - type: - type: string - description: "The type of tool being defined: `file_search`" - enum: ["file_search"] - file_search: - type: object - description: Overrides for the file search tool. - properties: - max_num_results: - type: integer - minimum: 1 - maximum: 50 - description: | - The maximum number of results the file search tool should output. The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. This number should be between 1 and 50 inclusive. - - Note that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information. + example: "msg_abc123" + last_id: + type: string + example: "msg_abc123" + has_more: + type: boolean + example: false required: - - type + - object + - data + - first_id + - last_id + - has_more - AssistantToolsFileSearchTypeOnly: + MessageContentImageFileObject: + title: Image file type: object - title: FileSearch tool + description: References an image [File](/docs/api-reference/files) in the content of a message. properties: type: + description: Always `image_file`. type: string - description: "The type of tool being defined: `file_search`" - enum: ["file_search"] + enum: ["image_file"] + image_file: + type: object + properties: + file_id: + description: The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content. + type: string + detail: + type: string + description: Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`. + enum: ["auto", "low", "high"] + default: "auto" + required: + - file_id required: - type + - image_file - AssistantToolsFunction: + MessageDeltaContentImageFileObject: + title: Image file type: object - title: Function tool + description: References an image [File](/docs/api-reference/files) in the content of a message. properties: + index: + type: integer + description: The index of the content part in the message. type: + description: Always `image_file`. type: string - description: "The type of tool being defined: `function`" - enum: ["function"] - function: - $ref: "#/components/schemas/FunctionObject" + enum: ["image_file"] + image_file: + type: object + properties: + file_id: + description: The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content. + type: string + detail: + type: string + description: Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`. + enum: ["auto", "low", "high"] + default: "auto" required: + - index - type - - function - TruncationObject: + MessageContentImageUrlObject: + title: Image URL type: object - title: Thread Truncation Controls - description: Controls for how a thread will be truncated prior to the run. Use this to control the intial context window of the run. + description: References an image URL in the content of a message. properties: type: type: string - description: The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`. - enum: ["auto", "last_messages"] - last_messages: - type: integer - description: The number of most recent messages from the thread when constructing the context for the run. - minimum: 1 - nullable: true + enum: ["image_url"] + description: The type of the content part. + image_url: + type: object + properties: + url: + type: string + description: "The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + format: uri + detail: + type: string + description: Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto` + enum: ["auto", "low", "high"] + default: "auto" + required: + - url required: - type + - image_url - AssistantsApiToolChoiceOption: - description: | - Controls which (if any) tool is called by the model. - `none` means the model will not call any tools and instead generates a message. - `auto` is the default value and means the model can pick between generating a message or calling one or more tools. - `required` means the model must call one or more tools before responding to the user. - Specifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. - - oneOf: - - type: string - description: > - `none` means the model will not call any tools and instead generates a message. - `auto` means the model can pick between generating a message or calling one or more tools. - `required` means the model must call one or more tools before responding to the user. - enum: [none, auto, required] - - $ref: "#/components/schemas/AssistantsNamedToolChoice" - x-oaiExpandable: true - - AssistantsNamedToolChoice: + MessageDeltaContentImageUrlObject: + title: Image URL type: object - description: Specifies a tool the model should use. Use to force the model to call a specific tool. + description: References an image URL in the content of a message. properties: + index: + type: integer + description: The index of the content part in the message. type: + description: Always `image_url`. type: string - enum: ["function", "code_interpreter", "file_search"] - description: The type of the tool. If type is `function`, the function name must be set - function: + enum: ["image_url"] + image_url: type: object properties: - name: + url: + description: "The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." type: string - description: The name of the function to call. - required: - - name + detail: + type: string + description: Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. + enum: ["auto", "low", "high"] + default: "auto" required: + - index - type - RunObject: + MessageContentTextObject: + title: Text type: object - title: A run on a thread - description: Represents an execution run on a [thread](/docs/api-reference/threads). + description: The text content that is part of a message. properties: - id: - description: The identifier, which can be referenced in API endpoints. - type: string - object: - description: The object type, which is always `thread.run`. - type: string - enum: ["thread.run"] - created_at: - description: The Unix timestamp (in seconds) for when the run was created. - type: integer - thread_id: - description: The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run. - type: string - assistant_id: - description: The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run. - type: string - status: - description: The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`. + type: + description: Always `text`. type: string - enum: - [ - "queued", - "in_progress", - "requires_action", - "cancelling", - "cancelled", - "failed", - "completed", - "incomplete", - "expired", - ] - required_action: - type: object - description: Details on the action required to continue the run. Will be `null` if no action is required. - nullable: true - properties: - type: - description: For now, this is always `submit_tool_outputs`. - type: string - enum: ["submit_tool_outputs"] - submit_tool_outputs: - type: object - description: Details on the tool outputs needed for this run to continue. - properties: - tool_calls: - type: array - description: A list of the relevant tool calls. - items: - $ref: "#/components/schemas/RunToolCallObject" - required: - - tool_calls - required: - - type - - submit_tool_outputs - last_error: + enum: ["text"] + text: type: object - description: The last error associated with this run. Will be `null` if there are no errors. - nullable: true properties: - code: - type: string - description: One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`. - enum: ["server_error", "rate_limit_exceeded", "invalid_prompt"] - message: + value: + description: The data that makes up the text. type: string - description: A human-readable description of the error. + annotations: + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageContentTextAnnotationsFileCitationObject" + - $ref: "#/components/schemas/MessageContentTextAnnotationsFilePathObject" + x-oaiExpandable: true required: - - code - - message - expires_at: - description: The Unix timestamp (in seconds) for when the run will expire. - type: integer - nullable: true - started_at: - description: The Unix timestamp (in seconds) for when the run was started. - type: integer - nullable: true - cancelled_at: - description: The Unix timestamp (in seconds) for when the run was cancelled. - type: integer - nullable: true - failed_at: - description: The Unix timestamp (in seconds) for when the run failed. - type: integer - nullable: true - completed_at: - description: The Unix timestamp (in seconds) for when the run was completed. - type: integer - nullable: true - incomplete_details: - description: Details on why the run is incomplete. Will be `null` if the run is not incomplete. - type: object - nullable: true - properties: - reason: - description: The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run. - type: string - enum: ["max_completion_tokens", "max_prompt_tokens"] - model: - description: The model that the [assistant](/docs/api-reference/assistants) used for this run. + - value + - annotations + required: + - type + - text + + MessageContentRefusalObject: + title: Refusal + type: object + description: The refusal content generated by the assistant. + properties: + type: + description: Always `refusal`. type: string - instructions: - description: The instructions that the [assistant](/docs/api-reference/assistants) used for this run. + enum: ["refusal"] + refusal: type: string - tools: - description: The list of tools that the [assistant](/docs/api-reference/assistants) used for this run. - default: [] - type: array - maxItems: 20 - items: - oneOf: - - $ref: "#/components/schemas/AssistantToolsCode" - - $ref: "#/components/schemas/AssistantToolsFileSearch" - - $ref: "#/components/schemas/AssistantToolsFunction" - x-oaiExpandable: true - metadata: - description: *metadata_description - type: object - x-oaiTypeLabel: map - nullable: true - usage: - $ref: "#/components/schemas/RunCompletionUsage" - temperature: - description: The sampling temperature used for this run. If not set, defaults to 1. - type: number - nullable: true - top_p: - description: The nucleus sampling value used for this run. If not set, defaults to 1. - type: number - nullable: true - max_prompt_tokens: - type: integer - nullable: true - description: | - The maximum number of prompt tokens specified to have been used over the course of the run. - minimum: 256 - max_completion_tokens: - type: integer - nullable: true - description: | - The maximum number of completion tokens specified to have been used over the course of the run. - minimum: 256 - truncation_strategy: - $ref: "#/components/schemas/TruncationObject" - nullable: true - tool_choice: - $ref: "#/components/schemas/AssistantsApiToolChoiceOption" - nullable: true - parallel_tool_calls: - $ref: "#/components/schemas/ParallelToolCalls" - response_format: - $ref: "#/components/schemas/AssistantsApiResponseFormatOption" - nullable: true + nullable: false required: - - id - - object - - created_at - - thread_id - - assistant_id - - status - - required_action - - last_error - - expires_at - - started_at - - cancelled_at - - failed_at - - completed_at - - model - - instructions - - tools - - metadata - - usage - - incomplete_details - - max_prompt_tokens - - max_completion_tokens - - truncation_strategy - - tool_choice - - parallel_tool_calls - - response_format - x-oaiMeta: - name: The run object - beta: true - example: | - { - "id": "run_abc123", - "object": "thread.run", - "created_at": 1698107661, - "assistant_id": "asst_abc123", - "thread_id": "thread_abc123", - "status": "completed", - "started_at": 1699073476, - "expires_at": null, - "cancelled_at": null, - "failed_at": null, - "completed_at": 1699073498, - "last_error": null, - "model": "gpt-4-turbo", - "instructions": null, - "tools": [{"type": "file_search"}, {"type": "code_interpreter"}], - "metadata": {}, - "incomplete_details": null, - "usage": { - "prompt_tokens": 123, - "completion_tokens": 456, - "total_tokens": 579 - }, - "temperature": 1.0, - "top_p": 1.0, - "max_prompt_tokens": 1000, - "max_completion_tokens": 1000, - "truncation_strategy": { - "type": "auto", - "last_messages": null - }, - "response_format": "auto", - "tool_choice": "auto", - "parallel_tool_calls": true - } - CreateRunRequest: + - type + - refusal + + MessageRequestContentTextObject: + title: Text type: object - additionalProperties: false + description: The text content that is part of a message. properties: - assistant_id: - description: The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run. + type: + description: Always `text`. type: string - model: - description: The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. - example: "gpt-4-turbo" - anyOf: - - type: string - - type: string - enum: - [ - "gpt-4o", - "gpt-4o-2024-05-13", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-0125-preview", - "gpt-4-turbo-preview", - "gpt-4-1106-preview", - "gpt-4-vision-preview", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-32k", - "gpt-4-32k-0314", - "gpt-4-32k-0613", - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-16k-0613", - ] - x-oaiTypeLabel: string - nullable: true - instructions: - description: Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis. + enum: ["text"] + text: type: string - nullable: true - additional_instructions: - description: Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions. + description: Text content to be sent to the model + required: + - type + - text + + MessageContentTextAnnotationsFileCitationObject: + title: File citation + type: object + description: A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. + properties: + type: + description: Always `file_citation`. type: string - nullable: true - additional_messages: - description: Adds additional messages to the thread before creating the run. - type: array - items: - $ref: "#/components/schemas/CreateMessageRequest" - nullable: true - tools: - description: Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. - nullable: true - type: array - maxItems: 20 - items: - oneOf: - - $ref: "#/components/schemas/AssistantToolsCode" - - $ref: "#/components/schemas/AssistantToolsFileSearch" - - $ref: "#/components/schemas/AssistantToolsFunction" - x-oaiExpandable: true - metadata: - description: *metadata_description + enum: ["file_citation"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_citation: type: object - x-oaiTypeLabel: map - nullable: true - temperature: - type: number - minimum: 0 - maximum: 2 - default: 1 - example: 1 - nullable: true - description: *run_temperature_description - top_p: - type: number - minimum: 0 - maximum: 1 - default: 1 - example: 1 - nullable: true - description: *run_top_p_description - stream: - type: boolean - nullable: true - description: | - If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. - max_prompt_tokens: - type: integer - nullable: true - description: | - The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. - minimum: 256 - max_completion_tokens: - type: integer - nullable: true - description: | - The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. - minimum: 256 - truncation_strategy: - $ref: "#/components/schemas/TruncationObject" - nullable: true - tool_choice: - $ref: "#/components/schemas/AssistantsApiToolChoiceOption" - nullable: true - parallel_tool_calls: - $ref: "#/components/schemas/ParallelToolCalls" - response_format: - $ref: "#/components/schemas/AssistantsApiResponseFormatOption" - nullable: true + properties: + file_id: + description: The ID of the specific File the citation is from. + type: string + required: + - file_id + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 required: - - thread_id - - assistant_id - ListRunsResponse: + - type + - text + - file_citation + - start_index + - end_index + + MessageContentTextAnnotationsFilePathObject: + title: File path type: object + description: A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. properties: - object: - type: string - example: "list" - data: - type: array - items: - $ref: "#/components/schemas/RunObject" - first_id: + type: + description: Always `file_path`. type: string - example: "run_abc123" - last_id: + enum: ["file_path"] + text: + description: The text in the message content that needs to be replaced. type: string - example: "run_abc456" - has_more: - type: boolean - example: false + file_path: + type: object + properties: + file_id: + description: The ID of the file that was generated. + type: string + required: + - file_id + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 required: - - object - - data - - first_id - - last_id - - has_more - ModifyRunRequest: + - type + - text + - file_path + - start_index + - end_index + + MessageDeltaContentTextObject: + title: Text type: object - additionalProperties: false + description: The text content that is part of a message. properties: - metadata: - description: *metadata_description + index: + type: integer + description: The index of the content part in the message. + type: + description: Always `text`. + type: string + enum: ["text"] + text: type: object - x-oaiTypeLabel: map - nullable: true - SubmitToolOutputsRunRequest: + properties: + value: + description: The data that makes up the text. + type: string + annotations: + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageDeltaContentTextAnnotationsFileCitationObject" + - $ref: "#/components/schemas/MessageDeltaContentTextAnnotationsFilePathObject" + x-oaiExpandable: true + required: + - index + - type + + MessageDeltaContentRefusalObject: + title: Refusal type: object - additionalProperties: false + description: The refusal content that is part of a message. properties: - tool_outputs: - description: A list of tools for which the outputs are being submitted. - type: array - items: - type: object - properties: - tool_call_id: - type: string - description: The ID of the tool call in the `required_action` object within the run object the output is being submitted for. - output: - type: string - description: The output of the tool call to be submitted to continue the run. - stream: - type: boolean - nullable: true - description: | - If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + index: + type: integer + description: The index of the refusal part in the message. + type: + description: Always `refusal`. + type: string + enum: ["refusal"] + refusal: + type: string required: - - tool_outputs + - index + - type - RunToolCallObject: + MessageDeltaContentTextAnnotationsFileCitationObject: + title: File citation type: object - description: Tool call objects + description: A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. properties: - id: - type: string - description: The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](/docs/api-reference/runs/submitToolOutputs) endpoint. + index: + type: integer + description: The index of the annotation in the text content part. type: + description: Always `file_citation`. type: string - description: The type of tool call the output is required for. For now, this is always `function`. - enum: ["function"] - function: + enum: ["file_citation"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_citation: type: object - description: The function definition. properties: - name: + file_id: + description: The ID of the specific File the citation is from. type: string - description: The name of the function. - arguments: + quote: + description: The specific quote in the file. type: string - description: The arguments that the model expects you to pass to the function. - required: - - name - - arguments + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 required: - - id + - index - type - - function - CreateThreadAndRunRequest: + MessageDeltaContentTextAnnotationsFilePathObject: + title: File path type: object - additionalProperties: false + description: A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. properties: - assistant_id: - description: The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run. + index: + type: integer + description: The index of the annotation in the text content part. + type: + description: Always `file_path`. type: string - thread: - $ref: "#/components/schemas/CreateThreadRequest" - description: If no thread is provided, an empty thread will be created. - model: - description: The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. - example: "gpt-4-turbo" - anyOf: - - type: string - - type: string - enum: - [ - "gpt-4o", - "gpt-4o-2024-05-13", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-0125-preview", - "gpt-4-turbo-preview", - "gpt-4-1106-preview", - "gpt-4-vision-preview", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-32k", - "gpt-4-32k-0314", - "gpt-4-32k-0613", - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-16k-0613", - ] - x-oaiTypeLabel: string - nullable: true - instructions: - description: Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis. + enum: ["file_path"] + text: + description: The text in the message content that needs to be replaced. type: string - nullable: true - tools: - description: Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. - nullable: true - type: array - maxItems: 20 - items: - oneOf: - - $ref: "#/components/schemas/AssistantToolsCode" - - $ref: "#/components/schemas/AssistantToolsFileSearch" - - $ref: "#/components/schemas/AssistantToolsFunction" - tool_resources: - type: object - description: | - A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. - properties: - code_interpreter: - type: object - properties: - file_ids: - type: array - description: | - A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. - default: [] - maxItems: 20 - items: - type: string - file_search: - type: object - properties: - vector_store_ids: - type: array - description: | - The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. - maxItems: 1 - items: - type: string - nullable: true - metadata: - description: *metadata_description + file_path: type: object - x-oaiTypeLabel: map - nullable: true - temperature: - type: number - minimum: 0 - maximum: 2 - default: 1 - example: 1 - nullable: true - description: *run_temperature_description - top_p: - type: number - minimum: 0 - maximum: 1 - default: 1 - example: 1 - nullable: true - description: *run_top_p_description - stream: - type: boolean - nullable: true - description: | - If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. - max_prompt_tokens: + properties: + file_id: + description: The ID of the file that was generated. + type: string + start_index: type: integer - nullable: true - description: | - The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. - minimum: 256 - max_completion_tokens: + minimum: 0 + end_index: type: integer - nullable: true - description: | - The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. - minimum: 256 - truncation_strategy: - $ref: "#/components/schemas/TruncationObject" - nullable: true - tool_choice: - $ref: "#/components/schemas/AssistantsApiToolChoiceOption" - nullable: true - parallel_tool_calls: - $ref: "#/components/schemas/ParallelToolCalls" - response_format: - $ref: "#/components/schemas/AssistantsApiResponseFormatOption" - nullable: true + minimum: 0 required: - - thread_id - - assistant_id + - index + - type - ThreadObject: + RunStepObject: type: object - title: Thread - description: Represents a thread that contains [messages](/docs/api-reference/messages). + title: Run steps + description: | + Represents a step in execution of a run. properties: id: - description: The identifier, which can be referenced in API endpoints. + description: The identifier of the run step, which can be referenced in API endpoints. type: string object: - description: The object type, which is always `thread`. + description: The object type, which is always `thread.run.step`. type: string - enum: ["thread"] + enum: ["thread.run.step"] created_at: - description: The Unix timestamp (in seconds) for when the thread was created. + description: The Unix timestamp (in seconds) for when the run step was created. type: integer - tool_resources: + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) associated with the run step. + type: string + thread_id: + description: The ID of the [thread](/docs/api-reference/threads) that was run. + type: string + run_id: + description: The ID of the [run](/docs/api-reference/runs) that this run step is a part of. + type: string + type: + description: The type of run step, which can be either `message_creation` or `tool_calls`. + type: string + enum: ["message_creation", "tool_calls"] + status: + description: The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`. + type: string + enum: ["in_progress", "cancelled", "failed", "completed", "expired"] + step_details: type: object - description: | - A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. - properties: - code_interpreter: - type: object - properties: - file_ids: - type: array - description: | - A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. - default: [] - maxItems: 20 - items: - type: string - file_search: - type: object - properties: - vector_store_ids: - type: array - description: | - The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. - maxItems: 1 - items: - type: string - nullable: true - metadata: - description: *metadata_description + description: The details of the run step. + oneOf: + - $ref: "#/components/schemas/RunStepDetailsMessageCreationObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsObject" + x-oaiExpandable: true + last_error: type: object - x-oaiTypeLabel: map + description: The last error associated with this run step. Will be `null` if there are no errors. nullable: true - required: - - id - - object - - created_at - - tool_resources - - metadata - x-oaiMeta: - name: The thread object - beta: true - example: | - { - "id": "thread_abc123", - "object": "thread", - "created_at": 1698107661, - "metadata": {} - } - - CreateThreadRequest: - type: object - additionalProperties: false - properties: - messages: - description: A list of [messages](/docs/api-reference/messages) to start the thread with. - type: array - items: - $ref: "#/components/schemas/CreateMessageRequest" - tool_resources: - type: object - description: | - A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. properties: - code_interpreter: - type: object - properties: - file_ids: - type: array - description: | - A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. - default: [] - maxItems: 20 - items: - type: string - file_search: - type: object - properties: - vector_store_ids: - type: array - description: | - The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. - maxItems: 1 - items: - type: string - vector_stores: - type: array - description: | - A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread. - maxItems: 1 - items: - type: object - properties: - file_ids: - type: array - description: | - A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store. - maxItems: 10000 - items: - type: string - chunking_strategy: - # Ideally we'd reuse the chunking strategy schema here, but it doesn't expand properly - type: object - description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. - oneOf: - - type: object - title: Auto Chunking Strategy - description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. - additionalProperties: false - properties: - type: - type: string - description: Always `auto`. - enum: ["auto"] - required: - - type - - type: object - title: Static Chunking Strategy - additionalProperties: false - properties: - type: - type: string - description: Always `static`. - enum: ["static"] - static: - type: object - additionalProperties: false - properties: - max_chunk_size_tokens: - type: integer - minimum: 100 - maximum: 4096 - description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. - chunk_overlap_tokens: - type: integer - description: | - The number of tokens that overlap between chunks. The default value is `400`. - - Note that the overlap must not exceed half of `max_chunk_size_tokens`. - required: - - max_chunk_size_tokens - - chunk_overlap_tokens - required: - - type - - static - x-oaiExpandable: true - metadata: - type: object - description: | - Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. - x-oaiTypeLabel: map - x-oaiExpandable: true - oneOf: - - required: [vector_store_ids] - - required: [vector_stores] + code: + type: string + description: One of `server_error` or `rate_limit_exceeded`. + enum: ["server_error", "rate_limit_exceeded"] + message: + type: string + description: A human-readable description of the error. + required: + - code + - message + expired_at: + description: The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired. + type: integer nullable: true - metadata: - description: *metadata_description - type: object - x-oaiTypeLabel: map + cancelled_at: + description: The Unix timestamp (in seconds) for when the run step was cancelled. + type: integer nullable: true - - ModifyThreadRequest: - type: object - additionalProperties: false - properties: - tool_resources: - type: object - description: | - A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. - properties: - code_interpreter: - type: object - properties: - file_ids: - type: array - description: | - A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. - default: [] - maxItems: 20 - items: - type: string - file_search: - type: object - properties: - vector_store_ids: - type: array - description: | - The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. - maxItems: 1 - items: - type: string + failed_at: + description: The Unix timestamp (in seconds) for when the run step failed. + type: integer + nullable: true + completed_at: + description: The Unix timestamp (in seconds) for when the run step completed. + type: integer nullable: true metadata: description: *metadata_description type: object x-oaiTypeLabel: map nullable: true + usage: + $ref: "#/components/schemas/RunStepCompletionUsage" + required: + - id + - object + - created_at + - assistant_id + - thread_id + - run_id + - type + - status + - step_details + - last_error + - expired_at + - cancelled_at + - failed_at + - completed_at + - metadata + - usage + x-oaiMeta: + name: The run step object + beta: true + example: *run_step_object_example - DeleteThreadResponse: + RunStepDeltaObject: type: object + title: Run step delta object + description: | + Represents a run step delta i.e. any changed fields on a run step during streaming. properties: id: + description: The identifier of the run step, which can be referenced in API endpoints. type: string - deleted: - type: boolean object: + description: The object type, which is always `thread.run.step.delta`. type: string - enum: [thread.deleted] + enum: ["thread.run.step.delta"] + delta: + description: The delta containing the fields that have changed on the run step. + type: object + properties: + step_details: + type: object + description: The details of the run step. + oneOf: + - $ref: "#/components/schemas/RunStepDeltaStepDetailsMessageCreationObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsObject" + x-oaiExpandable: true required: - id - object - - deleted + - delta + x-oaiMeta: + name: The run step delta object + beta: true + example: | + { + "id": "step_123", + "object": "thread.run.step.delta", + "delta": { + "step_details": { + "type": "tool_calls", + "tool_calls": [ + { + "index": 0, + "id": "call_123", + "type": "code_interpreter", + "code_interpreter": { "input": "", "outputs": [] } + } + ] + } + } + } - ListThreadsResponse: + ListRunStepsResponse: properties: object: type: string @@ -11550,13 +13940,13 @@ components: data: type: array items: - $ref: "#/components/schemas/ThreadObject" + $ref: "#/components/schemas/RunStepObject" first_id: type: string - example: "asst_abc123" + example: "step_abc123" last_id: type: string - example: "asst_abc456" + example: "step_abc456" has_more: type: boolean example: false @@ -11567,760 +13957,718 @@ components: - last_id - has_more - MessageObject: + RunStepDetailsMessageCreationObject: + title: Message creation type: object - title: The message object - description: Represents a message within a [thread](/docs/api-reference/threads). + description: Details of the message creation by the run step. properties: - id: - description: The identifier, which can be referenced in API endpoints. - type: string - object: - description: The object type, which is always `thread.message`. - type: string - enum: ["thread.message"] - created_at: - description: The Unix timestamp (in seconds) for when the message was created. - type: integer - thread_id: - description: The [thread](/docs/api-reference/threads) ID that this message belongs to. - type: string - status: - description: The status of the message, which can be either `in_progress`, `incomplete`, or `completed`. + type: + description: Always `message_creation`. type: string - enum: ["in_progress", "incomplete", "completed"] - incomplete_details: - description: On an incomplete message, details about why the message is incomplete. + enum: ["message_creation"] + message_creation: type: object properties: - reason: + message_id: type: string - description: The reason the message is incomplete. - enum: - [ - "content_filter", - "max_tokens", - "run_cancelled", - "run_expired", - "run_failed", - ] - nullable: true + description: The ID of the message that was created by this run step. required: - - reason - completed_at: - description: The Unix timestamp (in seconds) for when the message was completed. - type: integer - nullable: true - incomplete_at: - description: The Unix timestamp (in seconds) for when the message was marked as incomplete. - type: integer - nullable: true - role: - description: The entity that produced the message. One of `user` or `assistant`. + - message_id + required: + - type + - message_creation + + RunStepDeltaStepDetailsMessageCreationObject: + title: Message creation + type: object + description: Details of the message creation by the run step. + properties: + type: + description: Always `message_creation`. type: string - enum: ["user", "assistant"] - content: - description: The content of the message in array of text and/or images. + enum: ["message_creation"] + message_creation: + type: object + properties: + message_id: + type: string + description: The ID of the message that was created by this run step. + required: + - type + + RunStepDetailsToolCallsObject: + title: Tool calls + type: object + description: Details of the tool call. + properties: + type: + description: Always `tool_calls`. + type: string + enum: ["tool_calls"] + tool_calls: type: array + description: | + An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. items: oneOf: - - $ref: "#/components/schemas/MessageContentImageFileObject" - - $ref: "#/components/schemas/MessageContentImageUrlObject" - - $ref: "#/components/schemas/MessageContentTextObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsFileSearchObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsFunctionObject" + x-oaiExpandable: true + required: + - type + - tool_calls + + RunStepDeltaStepDetailsToolCallsObject: + title: Tool calls + type: object + description: Details of the tool call. + properties: + type: + description: Always `tool_calls`. + type: string + enum: ["tool_calls"] + tool_calls: + type: array + description: | + An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. + items: + oneOf: + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsFileSearchObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsFunctionObject" x-oaiExpandable: true - assistant_id: - description: If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message. - type: string - nullable: true - run_id: - description: The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints. - type: string - nullable: true - attachments: - type: array - items: - type: object - properties: - file_id: - type: string - description: The ID of the file to attach to the message. - tools: - description: The tools to add this file to. - type: array - items: - oneOf: - - $ref: "#/components/schemas/AssistantToolsCode" - - $ref: "#/components/schemas/AssistantToolsFileSearchTypeOnly" - x-oaiExpandable: true - description: A list of files attached to the message, and the tools they were added to. - nullable: true - metadata: - description: *metadata_description - type: object - x-oaiTypeLabel: map - nullable: true required: - - id - - object - - created_at - - thread_id - - status - - incomplete_details - - completed_at - - incomplete_at - - role - - content - - assistant_id - - run_id - - attachments - - metadata - x-oaiMeta: - name: The message object - beta: true - example: | - { - "id": "msg_abc123", - "object": "thread.message", - "created_at": 1698983503, - "thread_id": "thread_abc123", - "role": "assistant", - "content": [ - { - "type": "text", - "text": { - "value": "Hi! How can I help you today?", - "annotations": [] - } - } - ], - "assistant_id": "asst_abc123", - "run_id": "run_abc123", - "attachments": [], - "metadata": {} - } + - type - MessageDeltaObject: + RunStepDetailsToolCallsCodeObject: + title: Code Interpreter tool call type: object - title: Message delta object - description: | - Represents a message delta i.e. any changed fields on a message during streaming. + description: Details of the Code Interpreter tool call the run step was involved in. properties: id: - description: The identifier of the message, which can be referenced in API endpoints. type: string - object: - description: The object type, which is always `thread.message.delta`. + description: The ID of the tool call. + type: type: string - enum: ["thread.message.delta"] - delta: - description: The delta containing the fields that have changed on the Message. + description: The type of tool call. This is always going to be `code_interpreter` for this type of tool call. + enum: ["code_interpreter"] + code_interpreter: type: object + description: The Code Interpreter tool call definition. + required: + - input + - outputs properties: - role: - description: The entity that produced the message. One of `user` or `assistant`. + input: type: string - enum: ["user", "assistant"] - content: - description: The content of the message in array of text and/or images. + description: The input to the Code Interpreter tool call. + outputs: type: array + description: The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. items: + type: object oneOf: - - $ref: "#/components/schemas/MessageDeltaContentImageFileObject" - - $ref: "#/components/schemas/MessageDeltaContentTextObject" - - $ref: "#/components/schemas/MessageDeltaContentImageUrlObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeOutputLogsObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeOutputImageObject" x-oaiExpandable: true required: - id - - object - - delta - x-oaiMeta: - name: The message delta object - beta: true - example: | - { - "id": "msg_123", - "object": "thread.message.delta", - "delta": { - "content": [ - { - "index": 0, - "type": "text", - "text": { "value": "Hello", "annotations": [] } - } - ] - } - } + - type + - code_interpreter - CreateMessageRequest: + RunStepDeltaStepDetailsToolCallsCodeObject: + title: Code interpreter tool call type: object - additionalProperties: false - required: - - role - - content + description: Details of the Code Interpreter tool call the run step was involved in. properties: - role: + index: + type: integer + description: The index of the tool call in the tool calls array. + id: type: string - enum: ["user", "assistant"] - description: | - The role of the entity that is creating the message. Allowed values include: - - `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages. - - `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation. - content: - oneOf: - - type: string - description: The text contents of the message. - title: Text content - - type: array - description: An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](/docs/models/overview). - title: Array of content parts + description: The ID of the tool call. + type: + type: string + description: The type of tool call. This is always going to be `code_interpreter` for this type of tool call. + enum: ["code_interpreter"] + code_interpreter: + type: object + description: The Code Interpreter tool call definition. + properties: + input: + type: string + description: The input to the Code Interpreter tool call. + outputs: + type: array + description: The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. items: + type: object oneOf: - - $ref: "#/components/schemas/MessageContentImageFileObject" - - $ref: "#/components/schemas/MessageContentImageUrlObject" - - $ref: "#/components/schemas/MessageRequestContentTextObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeOutputImageObject" x-oaiExpandable: true - minItems: 1 - x-oaiExpandable: true - attachments: - type: array - items: - type: object - properties: - file_id: - type: string - description: The ID of the file to attach to the message. - tools: - description: The tools to add this file to. - type: array - items: - oneOf: - - $ref: "#/components/schemas/AssistantToolsCode" - - $ref: "#/components/schemas/AssistantToolsFileSearchTypeOnly" - x-oaiExpandable: true - description: A list of files attached to the message, and the tools they should be added to. - required: - - file_id - - tools - nullable: true - metadata: - description: *metadata_description - type: object - x-oaiTypeLabel: map - nullable: true - - ModifyMessageRequest: - type: object - additionalProperties: false - properties: - metadata: - description: *metadata_description - type: object - x-oaiTypeLabel: map - nullable: true + required: + - index + - type - DeleteMessageResponse: + RunStepDetailsToolCallsCodeOutputLogsObject: + title: Code Interpreter log output type: object + description: Text output from the Code Interpreter tool call as part of a run step. properties: - id: + type: + description: Always `logs`. type: string - deleted: - type: boolean - object: + enum: ["logs"] + logs: type: string - enum: [thread.message.deleted] + description: The text output from the Code Interpreter tool call. required: - - id - - object - - deleted + - type + - logs - ListMessagesResponse: + RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject: + title: Code interpreter log output + type: object + description: Text output from the Code Interpreter tool call as part of a run step. properties: - object: - type: string - example: "list" - data: - type: array - items: - $ref: "#/components/schemas/MessageObject" - first_id: + index: + type: integer + description: The index of the output in the outputs array. + type: + description: Always `logs`. type: string - example: "msg_abc123" - last_id: + enum: ["logs"] + logs: type: string - example: "msg_abc123" - has_more: - type: boolean - example: false + description: The text output from the Code Interpreter tool call. required: - - object - - data - - first_id - - last_id - - has_more + - index + - type - MessageContentImageFileObject: - title: Image file + RunStepDetailsToolCallsCodeOutputImageObject: + title: Code Interpreter image output type: object - description: References an image [File](/docs/api-reference/files) in the content of a message. properties: type: - description: Always `image_file`. + description: Always `image`. type: string - enum: ["image_file"] - image_file: + enum: ["image"] + image: type: object properties: file_id: - description: The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content. - type: string - detail: + description: The [file](/docs/api-reference/files) ID of the image. type: string - description: Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`. - enum: ["auto", "low", "high"] - default: "auto" required: - file_id required: - type - - image_file + - image - MessageDeltaContentImageFileObject: - title: Image file + RunStepDeltaStepDetailsToolCallsCodeOutputImageObject: + title: Code interpreter image output type: object - description: References an image [File](/docs/api-reference/files) in the content of a message. properties: index: type: integer - description: The index of the content part in the message. + description: The index of the output in the outputs array. type: - description: Always `image_file`. + description: Always `image`. type: string - enum: ["image_file"] - image_file: + enum: ["image"] + image: type: object properties: file_id: - description: The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content. - type: string - detail: + description: The [file](/docs/api-reference/files) ID of the image. type: string - description: Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`. - enum: ["auto", "low", "high"] - default: "auto" required: - index - type - MessageContentImageUrlObject: - title: Image URL + RunStepDetailsToolCallsFileSearchObject: + title: File search tool call type: object - description: References an image URL in the content of a message. properties: + id: + type: string + description: The ID of the tool call object. type: type: string - enum: ["image_url"] - description: The type of the content part. - image_url: + description: The type of tool call. This is always going to be `file_search` for this type of tool call. + enum: ["file_search"] + file_search: type: object - properties: - url: - type: string - description: "The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." - format: uri - detail: - type: string - description: Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto` - enum: ["auto", "low", "high"] - default: "auto" - required: - - url + description: For now, this is always going to be an empty object. + x-oaiTypeLabel: map required: + - id - type - - image_url + - file_search - MessageDeltaContentImageUrlObject: - title: Image URL + RunStepDeltaStepDetailsToolCallsFileSearchObject: + title: File search tool call type: object - description: References an image URL in the content of a message. properties: index: type: integer - description: The index of the content part in the message. + description: The index of the tool call in the tool calls array. + id: + type: string + description: The ID of the tool call object. type: - description: Always `image_url`. type: string - enum: ["image_url"] - image_url: + description: The type of tool call. This is always going to be `file_search` for this type of tool call. + enum: ["file_search"] + file_search: type: object - properties: - url: - description: "The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." - type: string - detail: - type: string - description: Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. - enum: ["auto", "low", "high"] - default: "auto" + description: For now, this is always going to be an empty object. + x-oaiTypeLabel: map required: - index - type + - file_search - MessageContentTextObject: - title: Text + RunStepDetailsToolCallsFunctionObject: type: object - description: The text content that is part of a message. + title: Function tool call properties: + id: + type: string + description: The ID of the tool call object. type: - description: Always `text`. type: string - enum: ["text"] - text: + description: The type of tool call. This is always going to be `function` for this type of tool call. + enum: ["function"] + function: type: object + description: The definition of the function that was called. properties: - value: - description: The data that makes up the text. + name: type: string - annotations: - type: array - items: - oneOf: - - $ref: "#/components/schemas/MessageContentTextAnnotationsFileCitationObject" - - $ref: "#/components/schemas/MessageContentTextAnnotationsFilePathObject" - x-oaiExpandable: true + description: The name of the function. + arguments: + type: string + description: The arguments passed to the function. + output: + type: string + description: The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet. + nullable: true required: - - value - - annotations + - name + - arguments + - output required: + - id - type - - text + - function - MessageRequestContentTextObject: - title: Text + RunStepDeltaStepDetailsToolCallsFunctionObject: type: object - description: The text content that is part of a message. + title: Function tool call properties: - type: - description: Always `text`. + index: + type: integer + description: The index of the tool call in the tool calls array. + id: type: string - enum: ["text"] - text: + description: The ID of the tool call object. + type: type: string - description: Text content to be sent to the model + description: The type of tool call. This is always going to be `function` for this type of tool call. + enum: ["function"] + function: + type: object + description: The definition of the function that was called. + properties: + name: + type: string + description: The name of the function. + arguments: + type: string + description: The arguments passed to the function. + output: + type: string + description: The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet. + nullable: true required: + - index - type - - text - MessageContentTextAnnotationsFileCitationObject: - title: File citation + VectorStoreExpirationAfter: type: object - description: A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. + title: Vector store expiration policy + description: The expiration policy for a vector store. properties: - type: - description: Always `file_citation`. + anchor: + description: "Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`." type: string - enum: ["file_citation"] - text: - description: The text in the message content that needs to be replaced. + enum: ["last_active_at"] + days: + description: The number of days after the anchor time that the vector store will expire. + type: integer + minimum: 1 + maximum: 365 + required: + - anchor + - days + + VectorStoreObject: + type: object + title: Vector store + description: A vector store is a collection of processed files can be used by the `file_search` tool. + properties: + id: + description: The identifier, which can be referenced in API endpoints. type: string - file_citation: + object: + description: The object type, which is always `vector_store`. + type: string + enum: ["vector_store"] + created_at: + description: The Unix timestamp (in seconds) for when the vector store was created. + type: integer + name: + description: The name of the vector store. + type: string + usage_bytes: + description: The total number of bytes used by the files in the vector store. + type: integer + file_counts: type: object properties: - file_id: - description: The ID of the specific File the citation is from. - type: string + in_progress: + description: The number of files that are currently being processed. + type: integer + completed: + description: The number of files that have been successfully processed. + type: integer + failed: + description: The number of files that have failed to process. + type: integer + cancelled: + description: The number of files that were cancelled. + type: integer + total: + description: The total number of files. + type: integer required: - - file_id - start_index: + - in_progress + - completed + - failed + - cancelled + - total + status: + description: The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use. + type: string + enum: ["expired", "in_progress", "completed"] + expires_after: + $ref: "#/components/schemas/VectorStoreExpirationAfter" + expires_at: + description: The Unix timestamp (in seconds) for when the vector store will expire. type: integer - minimum: 0 - end_index: + nullable: true + last_active_at: + description: The Unix timestamp (in seconds) for when the vector store was last active. type: integer - minimum: 0 + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true required: - - type - - text - - file_citation - - start_index - - end_index + - id + - object + - usage_bytes + - created_at + - status + - last_active_at + - name + - file_counts + - metadata + x-oaiMeta: + name: The vector store object + beta: true + example: | + { + "id": "vs_123", + "object": "vector_store", + "created_at": 1698107661, + "usage_bytes": 123456, + "last_active_at": 1698107661, + "name": "my_vector_store", + "status": "completed", + "file_counts": { + "in_progress": 0, + "completed": 100, + "cancelled": 0, + "failed": 0, + "total": 100 + }, + "metadata": {}, + "last_used_at": 1698107661 + } - MessageContentTextAnnotationsFilePathObject: - title: File path + CreateVectorStoreRequest: type: object - description: A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. + additionalProperties: false properties: - type: - description: Always `file_path`. - type: string - enum: ["file_path"] - text: - description: The text in the message content that needs to be replaced. + file_ids: + description: A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. + type: array + maxItems: 500 + items: + type: string + name: + description: The name of the vector store. type: string - file_path: + expires_after: + $ref: "#/components/schemas/VectorStoreExpirationAfter" + chunking_strategy: type: object - properties: - file_id: - description: The ID of the file that was generated. - type: string - required: - - file_id - start_index: - type: integer - minimum: 0 - end_index: - type: integer - minimum: 0 - required: - - type - - text - - file_path - - start_index - - end_index + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty. + oneOf: + - $ref: "#/components/schemas/AutoChunkingStrategyRequestParam" + - $ref: "#/components/schemas/StaticChunkingStrategyRequestParam" + x-oaiExpandable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true - MessageDeltaContentTextObject: - title: Text + UpdateVectorStoreRequest: type: object - description: The text content that is part of a message. + additionalProperties: false properties: - index: - type: integer - description: The index of the content part in the message. - type: - description: Always `text`. + name: + description: The name of the vector store. type: string - enum: ["text"] - text: + nullable: true + expires_after: + $ref: "#/components/schemas/VectorStoreExpirationAfter" + nullable: true + metadata: + description: *metadata_description type: object - properties: - value: - description: The data that makes up the text. - type: string - annotations: - type: array - items: - oneOf: - - $ref: "#/components/schemas/MessageDeltaContentTextAnnotationsFileCitationObject" - - $ref: "#/components/schemas/MessageDeltaContentTextAnnotationsFilePathObject" - x-oaiExpandable: true - required: - - index - - type + x-oaiTypeLabel: map + nullable: true - MessageDeltaContentTextAnnotationsFileCitationObject: - title: File citation - type: object - description: A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. + ListVectorStoresResponse: properties: - index: - type: integer - description: The index of the annotation in the text content part. - type: - description: Always `file_citation`. + object: type: string - enum: ["file_citation"] - text: - description: The text in the message content that needs to be replaced. + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/VectorStoreObject" + first_id: type: string - file_citation: - type: object - properties: - file_id: - description: The ID of the specific File the citation is from. - type: string - quote: - description: The specific quote in the file. - type: string - start_index: - type: integer - minimum: 0 - end_index: - type: integer - minimum: 0 + example: "vs_abc123" + last_id: + type: string + example: "vs_abc456" + has_more: + type: boolean + example: false required: - - index - - type + - object + - data + - first_id + - last_id + - has_more - MessageDeltaContentTextAnnotationsFilePathObject: - title: File path + DeleteVectorStoreResponse: type: object - description: A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. properties: - index: - type: integer - description: The index of the annotation in the text content part. - type: - description: Always `file_path`. + id: type: string - enum: ["file_path"] - text: - description: The text in the message content that needs to be replaced. + deleted: + type: boolean + object: type: string - file_path: - type: object - properties: - file_id: - description: The ID of the file that was generated. - type: string - start_index: - type: integer - minimum: 0 - end_index: - type: integer - minimum: 0 + enum: [vector_store.deleted] required: - - index - - type + - id + - object + - deleted - RunStepObject: + VectorStoreFileObject: type: object - title: Run steps - description: | - Represents a step in execution of a run. + title: Vector store files + description: A list of files attached to a vector store. properties: id: - description: The identifier of the run step, which can be referenced in API endpoints. + description: The identifier, which can be referenced in API endpoints. type: string object: - description: The object type, which is always `thread.run.step`. + description: The object type, which is always `vector_store.file`. type: string - enum: ["thread.run.step"] + enum: ["vector_store.file"] + usage_bytes: + description: The total vector store usage in bytes. Note that this may be different from the original file size. + type: integer created_at: - description: The Unix timestamp (in seconds) for when the run step was created. + description: The Unix timestamp (in seconds) for when the vector store file was created. type: integer - assistant_id: - description: The ID of the [assistant](/docs/api-reference/assistants) associated with the run step. - type: string - thread_id: - description: The ID of the [thread](/docs/api-reference/threads) that was run. - type: string - run_id: - description: The ID of the [run](/docs/api-reference/runs) that this run step is a part of. - type: string - type: - description: The type of run step, which can be either `message_creation` or `tool_calls`. + vector_store_id: + description: The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to. type: string - enum: ["message_creation", "tool_calls"] status: - description: The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`. - type: string - enum: ["in_progress", "cancelled", "failed", "completed", "expired"] - step_details: - type: object - description: The details of the run step. - oneOf: - - $ref: "#/components/schemas/RunStepDetailsMessageCreationObject" - - $ref: "#/components/schemas/RunStepDetailsToolCallsObject" - x-oaiExpandable: true + description: The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use. + type: string + enum: ["in_progress", "completed", "cancelled", "failed"] last_error: type: object - description: The last error associated with this run step. Will be `null` if there are no errors. + description: The last error associated with this vector store file. Will be `null` if there are no errors. nullable: true properties: code: type: string description: One of `server_error` or `rate_limit_exceeded`. - enum: ["server_error", "rate_limit_exceeded"] + enum: ["server_error", "unsupported_file", "invalid_file"] message: type: string description: A human-readable description of the error. required: - code - message - expired_at: - description: The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired. - type: integer - nullable: true - cancelled_at: - description: The Unix timestamp (in seconds) for when the run step was cancelled. - type: integer - nullable: true - failed_at: - description: The Unix timestamp (in seconds) for when the run step failed. - type: integer - nullable: true - completed_at: - description: The Unix timestamp (in seconds) for when the run step completed. - type: integer - nullable: true - metadata: - description: *metadata_description + chunking_strategy: type: object - x-oaiTypeLabel: map - nullable: true - usage: - $ref: "#/components/schemas/RunStepCompletionUsage" + description: The strategy used to chunk the file. + oneOf: + - $ref: "#/components/schemas/StaticChunkingStrategyResponseParam" + - $ref: "#/components/schemas/OtherChunkingStrategyResponseParam" + x-oaiExpandable: true required: - id - object + - usage_bytes - created_at - - assistant_id - - thread_id - - run_id - - type + - vector_store_id - status - - step_details - last_error - - expired_at - - cancelled_at - - failed_at - - completed_at - - metadata - - usage x-oaiMeta: - name: The run step object + name: The vector store file object beta: true - example: *run_step_object_example + example: | + { + "id": "file-abc123", + "object": "vector_store.file", + "usage_bytes": 1234, + "created_at": 1698107661, + "vector_store_id": "vs_abc123", + "status": "completed", + "last_error": null, + "chunking_strategy": { + "type": "static", + "static": { + "max_chunk_size_tokens": 800, + "chunk_overlap_tokens": 400 + } + } + } - RunStepDeltaObject: + OtherChunkingStrategyResponseParam: type: object - title: Run step delta object - description: | - Represents a run step delta i.e. any changed fields on a run step during streaming. + title: Other Chunking Strategy + description: This is returned when the chunking strategy is unknown. Typically, this is because the file was indexed before the `chunking_strategy` concept was introduced in the API. + additionalProperties: false properties: - id: - description: The identifier of the run step, which can be referenced in API endpoints. + type: type: string - object: - description: The object type, which is always `thread.run.step.delta`. + description: Always `other`. + enum: ["other"] + required: + - type + + StaticChunkingStrategyResponseParam: + type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: type: string - enum: ["thread.run.step.delta"] - delta: - description: The delta containing the fields that have changed on the run step. - type: object - properties: - step_details: - type: object - description: The details of the run step. - oneOf: - - $ref: "#/components/schemas/RunStepDeltaStepDetailsMessageCreationObject" - - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsObject" - x-oaiExpandable: true + description: Always `static`. + enum: ["static"] + static: + $ref: "#/components/schemas/StaticChunkingStrategy" required: - - id - - object - - delta - x-oaiMeta: - name: The run step delta object - beta: true - example: | - { - "id": "step_123", - "object": "thread.run.step.delta", - "delta": { - "step_details": { - "type": "tool_calls", - "tool_calls": [ - { - "index": 0, - "id": "call_123", - "type": "code_interpreter", - "code_interpreter": { "input": "", "outputs": [] } - } - ] - } - } - } + - type + - static - ListRunStepsResponse: + StaticChunkingStrategy: + type: object + additionalProperties: false + properties: + max_chunk_size_tokens: + type: integer + minimum: 100 + maximum: 4096 + description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. + chunk_overlap_tokens: + type: integer + description: | + The number of tokens that overlap between chunks. The default value is `400`. + + Note that the overlap must not exceed half of `max_chunk_size_tokens`. + required: + - max_chunk_size_tokens + - chunk_overlap_tokens + + AutoChunkingStrategyRequestParam: + type: object + title: Auto Chunking Strategy + description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + additionalProperties: false + properties: + type: + type: string + description: Always `auto`. + enum: ["auto"] + required: + - type + + StaticChunkingStrategyRequestParam: + type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + $ref: "#/components/schemas/StaticChunkingStrategy" + required: + - type + - static + + ChunkingStrategyRequestParam: + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. + oneOf: + - $ref: "#/components/schemas/AutoChunkingStrategyRequestParam" + - $ref: "#/components/schemas/StaticChunkingStrategyRequestParam" + x-oaiExpandable: true + + CreateVectorStoreFileRequest: + type: object + additionalProperties: false + properties: + file_id: + description: A [File](/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files. + type: string + chunking_strategy: + $ref: "#/components/schemas/ChunkingStrategyRequestParam" + required: + - file_id + + ListVectorStoreFilesResponse: properties: object: type: string @@ -12328,13 +14676,13 @@ components: data: type: array items: - $ref: "#/components/schemas/RunStepObject" + $ref: "#/components/schemas/VectorStoreFileObject" first_id: type: string - example: "step_abc123" + example: "file-abc123" last_id: type: string - example: "step_abc456" + example: "file-abc456" has_more: type: boolean example: false @@ -12345,1429 +14693,1713 @@ components: - last_id - has_more - RunStepDetailsMessageCreationObject: - title: Message creation + DeleteVectorStoreFileResponse: type: object - description: Details of the message creation by the run step. properties: - type: - description: Always `message_creation`. + id: type: string - enum: ["message_creation"] - message_creation: - type: object - properties: - message_id: - type: string - description: The ID of the message that was created by this run step. - required: - - message_id + deleted: + type: boolean + object: + type: string + enum: [vector_store.file.deleted] required: - - type - - message_creation + - id + - object + - deleted - RunStepDeltaStepDetailsMessageCreationObject: - title: Message creation + VectorStoreFileBatchObject: type: object - description: Details of the message creation by the run step. + title: Vector store file batch + description: A batch of files attached to a vector store. properties: - type: - description: Always `message_creation`. + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `vector_store.file_batch`. + type: string + enum: ["vector_store.files_batch"] + created_at: + description: The Unix timestamp (in seconds) for when the vector store files batch was created. + type: integer + vector_store_id: + description: The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to. + type: string + status: + description: The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`. type: string - enum: ["message_creation"] - message_creation: + enum: ["in_progress", "completed", "cancelled", "failed"] + file_counts: type: object properties: - message_id: - type: string - description: The ID of the message that was created by this run step. + in_progress: + description: The number of files that are currently being processed. + type: integer + completed: + description: The number of files that have been processed. + type: integer + failed: + description: The number of files that have failed to process. + type: integer + cancelled: + description: The number of files that where cancelled. + type: integer + total: + description: The total number of files. + type: integer + required: + - in_progress + - completed + - cancelled + - failed + - total required: - - type + - id + - object + - created_at + - vector_store_id + - status + - file_counts + x-oaiMeta: + name: The vector store files batch object + beta: true + example: | + { + "id": "vsfb_123", + "object": "vector_store.files_batch", + "created_at": 1698107661, + "vector_store_id": "vs_abc123", + "status": "completed", + "file_counts": { + "in_progress": 0, + "completed": 100, + "failed": 0, + "cancelled": 0, + "total": 100 + } + } - RunStepDetailsToolCallsObject: - title: Tool calls + CreateVectorStoreFileBatchRequest: type: object - description: Details of the tool call. + additionalProperties: false properties: - type: - description: Always `tool_calls`. - type: string - enum: ["tool_calls"] - tool_calls: + file_ids: + description: A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. type: array - description: | - An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. + minItems: 1 + maxItems: 500 items: - oneOf: - - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeObject" - - $ref: "#/components/schemas/RunStepDetailsToolCallsFileSearchObject" - - $ref: "#/components/schemas/RunStepDetailsToolCallsFunctionObject" - x-oaiExpandable: true + type: string + chunking_strategy: + $ref: "#/components/schemas/ChunkingStrategyRequestParam" required: - - type - - tool_calls + - file_ids - RunStepDeltaStepDetailsToolCallsObject: - title: Tool calls - type: object - description: Details of the tool call. - properties: - type: - description: Always `tool_calls`. - type: string - enum: ["tool_calls"] - tool_calls: - type: array - description: | - An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. - items: - oneOf: - - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeObject" - - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsFileSearchObject" - - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsFunctionObject" - x-oaiExpandable: true - required: - - type + AssistantStreamEvent: + description: | + Represents an event emitted when streaming a Run. - RunStepDetailsToolCallsCodeObject: - title: Code Interpreter tool call - type: object - description: Details of the Code Interpreter tool call the run step was involved in. - properties: - id: - type: string - description: The ID of the tool call. - type: - type: string - description: The type of tool call. This is always going to be `code_interpreter` for this type of tool call. - enum: ["code_interpreter"] - code_interpreter: - type: object - description: The Code Interpreter tool call definition. + Each event in a server-sent events stream has an `event` and `data` property: + + ``` + event: thread.created + data: {"id": "thread_123", "object": "thread", ...} + ``` + + We emit events whenever a new object is created, transitions to a new state, or is being + streamed in parts (deltas). For example, we emit `thread.run.created` when a new run + is created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses + to create a message during a run, we emit a `thread.message.created event`, a + `thread.message.in_progress` event, many `thread.message.delta` events, and finally a + `thread.message.completed` event. + + We may add additional events over time, so we recommend handling unknown events gracefully + in your code. See the [Assistants API quickstart](/docs/assistants/overview) to learn how to + integrate the Assistants API with streaming. + oneOf: + - $ref: "#/components/schemas/ThreadStreamEvent" + - $ref: "#/components/schemas/RunStreamEvent" + - $ref: "#/components/schemas/RunStepStreamEvent" + - $ref: "#/components/schemas/MessageStreamEvent" + - $ref: "#/components/schemas/ErrorEvent" + - $ref: "#/components/schemas/DoneEvent" + x-oaiMeta: + name: Assistant stream events + beta: true + + ThreadStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.created"] + data: + $ref: "#/components/schemas/ThreadObject" + required: + - event + - data + description: Occurs when a new [thread](/docs/api-reference/threads/object) is created. + x-oaiMeta: + dataDescription: "`data` is a [thread](/docs/api-reference/threads/object)" + + RunStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.run.created"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a new [run](/docs/api-reference/runs/object) is created. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.queued"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `queued` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.in_progress"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to an `in_progress` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.requires_action"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `requires_action` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.completed"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) is completed. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.incomplete"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) ends with status `incomplete`. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.failed"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) fails. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.cancelling"] + data: + $ref: "#/components/schemas/RunObject" required: - - input - - outputs + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `cancelling` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object properties: - input: + event: type: string - description: The input to the Code Interpreter tool call. - outputs: - type: array - description: The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. - items: - type: object - oneOf: - - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeOutputLogsObject" - - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeOutputImageObject" - x-oaiExpandable: true - required: - - id - - type - - code_interpreter - - RunStepDeltaStepDetailsToolCallsCodeObject: - title: Code interpreter tool call - type: object - description: Details of the Code Interpreter tool call the run step was involved in. - properties: - index: - type: integer - description: The index of the tool call in the tool calls array. - id: - type: string - description: The ID of the tool call. - type: - type: string - description: The type of tool call. This is always going to be `code_interpreter` for this type of tool call. - enum: ["code_interpreter"] - code_interpreter: - type: object - description: The Code Interpreter tool call definition. + enum: ["thread.run.cancelled"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) is cancelled. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object properties: - input: + event: type: string - description: The input to the Code Interpreter tool call. - outputs: - type: array - description: The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. - items: - type: object - oneOf: - - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject" - - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeOutputImageObject" - x-oaiExpandable: true - required: - - index - - type - - RunStepDetailsToolCallsCodeOutputLogsObject: - title: Code Interpreter log output - type: object - description: Text output from the Code Interpreter tool call as part of a run step. - properties: - type: - description: Always `logs`. - type: string - enum: ["logs"] - logs: - type: string - description: The text output from the Code Interpreter tool call. - required: - - type - - logs - - RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject: - title: Code interpreter log output - type: object - description: Text output from the Code Interpreter tool call as part of a run step. - properties: - index: - type: integer - description: The index of the output in the outputs array. - type: - description: Always `logs`. - type: string - enum: ["logs"] - logs: - type: string - description: The text output from the Code Interpreter tool call. - required: - - index - - type + enum: ["thread.run.expired"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) expires. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - RunStepDetailsToolCallsCodeOutputImageObject: - title: Code Interpreter image output - type: object - properties: - type: - description: Always `image`. - type: string - enum: ["image"] - image: - type: object + RunStepStreamEvent: + oneOf: + - type: object properties: - file_id: - description: The [file](/docs/api-reference/files) ID of the image. + event: type: string + enum: ["thread.run.step.created"] + data: + $ref: "#/components/schemas/RunStepObject" required: - - file_id - required: - - type - - image - - RunStepDeltaStepDetailsToolCallsCodeOutputImageObject: - title: Code interpreter image output - type: object - properties: - index: - type: integer - description: The index of the output in the outputs array. - type: - description: Always `image`. - type: string - enum: ["image"] - image: - type: object + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) is created. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object properties: - file_id: - description: The [file](/docs/api-reference/files) ID of the image. + event: type: string - required: - - index - - type - - RunStepDetailsToolCallsFileSearchObject: - title: File search tool call - type: object - properties: - id: - type: string - description: The ID of the tool call object. - type: - type: string - description: The type of tool call. This is always going to be `file_search` for this type of tool call. - enum: ["file_search"] - file_search: - type: object - description: For now, this is always going to be an empty object. - x-oaiTypeLabel: map - required: - - id - - type - - file_search - - RunStepDeltaStepDetailsToolCallsFileSearchObject: - title: File search tool call - type: object - properties: - index: - type: integer - description: The index of the tool call in the tool calls array. - id: - type: string - description: The ID of the tool call object. - type: - type: string - description: The type of tool call. This is always going to be `file_search` for this type of tool call. - enum: ["file_search"] - file_search: - type: object - description: For now, this is always going to be an empty object. - x-oaiTypeLabel: map - required: - - index - - type - - file_search - - RunStepDetailsToolCallsFunctionObject: - type: object - title: Function tool call - properties: - id: - type: string - description: The ID of the tool call object. - type: - type: string - description: The type of tool call. This is always going to be `function` for this type of tool call. - enum: ["function"] - function: - type: object - description: The definition of the function that was called. + enum: ["thread.run.step.in_progress"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) moves to an `in_progress` state. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.delta"] + data: + $ref: "#/components/schemas/RunStepDeltaObject" + required: + - event + - data + description: Occurs when parts of a [run step](/docs/api-reference/runs/step-object) are being streamed. + x-oaiMeta: + dataDescription: "`data` is a [run step delta](/docs/api-reference/assistants-streaming/run-step-delta-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.completed"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) is completed. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object properties: - name: + event: type: string - description: The name of the function. - arguments: + enum: ["thread.run.step.failed"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) fails. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: type: string - description: The arguments passed to the function. - output: + enum: ["thread.run.step.cancelled"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) is cancelled. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: type: string - description: The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet. - nullable: true + enum: ["thread.run.step.expired"] + data: + $ref: "#/components/schemas/RunStepObject" required: - - name - - arguments - - output - required: - - id - - type - - function + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) expires. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" - RunStepDeltaStepDetailsToolCallsFunctionObject: - type: object - title: Function tool call - properties: - index: - type: integer - description: The index of the tool call in the tool calls array. - id: - type: string - description: The ID of the tool call object. - type: - type: string - description: The type of tool call. This is always going to be `function` for this type of tool call. - enum: ["function"] - function: - type: object - description: The definition of the function that was called. + MessageStreamEvent: + oneOf: + - type: object properties: - name: - type: string - description: The name of the function. - arguments: + event: type: string - description: The arguments passed to the function. - output: + enum: ["thread.message.created"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) is created. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + - type: object + properties: + event: type: string - description: The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet. - nullable: true - required: - - index - - type - - VectorStoreExpirationAfter: - type: object - title: Vector store expiration policy - description: The expiration policy for a vector store. - properties: - anchor: - description: "Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`." - type: string - enum: ["last_active_at"] - days: - description: The number of days after the anchor time that the vector store will expire. - type: integer - minimum: 1 - maximum: 365 - required: - - anchor - - days - - VectorStoreObject: - type: object - title: Vector store - description: A vector store is a collection of processed files can be used by the `file_search` tool. - properties: - id: - description: The identifier, which can be referenced in API endpoints. - type: string - object: - description: The object type, which is always `vector_store`. - type: string - enum: ["vector_store"] - created_at: - description: The Unix timestamp (in seconds) for when the vector store was created. - type: integer - name: - description: The name of the vector store. - type: string - usage_bytes: - description: The total number of bytes used by the files in the vector store. - type: integer - file_counts: - type: object + enum: ["thread.message.in_progress"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) moves to an `in_progress` state. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + - type: object properties: - in_progress: - description: The number of files that are currently being processed. - type: integer - completed: - description: The number of files that have been successfully processed. - type: integer - failed: - description: The number of files that have failed to process. - type: integer - cancelled: - description: The number of files that were cancelled. - type: integer - total: - description: The total number of files. - type: integer + event: + type: string + enum: ["thread.message.delta"] + data: + $ref: "#/components/schemas/MessageDeltaObject" required: - - in_progress - - completed - - failed - - cancelled - - total - status: - description: The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use. - type: string - enum: ["expired", "in_progress", "completed"] - expires_after: - $ref: "#/components/schemas/VectorStoreExpirationAfter" - expires_at: - description: The Unix timestamp (in seconds) for when the vector store will expire. - type: integer - nullable: true - last_active_at: - description: The Unix timestamp (in seconds) for when the vector store was last active. - type: integer - nullable: true - metadata: - description: *metadata_description - type: object - x-oaiTypeLabel: map - nullable: true - required: - - id - - object - - usage_bytes - - created_at - - status - - last_active_at - - name - - file_counts - - metadata - x-oaiMeta: - name: The vector store object - beta: true - example: | - { - "id": "vs_123", - "object": "vector_store", - "created_at": 1698107661, - "usage_bytes": 123456, - "last_active_at": 1698107661, - "name": "my_vector_store", - "status": "completed", - "file_counts": { - "in_progress": 0, - "completed": 100, - "cancelled": 0, - "failed": 0, - "total": 100 - }, - "metadata": {}, - "last_used_at": 1698107661 - } - - CreateVectorStoreRequest: - type: object - additionalProperties: false - properties: - file_ids: - description: A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. - type: array - maxItems: 500 - items: - type: string - name: - description: The name of the vector store. - type: string - expires_after: - $ref: "#/components/schemas/VectorStoreExpirationAfter" - chunking_strategy: - type: object - description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty. - oneOf: - - $ref: "#/components/schemas/AutoChunkingStrategyRequestParam" - - $ref: "#/components/schemas/StaticChunkingStrategyRequestParam" - x-oaiExpandable: true - metadata: - description: *metadata_description - type: object - x-oaiTypeLabel: map - nullable: true + - event + - data + description: Occurs when parts of a [Message](/docs/api-reference/messages/object) are being streamed. + x-oaiMeta: + dataDescription: "`data` is a [message delta](/docs/api-reference/assistants-streaming/message-delta-object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.completed"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) is completed. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.incomplete"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) ends before it is completed. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" - UpdateVectorStoreRequest: + ErrorEvent: type: object - additionalProperties: false properties: - name: - description: The name of the vector store. - type: string - nullable: true - expires_after: - $ref: "#/components/schemas/VectorStoreExpirationAfter" - nullable: true - metadata: - description: *metadata_description - type: object - x-oaiTypeLabel: map - nullable: true - - ListVectorStoresResponse: - properties: - object: + event: type: string - example: "list" + enum: ["error"] data: - type: array - items: - $ref: "#/components/schemas/VectorStoreObject" - first_id: - type: string - example: "vs_abc123" - last_id: - type: string - example: "vs_abc456" - has_more: - type: boolean - example: false + $ref: "#/components/schemas/Error" required: - - object + - event - data - - first_id - - last_id - - has_more + description: Occurs when an [error](/docs/guides/error-codes/api-errors) occurs. This can happen due to an internal server error or a timeout. + x-oaiMeta: + dataDescription: "`data` is an [error](/docs/guides/error-codes/api-errors)" - DeleteVectorStoreResponse: + DoneEvent: type: object properties: - id: + event: type: string - deleted: - type: boolean - object: + enum: ["done"] + data: type: string - enum: [vector_store.deleted] + enum: ["[DONE]"] required: - - id - - object - - deleted + - event + - data + description: Occurs when a stream ends. + x-oaiMeta: + dataDescription: "`data` is `[DONE]`" - VectorStoreFileObject: + Batch: type: object - title: Vector store files - description: A list of files attached to a vector store. properties: id: - description: The identifier, which can be referenced in API endpoints. type: string object: - description: The object type, which is always `vector_store.file`. - type: string - enum: ["vector_store.file"] - usage_bytes: - description: The total vector store usage in bytes. Note that this may be different from the original file size. - type: integer - created_at: - description: The Unix timestamp (in seconds) for when the vector store file was created. - type: integer - vector_store_id: - description: The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to. type: string - status: - description: The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use. + enum: [batch] + description: The object type, which is always `batch`. + endpoint: type: string - enum: ["in_progress", "completed", "cancelled", "failed"] - last_error: + description: The OpenAI API endpoint used by the batch. + + errors: type: object - description: The last error associated with this vector store file. Will be `null` if there are no errors. - nullable: true properties: - code: - type: string - description: One of `server_error` or `rate_limit_exceeded`. - enum: - [ - "internal_error", - "file_not_found", - "parsing_error", - "unhandled_mime_type", - ] - message: + object: type: string - description: A human-readable description of the error. - required: - - code - - message - chunking_strategy: - type: object - description: The strategy used to chunk the file. - oneOf: - - $ref: "#/components/schemas/StaticChunkingStrategyResponseParam" - - $ref: "#/components/schemas/OtherChunkingStrategyResponseParam" - x-oaiExpandable: true - required: - - id - - object - - usage_bytes - - created_at - - vector_store_id - - status - - last_error - x-oaiMeta: - name: The vector store file object - beta: true - example: | - { - "id": "file-abc123", - "object": "vector_store.file", - "usage_bytes": 1234, - "created_at": 1698107661, - "vector_store_id": "vs_abc123", - "status": "completed", - "last_error": null, - "chunking_strategy": { - "type": "static", - "static": { - "max_chunk_size_tokens": 800, - "chunk_overlap_tokens": 400 - } - } - } - - OtherChunkingStrategyResponseParam: - type: object - title: Other Chunking Strategy - description: This is returned when the chunking strategy is unknown. Typically, this is because the file was indexed before the `chunking_strategy` concept was introduced in the API. - additionalProperties: false - properties: - type: + description: The object type, which is always `list`. + data: + type: array + items: + type: object + properties: + code: + type: string + description: An error code identifying the error type. + message: + type: string + description: A human-readable message providing more details about the error. + param: + type: string + description: The name of the parameter that caused the error, if applicable. + nullable: true + line: + type: integer + description: The line number of the input file where the error occurred, if applicable. + nullable: true + input_file_id: type: string - description: Always `other`. - enum: ["other"] - required: - - type - - StaticChunkingStrategyResponseParam: - type: object - title: Static Chunking Strategy - additionalProperties: false - properties: - type: + description: The ID of the input file for the batch. + completion_window: type: string - description: Always `static`. - enum: ["static"] - static: - $ref: "#/components/schemas/StaticChunkingStrategy" - required: - - type - - static - - StaticChunkingStrategy: - type: object - additionalProperties: false - properties: - max_chunk_size_tokens: + description: The time frame within which the batch should be processed. + status: + type: string + description: The current status of the batch. + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + output_file_id: + type: string + description: The ID of the file containing the outputs of successfully executed requests. + error_file_id: + type: string + description: The ID of the file containing the outputs of requests with errors. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch was created. + in_progress_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch started processing. + expires_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch will expire. + finalizing_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch started finalizing. + completed_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch was completed. + failed_at: type: integer - minimum: 100 - maximum: 4096 - description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. - chunk_overlap_tokens: + description: The Unix timestamp (in seconds) for when the batch failed. + expired_at: type: integer - description: | - The number of tokens that overlap between chunks. The default value is `400`. - - Note that the overlap must not exceed half of `max_chunk_size_tokens`. + description: The Unix timestamp (in seconds) for when the batch expired. + cancelling_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch started cancelling. + cancelled_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch was cancelled. + request_counts: + type: object + properties: + total: + type: integer + description: Total number of requests in the batch. + completed: + type: integer + description: Number of requests that have been completed successfully. + failed: + type: integer + description: Number of requests that have failed. + required: + - total + - completed + - failed + description: The request counts for different statuses within the batch. + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true required: - - max_chunk_size_tokens - - chunk_overlap_tokens + - id + - object + - endpoint + - input_file_id + - completion_window + - status + - created_at + x-oaiMeta: + name: The batch object + example: *batch_object - AutoChunkingStrategyRequestParam: + BatchRequestInput: type: object - title: Auto Chunking Strategy - description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. - additionalProperties: false + description: The per-line object of the batch input file properties: - type: + custom_id: type: string - description: Always `auto`. - enum: ["auto"] - required: - - type - - StaticChunkingStrategyRequestParam: - type: object - title: Static Chunking Strategy - additionalProperties: false - properties: - type: + description: A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch. + method: type: string - description: Always `static`. - enum: ["static"] - static: - $ref: "#/components/schemas/StaticChunkingStrategy" - required: - - type - - static - - ChunkingStrategyRequestParam: - type: object - description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. - oneOf: - - $ref: "#/components/schemas/AutoChunkingStrategyRequestParam" - - $ref: "#/components/schemas/StaticChunkingStrategyRequestParam" - x-oaiExpandable: true + enum: ["POST"] + description: The HTTP method to be used for the request. Currently only `POST` is supported. + url: + type: string + description: The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported. + x-oaiMeta: + name: The request input object + example: | + {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is 2+2?"}]}} - CreateVectorStoreFileRequest: + BatchRequestOutput: type: object - additionalProperties: false + description: The per-line object of the batch output and error files properties: - file_id: - description: A [File](/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files. + id: type: string - chunking_strategy: - $ref: "#/components/schemas/ChunkingStrategyRequestParam" - required: - - file_id + custom_id: + type: string + description: A developer-provided per-request id that will be used to match outputs to inputs. + response: + type: object + nullable: true + properties: + status_code: + type: integer + description: The HTTP status code of the response + request_id: + type: string + description: An unique identifier for the OpenAI API request. Please include this request ID when contacting support. + body: + type: object + x-oaiTypeLabel: map + description: The JSON body of the response + error: + type: object + nullable: true + description: For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure. + properties: + code: + type: string + description: A machine-readable error code. + message: + type: string + description: A human-readable error message. + x-oaiMeta: + name: The request output object + example: | + {"id": "batch_req_wnaDys", "custom_id": "request-2", "response": {"status_code": 200, "request_id": "req_c187b3", "body": {"id": "chatcmpl-9758Iw", "object": "chat.completion", "created": 1711475054, "model": "gpt-4o-mini", "choices": [{"index": 0, "message": {"role": "assistant", "content": "2 + 2 equals 4."}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 24, "completion_tokens": 15, "total_tokens": 39}, "system_fingerprint": null}}, "error": null} - ListVectorStoreFilesResponse: + ListBatchesResponse: + type: object properties: - object: - type: string - example: "list" data: type: array items: - $ref: "#/components/schemas/VectorStoreFileObject" + $ref: "#/components/schemas/Batch" first_id: type: string - example: "file-abc123" + example: "batch_abc123" last_id: type: string - example: "file-abc456" + example: "batch_abc456" has_more: type: boolean - example: false + object: + type: string + enum: [list] required: - object - data - - first_id - - last_id - has_more - DeleteVectorStoreFileResponse: + AuditLogActorServiceAccount: type: object + description: The service account that performed the audit logged action. properties: id: type: string - deleted: - type: boolean - object: - type: string - enum: [vector_store.file.deleted] - required: - - id - - object - - deleted + description: The service account id. - VectorStoreFileBatchObject: + AuditLogActorUser: type: object - title: Vector store file batch - description: A batch of files attached to a vector store. + description: The user who performed the audit logged action. properties: id: - description: The identifier, which can be referenced in API endpoints. - type: string - object: - description: The object type, which is always `vector_store.file_batch`. - type: string - enum: ["vector_store.files_batch"] - created_at: - description: The Unix timestamp (in seconds) for when the vector store files batch was created. - type: integer - vector_store_id: - description: The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to. type: string - status: - description: The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`. + description: The user id. + email: type: string - enum: ["in_progress", "completed", "cancelled", "failed"] - file_counts: - type: object - properties: - in_progress: - description: The number of files that are currently being processed. - type: integer - completed: - description: The number of files that have been processed. - type: integer - failed: - description: The number of files that have failed to process. - type: integer - cancelled: - description: The number of files that where cancelled. - type: integer - total: - description: The total number of files. - type: integer - required: - - in_progress - - completed - - cancelled - - failed - - total - required: - - id - - object - - created_at - - vector_store_id - - status - - file_counts - x-oaiMeta: - name: The vector store files batch object - beta: true - example: | - { - "id": "vsfb_123", - "object": "vector_store.files_batch", - "created_at": 1698107661, - "vector_store_id": "vs_abc123", - "status": "completed", - "file_counts": { - "in_progress": 0, - "completed": 100, - "failed": 0, - "cancelled": 0, - "total": 100 - } - } + description: The user email. - CreateVectorStoreFileBatchRequest: + AuditLogActorApiKey: type: object - additionalProperties: false - properties: - file_ids: - description: A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. - type: array - minItems: 1 - maxItems: 500 - items: - type: string - chunking_strategy: - $ref: "#/components/schemas/ChunkingStrategyRequestParam" - required: - - file_ids - - AssistantStreamEvent: - description: | - Represents an event emitted when streaming a Run. - - Each event in a server-sent events stream has an `event` and `data` property: - - ``` - event: thread.created - data: {"id": "thread_123", "object": "thread", ...} - ``` + description: The API Key used to perform the audit logged action. + properties: + id: + type: string + description: The tracking id of the API key. + type: + type: string + description: The type of API key. Can be either `user` or `service_account`. + enum: ["user", "service_account"] + user: + $ref: "#/components/schemas/AuditLogActorUser" + service_account: + $ref: "#/components/schemas/AuditLogActorServiceAccount" - We emit events whenever a new object is created, transitions to a new state, or is being - streamed in parts (deltas). For example, we emit `thread.run.created` when a new run - is created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses - to create a message during a run, we emit a `thread.message.created event`, a - `thread.message.in_progress` event, many `thread.message.delta` events, and finally a - `thread.message.completed` event. + AuditLogActorSession: + type: object + description: The session in which the audit logged action was performed. + properties: + user: + $ref: "#/components/schemas/AuditLogActorUser" + ip_address: + type: string + description: The IP address from which the action was performed. - We may add additional events over time, so we recommend handling unknown events gracefully - in your code. See the [Assistants API quickstart](/docs/assistants/overview) to learn how to - integrate the Assistants API with streaming. - oneOf: - - $ref: "#/components/schemas/ThreadStreamEvent" - - $ref: "#/components/schemas/RunStreamEvent" - - $ref: "#/components/schemas/RunStepStreamEvent" - - $ref: "#/components/schemas/MessageStreamEvent" - - $ref: "#/components/schemas/ErrorEvent" - - $ref: "#/components/schemas/DoneEvent" - x-oaiMeta: - name: Assistant stream events - beta: true + AuditLogActor: + type: object + description: The actor who performed the audit logged action. + properties: + type: + type: string + description: The type of actor. Is either `session` or `api_key`. + enum: ["session", "api_key"] + session: + type: object + $ref: "#/components/schemas/AuditLogActorSession" + api_key: + type: object + $ref: "#/components/schemas/AuditLogActorApiKey" - ThreadStreamEvent: - oneOf: - - type: object - properties: - event: - type: string - enum: ["thread.created"] - data: - $ref: "#/components/schemas/ThreadObject" - required: - - event - - data - description: Occurs when a new [thread](/docs/api-reference/threads/object) is created. - x-oaiMeta: - dataDescription: "`data` is a [thread](/docs/api-reference/threads/object)" + AuditLogEventType: + type: string + description: The event type. + x-oaiExpandable: true + enum: + - api_key.created + - api_key.updated + - api_key.deleted + - invite.sent + - invite.accepted + - invite.deleted + - login.succeeded + - login.failed + - logout.succeeded + - logout.failed + - organization.updated + - project.created + - project.updated + - project.archived + - service_account.created + - service_account.updated + - service_account.deleted + - user.added + - user.updated + - user.deleted + + AuditLog: + type: object + description: A log of a user action or configuration change within this organization. + properties: + id: + type: string + description: The ID of this log. + type: + $ref: "#/components/schemas/AuditLogEventType" - RunStreamEvent: - oneOf: - - type: object + effective_at: + type: integer + description: The Unix timestamp (in seconds) of the event. + project: + type: object + description: The project that the action was scoped to. Absent for actions not scoped to projects. properties: - event: + id: type: string - enum: ["thread.run.created"] - data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a new [run](/docs/api-reference/runs/object) is created. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - type: object - properties: - event: + description: The project ID. + name: type: string - enum: ["thread.run.queued"] - data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `queued` status. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - type: object + description: The project title. + actor: + $ref: "#/components/schemas/AuditLogActor" + api_key.created: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.in_progress"] + description: The tracking ID of the API key. data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a [run](/docs/api-reference/runs/object) moves to an `in_progress` status. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - type: object + type: object + description: The payload used to create the API key. + properties: + scopes: + type: array + items: + type: string + description: A list of scopes allowed for the API key, e.g. `["api.model.request"]` + api_key.updated: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.requires_action"] - data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `requires_action` status. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - type: object + description: The tracking ID of the API key. + changes_requested: + type: object + description: The payload used to update the API key. + properties: + scopes: + type: array + items: + type: string + description: A list of scopes allowed for the API key, e.g. `["api.model.request"]` + api_key.deleted: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.completed"] - data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a [run](/docs/api-reference/runs/object) is completed. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - type: object + description: The tracking ID of the API key. + invite.sent: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.incomplete"] + description: The ID of the invite. data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a [run](/docs/api-reference/runs/object) ends with status `incomplete`. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - type: object + type: object + description: The payload used to create the invite. + properties: + email: + type: string + description: The email invited to the organization. + role: + type: string + description: The role the email was invited to be. Is either `owner` or `member`. + invite.accepted: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.failed"] - data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a [run](/docs/api-reference/runs/object) fails. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - type: object + description: The ID of the invite. + invite.deleted: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.cancelling"] - data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `cancelling` status. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - type: object + description: The ID of the invite. + login.failed: + type: object + description: The details for events with this `type`. properties: - event: + error_code: type: string - enum: ["thread.run.cancelled"] - data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a [run](/docs/api-reference/runs/object) is cancelled. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - type: object - properties: - event: + description: The error code of the failure. + error_message: type: string - enum: ["thread.run.expired"] - data: - $ref: "#/components/schemas/RunObject" - required: - - event - - data - description: Occurs when a [run](/docs/api-reference/runs/object) expires. - x-oaiMeta: - dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" - - RunStepStreamEvent: - oneOf: - - type: object + description: The error message of the failure. + logout.failed: + type: object + description: The details for events with this `type`. properties: - event: + error_code: type: string - enum: ["thread.run.step.created"] - data: - $ref: "#/components/schemas/RunStepObject" - required: - - event - - data - description: Occurs when a [run step](/docs/api-reference/runs/step-object) is created. - x-oaiMeta: - dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" - - type: object - properties: - event: + description: The error code of the failure. + error_message: type: string - enum: ["thread.run.step.in_progress"] - data: - $ref: "#/components/schemas/RunStepObject" - required: - - event - - data - description: Occurs when a [run step](/docs/api-reference/runs/step-object) moves to an `in_progress` state. - x-oaiMeta: - dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" - - type: object + description: The error message of the failure. + organization.updated: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.step.delta"] - data: - $ref: "#/components/schemas/RunStepDeltaObject" - required: - - event - - data - description: Occurs when parts of a [run step](/docs/api-reference/runs/step-object) are being streamed. - x-oaiMeta: - dataDescription: "`data` is a [run step delta](/docs/api-reference/assistants-streaming/run-step-delta-object)" - - type: object + description: The organization ID. + changes_requested: + type: object + description: The payload used to update the organization settings. + properties: + title: + type: string + description: The organization title. + description: + type: string + description: The organization description. + name: + type: string + description: The organization name. + settings: + type: object + properties: + threads_ui_visibility: + type: string + description: Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`. + usage_dashboard_visibility: + type: string + description: Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`. + project.created: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.step.completed"] + description: The project ID. data: - $ref: "#/components/schemas/RunStepObject" - required: - - event - - data - description: Occurs when a [run step](/docs/api-reference/runs/step-object) is completed. - x-oaiMeta: - dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" - - type: object + type: object + description: The payload used to create the project. + properties: + name: + type: string + description: The project name. + title: + type: string + description: The title of the project as seen on the dashboard. + project.updated: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.step.failed"] - data: - $ref: "#/components/schemas/RunStepObject" - required: - - event - - data - description: Occurs when a [run step](/docs/api-reference/runs/step-object) fails. - x-oaiMeta: - dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" - - type: object + description: The project ID. + changes_requested: + type: object + description: The payload used to update the project. + properties: + title: + type: string + description: The title of the project as seen on the dashboard. + project.archived: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.step.cancelled"] - data: - $ref: "#/components/schemas/RunStepObject" - required: - - event - - data - description: Occurs when a [run step](/docs/api-reference/runs/step-object) is cancelled. - x-oaiMeta: - dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" - - type: object + description: The project ID. + service_account.created: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.run.step.expired"] + description: The service account ID. data: - $ref: "#/components/schemas/RunStepObject" - required: - - event - - data - description: Occurs when a [run step](/docs/api-reference/runs/step-object) expires. - x-oaiMeta: - dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" - - MessageStreamEvent: - oneOf: - - type: object + type: object + description: The payload used to create the service account. + properties: + role: + type: string + description: The role of the service account. Is either `owner` or `member`. + service_account.updated: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.message.created"] - data: - $ref: "#/components/schemas/MessageObject" - required: - - event - - data - description: Occurs when a [message](/docs/api-reference/messages/object) is created. - x-oaiMeta: - dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" - - type: object + description: The service account ID. + changes_requested: + type: object + description: The payload used to updated the service account. + properties: + role: + type: string + description: The role of the service account. Is either `owner` or `member`. + service_account.deleted: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.message.in_progress"] - data: - $ref: "#/components/schemas/MessageObject" - required: - - event - - data - description: Occurs when a [message](/docs/api-reference/messages/object) moves to an `in_progress` state. - x-oaiMeta: - dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" - - type: object + description: The service account ID. + user.added: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.message.delta"] + description: The user ID. data: - $ref: "#/components/schemas/MessageDeltaObject" - required: - - event - - data - description: Occurs when parts of a [Message](/docs/api-reference/messages/object) are being streamed. - x-oaiMeta: - dataDescription: "`data` is a [message delta](/docs/api-reference/assistants-streaming/message-delta-object)" - - type: object + type: object + description: The payload used to add the user to the project. + properties: + role: + type: string + description: The role of the user. Is either `owner` or `member`. + user.updated: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.message.completed"] - data: - $ref: "#/components/schemas/MessageObject" - required: - - event - - data - description: Occurs when a [message](/docs/api-reference/messages/object) is completed. - x-oaiMeta: - dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" - - type: object + description: The project ID. + changes_requested: + type: object + description: The payload used to update the user. + properties: + role: + type: string + description: The role of the user. Is either `owner` or `member`. + user.deleted: + type: object + description: The details for events with this `type`. properties: - event: + id: type: string - enum: ["thread.message.incomplete"] - data: - $ref: "#/components/schemas/MessageObject" - required: - - event - - data - description: Occurs when a [message](/docs/api-reference/messages/object) ends before it is completed. - x-oaiMeta: - dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + description: The user ID. + required: + - id + - type + - effective_at + - actor + x-oaiMeta: + name: The audit log object + example: | + { + "id": "req_xxx_20240101", + "type": "api_key.created", + "effective_at": 1720804090, + "actor": { + "type": "session", + "session": { + "user": { + "id": "user-xxx", + "email": "user@example.com" + }, + "ip_address": "127.0.0.1", + "user_agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + }, + "api_key.created": { + "id": "key_xxxx", + "data": { + "scopes": ["resource.operation"] + } + } + } + + ListAuditLogsResponse: + type: object + properties: + object: + type: string + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/AuditLog" + first_id: + type: string + example: "audit_log-defb456h8dks" + last_id: + type: string + example: "audit_log-hnbkd8s93s" + has_more: + type: boolean + + required: + - object + - data + - first_id + - last_id + - has_more + + Invite: + type: object + description: Represents an individual `invite` to the organization. + properties: + object: + type: string + enum: [organization.invite] + description: The object type, which is always `organization.invite` + id: + type: string + description: The identifier, which can be referenced in API endpoints + email: + type: string + description: The email address of the individual to whom the invite was sent + role: + type: string + enum: [owner, reader] + description: "`owner` or `reader`" + status: + type: string + enum: [accepted, expired, pending] + description: "`accepted`,`expired`, or `pending`" + invited_at: + type: integer + description: The Unix timestamp (in seconds) of when the invite was sent. + expires_at: + type: integer + description: The Unix timestamp (in seconds) of when the invite expires. + accepted_at: + type: integer + description: The Unix timestamp (in seconds) of when the invite was accepted. + + required: + - object + - id + - email + - role + - status + - invited_at + - expires_at + x-oaiMeta: + name: The invite object + example: | + { + "object": "organization.invite", + "id": "invite-abc", + "email": "user@example.com", + "role": "owner", + "status": "accepted", + "invited_at": 1711471533, + "expires_at": 1711471533, + "accepted_at": 1711471533 + } + + InviteListResponse: + type: object + properties: + object: + type: string + enum: [list] + description: The object type, which is always `list` + data: + type: array + items: + $ref: "#/components/schemas/Invite" + first_id: + type: string + description: The first `invite_id` in the retrieved `list` + last_id: + type: string + description: The last `invite_id` in the retrieved `list` + has_more: + type: boolean + description: The `has_more` property is used for pagination to indicate there are additional results. + required: + - object + - data + + InviteRequest: + type: object + properties: + email: + type: string + description: "Send an email to this address" + role: + type: string + enum: [reader, owner] + description: "`owner` or `reader`" + required: + - email + - role + + InviteDeleteResponse: + type: object + properties: + object: + type: string + enum: [organization.invite.deleted] + description: The object type, which is always `organization.invite.deleted` + id: + type: string + deleted: + type: boolean + required: + - object + - id + - deleted + + User: + type: object + description: Represents an individual `user` within an organization. + properties: + object: + type: string + enum: [organization.user] + description: The object type, which is always `organization.user` + id: + type: string + description: The identifier, which can be referenced in API endpoints + name: + type: string + description: The name of the user + email: + type: string + description: The email address of the user + role: + type: string + enum: [owner, reader] + description: "`owner` or `reader`" + added_at: + type: integer + description: The Unix timestamp (in seconds) of when the user was added. + required: + - object + - id + - name + - email + - role + - added_at + x-oaiMeta: + name: The user object + example: | + { + "object": "organization.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + UserListResponse: + type: object + properties: + object: + type: string + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/User" + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + required: + - object + - data + - first_id + - last_id + - has_more + + UserRoleUpdateRequest: + type: object + properties: + role: + type: string + enum: [owner, reader] + description: "`owner` or `reader`" + required: + - role - ErrorEvent: + UserDeleteResponse: type: object properties: - event: + object: type: string - enum: ["error"] + enum: [organization.user.deleted] + id: + type: string + deleted: + type: boolean + required: + - object + - id + - deleted + + Project: + type: object + description: Represents an individual project. + properties: + id: + type: string + description: The identifier, which can be referenced in API endpoints + object: + type: string + enum: [organization.project] + description: The object type, which is always `organization.project` + name: + type: string + description: The name of the project. This appears in reporting. + created_at: + type: integer + description: The Unix timestamp (in seconds) of when the project was created. + archived_at: + type: integer + nullable: true + description: The Unix timestamp (in seconds) of when the project was archived or `null`. + status: + type: string + enum: [active, archived] + description: "`active` or `archived`" + required: + - id + - object + - name + - created_at + - status + x-oaiMeta: + name: The project object + example: | + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project example", + "created_at": 1711471533, + "archived_at": null, + "status": "active" + } + + ProjectListResponse: + type: object + properties: + object: + type: string + enum: [list] data: - $ref: "#/components/schemas/Error" + type: array + items: + $ref: "#/components/schemas/Project" + first_id: + type: string + last_id: + type: string + has_more: + type: boolean required: - - event + - object - data - description: Occurs when an [error](/docs/guides/error-codes/api-errors) occurs. This can happen due to an internal server error or a timeout. + - first_id + - last_id + - has_more + + ProjectCreateRequest: + type: object + properties: + name: + type: string + description: The friendly name of the project, this name appears in reports. + required: + - name + + ProjectUpdateRequest: + type: object + properties: + name: + type: string + description: The updated name of the project, this name appears in reports. + required: + - name + + DefaultProjectErrorResponse: + type: object + properties: + code: + type: integer + message: + type: string + required: + - code + - message + + ProjectUser: + type: object + description: Represents an individual user in a project. + properties: + object: + type: string + enum: [organization.project.user] + description: The object type, which is always `organization.project.user` + id: + type: string + description: The identifier, which can be referenced in API endpoints + name: + type: string + description: The name of the user + email: + type: string + description: The email address of the user + role: + type: string + enum: [owner, member] + description: "`owner` or `member`" + added_at: + type: integer + description: The Unix timestamp (in seconds) of when the project was added. + + required: + - object + - id + - name + - email + - role + - added_at x-oaiMeta: - dataDescription: "`data` is an [error](/docs/guides/error-codes/api-errors)" + name: The project user object + example: | + { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } - DoneEvent: + ProjectUserListResponse: type: object properties: - event: + object: type: string - enum: ["done"] data: + type: array + items: + $ref: "#/components/schemas/ProjectUser" + first_id: type: string - enum: ["[DONE]"] + last_id: + type: string + has_more: + type: boolean required: - - event + - object - data - description: Occurs when a stream ends. - x-oaiMeta: - dataDescription: "`data` is `[DONE]`" + - first_id + - last_id + - has_more - Batch: + ProjectUserCreateRequest: + type: object + properties: + user_id: + type: string + description: The ID of the user. + role: + type: string + enum: [owner, member] + description: "`owner` or `member`" + required: + - user_id + - role + + ProjectUserUpdateRequest: + type: object + properties: + role: + type: string + enum: [owner, member] + description: "`owner` or `member`" + required: + - role + + ProjectUserDeleteResponse: type: object properties: + object: + type: string + enum: [organization.project.user.deleted] id: type: string + deleted: + type: boolean + required: + - object + - id + - deleted + + ProjectServiceAccount: + type: object + description: Represents an individual service account in a project. + properties: object: type: string - enum: [batch] - description: The object type, which is always `batch`. - endpoint: + enum: [organization.project.service_account] + description: The object type, which is always `organization.project.service_account` + id: type: string - description: The OpenAI API endpoint used by the batch. + description: The identifier, which can be referenced in API endpoints + name: + type: string + description: The name of the service account + role: + type: string + enum: [owner, member] + description: "`owner` or `member`" + created_at: + type: integer + description: The Unix timestamp (in seconds) of when the service account was created + required: + - object + - id + - name + - role + - created_at + x-oaiMeta: + name: The project service account object + example: | + { + "object": "organization.project.service_account", + "id": "svc_acct_abc", + "name": "Service Account", + "role": "owner", + "created_at": 1711471533 + } - errors: - type: object - properties: - object: - type: string - description: The object type, which is always `list`. - data: - type: array - items: - type: object - properties: - code: - type: string - description: An error code identifying the error type. - message: - type: string - description: A human-readable message providing more details about the error. - param: - type: string - description: The name of the parameter that caused the error, if applicable. - nullable: true - line: - type: integer - description: The line number of the input file where the error occurred, if applicable. - nullable: true - input_file_id: + ProjectServiceAccountListResponse: + type: object + properties: + object: type: string - description: The ID of the input file for the batch. - completion_window: + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/ProjectServiceAccount" + first_id: type: string - description: The time frame within which the batch should be processed. - status: + last_id: type: string - description: The current status of the batch. - enum: - - validating - - failed - - in_progress - - finalizing - - completed - - expired - - cancelling - - cancelled - output_file_id: + has_more: + type: boolean + required: + - object + - data + - first_id + - last_id + - has_more + + ProjectServiceAccountCreateRequest: + type: object + properties: + name: type: string - description: The ID of the file containing the outputs of successfully executed requests. - error_file_id: + description: The name of the service account being created. + required: + - name + + ProjectServiceAccountCreateResponse: + type: object + properties: + object: type: string - description: The ID of the file containing the outputs of requests with errors. + enum: [organization.project.service_account] + id: + type: string + name: + type: string + role: + type: string + enum: [member] + description: Service accounts can only have one role of type `member` created_at: type: integer - description: The Unix timestamp (in seconds) for when the batch was created. - in_progress_at: - type: integer - description: The Unix timestamp (in seconds) for when the batch started processing. - expires_at: - type: integer - description: The Unix timestamp (in seconds) for when the batch will expire. - finalizing_at: - type: integer - description: The Unix timestamp (in seconds) for when the batch started finalizing. - completed_at: - type: integer - description: The Unix timestamp (in seconds) for when the batch was completed. - failed_at: - type: integer - description: The Unix timestamp (in seconds) for when the batch failed. - expired_at: - type: integer - description: The Unix timestamp (in seconds) for when the batch expired. - cancelling_at: - type: integer - description: The Unix timestamp (in seconds) for when the batch started cancelling. - cancelled_at: - type: integer - description: The Unix timestamp (in seconds) for when the batch was cancelled. - request_counts: - type: object - properties: - total: - type: integer - description: Total number of requests in the batch. - completed: - type: integer - description: Number of requests that have been completed successfully. - failed: - type: integer - description: Number of requests that have failed. - required: - - total - - completed - - failed - description: The request counts for different statuses within the batch. - metadata: - description: *metadata_description - type: object - x-oaiTypeLabel: map - nullable: true + api_key: + $ref: "#/components/schemas/ProjectServiceAccountApiKey" required: - - id - object - - endpoint - - input_file_id - - completion_window - - status + - id + - name + - role - created_at - x-oaiMeta: - name: The batch object - example: *batch_object + - api_key - BatchRequestInput: + ProjectServiceAccountApiKey: type: object - description: The per-line object of the batch input file properties: - custom_id: + object: type: string - description: A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch. - method: + enum: [organization.project.service_account.api_key] + description: The object type, which is always `organization.project.service_account.api_key` + + value: type: string - enum: ["POST"] - description: The HTTP method to be used for the request. Currently only `POST` is supported. - url: + name: type: string - description: The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported. - x-oaiMeta: - name: The request input object - example: | - {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is 2+2?"}]}} + created_at: + type: integer + id: + type: string + required: + - object + - value + - name + - created_at + - id - BatchRequestOutput: + ProjectServiceAccountDeleteResponse: type: object - description: The per-line object of the batch output and error files properties: + object: + type: string + enum: [organization.project.service_account.deleted] id: type: string - custom_id: + deleted: + type: boolean + required: + - object + - id + - deleted + + ProjectApiKey: + type: object + description: Represents an individual API key in a project. + properties: + object: type: string - description: A developer-provided per-request id that will be used to match outputs to inputs. - response: - type: object - nullable: true - properties: - status_code: - type: integer - description: The HTTP status code of the response - request_id: - type: string - description: An unique identifier for the OpenAI API request. Please include this request ID when contacting support. - body: - type: object - x-oaiTypeLabel: map - description: The JSON body of the response - error: + enum: [organization.project.api_key] + description: The object type, which is always `organization.project.api_key` + redacted_value: + type: string + description: The redacted value of the API key + name: + type: string + description: The name of the API key + created_at: + type: integer + description: The Unix timestamp (in seconds) of when the API key was created + id: + type: string + description: The identifier, which can be referenced in API endpoints + owner: type: object - nullable: true - description: For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure. properties: - code: - type: string - description: A machine-readable error code. - message: + type: type: string - description: A human-readable error message. + enum: [user, service_account] + description: "`user` or `service_account`" + user: + $ref: "#/components/schemas/ProjectUser" + service_account: + $ref: "#/components/schemas/ProjectServiceAccount" + required: + - object + - redacted_value + - name + - created_at + - id + - owner x-oaiMeta: - name: The request output object + name: The project API key object example: | - {"id": "batch_req_wnaDys", "custom_id": "request-2", "response": {"status_code": 200, "request_id": "req_c187b3", "body": {"id": "chatcmpl-9758Iw", "object": "chat.completion", "created": 1711475054, "model": "gpt-3.5-turbo", "choices": [{"index": 0, "message": {"role": "assistant", "content": "2 + 2 equals 4."}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 24, "completion_tokens": 15, "total_tokens": 39}, "system_fingerprint": null}}, "error": null} + { + "object": "organization.project.api_key", + "redacted_value": "sk-abc...def", + "name": "My API Key", + "created_at": 1711471533, + "id": "key_abc", + "owner": { + "type": "user", + "user": { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + } + } - ListBatchesResponse: + ProjectApiKeyListResponse: type: object properties: + object: + type: string + enum: [list] data: type: array items: - $ref: "#/components/schemas/Batch" + $ref: "#/components/schemas/ProjectApiKey" first_id: type: string - example: "batch_abc123" last_id: type: string - example: "batch_abc456" has_more: type: boolean - object: - type: string - enum: [list] required: - object - data + - first_id + - last_id - has_more + ProjectApiKeyDeleteResponse: + type: object + properties: + object: + type: string + enum: [organization.project.api_key.deleted] + id: + type: string + deleted: + type: boolean + required: + - object + - id + - deleted + security: - ApiKeyAuth: [] @@ -13777,6 +16409,8 @@ x-oaiMeta: title: Endpoints - id: assistants title: Assistants + - id: administration + title: Administration - id: legacy title: Legacy groups: @@ -14025,6 +16659,7 @@ x-oaiMeta: - type: object key: CreateModerationResponse path: object + - id: assistants title: Assistants beta: true @@ -14252,6 +16887,175 @@ x-oaiMeta: - type: object key: AssistantStreamEvent path: events + + - id: administration + title: Overview + description: | + Programmatically manage your organization. + + The Audit Logs endpoint provides a log of all actions taken in the + organization for security and monitoring purposes. + + To access these endpoints please generate an Admin API Key through the [API Platform Organization overview](/organization/admin-keys). Admin API keys cannot be used for non-administration endpoints. + + For best practices on setting up your organization, please refer to this [guide](/docs/guides/production-best-practices/setting-up-your-organization) + navigationGroup: administration + + - id: invite + title: Invites + description: Invite and manage invitations for an organization. Invited users are automatically added to the Default project. + navigationGroup: administration + sections: + - type: endpoint + key: list-invites + path: list + - type: endpoint + key: inviteUser + path: create + - type: endpoint + key: retrieve-invite + path: retrieve + - type: endpoint + key: delete-invite + path: delete + - type: object + key: Invite + path: object + + - id: users + title: Users + description: | + Manage users and their role in an organization. Users will be automatically added to the Default project. + navigationGroup: administration + sections: + - type: endpoint + key: list-users + path: list + - type: endpoint + key: modify-user + path: modify + - type: endpoint + key: retrieve-user + path: retrieve + - type: endpoint + key: delete-user + path: delete + - type: object + key: User + path: object + + - id: projects + title: Projects + description: | + Manage the projects within an orgnanization includes creation, updating, and archiving or projects. + The Default project cannot be modified or archived. + navigationGroup: administration + sections: + - type: endpoint + key: list-projects + path: list + - type: endpoint + key: create-project + path: create + - type: endpoint + key: retrieve-project + path: retrieve + - type: endpoint + key: modify-project + path: modify + - type: endpoint + key: archive-project + path: archive + - type: object + key: Project + path: object + + - id: project-users + title: Project Users + description: | + Manage users within a project, including adding, updating roles, and removing users. + Users cannot be removed from the Default project, unless they are being removed from the organization. + navigationGroup: administration + sections: + - type: endpoint + key: list-project-users + path: list + - type: endpoint + key: create-project-user + path: creeate + - type: endpoint + key: retrieve-project-user + path: retrieve + - type: endpoint + key: modify-project-user + path: modify + - type: endpoint + key: delete-project-user + path: delete + - type: object + key: ProjectUser + path: object + + - id: project-service-accounts + title: Project Service Accounts + description: | + Manage service accounts within a project. A service account is a bot user that is not associated with a user. + If a user leaves an organization, their keys and membership in projects will no longer work. Service accounts + do not have this limitation. However, service accounts can also be deleted from a project. + navigationGroup: administration + sections: + - type: endpoint + key: list-project-service-accounts + path: list + - type: endpoint + key: create-project-service-account + path: create + - type: endpoint + key: retrieve-project-service-account + path: retrieve + - type: endpoint + key: delete-project-service-account + path: delete + - type: object + key: ProjectServiceAccount + path: object + + - id: project-api-keys + title: Project API Keys + description: | + Manage API keys for a given project. Supports listing and deleting keys for users. + This API does not allow issuing keys for users, as users need to authorize themselves to generate keys. + navigationGroup: administration + sections: + - type: endpoint + key: list-project-api-keys + path: list + - type: endpoint + key: retrieve-project-api-key + path: retrieve + - type: endpoint + key: delete-project-api-key + path: delete + - type: object + key: ProjectApiKey + path: object + + - id: audit-logs + title: Audit Logs + description: | + Logs of user actions and configuration changes within this organization. + + To log events, you must activate logging in the [Organization Settings](/settings/organization/general). + Once activated, for security reasons, logging cannot be deactivated. + navigationGroup: administration + sections: + - type: endpoint + key: list-audit-logs + path: list + - type: object + key: AuditLog + path: object + - id: completions title: Completions legacy: true From 1309815decbd845199fa364b571eb263f851c9d9 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 21 Aug 2024 14:58:20 -0700 Subject: [PATCH 369/425] Some cleanups (#604) --- model-engine/model_engine_server/api/v2/batch_completion.py | 2 +- model-engine/model_engine_server/api/v2/common.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/api/v2/batch_completion.py b/model-engine/model_engine_server/api/v2/batch_completion.py index 9412f945..70fa3dfc 100644 --- a/model-engine/model_engine_server/api/v2/batch_completion.py +++ b/model-engine/model_engine_server/api/v2/batch_completion.py @@ -37,7 +37,7 @@ ) -@batch_completions_router_v2.post("/", response_model=CreateBatchCompletionsV2Response) +@batch_completions_router_v2.post("", response_model=CreateBatchCompletionsV2Response) async def batch_completions( request: CreateBatchCompletionsV2Request, auth: User = Depends(verify_authentication), diff --git a/model-engine/model_engine_server/api/v2/common.py b/model-engine/model_engine_server/api/v2/common.py index 2099c3b6..50c61df2 100644 --- a/model-engine/model_engine_server/api/v2/common.py +++ b/model-engine/model_engine_server/api/v2/common.py @@ -19,8 +19,6 @@ async def get_metric_metadata( request: Request, auth: User = Depends(verify_authentication), ) -> MetricMetadata: - print("body") - print(request.body) model_name = request.query_params.get("model", None) return MetricMetadata(user=auth, model_name=model_name) From 96845864c18b9ea3c08fb03d7daf0297078b9da2 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 21 Aug 2024 17:23:34 -0700 Subject: [PATCH 370/425] More batch job cleanup (#605) * More batch job cleanup * Add priority field to client --- clients/python/llmengine/completion.py | 5 +++++ clients/python/llmengine/data_types/batch_completion.py | 4 ++-- model-engine/model_engine_server/api/v2/batch_completion.py | 3 +++ .../model_engine_server/common/dtos/llms/batch_completion.py | 4 ++-- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index bb0d5ffa..49683111 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -612,6 +612,7 @@ def batch_create_v2( data_parallelism: int = 1, max_runtime_sec: int = 24 * 3600, labels: Dict[str, str] = {}, + priority: Optional[str] = None, tool_config: Optional[ToolConfig] = None, request_headers: Optional[Dict[str, str]] = None, ) -> CreateBatchCompletionsV2Response: @@ -639,6 +640,9 @@ def batch_create_v2( max_runtime_sec (int): The maximum runtime of the batch completion in seconds. Defaults to 24 hours. + priority (str): + Priority of the batch inference job. Default to None. + tool_config (Optional[ToolConfig]): Configuration for tool use. NOTE: this config is highly experimental and signature will change significantly in future iterations. @@ -694,6 +698,7 @@ def batch_create_v2( labels=labels, max_runtime_sec=max_runtime_sec, tool_config=tool_config, + priority=priority, ).model_dump(exclude_none=True, by_alias=True) response = cls.post_sync( resource_name="v2/batch-completions", diff --git a/clients/python/llmengine/data_types/batch_completion.py b/clients/python/llmengine/data_types/batch_completion.py index cfb31248..1d163b2b 100644 --- a/clients/python/llmengine/data_types/batch_completion.py +++ b/clients/python/llmengine/data_types/batch_completion.py @@ -244,7 +244,7 @@ class BatchCompletionsJob(BaseModel): description="""Model configuration for the batch inference. Hardware configurations are inferred.""", ) - priority: Optional[int] = Field( + priority: Optional[str] = Field( default=None, description="Priority of the batch inference job. Default to None.", ) @@ -260,7 +260,7 @@ class BatchCompletionsJob(BaseModel): class UpdateBatchCompletionsV2Request(BaseModel): job_id: str = Field(description="ID of the batch completions job") - priority: Optional[int] = Field( + priority: Optional[str] = Field( default=None, description="Priority of the batch inference job. Default to None.", ) diff --git a/model-engine/model_engine_server/api/v2/batch_completion.py b/model-engine/model_engine_server/api/v2/batch_completion.py index 70fa3dfc..fb8262eb 100644 --- a/model-engine/model_engine_server/api/v2/batch_completion.py +++ b/model-engine/model_engine_server/api/v2/batch_completion.py @@ -16,6 +16,7 @@ from model_engine_server.core.auth.authentication_repository import User from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( + ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) @@ -51,6 +52,8 @@ async def batch_completions( ) return await use_case.execute(request, user=auth) + except ObjectHasInvalidValueException as exc: # pragma: no cover + raise HTTPException(status_code=400, detail=str(exc)) except ObjectNotFoundException as exc: raise HTTPException( status_code=404, diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py index 6b7ebfb5..71c762d8 100644 --- a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -251,7 +251,7 @@ class BatchCompletionsJob(BaseModel): description="""Model configuration for the batch inference. Hardware configurations are inferred.""", ) - priority: Optional[int] = Field( + priority: Optional[str] = Field( default=None, description="Priority of the batch inference job. Default to None.", ) @@ -267,7 +267,7 @@ class BatchCompletionsJob(BaseModel): class UpdateBatchCompletionsV2Request(BaseModel): job_id: str = Field(description="ID of the batch completions job") - priority: Optional[int] = Field( + priority: Optional[str] = Field( default=None, description="Priority of the batch inference job. Default to None.", ) From 47eefb1b9aed8033df8fbb7b8d7c157ec2b669da Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 22 Aug 2024 09:25:07 -0700 Subject: [PATCH 371/425] Relax pydantic constraint for client (#606) * Relax pydantic constraint for client * Relax pydantic constraint for client * Fix * Add get + cancel methods to client * Remove batch_create_v2 and consolidate into batch_create * Update v2 content to include v1 * cleanup --- .ruff.toml | 1 + clients/python/llmengine/__init__.py | 8 +- clients/python/llmengine/completion.py | 219 +- .../llmengine/data_types/batch_completion.py | 26 +- .../llmengine/data_types/chat_completion.py | 181 +- .../python/llmengine/data_types/completion.py | 50 +- .../python/llmengine/data_types/gen/openai.py | 2310 +++++++---------- .../llmengine/data_types/pydantic_types.py | 18 +- clients/python/llmengine/data_types/rest.py | 6 +- clients/python/pyproject.toml | 5 +- clients/python/setup.py | 2 +- .../common/dtos/llms/batch_completion.py | 10 +- scripts/generate-openai-types.sh | 33 +- 13 files changed, 1190 insertions(+), 1679 deletions(-) diff --git a/.ruff.toml b/.ruff.toml index af1d91d6..3a61ae77 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -2,3 +2,4 @@ line-length = 100 ignore = ["E501"] +exclude = ["gen"] diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 6e201069..aada0ef5 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0beta37" +__version__ = "0.0.0beta38" import os from typing import Sequence @@ -38,6 +38,9 @@ CreateBatchCompletionsRequest, CreateBatchCompletionsRequestContent, CreateBatchCompletionsResponse, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1RequestContent, + CreateBatchCompletionsV1Response, CreateBatchCompletionsV2ModelConfig, CreateBatchCompletionsV2Request, CreateBatchCompletionsV2RequestContent, @@ -87,6 +90,9 @@ "CreateBatchCompletionsRequest", "CreateBatchCompletionsRequestContent", "CreateBatchCompletionsResponse", + "CreateBatchCompletionsV1Request", + "CreateBatchCompletionsV1RequestContent", + "CreateBatchCompletionsV1Response", "CreateBatchCompletionsV2Request", "CreateBatchCompletionsV2RequestContent", "CreateBatchCompletionsV2ModelConfig", diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 49683111..8a9dd5ec 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -1,18 +1,17 @@ -from typing import Any, AsyncIterable, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterable, Dict, Iterator, List, Optional, Union, cast from llmengine.api_engine import APIEngine from llmengine.data_types import ( - BatchCompletionsModelConfig, + BatchCompletionContent, CompletionStreamResponse, CompletionStreamV1Request, CompletionSyncResponse, CompletionSyncV1Request, CreateBatchCompletionsModelConfig, - CreateBatchCompletionsRequest, - CreateBatchCompletionsRequestContent, - CreateBatchCompletionsResponse, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1RequestContent, + CreateBatchCompletionsV1Response, CreateBatchCompletionsV2Request, - CreateBatchCompletionsV2RequestContent, CreateBatchCompletionsV2Response, ToolConfig, ) @@ -479,13 +478,16 @@ def batch_create( cls, output_data_path: str, model_config: CreateBatchCompletionsModelConfig, - content: Optional[CreateBatchCompletionsRequestContent] = None, + content: Optional[BatchCompletionContent] = None, input_data_path: Optional[str] = None, data_parallelism: int = 1, max_runtime_sec: int = 24 * 3600, + labels: Optional[Dict[str, str]] = None, + priority: Optional[str] = None, + use_v2: bool = False, tool_config: Optional[ToolConfig] = None, request_headers: Optional[Dict[str, str]] = None, - ) -> CreateBatchCompletionsResponse: + ) -> Union[CreateBatchCompletionsV1Response, CreateBatchCompletionsV2Response]: """ Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint. @@ -507,9 +509,15 @@ def batch_create( data_parallelism (int): The number of parallel jobs to run. Data will be evenly distributed to the jobs. Defaults to 1. + priority (str): + Priority of the batch inference job. Default to None. + max_runtime_sec (int): The maximum runtime of the batch completion in seconds. Defaults to 24 hours. + use_v2 (bool): + Whether to use the v2 batch completion API. Defaults to False. + tool_config (Optional[ToolConfig]): Configuration for tool use. NOTE: this config is highly experimental and signature will change significantly in future iterations. @@ -583,83 +591,11 @@ def batch_create( ) print(response.json()) ``` - """ - data = CreateBatchCompletionsRequest( - model_config=model_config, - content=content, - input_data_path=input_data_path, - output_data_path=output_data_path, - data_parallelism=data_parallelism, - max_runtime_sec=max_runtime_sec, - tool_config=tool_config, - ).dict() - response = cls.post_sync( - resource_name="v1/llm/batch-completions", - data=data, - timeout=HTTP_TIMEOUT, - headers=request_headers, - ) - return CreateBatchCompletionsResponse.parse_obj(response) - - @classmethod - def batch_create_v2( - cls, - *, - output_data_path: str, - model_config: BatchCompletionsModelConfig, - content: Optional[List[CreateBatchCompletionsV2RequestContent]] = None, - input_data_path: Optional[str] = None, - data_parallelism: int = 1, - max_runtime_sec: int = 24 * 3600, - labels: Dict[str, str] = {}, - priority: Optional[str] = None, - tool_config: Optional[ToolConfig] = None, - request_headers: Optional[Dict[str, str]] = None, - ) -> CreateBatchCompletionsV2Response: - """ - Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint. - - Prompts can be passed in from an input file, or as a part of the request. - - Args: - output_data_path (str): - The path to the output file. The output file will be a JSON file containing the completions. - - model_config (BatchCompletionsModelConfig): - The model configuration to use for the batch completion. - - content (Optional[List[CreateBatchCompletionsV2RequestContent]]): - The content to use for the batch completion. Either one of `content` or `input_data_path` must be provided. - - input_data_path (Optional[str]): - The path to the input file. The input file should be a JSON file with data of type `BatchCompletionsRequestContent`. Either one of `content` or `input_data_path` must be provided. - - data_parallelism (int): - The number of parallel jobs to run. Data will be evenly distributed to the jobs. Defaults to 1. - - max_runtime_sec (int): - The maximum runtime of the batch completion in seconds. Defaults to 24 hours. - - priority (str): - Priority of the batch inference job. Default to None. - - tool_config (Optional[ToolConfig]): - Configuration for tool use. - NOTE: this config is highly experimental and signature will change significantly in future iterations. - Currently only Python code evaluator is supported. - Python code context starts with "\`\`\`python\\n" and ends with "\\n>>>\\n", data before "\\n\`\`\`\\n" and content end will be replaced by the Python execution results. - Please format prompts accordingly and provide examples so LLMs could properly generate Python code. - Returns: - response (CreateBatchCompletionsV2Response): The response containing the job id. - - === "Batch completions with prompts in the request" + === "V2 Batch completions with prompts in the request" ```python - from llmengine import ( - Completion, - ) from llmengine import Completion - from llmengine.data_types import CreateBatchCompletionsModelConfig, FilteredChatCompletionV2Request, + from llmengine.data_types import CreateBatchCompletionsModelConfig, FilteredChatCompletionV2Request model_config = CreateBatchCompletionsModelConfig( model="gemma-2-2b-it", @@ -678,32 +614,121 @@ def batch_create_v2( "logprobs": True, } - response = Completion.batch_create_v2( + response = Completion.batch_create( output_data_path="testoutput", model_config=model_config, content=[FilteredChatCompletionV2Request(**content)], + use_v2=True, labels={"team": "my-team", "product": "my-product"}, ) print(response.json()) + """ + labels = labels if labels else model_config.labels + if use_v2: + data = CreateBatchCompletionsV2Request( + model_config=model_config, + content=content, + input_data_path=input_data_path, + output_data_path=output_data_path, + data_parallelism=data_parallelism, + labels=labels, + max_runtime_sec=max_runtime_sec, + tool_config=tool_config, + priority=priority, + ).dict() + response = cls.post_sync( + resource_name="v2/batch-completions", + data=data, + timeout=HTTP_TIMEOUT, + headers=request_headers, + ) + return CreateBatchCompletionsV2Response.parse_obj(response) + else: + if input_data_path is None and not isinstance( + content, CreateBatchCompletionsV1RequestContent + ): + raise ValueError( + "Either input_data_path or content must be provided. If content is provided, it must be of type CreateBatchCompletionsV1RequestContent." + ) + + content = cast(Optional[CreateBatchCompletionsV1RequestContent], content) + data = CreateBatchCompletionsV1Request( + model_config=model_config, + content=content, + input_data_path=input_data_path, + output_data_path=output_data_path, + data_parallelism=data_parallelism, + max_runtime_sec=max_runtime_sec, + tool_config=tool_config, + ).dict() + response = cls.post_sync( + resource_name="v1/llm/batch-completions", + data=data, + timeout=HTTP_TIMEOUT, + headers=request_headers, + ) + return CreateBatchCompletionsV1Response.parse_obj(response) + + @classmethod + def get_batch_completion( + cls, + job_id: str, + request_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Get the status of a batch completion job. + + Args: + job_id (str): + The job id of the batch completion job. + + Returns: + response (Dict[str, Any]): The response containing the job status. + + === "Get batch completion status" + ```python + from llmengine import Completion + + response = Completion.get_batch_completion(job_id="job-id") + print(response) ``` + """ + response = cls._get( + resource_name=f"v2/batch-completions/{job_id}", + timeout=HTTP_TIMEOUT, + headers=request_headers, + ) + return response + @classmethod + def cancel_batch_completion( + cls, + job_id: str, + request_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Cancel a batch completion job. + + Args: + job_id (str): + The job id of the batch completion job. + + Returns: + response (Dict[str, Any]): The response containing the job status. + + === "Cancel batch completion job" + ```python + from llmengine import Completion + + response = Completion.cancel_batch_completion(job_id="job-id") + print(response) + ``` """ - data = CreateBatchCompletionsV2Request( - model_config=model_config, - content=content, - input_data_path=input_data_path, - output_data_path=output_data_path, - data_parallelism=data_parallelism, - labels=labels, - max_runtime_sec=max_runtime_sec, - tool_config=tool_config, - priority=priority, - ).model_dump(exclude_none=True, by_alias=True) response = cls.post_sync( - resource_name="v2/batch-completions", - data=data, + resource_name=f"v2/batch-completions/{job_id}/actions/cancel", + data={}, timeout=HTTP_TIMEOUT, headers=request_headers, ) - return CreateBatchCompletionsV2Response.parse_obj(response) + return response diff --git a/clients/python/llmengine/data_types/batch_completion.py b/clients/python/llmengine/data_types/batch_completion.py index 1d163b2b..6c14fcce 100644 --- a/clients/python/llmengine/data_types/batch_completion.py +++ b/clients/python/llmengine/data_types/batch_completion.py @@ -165,12 +165,9 @@ class CreateBatchCompletionsV1Request(BatchCompletionsRequestBase): Either `input_data_path` or `content` needs to be provided. When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. """ - model_cfg: CreateBatchCompletionsV1ModelConfig = Field(alias="model_config") + model_config: CreateBatchCompletionsV1ModelConfig = Field(alias="model_config") """ Model configuration for the batch inference. Hardware configurations are inferred. - - We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which - reserves model_config as a keyword. """ @@ -196,13 +193,17 @@ class FilteredChatCompletionV2Request(ChatCompletionV2Request): ] CreateBatchCompletionsV2ModelConfig: TypeAlias = BatchCompletionsModelConfig +BatchCompletionContent = Union[ + CreateBatchCompletionsV1RequestContent, CreateBatchCompletionsV2RequestContent +] + class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): """ Request object for batch completions. """ - content: Optional[CreateBatchCompletionsV2RequestContent] = Field( + content: Optional[BatchCompletionContent] = Field( default=None, description=""" Either `input_data_path` or `content` needs to be provided. @@ -210,10 +211,7 @@ class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): """, ) - # We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which - # reserves model_config as a keyword. - model_cfg: BatchCompletionsModelConfig = Field( - alias="model_config", + model_config: BatchCompletionsModelConfig = Field( description="""Model configuration for the batch inference. Hardware configurations are inferred.""", ) @@ -237,10 +235,7 @@ class BatchCompletionsJob(BaseModel): description="Path to the output file. The output file will be a JSON file of type List[CompletionOutput]." ) - # We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which - # reserves model_config as a keyword. - model_cfg: BatchCompletionsModelConfig = Field( - alias="model_config", + model_config: BatchCompletionsModelConfig = Field( description="""Model configuration for the batch inference. Hardware configurations are inferred.""", ) @@ -284,8 +279,3 @@ class ListBatchCompletionV2Response(BaseModel): class GetBatchCompletionV2Response(BaseModel): job: BatchCompletionsJob - - -BatchCompletionContent = Union[ - CreateBatchCompletionsV1RequestContent, CreateBatchCompletionsV2RequestContent -] diff --git a/clients/python/llmengine/data_types/chat_completion.py b/clients/python/llmengine/data_types/chat_completion.py index ab2c94a0..adee0046 100644 --- a/clients/python/llmengine/data_types/chat_completion.py +++ b/clients/python/llmengine/data_types/chat_completion.py @@ -1,131 +1,92 @@ from typing import Any, Dict, List, Optional -from pydantic import Field -from typing_extensions import Annotated - from .gen.openai import CreateChatCompletionRequest, CreateChatCompletionResponse +from .pydantic_types import Field # Fields that are a part of OpenAI spec but are not supported by model engine UNSUPPORTED_FIELDS = ["service_tier"] class VLLMAdditionalFields: - chat_template: Annotated[ - Optional[str], - Field( - default=None, - description=( - "A Jinja template to use for this conversion. " - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ), + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one." ), - ] - chat_template_kwargs: Annotated[ - Optional[Dict[str, Any]], - Field( - default=None, - description=( - "Additional kwargs to pass to the template renderer. " - "Will be accessible by the chat template." - ), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template." ), - ] - - guided_json: Annotated[ - Optional[Dict[str, Any]], - Field( - default=None, - description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_json: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_regex: Optional[str] = Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ) + + guided_grammar: Optional[str] = Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ) + + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" ), - ] + ) - guided_regex: Annotated[ - Optional[str], - Field( - default=None, - description="Regex for guided decoding. Only supported in vllm.", + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." ), - ] - guided_choice: Annotated[ - Optional[List[str]], - Field( - default=None, - description="Choices for guided decoding. Only supported in vllm.", - ), - ] + ) - guided_grammar: Annotated[ - Optional[str], - Field( - default=None, - description="Context-free grammar for guided decoding. Only supported in vllm.", - ), - ] - - guided_decoding_backend: Annotated[ - Optional[str], - Field( - default=None, - description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be either " - "'outlines' / 'lm-format-enforcer'" - ), - ), - ] - - guided_whitespace_pattern: Annotated[ - Optional[str], - Field( - default=None, - description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding." - ), - ), - ] - - skip_special_tokens: Annotated[ - Optional[bool], - Field( - True, - description="Whether to skip special tokens in the output. Only supported in vllm.", - ), - ] + skip_special_tokens: Optional[bool] = Field( + True, + description="Whether to skip special tokens in the output. Only supported in vllm.", + ) class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMAdditionalFields): - model: Annotated[ - str, - Field( - description="ID of the model to use.", - examples=["mixtral-8x7b-instruct"], - ), - ] - - stream: Annotated[ - Optional[bool], - Field( - False, - description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", - ), - ] - - top_k: Annotated[ - Optional[int], - Field( - None, - ge=-1, - description="Controls the number of top tokens to consider. -1 means consider all tokens.", - ), - ] - - include_stop_str_in_output: Annotated[ - Optional[bool], - Field(None, description="Whether to include the stop strings in output text."), - ] + model: str = Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ) + + stream: Optional[bool] = Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ) + + top_k: Optional[int] = Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ) + + include_stop_str_in_output: Optional[bool] = Field( + None, description="Whether to include the stop strings in output text." + ) class ChatCompletionV2Response(CreateChatCompletionResponse): diff --git a/clients/python/llmengine/data_types/completion.py b/clients/python/llmengine/data_types/completion.py index 24978263..fc92f711 100644 --- a/clients/python/llmengine/data_types/completion.py +++ b/clients/python/llmengine/data_types/completion.py @@ -1,7 +1,5 @@ from typing import Any, Dict, List, Optional -from typing_extensions import Annotated - from .gen.openai import CreateCompletionRequest, CreateCompletionResponse from .pydantic_types import BaseModel, Field @@ -288,35 +286,25 @@ def inter_token_latency(self) -> Optional[float]: # Only for streaming requests class CompletionV2Request(CreateCompletionRequest): - model: Annotated[ - str, - Field( - description="ID of the model to use.", - examples=["mixtral-8x7b-instruct"], - ), - ] - - stream: Annotated[ - Optional[bool], - Field( - False, - description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", - ), - ] - - top_k: Annotated[ - Optional[int], - Field( - None, - ge=-1, - description="Controls the number of top tokens to consider. -1 means consider all tokens.", - ), - ] - - include_stop_str_in_output: Annotated[ - Optional[bool], - Field(None, description="Whether to include the stop strings in output text."), - ] + model: str = Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ) + + stream: Optional[bool] = Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ) + + top_k: Optional[int] = Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ) + + include_stop_str_in_output: Optional[bool] = Field( + None, description="Whether to include the stop strings in output text." + ) class CompletionV2Response(CreateCompletionResponse): diff --git a/clients/python/llmengine/data_types/gen/openai.py b/clients/python/llmengine/data_types/gen/openai.py index b8222667..a97f0fd3 100644 --- a/clients/python/llmengine/data_types/gen/openai.py +++ b/clients/python/llmengine/data_types/gen/openai.py @@ -1,13 +1,19 @@ # mypy: ignore-errors # generated by datamodel-codegen: # filename: openai-spec.yaml -# timestamp: 2024-08-20T08:20:04+00:00 +# timestamp: 2024-08-22T02:56:18+00:00 from __future__ import annotations from typing import Any, Dict, List, Optional, Union -from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel +import pydantic + +PYDANTIC_V2 = hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2.") +if PYDANTIC_V2: + from pydantic.v1 import AnyUrl, BaseModel, Extra, Field # noqa: F401 +else: + from pydantic import AnyUrl, BaseModel, Extra, Field # type: ignore # noqa: F401 from typing_extensions import Annotated, Literal @@ -28,42 +34,39 @@ class DeleteModelResponse(BaseModel): object: str -class Prompt(RootModel[Optional[List[int]]]): - root: Annotated[ +class Prompt(BaseModel): + __root__: Annotated[ Optional[List[int]], Field( - "<|endoftext|>", description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", - examples=["[1212, 318, 257, 1332, 13]"], - min_length=1, + example="[1212, 318, 257, 1332, 13]", + min_items=1, ), ] = "<|endoftext|>" -class Prompt1Item(RootModel[List[int]]): - root: Annotated[List[int], Field(min_length=1)] +class Prompt1Item(BaseModel): + __root__: Annotated[List[int], Field(min_items=1)] -class Prompt1(RootModel[Optional[List[Prompt1Item]]]): - root: Annotated[ +class Prompt1(BaseModel): + __root__: Annotated[ Optional[List[Prompt1Item]], Field( - "<|endoftext|>", description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", - examples=["[[1212, 318, 257, 1332, 13]]"], - min_length=1, + example="[[1212, 318, 257, 1332, 13]]", + min_items=1, ), ] = "<|endoftext|>" -class Stop(RootModel[Optional[List[str]]]): - root: Annotated[ +class Stop(BaseModel): + __root__: Annotated[ Optional[List[str]], Field( - None, description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", - max_length=4, - min_length=1, + max_items=4, + min_items=1, ), ] = None @@ -100,10 +103,9 @@ class ImageUrl(BaseModel): detail: Annotated[ Optional[Literal["auto", "low", "high"]], Field( - "auto", - description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding).", + description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding)." ), - ] + ] = "auto" class ChatCompletionRequestMessageContentPartImage(BaseModel): @@ -116,52 +118,34 @@ class ChatCompletionRequestMessageContentPartRefusal(BaseModel): refusal: Annotated[str, Field(description="The refusal message generated by the model.")] -class ChatCompletionRequestSystemMessageContentPart( - RootModel[ChatCompletionRequestMessageContentPartText] -): - root: ChatCompletionRequestMessageContentPartText +class ChatCompletionRequestSystemMessageContentPart(BaseModel): + __root__: ChatCompletionRequestMessageContentPartText -class ChatCompletionRequestUserMessageContentPart( - RootModel[ - Union[ - ChatCompletionRequestMessageContentPartText, - ChatCompletionRequestMessageContentPartImage, - ] - ] -): - root: Union[ +class ChatCompletionRequestUserMessageContentPart(BaseModel): + __root__: Union[ ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ] -class ChatCompletionRequestAssistantMessageContentPart( - RootModel[ - Union[ - ChatCompletionRequestMessageContentPartText, - ChatCompletionRequestMessageContentPartRefusal, - ] - ] -): - root: Union[ +class ChatCompletionRequestAssistantMessageContentPart(BaseModel): + __root__: Union[ ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartRefusal, ] -class ChatCompletionRequestToolMessageContentPart( - RootModel[ChatCompletionRequestMessageContentPartText] -): - root: ChatCompletionRequestMessageContentPartText +class ChatCompletionRequestToolMessageContentPart(BaseModel): + __root__: ChatCompletionRequestMessageContentPartText -class Content(RootModel[List[ChatCompletionRequestSystemMessageContentPart]]): - root: Annotated[ +class Content(BaseModel): + __root__: Annotated[ List[ChatCompletionRequestSystemMessageContentPart], Field( description="An array of content parts with a defined type. For system messages, only type `text` is supported.", - min_length=1, + min_items=1, title="Array of content parts", ), ] @@ -178,18 +162,17 @@ class ChatCompletionRequestSystemMessage(BaseModel): name: Annotated[ Optional[str], Field( - None, - description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." ), - ] + ] = None -class Content1(RootModel[List[ChatCompletionRequestUserMessageContentPart]]): - root: Annotated[ +class Content1(BaseModel): + __root__: Annotated[ List[ChatCompletionRequestUserMessageContentPart], Field( description="An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model.", - min_length=1, + min_items=1, title="Array of content parts", ), ] @@ -206,19 +189,17 @@ class ChatCompletionRequestUserMessage(BaseModel): name: Annotated[ Optional[str], Field( - None, - description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." ), - ] + ] = None -class Content2(RootModel[Optional[List[ChatCompletionRequestAssistantMessageContentPart]]]): - root: Annotated[ +class Content2(BaseModel): + __root__: Annotated[ Optional[List[ChatCompletionRequestAssistantMessageContentPart]], Field( - None, description="An array of content parts with a defined type. Can be one or more of type `text`, or exactly one of type `refusal`.", - min_length=1, + min_items=1, title="Array of content parts", ), ] = None @@ -234,12 +215,12 @@ class FunctionCall(BaseModel): name: Annotated[str, Field(description="The name of the function to call.")] -class Content3(RootModel[List[ChatCompletionRequestToolMessageContentPart]]): - root: Annotated[ +class Content3(BaseModel): + __root__: Annotated[ List[ChatCompletionRequestToolMessageContentPart], Field( description="An array of content parts with a defined type. For tool messages, only type `text` is supported.", - min_length=1, + min_items=1, title="Array of content parts", ), ] @@ -265,19 +246,18 @@ class ChatCompletionRequestFunctionMessage(BaseModel): class FunctionParameters(BaseModel): pass - model_config = ConfigDict( - extra="allow", - ) + + class Config: + extra = Extra.allow class ChatCompletionFunctions(BaseModel): description: Annotated[ Optional[str], Field( - None, - description="A description of what the function does, used by the model to choose when and how to call the function.", + description="A description of what the function does, used by the model to choose when and how to call the function." ), - ] + ] = None name: Annotated[ str, Field( @@ -295,10 +275,9 @@ class FunctionObject(BaseModel): description: Annotated[ Optional[str], Field( - None, - description="A description of what the function does, used by the model to choose when and how to call the function.", + description="A description of what the function does, used by the model to choose when and how to call the function." ), - ] + ] = None name: Annotated[ str, Field( @@ -309,10 +288,9 @@ class FunctionObject(BaseModel): strict: Annotated[ Optional[bool], Field( - False, - description="Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling).", + description="Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling)." ), - ] + ] = False class ResponseFormatText(BaseModel): @@ -331,33 +309,31 @@ class ResponseFormatJsonObject(BaseModel): class ResponseFormatJsonSchemaSchema(BaseModel): pass - model_config = ConfigDict( - extra="allow", - ) + + class Config: + extra = Extra.allow class JsonSchema(BaseModel): description: Annotated[ Optional[str], Field( - None, - description="A description of what the response format is for, used by the model to determine how to respond in the format.", + description="A description of what the response format is for, used by the model to determine how to respond in the format." ), - ] + ] = None name: Annotated[ str, Field( description="The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." ), ] - schema_: Annotated[Optional[ResponseFormatJsonSchemaSchema], Field(None, alias="schema")] + schema_: Annotated[Optional[ResponseFormatJsonSchemaSchema], Field(alias="schema")] = None strict: Annotated[ Optional[bool], Field( - False, - description="Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs).", + description="Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs)." ), - ] + ] = False class ResponseFormatJsonSchema(BaseModel): @@ -380,8 +356,8 @@ class ChatCompletionNamedToolChoice(BaseModel): function: Function -class ParallelToolCalls(RootModel[bool]): - root: Annotated[ +class ParallelToolCalls(BaseModel): + __root__: Annotated[ bool, Field( description="Whether to enable [parallel function calling](/docs/guides/function-calling/parallel-function-calling) during tool use." @@ -409,31 +385,27 @@ class ChatCompletionMessageToolCall(BaseModel): class Function2(BaseModel): - name: Annotated[Optional[str], Field(None, description="The name of the function to call.")] + name: Annotated[Optional[str], Field(description="The name of the function to call.")] = None arguments: Annotated[ Optional[str], Field( - None, - description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." ), - ] + ] = None class ChatCompletionMessageToolCallChunk(BaseModel): index: int - id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call.")] = None type: Annotated[ Optional[Literal["function"]], - Field( - None, - description="The type of the tool. Currently, only `function` is supported.", - ), - ] + Field(description="The type of the tool. Currently, only `function` is supported."), + ] = None function: Optional[Function2] = None -class ChatCompletionRole(RootModel[Literal["system", "user", "assistant", "tool", "function"]]): - root: Annotated[ +class ChatCompletionRole(BaseModel): + __root__: Annotated[ Literal["system", "user", "assistant", "tool", "function"], Field(description="The role of the author of a message"), ] @@ -443,50 +415,48 @@ class ChatCompletionStreamOptions(BaseModel): include_usage: Annotated[ Optional[bool], Field( - None, - description="If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.\n", + description="If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.\n" ), - ] + ] = None class FunctionCall2(BaseModel): arguments: Annotated[ Optional[str], Field( - None, - description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." ), - ] - name: Annotated[Optional[str], Field(None, description="The name of the function to call.")] + ] = None + name: Annotated[Optional[str], Field(description="The name of the function to call.")] = None class ChatCompletionStreamResponseDelta(BaseModel): - content: Annotated[Optional[str], Field(None, description="The contents of the chunk message.")] + content: Annotated[ + Optional[str], Field(description="The contents of the chunk message.") + ] = None function_call: Annotated[ Optional[FunctionCall2], Field( - None, - description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." ), - ] + ] = None tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = None role: Annotated[ Optional[Literal["system", "user", "assistant", "tool"]], - Field(None, description="The role of the author of this message."), - ] + Field(description="The role of the author of this message."), + ] = None refusal: Annotated[ - Optional[str], - Field(None, description="The refusal message generated by the model."), - ] + Optional[str], Field(description="The refusal message generated by the model.") + ] = None -class Stop1(RootModel[List[str]]): - root: Annotated[ +class Stop1(BaseModel): + __root__: Annotated[ List[str], Field( description="Up to 4 sequences where the API will stop generating further tokens.\n", - max_length=4, - min_length=1, + max_items=4, + min_items=1, ), ] @@ -544,8 +514,8 @@ class Choice3(BaseModel): delta: ChatCompletionStreamResponseDelta logprobs: Annotated[ Optional[Logprobs2], - Field(None, description="Log probability information for the choice."), - ] + Field(description="Log probability information for the choice."), + ] = None finish_reason: Annotated[ Literal["stop", "length", "tool_calls", "content_filter", "function_call"], Field( @@ -589,18 +559,16 @@ class CreateChatCompletionStreamResponse(BaseModel): service_tier: Annotated[ Optional[Literal["scale", "default"]], Field( - None, description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", - examples=["scale"], + example="scale", ), - ] + ] = None system_fingerprint: Annotated[ Optional[str], Field( - None, - description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" ), - ] + ] = None object: Annotated[ Literal["chat.completion.chunk"], Field(description="The object type, which is always `chat.completion.chunk`."), @@ -608,10 +576,9 @@ class CreateChatCompletionStreamResponse(BaseModel): usage: Annotated[ Optional[Usage], Field( - None, - description='An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.\nWhen present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.\n', + description='An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.\nWhen present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.\n' ), - ] + ] = None class CreateChatCompletionImageResponse(BaseModel): @@ -623,91 +590,78 @@ class CreateImageRequest(BaseModel): str, Field( description="A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`.", - examples=["A cute baby sea otter"], + example="A cute baby sea otter", ), ] model: Annotated[ Optional[Union[str, Literal["dall-e-2", "dall-e-3"]]], - Field( - "dall-e-2", - description="The model to use for image generation.", - examples=["dall-e-3"], - ), - ] + Field(description="The model to use for image generation.", example="dall-e-3"), + ] = "dall-e-2" n: Annotated[ Optional[int], Field( - 1, description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", - examples=[1], + example=1, ge=1, le=10, ), - ] + ] = 1 quality: Annotated[ Optional[Literal["standard", "hd"]], Field( - "standard", description="The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`.", - examples=["standard"], + example="standard", ), - ] + ] = "standard" response_format: Annotated[ Optional[Literal["url", "b64_json"]], Field( - "url", description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", - examples=["url"], + example="url", ), - ] + ] = "url" size: Annotated[ Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]], Field( - "1024x1024", description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models.", - examples=["1024x1024"], + example="1024x1024", ), - ] + ] = "1024x1024" style: Annotated[ Optional[Literal["vivid", "natural"]], Field( - "vivid", description="The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`.", - examples=["vivid"], + example="vivid", ), - ] + ] = "vivid" user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", - examples=["user-1234"], + example="user-1234", ), - ] + ] = None class Image(BaseModel): b64_json: Annotated[ Optional[str], Field( - None, - description="The base64-encoded JSON of the generated image, if `response_format` is `b64_json`.", + description="The base64-encoded JSON of the generated image, if `response_format` is `b64_json`." ), - ] + ] = None url: Annotated[ Optional[str], Field( - None, - description="The URL of the generated image, if `response_format` is `url` (default).", + description="The URL of the generated image, if `response_format` is `url` (default)." ), - ] + ] = None revised_prompt: Annotated[ Optional[str], Field( - None, - description="The prompt that was used to generate the image, if there was any revision to the prompt.", + description="The prompt that was used to generate the image, if there was any revision to the prompt." ), - ] + ] = None class CreateImageEditRequest(BaseModel): @@ -721,58 +675,52 @@ class CreateImageEditRequest(BaseModel): str, Field( description="A text description of the desired image(s). The maximum length is 1000 characters.", - examples=["A cute baby sea otter wearing a beret"], + example="A cute baby sea otter wearing a beret", ), ] mask: Annotated[ Optional[bytes], Field( - None, - description="An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`.", + description="An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`." ), - ] + ] = None model: Annotated[ Optional[Union[str, Literal["dall-e-2"]]], Field( - "dall-e-2", description="The model to use for image generation. Only `dall-e-2` is supported at this time.", - examples=["dall-e-2"], + example="dall-e-2", ), - ] + ] = "dall-e-2" n: Annotated[ Optional[int], Field( - 1, description="The number of images to generate. Must be between 1 and 10.", - examples=[1], + example=1, ge=1, le=10, ), - ] + ] = 1 size: Annotated[ Optional[Literal["256x256", "512x512", "1024x1024"]], Field( - "1024x1024", description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", - examples=["1024x1024"], + example="1024x1024", ), - ] + ] = "1024x1024" response_format: Annotated[ Optional[Literal["url", "b64_json"]], Field( - "url", description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", - examples=["url"], + example="url", ), - ] + ] = "url" user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", - examples=["user-1234"], + example="user-1234", ), - ] + ] = None class CreateImageVariationRequest(BaseModel): @@ -785,45 +733,40 @@ class CreateImageVariationRequest(BaseModel): model: Annotated[ Optional[Union[str, Literal["dall-e-2"]]], Field( - "dall-e-2", description="The model to use for image generation. Only `dall-e-2` is supported at this time.", - examples=["dall-e-2"], + example="dall-e-2", ), - ] + ] = "dall-e-2" n: Annotated[ Optional[int], Field( - 1, description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", - examples=[1], + example=1, ge=1, le=10, ), - ] + ] = 1 response_format: Annotated[ Optional[Literal["url", "b64_json"]], Field( - "url", description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", - examples=["url"], + example="url", ), - ] + ] = "url" size: Annotated[ Optional[Literal["256x256", "512x512", "1024x1024"]], Field( - "1024x1024", description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", - examples=["1024x1024"], + example="1024x1024", ), - ] + ] = "1024x1024" user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", - examples=["user-1234"], + example="user-1234", ), - ] + ] = None class CreateModerationRequest(BaseModel): @@ -831,11 +774,10 @@ class CreateModerationRequest(BaseModel): model: Annotated[ Optional[Union[str, Literal["text-moderation-latest", "text-moderation-stable"]]], Field( - "text-moderation-latest", description="Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`.\n\nThe default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`.\n", - examples=["text-moderation-stable"], + example="text-moderation-stable", ), - ] + ] = "text-moderation-latest" class Categories(BaseModel): @@ -986,9 +928,9 @@ class CreateModerationResponse(BaseModel): class CreateFileRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + file: Annotated[bytes, Field(description="The File object (not file name) to be uploaded.\n")] purpose: Annotated[ Literal["assistants", "batch", "fine-tune", "vision"], @@ -1005,9 +947,9 @@ class DeleteFileResponse(BaseModel): class CreateUploadRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + filename: Annotated[str, Field(description="The name of the file to upload.\n")] purpose: Annotated[ Literal["assistants", "batch", "fine-tune", "vision"], @@ -1025,35 +967,34 @@ class CreateUploadRequest(BaseModel): class AddUploadPartRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + data: Annotated[bytes, Field(description="The chunk of bytes for this Part.\n")] class CompleteUploadRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + part_ids: Annotated[List[str], Field(description="The ordered list of Part IDs.\n")] md5: Annotated[ Optional[str], Field( - None, - description="The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect.\n", + description="The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect.\n" ), - ] + ] = None class CancelUploadRequest(BaseModel): pass - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid -class BatchSize(RootModel[int]): - root: Annotated[ + +class BatchSize(BaseModel): + __root__: Annotated[ int, Field( description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", @@ -1063,8 +1004,8 @@ class BatchSize(RootModel[int]): ] -class LearningRateMultiplier(RootModel[float]): - root: Annotated[ +class LearningRateMultiplier(BaseModel): + __root__: Annotated[ float, Field( description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", @@ -1073,8 +1014,8 @@ class LearningRateMultiplier(RootModel[float]): ] -class NEpochs(RootModel[int]): - root: Annotated[ +class NEpochs(BaseModel): + __root__: Annotated[ int, Field( description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", @@ -1088,24 +1029,21 @@ class Hyperparameters(BaseModel): batch_size: Annotated[ Optional[Union[Literal["auto"], BatchSize]], Field( - "auto", - description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n" ), - ] + ] = "auto" learning_rate_multiplier: Annotated[ Optional[Union[Literal["auto"], LearningRateMultiplier]], Field( - "auto", - description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n" ), - ] + ] = "auto" n_epochs: Annotated[ Optional[Union[Literal["auto"], NEpochs]], Field( - "auto", - description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n" ), - ] + ] = "auto" class Wandb(BaseModel): @@ -1113,30 +1051,27 @@ class Wandb(BaseModel): str, Field( description="The name of the project that the new run will be created under.\n", - examples=["my-wandb-project"], + example="my-wandb-project", ), ] name: Annotated[ Optional[str], Field( - None, - description="A display name to set for the run. If not set, we will use the Job ID as the name.\n", + description="A display name to set for the run. If not set, we will use the Job ID as the name.\n" ), - ] + ] = None entity: Annotated[ Optional[str], Field( - None, - description="The entity to use for the run. This allows you to set the team or username of the WandB user that you would\nlike associated with the run. If not set, the default entity for the registered WandB API key is used.\n", + description="The entity to use for the run. This allows you to set the team or username of the WandB user that you would\nlike associated with the run. If not set, the default entity for the registered WandB API key is used.\n" ), - ] + ] = None tags: Annotated[ Optional[List[str]], Field( - None, - description='A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some\ndefault tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".\n', + description='A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some\ndefault tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".\n' ), - ] + ] = None class Integration(BaseModel): @@ -1159,108 +1094,102 @@ class CreateFineTuningJobRequest(BaseModel): Union[str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo", "gpt-4o-mini"]], Field( description="The name of the model to fine-tune. You can select one of the\n[supported models](/docs/guides/fine-tuning/which-models-can-be-fine-tuned).\n", - examples=["gpt-4o-mini"], + example="gpt-4o-mini", ), ] training_file: Annotated[ str, Field( description="The ID of an uploaded file that contains training data.\n\nSee [upload file](/docs/api-reference/files/create) for how to upload a file.\n\nYour dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`.\n\nThe contents of the file should differ depending on if the model uses the [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) format.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", - examples=["file-abc123"], + example="file-abc123", ), ] hyperparameters: Annotated[ Optional[Hyperparameters], - Field(None, description="The hyperparameters used for the fine-tuning job."), - ] + Field(description="The hyperparameters used for the fine-tuning job."), + ] = None suffix: Annotated[ Optional[str], Field( - None, description='A string of up to 18 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`.\n', max_length=40, min_length=1, ), - ] + ] = None validation_file: Annotated[ Optional[str], Field( - None, description="The ID of an uploaded file that contains validation data.\n\nIf you provide this file, the data is used to generate validation\nmetrics periodically during fine-tuning. These metrics can be viewed in\nthe fine-tuning results file.\nThe same data should not be present in both train and validation files.\n\nYour dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", - examples=["file-abc123"], + example="file-abc123", ), - ] + ] = None integrations: Annotated[ Optional[List[Integration]], - Field( - None, - description="A list of integrations to enable for your fine-tuning job.", - ), - ] + Field(description="A list of integrations to enable for your fine-tuning job."), + ] = None seed: Annotated[ Optional[int], Field( - None, description="The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases.\nIf a seed is not specified, one will be generated for you.\n", - examples=[42], + example=42, ge=0, le=2147483647, ), - ] + ] = None -class Input(RootModel[List[str]]): - root: Annotated[ +class Input(BaseModel): + __root__: Annotated[ List[str], Field( description="The array of strings that will be turned into an embedding.", - examples=["The quick brown fox jumped over the lazy dog"], - max_length=2048, - min_length=1, + example="The quick brown fox jumped over the lazy dog", + max_items=2048, + min_items=1, title="array", ), ] -class Input1(RootModel[List[int]]): - root: Annotated[ +class Input1(BaseModel): + __root__: Annotated[ List[int], Field( description="The array of integers that will be turned into an embedding.", - examples=["[1212, 318, 257, 1332, 13]"], - max_length=2048, - min_length=1, + example="[1212, 318, 257, 1332, 13]", + max_items=2048, + min_items=1, title="array", ), ] -class Input2Item(RootModel[List[int]]): - root: Annotated[List[int], Field(min_length=1)] +class Input2Item(BaseModel): + __root__: Annotated[List[int], Field(min_items=1)] -class Input2(RootModel[List[Input2Item]]): - root: Annotated[ +class Input2(BaseModel): + __root__: Annotated[ List[Input2Item], Field( description="The array of arrays containing integers that will be turned into an embedding.", - examples=["[[1212, 318, 257, 1332, 13]]"], - max_length=2048, - min_length=1, + example="[[1212, 318, 257, 1332, 13]]", + max_items=2048, + min_items=1, title="array", ), ] class CreateEmbeddingRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + input: Annotated[ Union[str, Input, Input1, Input2], Field( description="Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", - examples=["The quick brown fox jumped over the lazy dog"], + example="The quick brown fox jumped over the lazy dog", ), ] model: Annotated[ @@ -1274,33 +1203,30 @@ class CreateEmbeddingRequest(BaseModel): ], Field( description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", - examples=["text-embedding-3-small"], + example="text-embedding-3-small", ), ] encoding_format: Annotated[ Optional[Literal["float", "base64"]], Field( - "float", description="The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).", - examples=["float"], + example="float", ), - ] + ] = "float" dimensions: Annotated[ Optional[int], Field( - None, description="The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.\n", ge=1, ), - ] + ] = None user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", - examples=["user-1234"], + example="user-1234", ), - ] + ] = None class Usage1(BaseModel): @@ -1311,9 +1237,9 @@ class Usage1(BaseModel): class CreateTranscriptionRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + file: Annotated[ bytes, Field( @@ -1324,45 +1250,40 @@ class CreateTranscriptionRequest(BaseModel): Union[str, Literal["whisper-1"]], Field( description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", - examples=["whisper-1"], + example="whisper-1", ), ] language: Annotated[ Optional[str], Field( - None, - description="The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.\n", + description="The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.\n" ), - ] + ] = None prompt: Annotated[ Optional[str], Field( - None, - description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language.\n", + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language.\n" ), - ] + ] = None response_format: Annotated[ Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]], Field( - "json", - description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n" ), - ] + ] = "json" temperature: Annotated[ Optional[float], Field( - 0, - description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n" ), - ] + ] = 0 timestamp_granularities__: Annotated[ Optional[List[Literal["word", "segment"]]], Field( - ["segment"], alias="timestamp_granularities[]", description="The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency.\n", ), - ] + ] = ["segment"] class CreateTranscriptionResponseJson(BaseModel): @@ -1412,21 +1333,18 @@ class CreateTranscriptionResponseVerboseJson(BaseModel): text: Annotated[str, Field(description="The transcribed text.")] words: Annotated[ Optional[List[TranscriptionWord]], - Field(None, description="Extracted words and their corresponding timestamps."), - ] + Field(description="Extracted words and their corresponding timestamps."), + ] = None segments: Annotated[ Optional[List[TranscriptionSegment]], - Field( - None, - description="Segments of the transcribed text and their corresponding details.", - ), - ] + Field(description="Segments of the transcribed text and their corresponding details."), + ] = None class CreateTranslationRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + file: Annotated[ bytes, Field( @@ -1437,30 +1355,27 @@ class CreateTranslationRequest(BaseModel): Union[str, Literal["whisper-1"]], Field( description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", - examples=["whisper-1"], + example="whisper-1", ), ] prompt: Annotated[ Optional[str], Field( - None, - description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English.\n", + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English.\n" ), - ] + ] = None response_format: Annotated[ Optional[str], Field( - "json", - description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n" ), - ] + ] = "json" temperature: Annotated[ Optional[float], Field( - 0, - description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n" ), - ] + ] = 0 class CreateTranslationResponseJson(BaseModel): @@ -1476,17 +1391,14 @@ class CreateTranslationResponseVerboseJson(BaseModel): text: Annotated[str, Field(description="The translated text.")] segments: Annotated[ Optional[List[TranscriptionSegment]], - Field( - None, - description="Segments of the translated text and their corresponding details.", - ), - ] + Field(description="Segments of the translated text and their corresponding details."), + ] = None class CreateSpeechRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + model: Annotated[ Union[str, Literal["tts-1", "tts-1-hd"]], Field( @@ -1509,19 +1421,17 @@ class CreateSpeechRequest(BaseModel): response_format: Annotated[ Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]], Field( - "mp3", - description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`.", + description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`." ), - ] + ] = "mp3" speed: Annotated[ Optional[float], Field( - 1.0, description="The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.", ge=0.25, le=4.0, ), - ] + ] = 1.0 class Model(BaseModel): @@ -1576,10 +1486,9 @@ class OpenAIFile(BaseModel): status_details: Annotated[ Optional[str], Field( - None, - description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`.", + description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`." ), - ] + ] = None class Upload(BaseModel): @@ -1611,12 +1520,12 @@ class Upload(BaseModel): ] object: Annotated[ Optional[Literal["upload"]], - Field(None, description='The object type, which is always "upload".'), - ] + Field(description='The object type, which is always "upload".'), + ] = None file: Annotated[ Optional[OpenAIFile], - Field(None, description="The ready File object after the Upload is completed."), - ] + Field(description="The ready File object after the Upload is completed."), + ] = None class UploadPart(BaseModel): @@ -1667,8 +1576,8 @@ class Error1(BaseModel): ] -class NEpochs1(RootModel[int]): - root: Annotated[ +class NEpochs1(BaseModel): + __root__: Annotated[ int, Field( description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.', @@ -1752,13 +1661,12 @@ class FineTuningJobCheckpoint(BaseModel): class FinetuneCompletionRequestInput(BaseModel): prompt: Annotated[ - Optional[str], - Field(None, description="The input prompt for this training example."), - ] + Optional[str], Field(description="The input prompt for this training example.") + ] = None completion: Annotated[ Optional[str], - Field(None, description="The desired completion for this training example."), - ] + Field(description="The desired completion for this training example."), + ] = None class CompletionUsage(BaseModel): @@ -1800,17 +1708,8 @@ class RunStepCompletionUsage(BaseModel): ] -class AssistantsApiResponseFormatOption( - RootModel[ - Union[ - Literal["auto"], - ResponseFormatText, - ResponseFormatJsonObject, - ResponseFormatJsonSchema, - ] - ] -): - root: Annotated[ +class AssistantsApiResponseFormatOption(BaseModel): + __root__: Annotated[ Union[ Literal["auto"], ResponseFormatText, @@ -1827,22 +1726,20 @@ class CodeInterpreter(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - [], description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool.\n", - max_length=20, + max_items=20, ), - ] + ] = [] class FileSearch(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", - max_length=1, + max_items=1, ), - ] + ] = None class ToolResources(BaseModel): @@ -1854,24 +1751,23 @@ class CodeInterpreter1(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - [], description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", - max_length=20, + max_items=20, ), - ] + ] = [] class ChunkingStrategy(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class Static(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + max_chunk_size_tokens: Annotated[ int, Field( @@ -1889,9 +1785,9 @@ class Static(BaseModel): class ChunkingStrategy1(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: Static @@ -1900,25 +1796,22 @@ class VectorStore(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", - max_length=10000, + max_items=10000, ), - ] + ] = None chunking_strategy: Annotated[ Optional[Union[ChunkingStrategy, ChunkingStrategy1]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class FileSearch1(BaseModel): @@ -1926,30 +1819,29 @@ class FileSearch1(BaseModel): List[str], Field( description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", - max_length=1, + max_items=1, ), ] vector_stores: Annotated[ Optional[List[VectorStore]], Field( - None, description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", - max_length=1, + max_items=1, ), - ] + ] = None class ChunkingStrategy2(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class ChunkingStrategy3(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: Static @@ -1958,41 +1850,37 @@ class VectorStore1(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", - max_length=10000, + max_items=10000, ), - ] + ] = None chunking_strategy: Annotated[ Optional[Union[ChunkingStrategy2, ChunkingStrategy3]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class FileSearch2(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", - max_length=1, + max_items=1, ), - ] + ] = None vector_stores: Annotated[ List[VectorStore1], Field( description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", - max_length=1, + max_items=1, ), ] @@ -2006,22 +1894,20 @@ class CodeInterpreter2(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - [], description="Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", - max_length=20, + max_items=20, ), - ] + ] = [] class FileSearch3(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", - max_length=1, + max_items=1, ), - ] + ] = None class ToolResources2(BaseModel): @@ -2046,12 +1932,11 @@ class FileSearch4(BaseModel): max_num_results: Annotated[ Optional[int], Field( - None, description="The maximum number of results the file search tool should output. The default is 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between 1 and 50 inclusive.\n\nNote that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information.\n", ge=1, le=50, ), - ] + ] = None class AssistantToolsFileSearch(BaseModel): @@ -2060,9 +1945,8 @@ class AssistantToolsFileSearch(BaseModel): Field(description="The type of tool being defined: `file_search`"), ] file_search: Annotated[ - Optional[FileSearch4], - Field(None, description="Overrides for the file search tool."), - ] + Optional[FileSearch4], Field(description="Overrides for the file search tool.") + ] = None class AssistantToolsFileSearchTypeOnly(BaseModel): @@ -2090,11 +1974,10 @@ class TruncationObject(BaseModel): last_messages: Annotated[ Optional[int], Field( - None, description="The number of most recent messages from the thread when constructing the context for the run.", ge=1, ), - ] + ] = None class Function3(BaseModel): @@ -2123,46 +2006,40 @@ class IncompleteDetails(BaseModel): reason: Annotated[ Optional[Literal["max_completion_tokens", "max_prompt_tokens"]], Field( - None, - description="The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run.", + description="The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run." ), - ] + ] = None class ModifyRunRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class ToolOutput(BaseModel): tool_call_id: Annotated[ Optional[str], Field( - None, - description="The ID of the tool call in the `required_action` object within the run object the output is being submitted for.", + description="The ID of the tool call in the `required_action` object within the run object the output is being submitted for." ), - ] + ] = None output: Annotated[ Optional[str], - Field( - None, - description="The output of the tool call to be submitted to continue the run.", - ), - ] + Field(description="The output of the tool call to be submitted to continue the run."), + ] = None class SubmitToolOutputsRunRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + tool_outputs: Annotated[ List[ToolOutput], Field(description="A list of tools for which the outputs are being submitted."), @@ -2170,10 +2047,9 @@ class SubmitToolOutputsRunRequest(BaseModel): stream: Annotated[ Optional[bool], Field( - None, - description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" ), - ] + ] = None class Function4(BaseModel): @@ -2204,22 +2080,20 @@ class CodeInterpreter3(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - [], description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", - max_length=20, + max_items=20, ), - ] + ] = [] class FileSearch5(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", - max_length=1, + max_items=1, ), - ] + ] = None class ToolResources3(BaseModel): @@ -2231,11 +2105,10 @@ class FileSearch6(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", - max_length=1, + max_items=1, ), - ] + ] = None class ToolResources4(BaseModel): @@ -2271,16 +2144,16 @@ class ThreadObject(BaseModel): class ChunkingStrategy4(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class ChunkingStrategy5(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: Static @@ -2289,25 +2162,22 @@ class VectorStore2(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", - max_length=10000, + max_items=10000, ), - ] + ] = None chunking_strategy: Annotated[ Optional[Union[ChunkingStrategy4, ChunkingStrategy5]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class FileSearch7(BaseModel): @@ -2315,30 +2185,29 @@ class FileSearch7(BaseModel): List[str], Field( description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", - max_length=1, + max_items=1, ), ] vector_stores: Annotated[ Optional[List[VectorStore2]], Field( - None, description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", - max_length=1, + max_items=1, ), - ] + ] = None class ChunkingStrategy6(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class ChunkingStrategy7(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: Static @@ -2347,41 +2216,37 @@ class VectorStore3(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", - max_length=10000, + max_items=10000, ), - ] + ] = None chunking_strategy: Annotated[ Optional[Union[ChunkingStrategy6, ChunkingStrategy7]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class FileSearch8(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", - max_length=1, + max_items=1, ), - ] + ] = None vector_stores: Annotated[ List[VectorStore3], Field( description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", - max_length=1, + max_items=1, ), ] @@ -2395,11 +2260,10 @@ class FileSearch9(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", - max_length=1, + max_items=1, ), - ] + ] = None class ToolResources6(BaseModel): @@ -2408,23 +2272,21 @@ class ToolResources6(BaseModel): class ModifyThreadRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + tool_resources: Annotated[ Optional[ToolResources6], Field( - None, - description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class DeleteThreadResponse(BaseModel): @@ -2434,11 +2296,11 @@ class DeleteThreadResponse(BaseModel): class ListThreadsResponse(BaseModel): - object: Annotated[str, Field(examples=["list"])] + object: Annotated[str, Field(example="list")] data: List[ThreadObject] - first_id: Annotated[str, Field(examples=["asst_abc123"])] - last_id: Annotated[str, Field(examples=["asst_abc456"])] - has_more: Annotated[bool, Field(examples=[False])] + first_id: Annotated[str, Field(example="asst_abc123")] + last_id: Annotated[str, Field(example="asst_abc456")] + has_more: Annotated[bool, Field(example=False)] class IncompleteDetails1(BaseModel): @@ -2450,26 +2312,24 @@ class IncompleteDetails1(BaseModel): class Attachment(BaseModel): file_id: Annotated[ - Optional[str], - Field(None, description="The ID of the file to attach to the message."), - ] + Optional[str], Field(description="The ID of the file to attach to the message.") + ] = None tools: Annotated[ Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearchTypeOnly]]], - Field(None, description="The tools to add this file to."), - ] + Field(description="The tools to add this file to."), + ] = None class ModifyMessageRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class DeleteMessageResponse(BaseModel): @@ -2488,10 +2348,9 @@ class ImageFile(BaseModel): detail: Annotated[ Optional[Literal["auto", "low", "high"]], Field( - "auto", - description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`." ), - ] + ] = "auto" class MessageContentImageFileObject(BaseModel): @@ -2503,17 +2362,15 @@ class ImageFile1(BaseModel): file_id: Annotated[ Optional[str], Field( - None, - description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.', + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.' ), - ] + ] = None detail: Annotated[ Optional[Literal["auto", "low", "high"]], Field( - "auto", - description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`." ), - ] + ] = "auto" class MessageDeltaContentImageFileObject(BaseModel): @@ -2532,10 +2389,9 @@ class ImageUrl1(BaseModel): detail: Annotated[ Optional[Literal["auto", "low", "high"]], Field( - "auto", - description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`", + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`" ), - ] + ] = "auto" class MessageContentImageUrlObject(BaseModel): @@ -2547,17 +2403,15 @@ class ImageUrl2(BaseModel): url: Annotated[ Optional[str], Field( - None, - description="The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp.", + description="The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." ), - ] + ] = None detail: Annotated[ Optional[Literal["auto", "low", "high"]], Field( - "auto", - description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`." ), - ] + ] = "auto" class MessageDeltaContentImageUrlObject(BaseModel): @@ -2615,9 +2469,9 @@ class MessageDeltaContentRefusalObject(BaseModel): class FileCitation1(BaseModel): file_id: Annotated[ Optional[str], - Field(None, description="The ID of the specific File the citation is from."), - ] - quote: Annotated[Optional[str], Field(None, description="The specific quote in the file.")] + Field(description="The ID of the specific File the citation is from."), + ] = None + quote: Annotated[Optional[str], Field(description="The specific quote in the file.")] = None class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): @@ -2627,20 +2481,17 @@ class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] text: Annotated[ Optional[str], - Field( - None, - description="The text in the message content that needs to be replaced.", - ), - ] + Field(description="The text in the message content that needs to be replaced."), + ] = None file_citation: Optional[FileCitation1] = None - start_index: Annotated[Optional[int], Field(None, ge=0)] - end_index: Annotated[Optional[int], Field(None, ge=0)] + start_index: Annotated[Optional[int], Field(ge=0)] = None + end_index: Annotated[Optional[int], Field(ge=0)] = None class FilePath1(BaseModel): file_id: Annotated[ - Optional[str], Field(None, description="The ID of the file that was generated.") - ] + Optional[str], Field(description="The ID of the file that was generated.") + ] = None class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): @@ -2650,14 +2501,11 @@ class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] text: Annotated[ Optional[str], - Field( - None, - description="The text in the message content that needs to be replaced.", - ), - ] + Field(description="The text in the message content that needs to be replaced."), + ] = None file_path: Optional[FilePath1] = None - start_index: Annotated[Optional[int], Field(None, ge=0)] - end_index: Annotated[Optional[int], Field(None, ge=0)] + start_index: Annotated[Optional[int], Field(ge=0)] = None + end_index: Annotated[Optional[int], Field(ge=0)] = None class LastError1(BaseModel): @@ -2683,8 +2531,8 @@ class RunStepDetailsMessageCreationObject(BaseModel): class MessageCreation1(BaseModel): message_id: Annotated[ Optional[str], - Field(None, description="The ID of the message that was created by this run step."), - ] + Field(description="The ID of the message that was created by this run step."), + ] = None class RunStepDeltaStepDetailsMessageCreationObject(BaseModel): @@ -2702,8 +2550,8 @@ class RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject(BaseModel): type: Annotated[Literal["logs"], Field(description="Always `logs`.")] logs: Annotated[ Optional[str], - Field(None, description="The text output from the Code Interpreter tool call."), - ] + Field(description="The text output from the Code Interpreter tool call."), + ] = None class Image1(BaseModel): @@ -2720,8 +2568,8 @@ class RunStepDetailsToolCallsCodeOutputImageObject(BaseModel): class Image2(BaseModel): file_id: Annotated[ Optional[str], - Field(None, description="The [file](/docs/api-reference/files) ID of the image."), - ] + Field(description="The [file](/docs/api-reference/files) ID of the image."), + ] = None class RunStepDeltaStepDetailsToolCallsCodeOutputImageObject(BaseModel): @@ -2746,7 +2594,7 @@ class RunStepDetailsToolCallsFileSearchObject(BaseModel): class RunStepDeltaStepDetailsToolCallsFileSearchObject(BaseModel): index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] - id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call object.")] = None type: Annotated[ Literal["file_search"], Field( @@ -2784,22 +2632,21 @@ class RunStepDetailsToolCallsFunctionObject(BaseModel): class Function6(BaseModel): - name: Annotated[Optional[str], Field(None, description="The name of the function.")] + name: Annotated[Optional[str], Field(description="The name of the function.")] = None arguments: Annotated[ - Optional[str], Field(None, description="The arguments passed to the function.") - ] + Optional[str], Field(description="The arguments passed to the function.") + ] = None output: Annotated[ Optional[str], Field( - None, - description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet.", + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." ), - ] + ] = None class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] - id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call object.")] = None type: Annotated[ Literal["function"], Field( @@ -2808,8 +2655,8 @@ class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): ] function: Annotated[ Optional[Function6], - Field(None, description="The definition of the function that was called."), - ] + Field(description="The definition of the function that was called."), + ] = None class VectorStoreExpirationAfter(BaseModel): @@ -2871,11 +2718,8 @@ class VectorStoreObject(BaseModel): expires_after: Optional[VectorStoreExpirationAfter] = None expires_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the vector store will expire.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the vector store will expire."), + ] = None last_active_at: Annotated[ int, Field( @@ -2891,26 +2735,25 @@ class VectorStoreObject(BaseModel): class UpdateVectorStoreRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) - name: Annotated[Optional[str], Field(None, description="The name of the vector store.")] + class Config: + extra = Extra.forbid + + name: Annotated[Optional[str], Field(description="The name of the vector store.")] = None expires_after: Optional[VectorStoreExpirationAfter] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class ListVectorStoresResponse(BaseModel): - object: Annotated[str, Field(examples=["list"])] + object: Annotated[str, Field(example="list")] data: List[VectorStoreObject] - first_id: Annotated[str, Field(examples=["vs_abc123"])] - last_id: Annotated[str, Field(examples=["vs_abc456"])] - has_more: Annotated[bool, Field(examples=[False])] + first_id: Annotated[str, Field(example="vs_abc123")] + last_id: Annotated[str, Field(example="vs_abc456")] + has_more: Annotated[bool, Field(example=False)] class DeleteVectorStoreResponse(BaseModel): @@ -2928,16 +2771,16 @@ class LastError2(BaseModel): class OtherChunkingStrategyResponseParam(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["other"], Field(description="Always `other`.")] class StaticChunkingStrategy(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + max_chunk_size_tokens: Annotated[ int, Field( @@ -2955,24 +2798,22 @@ class StaticChunkingStrategy(BaseModel): class AutoChunkingStrategyRequestParam(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] class StaticChunkingStrategyRequestParam(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: StaticChunkingStrategy -class ChunkingStrategyRequestParam( - RootModel[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]] -): - root: Annotated[ +class ChunkingStrategyRequestParam(BaseModel): + __root__: Annotated[ Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam], Field( description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." @@ -2981,9 +2822,9 @@ class ChunkingStrategyRequestParam( class CreateVectorStoreFileRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + file_id: Annotated[ str, Field( @@ -3041,15 +2882,15 @@ class VectorStoreFileBatchObject(BaseModel): class CreateVectorStoreFileBatchRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + file_ids: Annotated[ List[str], Field( description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", - max_length=500, - min_length=1, + max_items=500, + min_items=1, ), ] chunking_strategy: Optional[ChunkingStrategyRequestParam] = None @@ -3060,8 +2901,8 @@ class ThreadStreamEvent1(BaseModel): data: ThreadObject -class ThreadStreamEvent(RootModel[ThreadStreamEvent1]): - root: ThreadStreamEvent1 +class ThreadStreamEvent(BaseModel): + __root__: ThreadStreamEvent1 class ErrorEvent(BaseModel): @@ -3076,37 +2917,28 @@ class DoneEvent(BaseModel): class Datum(BaseModel): code: Annotated[ - Optional[str], - Field(None, description="An error code identifying the error type."), - ] + Optional[str], Field(description="An error code identifying the error type.") + ] = None message: Annotated[ Optional[str], - Field( - None, - description="A human-readable message providing more details about the error.", - ), - ] + Field(description="A human-readable message providing more details about the error."), + ] = None param: Annotated[ Optional[str], - Field( - None, - description="The name of the parameter that caused the error, if applicable.", - ), - ] + Field(description="The name of the parameter that caused the error, if applicable."), + ] = None line: Annotated[ Optional[int], Field( - None, - description="The line number of the input file where the error occurred, if applicable.", + description="The line number of the input file where the error occurred, if applicable." ), - ] + ] = None class Errors(BaseModel): object: Annotated[ - Optional[str], - Field(None, description="The object type, which is always `list`."), - ] + Optional[str], Field(description="The object type, which is always `list`.") + ] = None data: Optional[List[Datum]] = None @@ -3147,137 +2979,100 @@ class Batch(BaseModel): output_file_id: Annotated[ Optional[str], Field( - None, - description="The ID of the file containing the outputs of successfully executed requests.", + description="The ID of the file containing the outputs of successfully executed requests." ), - ] + ] = None error_file_id: Annotated[ Optional[str], - Field( - None, - description="The ID of the file containing the outputs of requests with errors.", - ), - ] + Field(description="The ID of the file containing the outputs of requests with errors."), + ] = None created_at: Annotated[ int, Field(description="The Unix timestamp (in seconds) for when the batch was created."), ] in_progress_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch started processing.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch started processing."), + ] = None expires_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch will expire.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch will expire."), + ] = None finalizing_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch started finalizing.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch started finalizing."), + ] = None completed_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch was completed.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch was completed."), + ] = None failed_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch failed.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch failed."), + ] = None expired_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch expired.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch expired."), + ] = None cancelling_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch started cancelling.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch started cancelling."), + ] = None cancelled_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch was cancelled.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch was cancelled."), + ] = None request_counts: Annotated[ Optional[RequestCounts], - Field( - None, - description="The request counts for different statuses within the batch.", - ), - ] + Field(description="The request counts for different statuses within the batch."), + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class BatchRequestInput(BaseModel): custom_id: Annotated[ Optional[str], Field( - None, - description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch.", + description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch." ), - ] + ] = None method: Annotated[ Optional[Literal["POST"]], Field( - None, - description="The HTTP method to be used for the request. Currently only `POST` is supported.", + description="The HTTP method to be used for the request. Currently only `POST` is supported." ), - ] + ] = None url: Annotated[ Optional[str], Field( - None, - description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported.", + description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported." ), - ] + ] = None class Response(BaseModel): status_code: Annotated[ - Optional[int], Field(None, description="The HTTP status code of the response") - ] + Optional[int], Field(description="The HTTP status code of the response") + ] = None request_id: Annotated[ Optional[str], Field( - None, - description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support.", + description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support." ), - ] + ] = None body: Annotated[ - Optional[Dict[str, Any]], - Field(None, description="The JSON body of the response"), - ] + Optional[Dict[str, Any]], Field(description="The JSON body of the response") + ] = None class Error2(BaseModel): - code: Annotated[Optional[str], Field(None, description="A machine-readable error code.")] - message: Annotated[Optional[str], Field(None, description="A human-readable error message.")] + code: Annotated[Optional[str], Field(description="A machine-readable error code.")] = None + message: Annotated[Optional[str], Field(description="A human-readable error message.")] = None class BatchRequestOutput(BaseModel): @@ -3285,46 +3080,41 @@ class BatchRequestOutput(BaseModel): custom_id: Annotated[ Optional[str], Field( - None, - description="A developer-provided per-request id that will be used to match outputs to inputs.", + description="A developer-provided per-request id that will be used to match outputs to inputs." ), - ] + ] = None response: Optional[Response] = None error: Annotated[ Optional[Error2], Field( - None, - description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure.", + description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure." ), - ] + ] = None class ListBatchesResponse(BaseModel): data: List[Batch] - first_id: Annotated[Optional[str], Field(None, examples=["batch_abc123"])] - last_id: Annotated[Optional[str], Field(None, examples=["batch_abc456"])] + first_id: Annotated[Optional[str], Field(example="batch_abc123")] = None + last_id: Annotated[Optional[str], Field(example="batch_abc456")] = None has_more: bool object: Literal["list"] class AuditLogActorServiceAccount(BaseModel): - id: Annotated[Optional[str], Field(None, description="The service account id.")] + id: Annotated[Optional[str], Field(description="The service account id.")] = None class AuditLogActorUser(BaseModel): - id: Annotated[Optional[str], Field(None, description="The user id.")] - email: Annotated[Optional[str], Field(None, description="The user email.")] + id: Annotated[Optional[str], Field(description="The user id.")] = None + email: Annotated[Optional[str], Field(description="The user email.")] = None class AuditLogActorApiKey(BaseModel): - id: Annotated[Optional[str], Field(None, description="The tracking id of the API key.")] + id: Annotated[Optional[str], Field(description="The tracking id of the API key.")] = None type: Annotated[ Optional[Literal["user", "service_account"]], - Field( - None, - description="The type of API key. Can be either `user` or `service_account`.", - ), - ] + Field(description="The type of API key. Can be either `user` or `service_account`."), + ] = None user: Optional[AuditLogActorUser] = None service_account: Optional[AuditLogActorServiceAccount] = None @@ -3333,46 +3123,21 @@ class AuditLogActorSession(BaseModel): user: Optional[AuditLogActorUser] = None ip_address: Annotated[ Optional[str], - Field(None, description="The IP address from which the action was performed."), - ] + Field(description="The IP address from which the action was performed."), + ] = None class AuditLogActor(BaseModel): type: Annotated[ Optional[Literal["session", "api_key"]], - Field(None, description="The type of actor. Is either `session` or `api_key`."), - ] + Field(description="The type of actor. Is either `session` or `api_key`."), + ] = None session: Optional[AuditLogActorSession] = None api_key: Optional[AuditLogActorApiKey] = None -class AuditLogEventType( - RootModel[ - Literal[ - "api_key.created", - "api_key.updated", - "api_key.deleted", - "invite.sent", - "invite.accepted", - "invite.deleted", - "login.succeeded", - "login.failed", - "logout.succeeded", - "logout.failed", - "organization.updated", - "project.created", - "project.updated", - "project.archived", - "service_account.created", - "service_account.updated", - "service_account.deleted", - "user.added", - "user.updated", - "user.deleted", - ] - ] -): - root: Annotated[ +class AuditLogEventType(BaseModel): + __root__: Annotated[ Literal[ "api_key.created", "api_key.updated", @@ -3400,232 +3165,212 @@ class AuditLogEventType( class Project(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] - name: Annotated[Optional[str], Field(None, description="The project title.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None + name: Annotated[Optional[str], Field(description="The project title.")] = None class Data(BaseModel): scopes: Annotated[ Optional[List[str]], - Field( - None, - description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`', - ), - ] + Field(description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`'), + ] = None class ApiKeyCreated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None data: Annotated[ - Optional[Data], - Field(None, description="The payload used to create the API key."), - ] + Optional[Data], Field(description="The payload used to create the API key.") + ] = None class ChangesRequested(BaseModel): scopes: Annotated[ Optional[List[str]], - Field( - None, - description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`', - ), - ] + Field(description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`'), + ] = None class ApiKeyUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None changes_requested: Annotated[ Optional[ChangesRequested], - Field(None, description="The payload used to update the API key."), - ] + Field(description="The payload used to update the API key."), + ] = None class ApiKeyDeleted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None class Data1(BaseModel): email: Annotated[ - Optional[str], Field(None, description="The email invited to the organization.") - ] + Optional[str], Field(description="The email invited to the organization.") + ] = None role: Annotated[ Optional[str], - Field( - None, - description="The role the email was invited to be. Is either `owner` or `member`.", - ), - ] + Field(description="The role the email was invited to be. Is either `owner` or `member`."), + ] = None class InviteSent(BaseModel): - id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None data: Annotated[ - Optional[Data1], - Field(None, description="The payload used to create the invite."), - ] + Optional[Data1], Field(description="The payload used to create the invite.") + ] = None class InviteAccepted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None class InviteDeleted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None class LoginFailed(BaseModel): - error_code: Annotated[Optional[str], Field(None, description="The error code of the failure.")] + error_code: Annotated[Optional[str], Field(description="The error code of the failure.")] = None error_message: Annotated[ - Optional[str], Field(None, description="The error message of the failure.") - ] + Optional[str], Field(description="The error message of the failure.") + ] = None class LogoutFailed(BaseModel): - error_code: Annotated[Optional[str], Field(None, description="The error code of the failure.")] + error_code: Annotated[Optional[str], Field(description="The error code of the failure.")] = None error_message: Annotated[ - Optional[str], Field(None, description="The error message of the failure.") - ] + Optional[str], Field(description="The error message of the failure.") + ] = None class Settings(BaseModel): threads_ui_visibility: Annotated[ Optional[str], Field( - None, - description="Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`.", + description="Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`." ), - ] + ] = None usage_dashboard_visibility: Annotated[ Optional[str], Field( - None, - description="Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`.", + description="Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`." ), - ] + ] = None class ChangesRequested1(BaseModel): - title: Annotated[Optional[str], Field(None, description="The organization title.")] - description: Annotated[Optional[str], Field(None, description="The organization description.")] - name: Annotated[Optional[str], Field(None, description="The organization name.")] + title: Annotated[Optional[str], Field(description="The organization title.")] = None + description: Annotated[Optional[str], Field(description="The organization description.")] = None + name: Annotated[Optional[str], Field(description="The organization name.")] = None settings: Optional[Settings] = None class OrganizationUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The organization ID.")] + id: Annotated[Optional[str], Field(description="The organization ID.")] = None changes_requested: Annotated[ Optional[ChangesRequested1], - Field(None, description="The payload used to update the organization settings."), - ] + Field(description="The payload used to update the organization settings."), + ] = None class Data2(BaseModel): - name: Annotated[Optional[str], Field(None, description="The project name.")] + name: Annotated[Optional[str], Field(description="The project name.")] = None title: Annotated[ Optional[str], - Field(None, description="The title of the project as seen on the dashboard."), - ] + Field(description="The title of the project as seen on the dashboard."), + ] = None class ProjectCreated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None data: Annotated[ - Optional[Data2], - Field(None, description="The payload used to create the project."), - ] + Optional[Data2], Field(description="The payload used to create the project.") + ] = None class ChangesRequested2(BaseModel): title: Annotated[ Optional[str], - Field(None, description="The title of the project as seen on the dashboard."), - ] + Field(description="The title of the project as seen on the dashboard."), + ] = None class ProjectUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None changes_requested: Annotated[ Optional[ChangesRequested2], - Field(None, description="The payload used to update the project."), - ] + Field(description="The payload used to update the project."), + ] = None class ProjectArchived(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None class Data3(BaseModel): role: Annotated[ Optional[str], - Field( - None, - description="The role of the service account. Is either `owner` or `member`.", - ), - ] + Field(description="The role of the service account. Is either `owner` or `member`."), + ] = None class ServiceAccountCreated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The service account ID.")] + id: Annotated[Optional[str], Field(description="The service account ID.")] = None data: Annotated[ Optional[Data3], - Field(None, description="The payload used to create the service account."), - ] + Field(description="The payload used to create the service account."), + ] = None class ChangesRequested3(BaseModel): role: Annotated[ Optional[str], - Field( - None, - description="The role of the service account. Is either `owner` or `member`.", - ), - ] + Field(description="The role of the service account. Is either `owner` or `member`."), + ] = None class ServiceAccountUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The service account ID.")] + id: Annotated[Optional[str], Field(description="The service account ID.")] = None changes_requested: Annotated[ Optional[ChangesRequested3], - Field(None, description="The payload used to updated the service account."), - ] + Field(description="The payload used to updated the service account."), + ] = None class ServiceAccountDeleted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The service account ID.")] + id: Annotated[Optional[str], Field(description="The service account ID.")] = None class Data4(BaseModel): role: Annotated[ Optional[str], - Field(None, description="The role of the user. Is either `owner` or `member`."), - ] + Field(description="The role of the user. Is either `owner` or `member`."), + ] = None class UserAdded(BaseModel): - id: Annotated[Optional[str], Field(None, description="The user ID.")] + id: Annotated[Optional[str], Field(description="The user ID.")] = None data: Annotated[ Optional[Data4], - Field(None, description="The payload used to add the user to the project."), - ] + Field(description="The payload used to add the user to the project."), + ] = None class ChangesRequested4(BaseModel): role: Annotated[ Optional[str], - Field(None, description="The role of the user. Is either `owner` or `member`."), - ] + Field(description="The role of the user. Is either `owner` or `member`."), + ] = None class UserUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None changes_requested: Annotated[ Optional[ChangesRequested4], - Field(None, description="The payload used to update the user."), - ] + Field(description="The payload used to update the user."), + ] = None class UserDeleted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The user ID.")] + id: Annotated[Optional[str], Field(description="The user ID.")] = None class AuditLog(BaseModel): @@ -3635,162 +3380,128 @@ class AuditLog(BaseModel): project: Annotated[ Optional[Project], Field( - None, - description="The project that the action was scoped to. Absent for actions not scoped to projects.", + description="The project that the action was scoped to. Absent for actions not scoped to projects." ), - ] + ] = None actor: AuditLogActor api_key_created: Annotated[ Optional[ApiKeyCreated], Field( - None, alias="api_key.created", description="The details for events with this `type`.", ), - ] + ] = None api_key_updated: Annotated[ Optional[ApiKeyUpdated], Field( - None, alias="api_key.updated", description="The details for events with this `type`.", ), - ] + ] = None api_key_deleted: Annotated[ Optional[ApiKeyDeleted], Field( - None, alias="api_key.deleted", description="The details for events with this `type`.", ), - ] + ] = None invite_sent: Annotated[ Optional[InviteSent], - Field( - None, - alias="invite.sent", - description="The details for events with this `type`.", - ), - ] + Field(alias="invite.sent", description="The details for events with this `type`."), + ] = None invite_accepted: Annotated[ Optional[InviteAccepted], Field( - None, alias="invite.accepted", description="The details for events with this `type`.", ), - ] + ] = None invite_deleted: Annotated[ Optional[InviteDeleted], Field( - None, alias="invite.deleted", description="The details for events with this `type`.", ), - ] + ] = None login_failed: Annotated[ Optional[LoginFailed], - Field( - None, - alias="login.failed", - description="The details for events with this `type`.", - ), - ] + Field(alias="login.failed", description="The details for events with this `type`."), + ] = None logout_failed: Annotated[ Optional[LogoutFailed], Field( - None, alias="logout.failed", description="The details for events with this `type`.", ), - ] + ] = None organization_updated: Annotated[ Optional[OrganizationUpdated], Field( - None, alias="organization.updated", description="The details for events with this `type`.", ), - ] + ] = None project_created: Annotated[ Optional[ProjectCreated], Field( - None, alias="project.created", description="The details for events with this `type`.", ), - ] + ] = None project_updated: Annotated[ Optional[ProjectUpdated], Field( - None, alias="project.updated", description="The details for events with this `type`.", ), - ] + ] = None project_archived: Annotated[ Optional[ProjectArchived], Field( - None, alias="project.archived", description="The details for events with this `type`.", ), - ] + ] = None service_account_created: Annotated[ Optional[ServiceAccountCreated], Field( - None, alias="service_account.created", description="The details for events with this `type`.", ), - ] + ] = None service_account_updated: Annotated[ Optional[ServiceAccountUpdated], Field( - None, alias="service_account.updated", description="The details for events with this `type`.", ), - ] + ] = None service_account_deleted: Annotated[ Optional[ServiceAccountDeleted], Field( - None, alias="service_account.deleted", description="The details for events with this `type`.", ), - ] + ] = None user_added: Annotated[ Optional[UserAdded], - Field( - None, - alias="user.added", - description="The details for events with this `type`.", - ), - ] + Field(alias="user.added", description="The details for events with this `type`."), + ] = None user_updated: Annotated[ Optional[UserUpdated], - Field( - None, - alias="user.updated", - description="The details for events with this `type`.", - ), - ] + Field(alias="user.updated", description="The details for events with this `type`."), + ] = None user_deleted: Annotated[ Optional[UserDeleted], - Field( - None, - alias="user.deleted", - description="The details for events with this `type`.", - ), - ] + Field(alias="user.deleted", description="The details for events with this `type`."), + ] = None class ListAuditLogsResponse(BaseModel): object: Literal["list"] data: List[AuditLog] - first_id: Annotated[str, Field(examples=["audit_log-defb456h8dks"])] - last_id: Annotated[str, Field(examples=["audit_log-hnbkd8s93s"])] + first_id: Annotated[str, Field(example="audit_log-defb456h8dks")] + last_id: Annotated[str, Field(example="audit_log-hnbkd8s93s")] has_more: bool @@ -3822,11 +3533,8 @@ class Invite(BaseModel): ] accepted_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) of when the invite was accepted.", - ), - ] + Field(description="The Unix timestamp (in seconds) of when the invite was accepted."), + ] = None class InviteListResponse(BaseModel): @@ -3834,19 +3542,17 @@ class InviteListResponse(BaseModel): data: List[Invite] first_id: Annotated[ Optional[str], - Field(None, description="The first `invite_id` in the retrieved `list`"), - ] + Field(description="The first `invite_id` in the retrieved `list`"), + ] = None last_id: Annotated[ - Optional[str], - Field(None, description="The last `invite_id` in the retrieved `list`"), - ] + Optional[str], Field(description="The last `invite_id` in the retrieved `list`") + ] = None has_more: Annotated[ Optional[bool], Field( - None, - description="The `has_more` property is used for pagination to indicate there are additional results.", + description="The `has_more` property is used for pagination to indicate there are additional results." ), - ] + ] = None class InviteRequest(BaseModel): @@ -3916,10 +3622,9 @@ class Project1(BaseModel): archived_at: Annotated[ Optional[int], Field( - None, - description="The Unix timestamp (in seconds) of when the project was archived or `null`.", + description="The Unix timestamp (in seconds) of when the project was archived or `null`." ), - ] + ] = None status: Annotated[Literal["active", "archived"], Field(description="`active` or `archived`")] @@ -4046,8 +3751,8 @@ class ProjectServiceAccountDeleteResponse(BaseModel): class Owner(BaseModel): type: Annotated[ Optional[Literal["user", "service_account"]], - Field(None, description="`user` or `service_account`"), - ] + Field(description="`user` or `service_account`"), + ] = None user: Optional[ProjectUser] = None service_account: Optional[ProjectServiceAccount] = None @@ -4105,129 +3810,115 @@ class CreateCompletionRequest(BaseModel): best_of: Annotated[ Optional[int], Field( - 1, description='Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.\n\nWhen used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n', ge=0, le=20, ), - ] + ] = 1 echo: Annotated[ Optional[bool], - Field(False, description="Echo back the prompt in addition to the completion\n"), - ] + Field(description="Echo back the prompt in addition to the completion\n"), + ] = False frequency_penalty: Annotated[ Optional[float], Field( - 0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", ge=-2.0, le=2.0, ), - ] + ] = 0 logit_bias: Annotated[ Optional[Dict[str, int]], Field( - None, - description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n', + description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n' ), - ] + ] = None logprobs: Annotated[ Optional[int], Field( - None, description="Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.\n\nThe maximum value for `logprobs` is 5.\n", ge=0, le=5, ), - ] + ] = None max_tokens: Annotated[ Optional[int], Field( - 16, description="The maximum number of [tokens](/tokenizer) that can be generated in the completion.\n\nThe token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", - examples=[16], + example=16, ge=0, ), - ] + ] = 16 n: Annotated[ Optional[int], Field( - 1, description="How many completions to generate for each prompt.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n", - examples=[1], + example=1, ge=1, le=128, ), - ] + ] = 1 presence_penalty: Annotated[ Optional[float], Field( - 0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", ge=-2.0, le=2.0, ), - ] + ] = 0 seed: Annotated[ Optional[int], Field( - None, description="If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\n\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", ge=-9223372036854775808, le=9223372036854775807, ), - ] + ] = None stop: Annotated[ Optional[Union[str, Stop]], Field( - None, - description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n" ), - ] + ] = None stream: Annotated[ Optional[bool], Field( - False, - description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n" ), - ] + ] = False stream_options: Optional[ChatCompletionStreamOptions] = None suffix: Annotated[ Optional[str], Field( - None, description="The suffix that comes after a completion of inserted text.\n\nThis parameter is only supported for `gpt-3.5-turbo-instruct`.\n", - examples=["test."], + example="test.", ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", - examples=[1], + example=1, ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", - examples=[1], + example=1, ge=0.0, le=1.0, ), - ] + ] = 1 user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", - examples=["user-1234"], + example="user-1234", ), - ] + ] = None class CreateCompletionResponse(BaseModel): @@ -4246,10 +3937,9 @@ class CreateCompletionResponse(BaseModel): system_fingerprint: Annotated[ Optional[str], Field( - None, - description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" ), - ] + ] = None object: Annotated[ Literal["text_completion"], Field(description='The object type, which is always "text_completion"'), @@ -4265,10 +3955,8 @@ class ChatCompletionTool(BaseModel): function: FunctionObject -class ChatCompletionToolChoiceOption( - RootModel[Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoice]] -): - root: Annotated[ +class ChatCompletionToolChoiceOption(BaseModel): + __root__: Annotated[ Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoice], Field( description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tool and instead generates a message.\n`auto` means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools.\nSpecifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n\n`none` is the default when no tools are present. `auto` is the default if tools are present.\n' @@ -4276,8 +3964,8 @@ class ChatCompletionToolChoiceOption( ] -class ChatCompletionMessageToolCalls(RootModel[List[ChatCompletionMessageToolCall]]): - root: Annotated[ +class ChatCompletionMessageToolCalls(BaseModel): + __root__: Annotated[ List[ChatCompletionMessageToolCall], Field(description="The tool calls generated by the model, such as function calls."), ] @@ -4294,10 +3982,9 @@ class ChatCompletionResponseMessage(BaseModel): function_call: Annotated[ Optional[FunctionCall], Field( - None, - description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." ), - ] + ] = None class Choice1(BaseModel): @@ -4330,18 +4017,16 @@ class CreateChatCompletionResponse(BaseModel): service_tier: Annotated[ Optional[Literal["scale", "default"]], Field( - None, description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", - examples=["scale"], + example="scale", ), - ] + ] = None system_fingerprint: Annotated[ Optional[str], Field( - None, - description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" ), - ] + ] = None object: Annotated[ Literal["chat.completion"], Field(description="The object type, which is always `chat.completion`."), @@ -4378,10 +4063,9 @@ class CreateChatCompletionFunctionResponse(BaseModel): system_fingerprint: Annotated[ Optional[str], Field( - None, - description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" ), - ] + ] = None object: Annotated[ Literal["chat.completion"], Field(description="The object type, which is always `chat.completion`."), @@ -4502,19 +4186,17 @@ class FineTuningJob(BaseModel): integrations: Annotated[ Optional[List[FineTuningIntegration]], Field( - None, description="A list of integrations to enable for this fine-tuning job.", - max_length=5, + max_items=5, ), - ] + ] = None seed: Annotated[int, Field(description="The seed used for the fine-tuning job.")] estimated_finish: Annotated[ Optional[int], Field( - None, - description="The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.", + description="The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running." ), - ] + ] = None class AssistantObject(BaseModel): @@ -4561,16 +4243,15 @@ class AssistantObject(BaseModel): List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], Field( description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", - max_length=128, + max_items=128, ), ] tool_resources: Annotated[ Optional[ToolResources], Field( - None, - description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Dict[str, Any], Field( @@ -4580,30 +4261,28 @@ class AssistantObject(BaseModel): temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", - examples=[1], + example=1, ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", - examples=[1], + example=1, ge=0.0, le=1.0, ), - ] + ] = 1 response_format: Optional[AssistantsApiResponseFormatOption] = None class CreateAssistantRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + model: Annotated[ Union[ str, @@ -4636,170 +4315,151 @@ class CreateAssistantRequest(BaseModel): ], Field( description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", - examples=["gpt-4o"], + example="gpt-4o", ), ] name: Annotated[ Optional[str], Field( - None, description="The name of the assistant. The maximum length is 256 characters.\n", max_length=256, ), - ] + ] = None description: Annotated[ Optional[str], Field( - None, description="The description of the assistant. The maximum length is 512 characters.\n", max_length=512, ), - ] + ] = None instructions: Annotated[ Optional[str], Field( - None, description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", max_length=256000, ), - ] + ] = None tools: Annotated[ Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], Field( - [], description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", - max_length=128, + max_items=128, ), - ] + ] = [] tool_resources: Annotated[ Optional[ToolResources1], Field( - None, - description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", - examples=[1], + example=1, ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", - examples=[1], + example=1, ge=0.0, le=1.0, ), - ] + ] = 1 response_format: Optional[AssistantsApiResponseFormatOption] = None class ModifyAssistantRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + model: Annotated[ Optional[str], Field( - None, - description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" ), - ] + ] = None name: Annotated[ Optional[str], Field( - None, description="The name of the assistant. The maximum length is 256 characters.\n", max_length=256, ), - ] + ] = None description: Annotated[ Optional[str], Field( - None, description="The description of the assistant. The maximum length is 512 characters.\n", max_length=512, ), - ] + ] = None instructions: Annotated[ Optional[str], Field( - None, description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", max_length=256000, ), - ] + ] = None tools: Annotated[ Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], Field( - [], description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", - max_length=128, + max_items=128, ), - ] + ] = [] tool_resources: Annotated[ Optional[ToolResources2], Field( - None, - description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", - examples=[1], + example=1, ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", - examples=[1], + example=1, ge=0.0, le=1.0, ), - ] + ] = 1 response_format: Optional[AssistantsApiResponseFormatOption] = None class ListAssistantsResponse(BaseModel): - object: Annotated[str, Field(examples=["list"])] + object: Annotated[str, Field(example="list")] data: List[AssistantObject] - first_id: Annotated[str, Field(examples=["asst_abc123"])] - last_id: Annotated[str, Field(examples=["asst_abc456"])] - has_more: Annotated[bool, Field(examples=[False])] + first_id: Annotated[str, Field(example="asst_abc123")] + last_id: Annotated[str, Field(example="asst_abc456")] + has_more: Annotated[bool, Field(example=False)] -class AssistantsApiToolChoiceOption( - RootModel[Union[Literal["none", "auto", "required"], AssistantsNamedToolChoice]] -): - root: Annotated[ +class AssistantsApiToolChoiceOption(BaseModel): + __root__: Annotated[ Union[Literal["none", "auto", "required"], AssistantsNamedToolChoice], Field( description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tools and instead generates a message.\n`auto` is the default value and means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools before responding to the user.\nSpecifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n' @@ -4919,7 +4579,7 @@ class RunObject(BaseModel): List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], Field( description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.", - max_length=20, + max_items=20, ), ] metadata: Annotated[ @@ -4931,18 +4591,14 @@ class RunObject(BaseModel): usage: RunCompletionUsage temperature: Annotated[ Optional[float], - Field( - None, - description="The sampling temperature used for this run. If not set, defaults to 1.", - ), - ] + Field(description="The sampling temperature used for this run. If not set, defaults to 1."), + ] = None top_p: Annotated[ Optional[float], Field( - None, - description="The nucleus sampling value used for this run. If not set, defaults to 1.", + description="The nucleus sampling value used for this run. If not set, defaults to 1." ), - ] + ] = None max_prompt_tokens: Annotated[ int, Field( @@ -4964,25 +4620,15 @@ class RunObject(BaseModel): class ListRunsResponse(BaseModel): - object: Annotated[str, Field(examples=["list"])] + object: Annotated[str, Field(example="list")] data: List[RunObject] - first_id: Annotated[str, Field(examples=["run_abc123"])] - last_id: Annotated[str, Field(examples=["run_abc456"])] - has_more: Annotated[bool, Field(examples=[False])] + first_id: Annotated[str, Field(example="run_abc123")] + last_id: Annotated[str, Field(example="run_abc456")] + has_more: Annotated[bool, Field(example=False)] -class Content4( - RootModel[ - List[ - Union[ - MessageContentImageFileObject, - MessageContentImageUrlObject, - MessageRequestContentTextObject, - ] - ] - ] -): - root: Annotated[ +class Content4(BaseModel): + __root__: Annotated[ List[ Union[ MessageContentImageFileObject, @@ -4992,16 +4638,16 @@ class Content4( ], Field( description="An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](/docs/models/overview).", - min_length=1, + min_items=1, title="Array of content parts", ), ] class CreateMessageRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + role: Annotated[ Literal["user", "assistant"], Field( @@ -5012,17 +4658,15 @@ class CreateMessageRequest(BaseModel): attachments: Annotated[ Optional[List[Attachment]], Field( - None, - description="A list of files attached to the message, and the tools they should be added to.", + description="A list of files attached to the message, and the tools they should be added to." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class Text(BaseModel): @@ -5041,7 +4685,7 @@ class MessageContentTextObject(BaseModel): class Text1(BaseModel): - value: Annotated[Optional[str], Field(None, description="The data that makes up the text.")] + value: Annotated[Optional[str], Field(description="The data that makes up the text.")] = None annotations: Optional[ List[ Union[ @@ -5089,9 +4733,8 @@ class RunStepDetailsToolCallsCodeObject(BaseModel): class CodeInterpreter8(BaseModel): input: Annotated[ - Optional[str], - Field(None, description="The input to the Code Interpreter tool call."), - ] + Optional[str], Field(description="The input to the Code Interpreter tool call.") + ] = None outputs: Annotated[ Optional[ List[ @@ -5102,15 +4745,14 @@ class CodeInterpreter8(BaseModel): ] ], Field( - None, - description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type.", + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type." ), - ] + ] = None class RunStepDeltaStepDetailsToolCallsCodeObject(BaseModel): index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] - id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call.")] = None type: Annotated[ Literal["code_interpreter"], Field( @@ -5119,44 +4761,41 @@ class RunStepDeltaStepDetailsToolCallsCodeObject(BaseModel): ] code_interpreter: Annotated[ Optional[CodeInterpreter8], - Field(None, description="The Code Interpreter tool call definition."), - ] + Field(description="The Code Interpreter tool call definition."), + ] = None class CreateVectorStoreRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", - max_length=500, + max_items=500, ), - ] - name: Annotated[Optional[str], Field(None, description="The name of the vector store.")] + ] = None + name: Annotated[Optional[str], Field(description="The name of the vector store.")] = None expires_after: Optional[VectorStoreExpirationAfter] = None chunking_strategy: Annotated[ Optional[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class StaticChunkingStrategyResponseParam(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + type: Annotated[Literal["static"], Field(description="Always `static`.")] static: StaticChunkingStrategy @@ -5211,23 +4850,8 @@ class RunStreamEvent10(BaseModel): data: RunObject -class RunStreamEvent( - RootModel[ - Union[ - RunStreamEvent1, - RunStreamEvent2, - RunStreamEvent3, - RunStreamEvent4, - RunStreamEvent5, - RunStreamEvent6, - RunStreamEvent7, - RunStreamEvent8, - RunStreamEvent9, - RunStreamEvent10, - ] - ] -): - root: Union[ +class RunStreamEvent(BaseModel): + __root__: Union[ RunStreamEvent1, RunStreamEvent2, RunStreamEvent3, @@ -5257,13 +4881,12 @@ class ChatCompletionRequestAssistantMessage(BaseModel): content: Annotated[ Optional[Union[str, Content2]], Field( - None, - description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n", + description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n" ), - ] + ] = None refusal: Annotated[ - Optional[str], Field(None, description="The refusal message by the assistant.") - ] + Optional[str], Field(description="The refusal message by the assistant.") + ] = None role: Annotated[ Literal["assistant"], Field(description="The role of the messages author, in this case `assistant`."), @@ -5271,28 +4894,23 @@ class ChatCompletionRequestAssistantMessage(BaseModel): name: Annotated[ Optional[str], Field( - None, - description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." ), - ] + ] = None tool_calls: Optional[ChatCompletionMessageToolCalls] = None function_call: Annotated[ Optional[FunctionCall], Field( - None, - description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." ), - ] + ] = None class FineTuneChatCompletionRequestAssistantMessage(ChatCompletionRequestAssistantMessage): weight: Annotated[ Optional[Literal[0, 1]], - Field( - None, - description="Controls whether the assistant message is trained against (0 or 1)", - ), - ] + Field(description="Controls whether the assistant message is trained against (0 or 1)"), + ] = None role: Annotated[ Literal["assistant"], Field(description="The role of the messages author, in this case `assistant`."), @@ -5318,28 +4936,27 @@ class FinetuneChatRequestInput(BaseModel): ] ] ], - Field(None, min_length=1), - ] + Field(min_items=1), + ] = None tools: Annotated[ Optional[List[ChatCompletionTool]], - Field(None, description="A list of tools the model may generate JSON inputs for."), - ] + Field(description="A list of tools the model may generate JSON inputs for."), + ] = None parallel_tool_calls: Optional[ParallelToolCalls] = None functions: Annotated[ Optional[List[ChatCompletionFunctions]], Field( - None, description="A list of functions the model may generate JSON inputs for.", - max_length=128, - min_length=1, + max_items=128, + min_items=1, ), - ] + ] = None class CreateRunRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + assistant_id: Annotated[ str, Field( @@ -5379,90 +4996,77 @@ class CreateRunRequest(BaseModel): ] ], Field( - None, description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", - examples=["gpt-4o"], + example="gpt-4o", ), - ] + ] = None instructions: Annotated[ Optional[str], Field( - None, - description="Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis.", + description="Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis." ), - ] + ] = None additional_instructions: Annotated[ Optional[str], Field( - None, - description="Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions.", + description="Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions." ), - ] + ] = None additional_messages: Annotated[ Optional[List[CreateMessageRequest]], - Field( - None, - description="Adds additional messages to the thread before creating the run.", - ), - ] + Field(description="Adds additional messages to the thread before creating the run."), + ] = None tools: Annotated[ Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], Field( - None, description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", - max_length=20, + max_items=20, ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", - examples=[1], + example=1, ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", - examples=[1], + example=1, ge=0.0, le=1.0, ), - ] + ] = 1 stream: Annotated[ Optional[bool], Field( - None, - description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" ), - ] + ] = None max_prompt_tokens: Annotated[ Optional[int], Field( - None, description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", ge=256, ), - ] + ] = None max_completion_tokens: Annotated[ Optional[int], Field( - None, description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", ge=256, ), - ] + ] = None truncation_strategy: Optional[TruncationObject] = None tool_choice: Optional[AssistantsApiToolChoiceOption] = None parallel_tool_calls: Optional[ParallelToolCalls] = None @@ -5470,30 +5074,27 @@ class CreateRunRequest(BaseModel): class CreateThreadRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + messages: Annotated[ Optional[List[CreateMessageRequest]], Field( - None, - description="A list of [messages](/docs/api-reference/messages) to start the thread with.", + description="A list of [messages](/docs/api-reference/messages) to start the thread with." ), - ] + ] = None tool_resources: Annotated[ Optional[ToolResources5], Field( - None, - description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class MessageObject(BaseModel): @@ -5579,11 +5180,8 @@ class MessageObject(BaseModel): class Delta(BaseModel): role: Annotated[ Optional[Literal["user", "assistant"]], - Field( - None, - description="The entity that produced the message. One of `user` or `assistant`.", - ), - ] + Field(description="The entity that produced the message. One of `user` or `assistant`."), + ] = None content: Annotated[ Optional[ List[ @@ -5595,11 +5193,8 @@ class Delta(BaseModel): ] ] ], - Field( - None, - description="The content of the message in array of text and/or images.", - ), - ] + Field(description="The content of the message in array of text and/or images."), + ] = None class MessageDeltaObject(BaseModel): @@ -5620,11 +5215,11 @@ class MessageDeltaObject(BaseModel): class ListMessagesResponse(BaseModel): - object: Annotated[str, Field(examples=["list"])] + object: Annotated[str, Field(example="list")] data: List[MessageObject] - first_id: Annotated[str, Field(examples=["msg_abc123"])] - last_id: Annotated[str, Field(examples=["msg_abc123"])] - has_more: Annotated[bool, Field(examples=[False])] + first_id: Annotated[str, Field(example="msg_abc123")] + last_id: Annotated[str, Field(example="msg_abc123")] + has_more: Annotated[bool, Field(example=False)] class RunStepDetailsToolCallsObject(BaseModel): @@ -5656,10 +5251,9 @@ class RunStepDeltaStepDetailsToolCallsObject(BaseModel): ] ], Field( - None, - description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n", + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n" ), - ] + ] = None class VectorStoreFileObject(BaseModel): @@ -5703,16 +5297,16 @@ class VectorStoreFileObject(BaseModel): ] chunking_strategy: Annotated[ Optional[Union[StaticChunkingStrategyResponseParam, OtherChunkingStrategyResponseParam]], - Field(None, description="The strategy used to chunk the file."), - ] + Field(description="The strategy used to chunk the file."), + ] = None class ListVectorStoreFilesResponse(BaseModel): - object: Annotated[str, Field(examples=["list"])] + object: Annotated[str, Field(example="list")] data: List[VectorStoreFileObject] - first_id: Annotated[str, Field(examples=["file-abc123"])] - last_id: Annotated[str, Field(examples=["file-abc456"])] - has_more: Annotated[bool, Field(examples=[False])] + first_id: Annotated[str, Field(example="file-abc123")] + last_id: Annotated[str, Field(example="file-abc456")] + has_more: Annotated[bool, Field(example=False)] class MessageStreamEvent1(BaseModel): @@ -5740,18 +5334,8 @@ class MessageStreamEvent5(BaseModel): data: MessageObject -class MessageStreamEvent( - RootModel[ - Union[ - MessageStreamEvent1, - MessageStreamEvent2, - MessageStreamEvent3, - MessageStreamEvent4, - MessageStreamEvent5, - ] - ] -): - root: Union[ +class MessageStreamEvent(BaseModel): + __root__: Union[ MessageStreamEvent1, MessageStreamEvent2, MessageStreamEvent3, @@ -5760,18 +5344,8 @@ class MessageStreamEvent( ] -class ChatCompletionRequestMessage( - RootModel[ - Union[ - ChatCompletionRequestSystemMessage, - ChatCompletionRequestUserMessage, - ChatCompletionRequestAssistantMessage, - ChatCompletionRequestToolMessage, - ChatCompletionRequestFunctionMessage, - ] - ] -): - root: Annotated[ +class ChatCompletionRequestMessage(BaseModel): + __root__: Annotated[ Union[ ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, @@ -5788,7 +5362,7 @@ class CreateChatCompletionRequest(BaseModel): List[ChatCompletionRequestMessage], Field( description="A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).", - min_length=1, + min_items=1, ), ] model: Annotated[ @@ -5824,164 +5398,144 @@ class CreateChatCompletionRequest(BaseModel): ], Field( description="ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.", - examples=["gpt-4o"], + example="gpt-4o", ), ] frequency_penalty: Annotated[ Optional[float], Field( - 0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", ge=-2.0, le=2.0, ), - ] + ] = 0 logit_bias: Annotated[ Optional[Dict[str, int]], Field( - None, - description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n", + description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n" ), - ] + ] = None logprobs: Annotated[ Optional[bool], Field( - False, - description="Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.", + description="Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`." ), - ] + ] = False top_logprobs: Annotated[ Optional[int], Field( - None, description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.", ge=0, le=20, ), - ] + ] = None max_tokens: Annotated[ Optional[int], Field( - None, - description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n" ), - ] + ] = None n: Annotated[ Optional[int], Field( - 1, description="How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs.", - examples=[1], + example=1, ge=1, le=128, ), - ] + ] = 1 presence_penalty: Annotated[ Optional[float], Field( - 0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", ge=-2.0, le=2.0, ), - ] + ] = 0 response_format: Annotated[ Optional[Union[ResponseFormatText, ResponseFormatJsonObject, ResponseFormatJsonSchema]], Field( - None, - description='An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n', + description='An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' ), - ] + ] = None seed: Annotated[ Optional[int], Field( - None, description="This feature is in Beta.\nIf specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", ge=-9223372036854775808, le=9223372036854775807, ), - ] + ] = None service_tier: Annotated[ Optional[Literal["auto", "default"]], Field( - None, - description="Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service:\n - If set to 'auto', the system will utilize scale tier credits until they are exhausted.\n - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee.\n - When not set, the default behavior is 'auto'.\n\n When this parameter is set, the response body will include the `service_tier` utilized.\n", + description="Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service:\n - If set to 'auto', the system will utilize scale tier credits until they are exhausted.\n - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee.\n - When not set, the default behavior is 'auto'.\n\n When this parameter is set, the response body will include the `service_tier` utilized.\n" ), - ] + ] = None stop: Annotated[ Optional[Union[str, Stop1]], - Field( - None, - description="Up to 4 sequences where the API will stop generating further tokens.\n", - ), - ] + Field(description="Up to 4 sequences where the API will stop generating further tokens.\n"), + ] = None stream: Annotated[ Optional[bool], Field( - False, - description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n" ), - ] + ] = False stream_options: Optional[ChatCompletionStreamOptions] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", - examples=[1], + example=1, ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", - examples=[1], + example=1, ge=0.0, le=1.0, ), - ] + ] = 1 tools: Annotated[ Optional[List[ChatCompletionTool]], Field( - None, - description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.\n", + description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.\n" ), - ] + ] = None tool_choice: Optional[ChatCompletionToolChoiceOption] = None parallel_tool_calls: Optional[ParallelToolCalls] = None user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", - examples=["user-1234"], + example="user-1234", ), - ] + ] = None function_call: Annotated[ Optional[Union[Literal["none", "auto"], ChatCompletionFunctionCallOption]], Field( - None, - description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n', + description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n' ), - ] + ] = None functions: Annotated[ Optional[List[ChatCompletionFunctions]], Field( - None, description="Deprecated in favor of `tools`.\n\nA list of functions the model may generate JSON inputs for.\n", - max_length=128, - min_length=1, + max_items=128, + min_items=1, ), - ] + ] = None class CreateThreadAndRunRequest(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) + class Config: + extra = Extra.forbid + assistant_id: Annotated[ str, Field( @@ -5990,11 +5544,8 @@ class CreateThreadAndRunRequest(BaseModel): ] thread: Annotated[ Optional[CreateThreadRequest], - Field( - None, - description="If no thread is provided, an empty thread will be created.", - ), - ] + Field(description="If no thread is provided, an empty thread will be created."), + ] = None model: Annotated[ Optional[ Union[ @@ -6028,83 +5579,73 @@ class CreateThreadAndRunRequest(BaseModel): ] ], Field( - None, description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", - examples=["gpt-4o"], + example="gpt-4o", ), - ] + ] = None instructions: Annotated[ Optional[str], Field( - None, - description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis.", + description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis." ), - ] + ] = None tools: Annotated[ Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], Field( - None, description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", - max_length=20, + max_items=20, ), - ] + ] = None tool_resources: Annotated[ Optional[ToolResources3], Field( - None, - description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", - examples=[1], + example=1, ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", - examples=[1], + example=1, ge=0.0, le=1.0, ), - ] + ] = 1 stream: Annotated[ Optional[bool], Field( - None, - description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" ), - ] + ] = None max_prompt_tokens: Annotated[ Optional[int], Field( - None, description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", ge=256, ), - ] + ] = None max_completion_tokens: Annotated[ Optional[int], Field( - None, description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", ge=256, ), - ] + ] = None truncation_strategy: Optional[TruncationObject] = None tool_choice: Optional[AssistantsApiToolChoiceOption] = None parallel_tool_calls: Optional[ParallelToolCalls] = None @@ -6199,8 +5740,8 @@ class Delta1(BaseModel): RunStepDeltaStepDetailsToolCallsObject, ] ], - Field(None, description="The details of the run step."), - ] + Field(description="The details of the run step."), + ] = None class RunStepDeltaObject(BaseModel): @@ -6221,11 +5762,11 @@ class RunStepDeltaObject(BaseModel): class ListRunStepsResponse(BaseModel): - object: Annotated[str, Field(examples=["list"])] + object: Annotated[str, Field(example="list")] data: List[RunStepObject] - first_id: Annotated[str, Field(examples=["step_abc123"])] - last_id: Annotated[str, Field(examples=["step_abc456"])] - has_more: Annotated[bool, Field(examples=[False])] + first_id: Annotated[str, Field(example="step_abc123")] + last_id: Annotated[str, Field(example="step_abc456")] + has_more: Annotated[bool, Field(example=False)] class RunStepStreamEvent1(BaseModel): @@ -6263,20 +5804,8 @@ class RunStepStreamEvent7(BaseModel): data: RunStepObject -class RunStepStreamEvent( - RootModel[ - Union[ - RunStepStreamEvent1, - RunStepStreamEvent2, - RunStepStreamEvent3, - RunStepStreamEvent4, - RunStepStreamEvent5, - RunStepStreamEvent6, - RunStepStreamEvent7, - ] - ] -): - root: Union[ +class RunStepStreamEvent(BaseModel): + __root__: Union[ RunStepStreamEvent1, RunStepStreamEvent2, RunStepStreamEvent3, @@ -6287,19 +5816,8 @@ class RunStepStreamEvent( ] -class AssistantStreamEvent( - RootModel[ - Union[ - ThreadStreamEvent, - RunStreamEvent, - RunStepStreamEvent, - MessageStreamEvent, - ErrorEvent, - DoneEvent, - ] - ] -): - root: Annotated[ +class AssistantStreamEvent(BaseModel): + __root__: Annotated[ Union[ ThreadStreamEvent, RunStreamEvent, diff --git a/clients/python/llmengine/data_types/pydantic_types.py b/clients/python/llmengine/data_types/pydantic_types.py index 64d89c3d..902f42ce 100644 --- a/clients/python/llmengine/data_types/pydantic_types.py +++ b/clients/python/llmengine/data_types/pydantic_types.py @@ -1,15 +1,9 @@ -from pydantic import BaseModel as PydanticBaseModel -from pydantic import ( # noqa: F401 - ConfigDict, - Field, - HttpUrl, - RootModel, - ValidationError, - model_validator, -) +import pydantic +PYDANTIC_V2 = hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2.") -class BaseModel(PydanticBaseModel): - """Common pydantic configurations for model engine""" +if PYDANTIC_V2: + from pydantic.v1 import BaseModel, Field, HttpUrl # noqa: F401 - model_config = ConfigDict(protected_namespaces=()) +else: + from pydantic import BaseModel, Field, HttpUrl # type: ignore # noqa: F401 diff --git a/clients/python/llmengine/data_types/rest.py b/clients/python/llmengine/data_types/rest.py index e7f80189..33f62750 100644 --- a/clients/python/llmengine/data_types/rest.py +++ b/clients/python/llmengine/data_types/rest.py @@ -6,7 +6,7 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union -from .pydantic_types import BaseModel, Field, HttpUrl, RootModel +from .pydantic_types import BaseModel, Field, HttpUrl CpuSpecificationType = Union[str, int, float] StorageSpecificationType = Union[str, int, float] @@ -67,8 +67,8 @@ class CallbackmTLSAuth(BaseModel): key: str -class CallbackAuth(RootModel): - root: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") +class CallbackAuth(BaseModel): + __root__: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") class ModelEndpointDeploymentState(BaseModel): diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 2a9a8a1c..3941db63 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta37" +version = "0.0.0.beta38" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] @@ -13,7 +13,7 @@ packages = [{include = "llmengine"}] [tool.poetry.dependencies] python = "^3.8" -pydantic = ">=2.0" +pydantic = ">=1.10.17" aiohttp = "^3.8" requests = "^2.31.0" openai = "^1.30.0" @@ -29,6 +29,7 @@ pytest-mypy-plugins = "^1.10.1" [tool.pytest.ini_options] asyncio_mode = "auto" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/clients/python/setup.py b/clients/python/setup.py index 4aa4832a..43299efc 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.8", - version="0.0.0.beta37", + version="0.0.0.beta38", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py index 71c762d8..1c471851 100644 --- a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -202,6 +202,9 @@ class FilteredChatCompletionV2Request(ChatCompletionV2Request): List[FilteredCompletionV2Request], List[FilteredChatCompletionV2Request] ] CreateBatchCompletionsV2ModelConfig: TypeAlias = BatchCompletionsModelConfig +BatchCompletionContent = Union[ + CreateBatchCompletionsV1RequestContent, CreateBatchCompletionsV2RequestContent +] class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): @@ -209,7 +212,7 @@ class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): Request object for batch completions. """ - content: Optional[CreateBatchCompletionsV2RequestContent] = Field( + content: Optional[BatchCompletionContent] = Field( default=None, description=""" Either `input_data_path` or `content` needs to be provided. @@ -293,11 +296,6 @@ class GetBatchCompletionV2Response(BaseModel): job: BatchCompletionsJob -BatchCompletionContent = Union[ - CreateBatchCompletionsV1RequestContent, CreateBatchCompletionsV2RequestContent -] - - class VLLMEngineAdditionalArgs(BaseModel): max_gpu_memory_utilization: Optional[float] = Field( default=0.9, diff --git a/scripts/generate-openai-types.sh b/scripts/generate-openai-types.sh index b9b2717b..bdfd0050 100755 --- a/scripts/generate-openai-types.sh +++ b/scripts/generate-openai-types.sh @@ -6,7 +6,7 @@ BASE_DIR=${SCRIPT_DIR}/.. DEST_DIR=${BASE_DIR}/model-engine/model_engine_server/common/types/gen OPENAI_SPEC=${SCRIPT_DIR}/openai-spec.yaml -# Generate OpenAPI types +# Generate OpenAPI types for server datamodel-codegen \ --input ${OPENAI_SPEC} \ --input-file-type openapi \ @@ -14,4 +14,33 @@ datamodel-codegen \ --output-model-type pydantic_v2.BaseModel \ --enum-field-as-literal all \ --field-constraints \ - --use-annotated \ No newline at end of file + --use-annotated + +CLIENT_DIR=${BASE_DIR}/clients/python/llmengine/data_types/gen + +# Generate OpenAPI types for client +# Client is using pydantic v1 +datamodel-codegen \ + --input ${OPENAI_SPEC} \ + --input-file-type openapi \ + --output ${CLIENT_DIR}/openai.py \ + --output-model-type pydantic.BaseModel \ + --enum-field-as-literal all \ + --field-constraints \ + --use-annotated + +# Ignore mypy for this file +# I tried updating mypy.ini to ignore this file, but it didn't work +sed -i '1s/^/# mypy: ignore-errors\n/' ${CLIENT_DIR}/openai.py + +# Add conditional import for pydantic v1 and v2 +# replace line starting with 'from pydantic ' with the following multiline python code +# import pydantic +# PYDANTIC_V2 = hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2.") +# +# if PYDANTIC_V2: +# from pydantic.v1 +# +# else: +# from pydantic +sed -i -E '/^from pydantic import /{s/^from pydantic import (.*)$/import pydantic\nPYDANTIC_V2 = hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2.")\nif PYDANTIC_V2:\n from pydantic.v1 import \1 # noqa: F401\nelse:\n from pydantic import \1 # type: ignore # noqa: F401/}' ${CLIENT_DIR}/openai.py \ No newline at end of file From 49af08975494adebbd03a42c808c6e4688b87121 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 26 Aug 2024 13:14:24 -0700 Subject: [PATCH 372/425] Fix list initialization (#607) --- model-engine/model_engine_server/inference/vllm/vllm_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_batch.py b/model-engine/model_engine_server/inference/vllm/vllm_batch.py index 99e50328..f84674f7 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_batch.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_batch.py @@ -115,7 +115,7 @@ async def generate_v1_completions( generator = merge_async_iterators(*results_generators) outputs: List[Optional[CompletionV1Output]] = [None] * len(prompts) - tokens: List[List[TokenOutput]] = [list()] * len(prompts) + tokens: List[List[TokenOutput]] = [list() for _ in prompts] async for i, res in generator: # There should only be one output output = res.outputs[-1] From f425d1fa44b027c714d2700005ee5ac2c3020f51 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 28 Aug 2024 19:56:16 -0700 Subject: [PATCH 373/425] Docs for qwen2 72b instruct (#601) * Docs for qwen2 72b instruct * use default image * try again --- .circleci/config.yml | 3 ++- docs/model_zoo.md | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 5b636bee..cef79018 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -175,6 +175,7 @@ jobs: - run: name: Install integration test dependencies command: | + export DEBIAN_FRONTEND=noninteractive sudo apt-get update && sudo apt-get install -y libcurl4-openssl-dev libssl-dev python3-dev pip install -r model-engine/requirements.txt - install_client @@ -190,7 +191,7 @@ jobs: executors: ubuntu-large: machine: - image: "ubuntu-2004:202201-02" + image: default resource_class: 2xlarge commands: diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 63c5bd1f..538f87e6 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -51,6 +51,7 @@ Scale hosts the following models in the LLM Engine Model Zoo: | `deepseek-coder-v2-instruct` | ✅ | | vllm | 131072 | | `deepseek-coder-v2-lite` | ✅ | | vllm | 131072 | | `deepseek-coder-v2-lite-instruct` | ✅ | | vllm | 131072 | +| `qwen2-72b-instruct` | ✅ | | vllm | 32768 | ## Usage From cff524c7d15724b10c89c91cd4b8b1d5d30285da Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 28 Aug 2024 21:45:02 -0700 Subject: [PATCH 374/425] MLI-2847 Replace instead of patch PDB (#603) * Replace instead of patch PDB * fix --- .../k8s_endpoint_resource_delegate.py | 18 +++++-- .../test_k8s_endpoint_resource_delegate.py | 53 +++++++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index af054dba..153b8a9f 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -631,10 +631,18 @@ async def _create_pdb(pdb: Dict[str, Any], name: str) -> None: if exc.status == 409: logger.info(f"PodDisruptionBudget {name} already exists, replacing") - await policy_api.patch_namespaced_pod_disruption_budget( + existing_pdb = await policy_api.read_namespaced_pod_disruption_budget( + name=name, namespace=hmi_config.endpoint_namespace + ) + replace_pdb = pdb.copy() + if "metadata" not in replace_pdb: + replace_pdb["metadata"] = {} + replace_pdb["metadata"]["resourceVersion"] = existing_pdb.metadata.resource_version + + await policy_api.replace_namespaced_pod_disruption_budget( name=name, namespace=hmi_config.endpoint_namespace, - body=pdb, + body=replace_pdb, ) else: logger.exception("Got an exception when trying to apply the PodDisruptionBudget") @@ -834,9 +842,9 @@ async def _get_config_maps( return config_maps.items @staticmethod - async def _get_all_config_maps() -> List[ - kubernetes_asyncio.client.models.v1_config_map.V1ConfigMap - ]: + async def _get_all_config_maps() -> ( + List[kubernetes_asyncio.client.models.v1_config_map.V1ConfigMap] + ): k8s_core_api = get_kubernetes_core_client() config_maps = await k8s_core_api.list_namespaced_config_map( namespace=hmi_config.endpoint_namespace diff --git a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py index acd298b3..4e3c6415 100644 --- a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py @@ -706,3 +706,56 @@ async def test_delete_resources_sync_success( endpoint_id="", deployment_name="", endpoint_type=ModelEndpointType.SYNC ) assert deleted + + +@pytest.mark.asyncio +async def test_create_pdb( + k8s_endpoint_resource_delegate, + mock_policy_client, +): + # Mock the necessary objects and functions + pdb = { + "metadata": {"name": "test-pdb", "namespace": "test-namespace"}, + "spec": {"maxUnavailable": "50%"}, + } + name = "test-pdb" + + # Test successful creation + await k8s_endpoint_resource_delegate._create_pdb(pdb, name) + + mock_policy_client.create_namespaced_pod_disruption_budget.assert_called_once_with( + namespace=hmi_config.endpoint_namespace, + body=pdb, + ) + + # Test creation when PDB already exists + mock_policy_client.create_namespaced_pod_disruption_budget.side_effect = ApiException( + status=409 + ) + + existing_pdb = Mock() + existing_pdb.metadata.resource_version = "123" + mock_policy_client.read_namespaced_pod_disruption_budget.return_value = existing_pdb + + await k8s_endpoint_resource_delegate._create_pdb(pdb, name) + + mock_policy_client.read_namespaced_pod_disruption_budget.assert_called_once_with( + name=name, namespace=hmi_config.endpoint_namespace + ) + + expected_replace_pdb = pdb.copy() + expected_replace_pdb["metadata"]["resourceVersion"] = "123" + + mock_policy_client.replace_namespaced_pod_disruption_budget.assert_called_once_with( + name=name, + namespace=hmi_config.endpoint_namespace, + body=expected_replace_pdb, + ) + + # Test creation with other API exception + mock_policy_client.create_namespaced_pod_disruption_budget.side_effect = ApiException( + status=500 + ) + + with pytest.raises(ApiException): + await k8s_endpoint_resource_delegate._create_pdb(pdb, name) From 0600c1041279039ca53d964f46cd239ec3dcc49e Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 28 Aug 2024 23:33:39 -0700 Subject: [PATCH 375/425] Use maxUnavailale for endpoint PDB (#596) --- charts/model-engine/templates/service_template_config_map.yaml | 2 +- .../templates/service_template_config_map_circleci.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index f721eb46..a756e99e 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -557,7 +557,7 @@ data: labels: {{- $service_template_labels | nindent 8 }} spec: - minAvailable: 1 + maxUnavailable: 50% selector: matchLabels: app: ${RESOURCE_NAME} diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 3311a509..1b35ad6a 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -2749,7 +2749,7 @@ data: endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} spec: - minAvailable: 1 + maxUnavailable: 50% selector: matchLabels: app: ${RESOURCE_NAME} From 370b111a9d473d6a12fabe400ff78f27ffcb96da Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 5 Sep 2024 09:25:38 -0700 Subject: [PATCH 376/425] Http Forwarder updates (#608) * Http forwarder updates * Refactor http forwarder to enable dynamic instantiation * Fix test * Add tests for http_forwarder * Fix * nocover --- .../common/service_requests.py | 5 +- .../configs/service--http_forwarder.yaml | 2 + .../inference/forwarding/forwarding.py | 32 ++- .../inference/forwarding/http_forwarder.py | 194 +++++++++++++----- .../inference/requirements_base.txt | 2 +- model-engine/requirements-test.txt | 1 + model-engine/requirements.in | 2 +- model-engine/requirements.txt | 6 +- model-engine/tests/unit/inference/conftest.py | 17 +- .../tests/unit/inference/test_forwarding.py | 84 +++++++- .../unit/inference/test_http_forwarder.py | 174 +++++++++++++++- 11 files changed, 449 insertions(+), 70 deletions(-) diff --git a/model-engine/model_engine_server/common/service_requests.py b/model-engine/model_engine_server/common/service_requests.py index ea77f2b9..d709bdec 100644 --- a/model-engine/model_engine_server/common/service_requests.py +++ b/model-engine/model_engine_server/common/service_requests.py @@ -37,9 +37,8 @@ def make_sync_request_with_retries( wait=wait_exponential(multiplier=1, min=1, max=timeout_seconds), ): with attempt: - logger.debug( - f"Retry number {attempt.retry_state.attempt_number}" - ) # pragma: no cover + if attempt.retry_state.attempt_number > 1: # pragma: no cover + logger.info(f"Retry number {attempt.retry_state.attempt_number}") resp = requests.post( request_url, json=payload_json, diff --git a/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml index 10052970..bfdb6553 100644 --- a/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml +++ b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml @@ -9,6 +9,7 @@ forwarder: model_engine_unwrap: true serialize_results_as_string: true forward_http_status: true + extra_routes: [] stream: user_port: 5005 user_hostname: "localhost" @@ -17,4 +18,5 @@ forwarder: batch_route: null model_engine_unwrap: true serialize_results_as_string: false + extra_routes: [] max_concurrency: 100 diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 38b4e8cc..20476339 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -104,6 +104,21 @@ def get_response_payload(using_serialize_results_as_string: bool, response: Any) return {"result": response} + @staticmethod + def get_response_payload_stream(using_serialize_results_as_string: bool, response: str): + """Event stream is needs to be treated as a stream of strings, not JSON objects""" + if using_serialize_results_as_string: + return {"result": response} + + return {"result": parse_to_object_or_string(response)} + + +def parse_to_object_or_string(value: str) -> object: + try: + return json.loads(value) + except json.JSONDecodeError: + return value + @dataclass class Forwarder(ModelEngineSerializationMixin): @@ -123,9 +138,9 @@ class Forwarder(ModelEngineSerializationMixin): predict_endpoint: str model_engine_unwrap: bool serialize_results_as_string: bool - post_inference_hooks_handler: PostInferenceHooksHandler wrap_response: bool forward_http_status: bool + post_inference_hooks_handler: PostInferenceHooksHandler def __call__(self, json_payload: Any) -> Any: json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) @@ -344,9 +359,7 @@ def __call__(self, json_payload: Any) -> Iterator[Any]: client = sseclient.SSEClient(response) for event in client.events(): - yield self.get_response_payload( - using_serialize_results_as_string, json.loads(event.data) - ) + yield self.get_response_payload_stream(using_serialize_results_as_string, event.data) @dataclass(frozen=True) @@ -509,13 +522,22 @@ def _substitute_config_overrides(config: dict, config_overrides: List[str]) -> N raise ValueError(f"Error setting {key_path} to {value} in {config}") from e +def _cast_value(value: Any) -> Any: + if value.isdigit(): + return int(value) + elif value.startswith("[") and value.endswith("]"): + return [_cast_value(v) for v in value[1:-1].split(",")] + else: + return value + + def _set_value(config: dict, key_path: List[str], value: Any) -> None: """ Modifies config by setting the value at config[key_path[0]][key_path[1]]... to be `value`. """ key = key_path[0] if len(key_path) == 1: - config[key] = value if not value.isdigit() else int(value) + config[key] = _cast_value(value) else: if key not in config: config[key] = dict() diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index 2f6ad755..adfbde59 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -1,28 +1,26 @@ import argparse +import asyncio import json import os -import subprocess +import signal from functools import lru_cache +from typing import Any, Dict, Optional +import uvicorn from fastapi import BackgroundTasks, Depends, FastAPI from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter from model_engine_server.common.dtos.tasks import EndpointPredictV1Request from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.forwarding.forwarding import ( + Forwarder, LoadForwarder, LoadStreamingForwarder, + StreamingForwarder, load_named_config, ) -from sse_starlette.sse import EventSourceResponse +from sse_starlette import EventSourceResponse logger = make_logger(logger_name()) -app = FastAPI() - - -@app.get("/healthz") -@app.get("/readyz") -def healthcheck(): - return "OK" def get_config(): @@ -36,15 +34,23 @@ def get_config(): ) -def get_forwarder_loader(): - config = get_config() - forwarder_loader = LoadForwarder(**config["sync"]) +def get_forwarder_loader(destination_path: Optional[str] = None): + config = get_config()["sync"] + if "extra_routes" in config: + del config["extra_routes"] + if destination_path: + config["predict_route"] = destination_path + forwarder_loader = LoadForwarder(**config) return forwarder_loader -def get_streaming_forwarder_loader(): - config = get_config() - streaming_forwarder_loader = LoadStreamingForwarder(**config["stream"]) +def get_streaming_forwarder_loader(destination_path: Optional[str] = None): + config = get_config()["stream"] + if "extra_routes" in config: + del config["extra_routes"] + if destination_path: + config["predict_route"] = destination_path + streaming_forwarder_loader = LoadStreamingForwarder(**config) return streaming_forwarder_loader @@ -58,16 +64,15 @@ def get_concurrency_limiter(): @lru_cache() -def load_forwarder(): - return get_forwarder_loader().load(None, None) +def load_forwarder(destination_path: Optional[str] = None): + return get_forwarder_loader(destination_path).load(None, None) @lru_cache() -def load_streaming_forwarder(): - return get_streaming_forwarder_loader().load(None, None) +def load_streaming_forwarder(destination_path: Optional[str] = None): + return get_streaming_forwarder_loader(destination_path).load(None, None) -@app.post("/predict") def predict( request: EndpointPredictV1Request, background_tasks: BackgroundTasks, @@ -76,7 +81,7 @@ def predict( ): with limiter: try: - response = forwarder(request.dict()) + response = forwarder(request.model_dump()) background_tasks.add_task( forwarder.post_inference_hooks_handler.handle, request, response ) @@ -86,7 +91,6 @@ def predict( raise -@app.post("/stream") async def stream( request: EndpointPredictV1Request, forwarder=Depends(load_streaming_forwarder), @@ -94,7 +98,7 @@ async def stream( ): with limiter: try: - payload = request.dict() + payload = request.model_dump() except Exception: logger.error(f"Failed to decode payload from: {request}") raise @@ -111,43 +115,133 @@ async def event_generator(): return EventSourceResponse(event_generator()) -def entrypoint(): +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): # pragma: no cover + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ", ".join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + server = uvicorn.Server(config) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + async def dummy_shutdown() -> None: + pass + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + return server.shutdown() + + +async def run_server(args, **uvicorn_kwargs) -> None: # pragma: no cover + app = await init_app() + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + **uvicorn_kwargs, + ) + + await shutdown_task + + +async def init_app(): + app = FastAPI() + + def healthcheck(): + return "OK" + + def add_extra_routes(app: FastAPI): + """Read extra_routes from config and dynamically add routes to app""" + config = get_config() + sync_forwarders: Dict[str, Forwarder] = dict() + stream_forwarders: Dict[str, StreamingForwarder] = dict() + for route in config.get("sync", {}).get("extra_routes", []): + sync_forwarders[route] = load_forwarder(route) + for route in config.get("stream", {}).get("extra_routes", []): + stream_forwarders[route] = load_streaming_forwarder(route) + + all_routes = set(list(sync_forwarders.keys()) + list(stream_forwarders.keys())) + + for route in all_routes: + # This route is a catch-all for any requests that don't match the /predict or /stream routes + # It will treat the request as a streaming request if the "stream" body parameter is set to true + # NOTE: it is important for this to be defined AFTER the /predict and /stream endpoints + # because FastAPI will match the first route that matches the request path + async def predict_or_stream( + request: EndpointPredictV1Request, + background_tasks: BackgroundTasks, + sync_forwarder=Depends(lambda: sync_forwarders.get(route)), + stream_forwarder=Depends(lambda: stream_forwarders.get(route)), + limiter=Depends(get_concurrency_limiter), + ): + if not request.args: + raise Exception("Request has no args") + if request.args.root.get("stream", False) and stream_forwarder: + return await stream(request, stream_forwarder, limiter) + elif request.args.root.get("stream") is not True and sync_forwarder: + return predict(request, background_tasks, sync_forwarder, limiter) + else: + raise Exception("No forwarder configured for this route") + + logger.info(f"Adding route {route}") + app.add_api_route( + path=route, + endpoint=predict_or_stream, + methods=["POST"], + ) + + app.add_api_route(path="/healthz", endpoint=healthcheck, methods=["GET"]) + app.add_api_route(path="/readyz", endpoint=healthcheck, methods=["GET"]) + app.add_api_route(path="/predict", endpoint=predict, methods=["POST"]) + app.add_api_route(path="/stream", endpoint=stream, methods=["POST"]) + + add_extra_routes(app) + return app + + +def entrypoint(): # pragma: no cover parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--num-workers", type=int, required=True) - parser.add_argument("--host", type=str, default="[::]") + parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=5000) parser.add_argument("--set", type=str, action="append") parser.add_argument("--graceful-timeout", type=int, default=600) args, extra_args = parser.parse_known_args() - values = [f"CONFIG_FILE={args.config}"] + os.environ["CONFIG_FILE"] = args.config if args.set is not None: - values.append(f"CONFIG_OVERRIDES={';'.join(args.set)}") - envs = [] - for v in values: - envs.extend(["--env", v]) - - command = [ - "gunicorn", - "--bind", - f"{args.host}:{args.port}", - "--timeout", - "1200", - "--keep-alive", - "2", - "--worker-class", - "uvicorn.workers.UvicornWorker", - "--workers", - str(args.num_workers), - "--graceful-timeout", - str(args.graceful_timeout), - *envs, - "model_engine_server.inference.forwarding.http_forwarder:app", - *extra_args, - ] - subprocess.run(command) + os.environ["CONFIG_OVERRIDES"] = ";".join(args.set) + + asyncio.run( + run_server( + args, + timeout_keep_alive=2, + timeout_graceful_shutdown=args.graceful_timeout, + workers=args.num_workers, + *extra_args, + ) + ) if __name__ == "__main__": diff --git a/model-engine/model_engine_server/inference/requirements_base.txt b/model-engine/model_engine_server/inference/requirements_base.txt index aeeb5efd..972a5247 100644 --- a/model-engine/model_engine_server/inference/requirements_base.txt +++ b/model-engine/model_engine_server/inference/requirements_base.txt @@ -10,7 +10,7 @@ importlib-metadata<5.0;python_version<"3.8" scale-launch>=0.1.0 smart_open==5.1.0 typing-extensions>=4.1.1 -uvicorn==0.17.6 +uvicorn==0.30.6 waitress==2.0.0 # HACK: at time of adding, these deps are imported by model-engine/model_engine_server files diff --git a/model-engine/requirements-test.txt b/model-engine/requirements-test.txt index 55a4b9f2..0f7cd2ec 100644 --- a/model-engine/requirements-test.txt +++ b/model-engine/requirements-test.txt @@ -12,6 +12,7 @@ pytest-cov==2.10.0 pytest-mypy==0.9.1 pytest-mypy-plugins==1.10.1 pytest-pylint==0.18.0 +requests-mock==1.9.3 types-cachetools==5.3.0.5 types-croniter==1.4.0.0 types-PyYAML==6.0.7 diff --git a/model-engine/requirements.in b/model-engine/requirements.in index f70d4503..b9b44867 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -57,6 +57,6 @@ tokenizers~=0.15.2 tqdm~=4.64 transformers==4.38.0 twine==3.7.1 -uvicorn==0.17.6 +uvicorn==0.30.6 uvloop==0.17.0 yarl~=1.4 \ No newline at end of file diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index fb0d4d24..3d19c348 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -25,8 +25,6 @@ anyio==3.7.1 # azure-core # httpx # starlette -asgiref==3.7.2 - # via uvicorn asn1crypto==1.5.1 # via scramp async-timeout==4.0.2 @@ -535,7 +533,6 @@ typing-extensions==4.10.0 # via # aioredis # annotated-types - # asgiref # azure-core # azure-keyvault-secrets # azure-servicebus @@ -562,6 +559,7 @@ typing-extensions==4.10.0 # sqlalchemy # starlette # typing-inspect + # uvicorn typing-inspect==0.9.0 # via dataclasses-json tzdata==2023.3 @@ -578,7 +576,7 @@ urllib3==1.26.16 # kubernetes # kubernetes-asyncio # requests -uvicorn==0.17.6 +uvicorn==0.30.6 # via -r model-engine/requirements.in uvloop==0.17.0 # via -r model-engine/requirements.in diff --git a/model-engine/tests/unit/inference/conftest.py b/model-engine/tests/unit/inference/conftest.py index 870f4075..e07a7c73 100644 --- a/model-engine/tests/unit/inference/conftest.py +++ b/model-engine/tests/unit/inference/conftest.py @@ -11,6 +11,11 @@ ) +@pytest.fixture +def anyio_backend(): + return "asyncio" + + @pytest.fixture def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineRequest: model_config = CreateBatchCompletionsModelConfig( @@ -34,7 +39,11 @@ def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineReq @pytest.fixture def create_batch_completions_tool_completion_request(): model_config = CreateBatchCompletionsModelConfig( - checkpoint_path="checkpoint_path", model="model", num_shards=4, seed=123, labels={} + checkpoint_path="checkpoint_path", + model="model", + num_shards=4, + seed=123, + labels={}, ) return CreateBatchCompletionsEngineRequest( @@ -100,7 +109,11 @@ def __init__(self, logprob: float): mock_vllm_request_output3.outputs[0].logprobs = [ {4: Logprob(0.1), 5: Logprob(0.2), 6: Logprob(0.3)} ] - return [mock_vllm_request_output1, mock_vllm_request_output2, mock_vllm_request_output3] + return [ + mock_vllm_request_output1, + mock_vllm_request_output2, + mock_vllm_request_output3, + ] @pytest.fixture diff --git a/model-engine/tests/unit/inference/test_forwarding.py b/model-engine/tests/unit/inference/test_forwarding.py index 68c9ab32..5c996303 100644 --- a/model-engine/tests/unit/inference/test_forwarding.py +++ b/model-engine/tests/unit/inference/test_forwarding.py @@ -14,6 +14,7 @@ LoadForwarder, LoadStreamingForwarder, StreamingForwarder, + load_named_config, ) from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, @@ -22,6 +23,7 @@ from tests.unit.conftest import FakeStreamingStorageGateway PAYLOAD: Mapping[str, str] = {"hello": "world"} +PAYLOAD_END = "[DONE]" def mocked_get(*args, **kwargs): # noqa @@ -63,7 +65,11 @@ class Event: class mocked_static_events: def events(self) -> list: payload_json = json.dumps(PAYLOAD) - return [Event(data=payload_json), Event(data=payload_json)] + return [ + Event(data=payload_json), + Event(data=payload_json), + Event(data=PAYLOAD_END), + ] return mocked_static_events() @@ -96,6 +102,76 @@ def post_inference_hooks_handler(): return handler +def mocked_config_content(): + return { + "forwarder": { + "sync": { + "user_port": 5005, + "user_hostname": "localhost", + "use_grpc": False, + "predict_route": "/predict", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": True, + "forward_http_status": True, + }, + "stream": { + "user_port": 5005, + "user_hostname": "localhost", + "predict_route": "/stream", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": False, + }, + "max_concurrency": 42, + } + } + + +def mocked_config_overrides(): + return [ + "forwarder.sync.extra_routes=[/v1/chat/completions]", + "forwarder.stream.extra_routes=[/v1/chat/completions]", + "forwarder.sync.healthcheck_route=/health", + "forwarder.stream.healthcheck_route=/health", + ] + + +# patch open(config_uri, "rt") and have output be mocked_config_content +@mock.patch("builtins.open", mock.mock_open(read_data=json.dumps(mocked_config_content()))) +def test_load_named_config(): + output = load_named_config("dummy.yml", config_overrides=mocked_config_overrides()) + expected_output = { + "name": "forwarder", + "sync": { + "user_port": 5005, + "user_hostname": "localhost", + "use_grpc": False, + "predict_route": "/predict", + "healthcheck_route": "/health", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": True, + "forward_http_status": True, + "extra_routes": ["/v1/chat/completions"], + }, + "stream": { + "user_port": 5005, + "user_hostname": "localhost", + "predict_route": "/stream", + "healthcheck_route": "/health", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": False, + "extra_routes": ["/v1/chat/completions"], + }, + "max_concurrency": 42, + } + assert output == expected_output + + @mock.patch("requests.post", mocked_post) @mock.patch("requests.get", mocked_get) def test_forwarders(post_inference_hooks_handler): @@ -131,16 +207,18 @@ def _check_responses_not_wrapped(json_response) -> None: def _check_streaming(streaming_response) -> None: streaming_response_list = list(streaming_response) - assert len(streaming_response_list) == 2 + assert len(streaming_response_list) == 3 assert streaming_response_list[0] == {"result": PAYLOAD} assert streaming_response_list[1] == {"result": PAYLOAD} + assert streaming_response_list[2] == {"result": PAYLOAD_END} def _check_streaming_serialized(streaming_response) -> None: streaming_response_list = list(streaming_response) - assert len(streaming_response_list) == 2 + assert len(streaming_response_list) == 3 assert streaming_response_list[0] == {"result": json.dumps(PAYLOAD)} assert streaming_response_list[1] == {"result": json.dumps(PAYLOAD)} + assert streaming_response_list[2] == {"result": PAYLOAD_END} @mock.patch("requests.post", mocked_post) diff --git a/model-engine/tests/unit/inference/test_http_forwarder.py b/model-engine/tests/unit/inference/test_http_forwarder.py index fff38834..e765accc 100644 --- a/model-engine/tests/unit/inference/test_http_forwarder.py +++ b/model-engine/tests/unit/inference/test_http_forwarder.py @@ -1,16 +1,23 @@ +import json import threading from dataclasses import dataclass from typing import Mapping from unittest import mock import pytest +import requests_mock from fastapi import BackgroundTasks from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.domain.entities.model_endpoint_entity import ModelEndpointConfig from model_engine_server.inference.forwarding.forwarding import Forwarder from model_engine_server.inference.forwarding.http_forwarder import ( MultiprocessingConcurrencyLimiter, get_concurrency_limiter, + get_forwarder_loader, + get_streaming_forwarder_loader, + init_app, predict, ) from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( @@ -135,7 +142,34 @@ def mock_request(): ) -@mock.patch("model_engine_server.inference.forwarding.http_forwarder.get_config", mocked_get_config) +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config, +) +def test_get_forwarder_loader(): + loader = get_forwarder_loader() + assert loader.predict_route == "/predict" + + loader = get_forwarder_loader("/v1/chat/completions") + assert loader.predict_route == "/v1/chat/completions" + + +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config, +) +def test_get_streaming_forwarder_loader(): + loader = get_streaming_forwarder_loader() + assert loader.predict_route == "/stream" + + loader = get_streaming_forwarder_loader("/v1/chat/completions") + assert loader.predict_route == "/v1/chat/completions" + + +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config, +) def test_get_concurrency_limiter(): limiter = get_concurrency_limiter() assert isinstance(limiter, MultiprocessingConcurrencyLimiter) @@ -197,3 +231,141 @@ def test_handler_with_logging(post_inference_hooks_handler_with_logging): ) except Exception as e: pytest.fail(f"Unexpected exception: {e}") + + +# Test the fastapi app + + +def mocked_get_config_with_extra_paths(): + return { + "sync": { + "user_port": 5005, + "user_hostname": "localhost", + "use_grpc": False, + "predict_route": "/predict", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": True, + "forward_http_status": True, + "extra_routes": ["/v1/chat/completions"], + }, + "stream": { + "user_port": 5005, + "user_hostname": "localhost", + "predict_route": "/stream", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": False, + "extra_routes": ["/v1/chat/completions"], + }, + "max_concurrency": 42, + } + + +def get_predict_endpoint(config): + cfg_sync = config["sync"] + predict_endpoint = ( + f"http://{cfg_sync['user_hostname']}:{cfg_sync['user_port']}{cfg_sync['predict_route']}" + ) + return predict_endpoint + + +def get_healthcheck_endpoint(config): + cfg_sync = config["sync"] + healthcheck_endpoint = ( + f"http://{cfg_sync['user_hostname']}:{cfg_sync['user_port']}{cfg_sync['healthcheck_route']}" + ) + return healthcheck_endpoint + + +def get_stream_endpoint(config): + cfg_stream = config["stream"] + stream_endpoint = f"http://{cfg_stream['user_hostname']}:{cfg_stream['user_port']}{cfg_stream['predict_route']}" + return stream_endpoint + + +def get_chat_endpoint(config): + cfg_sync = config["sync"] + chat_endpoint = ( + f"http://{cfg_sync['user_hostname']}:{cfg_sync['user_port']}{cfg_sync['extra_routes'][0]}" + ) + return chat_endpoint + + +def mocked_get_endpoint_config(): + return ModelEndpointConfig( + endpoint_name="test_endpoint_name", + bundle_name="test_bundle_name", + ) + + +@pytest.fixture() +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config_with_extra_paths, +) +@mock.patch( + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", + mocked_get_endpoint_config, +) +async def mocked_app(): + with requests_mock.Mocker() as req_mock: + healthcheck_endpoint = get_healthcheck_endpoint(mocked_get_config_with_extra_paths()) + print(healthcheck_endpoint) + req_mock.get( + healthcheck_endpoint, + json={"status": "ok"}, + ) + return await init_app() + + +def wrap_request(request): + return {"url": "", "args": request} + + +def wrap_result(result): + return {"result": result} + + +@pytest.mark.anyio +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config_with_extra_paths, +) +@mock.patch( + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", + mocked_get_endpoint_config, +) +async def test_mocked_app_success(mocked_app): + config = mocked_get_config_with_extra_paths() + config_sync = config["sync"] + # config_stream = config["stream"] + + predict_endpoint = get_predict_endpoint(config) + healthcheck_endpoint = get_healthcheck_endpoint(config) + + # stream_endpoint = get_stream_endpoint(config) + chat_endpoint = get_chat_endpoint(config) + + raw_payload = {"prompt": "Hello", "stream": False} + raw_result = {"message": "Hello World"} + + payload = wrap_request(raw_payload) + expected_result = wrap_result( + json.dumps(raw_result) if config_sync["serialize_results_as_string"] else raw_result + ) + with TestClient(mocked_app) as client, requests_mock.Mocker() as req_mock: + req_mock.get(healthcheck_endpoint, json={"status": "ok"}) + req_mock.post(predict_endpoint, json=raw_result) + response = client.post("/predict", json=payload) + assert response.status_code == 200 + assert response.json() == expected_result + + req_mock.post(chat_endpoint, json=raw_result) + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + assert response.json() == expected_result + + # TODO: add tests for streaming; it's not as trivial as I'd hoped From 62ebf4dc00e1958c221a58c8b85e69d2b48085e6 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 9 Sep 2024 11:30:25 -0700 Subject: [PATCH 377/425] Introduce alembic to repo (#610) --- .ruff.toml | 2 +- .../db/migrations/alembic.ini | 105 +++ .../db/migrations/alembic/README | 20 + .../db/migrations/alembic/env.py | 99 +++ .../db/migrations/alembic/script.py.mako | 24 + .../2024_09_09_1736-fa3267c80731_initial.py | 25 + .../db/migrations/initial.sql | 701 ++++++++++++++++++ 7 files changed, 975 insertions(+), 1 deletion(-) create mode 100644 model-engine/model_engine_server/db/migrations/alembic.ini create mode 100644 model-engine/model_engine_server/db/migrations/alembic/README create mode 100644 model-engine/model_engine_server/db/migrations/alembic/env.py create mode 100644 model-engine/model_engine_server/db/migrations/alembic/script.py.mako create mode 100644 model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py create mode 100644 model-engine/model_engine_server/db/migrations/initial.sql diff --git a/.ruff.toml b/.ruff.toml index 3a61ae77..69f83253 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -2,4 +2,4 @@ line-length = 100 ignore = ["E501"] -exclude = ["gen"] +exclude = ["gen", "alembic"] diff --git a/model-engine/model_engine_server/db/migrations/alembic.ini b/model-engine/model_engine_server/db/migrations/alembic.ini new file mode 100644 index 00000000..23f7c0ea --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic.ini @@ -0,0 +1,105 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/model-engine/model_engine_server/db/migrations/alembic/README b/model-engine/model_engine_server/db/migrations/alembic/README new file mode 100644 index 00000000..cfedbcc5 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic/README @@ -0,0 +1,20 @@ +# Setup + +We introduce alembic by +1. dumping the current db schemas into 'initial.sql' via pg_dump + +``` +pg_dump -h $HOST -U postgres -O -s -d $DB_NAME -n hosted_model_inference -n model -f initial.sql +``` + +2. writing an initial revision that reads and applies intial.sql script + +``` +alembic revision -m “initial” +``` + +3. Stamping the current revision to our production db to avoid actually running it on production + +``` +alembic stamp fa3267c80731 +``` diff --git a/model-engine/model_engine_server/db/migrations/alembic/env.py b/model-engine/model_engine_server/db/migrations/alembic/env.py new file mode 100644 index 00000000..3f4b73b3 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic/env.py @@ -0,0 +1,99 @@ +import logging +import os +from logging.config import fileConfig + +from alembic import context +from model_engine_server.db.base import get_engine_url +from sqlalchemy import engine_from_config, pool + +env = os.environ.get("ENV") +assert env is not None, "Expected ENV to be a nonempty environment variable." + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +config.set_main_option("sqlalchemy.url", get_engine_url(env, read_only=False).url) + +ALEMBIC_TABLE_NAME = "alembic_version_model_engine" + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = None + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + version_table=ALEMBIC_TABLE_NAME, + ) + + try: + with context.begin_transaction(): + context.run_migrations() + except Exception as e: + logging.exception("Error during migration: %s", str(e)) + raise e + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + version_table=ALEMBIC_TABLE_NAME, + ) + + try: + with context.begin_transaction(): + context.run_migrations() + except Exception as e: + logging.exception("Error during migration: %s", str(e)) + raise e + finally: + connection.close() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/model-engine/model_engine_server/db/migrations/alembic/script.py.mako b/model-engine/model_engine_server/db/migrations/alembic/script.py.mako new file mode 100644 index 00000000..55df2863 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py new file mode 100644 index 00000000..acad7cff --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py @@ -0,0 +1,25 @@ +"""“initial” + +Revision ID: fa3267c80731 +Revises: +Create Date: 2024-09-09 17:36:30.097136 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "fa3267c80731" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with open("migrations/initial.sql") as fd: + op.execute(fd.read()) + + +def downgrade() -> None: + pass diff --git a/model-engine/model_engine_server/db/migrations/initial.sql b/model-engine/model_engine_server/db/migrations/initial.sql new file mode 100644 index 00000000..93655bf3 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/initial.sql @@ -0,0 +1,701 @@ +-- +-- PostgreSQL database dump +-- + +-- Dumped from database version 13.12 +-- Dumped by pg_dump version 13.16 (Ubuntu 13.16-1.pgdg20.04+1) + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; +SET row_security = off; + +-- +-- Name: hosted_model_inference; Type: SCHEMA; Schema: -; Owner: - +-- + +CREATE SCHEMA hosted_model_inference; + + +-- +-- Name: model; Type: SCHEMA; Schema: -; Owner: - +-- + +CREATE SCHEMA model; + + +SET default_tablespace = ''; + +SET default_table_access_method = heap; + +-- +-- Name: batch_jobs; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.batch_jobs ( + id text NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + completed_at timestamp with time zone, + status text NOT NULL, + created_by character varying(24) NOT NULL, + owner character varying(24) NOT NULL, + model_bundle_id text NOT NULL, + model_endpoint_id text, + task_ids_location text, + result_location text +); + + +-- +-- Name: bundles; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.bundles ( + id text NOT NULL, + name character varying(50), + created_by character varying(24), + created_at timestamp with time zone DEFAULT now(), + location text, + version character varying(24), + registered_model_name text, + bundle_metadata json, + requirements json, + env_params json, + packaging_type text, + app_config json, + model_artifact_ids text[] DEFAULT '{}'::text[], + schema_location text, + owner character varying(24) NOT NULL, + flavor text, + artifact_requirements text[], + artifact_app_config json, + artifact_framework_type text, + artifact_pytorch_image_tag text, + artifact_tensorflow_version text, + artifact_image_repository text, + artifact_image_tag text, + cloudpickle_artifact_load_predict_fn text, + cloudpickle_artifact_load_model_fn text, + zip_artifact_load_predict_fn_module_path text, + zip_artifact_load_model_fn_module_path text, + runnable_image_repository text, + runnable_image_tag text, + runnable_image_command text[], + runnable_image_env json, + runnable_image_protocol text, + artifact_location text, + runnable_image_readiness_initial_delay_seconds integer, + triton_enhanced_runnable_image_model_repository text, + triton_enhanced_runnable_image_model_replicas json, + triton_enhanced_runnable_image_num_cpu numeric, + triton_enhanced_runnable_image_commit_tag text, + triton_enhanced_runnable_image_storage text, + triton_enhanced_runnable_image_memory text, + triton_enhanced_runnable_image_readiness_initial_delay_seconds integer, + streaming_enhanced_runnable_image_streaming_command text[], + runnable_image_predict_route text, + streaming_enhanced_runnable_image_streaming_predict_route text, + runnable_image_healthcheck_route text, + CONSTRAINT bundles_flavor_0 CHECK ((flavor = ANY (ARRAY['cloudpickle_artifact'::text, 'zip_artifact'::text, 'runnable_image'::text, 'triton_enhanced_runnable_image'::text, 'streaming_enhanced_runnable_image'::text]))), + CONSTRAINT bundles_flavor_1 CHECK (((flavor ~~ '%_artifact'::text) = (artifact_requirements IS NOT NULL))), + CONSTRAINT bundles_flavor_10 CHECK (((flavor = 'zip_artifact'::text) = (zip_artifact_load_predict_fn_module_path IS NOT NULL))), + CONSTRAINT bundles_flavor_11 CHECK (((flavor = 'zip_artifact'::text) = (zip_artifact_load_model_fn_module_path IS NOT NULL))), + CONSTRAINT bundles_flavor_12 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_repository IS NOT NULL))), + CONSTRAINT bundles_flavor_13 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_tag IS NOT NULL))), + CONSTRAINT bundles_flavor_14 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_command IS NOT NULL))), + CONSTRAINT bundles_flavor_15 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_protocol IS NOT NULL))), + CONSTRAINT bundles_flavor_16 CHECK (((flavor = 'triton_enhanced_runnable_image'::text) = (triton_enhanced_runnable_image_model_repository IS NOT NULL))), + CONSTRAINT bundles_flavor_17 CHECK (((flavor = 'triton_enhanced_runnable_image'::text) = (triton_enhanced_runnable_image_num_cpu IS NOT NULL))), + CONSTRAINT bundles_flavor_18 CHECK (((flavor = 'triton_enhanced_runnable_image'::text) = (triton_enhanced_runnable_image_commit_tag IS NOT NULL))), + CONSTRAINT bundles_flavor_19 CHECK (((flavor = 'triton_enhanced_runnable_image'::text) = (triton_enhanced_runnable_image_readiness_initial_delay_seconds IS NOT NULL))), + CONSTRAINT bundles_flavor_2 CHECK (((flavor ~~ '%_artifact'::text) = (artifact_location IS NOT NULL))), + CONSTRAINT bundles_flavor_20 CHECK (((flavor = 'streaming_enhanced_runnable_image'::text) = (streaming_enhanced_runnable_image_streaming_command IS NOT NULL))), + CONSTRAINT bundles_flavor_21 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_predict_route IS NOT NULL))), + CONSTRAINT bundles_flavor_22 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_healthcheck_route IS NOT NULL))), + CONSTRAINT bundles_flavor_23 CHECK (((flavor = 'streaming_enhanced_runnable_image'::text) = (streaming_enhanced_runnable_image_streaming_predict_route IS NOT NULL))), + CONSTRAINT bundles_flavor_3 CHECK (((flavor ~~ '%_artifact'::text) = (artifact_framework_type IS NOT NULL))), + CONSTRAINT bundles_flavor_4 CHECK (((artifact_framework_type = 'pytorch'::text) = (artifact_pytorch_image_tag IS NOT NULL))), + CONSTRAINT bundles_flavor_5 CHECK (((artifact_framework_type = 'tensorflow'::text) = (artifact_tensorflow_version IS NOT NULL))), + CONSTRAINT bundles_flavor_6 CHECK (((artifact_framework_type = 'custom_base_image'::text) = (artifact_image_repository IS NOT NULL))), + CONSTRAINT bundles_flavor_7 CHECK (((artifact_framework_type = 'custom_base_image'::text) = (artifact_image_tag IS NOT NULL))), + CONSTRAINT bundles_flavor_8 CHECK (((flavor = 'cloudpickle_artifact'::text) = (cloudpickle_artifact_load_predict_fn IS NOT NULL))), + CONSTRAINT bundles_flavor_9 CHECK (((flavor = 'cloudpickle_artifact'::text) = (cloudpickle_artifact_load_model_fn IS NOT NULL))) +); + + +-- +-- Name: docker_image_batch_job_bundles; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.docker_image_batch_job_bundles ( + id text NOT NULL, + name text NOT NULL, + created_by character varying(24) NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + owner character varying(24) NOT NULL, + image_repository text NOT NULL, + image_tag text NOT NULL, + command text[] NOT NULL, + env json NOT NULL, + mount_location text, + cpus text, + memory text, + storage text, + gpus integer, + gpu_type text, + public boolean +); + + +-- +-- Name: endpoints; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.endpoints ( + id text NOT NULL, + name text, + created_by character varying(24), + created_at timestamp with time zone DEFAULT now(), + last_updated_at timestamp with time zone DEFAULT now(), + current_bundle_id text, + endpoint_metadata jsonb, + creation_task_id text, + endpoint_type text, + destination text, + endpoint_status text, + owner character varying(24) NOT NULL, + public_inference boolean +); + + +-- +-- Name: triggers; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.triggers ( + id character varying NOT NULL, + name character varying NOT NULL, + owner character varying NOT NULL, + created_by character varying NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + cron_schedule character varying NOT NULL, + docker_image_batch_job_bundle_id character varying NOT NULL, + default_job_config jsonb, + default_job_metadata jsonb +); + + +-- +-- Name: model_artifacts; Type: TABLE; Schema: model; Owner: - +-- + +CREATE TABLE model.model_artifacts ( + id text NOT NULL, + name text NOT NULL, + description text, + is_public boolean NOT NULL, + created_by character varying(24) NOT NULL, + owner character varying(24) NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + input_schema json, + output_schema json, + config json NOT NULL, + location text NOT NULL, + format text NOT NULL, + format_metadata json NOT NULL, + source text NOT NULL, + source_metadata json NOT NULL +); + + +-- +-- Name: model_versions; Type: TABLE; Schema: model; Owner: - +-- + +CREATE TABLE model.model_versions ( + id text NOT NULL, + model_id text NOT NULL, + version_number integer NOT NULL, + tags text[] NOT NULL, + created_by character varying(24) NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + launch_model_bundle_id text, + nucleus_model_id text, + metadata json DEFAULT '{}'::json NOT NULL +); + + +-- +-- Name: models; Type: TABLE; Schema: model; Owner: - +-- + +CREATE TABLE model.models ( + id text NOT NULL, + name text NOT NULL, + description text, + task_types text[] NOT NULL, + created_by character varying(24) NOT NULL, + owner character varying(24) NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL +); + + +-- +-- Name: batch_jobs batch_jobs_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.batch_jobs + ADD CONSTRAINT batch_jobs_pkey PRIMARY KEY (id); + + +-- +-- Name: bundles bundles_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.bundles + ADD CONSTRAINT bundles_pkey PRIMARY KEY (id); + + +-- +-- Name: docker_image_batch_job_bundles docker_image_batch_job_bundles_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.docker_image_batch_job_bundles + ADD CONSTRAINT docker_image_batch_job_bundles_pkey PRIMARY KEY (id); + + +-- +-- Name: endpoints endpoint_name_created_by_uc; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.endpoints + ADD CONSTRAINT endpoint_name_created_by_uc UNIQUE (name, created_by); + + +-- +-- Name: endpoints endpoint_name_owner_uc; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.endpoints + ADD CONSTRAINT endpoint_name_owner_uc UNIQUE (name, owner); + + +-- +-- Name: endpoints endpoints_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.endpoints + ADD CONSTRAINT endpoints_pkey PRIMARY KEY (id); + + +-- +-- Name: triggers triggers_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.triggers + ADD CONSTRAINT triggers_pkey PRIMARY KEY (id); + + +-- +-- Name: triggers uq_triggers_name_owner; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.triggers + ADD CONSTRAINT uq_triggers_name_owner UNIQUE (name, owner); + + +-- +-- Name: model_versions launch_model_bundle_id_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT launch_model_bundle_id_uc UNIQUE (launch_model_bundle_id); + + +-- +-- Name: model_artifacts model_artifacts_owner_name_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_artifacts + ADD CONSTRAINT model_artifacts_owner_name_uc UNIQUE (owner, name); + + +-- +-- Name: model_artifacts model_artifacts_pkey; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_artifacts + ADD CONSTRAINT model_artifacts_pkey PRIMARY KEY (id); + + +-- +-- Name: model_versions model_id_version_number_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT model_id_version_number_uc UNIQUE (model_id, version_number); + + +-- +-- Name: model_versions model_versions_pkey; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT model_versions_pkey PRIMARY KEY (id); + + +-- +-- Name: models models_owner_name_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.models + ADD CONSTRAINT models_owner_name_uc UNIQUE (owner, name); + + +-- +-- Name: models models_pkey; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.models + ADD CONSTRAINT models_pkey PRIMARY KEY (id); + + +-- +-- Name: model_versions nucleus_model_id_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT nucleus_model_id_uc UNIQUE (nucleus_model_id); + + +-- +-- Name: endpoint_name_llm_uc; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE UNIQUE INDEX endpoint_name_llm_uc ON hosted_model_inference.endpoints USING btree (name) WHERE (endpoint_metadata ? '_llm'::text); + + +-- +-- Name: idx_endpoint_metadata; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX idx_endpoint_metadata ON hosted_model_inference.endpoints USING gin (endpoint_metadata); + + +-- +-- Name: idx_trigger_name; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX idx_trigger_name ON hosted_model_inference.triggers USING btree (name); + + +-- +-- Name: ix_hosted_model_inference_batch_jobs_created_by; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_batch_jobs_created_by ON hosted_model_inference.batch_jobs USING btree (created_by); + + +-- +-- Name: ix_hosted_model_inference_batch_jobs_owner; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_batch_jobs_owner ON hosted_model_inference.batch_jobs USING btree (owner); + + +-- +-- Name: ix_hosted_model_inference_bundles_created_by; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_bundles_created_by ON hosted_model_inference.bundles USING btree (created_by); + + +-- +-- Name: ix_hosted_model_inference_bundles_name; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_bundles_name ON hosted_model_inference.bundles USING btree (name); + + +-- +-- Name: ix_hosted_model_inference_docker_image_batch_job_bundle_79a0; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_docker_image_batch_job_bundle_79a0 ON hosted_model_inference.docker_image_batch_job_bundles USING btree (created_by); + + +-- +-- Name: ix_hosted_model_inference_docker_image_batch_job_bundles_owner; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_docker_image_batch_job_bundles_owner ON hosted_model_inference.docker_image_batch_job_bundles USING btree (owner); + + +-- +-- Name: ix_hosted_model_inference_endpoints_created_by; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_endpoints_created_by ON hosted_model_inference.endpoints USING btree (created_by); + + +-- +-- Name: ix_hosted_model_inference_endpoints_name; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_endpoints_name ON hosted_model_inference.endpoints USING btree (name); + + +-- +-- Name: ix_model_model_artifacts_created_by; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_created_by ON model.model_artifacts USING btree (created_by); + + +-- +-- Name: ix_model_model_artifacts_description; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_description ON model.model_artifacts USING btree (description); + + +-- +-- Name: ix_model_model_artifacts_format; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_format ON model.model_artifacts USING btree (format); + + +-- +-- Name: ix_model_model_artifacts_is_public; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_is_public ON model.model_artifacts USING btree (is_public); + + +-- +-- Name: ix_model_model_artifacts_name; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_name ON model.model_artifacts USING btree (name); + + +-- +-- Name: ix_model_model_artifacts_owner; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_owner ON model.model_artifacts USING btree (owner); + + +-- +-- Name: ix_model_model_artifacts_source; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_source ON model.model_artifacts USING btree (source); + + +-- +-- Name: ix_model_model_versions_created_by; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_versions_created_by ON model.model_versions USING btree (created_by); + + +-- +-- Name: ix_model_model_versions_model_id; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_versions_model_id ON model.model_versions USING btree (model_id); + + +-- +-- Name: ix_model_model_versions_tags; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_versions_tags ON model.model_versions USING btree (tags); + + +-- +-- Name: ix_model_model_versions_version_number; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_versions_version_number ON model.model_versions USING btree (version_number); + + +-- +-- Name: ix_model_models_created_by; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_created_by ON model.models USING btree (created_by); + + +-- +-- Name: ix_model_models_description; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_description ON model.models USING btree (description); + + +-- +-- Name: ix_model_models_name; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_name ON model.models USING btree (name); + + +-- +-- Name: ix_model_models_owner; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_owner ON model.models USING btree (owner); + + +-- +-- Name: ix_model_models_task_types; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_task_types ON model.models USING btree (task_types); + + +-- +-- Name: batch_jobs batch_jobs_model_bundle_id_fkey; Type: FK CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.batch_jobs + ADD CONSTRAINT batch_jobs_model_bundle_id_fkey FOREIGN KEY (model_bundle_id) REFERENCES hosted_model_inference.bundles(id); + + +-- +-- Name: batch_jobs batch_jobs_model_endpoint_id_fkey; Type: FK CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.batch_jobs + ADD CONSTRAINT batch_jobs_model_endpoint_id_fkey FOREIGN KEY (model_endpoint_id) REFERENCES hosted_model_inference.endpoints(id) ON DELETE SET NULL; + + +-- +-- Name: endpoints endpoints_current_bundle_id_fkey; Type: FK CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.endpoints + ADD CONSTRAINT endpoints_current_bundle_id_fkey FOREIGN KEY (current_bundle_id) REFERENCES hosted_model_inference.bundles(id); + + +-- +-- Name: triggers triggers_docker_image_batch_job_bundle_id_fkey; Type: FK CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.triggers + ADD CONSTRAINT triggers_docker_image_batch_job_bundle_id_fkey FOREIGN KEY (docker_image_batch_job_bundle_id) REFERENCES hosted_model_inference.docker_image_batch_job_bundles(id); + + +-- +-- Name: model_versions model_versions_launch_model_bundle_id_fkey; Type: FK CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT model_versions_launch_model_bundle_id_fkey FOREIGN KEY (launch_model_bundle_id) REFERENCES hosted_model_inference.bundles(id); + + +-- +-- Name: model_versions model_versions_model_id_fkey; Type: FK CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT model_versions_model_id_fkey FOREIGN KEY (model_id) REFERENCES model.models(id); + + +-- +-- Name: SCHEMA hosted_model_inference; Type: ACL; Schema: -; Owner: - +-- + +GRANT USAGE ON SCHEMA hosted_model_inference TO fivetran; + + +-- +-- Name: SCHEMA model; Type: ACL; Schema: -; Owner: - +-- + +GRANT USAGE ON SCHEMA model TO fivetran; + + +-- +-- Name: TABLE batch_jobs; Type: ACL; Schema: hosted_model_inference; Owner: - +-- + +GRANT SELECT ON TABLE hosted_model_inference.batch_jobs TO fivetran; + + +-- +-- Name: TABLE bundles; Type: ACL; Schema: hosted_model_inference; Owner: - +-- + +GRANT SELECT ON TABLE hosted_model_inference.bundles TO fivetran; + + +-- +-- Name: TABLE docker_image_batch_job_bundles; Type: ACL; Schema: hosted_model_inference; Owner: - +-- + +GRANT SELECT ON TABLE hosted_model_inference.docker_image_batch_job_bundles TO fivetran; + + +-- +-- Name: TABLE endpoints; Type: ACL; Schema: hosted_model_inference; Owner: - +-- + +GRANT SELECT ON TABLE hosted_model_inference.endpoints TO fivetran; + + +-- +-- Name: TABLE triggers; Type: ACL; Schema: hosted_model_inference; Owner: - +-- + +GRANT SELECT ON TABLE hosted_model_inference.triggers TO fivetran; + + +-- +-- Name: TABLE model_artifacts; Type: ACL; Schema: model; Owner: - +-- + +GRANT SELECT ON TABLE model.model_artifacts TO fivetran; + + +-- +-- Name: TABLE model_versions; Type: ACL; Schema: model; Owner: - +-- + +GRANT SELECT ON TABLE model.model_versions TO fivetran; + + +-- +-- Name: TABLE models; Type: ACL; Schema: model; Owner: - +-- + +GRANT SELECT ON TABLE model.models TO fivetran; + + +-- +-- Name: DEFAULT PRIVILEGES FOR TABLES; Type: DEFAULT ACL; Schema: hosted_model_inference; Owner: - +-- + +ALTER DEFAULT PRIVILEGES FOR ROLE postgres IN SCHEMA hosted_model_inference GRANT SELECT ON TABLES TO fivetran; + + +-- +-- Name: DEFAULT PRIVILEGES FOR TABLES; Type: DEFAULT ACL; Schema: model; Owner: - +-- + +ALTER DEFAULT PRIVILEGES FOR ROLE postgres IN SCHEMA model GRANT SELECT ON TABLES TO fivetran; + + +-- +-- PostgreSQL database dump complete +-- + From 6bbacf083b7a116b8be0f04584f3deccc1d23b8a Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 11 Sep 2024 15:15:16 -0700 Subject: [PATCH 378/425] Chat completion API (#609) * WIP * separate vllm params * Updates to MIG + LAG * Update k8s deployments w/ extra routes * Add chat template override * Update route * Chat completion use cases * Add api routes * fix batch mypy * fix tests * fix tests * PR cleanup * Fix column type * DB migration for chat completion * Add database migration to helm chart * Fix db manager initialization * Fix subcommand building quotes * more fixes * Fix 'required' fields from openai spec * fix extra_args parsing in forwarder * Update destination_path to be optional * Update black config to ignore gen/ * Strict nullable openapi generation * Clean up error handling * fix test * Try to get helm charts working on circleci * fix test * Fix migration job * fix alembic env * fix migration file * Fix alembic bug * add readme for running migrations * more test coverage * more test coverage * more test coverage * Update http forwarder to be async * Fix error handling * Working forwarder * no cover for test - to be covered by integration tests * Add test for 400 handling in forwarder * remove unused code * cleanup * test cov * test cov --- .black.toml | 2 + charts/model-engine/Chart.yaml | 2 +- ...t_job.yaml => database_migration_job.yaml} | 11 +- .../service_template_config_map.yaml | 4 + charts/model-engine/values.yaml | 1 + charts/model-engine/values_sample.yaml | 1 + clients/python/llmengine/completion.py | 6 +- .../model_engine_server/api/v2/__init__.py | 2 + .../api/v2/chat_completion.py | 277 ++++++++++++++ .../common/dtos/llms/batch_completion.py | 4 +- .../common/dtos/llms/chat_completion.py | 127 ++----- .../common/dtos/llms/model_endpoints.py | 12 + .../common/dtos/llms/vllm.py | 146 ++++++++ .../model_engine_server/common/dtos/tasks.py | 1 + .../common/types/gen/openai.py | 189 +++++----- model-engine/model_engine_server/db/base.py | 19 +- .../model_engine_server/db/migrations/README | 43 +++ .../db/migrations/alembic/README | 20 - .../db/migrations/alembic/env.py | 7 +- .../2024_09_09_1736-fa3267c80731_initial.py | 7 +- ...711e35_chat_completion_add_extra_routes.py | 32 ++ .../db/migrations/initial.sql | 84 ----- .../db/migrations/run_database_migration.sh | 10 + .../db/models/hosted_model_inference.py | 7 +- .../domain/entities/llm_entity.py | 1 + .../domain/entities/model_bundle_entity.py | 1 + .../model_engine_server/domain/exceptions.py | 6 + .../use_cases/llm_model_endpoint_use_cases.py | 350 +++++++++++++++++- .../inference/forwarding/forwarding.py | 24 +- .../inference/forwarding/http_forwarder.py | 5 +- .../gateways/abs_llm_artifact_gateway.py | 3 +- ...eaming_model_endpoint_inference_gateway.py | 28 +- ...e_sync_model_endpoint_inference_gateway.py | 24 +- .../gateways/resources/k8s_resource_types.py | 24 +- .../service_template_config_map_circleci.yaml | 24 ++ .../db_model_bundle_repository.py | 22 +- model-engine/tests/unit/conftest.py | 72 +++- model-engine/tests/unit/domain/conftest.py | 28 ++ .../tests/unit/domain/test_llm_use_cases.py | 87 ++++- .../tests/unit/inference/test_forwarding.py | 32 +- .../tests/unit/inference/test_vllm_batch.py | 8 +- ...test_live_async_model_inference_gateway.py | 1 + ...eaming_model_endpoint_inference_gateway.py | 23 +- ...e_sync_model_endpoint_inference_gateway.py | 22 +- scripts/generate-openai-types.sh | 2 + scripts/openai-spec.yaml | 6 +- 46 files changed, 1416 insertions(+), 391 deletions(-) rename charts/model-engine/templates/{database_init_job.yaml => database_migration_job.yaml} (85%) create mode 100644 model-engine/model_engine_server/api/v2/chat_completion.py create mode 100644 model-engine/model_engine_server/common/dtos/llms/vllm.py create mode 100644 model-engine/model_engine_server/db/migrations/README delete mode 100644 model-engine/model_engine_server/db/migrations/alembic/README create mode 100644 model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1831-b574e9711e35_chat_completion_add_extra_routes.py create mode 100755 model-engine/model_engine_server/db/migrations/run_database_migration.sh diff --git a/.black.toml b/.black.toml index b9123f45..ab65d233 100644 --- a/.black.toml +++ b/.black.toml @@ -16,6 +16,8 @@ exclude = ''' | buck-out | build | dist + | alembic + | gen )/ ) ''' diff --git a/charts/model-engine/Chart.yaml b/charts/model-engine/Chart.yaml index 175dba37..1ebd5db6 100644 --- a/charts/model-engine/Chart.yaml +++ b/charts/model-engine/Chart.yaml @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.2 +version: 0.1.3 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to diff --git a/charts/model-engine/templates/database_init_job.yaml b/charts/model-engine/templates/database_migration_job.yaml similarity index 85% rename from charts/model-engine/templates/database_init_job.yaml rename to charts/model-engine/templates/database_migration_job.yaml index 0c273de9..183814c6 100644 --- a/charts/model-engine/templates/database_init_job.yaml +++ b/charts/model-engine/templates/database_migration_job.yaml @@ -1,12 +1,12 @@ -{{- if or (.Values.secrets.kubernetesDatabaseSecretName) (.Values.db.runDbInitScript) }} +{{- if or (.Values.secrets.kubernetesDatabaseSecretName) (.Values.db.runDbMigrationScript) }} apiVersion: batch/v1 kind: Job metadata: - name: {{ include "modelEngine.fullname" . }}-database-setup + name: {{ include "modelEngine.fullname" . }}-database-migration labels: {{- include "modelEngine.labels" . | nindent 4 }} annotations: - "helm.sh/hook": pre-install + "helm.sh/hook": pre-install,pre-upgrade "helm.sh/hook-weight": "-1" "helm.sh/hook-delete-policy": hook-succeeded spec: @@ -31,9 +31,8 @@ spec: - dumb-init - -- args: - - python - - -m - - model_engine_server.entrypoints.init_database + - bash + - /workspace/model-engine/model_engine_server/db/migrations/run_database_migration.sh {{- include "modelEngine.serviceEnvGitTagFromHelmVar" . | indent 10 }} {{- include "modelEngine.volumeMounts" . | indent 10 }} serviceAccountName: {{ include "modelEngine.fullname" . }} diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index a756e99e..19f70286 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -180,6 +180,10 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" + - --set + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" {{- $sync_forwarder_template_env | nindent 14 }} readinessProbe: httpGet: diff --git a/charts/model-engine/values.yaml b/charts/model-engine/values.yaml index b75b6efa..1ea7522e 100644 --- a/charts/model-engine/values.yaml +++ b/charts/model-engine/values.yaml @@ -5,6 +5,7 @@ redis: auth: db: runDbInitScript: false + runDbMigrationScript: false balloonNodeSelector: node-lifecycle: normal nodeSelector: diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 430abea6..7eb04e52 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -26,6 +26,7 @@ keyvaultName: llm-engine-keyvault db: runDbInitScript: false + runDbMigrationScript: false # serviceAccount [required] specifies the service account for LLM Engine server deployments (e.g gateway, cache, and builder deployments). serviceAccount: diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 8a9dd5ec..29617f26 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -690,8 +690,10 @@ def get_batch_completion( ```python from llmengine import Completion - response = Completion.get_batch_completion(job_id="job-id") - print(response) + response = Completion.get_batch_completion(job_id="job_id") + print( + f"Current job status for {job_id} is {job.status}" + ) ``` """ response = cls._get( diff --git a/model-engine/model_engine_server/api/v2/__init__.py b/model-engine/model_engine_server/api/v2/__init__.py index dbcb4d67..d8d906b6 100644 --- a/model-engine/model_engine_server/api/v2/__init__.py +++ b/model-engine/model_engine_server/api/v2/__init__.py @@ -3,8 +3,10 @@ from fastapi import APIRouter from .batch_completion import batch_completions_router_v2 +from .chat_completion import chat_router_v2 llm_router_v2 = APIRouter(prefix="/v2") llm_router_v2.include_router(batch_completions_router_v2) +llm_router_v2.include_router(chat_router_v2) __all__: Sequence[str] = ("llm_router_v2",) diff --git a/model-engine/model_engine_server/api/v2/chat_completion.py b/model-engine/model_engine_server/api/v2/chat_completion.py new file mode 100644 index 00000000..f5d5a2db --- /dev/null +++ b/model-engine/model_engine_server/api/v2/chat_completion.py @@ -0,0 +1,277 @@ +import traceback +from datetime import datetime +from typing import Any + +import pytz +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.llms import ( + ChatCompletionV2ErrorChunk, + ChatCompletionV2Request, + ChatCompletionV2Response, + ChatCompletionV2ResponseItem, + StreamError, + StreamErrorContent, + TokenUsage, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.core.utils.timer import timer +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + InvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + UpstreamServiceError, +) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + ChatCompletionStreamV2UseCase, + ChatCompletionSyncV2UseCase, +) +from sse_starlette import EventSourceResponse + +from .common import get_metric_metadata, record_route_call + +logger = make_logger(logger_name()) + +chat_router_v2 = APIRouter(dependencies=[Depends(record_route_call)]) + + +def handle_streaming_exception( + e: Exception, + code: int, + message: str, +): # pragma: no cover + tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": message, + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Exception: %s", structured_log) + return { + "data": ChatCompletionV2ErrorChunk( + request_id=str(request_id), + error=StreamError( + status_code=code, + content=StreamErrorContent( + error=message, + timestamp=timestamp, + ), + ), + ).model_dump_json(exclude_none=True) + } + + +async def handle_stream_request( + external_interfaces: ExternalInterfaces, + background_tasks: BackgroundTasks, + request: ChatCompletionV2Request, + auth: User, + model_endpoint_name: str, + metric_metadata: MetricMetadata, +): # pragma: no cover + use_case = ChatCompletionStreamV2UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + + try: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except ( + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + ) as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException( + status_code=500, + detail="Internal error occurred. Our team has been notified.", + ) from exc + + async def event_generator(): + try: + ttft = None + message = None + with timer() as use_case_timer: # todo, this should be move to start of method + async for message in response: + if ttft is None: + ttft = use_case_timer.lap() + # if ttft is None and message.startswith("data"): + # ttft = use_case_timer.lap() + print("message", message.model_dump_json(exclude_none=True)) + yield {"data": message.model_dump_json(exclude_none=True)} + + if message: + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=(message.usage.prompt_tokens if message.usage else None), + num_completion_tokens=( + message.usage.completion_tokens if message.usage else None + ), + total_duration=use_case_timer.duration, + ), + metric_metadata, + ) + + # The following two exceptions are only raised after streaming begins, so we wrap the exception within a Response object + except InvalidRequestException as exc: + yield handle_streaming_exception(exc, 400, str(exc)) + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + yield handle_streaming_exception( + exc, + 500, + f"Upstream service error for request_id {request_id}", + ) + except Exception as exc: + yield handle_streaming_exception( + exc, 500, "Internal error occurred. Our team has been notified." + ) + + return EventSourceResponse(event_generator()) + + +async def handle_sync_request( + external_interfaces: ExternalInterfaces, + request: ChatCompletionV2Request, + background_tasks: BackgroundTasks, + auth: User, + model_endpoint_name: str, + metric_metadata: MetricMetadata, +): + try: + use_case = ChatCompletionSyncV2UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + with timer() as use_case_timer: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=(response.usage.prompt_tokens if response.usage else None), + num_completion_tokens=( + response.usage.completion_tokens if response.usage else None + ), + total_duration=use_case_timer.duration, + ), + metric_metadata, + ) + return response + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + raise HTTPException( + status_code=500, + detail=f"Upstream service error for request_id {request_id}", + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + if isinstance(exc, ObjectNotAuthorizedException): # pragma: no cover + logger.info( + f"POST /completions-sync to endpoint {model_endpoint_name} for {auth} failed with authz error {exc.args}" + ) + + raise HTTPException( + status_code=404, + detail="The specified endpoint could not be found.", + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=to_error_details(exc)) + except InvalidRequestException as exc: + raise HTTPException(status_code=400, detail=to_error_details(exc)) + except EndpointUnsupportedRequestException as exc: + raise HTTPException( + status_code=400, + detail=f"Endpoint does not support request: {str(exc)}", + ) from exc + except EndpointUnsupportedInferenceTypeException as exc: + raise HTTPException( + status_code=400, + detail=f"Unsupported inference type: {str(exc)}", + ) from exc + + +def to_error_details(exc: Exception) -> Any: + if not exc.args or len(exc.args) == 0: + return str(exc) + if len(exc.args) == 1: + return exc.args[0] + else: + return exc.args + + +@chat_router_v2.post("/chat/completions", response_model=ChatCompletionV2ResponseItem) +async def chat_completion( + request: ChatCompletionV2Request, + background_tasks: BackgroundTasks, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +) -> ChatCompletionV2Response: # pragma: no cover + model_endpoint_name = request.model + if hmi_config.sensitive_log_mode: + logger.info( + f"POST /v2/chat/completion ({('stream' if request.stream else 'sync')}) to endpoint {model_endpoint_name} for {auth}" + ) + else: + logger.info( + f"POST /v2/chat/completion ({('stream' if request.stream else 'sync')}) with {request} to endpoint {model_endpoint_name} for {auth}" + ) + + if request.stream: + return await handle_stream_request( + external_interfaces=external_interfaces, + background_tasks=background_tasks, + request=request, + auth=auth, + model_endpoint_name=model_endpoint_name, + metric_metadata=metric_metadata, + ) + else: + return await handle_sync_request( + external_interfaces=external_interfaces, + background_tasks=background_tasks, + request=request, + auth=auth, + model_endpoint_name=model_endpoint_name, + metric_metadata=metric_metadata, + ) diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py index 1c471851..1ebe1f78 100644 --- a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -4,7 +4,7 @@ from model_engine_server.common.dtos.llms.chat_completion import ( ChatCompletionV2Request, - ChatCompletionV2Response, + ChatCompletionV2SyncResponse, ) from model_engine_server.common.dtos.llms.completion import ( CompletionOutput, @@ -197,7 +197,7 @@ class FilteredChatCompletionV2Request(ChatCompletionV2Request): # V2 DTOs for batch completions CompletionRequest: TypeAlias = Union[FilteredCompletionV2Request, FilteredChatCompletionV2Request] -CompletionResponse: TypeAlias = Union[CompletionV2Response, ChatCompletionV2Response] +CompletionResponse: TypeAlias = Union[CompletionV2Response, ChatCompletionV2SyncResponse] CreateBatchCompletionsV2RequestContent: TypeAlias = Union[ List[FilteredCompletionV2Request], List[FilteredChatCompletionV2Request] ] diff --git a/model-engine/model_engine_server/common/dtos/llms/chat_completion.py b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py index c573b526..bfb5ab09 100644 --- a/model-engine/model_engine_server/common/dtos/llms/chat_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py @@ -1,104 +1,21 @@ -from typing import Any, Dict, List, Optional +from typing import Optional, Union +from model_engine_server.common.dtos.llms.completion import StreamError +from model_engine_server.common.dtos.llms.vllm import VLLMChatCompletionAdditionalParams +from model_engine_server.common.pydantic_types import BaseModel, Field from model_engine_server.common.types.gen.openai import ( CreateChatCompletionRequest, CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, ) -from pydantic import Field +from sse_starlette import EventSourceResponse from typing_extensions import Annotated # Fields that are a part of OpenAI spec but are not supported by model engine UNSUPPORTED_FIELDS = ["service_tier"] -class VLLMAdditionalFields: - chat_template: Annotated[ - Optional[str], - Field( - default=None, - description=( - "A Jinja template to use for this conversion. " - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ), - ), - ] - chat_template_kwargs: Annotated[ - Optional[Dict[str, Any]], - Field( - default=None, - description=( - "Additional kwargs to pass to the template renderer. " - "Will be accessible by the chat template." - ), - ), - ] - - guided_json: Annotated[ - Optional[Dict[str, Any]], - Field( - default=None, - description="JSON schema for guided decoding. Only supported in vllm.", - ), - ] - - guided_regex: Annotated[ - Optional[str], - Field( - default=None, - description="Regex for guided decoding. Only supported in vllm.", - ), - ] - guided_choice: Annotated[ - Optional[List[str]], - Field( - default=None, - description="Choices for guided decoding. Only supported in vllm.", - ), - ] - - guided_grammar: Annotated[ - Optional[str], - Field( - default=None, - description="Context-free grammar for guided decoding. Only supported in vllm.", - ), - ] - - guided_decoding_backend: Annotated[ - Optional[str], - Field( - default=None, - description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be either " - "'outlines' / 'lm-format-enforcer'" - ), - ), - ] - - guided_whitespace_pattern: Annotated[ - Optional[str], - Field( - default=None, - description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding." - ), - ), - ] - - skip_special_tokens: Annotated[ - Optional[bool], - Field( - True, - description="Whether to skip special tokens in the output. Only supported in vllm.", - ), - ] - - -class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMAdditionalFields): +class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMChatCompletionAdditionalParams): model: Annotated[ str, Field( @@ -115,20 +32,22 @@ class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMAdditionalFields) ), ] - top_k: Annotated[ - Optional[int], - Field( - None, - ge=-1, - description="Controls the number of top tokens to consider. -1 means consider all tokens.", - ), - ] - include_stop_str_in_output: Annotated[ - Optional[bool], - Field(None, description="Whether to include the stop strings in output text."), - ] +ChatCompletionV2SyncResponse = CreateChatCompletionResponse +ChatCompletionV2SuccessChunk = CreateChatCompletionStreamResponse + + +class ChatCompletionV2ErrorChunk(BaseModel): + error: StreamError + + +ChatCompletionV2Chunk = Union[ChatCompletionV2SuccessChunk, ChatCompletionV2ErrorChunk] +ChatCompletionV2StreamResponse = ( + EventSourceResponse # EventSourceResponse[ChatCompletionV2Chunk | ChatCompletionV2ErrorChunk] +) +ChatCompletionV2Response = Union[ChatCompletionV2SyncResponse, ChatCompletionV2StreamResponse] -class ChatCompletionV2Response(CreateChatCompletionResponse): - pass +# This is a version of ChatCompletionV2Response that is used by pydantic to determine the response model +# Since EventSourceResponse isn't a pydanitc model, we need to use a Union of the two response types +ChatCompletionV2ResponseItem = Union[ChatCompletionV2SyncResponse, ChatCompletionV2Chunk] diff --git a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py index 5b870532..d44d7d0e 100644 --- a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py @@ -69,6 +69,10 @@ class CreateLLMModelEndpointV1Request(BaseModel): default_callback_url: Optional[HttpUrlStr] = None default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = True # LLM endpoints are public by default. + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) class CreateLLMModelEndpointV1Response(BaseModel): @@ -90,6 +94,10 @@ class GetLLMModelEndpointV1Response(BaseModel): num_shards: Optional[int] = None quantize: Optional[Quantization] = None checkpoint_path: Optional[str] = None + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) spec: Optional[GetModelEndpointV1Response] = None @@ -136,6 +144,10 @@ class UpdateLLMModelEndpointV1Request(BaseModel): default_callback_url: Optional[HttpUrlStr] = None default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = None + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) class UpdateLLMModelEndpointV1Response(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py new file mode 100644 index 00000000..700af2b1 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -0,0 +1,146 @@ +from typing import Any, Dict, List, Optional + +from model_engine_server.common.pydantic_types import Field +from typing_extensions import Annotated + +# This was last synced w/ vLLM v0.5.5 on 2024-09-03 + + +class VLLMSamplingParams: + best_of: Optional[int] = Field( + None, + description="""Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. This is treated as + the beam width when `use_beam_search` is True. By default, `best_of` + is set to `n`.""", + ) + top_k: Annotated[ + Optional[int], + Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ), + ] + min_p: Optional[float] = Field( + None, + description="""Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this.""", + ) + use_beam_search: Optional[bool] = Field( + None, + description="""Whether to use beam search for sampling.""", + ) + length_penalty: Optional[float] = Field( + default=None, + description="""Float that penalizes sequences based on their length. + Used in beam search.""", + ) + repetition_penalty: Optional[float] = Field( + default=None, + description="""Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens.""", + ) + early_stopping: Optional[bool] = Field( + None, + description="""Controls the stopping condition for beam search. It + accepts the following values: `True`, where the generation stops as + soon as there are `best_of` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very + unlikely to find better candidates; `"never"`, where the beam search + procedure only stops when there cannot be better candidates + (canonical beam search algorithm).""", + ) + stop_token_ids: Optional[List[int]] = Field( + default_factory=list, + description="""List of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are special tokens.""", + ) + include_stop_str_in_output: Annotated[ + Optional[bool], + Field( + None, + description="""Whether to include the stop strings in + output text. Defaults to False.""", + ), + ] + ignore_eos: Optional[bool] = Field( + None, + description="""Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated.""", + ) + min_tokens: Optional[int] = Field( + None, + description="""Minimum number of tokens to generate per output sequence + before EOS or stop_token_ids can be generated""", + ) + + skip_special_tokens: Optional[bool] = Field( + True, + description="Whether to skip special tokens in the output. Only supported in vllm.", + ) + + spaces_between_special_tokens: Optional[bool] = Field( + True, + description="Whether to add spaces between special tokens in the output. Only supported in vllm.", + ) + + +class VLLMChatCompletionAdditionalParams(VLLMSamplingParams): + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the model's tokenizer " + "does not define one and no override template is given" + ), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ) + + guided_json: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_regex: Optional[str] = Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ) + + guided_grammar: Optional[str] = Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ) + + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ) + + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ) diff --git a/model-engine/model_engine_server/common/dtos/tasks.py b/model-engine/model_engine_server/common/dtos/tasks.py index 874c50a8..98335277 100644 --- a/model-engine/model_engine_server/common/dtos/tasks.py +++ b/model-engine/model_engine_server/common/dtos/tasks.py @@ -49,6 +49,7 @@ class EndpointPredictV1Request(BaseModel): callback_url: Optional[str] = None callback_auth: Optional[CallbackAuth] = None return_pickled: bool = False + destination_path: Optional[str] = None class SyncEndpointPredictV1Request(EndpointPredictV1Request): diff --git a/model-engine/model_engine_server/common/types/gen/openai.py b/model-engine/model_engine_server/common/types/gen/openai.py index 9ac7a40f..5337b6e1 100644 --- a/model-engine/model_engine_server/common/types/gen/openai.py +++ b/model-engine/model_engine_server/common/types/gen/openai.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openai-spec.yaml -# timestamp: 2024-08-20T08:20:04+00:00 +# timestamp: 2024-09-10T16:20:49+00:00 from __future__ import annotations @@ -11,9 +11,9 @@ class Error(BaseModel): - code: str + code: Annotated[Optional[str], Field(...)] message: str - param: str + param: Annotated[Optional[str], Field(...)] type: str @@ -82,7 +82,7 @@ class Choice(BaseModel): ), ] index: int - logprobs: Logprobs + logprobs: Annotated[Optional[Logprobs], Field(...)] text: str @@ -97,7 +97,7 @@ class ImageUrl(BaseModel): Field(description="Either a URL of the image or the base64 encoded image data."), ] detail: Annotated[ - Optional[Literal["auto", "low", "high"]], + Literal["auto", "low", "high"], Field( "auto", description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding).", @@ -258,7 +258,7 @@ class ChatCompletionRequestFunctionMessage(BaseModel): Literal["function"], Field(description="The role of the messages author, in this case `function`."), ] - content: Annotated[str, Field(description="The contents of the function message.")] + content: Annotated[Optional[str], Field(description="The contents of the function message.")] name: Annotated[str, Field(description="The name of the function to call.")] @@ -499,7 +499,7 @@ class TopLogprob(BaseModel): ), ] bytes: Annotated[ - List[int], + Optional[List[int]], Field( description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." ), @@ -515,7 +515,7 @@ class ChatCompletionTokenLogprob(BaseModel): ), ] bytes: Annotated[ - List[int], + Optional[List[int]], Field( description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." ), @@ -530,12 +530,15 @@ class ChatCompletionTokenLogprob(BaseModel): class Logprobs2(BaseModel): content: Annotated[ - List[ChatCompletionTokenLogprob], + Optional[List[ChatCompletionTokenLogprob]], Field(description="A list of message content tokens with log probability information."), ] refusal: Annotated[ - List[ChatCompletionTokenLogprob], - Field(description="A list of message refusal tokens with log probability information."), + Optional[List[ChatCompletionTokenLogprob]], + Field( + None, + description="A list of message refusal tokens with log probability information.", + ), ] @@ -546,7 +549,7 @@ class Choice3(BaseModel): Field(None, description="Log probability information for the choice."), ] finish_reason: Annotated[ - Literal["stop", "length", "tool_calls", "content_filter", "function_call"], + Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]], Field( description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" ), @@ -626,7 +629,7 @@ class CreateImageRequest(BaseModel): ), ] model: Annotated[ - Optional[Union[str, Literal["dall-e-2", "dall-e-3"]]], + Optional[Union[Optional[str], Literal["dall-e-2", "dall-e-3"]]], Field( "dall-e-2", description="The model to use for image generation.", @@ -644,7 +647,7 @@ class CreateImageRequest(BaseModel): ), ] quality: Annotated[ - Optional[Literal["standard", "hd"]], + Literal["standard", "hd"], Field( "standard", description="The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`.", @@ -731,7 +734,7 @@ class CreateImageEditRequest(BaseModel): ), ] model: Annotated[ - Optional[Union[str, Literal["dall-e-2"]]], + Optional[Union[Optional[str], Literal["dall-e-2"]]], Field( "dall-e-2", description="The model to use for image generation. Only `dall-e-2` is supported at this time.", @@ -782,7 +785,7 @@ class CreateImageVariationRequest(BaseModel): ), ] model: Annotated[ - Optional[Union[str, Literal["dall-e-2"]]], + Optional[Union[Optional[str], Literal["dall-e-2"]]], Field( "dall-e-2", description="The model to use for image generation. Only `dall-e-2` is supported at this time.", @@ -828,7 +831,7 @@ class CreateImageVariationRequest(BaseModel): class CreateModerationRequest(BaseModel): input: Annotated[Union[str, List[str]], Field(description="The input text to classify")] model: Annotated[ - Optional[Union[str, Literal["text-moderation-latest", "text-moderation-stable"]]], + Union[str, Literal["text-moderation-latest", "text-moderation-stable"]], Field( "text-moderation-latest", description="Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`.\n\nThe default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`.\n", @@ -1085,21 +1088,21 @@ class NEpochs(RootModel[int]): class Hyperparameters(BaseModel): batch_size: Annotated[ - Optional[Union[Literal["auto"], BatchSize]], + Union[Literal["auto"], BatchSize], Field( "auto", description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", ), ] learning_rate_multiplier: Annotated[ - Optional[Union[Literal["auto"], LearningRateMultiplier]], + Union[Literal["auto"], LearningRateMultiplier], Field( "auto", description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", ), ] n_epochs: Annotated[ - Optional[Union[Literal["auto"], NEpochs]], + Union[Literal["auto"], NEpochs], Field( "auto", description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", @@ -1277,7 +1280,7 @@ class CreateEmbeddingRequest(BaseModel): ), ] encoding_format: Annotated[ - Optional[Literal["float", "base64"]], + Literal["float", "base64"], Field( "float", description="The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).", @@ -1341,21 +1344,21 @@ class CreateTranscriptionRequest(BaseModel): ), ] response_format: Annotated[ - Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]], + Literal["json", "text", "srt", "verbose_json", "vtt"], Field( "json", description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", ), ] temperature: Annotated[ - Optional[float], + float, Field( 0, description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", ), ] timestamp_granularities__: Annotated[ - Optional[List[Literal["word", "segment"]]], + List[Literal["word", "segment"]], Field( ["segment"], alias="timestamp_granularities[]", @@ -1447,14 +1450,14 @@ class CreateTranslationRequest(BaseModel): ), ] response_format: Annotated[ - Optional[str], + str, Field( "json", description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", ), ] temperature: Annotated[ - Optional[float], + float, Field( 0, description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", @@ -1506,14 +1509,14 @@ class CreateSpeechRequest(BaseModel): ), ] response_format: Annotated[ - Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]], + Literal["mp3", "opus", "aac", "flac", "wav", "pcm"], Field( "mp3", description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`.", ), ] speed: Annotated[ - Optional[float], + float, Field( 1.0, description="The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.", @@ -1659,7 +1662,7 @@ class Error1(BaseModel): code: Annotated[str, Field(description="A machine-readable error code.")] message: Annotated[str, Field(description="A human-readable error message.")] param: Annotated[ - str, + Optional[str], Field( description="The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific." ), @@ -1824,7 +1827,7 @@ class AssistantsApiResponseFormatOption( class CodeInterpreter(BaseModel): file_ids: Annotated[ - Optional[List[str]], + List[str], Field( [], description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool.\n", @@ -1851,7 +1854,7 @@ class ToolResources(BaseModel): class CodeInterpreter1(BaseModel): file_ids: Annotated[ - Optional[List[str]], + List[str], Field( [], description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", @@ -2003,7 +2006,7 @@ class ToolResources1(BaseModel): class CodeInterpreter2(BaseModel): file_ids: Annotated[ - Optional[List[str]], + List[str], Field( [], description="Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", @@ -2201,7 +2204,7 @@ class RunToolCallObject(BaseModel): class CodeInterpreter3(BaseModel): file_ids: Annotated[ - Optional[List[str]], + List[str], Field( [], description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", @@ -2256,13 +2259,13 @@ class ThreadObject(BaseModel): Field(description="The Unix timestamp (in seconds) for when the thread was created."), ] tool_resources: Annotated[ - ToolResources4, + Optional[ToolResources4], Field( description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), ] metadata: Annotated[ - Dict[str, Any], + Optional[Dict[str, Any]], Field( description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), @@ -2485,7 +2488,7 @@ class ImageFile(BaseModel): ), ] detail: Annotated[ - Optional[Literal["auto", "low", "high"]], + Literal["auto", "low", "high"], Field( "auto", description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", @@ -2507,7 +2510,7 @@ class ImageFile1(BaseModel): ), ] detail: Annotated[ - Optional[Literal["auto", "low", "high"]], + Literal["auto", "low", "high"], Field( "auto", description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", @@ -2529,7 +2532,7 @@ class ImageUrl1(BaseModel): ), ] detail: Annotated[ - Optional[Literal["auto", "low", "high"]], + Literal["auto", "low", "high"], Field( "auto", description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`", @@ -2551,7 +2554,7 @@ class ImageUrl2(BaseModel): ), ] detail: Annotated[ - Optional[Literal["auto", "low", "high"]], + Literal["auto", "low", "high"], Field( "auto", description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`.", @@ -2762,7 +2765,7 @@ class Function5(BaseModel): name: Annotated[str, Field(description="The name of the function.")] arguments: Annotated[str, Field(description="The arguments passed to the function.")] output: Annotated[ - str, + Optional[str], Field( description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." ), @@ -2876,13 +2879,13 @@ class VectorStoreObject(BaseModel): ), ] last_active_at: Annotated[ - int, + Optional[int], Field( description="The Unix timestamp (in seconds) for when the vector store was last active." ), ] metadata: Annotated[ - Dict[str, Any], + Optional[Dict[str, Any]], Field( description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), @@ -4096,7 +4099,7 @@ class CreateCompletionRequest(BaseModel): ), ] prompt: Annotated[ - Union[str, List[str], Prompt, Prompt1], + Optional[Union[Optional[str], List[str], Prompt, Prompt1]], Field( description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n" ), @@ -4177,7 +4180,7 @@ class CreateCompletionRequest(BaseModel): ), ] stop: Annotated[ - Optional[Union[str, Stop]], + Optional[Union[Optional[str], Stop]], Field( None, description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", @@ -4283,8 +4286,11 @@ class ChatCompletionMessageToolCalls(RootModel[List[ChatCompletionMessageToolCal class ChatCompletionResponseMessage(BaseModel): - content: Annotated[str, Field(description="The contents of the message.")] - refusal: Annotated[str, Field(description="The refusal message generated by the model.")] + content: Annotated[Optional[str], Field(description="The contents of the message.")] + refusal: Annotated[ + Optional[str], + Field(None, description="The refusal message generated by the model."), + ] tool_calls: Optional[ChatCompletionMessageToolCalls] = None role: Annotated[ Literal["assistant"], @@ -4308,7 +4314,10 @@ class Choice1(BaseModel): ] index: Annotated[int, Field(description="The index of the choice in the list of choices.")] message: ChatCompletionResponseMessage - logprobs: Annotated[Logprobs2, Field(description="Log probability information for the choice.")] + logprobs: Annotated[ + Optional[Logprobs2], + Field(None, description="Log probability information for the choice."), + ] class CreateChatCompletionResponse(BaseModel): @@ -4437,19 +4446,19 @@ class FineTuningJob(BaseModel): ), ] error: Annotated[ - Error1, + Optional[Error1], Field( description="For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure." ), ] fine_tuned_model: Annotated[ - str, + Optional[str], Field( description="The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running." ), ] finished_at: Annotated[ - int, + Optional[int], Field( description="The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running." ), @@ -4481,7 +4490,7 @@ class FineTuningJob(BaseModel): ), ] trained_tokens: Annotated[ - int, + Optional[int], Field( description="The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running." ), @@ -4493,7 +4502,7 @@ class FineTuningJob(BaseModel): ), ] validation_file: Annotated[ - str, + Optional[str], Field( description="The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents)." ), @@ -4530,14 +4539,14 @@ class AssistantObject(BaseModel): Field(description="The Unix timestamp (in seconds) for when the assistant was created."), ] name: Annotated[ - str, + Optional[str], Field( description="The name of the assistant. The maximum length is 256 characters.\n", max_length=256, ), ] description: Annotated[ - str, + Optional[str], Field( description="The description of the assistant. The maximum length is 512 characters.\n", max_length=512, @@ -4550,7 +4559,7 @@ class AssistantObject(BaseModel): ), ] instructions: Annotated[ - str, + Optional[str], Field( description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", max_length=256000, @@ -4571,7 +4580,7 @@ class AssistantObject(BaseModel): ), ] metadata: Annotated[ - Dict[str, Any], + Optional[Dict[str, Any]], Field( description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), @@ -4663,7 +4672,7 @@ class CreateAssistantRequest(BaseModel): ), ] tools: Annotated[ - Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], Field( [], description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", @@ -4743,7 +4752,7 @@ class ModifyAssistantRequest(BaseModel): ), ] tools: Annotated[ - Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], Field( [], description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", @@ -4865,39 +4874,39 @@ class RunObject(BaseModel): ), ] required_action: Annotated[ - RequiredAction, + Optional[RequiredAction], Field( description="Details on the action required to continue the run. Will be `null` if no action is required." ), ] last_error: Annotated[ - LastError, + Optional[LastError], Field( description="The last error associated with this run. Will be `null` if there are no errors." ), ] expires_at: Annotated[ - int, + Optional[int], Field(description="The Unix timestamp (in seconds) for when the run will expire."), ] started_at: Annotated[ - int, + Optional[int], Field(description="The Unix timestamp (in seconds) for when the run was started."), ] cancelled_at: Annotated[ - int, + Optional[int], Field(description="The Unix timestamp (in seconds) for when the run was cancelled."), ] failed_at: Annotated[ - int, + Optional[int], Field(description="The Unix timestamp (in seconds) for when the run failed."), ] completed_at: Annotated[ - int, + Optional[int], Field(description="The Unix timestamp (in seconds) for when the run was completed."), ] incomplete_details: Annotated[ - IncompleteDetails, + Optional[IncompleteDetails], Field( description="Details on why the run is incomplete. Will be `null` if the run is not incomplete." ), @@ -4922,7 +4931,7 @@ class RunObject(BaseModel): ), ] metadata: Annotated[ - Dict[str, Any], + Optional[Dict[str, Any]], Field( description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), @@ -4943,23 +4952,23 @@ class RunObject(BaseModel): ), ] max_prompt_tokens: Annotated[ - int, + Optional[int], Field( description="The maximum number of prompt tokens specified to have been used over the course of the run.\n", ge=256, ), ] max_completion_tokens: Annotated[ - int, + Optional[int], Field( description="The maximum number of completion tokens specified to have been used over the course of the run.\n", ge=256, ), ] - truncation_strategy: TruncationObject - tool_choice: AssistantsApiToolChoiceOption + truncation_strategy: Annotated[Optional[TruncationObject], Field(...)] + tool_choice: Annotated[Optional[AssistantsApiToolChoiceOption], Field(...)] parallel_tool_calls: ParallelToolCalls - response_format: AssistantsApiResponseFormatOption + response_format: Annotated[Optional[AssistantsApiResponseFormatOption], Field(...)] class ListRunsResponse(BaseModel): @@ -5254,7 +5263,7 @@ class ProjectServiceAccountCreateResponse(BaseModel): class ChatCompletionRequestAssistantMessage(BaseModel): content: Annotated[ - Optional[Union[str, Content2]], + Optional[Union[Optional[str], Content2]], Field( None, description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n", @@ -5348,7 +5357,7 @@ class CreateRunRequest(BaseModel): model: Annotated[ Optional[ Union[ - str, + Optional[str], Literal[ "gpt-4o", "gpt-4o-2024-08-06", @@ -5521,15 +5530,15 @@ class MessageObject(BaseModel): ), ] incomplete_details: Annotated[ - IncompleteDetails1, + Optional[IncompleteDetails1], Field(description="On an incomplete message, details about why the message is incomplete."), ] completed_at: Annotated[ - int, + Optional[int], Field(description="The Unix timestamp (in seconds) for when the message was completed."), ] incomplete_at: Annotated[ - int, + Optional[int], Field( description="The Unix timestamp (in seconds) for when the message was marked as incomplete." ), @@ -5550,25 +5559,25 @@ class MessageObject(BaseModel): Field(description="The content of the message in array of text and/or images."), ] assistant_id: Annotated[ - str, + Optional[str], Field( description="If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message." ), ] run_id: Annotated[ - str, + Optional[str], Field( description="The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints." ), ] attachments: Annotated[ - List[Attachment], + Optional[List[Attachment]], Field( description="A list of files attached to the message, and the tools they were added to." ), ] metadata: Annotated[ - Dict[str, Any], + Optional[Dict[str, Any]], Field( description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), @@ -5695,7 +5704,7 @@ class VectorStoreFileObject(BaseModel): ), ] last_error: Annotated[ - LastError2, + Optional[LastError2], Field( description="The last error associated with this vector store file. Will be `null` if there are no errors." ), @@ -5908,7 +5917,7 @@ class CreateChatCompletionRequest(BaseModel): ), ] stop: Annotated[ - Optional[Union[str, Stop1]], + Union[Optional[str], Stop1], Field( None, description="Up to 4 sequences where the API will stop generating further tokens.\n", @@ -5997,7 +6006,7 @@ class CreateThreadAndRunRequest(BaseModel): model: Annotated[ Optional[ Union[ - str, + Optional[str], Literal[ "gpt-4o", "gpt-4o-2024-08-06", @@ -6158,31 +6167,31 @@ class RunStepObject(BaseModel): Field(description="The details of the run step."), ] last_error: Annotated[ - LastError1, + Optional[LastError1], Field( description="The last error associated with this run step. Will be `null` if there are no errors." ), ] expired_at: Annotated[ - int, + Optional[int], Field( description="The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired." ), ] cancelled_at: Annotated[ - int, + Optional[int], Field(description="The Unix timestamp (in seconds) for when the run step was cancelled."), ] failed_at: Annotated[ - int, + Optional[int], Field(description="The Unix timestamp (in seconds) for when the run step failed."), ] completed_at: Annotated[ - int, + Optional[int], Field(description="The Unix timestamp (in seconds) for when the run step completed."), ] metadata: Annotated[ - Dict[str, Any], + Optional[Dict[str, Any]], Field( description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py index 9a67eb65..5033d8ad 100644 --- a/model-engine/model_engine_server/db/base.py +++ b/model-engine/model_engine_server/db/base.py @@ -286,27 +286,34 @@ def get_session_async_null_pool(self) -> async_sessionmaker: return self.sessions.session_async_null_pool.session -db_manager = DBManager(infra_config()) +db_manager: Optional[DBManager] = None + + +def get_db_manager(): + global db_manager + if db_manager is None: + db_manager = DBManager(infra_config()) + return db_manager def get_session(): - return db_manager.get_session_sync() + return get_db_manager().get_session_sync() def get_session_read_only(): - return db_manager.get_session_sync_ro() + return get_db_manager().get_session_sync_ro() def get_session_async(): - return db_manager.get_session_async() + return get_db_manager().get_session_async() def get_session_async_null_pool(): - return db_manager.get_session_async_null_pool() + return get_db_manager().get_session_async_null_pool() def get_session_read_only_async(): - return db_manager.get_session_async_ro() + return get_db_manager().get_session_async_ro() Base = declarative_base() diff --git a/model-engine/model_engine_server/db/migrations/README b/model-engine/model_engine_server/db/migrations/README new file mode 100644 index 00000000..34a0d901 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/README @@ -0,0 +1,43 @@ +# Setup + +We introduce alembic by +1. dumping the current db schemas into 'initial.sql' via pg_dump + +``` +pg_dump -h $HOST -U postgres -O -s -d $DB_NAME -n hosted_model_inference -n model -f initial.sql +``` + +2. writing an initial revision that reads and applies intial.sql script + +``` +alembic revision -m “initial” +``` + +3. Stamping the current revision to our production db to avoid actually running it on production + +``` +alembic stamp fa3267c80731 +``` + + +# Test db migration from scratch + +## Set up postgresql + +``` +docker pull postgres +docker run --name postgres -e POSTGRES_PASSWORD=password -d -p 5432:5432 postgres +``` + +## Run migration script + +``` +PYTHONPATH="${PYTHONPATH}:" +ML_INFRA_DATABASE_URL="postgresql://postgres:password@localhost:54320/postgres" bash run_database_migration.sh +``` + + +To reset db, you can recreate docker or run +``` +psql "$ML_INFRA_DATABASE_URL" -c "DROP table if exists public.alembic_version_model_engine; DROP schema if exists hosted_model_inference CASCADE; DROP schema if exists model CASCADE" +``` diff --git a/model-engine/model_engine_server/db/migrations/alembic/README b/model-engine/model_engine_server/db/migrations/alembic/README deleted file mode 100644 index cfedbcc5..00000000 --- a/model-engine/model_engine_server/db/migrations/alembic/README +++ /dev/null @@ -1,20 +0,0 @@ -# Setup - -We introduce alembic by -1. dumping the current db schemas into 'initial.sql' via pg_dump - -``` -pg_dump -h $HOST -U postgres -O -s -d $DB_NAME -n hosted_model_inference -n model -f initial.sql -``` - -2. writing an initial revision that reads and applies intial.sql script - -``` -alembic revision -m “initial” -``` - -3. Stamping the current revision to our production db to avoid actually running it on production - -``` -alembic stamp fa3267c80731 -``` diff --git a/model-engine/model_engine_server/db/migrations/alembic/env.py b/model-engine/model_engine_server/db/migrations/alembic/env.py index 3f4b73b3..24c67aa5 100644 --- a/model-engine/model_engine_server/db/migrations/alembic/env.py +++ b/model-engine/model_engine_server/db/migrations/alembic/env.py @@ -7,7 +7,10 @@ from sqlalchemy import engine_from_config, pool env = os.environ.get("ENV") -assert env is not None, "Expected ENV to be a nonempty environment variable." +if env is None: + assert ( + os.getenv("ML_INFRA_DATABASE_URL") is not None + ), "Expected ML_INFRA_DATABASE_URL to be set if ENV is not set." # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -53,6 +56,7 @@ def run_migrations_offline() -> None: literal_binds=True, dialect_opts={"paramstyle": "named"}, version_table=ALEMBIC_TABLE_NAME, + version_table_schema="public", ) try: @@ -81,6 +85,7 @@ def run_migrations_online() -> None: connection=connection, target_metadata=target_metadata, version_table=ALEMBIC_TABLE_NAME, + version_table_schema="public", ) try: diff --git a/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py index acad7cff..efee8963 100644 --- a/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py +++ b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py @@ -6,6 +6,11 @@ """ +from pathlib import Path + +INITIAL_MIGRATION_PATH = Path(__file__).parent / "../../initial.sql" + + import sqlalchemy as sa from alembic import op @@ -17,7 +22,7 @@ def upgrade() -> None: - with open("migrations/initial.sql") as fd: + with open(INITIAL_MIGRATION_PATH) as fd: op.execute(fd.read()) diff --git a/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1831-b574e9711e35_chat_completion_add_extra_routes.py b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1831-b574e9711e35_chat_completion_add_extra_routes.py new file mode 100644 index 00000000..43279e0f --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1831-b574e9711e35_chat_completion_add_extra_routes.py @@ -0,0 +1,32 @@ +"""chat completion - Add extra_routes + +Revision ID: b574e9711e35 +Revises: fa3267c80731 +Create Date: 2024-09-09 18:31:59.422082 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import ARRAY + +# revision identifiers, used by Alembic. +revision = "b574e9711e35" +down_revision = "fa3267c80731" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "bundles", + sa.Column("runnable_image_extra_routes", ARRAY(sa.Text), nullable=True), + schema="hosted_model_inference", + ) + + +def downgrade(): + op.drop_column( + "bundles", + "runnable_image_extra_routes", + schema="hosted_model_inference", + ) diff --git a/model-engine/model_engine_server/db/migrations/initial.sql b/model-engine/model_engine_server/db/migrations/initial.sql index 93655bf3..65728431 100644 --- a/model-engine/model_engine_server/db/migrations/initial.sql +++ b/model-engine/model_engine_server/db/migrations/initial.sql @@ -611,90 +611,6 @@ ALTER TABLE ONLY model.model_versions ADD CONSTRAINT model_versions_model_id_fkey FOREIGN KEY (model_id) REFERENCES model.models(id); --- --- Name: SCHEMA hosted_model_inference; Type: ACL; Schema: -; Owner: - --- - -GRANT USAGE ON SCHEMA hosted_model_inference TO fivetran; - - --- --- Name: SCHEMA model; Type: ACL; Schema: -; Owner: - --- - -GRANT USAGE ON SCHEMA model TO fivetran; - - --- --- Name: TABLE batch_jobs; Type: ACL; Schema: hosted_model_inference; Owner: - --- - -GRANT SELECT ON TABLE hosted_model_inference.batch_jobs TO fivetran; - - --- --- Name: TABLE bundles; Type: ACL; Schema: hosted_model_inference; Owner: - --- - -GRANT SELECT ON TABLE hosted_model_inference.bundles TO fivetran; - - --- --- Name: TABLE docker_image_batch_job_bundles; Type: ACL; Schema: hosted_model_inference; Owner: - --- - -GRANT SELECT ON TABLE hosted_model_inference.docker_image_batch_job_bundles TO fivetran; - - --- --- Name: TABLE endpoints; Type: ACL; Schema: hosted_model_inference; Owner: - --- - -GRANT SELECT ON TABLE hosted_model_inference.endpoints TO fivetran; - - --- --- Name: TABLE triggers; Type: ACL; Schema: hosted_model_inference; Owner: - --- - -GRANT SELECT ON TABLE hosted_model_inference.triggers TO fivetran; - - --- --- Name: TABLE model_artifacts; Type: ACL; Schema: model; Owner: - --- - -GRANT SELECT ON TABLE model.model_artifacts TO fivetran; - - --- --- Name: TABLE model_versions; Type: ACL; Schema: model; Owner: - --- - -GRANT SELECT ON TABLE model.model_versions TO fivetran; - - --- --- Name: TABLE models; Type: ACL; Schema: model; Owner: - --- - -GRANT SELECT ON TABLE model.models TO fivetran; - - --- --- Name: DEFAULT PRIVILEGES FOR TABLES; Type: DEFAULT ACL; Schema: hosted_model_inference; Owner: - --- - -ALTER DEFAULT PRIVILEGES FOR ROLE postgres IN SCHEMA hosted_model_inference GRANT SELECT ON TABLES TO fivetran; - - --- --- Name: DEFAULT PRIVILEGES FOR TABLES; Type: DEFAULT ACL; Schema: model; Owner: - --- - -ALTER DEFAULT PRIVILEGES FOR ROLE postgres IN SCHEMA model GRANT SELECT ON TABLES TO fivetran; - - -- -- PostgreSQL database dump complete -- diff --git a/model-engine/model_engine_server/db/migrations/run_database_migration.sh b/model-engine/model_engine_server/db/migrations/run_database_migration.sh new file mode 100755 index 00000000..8b25f20e --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/run_database_migration.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Get the directory of this script +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +# Change directory to the directory of this script +cd $DIR + +# Runs database migration +alembic upgrade head \ No newline at end of file diff --git a/model-engine/model_engine_server/db/models/hosted_model_inference.py b/model-engine/model_engine_server/db/models/hosted_model_inference.py index 8e028a2c..72f7806f 100644 --- a/model-engine/model_engine_server/db/models/hosted_model_inference.py +++ b/model-engine/model_engine_server/db/models/hosted_model_inference.py @@ -146,6 +146,7 @@ class Bundle(Base): runnable_image_env = Column(JSON, nullable=True) runnable_image_protocol = Column(Text, nullable=True) runnable_image_readiness_initial_delay_seconds = Column(Integer, nullable=True) + runnable_image_extra_routes = Column(ARRAY(Text), nullable=True) # Streaming Enhanced Runnable Image fields streaming_enhanced_runnable_image_streaming_command = Column(ARRAY(Text), nullable=True) @@ -205,6 +206,7 @@ def __init__( runnable_image_env: Optional[Dict[str, Any]] = None, runnable_image_protocol: Optional[str] = None, runnable_image_readiness_initial_delay_seconds: Optional[int] = None, + runnable_image_extra_routes: Optional[List[str]] = None, # Streaming Enhanced Runnable Image fields streaming_enhanced_runnable_image_streaming_command: Optional[List[str]] = None, streaming_enhanced_runnable_image_streaming_predict_route: Optional[str] = None, @@ -260,6 +262,7 @@ def __init__( self.runnable_image_healthcheck_route = runnable_image_healthcheck_route self.runnable_image_env = runnable_image_env self.runnable_image_protocol = runnable_image_protocol + self.runnable_image_extra_routes = runnable_image_extra_routes self.runnable_image_readiness_initial_delay_seconds = ( runnable_image_readiness_initial_delay_seconds ) @@ -632,7 +635,9 @@ class BatchJob(Base): created_by = Column(String(SHORT_STRING), index=True, nullable=False) owner = Column(String(SHORT_STRING), index=True, nullable=False) model_bundle_id = Column( - Text, ForeignKey("hosted_model_inference.bundles.id", ondelete="SET NULL"), nullable=False + Text, + ForeignKey("hosted_model_inference.bundles.id", ondelete="SET NULL"), + nullable=False, ) model_endpoint_id = Column( Text, ForeignKey("hosted_model_inference.endpoints.id"), nullable=True diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index 4da8c278..937a739f 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -29,3 +29,4 @@ class LLMMetadata: num_shards: int quantize: Optional[Quantization] = None checkpoint_path: Optional[str] = None + chat_template_override: Optional[str] = None diff --git a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py index d5d0a5f3..32c5c4e5 100644 --- a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py @@ -157,6 +157,7 @@ class RunnableImageLike(BaseModel, ABC): env: Optional[Dict[str, str]] = None protocol: Literal["http"] # TODO: add support for other protocols (e.g. grpc) readiness_initial_delay_seconds: int = 120 + extra_routes: List[str] = Field(default_factory=list) class RunnableImageFlavor(RunnableImageLike): diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py index 5b81a68e..075b4823 100644 --- a/model-engine/model_engine_server/domain/exceptions.py +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -80,6 +80,12 @@ class EndpointUnsupportedInferenceTypeException(DomainException): """ +class EndpointUnsupportedRequestException(DomainException): + """ + Throw if the request is unsupported by the endpoint. + """ + + class EndpointResourceInvalidRequestException(DomainException): """ Thrown if the endpoint resource requests are invalid. diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 2f6c9e37..84c50f69 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -4,6 +4,7 @@ Read model endpoint creation logs: GET model-endpoints//creation-logs """ +import base64 import datetime import json import math @@ -17,6 +18,9 @@ from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests from model_engine_server.common.dtos.llms import ( + ChatCompletionV2Request, + ChatCompletionV2SuccessChunk, + ChatCompletionV2SyncResponse, CompletionOutput, CompletionStreamOutput, CompletionStreamV1Request, @@ -69,6 +73,7 @@ ModelEndpointType, Quantization, RunnableImageFlavor, + RunnableImageLike, StreamingEnhancedRunnableImageFlavor, ) from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( @@ -79,6 +84,7 @@ EndpointInfraStateNotFound, EndpointLabelsException, EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, FailToInferHardwareException, InvalidRequestException, LatestImageTagNotFoundException, @@ -123,8 +129,13 @@ logger = make_logger(logger_name()) +OPENAI_CHAT_COMPLETION_PATH = "/v1/chat/completions" +CHAT_TEMPLATE_MAX_LENGTH = 10_000 +CHAT_SUPPORTED_INFERENCE_FRAMEWORKS = [LLMInferenceFramework.VLLM] + LLM_METADATA_KEY = "_llm" RESERVED_METADATA_KEYS = [LLM_METADATA_KEY, CONVERTED_FROM_ARTIFACT_LIKE_KEY] +VLLM_MODEL_WEIGHTS_FOLDER = "model_files" INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = { LLMInferenceFramework.DEEPSPEED: "instant-llm", @@ -287,7 +298,6 @@ def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRep async def _get_latest_batch_v2_tag(inference_framework: LLMInferenceFramework) -> str: config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME) - print(config_map) batch_key = f"{inference_framework}_batch_v2" if batch_key not in config_map: raise LatestImageTagNotFoundException( @@ -350,6 +360,7 @@ def _model_endpoint_entity_to_get_llm_model_endpoint_response( num_shards=llm_metadata["num_shards"], quantize=llm_metadata.get("quantize"), checkpoint_path=llm_metadata.get("checkpoint_path"), + chat_template_override=llm_metadata.get("chat_template_override"), spec=model_endpoint_entity_to_get_model_endpoint_response(model_endpoint), ) return response @@ -387,6 +398,21 @@ def validate_quantization( ) +def validate_chat_template( + chat_template: Optional[str], inference_framework: LLMInferenceFramework +) -> None: + if chat_template is not None: + if len(chat_template) > CHAT_TEMPLATE_MAX_LENGTH: + raise ObjectHasInvalidValueException( + f"Chat template length must be less than {CHAT_TEMPLATE_MAX_LENGTH}." + ) + + if inference_framework != LLMInferenceFramework.VLLM: + raise ObjectHasInvalidValueException( + f"Chat template is only supported for inference framework {LLMInferenceFramework.VLLM}." + ) + + def validate_checkpoint_path_uri(checkpoint_path: str) -> None: if ( not checkpoint_path.startswith("s3://") @@ -425,6 +451,13 @@ def validate_checkpoint_files(checkpoint_files: List[str]) -> None: raise ObjectHasInvalidValueException("No safetensors found in the checkpoint path.") +def encode_template(chat_template: str) -> str: + """Base64 encode the chat template to safely pass it to bash.""" + + encoded = base64.b64encode(chat_template.encode("utf-8")).decode("utf-8") + return encoded + + class CreateLLMModelBundleV1UseCase: def __init__( self, @@ -463,6 +496,7 @@ async def execute( num_shards: int, quantize: Optional[Quantization], checkpoint_path: Optional[str], + chat_template_override: Optional[str], ) -> ModelBundle: if source == LLMSource.HUGGING_FACE: self.check_docker_image_exists_for_image_tag( @@ -495,6 +529,7 @@ async def execute( num_shards, quantize, checkpoint_path, + chat_template_override, ) elif framework == LLMInferenceFramework.LIGHTLLM: bundle_id = await self.create_lightllm_bundle( @@ -632,7 +667,7 @@ def load_model_weights_sub_commands_s3( validate_checkpoint_files(checkpoint_files) # filter to configs ('*.model' and '*.json') and weights ('*.safetensors') - file_selection_str = "--include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*'" + file_selection_str = '--include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*"' subcommands.append( f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" ) @@ -783,10 +818,14 @@ async def create_vllm_bundle( num_shards: int, quantize: Optional[Quantization], checkpoint_path: Optional[str], + chat_template_override: Optional[str], ): command = [] subcommands = [] + print("here") + print(chat_template_override) + checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) # added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder if "mistral" in model_name: @@ -800,25 +839,34 @@ async def create_vllm_bundle( final_weights_folder, ) - subcommands.append( - f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005" - ) + vllm_cmd = f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005" - if quantize: - if quantize == Quantization.AWQ: - subcommands[-1] = subcommands[-1] + f" --quantization {quantize}" - else: + chat_template_cmd = None + if chat_template_override: + # We encode the chat template as base64 to avoid issues with special characters + # and decode it via bash + chat_template_cmd = f'export CHAT_TEMPLATE=$(echo "{encode_template(chat_template_override)}" | base64 --decode)' + subcommands.append(chat_template_cmd) + vllm_cmd += ' --chat-template "$CHAT_TEMPLATE"' + + if quantize: # pragma: no cover + if quantize != Quantization.AWQ: raise InvalidRequestException(f"Quantization {quantize} is not supported by vLLM.") + vllm_cmd += f" --quantization {quantize}" + if hmi_config.sensitive_log_mode: # pragma: no cover - subcommands[-1] = subcommands[-1] + " --disable-log-requests" + vllm_cmd += " --disable-log-requests" - if "llama-3-70b" in model_name: - subcommands[-1] = subcommands[-1] + " --gpu-memory-utilization 0.95 --enforce-eager" + additional_args = infer_addition_engine_args_from_model_name(model_name) - if "gemma-2" in model_name: - subcommands[-1] = subcommands[-1] + " --attention-backend FLASHINFER" + if additional_args.max_gpu_memory_utilization: + vllm_cmd += f" --gpu-memory-utilization {additional_args.max_gpu_memory_utilization} --enforce-eager" + if additional_args.attention_backend: + vllm_cmd += " --attention-backend FLASHINFER" + + subcommands.append(vllm_cmd) command = [ "/bin/bash", "-c", @@ -842,6 +890,7 @@ async def create_vllm_bundle( healthcheck_route="/health", predict_route="/predict", streaming_predict_route="/stream", + extra_routes=[OPENAI_CHAT_COMPLETION_PATH], env={}, ), metadata={}, @@ -1027,12 +1076,14 @@ async def execute( ) if request.labels is None: raise EndpointLabelsException("Endpoint labels cannot be None!") + validate_labels(request.labels) validate_billing_tags(request.billing_tags) validate_post_inference_hooks(user, request.post_inference_hooks) validate_model_name(request.model_name, request.inference_framework) validate_num_shards(request.num_shards, request.inference_framework, request.gpus) validate_quantization(request.quantize, request.inference_framework) + validate_chat_template(request.chat_template_override, request.inference_framework) if request.inference_framework in [ LLMInferenceFramework.TEXT_GENERATION_INFERENCE, @@ -1061,6 +1112,7 @@ async def execute( num_shards=request.num_shards, quantize=request.quantize, checkpoint_path=request.checkpoint_path, + chat_template_override=request.chat_template_override, ) validate_resource_requests( bundle=bundle, @@ -1091,6 +1143,7 @@ async def execute( num_shards=request.num_shards, quantize=request.quantize, checkpoint_path=request.checkpoint_path, + chat_template_override=request.chat_template_override, ) ) @@ -1273,6 +1326,7 @@ async def execute( or request.num_shards or request.quantize or request.checkpoint_path + or request.chat_template_override ): llm_metadata = (model_endpoint.record.metadata or {}).get(LLM_METADATA_KEY, {}) inference_framework = llm_metadata["inference_framework"] @@ -1298,6 +1352,10 @@ async def execute( request.gpus or infra_state.resource_state.gpus, ) validate_quantization(quantize, inference_framework) + validate_chat_template(request.chat_template_override, inference_framework) + chat_template_override = request.chat_template_override or llm_metadata.get( + "chat_template_override" + ) bundle = await self.create_llm_model_bundle_use_case.execute( user, @@ -1310,6 +1368,7 @@ async def execute( num_shards=num_shards, quantize=quantize, checkpoint_path=checkpoint_path, + chat_template_override=chat_template_override, ) metadata = endpoint_record.metadata or {} @@ -1322,6 +1381,7 @@ async def execute( num_shards=num_shards, quantize=quantize, checkpoint_path=checkpoint_path, + chat_template_override=chat_template_override, ) ) endpoint_record.metadata = metadata @@ -1716,10 +1776,10 @@ async def execute( ): raise ObjectNotAuthorizedException - if ( - model_endpoint.record.endpoint_type is not ModelEndpointType.SYNC - and model_endpoint.record.endpoint_type is not ModelEndpointType.STREAMING - ): + if model_endpoint.record.endpoint_type not in [ + ModelEndpointType.SYNC, + ModelEndpointType.STREAMING, + ]: raise EndpointUnsupportedInferenceTypeException( f"Endpoint {model_endpoint_name} does not serve sync requests." ) @@ -2378,6 +2438,256 @@ async def _response_chunk_generator( # raising an exception if it is not one of the frameworks handled above. +def validate_endpoint_supports_chat_completion( + endpoint: ModelEndpoint, endpoint_content: GetLLMModelEndpointV1Response +): + if endpoint_content.inference_framework not in CHAT_SUPPORTED_INFERENCE_FRAMEWORKS: + raise EndpointUnsupportedInferenceTypeException( + f"The endpoint's inference framework ({endpoint_content.inference_framework}) does not support chat completion." + ) + + if ( + not isinstance(endpoint.record.current_model_bundle.flavor, RunnableImageLike) + or OPENAI_CHAT_COMPLETION_PATH + not in endpoint.record.current_model_bundle.flavor.extra_routes + ): + raise EndpointUnsupportedRequestException("Endpoint does not support chat completion") + + +class ChatCompletionSyncV2UseCase: + """ + Use case for running a chat completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + async def execute( + self, user: User, model_endpoint_name: str, request: ChatCompletionV2Request + ) -> ChatCompletionV2SyncResponse: # pragma: no cover + """ + Runs the use case to create a sync inference task. + + Args: + user: The user who is creating the sync inference task. + model_endpoint_name: The name of the model endpoint for the task. + request: The body of the request to forward to the endpoint. + + Returns: + A response object that contains the status and result of the task. + + Raises: + ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectNotAuthorizedException: If the owner does not own the model endpoint. + """ + + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if ( + model_endpoint.record.endpoint_type is not ModelEndpointType.SYNC + and model_endpoint.record.endpoint_type is not ModelEndpointType.STREAMING + ): + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} does not serve sync requests." + ) + + inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + validate_endpoint_supports_chat_completion(model_endpoint, endpoint_content) + + # if inference framework is VLLM, we need to set the model to use the weights folder + if endpoint_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + + inference_request = SyncEndpointPredictV1Request( + args=request.model_dump(exclude_none=True), + destination_path=OPENAI_CHAT_COMPLETION_PATH, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + try: + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + raise UpstreamServiceError( + status_code=500, + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), + ) + + output = json.loads(predict_result.result["result"]) + # reset model name to correct value + output["model"] = model_endpoint.record.name + return ChatCompletionV2SyncResponse.model_validate(output) + except UpstreamServiceError as exc: + # Expect upstream inference service to handle bulk of input validation + if 400 <= exc.status_code < 500: + raise InvalidRequestException(exc.content) + raise exc + + +class ChatCompletionStreamV2UseCase: + """ + Use case for running a chat completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + async def execute( + self, model_endpoint_name: str, request: ChatCompletionV2Request, user: User + ) -> AsyncIterable[ChatCompletionV2SuccessChunk]: # pragma: no cover + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if model_endpoint.record.endpoint_type != ModelEndpointType.STREAMING: + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} is not a streaming endpoint." + ) + + inference_gateway = ( + self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() + ) + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + validate_endpoint_supports_chat_completion(model_endpoint, model_content) + + # if inference framework is VLLM, we need to set the model to use the weights folder + if model_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + + inference_request = SyncEndpointPredictV1Request( + args=request.model_dump(exclude_none=True), + destination_path=OPENAI_CHAT_COMPLETION_PATH, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + + return self._response_chunk_generator( + request_id=request_id, + model_endpoint=model_endpoint, + model_content=model_content, + inference_gateway=inference_gateway, + inference_request=inference_request, + ) + + async def _response_chunk_generator( + self, + request_id: Optional[str], + model_endpoint: ModelEndpoint, + model_content: GetLLMModelEndpointV1Response, + inference_gateway: StreamingModelEndpointInferenceGateway, + inference_request: SyncEndpointPredictV1Request, + ) -> AsyncIterable[ChatCompletionV2SuccessChunk]: + """ + Async generator yielding tokens to stream for the completions response. Should only be called when + returned directly by execute(). + """ + try: + predict_result = inference_gateway.streaming_predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + ) + except UpstreamServiceError as exc: + # Expect upstream inference service to handle bulk of input validation + if 400 <= exc.status_code < 500: + raise InvalidRequestException(str(exc)) + + raise exc + + async for res in predict_result: + if not res.status == TaskStatus.SUCCESS or res.result is None: + raise UpstreamServiceError( + status_code=500, + content=(res.traceback.encode("utf-8") if res.traceback is not None else b""), + ) + else: + result = res.result["result"] + # Reset model name to correct value + if "DONE" in result: + continue + result["model"] = model_endpoint.record.name + yield ChatCompletionV2SuccessChunk.model_validate(result) + + class ModelDownloadV1UseCase: def __init__( self, @@ -2659,7 +2969,9 @@ async def execute( gpu_type=hardware.gpu_type, ) - if engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1: + if ( + engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1 + ): # pragma: no cover raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 20476339..3cc53d7c 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -1,13 +1,15 @@ +import ast import json import os import time from dataclasses import dataclass from pathlib import Path -from typing import Any, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Iterable, List, Optional, Sequence, Tuple import requests import sseclient import yaml +from fastapi import HTTPException from fastapi.responses import JSONResponse from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.common import get_endpoint_config @@ -335,7 +337,7 @@ class StreamingForwarder(ModelEngineSerializationMixin): serialize_results_as_string: bool post_inference_hooks_handler: PostInferenceHooksHandler # unused for now - def __call__(self, json_payload: Any) -> Iterator[Any]: + def __call__(self, json_payload: Any) -> Iterable[Any]: json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload @@ -350,6 +352,11 @@ def __call__(self, json_payload: Any) -> Iterator[Any]: }, stream=True, ) + + if response.status_code != 200: + print(response.json()) + raise HTTPException(status_code=response.status_code, detail=response.json()) + except Exception: logger.exception( f"Failed to get response for request ({json_payload_repr}) " @@ -358,8 +365,14 @@ def __call__(self, json_payload: Any) -> Iterator[Any]: raise client = sseclient.SSEClient(response) - for event in client.events(): - yield self.get_response_payload_stream(using_serialize_results_as_string, event.data) + + def event_stream(): + for event in client.events(): + yield self.get_response_payload_stream( + using_serialize_results_as_string, event.data + ) + + return event_stream() @dataclass(frozen=True) @@ -526,7 +539,8 @@ def _cast_value(value: Any) -> Any: if value.isdigit(): return int(value) elif value.startswith("[") and value.endswith("]"): - return [_cast_value(v) for v in value[1:-1].split(",")] + # Can't use json because it doesn't support single quotes + return ast.literal_eval(value) else: return value diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index adfbde59..00a0c19c 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -1,11 +1,11 @@ import argparse import asyncio -import json import os import signal from functools import lru_cache from typing import Any, Dict, Optional +import orjson import uvicorn from fastapi import BackgroundTasks, Depends, FastAPI from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter @@ -105,12 +105,11 @@ async def stream( else: logger.debug(f"Received request: {payload}") - # has internal error logging for each processing stage responses = forwarder(payload) async def event_generator(): for response in responses: - yield {"data": json.dumps(response)} + yield {"data": orjson.dumps(response).decode("utf-8")} return EventSourceResponse(event_generator()) diff --git a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py index d68d539c..a1236138 100644 --- a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py @@ -14,7 +14,8 @@ def _get_abs_container_client(bucket: str) -> ContainerClient: blob_service_client = BlobServiceClient( - f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", DefaultAzureCredential() + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), ) return blob_service_client.get_container_client(container=bucket) diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index ae790eef..c6e8837d 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -14,6 +14,7 @@ from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( + InvalidRequestException, NoHealthyUpstreamException, TooManyRequestsException, UpstreamServiceError, @@ -44,7 +45,7 @@ ) -def _get_streaming_endpoint_url(deployment_name: str) -> str: +def _get_streaming_endpoint_url(deployment_name: str, path: str = "/stream") -> str: if CIRCLECI: # Circle CI: a NodePort is used to expose the service # The IP address is obtained from `minikube ip`. @@ -58,7 +59,7 @@ def _get_streaming_endpoint_url(deployment_name: str) -> str: protocol = "http" # no need to hit external DNS resolution if we're w/in the k8s cluster hostname = f"{deployment_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" - return f"{protocol}://{hostname}/stream" + return f"{protocol}://{hostname}{path}" def _serialize_json(data) -> str: @@ -89,6 +90,7 @@ async def make_single_request(self, request_url: str, payload_json: Dict[str, An headers={"Content-Type": "application/json"}, ) status = aio_resp.status + print(status) if status == 200: async with EventSource(response=aio_resp) as event_source: async for event in event_source: @@ -139,7 +141,8 @@ async def make_request_with_retries( try: async for attempt in AsyncRetrying( stop=stop_any( - stop_after_attempt(num_retries + 1), stop_after_delay(timeout_seconds) + stop_after_attempt(num_retries + 1), + stop_after_delay(timeout_seconds), ), retry=retry_if_exception_type( ( @@ -156,7 +159,10 @@ async def make_request_with_retries( ), ): with attempt: - logger.info(f"Retry number {attempt.retry_state.attempt_number}") + if attempt.retry_state.attempt_number > 1: + logger.info( + f"Retry number {attempt.retry_state.attempt_number}" + ) # pragma: no cover response = self.make_single_request(request_url, payload_json) async for item in response: yield orjson.loads(item) @@ -186,7 +192,9 @@ async def make_request_with_retries( async def streaming_predict( self, topic: str, predict_request: SyncEndpointPredictV1Request ) -> AsyncIterable[SyncEndpointPredictV1Response]: - deployment_url = _get_streaming_endpoint_url(topic) + deployment_url = _get_streaming_endpoint_url( + topic, path=predict_request.destination_path or "/stream" + ) try: timeout_seconds = ( @@ -201,14 +209,22 @@ async def streaming_predict( ) response = self.make_request_with_retries( request_url=deployment_url, - payload_json=predict_request.dict(), + payload_json=predict_request.model_dump(exclude_none=True), timeout_seconds=timeout_seconds, num_retries=num_retries, ) + print(response) async for item in response: yield SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=item) except UpstreamServiceError as exc: logger.error(f"Service error on streaming task: {exc.content!r}") + + if exc.status_code == 400: + error_json = orjson.loads(exc.content.decode("utf-8")) + if "result" in error_json: + error_json = orjson.loads(error_json["result"]) + raise InvalidRequestException(error_json) + try: error_json = orjson.loads(exc.content.decode("utf-8")) result_traceback = ( diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index 48ae3410..f7781ea3 100644 --- a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -13,6 +13,7 @@ from model_engine_server.core.config import infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ( + InvalidRequestException, NoHealthyUpstreamException, TooManyRequestsException, UpstreamServiceError, @@ -41,7 +42,7 @@ ) -def _get_sync_endpoint_url(deployment_name: str) -> str: +def _get_sync_endpoint_url(deployment_name: str, destination_path: str = "/predict") -> str: if CIRCLECI: # Circle CI: a NodePort is used to expose the service # The IP address is obtained from `minikube ip`. @@ -55,7 +56,7 @@ def _get_sync_endpoint_url(deployment_name: str) -> str: protocol = "http" # no need to hit external DNS resolution if we're w/in the k8s cluster hostname = f"{deployment_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" - return f"{protocol}://{hostname}/predict" + return f"{protocol}://{hostname}{destination_path}" def _serialize_json(data) -> str: @@ -139,7 +140,8 @@ async def make_request_with_retries( ), ): with attempt: - logger.info(f"Retry number {attempt.retry_state.attempt_number}") + if attempt.retry_state.attempt_number > 1: # pragma: no cover + logger.info(f"Retry number {attempt.retry_state.attempt_number}") return await self.make_single_request(request_url, payload_json) except RetryError as e: if type(e.last_attempt.exception()) == TooManyRequestsException: @@ -163,7 +165,9 @@ async def make_request_with_retries( async def predict( self, topic: str, predict_request: SyncEndpointPredictV1Request ) -> SyncEndpointPredictV1Response: - deployment_url = _get_sync_endpoint_url(topic) + deployment_url = _get_sync_endpoint_url( + topic, destination_path=predict_request.destination_path or "/predict" + ) try: timeout_seconds = ( @@ -178,12 +182,20 @@ async def predict( ) response = await self.make_request_with_retries( request_url=deployment_url, - payload_json=predict_request.dict(), + payload_json=predict_request.model_dump(exclude_none=True), timeout_seconds=timeout_seconds, num_retries=num_retries, ) except UpstreamServiceError as exc: logger.error(f"Service error on sync task: {exc.content!r}") + + if exc.status_code == 400: + error_json = orjson.loads(exc.content.decode("utf-8")) + if "result" in error_json: + error_json = orjson.loads(error_json["result"]) + + raise InvalidRequestException(error_json) + try: # Try to parse traceback from the response, fallback to just return all the content if failed. # Three cases considered: @@ -193,6 +205,7 @@ async def predict( error_json = orjson.loads(exc.content.decode("utf-8")) if "result" in error_json: error_json = orjson.loads(error_json["result"]) + detail = error_json.get("detail", {}) if not isinstance(detail, dict): result_traceback = orjson.dumps(error_json) @@ -204,6 +217,7 @@ async def predict( status=TaskStatus.FAILURE, traceback=result_traceback, ) + except Exception as e: logger.error(f"Failed to parse error: {e}") return SyncEndpointPredictV1Response( diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 481ab0a6..8a6f0a8e 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -162,6 +162,7 @@ class _RunnableImageDeploymentArguments(_BaseDeploymentArguments): FORWARDER_CPUS_LIMIT: float FORWARDER_MEMORY_LIMIT: str FORWARDER_STORAGE_LIMIT: str + FORWARDER_EXTRA_ROUTES: List[str] USER_CONTAINER_PORT: int @@ -217,7 +218,9 @@ class DeploymentRunnableImageSyncCpuArguments( class DeploymentRunnableImageSyncGpuArguments( - _RunnableImageDeploymentArguments, _SyncRunnableImageDeploymentArguments, _GpuArguments + _RunnableImageDeploymentArguments, + _SyncRunnableImageDeploymentArguments, + _GpuArguments, ): """Keyword-arguments for substituting into GPU sync deployment templates for runnable images.""" @@ -247,7 +250,9 @@ class DeploymentRunnableImageAsyncGpuArguments( class DeploymentTritonEnhancedRunnableImageSyncCpuArguments( - _RunnableImageDeploymentArguments, _SyncRunnableImageDeploymentArguments, _TritonArguments + _RunnableImageDeploymentArguments, + _SyncRunnableImageDeploymentArguments, + _TritonArguments, ): """Keyword-arguments for substituting into CPU sync deployment templates for triton-enhanced runnable images. @@ -274,7 +279,10 @@ class DeploymentTritonEnhancedRunnableImageAsyncCpuArguments( class DeploymentTritonEnhancedRunnableImageAsyncGpuArguments( - _RunnableImageDeploymentArguments, _AsyncDeploymentArguments, _GpuArguments, _TritonArguments + _RunnableImageDeploymentArguments, + _AsyncDeploymentArguments, + _GpuArguments, + _TritonArguments, ): """Keyword-arguments for substituting GPU async deployment templates for triton-enhanced runnable images. @@ -609,6 +617,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Async Deployment Arguments CELERY_S3_BUCKET=s3_bucket, QUEUE=sqs_queue_name, @@ -657,6 +666,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Async Deployment Arguments CELERY_S3_BUCKET=s3_bucket, QUEUE=sqs_queue_name, @@ -708,6 +718,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Streaming Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, @@ -753,6 +764,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Streaming Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, @@ -799,6 +811,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, @@ -843,6 +856,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, @@ -889,6 +903,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Async Deployment Arguments CELERY_S3_BUCKET=s3_bucket, QUEUE=sqs_queue_name, @@ -945,6 +960,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Async Deployment Arguments CELERY_S3_BUCKET=s3_bucket, QUEUE=sqs_queue_name, @@ -1003,6 +1019,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, @@ -1055,6 +1072,7 @@ def get_endpoint_resource_arguments_from_request( FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 1b35ad6a..4a61d564 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -617,6 +617,10 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" + - --set + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" @@ -885,6 +889,10 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" + - --set + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" @@ -1111,6 +1119,10 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" + - --set + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" @@ -1852,6 +1864,10 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" + - --set + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" @@ -2127,6 +2143,10 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" + - --set + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" @@ -2360,6 +2380,10 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" + - --set + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" diff --git a/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py b/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py index 9408d59b..5371bcea 100644 --- a/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py @@ -122,14 +122,16 @@ def translate_model_bundle_orm_to_model_bundle( flavor=model_bundle_orm.flavor, requirements=model_bundle_orm.artifact_requirements, location=model_bundle_orm.artifact_location, - framework=None - if model_bundle_orm.artifact_framework_type is None - else dict_not_none( - framework_type=model_bundle_orm.artifact_framework_type, - pytorch_image_tag=model_bundle_orm.artifact_pytorch_image_tag, - tensorflow_version=model_bundle_orm.artifact_tensorflow_version, - image_repository=model_bundle_orm.artifact_image_repository, - image_tag=model_bundle_orm.artifact_image_tag, + framework=( + None + if model_bundle_orm.artifact_framework_type is None + else dict_not_none( + framework_type=model_bundle_orm.artifact_framework_type, + pytorch_image_tag=model_bundle_orm.artifact_pytorch_image_tag, + tensorflow_version=model_bundle_orm.artifact_tensorflow_version, + image_repository=model_bundle_orm.artifact_image_repository, + image_tag=model_bundle_orm.artifact_image_tag, + ) ), app_config=model_bundle_orm.artifact_app_config, load_predict_fn=model_bundle_orm.cloudpickle_artifact_load_predict_fn, @@ -144,6 +146,7 @@ def translate_model_bundle_orm_to_model_bundle( env=model_bundle_orm.runnable_image_env, protocol=model_bundle_orm.runnable_image_protocol, readiness_initial_delay_seconds=model_bundle_orm.runnable_image_readiness_initial_delay_seconds, + extra_routes=model_bundle_orm.runnable_image_extra_routes, streaming_command=model_bundle_orm.streaming_enhanced_runnable_image_streaming_command, streaming_predict_route=model_bundle_orm.streaming_enhanced_runnable_image_streaming_predict_route, triton_model_repository=model_bundle_orm.triton_enhanced_runnable_image_model_repository, @@ -161,7 +164,7 @@ def translate_model_bundle_orm_to_model_bundle( packaging_type=model_bundle_orm.packaging_type, app_config=model_bundle_orm.app_config, ) - return ModelBundle.parse_obj(kwargs) + return ModelBundle.model_validate(kwargs) def translate_kwargs_to_model_bundle_orm( @@ -212,6 +215,7 @@ def translate_kwargs_to_model_bundle_orm( runnable_image_readiness_initial_delay_seconds=flavor_dict.get( "readiness_initial_delay_seconds" ), + runnable_image_extra_routes=flavor_dict.get("extra_routes"), streaming_enhanced_runnable_image_streaming_command=flavor_dict.get("streaming_command"), streaming_enhanced_runnable_image_streaming_predict_route=flavor_dict.get( "streaming_predict_route" diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 97beffb8..5c1ce58f 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -738,7 +738,10 @@ async def get_job_template_for_model( return self.db.get((model_name, fine_tuning_method), None) async def write_job_template_for_model( - self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate + self, + model_name: str, + fine_tuning_method: str, + job_template: LLMFineTuneTemplate, ): self.db[(model_name, fine_tuning_method)] = job_template @@ -761,7 +764,10 @@ class FakeLLMArtifactGateway(LLMArtifactGateway): def __init__(self): self.existing_models = [] self.s3_bucket = { - "fake-checkpoint": ["model-fake.bin, model-fake2.bin", "model-fake.safetensors"], + "fake-checkpoint": [ + "model-fake.bin, model-fake2.bin", + "model-fake.safetensors", + ], "llama-7b/tokenizer.json": ["llama-7b/tokenizer.json"], "llama-7b/tokenizer_config.json": ["llama-7b/tokenizer_config.json"], "llama-7b/special_tokens_map.json": ["llama-7b/special_tokens_map.json"], @@ -793,6 +799,50 @@ def __init__(self): "use_cache": True, "vocab_size": 32000, } + self.tokenizer_config = { + "add_bos_token": True, + "add_eos_token": False, + "add_prefix_space": None, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + "1": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + "2": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + }, + "additional_special_tokens": [], + "bos_token": "", + "chat_template": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n", + "clean_up_tokenization_spaces": False, + "eos_token": "", + "legacy": False, + "model_max_length": 1000000000000000019884624838656, + "pad_token": None, + "sp_model_kwargs": {}, + "spaces_between_special_tokens": False, + "tokenizer_class": "LlamaTokenizer", + "unk_token": "", + "use_default_system_prompt": False, + } def _add_model(self, owner: str, model_name: str): self.existing_models.append((owner, model_name)) @@ -816,7 +866,7 @@ def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: return self.model_config -class FakeTriggerRepository(TriggerRepository): +class FakeTriggerRepository(TriggerRepository): # pragma: no cover def __init__(self, contents: Optional[Dict[str, Trigger]] = None): self.db = {} if contents is None else contents self.next_id = 0 @@ -1960,7 +2010,7 @@ def fake_docker_repository_image_never_exists() -> FakeDockerRepository: @pytest.fixture -def fake_docker_repository_image_never_exists_and_builds_dont_work() -> FakeDockerRepository: +def fake_docker_repository_image_never_exists_and_builds_dont_work() -> (FakeDockerRepository): repo = FakeDockerRepository(image_always_exists=False, raises_error=True) return repo @@ -1990,7 +2040,7 @@ def fake_model_endpoint_record_repository() -> FakeModelEndpointRecordRepository @pytest.fixture -def fake_docker_image_batch_job_bundle_repository() -> FakeDockerImageBatchJobBundleRepository: +def fake_docker_image_batch_job_bundle_repository() -> (FakeDockerImageBatchJobBundleRepository): repo = FakeDockerImageBatchJobBundleRepository() return repo @@ -2073,25 +2123,27 @@ def fake_model_primitive_gateway() -> FakeModelPrimitiveGateway: @pytest.fixture -def fake_async_model_endpoint_inference_gateway() -> FakeAsyncModelEndpointInferenceGateway: +def fake_async_model_endpoint_inference_gateway() -> (FakeAsyncModelEndpointInferenceGateway): gateway = FakeAsyncModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_streaming_model_endpoint_inference_gateway() -> FakeStreamingModelEndpointInferenceGateway: +def fake_streaming_model_endpoint_inference_gateway() -> ( + FakeStreamingModelEndpointInferenceGateway +): gateway = FakeStreamingModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_sync_model_endpoint_inference_gateway() -> FakeSyncModelEndpointInferenceGateway: +def fake_sync_model_endpoint_inference_gateway() -> (FakeSyncModelEndpointInferenceGateway): gateway = FakeSyncModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_inference_autoscaling_metrics_gateway() -> FakeInferenceAutoscalingMetricsGateway: +def fake_inference_autoscaling_metrics_gateway() -> (FakeInferenceAutoscalingMetricsGateway): gateway = FakeInferenceAutoscalingMetricsGateway() return gateway @@ -3535,7 +3587,7 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An @pytest.fixture -def sync_endpoint_predict_request_1() -> Tuple[SyncEndpointPredictV1Request, Dict[str, Any]]: +def sync_endpoint_predict_request_1() -> (Tuple[SyncEndpointPredictV1Request, Dict[str, Any]]): request = SyncEndpointPredictV1Request( url="test_url", return_pickled=False, diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 23abfd9d..4882c3e3 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -335,6 +335,34 @@ def create_llm_model_endpoint_request_llama_3_70b() -> CreateLLMModelEndpointV1R ) +@pytest.fixture +def create_llm_model_endpoint_request_llama_3_70b_chat() -> (CreateLLMModelEndpointV1Request): + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_3_70b_chat", + model_name="llama-3-70b", + source="hugging_face", + inference_framework="vllm", + inference_framework_image_tag="1.0.0", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-3-70b", + chat_template_override="test-template", + ) + + @pytest.fixture def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( CreateLLMModelEndpointV1Request diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 7dc149fd..06098ce4 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -39,6 +39,7 @@ is_model_name_suffix_valid, ) from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + CHAT_TEMPLATE_MAX_LENGTH, CompletionStreamV1UseCase, CompletionSyncV1UseCase, CreateBatchCompletionsUseCase, @@ -53,6 +54,7 @@ _infer_hardware, merge_metadata, validate_and_update_completion_params, + validate_chat_template, validate_checkpoint_files, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -132,6 +134,7 @@ async def test_create_model_endpoint_use_case_success( "num_shards": create_llm_model_endpoint_request_async.num_shards, "quantize": None, "checkpoint_path": create_llm_model_endpoint_request_async.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_async.chat_template_override, } } @@ -155,6 +158,7 @@ async def test_create_model_endpoint_use_case_success( "num_shards": create_llm_model_endpoint_request_sync.num_shards, "quantize": None, "checkpoint_path": create_llm_model_endpoint_request_sync.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_sync.chat_template_override, } } @@ -179,7 +183,8 @@ async def test_create_model_endpoint_use_case_success( "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, "num_shards": create_llm_model_endpoint_request_streaming.num_shards, "quantize": None, - "checkpoint_path": create_llm_model_endpoint_request_sync.checkpoint_path, + "checkpoint_path": create_llm_model_endpoint_request_streaming.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, } } @@ -276,6 +281,7 @@ async def test_create_model_bundle_fails_if_no_checkpoint( num_shards=request.num_shards, quantize=request.quantize, checkpoint_path=checkpoint_path, + chat_template_override=request.chat_template_override, ) @@ -333,6 +339,64 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( await use_case.execute(user=user, request=request) +@pytest.mark.asyncio +async def test_create_model_endpoint_w_chat_template( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_request_llama_3_70b_chat: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + print(create_llm_model_endpoint_request_llama_3_70b_chat) + response = await use_case.execute( + user=user, + request=create_llm_model_endpoint_request_llama_3_70b_chat, + ) + assert response.endpoint_creation_task_id + assert isinstance(response, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_llama_3_70b_chat.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_llama_3_70b_chat.model_name, + "source": create_llm_model_endpoint_request_llama_3_70b_chat.source, + "inference_framework": create_llm_model_endpoint_request_llama_3_70b_chat.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_request_llama_3_70b_chat.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_request_llama_3_70b_chat.num_shards, + "quantize": create_llm_model_endpoint_request_llama_3_70b_chat.quantize, + "checkpoint_path": create_llm_model_endpoint_request_llama_3_70b_chat.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_llama_3_70b_chat.chat_template_override, + } + } + + @pytest.mark.asyncio async def test_create_model_endpoint_text_generation_inference_use_case_success( test_api_key: str, @@ -386,6 +450,7 @@ async def test_create_model_endpoint_text_generation_inference_use_case_success( "num_shards": create_llm_model_endpoint_text_generation_inference_request_streaming.num_shards, "quantize": create_llm_model_endpoint_text_generation_inference_request_streaming.quantize, "checkpoint_path": create_llm_model_endpoint_text_generation_inference_request_streaming.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_text_generation_inference_request_streaming.chat_template_override, } } @@ -426,7 +491,7 @@ def test_load_model_weights_sub_commands( ) expected_result = [ - "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' s3://fake-checkpoint/* test_folder", + './s5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*" s3://fake-checkpoint/* test_folder', ] assert expected_result == subcommands @@ -441,7 +506,7 @@ def test_load_model_weights_sub_commands( expected_result = [ "s5cmd > /dev/null || conda install -c conda-forge -y s5cmd", - "s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' s3://fake-checkpoint/* test_folder", + 's5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*" s3://fake-checkpoint/* test_folder', ] assert expected_result == subcommands @@ -515,6 +580,7 @@ async def test_create_model_endpoint_trt_llm_use_case_success( "num_shards": create_llm_model_endpoint_trt_llm_request_streaming.num_shards, "quantize": create_llm_model_endpoint_trt_llm_request_streaming.quantize, "checkpoint_path": create_llm_model_endpoint_trt_llm_request_streaming.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_trt_llm_request_streaming.chat_template_override, } } @@ -705,6 +771,7 @@ async def test_update_model_endpoint_use_case_success( "num_shards": create_llm_model_endpoint_request_streaming.num_shards, "quantize": None, "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, } } assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory @@ -740,6 +807,7 @@ async def test_update_model_endpoint_use_case_success( "num_shards": create_llm_model_endpoint_request_streaming.num_shards, "quantize": None, "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, } } assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory @@ -2720,3 +2788,16 @@ def test_merge_metadata(): "key2": "value2", "key3": "value3", } + + +def test_validate_chat_template(): + assert validate_chat_template(None, LLMInferenceFramework.DEEPSPEED) is None + good_chat_template = CHAT_TEMPLATE_MAX_LENGTH * "_" + assert validate_chat_template(good_chat_template, LLMInferenceFramework.VLLM) is None + + bad_chat_template = (CHAT_TEMPLATE_MAX_LENGTH + 1) * "_" + with pytest.raises(ObjectHasInvalidValueException): + validate_chat_template(bad_chat_template, LLMInferenceFramework.DEEPSPEED) + + with pytest.raises(ObjectHasInvalidValueException): + validate_chat_template(good_chat_template, LLMInferenceFramework.DEEPSPEED) diff --git a/model-engine/tests/unit/inference/test_forwarding.py b/model-engine/tests/unit/inference/test_forwarding.py index 5c996303..0462b317 100644 --- a/model-engine/tests/unit/inference/test_forwarding.py +++ b/model-engine/tests/unit/inference/test_forwarding.py @@ -4,6 +4,7 @@ from unittest import mock import pytest +from fastapi import HTTPException from fastapi.responses import JSONResponse from model_engine_server.core.utils.env import environment from model_engine_server.domain.entities import ModelEndpointConfig @@ -45,6 +46,17 @@ def json(self) -> dict: return mocked_static_json() +def mocked_post_400(*args, **kwargs): # noqa + @dataclass + class mocked_static_json: + status_code: int = 400 + + def json(self) -> dict: + return PAYLOAD # type: ignore + + return mocked_static_json() + + def mocked_post_500(*args, **kwargs): # noqa @dataclass class mocked_static_json: @@ -132,8 +144,8 @@ def mocked_config_content(): def mocked_config_overrides(): return [ - "forwarder.sync.extra_routes=[/v1/chat/completions]", - "forwarder.stream.extra_routes=[/v1/chat/completions]", + "forwarder.sync.extra_routes=['/v1/chat/completions']", + "forwarder.stream.extra_routes=['/v1/chat/completions']", "forwarder.sync.healthcheck_route=/health", "forwarder.stream.healthcheck_route=/health", ] @@ -406,6 +418,22 @@ def test_streaming_forwarders(post_inference_hooks_handler): _check_streaming(response) +@mock.patch("requests.post", mocked_post_400) +@mock.patch("requests.get", mocked_get) +@mock.patch("sseclient.SSEClient", mocked_sse_client) +def test_streaming_forwarder_400_upstream(post_inference_hooks_handler): + fwd = StreamingForwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=False, + post_inference_hooks_handler=post_inference_hooks_handler, + ) + with pytest.raises(HTTPException) as e: + fwd({"ignore": "me"}) + + assert e.value.status_code == 400 + + @mock.patch("requests.post", mocked_post) @mock.patch("requests.get", mocked_get) @mock.patch("sseclient.SSEClient", mocked_sse_client) diff --git a/model-engine/tests/unit/inference/test_vllm_batch.py b/model-engine/tests/unit/inference/test_vllm_batch.py index c097f858..5462c6ae 100644 --- a/model-engine/tests/unit/inference/test_vllm_batch.py +++ b/model-engine/tests/unit/inference/test_vllm_batch.py @@ -312,7 +312,8 @@ def test_file_exists(): path = "test_path" with patch( - "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", mock_open_func + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + mock_open_func, ): result = file_exists(path) @@ -431,7 +432,10 @@ async def test_batch_inference_tool_completion( {"token": ".", "log_prob": -0.3870151937007904}, {"token": "\n", "log_prob": -0.027081478387117386}, {"token": "Final", "log_prob": -0.1980377733707428}, - {"token": " Answer", "log_prob": -0.0037908137310296297}, + { + "token": " Answer", + "log_prob": -0.0037908137310296297, + }, {"token": ":", "log_prob": -0.015637163072824478}, {"token": " ", "log_prob": -0.0010788579238578677}, {"token": "4", "log_prob": -0.04351021721959114}, diff --git a/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py index 1d38c223..8140b1c2 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py @@ -58,6 +58,7 @@ def test_task_create_get_args_callback( "callback_auth": json.loads(endpoint_predict_request_2[0].callback_auth.json()), "callback_url": endpoint_predict_request_2[0].callback_url, "return_pickled": endpoint_predict_request_2[0].return_pickled, + "destination_path": None, } assert (datetime.now() - task_queue_gateway.queue[task_id]["args"][1]) < timedelta(seconds=1) assert ( diff --git a/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py index e2cabc79..db48e7f8 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py @@ -9,7 +9,7 @@ SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) -from model_engine_server.domain.exceptions import UpstreamServiceError +from model_engine_server.domain.exceptions import InvalidRequestException, UpstreamServiceError from model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway import ( LiveStreamingModelEndpointInferenceGateway, ) @@ -214,3 +214,24 @@ async def test_predict_raises_traceback_not_json( } count += 1 assert count == 1 + + +@pytest.mark.asyncio +async def test_predict_upstream_raises_400( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] +): + gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) + content = json.dumps({"result": json.dumps({"error": "error"})}).encode("utf-8") + + fake_response = FakeResponse(status=400, message_content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + with pytest.raises(InvalidRequestException): + response = gateway.streaming_predict( + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + ) + async for message in response: + message diff --git a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py index afa1aee5..fdc74288 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py @@ -9,7 +9,7 @@ SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) -from model_engine_server.domain.exceptions import UpstreamServiceError +from model_engine_server.domain.exceptions import InvalidRequestException, UpstreamServiceError from model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway import ( LiveSyncModelEndpointInferenceGateway, ) @@ -229,3 +229,23 @@ async def test_predict_raises_traceback_wrapped_detail_array( "result": None, "traceback": """{"detail":[{"error":"error"}]}""", } + + +@pytest.mark.asyncio +async def test_predict_upstream_raises_400( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] +): + gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + + content = json.dumps({"result": json.dumps({"error": "error"})}).encode("utf-8") + fake_response = FakeResponse(status=400, content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + # assert that the exception is raised + with pytest.raises(InvalidRequestException): + await gateway.predict( + topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + ) diff --git a/scripts/generate-openai-types.sh b/scripts/generate-openai-types.sh index bdfd0050..2a60d88c 100755 --- a/scripts/generate-openai-types.sh +++ b/scripts/generate-openai-types.sh @@ -14,6 +14,7 @@ datamodel-codegen \ --output-model-type pydantic_v2.BaseModel \ --enum-field-as-literal all \ --field-constraints \ + --strict-nullable \ --use-annotated CLIENT_DIR=${BASE_DIR}/clients/python/llmengine/data_types/gen @@ -27,6 +28,7 @@ datamodel-codegen \ --output-model-type pydantic.BaseModel \ --enum-field-as-literal all \ --field-constraints \ + --strict-nullable \ --use-annotated # Ignore mypy for this file diff --git a/scripts/openai-spec.yaml b/scripts/openai-spec.yaml index 8a50ee2f..7ec21c7f 100644 --- a/scripts/openai-spec.yaml +++ b/scripts/openai-spec.yaml @@ -9556,7 +9556,7 @@ components: required: - role - content - - refusal + # - refusal ChatCompletionStreamResponseDelta: type: object @@ -9822,7 +9822,7 @@ components: - finish_reason - index - message - - logprobs + # - logprobs properties: finish_reason: type: string @@ -9863,7 +9863,7 @@ components: nullable: true required: - content - - refusal + # - refusal created: type: integer From 624b91e62485f4ee4c614435ddf026a516550371 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 12 Sep 2024 12:03:30 -0700 Subject: [PATCH 379/425] Fix passing in vllm args options (#611) --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/data_types/chat_completion.py | 4 ++-- clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index aada0ef5..40ea1a28 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0beta38" +__version__ = "0.0.0beta39" import os from typing import Sequence diff --git a/clients/python/llmengine/data_types/chat_completion.py b/clients/python/llmengine/data_types/chat_completion.py index adee0046..fdfa85b4 100644 --- a/clients/python/llmengine/data_types/chat_completion.py +++ b/clients/python/llmengine/data_types/chat_completion.py @@ -1,13 +1,13 @@ from typing import Any, Dict, List, Optional from .gen.openai import CreateChatCompletionRequest, CreateChatCompletionResponse -from .pydantic_types import Field +from .pydantic_types import BaseModel, Field # Fields that are a part of OpenAI spec but are not supported by model engine UNSUPPORTED_FIELDS = ["service_tier"] -class VLLMAdditionalFields: +class VLLMAdditionalFields(BaseModel): chat_template: Optional[str] = Field( default=None, description=( diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 3941db63..2d011128 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta38" +version = "0.0.0.beta39" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 43299efc..73db3f02 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.8", - version="0.0.0.beta38", + version="0.0.0.beta39", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) From 65bbb63619d5097a42e89eb23fbf1807731b6c6f Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:07:27 -0700 Subject: [PATCH 380/425] Option to skip AWS profile set (#613) --- .../inference/batch_inference/vllm_batch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py index b31e7331..d4863391 100644 --- a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -31,7 +31,9 @@ AWS_REGION = os.getenv("AWS_REGION", "us-west-2") MODEL_WEIGHTS_FOLDER = os.getenv("MODEL_WEIGHTS_FOLDER", "./model_weights") -os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") +SKIP_AWS_PROFILE_SET = os.getenv("SKIP_AWS_PROFILE_SET", "false").lower() == "true" +if not SKIP_AWS_PROFILE_SET: + os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") def get_cpu_cores_in_container(): From 86e3589a90a6db2950c2fced8b0606df2b573e9d Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:01:13 -0700 Subject: [PATCH 381/425] MLI-2949 Upgrade vllm to 0.6.1.post2 (#614) * Changes for vllm 0.6 * vllm 0.6 * updates * updates * fix --- .../inference/vllm/Dockerfile.vllm | 8 +------ .../inference/vllm/README.md | 2 +- .../inference/vllm/build_and_upload_image.sh | 2 +- .../inference/vllm/requirements-batch.txt | 3 ++- .../inference/vllm/requirements-dev.txt | 2 +- .../inference/vllm/vllm_server.py | 1 + scripts/throughput_benchmarks.py | 23 +++++++++++-------- 7 files changed, 20 insertions(+), 21 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm index bb4bf801..6a00adb7 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -ARG VLLM_VERSION=0.5.3.post1 +ARG VLLM_VERSION=0.6.1.post2 ARG VLLM_BASE_IMAGE=vllm/vllm-openai:v${VLLM_VERSION} FROM ${VLLM_BASE_IMAGE} AS base @@ -9,12 +9,6 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* \ apt-get clean -# Need to fix flashinfer at 0.0.8 to support gemma models -# See https://github.com/vllm-project/vllm/issues/7060#issuecomment-2266248014 -# vLLM 0.5.3 depends on torch 2.3.1 -RUN pip uninstall flashinfer -y -RUN pip install flashinfer==0.0.8 --index-url https://flashinfer.ai/whl/cu121/torch2.3 - WORKDIR /workspace RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz diff --git a/model-engine/model_engine_server/inference/vllm/README.md b/model-engine/model_engine_server/inference/vllm/README.md index 29b44d60..8f969f17 100644 --- a/model-engine/model_engine_server/inference/vllm/README.md +++ b/model-engine/model_engine_server/inference/vllm/README.md @@ -52,7 +52,7 @@ docker run \ -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/examples:/workspace/examples \ -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_batch.py:/workspace/vllm_batch.py \ -p 5005:5005 \ - -e CONFIG_FILE=/workspace/examples/sample_config_gemma.json \ + -e CONFIG_FILE=/workspace/examples/v2/sample_config_gemma.json \ -e MODEL_WEIGHTS_FOLDER=/workspace/model_files \ --name vllm_batch \ ${IMAGE_BATCH} \ diff --git a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh index 8b6175b6..d7fcb547 100755 --- a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh +++ b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh @@ -4,7 +4,7 @@ set -eo pipefail # Build and push vLLM docker image to AWS ECR. # -# Usage: VLLM_VERSION=0.5.3.post1 ./build_and_upload_image.sh vllm|vllm_batch +# Usage: VLLM_VERSION=0.5.3.post1 ./build_and_upload_image.sh vllm|vllm_batch|vllm_batch_v2 SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) PROJECT_DIR=$SCRIPT_DIR/../../../.. diff --git a/model-engine/model_engine_server/inference/vllm/requirements-batch.txt b/model-engine/model_engine_server/inference/vllm/requirements-batch.txt index f2865af3..04afaa23 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements-batch.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements-batch.txt @@ -3,4 +3,5 @@ boto3==1.34.15 smart-open==6.4.0 ddtrace==2.11.0 datadog==0.49.1 -dataclasses-json~=0.6.7 \ No newline at end of file +dataclasses-json~=0.6.7 +sse-starlette==2.1.3 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/requirements-dev.txt b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt index 34cee62b..066478b2 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements-dev.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt @@ -1 +1 @@ -vllm>=0.5.4 \ No newline at end of file +vllm==0.6.1.post2 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 5d19111a..4929ef72 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -230,6 +230,7 @@ async def run_server(args, **uvicorn_kwargs) -> None: shutdown_task = await serve_http( app, + engine=async_engine_client, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/scripts/throughput_benchmarks.py b/scripts/throughput_benchmarks.py index e970ead7..06888dd7 100644 --- a/scripts/throughput_benchmarks.py +++ b/scripts/throughput_benchmarks.py @@ -20,7 +20,7 @@ GATEWAY_URL = os.getenv("GATEWAY_URL") app = typer.Typer(name="throughput-benchmarks", add_completion=False) -MAX_CONTEXT_WINDOW = 4096 +MAX_CONTEXT_WINDOW = 100000 @dataclass @@ -453,11 +453,18 @@ def run_benchmarks( if output_file is not None: header = all_statistics[0].keys() - - with open(output_file, "a") as csvfile: - csv_writer = csv.DictWriter(csvfile, fieldnames=header) - csv_writer.writeheader() - csv_writer.writerows(all_statistics) + import os + + if not os.path.exists(output_file): + with open(output_file, "w") as csvfile: + print("creating the data in csv") + csv_writer = csv.DictWriter(csvfile, fieldnames=header) + csv_writer.writeheader() + csv_writer.writerows(all_statistics) + else: + with open(output_file, "a") as csvfile: + csv_writer = csv.DictWriter(csvfile, fieldnames=header) + csv_writer.writerows(all_statistics) @app.command() @@ -478,10 +485,6 @@ def run_benchmarks_concurrency_range( response_token_count_distribution_file: Optional[str] = None, prompts_list_override_file: Optional[str] = None, ): - if output_file is not None: - # Create empty file - with open(output_file, "w"): - pass for concurrency in range(concurrency_min, concurrency_max + 1, concurrency_step): run_benchmarks( model, From 80fa44dfb619d9151826f8880f350e4a51d1623c Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:51:41 -0700 Subject: [PATCH 382/425] add skipping aws profile code to v2 batch (#615) --- .../model_engine_server/inference/vllm/vllm_batch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_batch.py b/model-engine/model_engine_server/inference/vllm/vllm_batch.py index f84674f7..b798d392 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_batch.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_batch.py @@ -37,7 +37,10 @@ CONFIG_FILE = os.getenv("CONFIG_FILE") AWS_REGION = os.getenv("AWS_REGION", "us-west-2") MODEL_WEIGHTS_FOLDER = os.getenv("MODEL_WEIGHTS_FOLDER", "./model_weights") -os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") + +SKIP_AWS_PROFILE_SET = os.getenv("SKIP_AWS_PROFILE_SET", "false").lower() == "true" +if not SKIP_AWS_PROFILE_SET: + os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") openai_serving_chat: OpenAIServingChat From 2c389ffb1f89ebd11fb83d765564cc13ebcb45a0 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 23 Sep 2024 13:02:43 -0700 Subject: [PATCH 383/425] Enable users to force redeploy endpoints (#617) * Enable users to force redeploy endpoints * Rename to force bundle recreation --- .../common/dtos/llms/model_endpoints.py | 8 ++++++++ .../domain/use_cases/llm_model_endpoint_use_cases.py | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py index d44d7d0e..c6f8df02 100644 --- a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py @@ -149,6 +149,14 @@ class UpdateLLMModelEndpointV1Request(BaseModel): description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", ) + force_bundle_recreation: Optional[bool] = False + """ + Whether to force recreate the underlying bundle. + + If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created + that we would like to pick up for existing endpoints + """ + class UpdateLLMModelEndpointV1Response(BaseModel): endpoint_creation_task_id: str diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 84c50f69..a78b41d6 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1320,7 +1320,8 @@ async def execute( metadata: Optional[Dict[str, Any]] if ( - request.model_name + request.force_bundle_recreation + or request.model_name or request.source or request.inference_framework_image_tag or request.num_shards From 1f03d44c84100ae3f4d5eac94e7c3eb87b55934d Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 23 Sep 2024 13:56:09 -0700 Subject: [PATCH 384/425] Remove print statement (#618) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index a78b41d6..5db0c617 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -823,9 +823,6 @@ async def create_vllm_bundle( command = [] subcommands = [] - print("here") - print(chat_template_override) - checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) # added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder if "mistral" in model_name: From 01c9387cd19c3396ba40b2397a303c62ac29833c Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 24 Sep 2024 16:13:37 -0700 Subject: [PATCH 385/425] Fix batch compeltion v2 for oai completion (#621) --- .../inference/vllm/vllm_batch.py | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_batch.py b/model-engine/model_engine_server/inference/vllm/vllm_batch.py index b798d392..23d716e4 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_batch.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_batch.py @@ -3,9 +3,20 @@ import json import os import subprocess -from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Dict, List, Optional, Union +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Coroutine, + Dict, + List, + MutableMapping, + Optional, + Union, +) import smart_open +from fastapi import Request from model_engine_server.common.dtos.llms import ( BatchCompletionContent, BatchCompletionsModelConfig, @@ -25,6 +36,7 @@ random_uuid, ) from pydantic import TypeAdapter +from starlette.datastructures import Headers from tqdm import tqdm from typing_extensions import TypeAlias, assert_never from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams @@ -55,6 +67,26 @@ ] +async def dummy_receive() -> MutableMapping[str, Any]: + return {"type": "continue"} + + +# jank but create_completion expects a FastAPI Request object +dummy_request = Request( + scope={ + "type": "http", + "path": "/", + "headers": Headers().raw, + "http_version": "1.1", + "method": "GET", + "scheme": "https", + "client": ("127.0.0.1", 8080), + }, + # receive fn that doesn't terminate + receive=dummy_receive, +) + + async def download_model(checkpoint_path: str, target_dir: str) -> None: s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}" env = os.environ.copy() @@ -162,7 +194,9 @@ async def generate_v2_completions( ] = [] for request in requests: if isinstance(request, CompletionRequest): - results_generators.append(openai_serving_completion.create_completion(request)) + results_generators.append( + openai_serving_completion.create_completion(request, dummy_request) + ) elif isinstance(request, ChatCompletionRequest): results_generators.append(openai_serving_chat.create_chat_completion(request)) else: From ba065402c77a8050b456b8f09d879d3c0d894320 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:11:18 -0700 Subject: [PATCH 386/425] Multinode bundle db migration + orm class + entity (#620) * add multinode bundle file * migration script + orm class * entity, orm <-> entity conversion, orm init that I missed * dang it --- ...9_24_1456-f55525c81eb5_multinode_bundle.py | 42 +++++++++++++++++++ .../db/models/hosted_model_inference.py | 6 +++ .../domain/entities/model_bundle_entity.py | 2 + .../db_model_bundle_repository.py | 4 ++ 4 files changed, 54 insertions(+) create mode 100644 model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_24_1456-f55525c81eb5_multinode_bundle.py diff --git a/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_24_1456-f55525c81eb5_multinode_bundle.py b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_24_1456-f55525c81eb5_multinode_bundle.py new file mode 100644 index 00000000..532b0e38 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_24_1456-f55525c81eb5_multinode_bundle.py @@ -0,0 +1,42 @@ +"""multinode_bundle + +Revision ID: f55525c81eb5 +Revises: b574e9711e35 +Create Date: 2024-09-24 14:56:36.287001 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import ARRAY + +# revision identifiers, used by Alembic. +revision = "f55525c81eb5" +down_revision = "b574e9711e35" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "bundles", + sa.Column("runnable_image_worker_command", ARRAY(sa.Text), nullable=True), + schema="hosted_model_inference", + ) + op.add_column( + "bundles", + sa.Column("runnable_image_worker_env", sa.JSON, nullable=True), + schema="hosted_model_inference", + ) + + +def downgrade() -> None: + op.drop_column( + "bundles", + "runnable_image_worker_command", + schema="hosted_model_inference", + ) + op.drop_column( + "bundles", + "runnable_image_worker_env", + schema="hosted_model_inference", + ) diff --git a/model-engine/model_engine_server/db/models/hosted_model_inference.py b/model-engine/model_engine_server/db/models/hosted_model_inference.py index 72f7806f..7661be46 100644 --- a/model-engine/model_engine_server/db/models/hosted_model_inference.py +++ b/model-engine/model_engine_server/db/models/hosted_model_inference.py @@ -147,6 +147,8 @@ class Bundle(Base): runnable_image_protocol = Column(Text, nullable=True) runnable_image_readiness_initial_delay_seconds = Column(Integer, nullable=True) runnable_image_extra_routes = Column(ARRAY(Text), nullable=True) + runnable_image_worker_command = Column(ARRAY(Text), nullable=True) + runnable_image_worker_env = Column(JSON, nullable=True) # Streaming Enhanced Runnable Image fields streaming_enhanced_runnable_image_streaming_command = Column(ARRAY(Text), nullable=True) @@ -207,6 +209,8 @@ def __init__( runnable_image_protocol: Optional[str] = None, runnable_image_readiness_initial_delay_seconds: Optional[int] = None, runnable_image_extra_routes: Optional[List[str]] = None, + runnable_image_worker_command: Optional[List[str]] = None, + runnable_image_worker_env: Optional[Dict[str, Any]] = None, # Streaming Enhanced Runnable Image fields streaming_enhanced_runnable_image_streaming_command: Optional[List[str]] = None, streaming_enhanced_runnable_image_streaming_predict_route: Optional[str] = None, @@ -263,6 +267,8 @@ def __init__( self.runnable_image_env = runnable_image_env self.runnable_image_protocol = runnable_image_protocol self.runnable_image_extra_routes = runnable_image_extra_routes + self.runnable_image_worker_command = runnable_image_worker_command + self.runnable_image_worker_env = runnable_image_worker_env self.runnable_image_readiness_initial_delay_seconds = ( runnable_image_readiness_initial_delay_seconds ) diff --git a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py index 32c5c4e5..40e26670 100644 --- a/model-engine/model_engine_server/domain/entities/model_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py @@ -158,6 +158,8 @@ class RunnableImageLike(BaseModel, ABC): protocol: Literal["http"] # TODO: add support for other protocols (e.g. grpc) readiness_initial_delay_seconds: int = 120 extra_routes: List[str] = Field(default_factory=list) + worker_command: Optional[List[str]] = None + worker_env: Optional[Dict[str, str]] = None class RunnableImageFlavor(RunnableImageLike): diff --git a/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py b/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py index 5371bcea..b84a598e 100644 --- a/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py @@ -147,6 +147,8 @@ def translate_model_bundle_orm_to_model_bundle( protocol=model_bundle_orm.runnable_image_protocol, readiness_initial_delay_seconds=model_bundle_orm.runnable_image_readiness_initial_delay_seconds, extra_routes=model_bundle_orm.runnable_image_extra_routes, + worker_command=model_bundle_orm.runnable_image_worker_command, + worker_env=model_bundle_orm.runnable_image_worker_env, streaming_command=model_bundle_orm.streaming_enhanced_runnable_image_streaming_command, streaming_predict_route=model_bundle_orm.streaming_enhanced_runnable_image_streaming_predict_route, triton_model_repository=model_bundle_orm.triton_enhanced_runnable_image_model_repository, @@ -216,6 +218,8 @@ def translate_kwargs_to_model_bundle_orm( "readiness_initial_delay_seconds" ), runnable_image_extra_routes=flavor_dict.get("extra_routes"), + runnable_image_worker_command=flavor_dict.get("worker_command"), + runnable_image_worker_env=flavor_dict.get("worker_env"), streaming_enhanced_runnable_image_streaming_command=flavor_dict.get("streaming_command"), streaming_enhanced_runnable_image_streaming_predict_route=flavor_dict.get( "streaming_predict_route" From 72fd1b8da2b3919b5bb853603ef3d4f0d243ed59 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:31:04 -0700 Subject: [PATCH 387/425] Make Redis endpoint cache read service identifier (#622) * make redis cache read service identifier * handle SERVICE_IDENTIFIER='' --- .../repositories/redis_model_endpoint_cache_repository.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py b/model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py index fb1cf630..feea00b7 100644 --- a/model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py +++ b/model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py @@ -1,4 +1,5 @@ import json +import os from typing import Optional import aioredis @@ -7,6 +8,8 @@ ModelEndpointCacheRepository, ) +SERVICE_IDENTIFIER = os.getenv("SERVICE_IDENTIFIER") + class RedisModelEndpointCacheRepository(ModelEndpointCacheRepository): # TODO figure out exceptions that can be thrown @@ -32,7 +35,10 @@ def __init__( @staticmethod def _find_redis_key(key: str): - return f"launch-k8s-cache:{key}" + if SERVICE_IDENTIFIER: + return f"launch-k8s-cache:{SERVICE_IDENTIFIER}:{key}" + else: + return f"launch-k8s-cache:{key}" async def write_endpoint_info( self, From c1fc1c6b35d8142f092bebaeea02c99733029651 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 26 Sep 2024 17:53:00 -0700 Subject: [PATCH 388/425] set default storage request/limit for batch jobs (#624) --- charts/model-engine/templates/service_template_config_map.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 19f70286..1ce5d2d8 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -639,9 +639,11 @@ data: requests: cpu: 1 memory: 8Gi + ephemeral-storage: 10Gi limits: cpu: 4 memory: 32Gi + ephemeral-storage: 30Gi {{- if $require_aws_config }} volumeMounts: - name: config-volume From 41639dae1f7b48ef30d49244f6382f73cffd7765 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 26 Sep 2024 21:05:31 -0700 Subject: [PATCH 389/425] Bump server python version to 3.10 (#623) * bump python version to 3.10 * update req.txt * Upgrade celery to remove backports dependency * update pre-commit-config * Add widget for v1 vs v2 batch completion calls * Handle black + semgrep upgrade * fix ruff * Fix ruff * fix typealias declaration * Fix format * debug it test setup * downgrade to 3.10.14 - pyenv doesn't support yet * Bump pytorch image for python 3.10 --- .circleci/config.yml | 24 ++++----- .pre-commit-config.yaml | 16 +++--- examples/finetune_llama_2_on_science_qa.ipynb | 31 ++++++------ integration_tests/rest_api_utils.py | 8 +-- model-engine/Dockerfile | 6 +-- model-engine/model_engine_server/api/app.py | 2 +- .../model_engine_server/api/llms_v1.py | 2 +- .../api/v2/chat_completion.py | 2 +- .../model_engine_server/common/config.py | 6 +-- .../common/dtos/batch_jobs.py | 1 + .../model_engine_server/common/dtos/files.py | 1 + .../common/dtos/llms/chat_completion.py | 12 +++-- .../common/dtos/model_bundles.py | 1 + .../common/dtos/triggers.py | 1 + .../model_engine_server/common/env_vars.py | 1 + model-engine/model_engine_server/common/io.py | 1 + .../model_engine_server/core/aws/secrets.py | 1 + .../model_engine_server/core/celery/app.py | 6 +-- .../core/celery/celery_autoscaler.py | 6 +-- .../model_engine_server/core/config.py | 1 + .../model_engine_server/core/utils/env.py | 1 + .../model_engine_server/core/utils/format.py | 1 + .../core/utils/python_utils.py | 1 + .../model_engine_server/core/utils/timer.py | 1 + .../model_engine_server/core/utils/url.py | 1 + .../use_cases/llm_fine_tuning_use_cases.py | 6 +-- .../use_cases/llm_model_endpoint_use_cases.py | 2 +- .../use_cases/model_endpoint_use_cases.py | 4 +- ...populate_llm_fine_tuning_job_repository.py | 1 + .../entrypoints/start_fastapi_server.py | 1 + .../inference/forwarding/celery_forwarder.py | 8 +-- .../inference/forwarding/echo_server.py | 1 + ...eaming_model_endpoint_inference_gateway.py | 6 +-- ...e_sync_model_endpoint_inference_gateway.py | 6 +-- .../services/live_endpoint_builder_service.py | 12 ++--- .../services/model_endpoint_cache_service.py | 6 +-- .../service_builder/tasks_v1.py | 8 +-- model-engine/requirements.in | 2 +- model-engine/requirements.txt | 50 +++---------------- .../tests/unit/api/test_batch_jobs.py | 12 ++--- model-engine/tests/unit/conftest.py | 16 +++--- model-engine/tests/unit/domain/conftest.py | 12 ++--- ...st_live_batch_job_orchestration_service.py | 6 +-- requirements-dev.txt | 12 ++--- 44 files changed, 142 insertions(+), 163 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index cef79018..09c86645 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -20,7 +20,7 @@ workflows: jobs: run_unit_tests_python_client: docker: - - image: python:3.8-bookworm + - image: python:3.10-bookworm resource_class: small parallelism: 1 steps: @@ -34,7 +34,7 @@ jobs: - run_unit_tests_python_client run_unit_tests_server: docker: - - image: python:3.8-bookworm + - image: python:3.10-bookworm environment: ML_INFRA_DATABASE_URL: postgresql://postgres@localhost/circle_test - image: circleci/postgres:12.9-postgis-ram @@ -54,7 +54,7 @@ jobs: - run_unit_tests_server build_docs: docker: - - image: python:3.8-bookworm + - image: python:3.10-bookworm resource_class: small parallelism: 1 steps: @@ -70,7 +70,7 @@ jobs: mkdocs build --strict deploy_docs: docker: - - image: python:3.8-bookworm + - image: python:3.10-bookworm resource_class: small parallelism: 1 steps: @@ -149,18 +149,18 @@ jobs: docker build -f model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile \ --build-arg BASE_IMAGE=python:3.8-slim \ --build-arg REQUIREMENTS_FILE="$CIRCLE_SHA1-base-requirements.txt" \ - -t temp:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1 . + -t temp:1.11.0-cuda11.3-cudnn8-runtime-$CIRCLE_SHA1 . touch $CIRCLE_SHA1-requirements.txt echo -e "cloudpickle==2.1.0\npyyaml==6.0" > $CIRCLE_SHA1-requirements.txt DOCKER_BUILDKIT=1 docker build -f model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile \ - --build-arg BASE_IMAGE=temp:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1 \ + --build-arg BASE_IMAGE=temp:1.11.0-cuda11.3-cudnn8-runtime-$CIRCLE_SHA1 \ --build-arg REQUIREMENTS_FILE="$CIRCLE_SHA1-requirements.txt" \ - -t $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1-b8c25b . + -t $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.11.0-cuda11.3-cudnn8-runtime-$CIRCLE_SHA1-b8c25b . rm $CIRCLE_SHA1-requirements.txt - minikube --logtostderr -v 1 image load $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.7.1-cuda11.0-cudnn8-runtime-$CIRCLE_SHA1-b8c25b + minikube --logtostderr -v 1 image load $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.11.0-cuda11.3-cudnn8-runtime-$CIRCLE_SHA1-b8c25b - run: name: Install helm chart command: | @@ -168,10 +168,10 @@ jobs: cat model-engine/values_circleci.yaml | envsubst > model-engine/values_circleci_subst.yaml helm install model-engine model-engine --values model-engine/values_circleci_subst.yaml --set tag=$CIRCLE_SHA1 --atomic --debug - run: - name: Change python version to 3.8.12 + name: Change python version to 3.10.14 command: | - pyenv install 3.8.12 - pyenv global 3.8.12 + pyenv install 3.10.14 + pyenv global 3.10.14 - run: name: Install integration test dependencies command: | @@ -256,7 +256,7 @@ commands: - run: name: Ruff Lint Check command: | - ruff . + ruff check . - run: name: Type Check command: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb2d9cc0..36bf3e95 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,23 +2,23 @@ fail_fast: false repos: - repo: https://github.com/psf/black # Make sure to update requirements-dev-extra.txt to match versions! - rev: 22.12.0 + rev: 24.8.0 hooks: - id: black name: "python:black" entry: black --config .black.toml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.278 + rev: v0.6.8 hooks: - id: ruff name: "python:ruff" - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: "python:isort" - repo: https://github.com/jazzband/pip-tools - rev: 7.0.0 + rev: 7.4.1 hooks: - id: pip-compile files: model-engine/requirements\.(in|txt) @@ -31,7 +31,7 @@ repos: --index-url=https://pypi.org/simple, ] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 # https://github.com/pre-commit/pre-commit-hooks/releases + rev: v4.6.0 # https://github.com/pre-commit/pre-commit-hooks/releases hooks: - id: check-added-large-files args: @@ -51,7 +51,7 @@ repos: - id: check-toml language: python - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.3.0' # Make sure this matches the version in requirements-dev.txt! + rev: 'v1.11.2' # Make sure this matches the version in requirements-dev.txt! hooks: - id: mypy name: mypy-clients-python @@ -59,7 +59,7 @@ repos: entry: mypy --config-file clients/python/mypy.ini language: system - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.3.0' # Make sure this matches the version in requirements-dev.txt! + rev: 'v1.11.2' # Make sure this matches the version in requirements-dev.txt! hooks: - id: mypy name: mypy-server @@ -74,7 +74,7 @@ repos: language: system stages: ["commit", "push"] - repo: https://github.com/returntocorp/semgrep - rev: 'v1.36.0' + rev: 'v1.89.0' hooks: - id: semgrep args: [ '--config', 'p/python', '--error' ] diff --git a/examples/finetune_llama_2_on_science_qa.ipynb b/examples/finetune_llama_2_on_science_qa.ipynb index 1b8f0ce5..dad7fe5e 100644 --- a/examples/finetune_llama_2_on_science_qa.ipynb +++ b/examples/finetune_llama_2_on_science_qa.ipynb @@ -13,27 +13,27 @@ }, { "cell_type": "markdown", - "source": [ - "# Packages Required\n", - "For this demo, we'll be using the `scale-llm-engine` package and `datasets` from Huggingface.\n" - ], + "id": "XK6VpTnOL4OV", "metadata": { "id": "XK6VpTnOL4OV" }, - "id": "XK6VpTnOL4OV" + "source": [ + "# Packages Required\n", + "For this demo, we'll be using the `scale-llm-engine` package and `datasets` from Huggingface.\n" + ] }, { "cell_type": "code", - "source": [ - "!pip install scale-llm-engine\n", - "!pip install datasets" - ], + "execution_count": null, + "id": "S5u6DdInMEQ7", "metadata": { "id": "S5u6DdInMEQ7" }, - "id": "S5u6DdInMEQ7", - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "!pip install scale-llm-engine\n", + "!pip install datasets" + ] }, { "cell_type": "markdown", @@ -57,7 +57,6 @@ "source": [ "from datasets import load_dataset\n", "from smart_open import smart_open\n", - "import pandas as pd\n", "\n", "dataset = load_dataset('derek-thomas/ScienceQA')\n", "dataset['train'].features" @@ -244,6 +243,9 @@ } ], "metadata": { + "colab": { + "provenance": [] + }, "kernelspec": { "display_name": "Environment (conda_pytorch_p38)", "language": "python", @@ -260,9 +262,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" - }, - "colab": { - "provenance": [] } }, "nbformat": 4, diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index 6f6a9407..4fe29618 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -57,7 +57,7 @@ def my_model(**keyword_args): "load_model_fn": inspect.getsource(echo_load_model_fn), "framework": { "framework_type": "pytorch", - "pytorch_image_tag": "1.7.1-cuda11.0-cudnn8-runtime", + "pytorch_image_tag": "1.11.0-cuda11.3-cudnn8-runtime", }, "requirements": [ "cloudpickle==2.1.0", @@ -699,9 +699,9 @@ def create_llm_model_endpoint( if inference_framework: create_model_endpoint_request["inference_framework"] = inference_framework if inference_framework_image_tag: - create_model_endpoint_request[ - "inference_framework_image_tag" - ] = inference_framework_image_tag + create_model_endpoint_request["inference_framework_image_tag"] = ( + inference_framework_image_tag + ) response = requests.post( f"{BASE_PATH}/v1/llm/model-endpoints", json=create_model_endpoint_request, diff --git a/model-engine/Dockerfile b/model-engine/Dockerfile index 23eacd9c..45cd9630 100644 --- a/model-engine/Dockerfile +++ b/model-engine/Dockerfile @@ -1,6 +1,6 @@ # syntax = docker/dockerfile:experimental -FROM python:3.8.18-slim as model-engine +FROM python:3.10.15-slim as model-engine WORKDIR /workspace RUN apt-get update && apt-get install -y \ @@ -30,11 +30,11 @@ RUN curl -LO "https://dl.k8s.io/release/v1.23.13/bin/linux/amd64/kubectl" \ && mv kubectl /usr/local/bin/kubectl # Pin pip version -RUN pip install pip==23.0.1 +RUN pip install pip==24.2 RUN chmod -R 777 /workspace # Install AWS CLI -RUN pip install awscli==1.25.62 --no-cache-dir +RUN pip install awscli==1.34.28 --no-cache-dir ## grab model_engine_server project (w/ requirements install layer caching) WORKDIR /workspace/model-engine/ diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index 164768c8..2f7a4b0e 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -64,7 +64,7 @@ async def dispatch(self, request: Request, call_next): }, ) except Exception as e: - tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + tb_str = traceback.format_exception(e) request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") structured_log = { diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 53194fb3..a52d81c6 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -122,7 +122,7 @@ def handle_streaming_exception( code: int, message: str, ): - tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + tb_str = traceback.format_exception(e) request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") structured_log = { diff --git a/model-engine/model_engine_server/api/v2/chat_completion.py b/model-engine/model_engine_server/api/v2/chat_completion.py index f5d5a2db..0dc1f989 100644 --- a/model-engine/model_engine_server/api/v2/chat_completion.py +++ b/model-engine/model_engine_server/api/v2/chat_completion.py @@ -55,7 +55,7 @@ def handle_streaming_exception( code: int, message: str, ): # pragma: no cover - tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + tb_str = traceback.format_exception(e) request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") structured_log = { diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py index 4531cd2a..1226d62a 100644 --- a/model-engine/model_engine_server/common/config.py +++ b/model-engine/model_engine_server/common/config.py @@ -73,9 +73,9 @@ class HostedModelInferenceServiceConfig: # Exactly one of the following three must be specified cache_redis_aws_url: Optional[str] = None # also using this to store sync autoscaling metrics cache_redis_azure_host: Optional[str] = None - cache_redis_aws_secret_name: Optional[ - str - ] = None # Not an env var because the redis cache info is already here + cache_redis_aws_secret_name: Optional[str] = ( + None # Not an env var because the redis cache info is already here + ) @classmethod def from_json(cls, json): diff --git a/model-engine/model_engine_server/common/dtos/batch_jobs.py b/model-engine/model_engine_server/common/dtos/batch_jobs.py index 12225537..7b6a62ed 100644 --- a/model-engine/model_engine_server/common/dtos/batch_jobs.py +++ b/model-engine/model_engine_server/common/dtos/batch_jobs.py @@ -1,6 +1,7 @@ """ DTOs for the batch job abstraction. """ + from datetime import datetime, timedelta from typing import Any, Collection, Dict, List, Optional diff --git a/model-engine/model_engine_server/common/dtos/files.py b/model-engine/model_engine_server/common/dtos/files.py index 8f09d9a3..8fa6e8a8 100644 --- a/model-engine/model_engine_server/common/dtos/files.py +++ b/model-engine/model_engine_server/common/dtos/files.py @@ -1,6 +1,7 @@ """ DTOs for Files API. """ + from typing import List from model_engine_server.common.pydantic_types import BaseModel, Field diff --git a/model-engine/model_engine_server/common/dtos/llms/chat_completion.py b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py index bfb5ab09..547ee5f9 100644 --- a/model-engine/model_engine_server/common/dtos/llms/chat_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py @@ -9,7 +9,7 @@ CreateChatCompletionStreamResponse, ) from sse_starlette import EventSourceResponse -from typing_extensions import Annotated +from typing_extensions import Annotated, TypeAlias # Fields that are a part of OpenAI spec but are not supported by model engine UNSUPPORTED_FIELDS = ["service_tier"] @@ -41,13 +41,15 @@ class ChatCompletionV2ErrorChunk(BaseModel): error: StreamError -ChatCompletionV2Chunk = Union[ChatCompletionV2SuccessChunk, ChatCompletionV2ErrorChunk] -ChatCompletionV2StreamResponse = ( +ChatCompletionV2Chunk: TypeAlias = Union[ChatCompletionV2SuccessChunk, ChatCompletionV2ErrorChunk] +ChatCompletionV2StreamResponse: TypeAlias = ( EventSourceResponse # EventSourceResponse[ChatCompletionV2Chunk | ChatCompletionV2ErrorChunk] ) -ChatCompletionV2Response = Union[ChatCompletionV2SyncResponse, ChatCompletionV2StreamResponse] +ChatCompletionV2Response: TypeAlias = Union[ + ChatCompletionV2SyncResponse, ChatCompletionV2StreamResponse +] # This is a version of ChatCompletionV2Response that is used by pydantic to determine the response model # Since EventSourceResponse isn't a pydanitc model, we need to use a Union of the two response types -ChatCompletionV2ResponseItem = Union[ChatCompletionV2SyncResponse, ChatCompletionV2Chunk] +ChatCompletionV2ResponseItem: TypeAlias = Union[ChatCompletionV2SyncResponse, ChatCompletionV2Chunk] diff --git a/model-engine/model_engine_server/common/dtos/model_bundles.py b/model-engine/model_engine_server/common/dtos/model_bundles.py index cd6f7f30..99d1e13b 100644 --- a/model-engine/model_engine_server/common/dtos/model_bundles.py +++ b/model-engine/model_engine_server/common/dtos/model_bundles.py @@ -1,6 +1,7 @@ """ Contains various input and output types relating to Model Bundles for the server. """ + import datetime from enum import Enum from typing import Any, Dict, List, Optional diff --git a/model-engine/model_engine_server/common/dtos/triggers.py b/model-engine/model_engine_server/common/dtos/triggers.py index ed8d45cf..a7cf2750 100644 --- a/model-engine/model_engine_server/common/dtos/triggers.py +++ b/model-engine/model_engine_server/common/dtos/triggers.py @@ -1,6 +1,7 @@ """ Contains various input and output types relating to Triggers for the server. """ + import datetime from typing import Any, Dict, List, Optional diff --git a/model-engine/model_engine_server/common/env_vars.py b/model-engine/model_engine_server/common/env_vars.py index ad7478fa..2a69cbff 100644 --- a/model-engine/model_engine_server/common/env_vars.py +++ b/model-engine/model_engine_server/common/env_vars.py @@ -1,6 +1,7 @@ """ A place for defining, setting, and referencing all environment variables used in Launch. """ + import os import sys from typing import Optional, Sequence diff --git a/model-engine/model_engine_server/common/io.py b/model-engine/model_engine_server/common/io.py index ae53e7b9..c9d9458f 100644 --- a/model-engine/model_engine_server/common/io.py +++ b/model-engine/model_engine_server/common/io.py @@ -1,4 +1,5 @@ """Launch Input/Output utils.""" + import os from typing import Any diff --git a/model-engine/model_engine_server/core/aws/secrets.py b/model-engine/model_engine_server/core/aws/secrets.py index 3c39b259..0637b121 100644 --- a/model-engine/model_engine_server/core/aws/secrets.py +++ b/model-engine/model_engine_server/core/aws/secrets.py @@ -1,4 +1,5 @@ """AWS secrets module.""" + import json from functools import lru_cache from typing import Optional diff --git a/model-engine/model_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py index 80fda86b..af7790d1 100644 --- a/model-engine/model_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -27,9 +27,9 @@ # override the backend with a class instead of a URL, despite the fact # that the `backend` constructor arg type is a Union[str, Type[celery.backends.base.Backend]] backends.BACKEND_ALIASES["s3"] = "model_engine_server.core.celery.s3:S3Backend" -backends.BACKEND_ALIASES[ - "azureblockblob" -] = "model_engine_server.core.celery.abs:AzureBlockBlobBackend" +backends.BACKEND_ALIASES["azureblockblob"] = ( + "model_engine_server.core.celery.abs:AzureBlockBlobBackend" +) DEFAULT_TASK_VISIBILITY_SECONDS = 86400 diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py index 3abd6d4d..54e3c5bc 100644 --- a/model-engine/model_engine_server/core/celery/celery_autoscaler.py +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -68,7 +68,7 @@ class CeleryAutoscalerParams: def _hash_any_to_int(data: Any): - return int(hashlib.md5(str(data).encode()).hexdigest(), 16) + return int(hashlib.md5(str(data).encode()).hexdigest(), 16) # nosemgrep async def list_deployments(core_api, apps_api) -> Dict[Tuple[str, str], CeleryAutoscalerParams]: @@ -593,9 +593,7 @@ async def main(): broker_type = ( "redis" if isinstance(broker, RedisBroker) - else "sqs" - if isinstance(broker, SQSBroker) - else "servicebus" + else "sqs" if isinstance(broker, SQSBroker) else "servicebus" ) if broker_type == "redis": diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index b4c05042..dc6a492b 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -4,6 +4,7 @@ If this is not set, the default configuration file is used from model_engine_server.core/configs/default.yaml. """ + import inspect import os from contextlib import contextmanager diff --git a/model-engine/model_engine_server/core/utils/env.py b/model-engine/model_engine_server/core/utils/env.py index ef9a6a88..3eb87dd8 100644 --- a/model-engine/model_engine_server/core/utils/env.py +++ b/model-engine/model_engine_server/core/utils/env.py @@ -1,4 +1,5 @@ """Utilities for working with environment variables.""" + import os from typing import ContextManager, Dict, Optional, Sequence, Union diff --git a/model-engine/model_engine_server/core/utils/format.py b/model-engine/model_engine_server/core/utils/format.py index 39e82a57..a26bd2c5 100644 --- a/model-engine/model_engine_server/core/utils/format.py +++ b/model-engine/model_engine_server/core/utils/format.py @@ -1,4 +1,5 @@ """Utilities for formatting and printing messages, especially for CLI programs.""" + import traceback from logging import Logger from typing import Any, List, Optional, Sequence, Tuple, Union diff --git a/model-engine/model_engine_server/core/utils/python_utils.py b/model-engine/model_engine_server/core/utils/python_utils.py index a9297d42..2925c7d9 100644 --- a/model-engine/model_engine_server/core/utils/python_utils.py +++ b/model-engine/model_engine_server/core/utils/python_utils.py @@ -1,4 +1,5 @@ """Python-language-based utility functions.""" + import builtins from importlib import import_module from typing import Any, Optional diff --git a/model-engine/model_engine_server/core/utils/timer.py b/model-engine/model_engine_server/core/utils/timer.py index 5a2bd1be..edd80891 100644 --- a/model-engine/model_engine_server/core/utils/timer.py +++ b/model-engine/model_engine_server/core/utils/timer.py @@ -1,4 +1,5 @@ """Utilities for timing code blocks.""" + import inspect import time from datetime import timedelta diff --git a/model-engine/model_engine_server/core/utils/url.py b/model-engine/model_engine_server/core/utils/url.py index 81a48ffc..16ae3d6f 100644 --- a/model-engine/model_engine_server/core/utils/url.py +++ b/model-engine/model_engine_server/core/utils/url.py @@ -1,4 +1,5 @@ """URL-based utility functions.""" + import re from typing import NamedTuple, Optional diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py index 70da8a9e..02466a52 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -230,9 +230,9 @@ async def execute(self, user: User) -> ListFineTunesResponse: GetFineTuneResponse( id=job.id, status=job.status, - fine_tuned_model=job.annotations.get("fine_tuned_model") - if job.annotations - else None, + fine_tuned_model=( + job.annotations.get("fine_tuned_model") if job.annotations else None + ), ) for job in di_batch_jobs ] diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 5db0c617..aae8ff2f 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1715,7 +1715,7 @@ def model_output_to_completion_output( # Also the log probs don't look right, so returning log-probs is still broken num_completion_tokens = ( len(model_output["output_log_probs"]) - if type(model_output["output_log_probs"]) == list + if type(model_output["output_log_probs"]) is list else 1 ) # Output is just "output". See `exclude_input_in_output` inside of diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index 9d355307..21a55dcb 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -178,7 +178,7 @@ def validate_billing_tags(billing_tags: Optional[Dict[str, Any]]) -> None: if billing_tags is None: return - if type(billing_tags) != dict: + if type(billing_tags) is not dict: raise EndpointBillingTagsMalformedException("Billing tags must be a json dictionary") required_keys = { @@ -195,7 +195,7 @@ def validate_billing_tags(billing_tags: Optional[Dict[str, Any]]) -> None: if len(missing_keys) > 0: raise EndpointBillingTagsMalformedException(f"Missing billing tag keys {missing_keys}") for k, v in billing_tags.items(): - if type(k) != str or type(v) not in [str, dict]: + if type(k) is not str or type(v) not in [str, dict]: raise EndpointBillingTagsMalformedException( "Billing tags must have string keys and string/dict values" ) diff --git a/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py b/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py index 0b971b06..2e0caf29 100644 --- a/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py +++ b/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py @@ -10,6 +10,7 @@ You will need a docker image from the fine-tuning repo. Refer to llm/finetune_pipeline/README.md for instructions. """ + import argparse import asyncio diff --git a/model-engine/model_engine_server/entrypoints/start_fastapi_server.py b/model-engine/model_engine_server/entrypoints/start_fastapi_server.py index 119935ff..90271625 100644 --- a/model-engine/model_engine_server/entrypoints/start_fastapi_server.py +++ b/model-engine/model_engine_server/entrypoints/start_fastapi_server.py @@ -3,6 +3,7 @@ You can do this with `start-fastapi-server`. """ + import argparse import subprocess from typing import List diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index d9c841f2..016ded85 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -70,9 +70,11 @@ def create_celery_service( aws_role=infra_config().profile_ml_inference_worker, task_visibility=task_visibility, broker_type=broker_type, - broker_transport_options={"predefined_queues": {queue_name: {"url": sqs_url}}} - if broker_type == str(BrokerType.SQS.value) - else None, + broker_transport_options=( + {"predefined_queues": {queue_name: {"url": sqs_url}}} + if broker_type == str(BrokerType.SQS.value) + else None + ), backend_protocol=backend_protocol, ) diff --git a/model-engine/model_engine_server/inference/forwarding/echo_server.py b/model-engine/model_engine_server/inference/forwarding/echo_server.py index 3581f678..6ed33d40 100644 --- a/model-engine/model_engine_server/inference/forwarding/echo_server.py +++ b/model-engine/model_engine_server/inference/forwarding/echo_server.py @@ -1,6 +1,7 @@ """ This file is for testing purposes only. It serves as simple server to mock a deployed model. """ + import argparse import subprocess import time diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index c6e8837d..9cda4b9c 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -168,13 +168,13 @@ async def make_request_with_retries( yield orjson.loads(item) return except RetryError as e: - if type(e.last_attempt.exception()) == TooManyRequestsException: + if isinstance(e.last_attempt.exception(), TooManyRequestsException): logger.warning("Hit max # of retries, returning 429 to client") raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") - elif type(e.last_attempt.exception()) == NoHealthyUpstreamException: + elif isinstance(e.last_attempt.exception(), NoHealthyUpstreamException): logger.warning("Pods didn't spin up in time, returning 503 to client") raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") - elif type(e.last_attempt.exception()) == aiohttp.ClientConnectorError: + elif isinstance(e.last_attempt.exception(), aiohttp.ClientConnectorError): logger.warning("ClientConnectorError, returning 503 to client") raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") else: diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index f7781ea3..53230ff0 100644 --- a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -144,13 +144,13 @@ async def make_request_with_retries( logger.info(f"Retry number {attempt.retry_state.attempt_number}") return await self.make_single_request(request_url, payload_json) except RetryError as e: - if type(e.last_attempt.exception()) == TooManyRequestsException: + if isinstance(e.last_attempt.exception(), TooManyRequestsException): logger.warning("Hit max # of retries, returning 429 to client") raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") - elif type(e.last_attempt.exception()) == NoHealthyUpstreamException: + elif isinstance(e.last_attempt.exception(), NoHealthyUpstreamException): logger.warning("Pods didn't spin up in time, returning 503 to client") raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") - elif type(e.last_attempt.exception()) == aiohttp.ClientConnectorError: + elif isinstance(e.last_attempt.exception(), aiohttp.ClientConnectorError): logger.warning("ClientConnectorError, returning 503 to client") raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") else: diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 9f9f257d..7ae03645 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -437,12 +437,12 @@ def convert_artifact_like_bundle_to_runnable_image( if isinstance(model_bundle.flavor, ZipArtifactFlavor): if new_flavor.env is None: new_flavor.env = {} - new_flavor.env[ - "LOAD_PREDICT_FN_MODULE_PATH" - ] = model_bundle.flavor.load_predict_fn_module_path - new_flavor.env[ - "LOAD_MODEL_FN_MODULE_PATH" - ] = model_bundle.flavor.load_model_fn_module_path + new_flavor.env["LOAD_PREDICT_FN_MODULE_PATH"] = ( + model_bundle.flavor.load_predict_fn_module_path + ) + new_flavor.env["LOAD_MODEL_FN_MODULE_PATH"] = ( + model_bundle.flavor.load_model_fn_module_path + ) new_model_bundle.flavor = new_flavor new_model_bundle.model_artifact_ids = [] diff --git a/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py b/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py index 9169d883..7e193027 100644 --- a/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py +++ b/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py @@ -26,9 +26,9 @@ def __init__( self.image_cache_service = image_cache_service async def execute(self, ttl_seconds: float): - endpoint_infra_states: Dict[ - str, Tuple[bool, ModelEndpointInfraState] - ] = await self.resource_gateway.get_all_resources() + endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpointInfraState]] = ( + await self.resource_gateway.get_all_resources() + ) for key, (is_key_an_endpoint_id, state) in endpoint_infra_states.items(): if is_key_an_endpoint_id: diff --git a/model-engine/model_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py index e9eca9a6..cd4ff63c 100644 --- a/model-engine/model_engine_server/service_builder/tasks_v1.py +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -94,9 +94,11 @@ def get_live_endpoint_builder_service( monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False ), model_endpoint_cache_repository=RedisModelEndpointCacheRepository(redis_client=redis), - filesystem_gateway=ABSFilesystemGateway() - if infra_config().cloud_provider == "azure" - else S3FilesystemGateway(), + filesystem_gateway=( + ABSFilesystemGateway() + if infra_config().cloud_provider == "azure" + else S3FilesystemGateway() + ), notification_gateway=notification_gateway, feature_flag_repo=RedisFeatureFlagRepository(redis_client=redis), ) diff --git a/model-engine/requirements.in b/model-engine/requirements.in index b9b44867..d503f7b8 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -13,7 +13,7 @@ boto3-stubs[essential]~=1.26.67 boto3~=1.21 botocore~=1.24 build~=1.0.3 -celery[redis,sqs,tblib]~=5.3.6 +celery[redis,sqs,tblib]~=5.4.0 click~=8.1 cloudpickle==2.1.0 croniter==1.4.1 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 3d19c348..6e784ecc 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --index-url=https://pypi.org/simple --no-emit-index-url --no-emit-trusted-host model-engine/requirements.in @@ -60,10 +60,6 @@ azure-servicebus==7.11.4 # via -r model-engine/requirements.in azure-storage-blob==12.19.0 # via -r model-engine/requirements.in -backports-zoneinfo[tzdata]==0.2.1 - # via - # celery - # kombu billiard==4.2.0 # via celery bleach==6.0.0 @@ -76,9 +72,7 @@ boto3==1.28.1 # celery # kombu boto3-stubs[essential]==1.26.67 - # via - # -r model-engine/requirements.in - # boto3-stubs + # via -r model-engine/requirements.in botocore==1.31.1 # via # -r model-engine/requirements.in @@ -94,10 +88,8 @@ cachetools==5.3.1 # via google-auth cattrs==23.1.2 # via ddtrace -celery[redis,sqs,tblib]==5.3.6 - # via - # -r model-engine/requirements.in - # celery +celery[redis,sqs,tblib]==5.4.0 + # via -r model-engine/requirements.in certifi==2023.7.22 # via # datadog-api-client @@ -219,17 +211,8 @@ idna==3.7 # yarl importlib-metadata==6.8.0 # via - # alembic - # build # keyring - # quart # twine -importlib-resources==6.1.1 - # via - # alembic - # jsonschema - # jsonschema-specifications - # keyring isodate==0.6.1 # via # azure-containerregistry @@ -332,8 +315,6 @@ pg8000==1.29.8 # via testing-postgresql pkginfo==1.9.6 # via twine -pkgutil-resolve-name==1.3.10 - # via jsonschema portalocker==2.8.2 # via msal-extensions priority==2.0.0 @@ -481,7 +462,6 @@ sqlalchemy[asyncio]==2.0.4 # via # -r model-engine/requirements.in # alembic - # sqlalchemy sse-starlette==1.6.1 # via -r model-engine/requirements.in sseclient-py==1.7.2 @@ -491,7 +471,6 @@ starlette[full]==0.36.3 # -r model-engine/requirements.in # fastapi # sse-starlette - # starlette stringcase==1.2.0 # via -r model-engine/requirements.in tblib==2.0.0 @@ -532,40 +511,25 @@ types-s3transfer==0.6.1 typing-extensions==4.10.0 # via # aioredis - # annotated-types # azure-core # azure-keyvault-secrets # azure-servicebus # azure-storage-blob # boto3-stubs - # botocore-stubs - # bytecode # cattrs # datadog-api-client # ddtrace # fastapi # huggingface-hub - # kombu - # mypy-boto3-cloudformation - # mypy-boto3-dynamodb - # mypy-boto3-ec2 - # mypy-boto3-lambda - # mypy-boto3-rds - # mypy-boto3-s3 - # mypy-boto3-sqs # pydantic # pydantic-core - # rich # sqlalchemy - # starlette # typing-inspect # uvicorn typing-inspect==0.9.0 # via dataclasses-json tzdata==2023.3 - # via - # backports-zoneinfo - # celery + # via celery urllib3==1.26.16 # via # botocore @@ -604,9 +568,7 @@ yarl==1.9.2 # -r model-engine/requirements.in # aiohttp zipp==3.16.0 - # via - # importlib-metadata - # importlib-resources + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: setuptools==69.0.3 diff --git a/model-engine/tests/unit/api/test_batch_jobs.py b/model-engine/tests/unit/api/test_batch_jobs.py index 5b638d1f..426bd1d3 100644 --- a/model-engine/tests/unit/api/test_batch_jobs.py +++ b/model-engine/tests/unit/api/test_batch_jobs.py @@ -293,9 +293,9 @@ def test_create_docker_image_batch_job_unauthorized( } ) del create_docker_image_batch_job_request["docker_image_batch_job_bundle_name"] - create_docker_image_batch_job_request[ - "docker_image_batch_job_bundle_id" - ] = docker_image_batch_job_bundle_1_v1[0].id + create_docker_image_batch_job_request["docker_image_batch_job_bundle_id"] = ( + docker_image_batch_job_bundle_1_v1[0].id + ) response = client.post( "/v1/docker-image-batch-jobs", auth=(test_api_key_2, ""), @@ -335,9 +335,9 @@ def test_create_docker_image_batch_job_bundle_id_and_name( docker_image_batch_job_bundle_1_v1[0].id: docker_image_batch_job_bundle_1_v1[0] } ) - create_docker_image_batch_job_request[ - "docker_image_batch_job_bundle_id" - ] = docker_image_batch_job_bundle_1_v1[0].id + create_docker_image_batch_job_request["docker_image_batch_job_bundle_id"] = ( + docker_image_batch_job_bundle_1_v1[0].id + ) response = client.post( "/v1/docker-image-batch-jobs", auth=(test_api_key, ""), diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 5c1ce58f..34f00c92 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -2010,7 +2010,7 @@ def fake_docker_repository_image_never_exists() -> FakeDockerRepository: @pytest.fixture -def fake_docker_repository_image_never_exists_and_builds_dont_work() -> (FakeDockerRepository): +def fake_docker_repository_image_never_exists_and_builds_dont_work() -> FakeDockerRepository: repo = FakeDockerRepository(image_always_exists=False, raises_error=True) return repo @@ -2040,7 +2040,7 @@ def fake_model_endpoint_record_repository() -> FakeModelEndpointRecordRepository @pytest.fixture -def fake_docker_image_batch_job_bundle_repository() -> (FakeDockerImageBatchJobBundleRepository): +def fake_docker_image_batch_job_bundle_repository() -> FakeDockerImageBatchJobBundleRepository: repo = FakeDockerImageBatchJobBundleRepository() return repo @@ -2123,27 +2123,25 @@ def fake_model_primitive_gateway() -> FakeModelPrimitiveGateway: @pytest.fixture -def fake_async_model_endpoint_inference_gateway() -> (FakeAsyncModelEndpointInferenceGateway): +def fake_async_model_endpoint_inference_gateway() -> FakeAsyncModelEndpointInferenceGateway: gateway = FakeAsyncModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_streaming_model_endpoint_inference_gateway() -> ( - FakeStreamingModelEndpointInferenceGateway -): +def fake_streaming_model_endpoint_inference_gateway() -> FakeStreamingModelEndpointInferenceGateway: gateway = FakeStreamingModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_sync_model_endpoint_inference_gateway() -> (FakeSyncModelEndpointInferenceGateway): +def fake_sync_model_endpoint_inference_gateway() -> FakeSyncModelEndpointInferenceGateway: gateway = FakeSyncModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_inference_autoscaling_metrics_gateway() -> (FakeInferenceAutoscalingMetricsGateway): +def fake_inference_autoscaling_metrics_gateway() -> FakeInferenceAutoscalingMetricsGateway: gateway = FakeInferenceAutoscalingMetricsGateway() return gateway @@ -3587,7 +3585,7 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An @pytest.fixture -def sync_endpoint_predict_request_1() -> (Tuple[SyncEndpointPredictV1Request, Dict[str, Any]]): +def sync_endpoint_predict_request_1() -> Tuple[SyncEndpointPredictV1Request, Dict[str, Any]]: request = SyncEndpointPredictV1Request( url="test_url", return_pickled=False, diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 4882c3e3..1e30911a 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -162,7 +162,7 @@ def update_model_endpoint_request( @pytest.fixture -def create_docker_image_batch_job_bundle_request() -> (CreateDockerImageBatchJobBundleV1Request): +def create_docker_image_batch_job_bundle_request() -> CreateDockerImageBatchJobBundleV1Request: return CreateDockerImageBatchJobBundleV1Request( name="name", image_repository="repo", @@ -336,7 +336,7 @@ def create_llm_model_endpoint_request_llama_3_70b() -> CreateLLMModelEndpointV1R @pytest.fixture -def create_llm_model_endpoint_request_llama_3_70b_chat() -> (CreateLLMModelEndpointV1Request): +def create_llm_model_endpoint_request_llama_3_70b_chat() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( name="test_llm_endpoint_name_llama_3_70b_chat", model_name="llama-3-70b", @@ -422,7 +422,7 @@ def create_llm_model_endpoint_text_generation_inference_request_async() -> ( @pytest.fixture -def create_llm_model_endpoint_trt_llm_request_streaming() -> (CreateLLMModelEndpointV1Request): +def create_llm_model_endpoint_trt_llm_request_streaming() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( name="test_llm_endpoint_name_trt_llm_streaming", model_name="llama-2-7b", @@ -449,7 +449,7 @@ def create_llm_model_endpoint_trt_llm_request_streaming() -> (CreateLLMModelEndp @pytest.fixture -def create_llm_model_endpoint_trt_llm_request_async() -> (CreateLLMModelEndpointV1Request): +def create_llm_model_endpoint_trt_llm_request_async() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( name="test_llm_endpoint_name_tgi_async", model_name="llama-2-7b", @@ -477,7 +477,7 @@ def create_llm_model_endpoint_trt_llm_request_async() -> (CreateLLMModelEndpoint @pytest.fixture -def create_llm_model_endpoint_request_invalid_model_name() -> (CreateLLMModelEndpointV1Request): +def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( name="test_llm_endpoint_name_1", model_name="nonexist", @@ -503,7 +503,7 @@ def create_llm_model_endpoint_request_invalid_model_name() -> (CreateLLMModelEnd @pytest.fixture -def create_llm_model_endpoint_request_invalid_quantization() -> (CreateLLMModelEndpointV1Request): +def create_llm_model_endpoint_request_invalid_quantization() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( name="test_llm_endpoint_name_1", model_name="nonexist", diff --git a/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py b/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py index 11b2abe5..7f80b4a1 100644 --- a/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py +++ b/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py @@ -50,9 +50,9 @@ def live_batch_job_orchestration_service( assert model_endpoint_1.infra_state is not None assert model_endpoint_runnable.infra_state is not None gateway.db[model_endpoint_1.infra_state.deployment_name] = model_endpoint_1.infra_state - gateway.db[ - model_endpoint_runnable.infra_state.deployment_name - ] = model_endpoint_runnable.infra_state + gateway.db[model_endpoint_runnable.infra_state.deployment_name] = ( + model_endpoint_runnable.infra_state + ) return LiveBatchJobOrchestrationService( model_endpoint_service=fake_live_model_endpoint_service, batch_job_record_repository=fake_batch_job_record_repository, diff --git a/requirements-dev.txt b/requirements-dev.txt index 87959381..5e673c87 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,10 +1,10 @@ # Make sure to update .pre-commit-config.yaml to match versions! -black[jupyter]==22.12.0 +black[jupyter]==24.8.0 datamodel-code-generator>=0.25.8 -ruff==0.0.278 +ruff==0.6.8 ipython==8.12.0 # 8.12.0 is the last version to support Python 3.8 -isort==5.12.0 -mypy==1.3.0 -pip-tools==7.0.0 +isort==5.13.2 +mypy==1.11.2 +pip-tools==7.4.1 poetry==1.8.2 -pre-commit==3.3.3 \ No newline at end of file +pre-commit==3.8.0 From 1e35c171ade19c62001bec6260740bf92c78ad03 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 30 Sep 2024 10:57:33 -0700 Subject: [PATCH 390/425] Upgrade vLLM to 0.6.2 (#626) * Bump vllm 0.6.2 * Get multimodal batch completions working w/ 0.6.2 * Get vllm_server working * replace pydantic v2 anyurl --- .../common/dtos/llms/batch_completion.py | 5 +- .../common/dtos/llms/vllm.py | 28 +- .../common/pydantic_types.py | 116 +- .../common/types/gen/openai.py | 1600 +++++++---------- .../inference/vllm/Dockerfile.vllm | 2 +- .../inference/vllm/README.md | 3 +- .../inference/vllm/build_and_upload_image.sh | 2 +- .../vllm/examples/v2/gemma/README.md | 19 + .../config.json} | 4 +- .../v2/gemma/config_w_oai_chat_content.json | 33 + .../vllm/examples/v2/gemma/data_oai_chat.json | 7 + .../v2/gemma/data_oai_completion.json | 16 + .../examples/v2/llama-3.2-vision/README.md | 19 + .../examples/v2/llama-3.2-vision/config.json | 18 + .../v2/llama-3.2-vision/data_oai_chat.json | 22 + .../v2/llama-3.2-vision/output_oi_chat.json | 1 + .../examples/v2/sample_data_chat_gemma.json | 1 - .../inference/vllm/requirements-dev.txt | 2 +- .../inference/vllm/vllm_batch.py | 41 +- .../inference/vllm/vllm_server.py | 40 +- scripts/generate-openai-types.sh | 5 + 21 files changed, 965 insertions(+), 1019 deletions(-) create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/gemma/README.md rename model-engine/model_engine_server/inference/vllm/examples/v2/{sample_config_gemma.json => gemma/config.json} (68%) create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config_w_oai_chat_content.json create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_chat.json create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_completion.json create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/README.md create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/config.json create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/data_oai_chat.json create mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/output_oi_chat.json delete mode 100644 model-engine/model_engine_server/inference/vllm/examples/v2/sample_data_chat_gemma.json diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py index 1ebe1f78..019aa707 100644 --- a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -11,6 +11,7 @@ CompletionV2Request, CompletionV2Response, ) +from model_engine_server.common.dtos.llms.vllm import VLLMModelConfig from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field from typing_extensions import TypeAlias @@ -40,7 +41,7 @@ class ToolConfig(BaseModel): """ -class BatchCompletionsModelConfig(BaseModel): +class BatchCompletionsModelConfig(VLLMModelConfig): model: str = Field( description="ID of the model to use.", examples=["mixtral-8x7b-instruct"], @@ -62,7 +63,7 @@ class BatchCompletionsModelConfig(BaseModel): max_context_length: Optional[int] = Field( default=None, ge=1, - description="Maximum context length to use for the model. Defaults to the max allowed by the model", + description="Maximum context length to use for the model. Defaults to the max allowed by the model. Deprecated in favor of max_model_len.", ) seed: Optional[int] = Field(default=None, description="Random seed for the model.") diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py index 700af2b1..af207d94 100644 --- a/model-engine/model_engine_server/common/dtos/llms/vllm.py +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -1,12 +1,36 @@ from typing import Any, Dict, List, Optional -from model_engine_server.common.pydantic_types import Field +from model_engine_server.common.pydantic_types import BaseModel, Field from typing_extensions import Annotated # This was last synced w/ vLLM v0.5.5 on 2024-09-03 -class VLLMSamplingParams: +class VLLMModelConfig(BaseModel): + """Model configuration for VLLM""" + + max_model_len: Optional[int] = Field( + None, + description="""Model context length, If unspecified, will be automatically derived from the model config""", + ) + + max_num_seqs: Optional[int] = Field( + None, + description="""Maximum number of sequences per iteration""", + ) + + enforce_eager: Optional[bool] = Field( + None, + description="""Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph in hybrid for maximal perforamnce and flexibility""", + ) + + gpu_memory_utilization: Optional[float] = Field( + None, + description="Maximum GPU memory utilization for the batch inference. Default to 90%.", + ) + + +class VLLMSamplingParams(BaseModel): best_of: Optional[int] = Field( None, description="""Number of output sequences that are generated from the prompt. diff --git a/model-engine/model_engine_server/common/pydantic_types.py b/model-engine/model_engine_server/common/pydantic_types.py index 19fc99c0..6768acae 100644 --- a/model-engine/model_engine_server/common/pydantic_types.py +++ b/model-engine/model_engine_server/common/pydantic_types.py @@ -1,8 +1,122 @@ +from typing import Any, Type, TypeVar + +from pydantic import AnyHttpUrl as PyAnyHttpUrl +from pydantic import AnyUrl as PyAnyUrl +from pydantic import AnyWebsocketUrl as PyAnyWebsocketUrl from pydantic import BaseModel as PydanticBaseModel -from pydantic import ConfigDict, Field, RootModel, ValidationError, model_validator # noqa: F401 +from pydantic import model_validator # noqa: F401 +from pydantic import ConfigDict, Field # noqa: F401 +from pydantic import FileUrl as PyFileUrl +from pydantic import FtpUrl as PyFtpUrl +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler # noqa: F401 +from pydantic import HttpUrl as PyHttpUrl +from pydantic import RootModel, TypeAdapter, ValidationError # noqa: F401 +from pydantic import WebsocketUrl as PyWebsocketUrl +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import CoreSchema, core_schema class BaseModel(PydanticBaseModel): """Common pydantic configurations for model engine""" model_config = ConfigDict(protected_namespaces=()) + + +# See https://github.com/patrsc/pydantic-string-url +# just copied it over cause it was a single file + +"""Pydantic URL types based on strings.""" + + +T = TypeVar("T", bound=PyAnyUrl) + + +class AnyUrl(str): + """Pydantic's AnyUrl based on str.""" + + _pydantic_type = PyAnyUrl + _example_url = "http://www.example.com/" + + def __init__(self, url: str) -> None: + """Initialize.""" + pydantic_url = validate_url(url, self._pydantic_type) + super().__init__() + self.url = pydantic_url + + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type: Any, # pylint: disable=unused-argument + handler: GetCoreSchemaHandler, + ) -> CoreSchema: + """Get pydantic core schema.""" + return core_schema.no_info_after_validator_function(cls._validate, handler(str)) + + @classmethod + def __get_pydantic_json_schema__( + cls, + schema: CoreSchema, + handler: GetJsonSchemaHandler, + ) -> JsonSchemaValue: + """Get pydantic JSON schema.""" + json_schema = handler(schema) + json_schema = handler.resolve_ref_schema(json_schema) + json_schema["format"] = "uri" + json_schema["minLength"] = 1 + json_schema["maxLength"] = 65536 + json_schema["examples"] = [cls._example_url] + return json_schema + + @classmethod + def _validate(cls, __input_value: str) -> "AnyUrl": + return cls(__input_value) + + +def validate_url(s: str, cls: Type[T]) -> T: + """Validate if string has the format of a proper URL or given Pydantic type.""" + # This uses pydantic's class just for validation. + a = TypeAdapter(cls) + url = a.validate_python(s, strict=True) + return url + + +class AnyHttpUrl(AnyUrl): + """Pydantic's AnyHttpUrl based on str.""" + + _pydantic_type = PyAnyHttpUrl + _example_url = "http://www.example.com/" + + +class HttpUrl(AnyUrl): + """Pydantic's HttpUrl based on str.""" + + _pydantic_type = PyHttpUrl + _example_url = "http://www.example.com/" + + +class AnyWebsocketUrl(AnyUrl): + """Pydantic's AnyWebsocketUrl based on str.""" + + _pydantic_type = PyAnyWebsocketUrl + _example_url = "ws://www.example.com/" + + +class WebsocketUrl(AnyUrl): + """Pydantic's WebsocketUrl based on str.""" + + _pydantic_type = PyWebsocketUrl + _example_url = "ws://www.example.com/" + + +class FileUrl(AnyUrl): + """Pydantic's FileUrl based on str.""" + + _pydantic_type = PyFileUrl + _example_url = "file://www.example.com/" + + +class FtpUrl(AnyUrl): + """Pydantic's FtpUrl based on str.""" + + _pydantic_type = PyFtpUrl + _example_url = "ftp://www.example.com/" diff --git a/model-engine/model_engine_server/common/types/gen/openai.py b/model-engine/model_engine_server/common/types/gen/openai.py index 5337b6e1..f9444769 100644 --- a/model-engine/model_engine_server/common/types/gen/openai.py +++ b/model-engine/model_engine_server/common/types/gen/openai.py @@ -1,19 +1,25 @@ # generated by datamodel-codegen: # filename: openai-spec.yaml -# timestamp: 2024-09-10T16:20:49+00:00 +# timestamp: 2024-09-30T08:39:28+00:00 from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union -from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel -from typing_extensions import Annotated, Literal +from model_engine_server.common.pydantic_types import ( + AnyUrl, + BaseModel, + ConfigDict, + Field, + RootModel, +) +from typing_extensions import Annotated class Error(BaseModel): - code: Annotated[Optional[str], Field(...)] + code: Annotated[Optional[str], Field(...)] = None message: str - param: Annotated[Optional[str], Field(...)] + param: Annotated[Optional[str], Field(...)] = None type: str @@ -31,7 +37,6 @@ class Prompt(RootModel[Optional[List[int]]]): root: Annotated[ Optional[List[int]], Field( - "<|endoftext|>", description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", examples=["[1212, 318, 257, 1332, 13]"], min_length=1, @@ -47,7 +52,6 @@ class Prompt1(RootModel[Optional[List[Prompt1Item]]]): root: Annotated[ Optional[List[Prompt1Item]], Field( - "<|endoftext|>", description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", examples=["[[1212, 318, 257, 1332, 13]]"], min_length=1, @@ -59,7 +63,6 @@ class Stop(RootModel[Optional[List[str]]]): root: Annotated[ Optional[List[str]], Field( - None, description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", max_length=4, min_length=1, @@ -99,10 +102,9 @@ class ImageUrl(BaseModel): detail: Annotated[ Literal["auto", "low", "high"], Field( - "auto", - description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding).", + description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding)." ), - ] + ] = "auto" class ChatCompletionRequestMessageContentPartImage(BaseModel): @@ -177,10 +179,9 @@ class ChatCompletionRequestSystemMessage(BaseModel): name: Annotated[ Optional[str], Field( - None, - description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." ), - ] + ] = None class Content1(RootModel[List[ChatCompletionRequestUserMessageContentPart]]): @@ -205,17 +206,15 @@ class ChatCompletionRequestUserMessage(BaseModel): name: Annotated[ Optional[str], Field( - None, - description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." ), - ] + ] = None class Content2(RootModel[Optional[List[ChatCompletionRequestAssistantMessageContentPart]]]): root: Annotated[ Optional[List[ChatCompletionRequestAssistantMessageContentPart]], Field( - None, description="An array of content parts with a defined type. Can be one or more of type `text`, or exactly one of type `refusal`.", min_length=1, title="Array of content parts", @@ -258,7 +257,9 @@ class ChatCompletionRequestFunctionMessage(BaseModel): Literal["function"], Field(description="The role of the messages author, in this case `function`."), ] - content: Annotated[Optional[str], Field(description="The contents of the function message.")] + content: Annotated[ + Optional[str], Field(description="The contents of the function message.") + ] = None name: Annotated[str, Field(description="The name of the function to call.")] @@ -273,10 +274,9 @@ class ChatCompletionFunctions(BaseModel): description: Annotated[ Optional[str], Field( - None, - description="A description of what the function does, used by the model to choose when and how to call the function.", + description="A description of what the function does, used by the model to choose when and how to call the function." ), - ] + ] = None name: Annotated[ str, Field( @@ -294,10 +294,9 @@ class FunctionObject(BaseModel): description: Annotated[ Optional[str], Field( - None, - description="A description of what the function does, used by the model to choose when and how to call the function.", + description="A description of what the function does, used by the model to choose when and how to call the function." ), - ] + ] = None name: Annotated[ str, Field( @@ -308,10 +307,9 @@ class FunctionObject(BaseModel): strict: Annotated[ Optional[bool], Field( - False, - description="Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling).", + description="Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling)." ), - ] + ] = False class ResponseFormatText(BaseModel): @@ -339,24 +337,22 @@ class JsonSchema(BaseModel): description: Annotated[ Optional[str], Field( - None, - description="A description of what the response format is for, used by the model to determine how to respond in the format.", + description="A description of what the response format is for, used by the model to determine how to respond in the format." ), - ] + ] = None name: Annotated[ str, Field( description="The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." ), ] - schema_: Annotated[Optional[ResponseFormatJsonSchemaSchema], Field(None, alias="schema")] + schema_: Annotated[Optional[ResponseFormatJsonSchemaSchema], Field(alias="schema")] = None strict: Annotated[ Optional[bool], Field( - False, - description="Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs).", + description="Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs)." ), - ] + ] = False class ResponseFormatJsonSchema(BaseModel): @@ -408,26 +404,22 @@ class ChatCompletionMessageToolCall(BaseModel): class Function2(BaseModel): - name: Annotated[Optional[str], Field(None, description="The name of the function to call.")] + name: Annotated[Optional[str], Field(description="The name of the function to call.")] = None arguments: Annotated[ Optional[str], Field( - None, - description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." ), - ] + ] = None class ChatCompletionMessageToolCallChunk(BaseModel): index: int - id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call.")] = None type: Annotated[ Optional[Literal["function"]], - Field( - None, - description="The type of the tool. Currently, only `function` is supported.", - ), - ] + Field(description="The type of the tool. Currently, only `function` is supported."), + ] = None function: Optional[Function2] = None @@ -442,41 +434,39 @@ class ChatCompletionStreamOptions(BaseModel): include_usage: Annotated[ Optional[bool], Field( - None, - description="If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.\n", + description="If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.\n" ), - ] + ] = None class FunctionCall2(BaseModel): arguments: Annotated[ Optional[str], Field( - None, - description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.", + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." ), - ] - name: Annotated[Optional[str], Field(None, description="The name of the function to call.")] + ] = None + name: Annotated[Optional[str], Field(description="The name of the function to call.")] = None class ChatCompletionStreamResponseDelta(BaseModel): - content: Annotated[Optional[str], Field(None, description="The contents of the chunk message.")] + content: Annotated[Optional[str], Field(description="The contents of the chunk message.")] = ( + None + ) function_call: Annotated[ Optional[FunctionCall2], Field( - None, - description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." ), - ] + ] = None tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = None role: Annotated[ Optional[Literal["system", "user", "assistant", "tool"]], - Field(None, description="The role of the author of this message."), - ] + Field(description="The role of the author of this message."), + ] = None refusal: Annotated[ - Optional[str], - Field(None, description="The refusal message generated by the model."), - ] + Optional[str], Field(description="The refusal message generated by the model.") + ] = None class Stop1(RootModel[List[str]]): @@ -535,19 +525,16 @@ class Logprobs2(BaseModel): ] refusal: Annotated[ Optional[List[ChatCompletionTokenLogprob]], - Field( - None, - description="A list of message refusal tokens with log probability information.", - ), - ] + Field(description="A list of message refusal tokens with log probability information."), + ] = None class Choice3(BaseModel): delta: ChatCompletionStreamResponseDelta logprobs: Annotated[ Optional[Logprobs2], - Field(None, description="Log probability information for the choice."), - ] + Field(description="Log probability information for the choice."), + ] = None finish_reason: Annotated[ Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]], Field( @@ -591,18 +578,16 @@ class CreateChatCompletionStreamResponse(BaseModel): service_tier: Annotated[ Optional[Literal["scale", "default"]], Field( - None, description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", examples=["scale"], ), - ] + ] = None system_fingerprint: Annotated[ Optional[str], Field( - None, - description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" ), - ] + ] = None object: Annotated[ Literal["chat.completion.chunk"], Field(description="The object type, which is always `chat.completion.chunk`."), @@ -610,10 +595,9 @@ class CreateChatCompletionStreamResponse(BaseModel): usage: Annotated[ Optional[Usage], Field( - None, - description='An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.\nWhen present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.\n', + description='An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.\nWhen present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.\n' ), - ] + ] = None class CreateChatCompletionImageResponse(BaseModel): @@ -630,86 +614,73 @@ class CreateImageRequest(BaseModel): ] model: Annotated[ Optional[Union[Optional[str], Literal["dall-e-2", "dall-e-3"]]], - Field( - "dall-e-2", - description="The model to use for image generation.", - examples=["dall-e-3"], - ), - ] + Field(description="The model to use for image generation.", examples=["dall-e-3"]), + ] = "dall-e-2" n: Annotated[ Optional[int], Field( - 1, description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", examples=[1], ge=1, le=10, ), - ] + ] = 1 quality: Annotated[ Literal["standard", "hd"], Field( - "standard", description="The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`.", examples=["standard"], ), - ] + ] = "standard" response_format: Annotated[ Optional[Literal["url", "b64_json"]], Field( - "url", description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", examples=["url"], ), - ] + ] = "url" size: Annotated[ Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]], Field( - "1024x1024", description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models.", examples=["1024x1024"], ), - ] + ] = "1024x1024" style: Annotated[ Optional[Literal["vivid", "natural"]], Field( - "vivid", description="The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`.", examples=["vivid"], ), - ] + ] = "vivid" user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", examples=["user-1234"], ), - ] + ] = None class Image(BaseModel): b64_json: Annotated[ Optional[str], Field( - None, - description="The base64-encoded JSON of the generated image, if `response_format` is `b64_json`.", + description="The base64-encoded JSON of the generated image, if `response_format` is `b64_json`." ), - ] + ] = None url: Annotated[ Optional[str], Field( - None, - description="The URL of the generated image, if `response_format` is `url` (default).", + description="The URL of the generated image, if `response_format` is `url` (default)." ), - ] + ] = None revised_prompt: Annotated[ Optional[str], Field( - None, - description="The prompt that was used to generate the image, if there was any revision to the prompt.", + description="The prompt that was used to generate the image, if there was any revision to the prompt." ), - ] + ] = None class CreateImageEditRequest(BaseModel): @@ -729,52 +700,46 @@ class CreateImageEditRequest(BaseModel): mask: Annotated[ Optional[bytes], Field( - None, - description="An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`.", + description="An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`." ), - ] + ] = None model: Annotated[ Optional[Union[Optional[str], Literal["dall-e-2"]]], Field( - "dall-e-2", description="The model to use for image generation. Only `dall-e-2` is supported at this time.", examples=["dall-e-2"], ), - ] + ] = "dall-e-2" n: Annotated[ Optional[int], Field( - 1, description="The number of images to generate. Must be between 1 and 10.", examples=[1], ge=1, le=10, ), - ] + ] = 1 size: Annotated[ Optional[Literal["256x256", "512x512", "1024x1024"]], Field( - "1024x1024", description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", examples=["1024x1024"], ), - ] + ] = "1024x1024" response_format: Annotated[ Optional[Literal["url", "b64_json"]], Field( - "url", description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", examples=["url"], ), - ] + ] = "url" user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", examples=["user-1234"], ), - ] + ] = None class CreateImageVariationRequest(BaseModel): @@ -787,45 +752,40 @@ class CreateImageVariationRequest(BaseModel): model: Annotated[ Optional[Union[Optional[str], Literal["dall-e-2"]]], Field( - "dall-e-2", description="The model to use for image generation. Only `dall-e-2` is supported at this time.", examples=["dall-e-2"], ), - ] + ] = "dall-e-2" n: Annotated[ Optional[int], Field( - 1, description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", examples=[1], ge=1, le=10, ), - ] + ] = 1 response_format: Annotated[ Optional[Literal["url", "b64_json"]], Field( - "url", description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", examples=["url"], ), - ] + ] = "url" size: Annotated[ Optional[Literal["256x256", "512x512", "1024x1024"]], Field( - "1024x1024", description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", examples=["1024x1024"], ), - ] + ] = "1024x1024" user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", examples=["user-1234"], ), - ] + ] = None class CreateModerationRequest(BaseModel): @@ -833,11 +793,10 @@ class CreateModerationRequest(BaseModel): model: Annotated[ Union[str, Literal["text-moderation-latest", "text-moderation-stable"]], Field( - "text-moderation-latest", description="Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`.\n\nThe default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`.\n", examples=["text-moderation-stable"], ), - ] + ] = "text-moderation-latest" class Categories(BaseModel): @@ -1041,10 +1000,9 @@ class CompleteUploadRequest(BaseModel): md5: Annotated[ Optional[str], Field( - None, - description="The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect.\n", + description="The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect.\n" ), - ] + ] = None class CancelUploadRequest(BaseModel): @@ -1090,24 +1048,21 @@ class Hyperparameters(BaseModel): batch_size: Annotated[ Union[Literal["auto"], BatchSize], Field( - "auto", - description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n" ), - ] + ] = "auto" learning_rate_multiplier: Annotated[ Union[Literal["auto"], LearningRateMultiplier], Field( - "auto", - description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n" ), - ] + ] = "auto" n_epochs: Annotated[ Union[Literal["auto"], NEpochs], Field( - "auto", - description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n" ), - ] + ] = "auto" class Wandb(BaseModel): @@ -1121,24 +1076,21 @@ class Wandb(BaseModel): name: Annotated[ Optional[str], Field( - None, - description="A display name to set for the run. If not set, we will use the Job ID as the name.\n", + description="A display name to set for the run. If not set, we will use the Job ID as the name.\n" ), - ] + ] = None entity: Annotated[ Optional[str], Field( - None, - description="The entity to use for the run. This allows you to set the team or username of the WandB user that you would\nlike associated with the run. If not set, the default entity for the registered WandB API key is used.\n", + description="The entity to use for the run. This allows you to set the team or username of the WandB user that you would\nlike associated with the run. If not set, the default entity for the registered WandB API key is used.\n" ), - ] + ] = None tags: Annotated[ Optional[List[str]], Field( - None, - description='A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some\ndefault tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".\n', + description='A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some\ndefault tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".\n' ), - ] + ] = None class Integration(BaseModel): @@ -1173,42 +1125,36 @@ class CreateFineTuningJobRequest(BaseModel): ] hyperparameters: Annotated[ Optional[Hyperparameters], - Field(None, description="The hyperparameters used for the fine-tuning job."), - ] + Field(description="The hyperparameters used for the fine-tuning job."), + ] = None suffix: Annotated[ Optional[str], Field( - None, description='A string of up to 18 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`.\n', max_length=40, min_length=1, ), - ] + ] = None validation_file: Annotated[ Optional[str], Field( - None, description="The ID of an uploaded file that contains validation data.\n\nIf you provide this file, the data is used to generate validation\nmetrics periodically during fine-tuning. These metrics can be viewed in\nthe fine-tuning results file.\nThe same data should not be present in both train and validation files.\n\nYour dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", examples=["file-abc123"], ), - ] + ] = None integrations: Annotated[ Optional[List[Integration]], - Field( - None, - description="A list of integrations to enable for your fine-tuning job.", - ), - ] + Field(description="A list of integrations to enable for your fine-tuning job."), + ] = None seed: Annotated[ Optional[int], Field( - None, description="The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases.\nIf a seed is not specified, one will be generated for you.\n", examples=[42], ge=0, le=2147483647, ), - ] + ] = None class Input(RootModel[List[str]]): @@ -1282,27 +1228,24 @@ class CreateEmbeddingRequest(BaseModel): encoding_format: Annotated[ Literal["float", "base64"], Field( - "float", description="The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).", examples=["float"], ), - ] + ] = "float" dimensions: Annotated[ Optional[int], Field( - None, description="The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.\n", ge=1, ), - ] + ] = None user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", examples=["user-1234"], ), - ] + ] = None class Usage1(BaseModel): @@ -1332,39 +1275,34 @@ class CreateTranscriptionRequest(BaseModel): language: Annotated[ Optional[str], Field( - None, - description="The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.\n", + description="The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.\n" ), - ] + ] = None prompt: Annotated[ Optional[str], Field( - None, - description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language.\n", + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language.\n" ), - ] + ] = None response_format: Annotated[ Literal["json", "text", "srt", "verbose_json", "vtt"], Field( - "json", - description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n" ), - ] + ] = "json" temperature: Annotated[ float, Field( - 0, - description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n" ), - ] + ] = 0 timestamp_granularities__: Annotated[ List[Literal["word", "segment"]], Field( - ["segment"], alias="timestamp_granularities[]", description="The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency.\n", ), - ] + ] = ["segment"] class CreateTranscriptionResponseJson(BaseModel): @@ -1414,15 +1352,12 @@ class CreateTranscriptionResponseVerboseJson(BaseModel): text: Annotated[str, Field(description="The transcribed text.")] words: Annotated[ Optional[List[TranscriptionWord]], - Field(None, description="Extracted words and their corresponding timestamps."), - ] + Field(description="Extracted words and their corresponding timestamps."), + ] = None segments: Annotated[ Optional[List[TranscriptionSegment]], - Field( - None, - description="Segments of the transcribed text and their corresponding details.", - ), - ] + Field(description="Segments of the transcribed text and their corresponding details."), + ] = None class CreateTranslationRequest(BaseModel): @@ -1445,24 +1380,21 @@ class CreateTranslationRequest(BaseModel): prompt: Annotated[ Optional[str], Field( - None, - description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English.\n", + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English.\n" ), - ] + ] = None response_format: Annotated[ str, Field( - "json", - description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n", + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n" ), - ] + ] = "json" temperature: Annotated[ float, Field( - 0, - description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n", + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n" ), - ] + ] = 0 class CreateTranslationResponseJson(BaseModel): @@ -1478,11 +1410,8 @@ class CreateTranslationResponseVerboseJson(BaseModel): text: Annotated[str, Field(description="The translated text.")] segments: Annotated[ Optional[List[TranscriptionSegment]], - Field( - None, - description="Segments of the translated text and their corresponding details.", - ), - ] + Field(description="Segments of the translated text and their corresponding details."), + ] = None class CreateSpeechRequest(BaseModel): @@ -1511,19 +1440,17 @@ class CreateSpeechRequest(BaseModel): response_format: Annotated[ Literal["mp3", "opus", "aac", "flac", "wav", "pcm"], Field( - "mp3", - description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`.", + description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`." ), - ] + ] = "mp3" speed: Annotated[ float, Field( - 1.0, description="The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.", ge=0.25, le=4.0, ), - ] + ] = 1.0 class Model(BaseModel): @@ -1578,10 +1505,9 @@ class OpenAIFile(BaseModel): status_details: Annotated[ Optional[str], Field( - None, - description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`.", + description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`." ), - ] + ] = None class Upload(BaseModel): @@ -1613,12 +1539,12 @@ class Upload(BaseModel): ] object: Annotated[ Optional[Literal["upload"]], - Field(None, description='The object type, which is always "upload".'), - ] + Field(description='The object type, which is always "upload".'), + ] = None file: Annotated[ Optional[OpenAIFile], - Field(None, description="The ready File object after the Upload is completed."), - ] + Field(description="The ready File object after the Upload is completed."), + ] = None class UploadPart(BaseModel): @@ -1666,7 +1592,7 @@ class Error1(BaseModel): Field( description="The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific." ), - ] + ] = None class NEpochs1(RootModel[int]): @@ -1754,13 +1680,12 @@ class FineTuningJobCheckpoint(BaseModel): class FinetuneCompletionRequestInput(BaseModel): prompt: Annotated[ - Optional[str], - Field(None, description="The input prompt for this training example."), - ] + Optional[str], Field(description="The input prompt for this training example.") + ] = None completion: Annotated[ Optional[str], - Field(None, description="The desired completion for this training example."), - ] + Field(description="The desired completion for this training example."), + ] = None class CompletionUsage(BaseModel): @@ -1829,22 +1754,20 @@ class CodeInterpreter(BaseModel): file_ids: Annotated[ List[str], Field( - [], description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool.\n", max_length=20, ), - ] + ] = [] class FileSearch(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", max_length=1, ), - ] + ] = None class ToolResources(BaseModel): @@ -1856,11 +1779,10 @@ class CodeInterpreter1(BaseModel): file_ids: Annotated[ List[str], Field( - [], description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", max_length=20, ), - ] + ] = [] class ChunkingStrategy(BaseModel): @@ -1902,25 +1824,22 @@ class VectorStore(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", max_length=10000, ), - ] + ] = None chunking_strategy: Annotated[ Optional[Union[ChunkingStrategy, ChunkingStrategy1]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class FileSearch1(BaseModel): @@ -1934,11 +1853,10 @@ class FileSearch1(BaseModel): vector_stores: Annotated[ Optional[List[VectorStore]], Field( - None, description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", max_length=1, ), - ] + ] = None class ChunkingStrategy2(BaseModel): @@ -1960,36 +1878,32 @@ class VectorStore1(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", max_length=10000, ), - ] + ] = None chunking_strategy: Annotated[ Optional[Union[ChunkingStrategy2, ChunkingStrategy3]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class FileSearch2(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", max_length=1, ), - ] + ] = None vector_stores: Annotated[ List[VectorStore1], Field( @@ -2008,22 +1922,20 @@ class CodeInterpreter2(BaseModel): file_ids: Annotated[ List[str], Field( - [], description="Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", max_length=20, ), - ] + ] = [] class FileSearch3(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", max_length=1, ), - ] + ] = None class ToolResources2(BaseModel): @@ -2048,12 +1960,11 @@ class FileSearch4(BaseModel): max_num_results: Annotated[ Optional[int], Field( - None, description="The maximum number of results the file search tool should output. The default is 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between 1 and 50 inclusive.\n\nNote that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information.\n", ge=1, le=50, ), - ] + ] = None class AssistantToolsFileSearch(BaseModel): @@ -2062,9 +1973,8 @@ class AssistantToolsFileSearch(BaseModel): Field(description="The type of tool being defined: `file_search`"), ] file_search: Annotated[ - Optional[FileSearch4], - Field(None, description="Overrides for the file search tool."), - ] + Optional[FileSearch4], Field(description="Overrides for the file search tool.") + ] = None class AssistantToolsFileSearchTypeOnly(BaseModel): @@ -2092,11 +2002,10 @@ class TruncationObject(BaseModel): last_messages: Annotated[ Optional[int], Field( - None, description="The number of most recent messages from the thread when constructing the context for the run.", ge=1, ), - ] + ] = None class Function3(BaseModel): @@ -2125,10 +2034,9 @@ class IncompleteDetails(BaseModel): reason: Annotated[ Optional[Literal["max_completion_tokens", "max_prompt_tokens"]], Field( - None, - description="The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run.", + description="The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run." ), - ] + ] = None class ModifyRunRequest(BaseModel): @@ -2138,27 +2046,22 @@ class ModifyRunRequest(BaseModel): metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class ToolOutput(BaseModel): tool_call_id: Annotated[ Optional[str], Field( - None, - description="The ID of the tool call in the `required_action` object within the run object the output is being submitted for.", + description="The ID of the tool call in the `required_action` object within the run object the output is being submitted for." ), - ] + ] = None output: Annotated[ Optional[str], - Field( - None, - description="The output of the tool call to be submitted to continue the run.", - ), - ] + Field(description="The output of the tool call to be submitted to continue the run."), + ] = None class SubmitToolOutputsRunRequest(BaseModel): @@ -2172,10 +2075,9 @@ class SubmitToolOutputsRunRequest(BaseModel): stream: Annotated[ Optional[bool], Field( - None, - description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" ), - ] + ] = None class Function4(BaseModel): @@ -2206,22 +2108,20 @@ class CodeInterpreter3(BaseModel): file_ids: Annotated[ List[str], Field( - [], description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", max_length=20, ), - ] + ] = [] class FileSearch5(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", max_length=1, ), - ] + ] = None class ToolResources3(BaseModel): @@ -2233,11 +2133,10 @@ class FileSearch6(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", max_length=1, ), - ] + ] = None class ToolResources4(BaseModel): @@ -2291,25 +2190,22 @@ class VectorStore2(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", max_length=10000, ), - ] + ] = None chunking_strategy: Annotated[ Optional[Union[ChunkingStrategy4, ChunkingStrategy5]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class FileSearch7(BaseModel): @@ -2323,11 +2219,10 @@ class FileSearch7(BaseModel): vector_stores: Annotated[ Optional[List[VectorStore2]], Field( - None, description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", max_length=1, ), - ] + ] = None class ChunkingStrategy6(BaseModel): @@ -2349,36 +2244,32 @@ class VectorStore3(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", max_length=10000, ), - ] + ] = None chunking_strategy: Annotated[ Optional[Union[ChunkingStrategy6, ChunkingStrategy7]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class FileSearch8(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", max_length=1, ), - ] + ] = None vector_stores: Annotated[ List[VectorStore3], Field( @@ -2397,11 +2288,10 @@ class FileSearch9(BaseModel): vector_store_ids: Annotated[ Optional[List[str]], Field( - None, description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", max_length=1, ), - ] + ] = None class ToolResources6(BaseModel): @@ -2416,17 +2306,15 @@ class ModifyThreadRequest(BaseModel): tool_resources: Annotated[ Optional[ToolResources6], Field( - None, - description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class DeleteThreadResponse(BaseModel): @@ -2452,13 +2340,12 @@ class IncompleteDetails1(BaseModel): class Attachment(BaseModel): file_id: Annotated[ - Optional[str], - Field(None, description="The ID of the file to attach to the message."), - ] + Optional[str], Field(description="The ID of the file to attach to the message.") + ] = None tools: Annotated[ Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearchTypeOnly]]], - Field(None, description="The tools to add this file to."), - ] + Field(description="The tools to add this file to."), + ] = None class ModifyMessageRequest(BaseModel): @@ -2468,10 +2355,9 @@ class ModifyMessageRequest(BaseModel): metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class DeleteMessageResponse(BaseModel): @@ -2490,10 +2376,9 @@ class ImageFile(BaseModel): detail: Annotated[ Literal["auto", "low", "high"], Field( - "auto", - description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`." ), - ] + ] = "auto" class MessageContentImageFileObject(BaseModel): @@ -2505,17 +2390,15 @@ class ImageFile1(BaseModel): file_id: Annotated[ Optional[str], Field( - None, - description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.', + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.' ), - ] + ] = None detail: Annotated[ Literal["auto", "low", "high"], Field( - "auto", - description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`." ), - ] + ] = "auto" class MessageDeltaContentImageFileObject(BaseModel): @@ -2534,10 +2417,9 @@ class ImageUrl1(BaseModel): detail: Annotated[ Literal["auto", "low", "high"], Field( - "auto", - description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`", + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`" ), - ] + ] = "auto" class MessageContentImageUrlObject(BaseModel): @@ -2549,17 +2431,15 @@ class ImageUrl2(BaseModel): url: Annotated[ Optional[str], Field( - None, - description="The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp.", + description="The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." ), - ] + ] = None detail: Annotated[ Literal["auto", "low", "high"], Field( - "auto", - description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`.", + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`." ), - ] + ] = "auto" class MessageDeltaContentImageUrlObject(BaseModel): @@ -2617,9 +2497,9 @@ class MessageDeltaContentRefusalObject(BaseModel): class FileCitation1(BaseModel): file_id: Annotated[ Optional[str], - Field(None, description="The ID of the specific File the citation is from."), - ] - quote: Annotated[Optional[str], Field(None, description="The specific quote in the file.")] + Field(description="The ID of the specific File the citation is from."), + ] = None + quote: Annotated[Optional[str], Field(description="The specific quote in the file.")] = None class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): @@ -2629,20 +2509,17 @@ class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] text: Annotated[ Optional[str], - Field( - None, - description="The text in the message content that needs to be replaced.", - ), - ] + Field(description="The text in the message content that needs to be replaced."), + ] = None file_citation: Optional[FileCitation1] = None - start_index: Annotated[Optional[int], Field(None, ge=0)] - end_index: Annotated[Optional[int], Field(None, ge=0)] + start_index: Annotated[Optional[int], Field(ge=0)] = None + end_index: Annotated[Optional[int], Field(ge=0)] = None class FilePath1(BaseModel): file_id: Annotated[ - Optional[str], Field(None, description="The ID of the file that was generated.") - ] + Optional[str], Field(description="The ID of the file that was generated.") + ] = None class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): @@ -2652,14 +2529,11 @@ class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] text: Annotated[ Optional[str], - Field( - None, - description="The text in the message content that needs to be replaced.", - ), - ] + Field(description="The text in the message content that needs to be replaced."), + ] = None file_path: Optional[FilePath1] = None - start_index: Annotated[Optional[int], Field(None, ge=0)] - end_index: Annotated[Optional[int], Field(None, ge=0)] + start_index: Annotated[Optional[int], Field(ge=0)] = None + end_index: Annotated[Optional[int], Field(ge=0)] = None class LastError1(BaseModel): @@ -2685,8 +2559,8 @@ class RunStepDetailsMessageCreationObject(BaseModel): class MessageCreation1(BaseModel): message_id: Annotated[ Optional[str], - Field(None, description="The ID of the message that was created by this run step."), - ] + Field(description="The ID of the message that was created by this run step."), + ] = None class RunStepDeltaStepDetailsMessageCreationObject(BaseModel): @@ -2704,8 +2578,8 @@ class RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject(BaseModel): type: Annotated[Literal["logs"], Field(description="Always `logs`.")] logs: Annotated[ Optional[str], - Field(None, description="The text output from the Code Interpreter tool call."), - ] + Field(description="The text output from the Code Interpreter tool call."), + ] = None class Image1(BaseModel): @@ -2722,8 +2596,8 @@ class RunStepDetailsToolCallsCodeOutputImageObject(BaseModel): class Image2(BaseModel): file_id: Annotated[ Optional[str], - Field(None, description="The [file](/docs/api-reference/files) ID of the image."), - ] + Field(description="The [file](/docs/api-reference/files) ID of the image."), + ] = None class RunStepDeltaStepDetailsToolCallsCodeOutputImageObject(BaseModel): @@ -2748,7 +2622,7 @@ class RunStepDetailsToolCallsFileSearchObject(BaseModel): class RunStepDeltaStepDetailsToolCallsFileSearchObject(BaseModel): index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] - id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call object.")] = None type: Annotated[ Literal["file_search"], Field( @@ -2769,7 +2643,7 @@ class Function5(BaseModel): Field( description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." ), - ] + ] = None class RunStepDetailsToolCallsFunctionObject(BaseModel): @@ -2786,22 +2660,21 @@ class RunStepDetailsToolCallsFunctionObject(BaseModel): class Function6(BaseModel): - name: Annotated[Optional[str], Field(None, description="The name of the function.")] + name: Annotated[Optional[str], Field(description="The name of the function.")] = None arguments: Annotated[ - Optional[str], Field(None, description="The arguments passed to the function.") - ] + Optional[str], Field(description="The arguments passed to the function.") + ] = None output: Annotated[ Optional[str], Field( - None, - description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet.", + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." ), - ] + ] = None class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] - id: Annotated[Optional[str], Field(None, description="The ID of the tool call object.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call object.")] = None type: Annotated[ Literal["function"], Field( @@ -2810,8 +2683,8 @@ class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): ] function: Annotated[ Optional[Function6], - Field(None, description="The definition of the function that was called."), - ] + Field(description="The definition of the function that was called."), + ] = None class VectorStoreExpirationAfter(BaseModel): @@ -2873,17 +2746,14 @@ class VectorStoreObject(BaseModel): expires_after: Optional[VectorStoreExpirationAfter] = None expires_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the vector store will expire.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the vector store will expire."), + ] = None last_active_at: Annotated[ Optional[int], Field( description="The Unix timestamp (in seconds) for when the vector store was last active." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( @@ -2896,15 +2766,14 @@ class UpdateVectorStoreRequest(BaseModel): model_config = ConfigDict( extra="forbid", ) - name: Annotated[Optional[str], Field(None, description="The name of the vector store.")] + name: Annotated[Optional[str], Field(description="The name of the vector store.")] = None expires_after: Optional[VectorStoreExpirationAfter] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class ListVectorStoresResponse(BaseModel): @@ -3078,37 +2947,28 @@ class DoneEvent(BaseModel): class Datum(BaseModel): code: Annotated[ - Optional[str], - Field(None, description="An error code identifying the error type."), - ] + Optional[str], Field(description="An error code identifying the error type.") + ] = None message: Annotated[ Optional[str], - Field( - None, - description="A human-readable message providing more details about the error.", - ), - ] + Field(description="A human-readable message providing more details about the error."), + ] = None param: Annotated[ Optional[str], - Field( - None, - description="The name of the parameter that caused the error, if applicable.", - ), - ] + Field(description="The name of the parameter that caused the error, if applicable."), + ] = None line: Annotated[ Optional[int], Field( - None, - description="The line number of the input file where the error occurred, if applicable.", + description="The line number of the input file where the error occurred, if applicable." ), - ] + ] = None class Errors(BaseModel): object: Annotated[ - Optional[str], - Field(None, description="The object type, which is always `list`."), - ] + Optional[str], Field(description="The object type, which is always `list`.") + ] = None data: Optional[List[Datum]] = None @@ -3149,137 +3009,100 @@ class Batch(BaseModel): output_file_id: Annotated[ Optional[str], Field( - None, - description="The ID of the file containing the outputs of successfully executed requests.", + description="The ID of the file containing the outputs of successfully executed requests." ), - ] + ] = None error_file_id: Annotated[ Optional[str], - Field( - None, - description="The ID of the file containing the outputs of requests with errors.", - ), - ] + Field(description="The ID of the file containing the outputs of requests with errors."), + ] = None created_at: Annotated[ int, Field(description="The Unix timestamp (in seconds) for when the batch was created."), ] in_progress_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch started processing.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch started processing."), + ] = None expires_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch will expire.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch will expire."), + ] = None finalizing_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch started finalizing.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch started finalizing."), + ] = None completed_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch was completed.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch was completed."), + ] = None failed_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch failed.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch failed."), + ] = None expired_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch expired.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch expired."), + ] = None cancelling_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch started cancelling.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch started cancelling."), + ] = None cancelled_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) for when the batch was cancelled.", - ), - ] + Field(description="The Unix timestamp (in seconds) for when the batch was cancelled."), + ] = None request_counts: Annotated[ Optional[RequestCounts], - Field( - None, - description="The request counts for different statuses within the batch.", - ), - ] + Field(description="The request counts for different statuses within the batch."), + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class BatchRequestInput(BaseModel): custom_id: Annotated[ Optional[str], Field( - None, - description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch.", + description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch." ), - ] + ] = None method: Annotated[ Optional[Literal["POST"]], Field( - None, - description="The HTTP method to be used for the request. Currently only `POST` is supported.", + description="The HTTP method to be used for the request. Currently only `POST` is supported." ), - ] + ] = None url: Annotated[ Optional[str], Field( - None, - description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported.", + description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported." ), - ] + ] = None class Response(BaseModel): status_code: Annotated[ - Optional[int], Field(None, description="The HTTP status code of the response") - ] + Optional[int], Field(description="The HTTP status code of the response") + ] = None request_id: Annotated[ Optional[str], Field( - None, - description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support.", + description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support." ), - ] + ] = None body: Annotated[ - Optional[Dict[str, Any]], - Field(None, description="The JSON body of the response"), - ] + Optional[Dict[str, Any]], Field(description="The JSON body of the response") + ] = None class Error2(BaseModel): - code: Annotated[Optional[str], Field(None, description="A machine-readable error code.")] - message: Annotated[Optional[str], Field(None, description="A human-readable error message.")] + code: Annotated[Optional[str], Field(description="A machine-readable error code.")] = None + message: Annotated[Optional[str], Field(description="A human-readable error message.")] = None class BatchRequestOutput(BaseModel): @@ -3287,46 +3110,41 @@ class BatchRequestOutput(BaseModel): custom_id: Annotated[ Optional[str], Field( - None, - description="A developer-provided per-request id that will be used to match outputs to inputs.", + description="A developer-provided per-request id that will be used to match outputs to inputs." ), - ] + ] = None response: Optional[Response] = None error: Annotated[ Optional[Error2], Field( - None, - description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure.", + description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure." ), - ] + ] = None class ListBatchesResponse(BaseModel): data: List[Batch] - first_id: Annotated[Optional[str], Field(None, examples=["batch_abc123"])] - last_id: Annotated[Optional[str], Field(None, examples=["batch_abc456"])] + first_id: Annotated[Optional[str], Field(examples=["batch_abc123"])] = None + last_id: Annotated[Optional[str], Field(examples=["batch_abc456"])] = None has_more: bool object: Literal["list"] class AuditLogActorServiceAccount(BaseModel): - id: Annotated[Optional[str], Field(None, description="The service account id.")] + id: Annotated[Optional[str], Field(description="The service account id.")] = None class AuditLogActorUser(BaseModel): - id: Annotated[Optional[str], Field(None, description="The user id.")] - email: Annotated[Optional[str], Field(None, description="The user email.")] + id: Annotated[Optional[str], Field(description="The user id.")] = None + email: Annotated[Optional[str], Field(description="The user email.")] = None class AuditLogActorApiKey(BaseModel): - id: Annotated[Optional[str], Field(None, description="The tracking id of the API key.")] + id: Annotated[Optional[str], Field(description="The tracking id of the API key.")] = None type: Annotated[ Optional[Literal["user", "service_account"]], - Field( - None, - description="The type of API key. Can be either `user` or `service_account`.", - ), - ] + Field(description="The type of API key. Can be either `user` or `service_account`."), + ] = None user: Optional[AuditLogActorUser] = None service_account: Optional[AuditLogActorServiceAccount] = None @@ -3335,15 +3153,15 @@ class AuditLogActorSession(BaseModel): user: Optional[AuditLogActorUser] = None ip_address: Annotated[ Optional[str], - Field(None, description="The IP address from which the action was performed."), - ] + Field(description="The IP address from which the action was performed."), + ] = None class AuditLogActor(BaseModel): type: Annotated[ Optional[Literal["session", "api_key"]], - Field(None, description="The type of actor. Is either `session` or `api_key`."), - ] + Field(description="The type of actor. Is either `session` or `api_key`."), + ] = None session: Optional[AuditLogActorSession] = None api_key: Optional[AuditLogActorApiKey] = None @@ -3402,232 +3220,212 @@ class AuditLogEventType( class Project(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] - name: Annotated[Optional[str], Field(None, description="The project title.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None + name: Annotated[Optional[str], Field(description="The project title.")] = None class Data(BaseModel): scopes: Annotated[ Optional[List[str]], - Field( - None, - description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`', - ), - ] + Field(description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`'), + ] = None class ApiKeyCreated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None data: Annotated[ - Optional[Data], - Field(None, description="The payload used to create the API key."), - ] + Optional[Data], Field(description="The payload used to create the API key.") + ] = None class ChangesRequested(BaseModel): scopes: Annotated[ Optional[List[str]], - Field( - None, - description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`', - ), - ] + Field(description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`'), + ] = None class ApiKeyUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None changes_requested: Annotated[ Optional[ChangesRequested], - Field(None, description="The payload used to update the API key."), - ] + Field(description="The payload used to update the API key."), + ] = None class ApiKeyDeleted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The tracking ID of the API key.")] + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None class Data1(BaseModel): - email: Annotated[ - Optional[str], Field(None, description="The email invited to the organization.") - ] + email: Annotated[Optional[str], Field(description="The email invited to the organization.")] = ( + None + ) role: Annotated[ Optional[str], - Field( - None, - description="The role the email was invited to be. Is either `owner` or `member`.", - ), - ] + Field(description="The role the email was invited to be. Is either `owner` or `member`."), + ] = None class InviteSent(BaseModel): - id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None data: Annotated[ - Optional[Data1], - Field(None, description="The payload used to create the invite."), - ] + Optional[Data1], Field(description="The payload used to create the invite.") + ] = None class InviteAccepted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None class InviteDeleted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The ID of the invite.")] + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None class LoginFailed(BaseModel): - error_code: Annotated[Optional[str], Field(None, description="The error code of the failure.")] + error_code: Annotated[Optional[str], Field(description="The error code of the failure.")] = None error_message: Annotated[ - Optional[str], Field(None, description="The error message of the failure.") - ] + Optional[str], Field(description="The error message of the failure.") + ] = None class LogoutFailed(BaseModel): - error_code: Annotated[Optional[str], Field(None, description="The error code of the failure.")] + error_code: Annotated[Optional[str], Field(description="The error code of the failure.")] = None error_message: Annotated[ - Optional[str], Field(None, description="The error message of the failure.") - ] + Optional[str], Field(description="The error message of the failure.") + ] = None class Settings(BaseModel): threads_ui_visibility: Annotated[ Optional[str], Field( - None, - description="Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`.", + description="Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`." ), - ] + ] = None usage_dashboard_visibility: Annotated[ Optional[str], Field( - None, - description="Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`.", + description="Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`." ), - ] + ] = None class ChangesRequested1(BaseModel): - title: Annotated[Optional[str], Field(None, description="The organization title.")] - description: Annotated[Optional[str], Field(None, description="The organization description.")] - name: Annotated[Optional[str], Field(None, description="The organization name.")] + title: Annotated[Optional[str], Field(description="The organization title.")] = None + description: Annotated[Optional[str], Field(description="The organization description.")] = None + name: Annotated[Optional[str], Field(description="The organization name.")] = None settings: Optional[Settings] = None class OrganizationUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The organization ID.")] + id: Annotated[Optional[str], Field(description="The organization ID.")] = None changes_requested: Annotated[ Optional[ChangesRequested1], - Field(None, description="The payload used to update the organization settings."), - ] + Field(description="The payload used to update the organization settings."), + ] = None class Data2(BaseModel): - name: Annotated[Optional[str], Field(None, description="The project name.")] + name: Annotated[Optional[str], Field(description="The project name.")] = None title: Annotated[ Optional[str], - Field(None, description="The title of the project as seen on the dashboard."), - ] + Field(description="The title of the project as seen on the dashboard."), + ] = None class ProjectCreated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None data: Annotated[ - Optional[Data2], - Field(None, description="The payload used to create the project."), - ] + Optional[Data2], Field(description="The payload used to create the project.") + ] = None class ChangesRequested2(BaseModel): title: Annotated[ Optional[str], - Field(None, description="The title of the project as seen on the dashboard."), - ] + Field(description="The title of the project as seen on the dashboard."), + ] = None class ProjectUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None changes_requested: Annotated[ Optional[ChangesRequested2], - Field(None, description="The payload used to update the project."), - ] + Field(description="The payload used to update the project."), + ] = None class ProjectArchived(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None class Data3(BaseModel): role: Annotated[ Optional[str], - Field( - None, - description="The role of the service account. Is either `owner` or `member`.", - ), - ] + Field(description="The role of the service account. Is either `owner` or `member`."), + ] = None class ServiceAccountCreated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The service account ID.")] + id: Annotated[Optional[str], Field(description="The service account ID.")] = None data: Annotated[ Optional[Data3], - Field(None, description="The payload used to create the service account."), - ] + Field(description="The payload used to create the service account."), + ] = None class ChangesRequested3(BaseModel): role: Annotated[ Optional[str], - Field( - None, - description="The role of the service account. Is either `owner` or `member`.", - ), - ] + Field(description="The role of the service account. Is either `owner` or `member`."), + ] = None class ServiceAccountUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The service account ID.")] + id: Annotated[Optional[str], Field(description="The service account ID.")] = None changes_requested: Annotated[ Optional[ChangesRequested3], - Field(None, description="The payload used to updated the service account."), - ] + Field(description="The payload used to updated the service account."), + ] = None class ServiceAccountDeleted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The service account ID.")] + id: Annotated[Optional[str], Field(description="The service account ID.")] = None class Data4(BaseModel): role: Annotated[ Optional[str], - Field(None, description="The role of the user. Is either `owner` or `member`."), - ] + Field(description="The role of the user. Is either `owner` or `member`."), + ] = None class UserAdded(BaseModel): - id: Annotated[Optional[str], Field(None, description="The user ID.")] + id: Annotated[Optional[str], Field(description="The user ID.")] = None data: Annotated[ Optional[Data4], - Field(None, description="The payload used to add the user to the project."), - ] + Field(description="The payload used to add the user to the project."), + ] = None class ChangesRequested4(BaseModel): role: Annotated[ Optional[str], - Field(None, description="The role of the user. Is either `owner` or `member`."), - ] + Field(description="The role of the user. Is either `owner` or `member`."), + ] = None class UserUpdated(BaseModel): - id: Annotated[Optional[str], Field(None, description="The project ID.")] + id: Annotated[Optional[str], Field(description="The project ID.")] = None changes_requested: Annotated[ Optional[ChangesRequested4], - Field(None, description="The payload used to update the user."), - ] + Field(description="The payload used to update the user."), + ] = None class UserDeleted(BaseModel): - id: Annotated[Optional[str], Field(None, description="The user ID.")] + id: Annotated[Optional[str], Field(description="The user ID.")] = None class AuditLog(BaseModel): @@ -3637,155 +3435,121 @@ class AuditLog(BaseModel): project: Annotated[ Optional[Project], Field( - None, - description="The project that the action was scoped to. Absent for actions not scoped to projects.", + description="The project that the action was scoped to. Absent for actions not scoped to projects." ), - ] + ] = None actor: AuditLogActor api_key_created: Annotated[ Optional[ApiKeyCreated], Field( - None, alias="api_key.created", description="The details for events with this `type`.", ), - ] + ] = None api_key_updated: Annotated[ Optional[ApiKeyUpdated], Field( - None, alias="api_key.updated", description="The details for events with this `type`.", ), - ] + ] = None api_key_deleted: Annotated[ Optional[ApiKeyDeleted], Field( - None, alias="api_key.deleted", description="The details for events with this `type`.", ), - ] + ] = None invite_sent: Annotated[ Optional[InviteSent], - Field( - None, - alias="invite.sent", - description="The details for events with this `type`.", - ), - ] + Field(alias="invite.sent", description="The details for events with this `type`."), + ] = None invite_accepted: Annotated[ Optional[InviteAccepted], Field( - None, alias="invite.accepted", description="The details for events with this `type`.", ), - ] + ] = None invite_deleted: Annotated[ Optional[InviteDeleted], Field( - None, alias="invite.deleted", description="The details for events with this `type`.", ), - ] + ] = None login_failed: Annotated[ Optional[LoginFailed], - Field( - None, - alias="login.failed", - description="The details for events with this `type`.", - ), - ] + Field(alias="login.failed", description="The details for events with this `type`."), + ] = None logout_failed: Annotated[ Optional[LogoutFailed], Field( - None, alias="logout.failed", description="The details for events with this `type`.", ), - ] + ] = None organization_updated: Annotated[ Optional[OrganizationUpdated], Field( - None, alias="organization.updated", description="The details for events with this `type`.", ), - ] + ] = None project_created: Annotated[ Optional[ProjectCreated], Field( - None, alias="project.created", description="The details for events with this `type`.", ), - ] + ] = None project_updated: Annotated[ Optional[ProjectUpdated], Field( - None, alias="project.updated", description="The details for events with this `type`.", ), - ] + ] = None project_archived: Annotated[ Optional[ProjectArchived], Field( - None, alias="project.archived", description="The details for events with this `type`.", ), - ] + ] = None service_account_created: Annotated[ Optional[ServiceAccountCreated], Field( - None, alias="service_account.created", description="The details for events with this `type`.", ), - ] + ] = None service_account_updated: Annotated[ Optional[ServiceAccountUpdated], Field( - None, alias="service_account.updated", description="The details for events with this `type`.", ), - ] + ] = None service_account_deleted: Annotated[ Optional[ServiceAccountDeleted], Field( - None, alias="service_account.deleted", description="The details for events with this `type`.", ), - ] + ] = None user_added: Annotated[ Optional[UserAdded], - Field( - None, - alias="user.added", - description="The details for events with this `type`.", - ), - ] + Field(alias="user.added", description="The details for events with this `type`."), + ] = None user_updated: Annotated[ Optional[UserUpdated], - Field( - None, - alias="user.updated", - description="The details for events with this `type`.", - ), - ] + Field(alias="user.updated", description="The details for events with this `type`."), + ] = None user_deleted: Annotated[ Optional[UserDeleted], - Field( - None, - alias="user.deleted", - description="The details for events with this `type`.", - ), - ] + Field(alias="user.deleted", description="The details for events with this `type`."), + ] = None class ListAuditLogsResponse(BaseModel): @@ -3824,11 +3588,8 @@ class Invite(BaseModel): ] accepted_at: Annotated[ Optional[int], - Field( - None, - description="The Unix timestamp (in seconds) of when the invite was accepted.", - ), - ] + Field(description="The Unix timestamp (in seconds) of when the invite was accepted."), + ] = None class InviteListResponse(BaseModel): @@ -3836,19 +3597,17 @@ class InviteListResponse(BaseModel): data: List[Invite] first_id: Annotated[ Optional[str], - Field(None, description="The first `invite_id` in the retrieved `list`"), - ] + Field(description="The first `invite_id` in the retrieved `list`"), + ] = None last_id: Annotated[ - Optional[str], - Field(None, description="The last `invite_id` in the retrieved `list`"), - ] + Optional[str], Field(description="The last `invite_id` in the retrieved `list`") + ] = None has_more: Annotated[ Optional[bool], Field( - None, - description="The `has_more` property is used for pagination to indicate there are additional results.", + description="The `has_more` property is used for pagination to indicate there are additional results." ), - ] + ] = None class InviteRequest(BaseModel): @@ -3918,10 +3677,9 @@ class Project1(BaseModel): archived_at: Annotated[ Optional[int], Field( - None, - description="The Unix timestamp (in seconds) of when the project was archived or `null`.", + description="The Unix timestamp (in seconds) of when the project was archived or `null`." ), - ] + ] = None status: Annotated[Literal["active", "archived"], Field(description="`active` or `archived`")] @@ -4048,8 +3806,8 @@ class ProjectServiceAccountDeleteResponse(BaseModel): class Owner(BaseModel): type: Annotated[ Optional[Literal["user", "service_account"]], - Field(None, description="`user` or `service_account`"), - ] + Field(description="`user` or `service_account`"), + ] = None user: Optional[ProjectUser] = None service_account: Optional[ProjectServiceAccount] = None @@ -4107,129 +3865,115 @@ class CreateCompletionRequest(BaseModel): best_of: Annotated[ Optional[int], Field( - 1, description='Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.\n\nWhen used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n', ge=0, le=20, ), - ] + ] = 1 echo: Annotated[ Optional[bool], - Field(False, description="Echo back the prompt in addition to the completion\n"), - ] + Field(description="Echo back the prompt in addition to the completion\n"), + ] = False frequency_penalty: Annotated[ Optional[float], Field( - 0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", ge=-2.0, le=2.0, ), - ] + ] = 0 logit_bias: Annotated[ Optional[Dict[str, int]], Field( - None, - description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n', + description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n' ), - ] + ] = None logprobs: Annotated[ Optional[int], Field( - None, description="Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.\n\nThe maximum value for `logprobs` is 5.\n", ge=0, le=5, ), - ] + ] = None max_tokens: Annotated[ Optional[int], Field( - 16, description="The maximum number of [tokens](/tokenizer) that can be generated in the completion.\n\nThe token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", examples=[16], ge=0, ), - ] + ] = 16 n: Annotated[ Optional[int], Field( - 1, description="How many completions to generate for each prompt.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n", examples=[1], ge=1, le=128, ), - ] + ] = 1 presence_penalty: Annotated[ Optional[float], Field( - 0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", ge=-2.0, le=2.0, ), - ] + ] = 0 seed: Annotated[ Optional[int], Field( - None, description="If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\n\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", ge=-9223372036854775808, le=9223372036854775807, ), - ] + ] = None stop: Annotated[ Optional[Union[Optional[str], Stop]], Field( - None, - description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n" ), - ] + ] = None stream: Annotated[ Optional[bool], Field( - False, - description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n" ), - ] + ] = False stream_options: Optional[ChatCompletionStreamOptions] = None suffix: Annotated[ Optional[str], Field( - None, description="The suffix that comes after a completion of inserted text.\n\nThis parameter is only supported for `gpt-3.5-turbo-instruct`.\n", examples=["test."], ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", examples=[1], ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", examples=[1], ge=0.0, le=1.0, ), - ] + ] = 1 user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", examples=["user-1234"], ), - ] + ] = None class CreateCompletionResponse(BaseModel): @@ -4248,10 +3992,9 @@ class CreateCompletionResponse(BaseModel): system_fingerprint: Annotated[ Optional[str], Field( - None, - description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" ), - ] + ] = None object: Annotated[ Literal["text_completion"], Field(description='The object type, which is always "text_completion"'), @@ -4286,11 +4029,10 @@ class ChatCompletionMessageToolCalls(RootModel[List[ChatCompletionMessageToolCal class ChatCompletionResponseMessage(BaseModel): - content: Annotated[Optional[str], Field(description="The contents of the message.")] + content: Annotated[Optional[str], Field(description="The contents of the message.")] = None refusal: Annotated[ - Optional[str], - Field(None, description="The refusal message generated by the model."), - ] + Optional[str], Field(description="The refusal message generated by the model.") + ] = None tool_calls: Optional[ChatCompletionMessageToolCalls] = None role: Annotated[ Literal["assistant"], @@ -4299,10 +4041,9 @@ class ChatCompletionResponseMessage(BaseModel): function_call: Annotated[ Optional[FunctionCall], Field( - None, - description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." ), - ] + ] = None class Choice1(BaseModel): @@ -4316,8 +4057,8 @@ class Choice1(BaseModel): message: ChatCompletionResponseMessage logprobs: Annotated[ Optional[Logprobs2], - Field(None, description="Log probability information for the choice."), - ] + Field(description="Log probability information for the choice."), + ] = None class CreateChatCompletionResponse(BaseModel): @@ -4338,18 +4079,16 @@ class CreateChatCompletionResponse(BaseModel): service_tier: Annotated[ Optional[Literal["scale", "default"]], Field( - None, description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", examples=["scale"], ), - ] + ] = None system_fingerprint: Annotated[ Optional[str], Field( - None, - description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" ), - ] + ] = None object: Annotated[ Literal["chat.completion"], Field(description="The object type, which is always `chat.completion`."), @@ -4386,10 +4125,9 @@ class CreateChatCompletionFunctionResponse(BaseModel): system_fingerprint: Annotated[ Optional[str], Field( - None, - description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n", + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" ), - ] + ] = None object: Annotated[ Literal["chat.completion"], Field(description="The object type, which is always `chat.completion`."), @@ -4456,13 +4194,13 @@ class FineTuningJob(BaseModel): Field( description="The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running." ), - ] + ] = None finished_at: Annotated[ Optional[int], Field( description="The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running." ), - ] + ] = None hyperparameters: Annotated[ Hyperparameters1, Field( @@ -4494,7 +4232,7 @@ class FineTuningJob(BaseModel): Field( description="The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running." ), - ] + ] = None training_file: Annotated[ str, Field( @@ -4506,23 +4244,21 @@ class FineTuningJob(BaseModel): Field( description="The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents)." ), - ] + ] = None integrations: Annotated[ Optional[List[FineTuningIntegration]], Field( - None, description="A list of integrations to enable for this fine-tuning job.", max_length=5, ), - ] + ] = None seed: Annotated[int, Field(description="The seed used for the fine-tuning job.")] estimated_finish: Annotated[ Optional[int], Field( - None, - description="The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.", + description="The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running." ), - ] + ] = None class AssistantObject(BaseModel): @@ -4544,14 +4280,14 @@ class AssistantObject(BaseModel): description="The name of the assistant. The maximum length is 256 characters.\n", max_length=256, ), - ] + ] = None description: Annotated[ Optional[str], Field( description="The description of the assistant. The maximum length is 512 characters.\n", max_length=512, ), - ] + ] = None model: Annotated[ str, Field( @@ -4564,7 +4300,7 @@ class AssistantObject(BaseModel): description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", max_length=256000, ), - ] + ] = None tools: Annotated[ List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], Field( @@ -4575,10 +4311,9 @@ class AssistantObject(BaseModel): tool_resources: Annotated[ Optional[ToolResources], Field( - None, - description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( @@ -4588,23 +4323,21 @@ class AssistantObject(BaseModel): temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", examples=[1], ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", examples=[1], ge=0.0, le=1.0, ), - ] + ] = 1 response_format: Optional[AssistantsApiResponseFormatOption] = None @@ -4650,69 +4383,61 @@ class CreateAssistantRequest(BaseModel): name: Annotated[ Optional[str], Field( - None, description="The name of the assistant. The maximum length is 256 characters.\n", max_length=256, ), - ] + ] = None description: Annotated[ Optional[str], Field( - None, description="The description of the assistant. The maximum length is 512 characters.\n", max_length=512, ), - ] + ] = None instructions: Annotated[ Optional[str], Field( - None, description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", max_length=256000, ), - ] + ] = None tools: Annotated[ List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], Field( - [], description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", max_length=128, ), - ] + ] = [] tool_resources: Annotated[ Optional[ToolResources1], Field( - None, - description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", examples=[1], ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", examples=[1], ge=0.0, le=1.0, ), - ] + ] = 1 response_format: Optional[AssistantsApiResponseFormatOption] = None @@ -4723,76 +4448,67 @@ class ModifyAssistantRequest(BaseModel): model: Annotated[ Optional[str], Field( - None, - description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" ), - ] + ] = None name: Annotated[ Optional[str], Field( - None, description="The name of the assistant. The maximum length is 256 characters.\n", max_length=256, ), - ] + ] = None description: Annotated[ Optional[str], Field( - None, description="The description of the assistant. The maximum length is 512 characters.\n", max_length=512, ), - ] + ] = None instructions: Annotated[ Optional[str], Field( - None, description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", max_length=256000, ), - ] + ] = None tools: Annotated[ List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], Field( - [], description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", max_length=128, ), - ] + ] = [] tool_resources: Annotated[ Optional[ToolResources2], Field( - None, - description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", examples=[1], ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", examples=[1], ge=0.0, le=1.0, ), - ] + ] = 1 response_format: Optional[AssistantsApiResponseFormatOption] = None @@ -4888,23 +4604,23 @@ class RunObject(BaseModel): expires_at: Annotated[ Optional[int], Field(description="The Unix timestamp (in seconds) for when the run will expire."), - ] + ] = None started_at: Annotated[ Optional[int], Field(description="The Unix timestamp (in seconds) for when the run was started."), - ] + ] = None cancelled_at: Annotated[ Optional[int], Field(description="The Unix timestamp (in seconds) for when the run was cancelled."), - ] + ] = None failed_at: Annotated[ Optional[int], Field(description="The Unix timestamp (in seconds) for when the run failed."), - ] + ] = None completed_at: Annotated[ Optional[int], Field(description="The Unix timestamp (in seconds) for when the run was completed."), - ] + ] = None incomplete_details: Annotated[ Optional[IncompleteDetails], Field( @@ -4939,32 +4655,28 @@ class RunObject(BaseModel): usage: RunCompletionUsage temperature: Annotated[ Optional[float], - Field( - None, - description="The sampling temperature used for this run. If not set, defaults to 1.", - ), - ] + Field(description="The sampling temperature used for this run. If not set, defaults to 1."), + ] = None top_p: Annotated[ Optional[float], Field( - None, - description="The nucleus sampling value used for this run. If not set, defaults to 1.", + description="The nucleus sampling value used for this run. If not set, defaults to 1." ), - ] + ] = None max_prompt_tokens: Annotated[ Optional[int], Field( description="The maximum number of prompt tokens specified to have been used over the course of the run.\n", ge=256, ), - ] + ] = None max_completion_tokens: Annotated[ Optional[int], Field( description="The maximum number of completion tokens specified to have been used over the course of the run.\n", ge=256, ), - ] + ] = None truncation_strategy: Annotated[Optional[TruncationObject], Field(...)] tool_choice: Annotated[Optional[AssistantsApiToolChoiceOption], Field(...)] parallel_tool_calls: ParallelToolCalls @@ -5020,17 +4732,15 @@ class CreateMessageRequest(BaseModel): attachments: Annotated[ Optional[List[Attachment]], Field( - None, - description="A list of files attached to the message, and the tools they should be added to.", + description="A list of files attached to the message, and the tools they should be added to." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class Text(BaseModel): @@ -5049,7 +4759,7 @@ class MessageContentTextObject(BaseModel): class Text1(BaseModel): - value: Annotated[Optional[str], Field(None, description="The data that makes up the text.")] + value: Annotated[Optional[str], Field(description="The data that makes up the text.")] = None annotations: Optional[ List[ Union[ @@ -5097,9 +4807,8 @@ class RunStepDetailsToolCallsCodeObject(BaseModel): class CodeInterpreter8(BaseModel): input: Annotated[ - Optional[str], - Field(None, description="The input to the Code Interpreter tool call."), - ] + Optional[str], Field(description="The input to the Code Interpreter tool call.") + ] = None outputs: Annotated[ Optional[ List[ @@ -5110,15 +4819,14 @@ class CodeInterpreter8(BaseModel): ] ], Field( - None, - description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type.", + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type." ), - ] + ] = None class RunStepDeltaStepDetailsToolCallsCodeObject(BaseModel): index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] - id: Annotated[Optional[str], Field(None, description="The ID of the tool call.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call.")] = None type: Annotated[ Literal["code_interpreter"], Field( @@ -5127,8 +4835,8 @@ class RunStepDeltaStepDetailsToolCallsCodeObject(BaseModel): ] code_interpreter: Annotated[ Optional[CodeInterpreter8], - Field(None, description="The Code Interpreter tool call definition."), - ] + Field(description="The Code Interpreter tool call definition."), + ] = None class CreateVectorStoreRequest(BaseModel): @@ -5138,27 +4846,24 @@ class CreateVectorStoreRequest(BaseModel): file_ids: Annotated[ Optional[List[str]], Field( - None, description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", max_length=500, ), - ] - name: Annotated[Optional[str], Field(None, description="The name of the vector store.")] + ] = None + name: Annotated[Optional[str], Field(description="The name of the vector store.")] = None expires_after: Optional[VectorStoreExpirationAfter] = None chunking_strategy: Annotated[ Optional[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]], Field( - None, - description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty.", + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty." ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class StaticChunkingStrategyResponseParam(BaseModel): @@ -5265,13 +4970,12 @@ class ChatCompletionRequestAssistantMessage(BaseModel): content: Annotated[ Optional[Union[Optional[str], Content2]], Field( - None, - description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n", + description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n" ), - ] + ] = None refusal: Annotated[ - Optional[str], Field(None, description="The refusal message by the assistant.") - ] + Optional[str], Field(description="The refusal message by the assistant.") + ] = None role: Annotated[ Literal["assistant"], Field(description="The role of the messages author, in this case `assistant`."), @@ -5279,28 +4983,23 @@ class ChatCompletionRequestAssistantMessage(BaseModel): name: Annotated[ Optional[str], Field( - None, - description="An optional name for the participant. Provides the model information to differentiate between participants of the same role.", + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." ), - ] + ] = None tool_calls: Optional[ChatCompletionMessageToolCalls] = None function_call: Annotated[ Optional[FunctionCall], Field( - None, - description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.", + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." ), - ] + ] = None class FineTuneChatCompletionRequestAssistantMessage(ChatCompletionRequestAssistantMessage): weight: Annotated[ Optional[Literal[0, 1]], - Field( - None, - description="Controls whether the assistant message is trained against (0 or 1)", - ), - ] + Field(description="Controls whether the assistant message is trained against (0 or 1)"), + ] = None role: Annotated[ Literal["assistant"], Field(description="The role of the messages author, in this case `assistant`."), @@ -5326,22 +5025,21 @@ class FinetuneChatRequestInput(BaseModel): ] ] ], - Field(None, min_length=1), - ] + Field(min_length=1), + ] = None tools: Annotated[ Optional[List[ChatCompletionTool]], - Field(None, description="A list of tools the model may generate JSON inputs for."), - ] + Field(description="A list of tools the model may generate JSON inputs for."), + ] = None parallel_tool_calls: Optional[ParallelToolCalls] = None functions: Annotated[ Optional[List[ChatCompletionFunctions]], Field( - None, description="A list of functions the model may generate JSON inputs for.", max_length=128, min_length=1, ), - ] + ] = None class CreateRunRequest(BaseModel): @@ -5387,90 +5085,77 @@ class CreateRunRequest(BaseModel): ] ], Field( - None, description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", examples=["gpt-4o"], ), - ] + ] = None instructions: Annotated[ Optional[str], Field( - None, - description="Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis.", + description="Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis." ), - ] + ] = None additional_instructions: Annotated[ Optional[str], Field( - None, - description="Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions.", + description="Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions." ), - ] + ] = None additional_messages: Annotated[ Optional[List[CreateMessageRequest]], - Field( - None, - description="Adds additional messages to the thread before creating the run.", - ), - ] + Field(description="Adds additional messages to the thread before creating the run."), + ] = None tools: Annotated[ Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], Field( - None, description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", max_length=20, ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", examples=[1], ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", examples=[1], ge=0.0, le=1.0, ), - ] + ] = 1 stream: Annotated[ Optional[bool], Field( - None, - description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" ), - ] + ] = None max_prompt_tokens: Annotated[ Optional[int], Field( - None, description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", ge=256, ), - ] + ] = None max_completion_tokens: Annotated[ Optional[int], Field( - None, description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", ge=256, ), - ] + ] = None truncation_strategy: Optional[TruncationObject] = None tool_choice: Optional[AssistantsApiToolChoiceOption] = None parallel_tool_calls: Optional[ParallelToolCalls] = None @@ -5484,24 +5169,21 @@ class CreateThreadRequest(BaseModel): messages: Annotated[ Optional[List[CreateMessageRequest]], Field( - None, - description="A list of [messages](/docs/api-reference/messages) to start the thread with.", + description="A list of [messages](/docs/api-reference/messages) to start the thread with." ), - ] + ] = None tool_resources: Annotated[ Optional[ToolResources5], Field( - None, - description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None class MessageObject(BaseModel): @@ -5536,13 +5218,13 @@ class MessageObject(BaseModel): completed_at: Annotated[ Optional[int], Field(description="The Unix timestamp (in seconds) for when the message was completed."), - ] + ] = None incomplete_at: Annotated[ Optional[int], Field( description="The Unix timestamp (in seconds) for when the message was marked as incomplete." ), - ] + ] = None role: Annotated[ Literal["user", "assistant"], Field(description="The entity that produced the message. One of `user` or `assistant`."), @@ -5563,13 +5245,13 @@ class MessageObject(BaseModel): Field( description="If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message." ), - ] + ] = None run_id: Annotated[ Optional[str], Field( description="The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints." ), - ] + ] = None attachments: Annotated[ Optional[List[Attachment]], Field( @@ -5587,11 +5269,8 @@ class MessageObject(BaseModel): class Delta(BaseModel): role: Annotated[ Optional[Literal["user", "assistant"]], - Field( - None, - description="The entity that produced the message. One of `user` or `assistant`.", - ), - ] + Field(description="The entity that produced the message. One of `user` or `assistant`."), + ] = None content: Annotated[ Optional[ List[ @@ -5603,11 +5282,8 @@ class Delta(BaseModel): ] ] ], - Field( - None, - description="The content of the message in array of text and/or images.", - ), - ] + Field(description="The content of the message in array of text and/or images."), + ] = None class MessageDeltaObject(BaseModel): @@ -5664,10 +5340,9 @@ class RunStepDeltaStepDetailsToolCallsObject(BaseModel): ] ], Field( - None, - description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n", + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n" ), - ] + ] = None class VectorStoreFileObject(BaseModel): @@ -5711,8 +5386,8 @@ class VectorStoreFileObject(BaseModel): ] chunking_strategy: Annotated[ Optional[Union[StaticChunkingStrategyResponseParam, OtherChunkingStrategyResponseParam]], - Field(None, description="The strategy used to chunk the file."), - ] + Field(description="The strategy used to chunk the file."), + ] = None class ListVectorStoreFilesResponse(BaseModel): @@ -5838,152 +5513,132 @@ class CreateChatCompletionRequest(BaseModel): frequency_penalty: Annotated[ Optional[float], Field( - 0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", ge=-2.0, le=2.0, ), - ] + ] = 0 logit_bias: Annotated[ Optional[Dict[str, int]], Field( - None, - description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n", + description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n" ), - ] + ] = None logprobs: Annotated[ Optional[bool], Field( - False, - description="Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.", + description="Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`." ), - ] + ] = False top_logprobs: Annotated[ Optional[int], Field( - None, description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.", ge=0, le=20, ), - ] + ] = None max_tokens: Annotated[ Optional[int], Field( - None, - description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n" ), - ] + ] = None n: Annotated[ Optional[int], Field( - 1, description="How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs.", examples=[1], ge=1, le=128, ), - ] + ] = 1 presence_penalty: Annotated[ Optional[float], Field( - 0, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", ge=-2.0, le=2.0, ), - ] + ] = 0 response_format: Annotated[ Optional[Union[ResponseFormatText, ResponseFormatJsonObject, ResponseFormatJsonSchema]], Field( - None, - description='An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n', + description='An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' ), - ] + ] = None seed: Annotated[ Optional[int], Field( - None, description="This feature is in Beta.\nIf specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", ge=-9223372036854775808, le=9223372036854775807, ), - ] + ] = None service_tier: Annotated[ Optional[Literal["auto", "default"]], Field( - None, - description="Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service:\n - If set to 'auto', the system will utilize scale tier credits until they are exhausted.\n - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee.\n - When not set, the default behavior is 'auto'.\n\n When this parameter is set, the response body will include the `service_tier` utilized.\n", + description="Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service:\n - If set to 'auto', the system will utilize scale tier credits until they are exhausted.\n - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee.\n - When not set, the default behavior is 'auto'.\n\n When this parameter is set, the response body will include the `service_tier` utilized.\n" ), - ] + ] = None stop: Annotated[ Union[Optional[str], Stop1], - Field( - None, - description="Up to 4 sequences where the API will stop generating further tokens.\n", - ), - ] + Field(description="Up to 4 sequences where the API will stop generating further tokens.\n"), + ] = None stream: Annotated[ Optional[bool], Field( - False, - description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n" ), - ] + ] = False stream_options: Optional[ChatCompletionStreamOptions] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", examples=[1], ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", examples=[1], ge=0.0, le=1.0, ), - ] + ] = 1 tools: Annotated[ Optional[List[ChatCompletionTool]], Field( - None, - description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.\n", + description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.\n" ), - ] + ] = None tool_choice: Optional[ChatCompletionToolChoiceOption] = None parallel_tool_calls: Optional[ParallelToolCalls] = None user: Annotated[ Optional[str], Field( - None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", examples=["user-1234"], ), - ] + ] = None function_call: Annotated[ Optional[Union[Literal["none", "auto"], ChatCompletionFunctionCallOption]], Field( - None, - description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n', + description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n' ), - ] + ] = None functions: Annotated[ Optional[List[ChatCompletionFunctions]], Field( - None, description="Deprecated in favor of `tools`.\n\nA list of functions the model may generate JSON inputs for.\n", max_length=128, min_length=1, ), - ] + ] = None class CreateThreadAndRunRequest(BaseModel): @@ -5998,11 +5653,8 @@ class CreateThreadAndRunRequest(BaseModel): ] thread: Annotated[ Optional[CreateThreadRequest], - Field( - None, - description="If no thread is provided, an empty thread will be created.", - ), - ] + Field(description="If no thread is provided, an empty thread will be created."), + ] = None model: Annotated[ Optional[ Union[ @@ -6036,83 +5688,73 @@ class CreateThreadAndRunRequest(BaseModel): ] ], Field( - None, description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", examples=["gpt-4o"], ), - ] + ] = None instructions: Annotated[ Optional[str], Field( - None, - description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis.", + description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis." ), - ] + ] = None tools: Annotated[ Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], Field( - None, description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", max_length=20, ), - ] + ] = None tool_resources: Annotated[ Optional[ToolResources3], Field( - None, - description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n", + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" ), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( - None, - description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n", + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" ), - ] + ] = None temperature: Annotated[ Optional[float], Field( - 1, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", examples=[1], ge=0.0, le=2.0, ), - ] + ] = 1 top_p: Annotated[ Optional[float], Field( - 1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", examples=[1], ge=0.0, le=1.0, ), - ] + ] = 1 stream: Annotated[ Optional[bool], Field( - None, - description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n", + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" ), - ] + ] = None max_prompt_tokens: Annotated[ Optional[int], Field( - None, description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", ge=256, ), - ] + ] = None max_completion_tokens: Annotated[ Optional[int], Field( - None, description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", ge=256, ), - ] + ] = None truncation_strategy: Optional[TruncationObject] = None tool_choice: Optional[AssistantsApiToolChoiceOption] = None parallel_tool_calls: Optional[ParallelToolCalls] = None @@ -6177,19 +5819,19 @@ class RunStepObject(BaseModel): Field( description="The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired." ), - ] + ] = None cancelled_at: Annotated[ Optional[int], Field(description="The Unix timestamp (in seconds) for when the run step was cancelled."), - ] + ] = None failed_at: Annotated[ Optional[int], Field(description="The Unix timestamp (in seconds) for when the run step failed."), - ] + ] = None completed_at: Annotated[ Optional[int], Field(description="The Unix timestamp (in seconds) for when the run step completed."), - ] + ] = None metadata: Annotated[ Optional[Dict[str, Any]], Field( @@ -6207,8 +5849,8 @@ class Delta1(BaseModel): RunStepDeltaStepDetailsToolCallsObject, ] ], - Field(None, description="The details of the run step."), - ] + Field(description="The details of the run step."), + ] = None class RunStepDeltaObject(BaseModel): diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm index 6a00adb7..6494be71 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -ARG VLLM_VERSION=0.6.1.post2 +ARG VLLM_VERSION=0.6.2 ARG VLLM_BASE_IMAGE=vllm/vllm-openai:v${VLLM_VERSION} FROM ${VLLM_BASE_IMAGE} AS base diff --git a/model-engine/model_engine_server/inference/vllm/README.md b/model-engine/model_engine_server/inference/vllm/README.md index 8f969f17..486b528c 100644 --- a/model-engine/model_engine_server/inference/vllm/README.md +++ b/model-engine/model_engine_server/inference/vllm/README.md @@ -25,6 +25,7 @@ docker run \ --shm-size=16gb \ --gpus '"device=0"' \ -v $MODEL_PATH:/workspace/model_files:ro \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_server.py:/workspace/vllm_server.py \ -p 5005:5005 \ --name vllm \ ${IMAGE} \ @@ -52,7 +53,7 @@ docker run \ -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/examples:/workspace/examples \ -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_batch.py:/workspace/vllm_batch.py \ -p 5005:5005 \ - -e CONFIG_FILE=/workspace/examples/v2/sample_config_gemma.json \ + -e CONFIG_FILE=/workspace/examples/v2/gemma/config.json \ -e MODEL_WEIGHTS_FOLDER=/workspace/model_files \ --name vllm_batch \ ${IMAGE_BATCH} \ diff --git a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh index d7fcb547..65c49b32 100755 --- a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh +++ b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh @@ -28,7 +28,7 @@ fi ACCOUNT=$1 IMAGE_TAG=$2 BUILD_TARGET=$3 -VLLM_VERSION=${VLLM_VERSION:-"0.5.3.post1"} +VLLM_VERSION=${VLLM_VERSION:-"0.6.2"} # if build target = vllm use vllm otherwise use vllm_batch if [ "$BUILD_TARGET" == "vllm" ]; then diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/README.md b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/README.md new file mode 100644 index 00000000..08c3b213 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/README.md @@ -0,0 +1,19 @@ +# quick commands + +``` +export MODEL=gemma-2-2b-it && export MODEL_PATH=/data/model_files/$MODEL +docker kill vllm_batch; docker rm vllm_batch; +docker run \ + --runtime nvidia \ + --shm-size=16gb \ + --gpus '"device=6,7"' \ + -v $MODEL_PATH:/workspace/model_files:ro \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/examples:/workspace/examples \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_batch.py:/workspace/vllm_batch.py \ + -p 5005:5005 \ + -e CONFIG_FILE=/workspace/examples/v2/gemma/config.json \ + -e MODEL_WEIGHTS_FOLDER=/workspace/model_files \ + --name vllm_batch \ + ${IMAGE_BATCH} \ + python vllm_batch.py +``` \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/sample_config_gemma.json b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config.json similarity index 68% rename from model-engine/model_engine_server/inference/vllm/examples/v2/sample_config_gemma.json rename to model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config.json index 2b1c020d..fc98e6d0 100644 --- a/model-engine/model_engine_server/inference/vllm/examples/v2/sample_config_gemma.json +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config.json @@ -1,6 +1,6 @@ { - "input_data_path": "./examples/v2/sample_data_chat_gemma.json", - "output_data_path": "./examples/v2/sample_output.json", + "input_data_path": "./examples/v2/gemma/data_oai_chat.json", + "output_data_path": "./examples/v2/gemma/output_oi_chat.json", "model_config": { "model": "gemma-2-2b-it", "checkpoint_path": "my_path", diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config_w_oai_chat_content.json b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config_w_oai_chat_content.json new file mode 100644 index 00000000..15e35c4d --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config_w_oai_chat_content.json @@ -0,0 +1,33 @@ +{ + "content": [ + { + "messages": [ + { + "role": "user", + "content": "What is a good place for travel in the US?" + }, + { + "role": "assistant", + "content": "California." + }, + { + "role": "user", + "content": "What can I do in California?" + } + ], + "logprobs": true + } + ], + "output_data_path": "./examples/v2/sample_output.json", + "model_config": { + "model": "gemma-2-2b-it", + "checkpoint_path": "my_path", + "num_shards": 1, + "response_role": "assistant", + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_chat.json b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_chat.json new file mode 100644 index 00000000..fbbf1286 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_chat.json @@ -0,0 +1,7 @@ +[ + {"messages": [ + {"role": "user", "content": "What is a good place for travel in the US?"}, + {"role": "assistant", "content": "California."}, + {"role": "user", "content": "What can I do in California?"}], + "logprobs": true} +] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_completion.json b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_completion.json new file mode 100644 index 00000000..2b1500b4 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_completion.json @@ -0,0 +1,16 @@ +[ + { + "prompt": "What is a good place for travel in the US?", + "logprobs": true, + "echo": true, + "max_tokens": 7, + "temperature": 1 + }, + { + "prompt": "What is a good place for travel in the EU?", + "logprobs": true, + "echo": true, + "max_tokens": 7, + "temperature": 1 + } +] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/README.md b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/README.md new file mode 100644 index 00000000..cc7f3052 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/README.md @@ -0,0 +1,19 @@ +# quick commands + +``` +export MODEL=meta-llama/Llama-3.2-11B-Vision-Instruct && export MODEL_PATH=/data/model_files/$MODEL +docker kill vllm_batch; docker rm vllm_batch; +docker run \ + --runtime nvidia \ + --shm-size=16gb \ + --gpus '"device=6,7"' \ + -v $MODEL_PATH:/workspace/model_files:ro \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/examples:/workspace/examples \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_batch.py:/workspace/vllm_batch.py \ + -p 5005:5005 \ + -e CONFIG_FILE=/workspace/examples/v2/llama-3.2-vision/config.json \ + -e MODEL_WEIGHTS_FOLDER=/workspace/model_files \ + --name vllm_batch \ + ${IMAGE_BATCH} \ + python vllm_batch.py +``` \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/config.json b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/config.json new file mode 100644 index 00000000..a26a3b3c --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/config.json @@ -0,0 +1,18 @@ +{ + "input_data_path": "./examples/v2/llama-3.2-vision/data_oai_chat.json", + "output_data_path": "./examples/v2/llama-3.2-vision/output_oi_chat.json", + "model_config": { + "model": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "checkpoint_path": "my_path", + "num_shards": 1, + "max_model_len": 4096, + "max_num_seqs": 16, + "enforce_eager": true, + "response_role": "assistant", + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/data_oai_chat.json b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/data_oai_chat.json new file mode 100644 index 00000000..2cdc9656 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/data_oai_chat.json @@ -0,0 +1,22 @@ +[ + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + } + ] + } + ], + "max_tokens": 64 + } +] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/output_oi_chat.json b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/output_oi_chat.json new file mode 100644 index 00000000..402429c6 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/output_oi_chat.json @@ -0,0 +1 @@ +[{"id": "chat-b61abe3898714576802d92f36ab90c38", "object": "chat.completion", "created": 1727669398, "model": "/workspace/model_files", "choices": [{"index": 0, "message": {"role": "assistant", "content": "This image depicts a serene landscape with a long wooden boardwalk or path that stretches out into a field dotted with long green grass in the foreground and tall green and yellow grass and green and red shrubbery on the side of the path. In the background, there are large, short and thick green and yellow shrubs", "tool_calls": []}, "logprobs": null, "finish_reason": "length", "stop_reason": null}], "usage": {"prompt_tokens": 17, "total_tokens": 81, "completion_tokens": 64}, "prompt_logprobs": null}] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/sample_data_chat_gemma.json b/model-engine/model_engine_server/inference/vllm/examples/v2/sample_data_chat_gemma.json deleted file mode 100644 index 39722117..00000000 --- a/model-engine/model_engine_server/inference/vllm/examples/v2/sample_data_chat_gemma.json +++ /dev/null @@ -1 +0,0 @@ -[{"messages": [{"role": "user", "content": "What is a good place for travel in the US?"}, {"role": "assistant", "content": "California."}, {"role": "user", "content": "What can I do in California?"}], "logprobs": true}] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/requirements-dev.txt b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt index 066478b2..d330101a 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements-dev.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt @@ -1 +1 @@ -vllm==0.6.1.post2 \ No newline at end of file +vllm==0.6.2 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/vllm_batch.py b/model-engine/model_engine_server/inference/vllm/vllm_batch.py index 23d716e4..ea0989fe 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_batch.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_batch.py @@ -25,6 +25,7 @@ CreateBatchCompletionsEngineRequest, CreateBatchCompletionsV1RequestContent, TokenOutput, + VLLMModelConfig, ) from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, @@ -40,10 +41,11 @@ from tqdm import tqdm from typing_extensions import TypeAlias, assert_never from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest, ErrorResponse from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.utils import merge_async_iterators CONFIG_FILE = os.getenv("CONFIG_FILE") @@ -117,7 +119,7 @@ async def download_model(checkpoint_path: str, target_dir: str) -> None: async def generate_v1_completions( - engine: AsyncEngineClient, + engine: EngineClient, content: CreateBatchCompletionsV1RequestContent, ) -> List[Optional[CompletionV1Output]]: prompts = content.prompts @@ -181,7 +183,7 @@ async def generate_v1_completions( async def generate_v2_completions( - engine: AsyncEngineClient, + engine: EngineClient, requests: Union[List[CompletionRequest], List[ChatCompletionRequest]], ) -> List[Union[CompletionResponse, ErrorResponse, None]]: bar = tqdm(total=len(requests), desc="Processed requests") @@ -214,7 +216,7 @@ async def generate_v2_completions( async def generate_completions( - engine: AsyncEngineClient, request: _BatchCompletionContent + engine: EngineClient, request: _BatchCompletionContent ) -> Union[List[Optional[CompletionV1Output]], List[Optional[CompletionResponse]]]: if isinstance(request, CreateBatchCompletionsV1RequestContent): return await generate_v1_completions(engine, request) @@ -227,30 +229,36 @@ async def generate_completions( async def init_engine( model: str, request: CreateBatchCompletionsEngineRequest, -) -> AsyncEngineClient: +) -> EngineClient: global openai_serving_chat global openai_serving_completion if request.attention_backend is not None: os.environ["ATTENTION_BACKEND"] = request.attention_backend + parsed_configs = VLLMModelConfig.model_validate_json(request.model_cfg.model_dump_json()) + if not parsed_configs.max_model_len: + parsed_configs.max_model_len = request.model_cfg.max_context_length + + print("VLLM additional configs:", parsed_configs.model_dump()) + engine_args = AsyncEngineArgs( model=model, tensor_parallel_size=request.model_cfg.num_shards, seed=request.model_cfg.seed or 0, disable_log_requests=True, gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9, - max_model_len=request.model_cfg.max_context_length, + **parsed_configs.model_dump(exclude_none=True), ) - async_engine_client = AsyncLLMEngine.from_engine_args(engine_args) - model_config = await async_engine_client.get_model_config() - served_model_names = [model] + engine_client = AsyncLLMEngine.from_engine_args(engine_args) + model_config = await engine_client.get_model_config() + base_model_paths = [BaseModelPath(name=model, model_path=model)] openai_serving_chat = OpenAIServingChat( - async_engine_client, + engine_client, model_config, - served_model_names, + base_model_paths, response_role=request.model_cfg.response_role or "assistant", lora_modules=None, prompt_adapters=None, @@ -259,15 +267,15 @@ async def init_engine( ) openai_serving_completion = OpenAIServingCompletion( - async_engine_client, + engine_client, model_config, - served_model_names, + base_model_paths, lora_modules=None, prompt_adapters=None, request_logger=None, ) - return async_engine_client + return engine_client def overwrite_request(request: Dict[str, Any], model: str) -> Dict[str, Any]: @@ -291,7 +299,10 @@ def load_batch_content( return TypeAdapter( Union[List[CompletionRequest], List[ChatCompletionRequest]] ).validate_python( - [overwrite_request(req.model_dump(exclude_none=True), model) for req in content] + [ + overwrite_request(req.model_dump(exclude_none=True, mode="json"), model) + for req in content + ] ) return content diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 4929ef72..6e3e8cbd 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -3,6 +3,7 @@ import json import os import signal +import socket import subprocess import traceback from logging import Logger @@ -11,9 +12,9 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException, Request from fastapi.responses import Response, StreamingResponse from vllm.engine.async_llm_engine import AsyncEngineDeadError -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.launcher import serve_http -from vllm.entrypoints.openai.api_server import build_async_engine_client, init_app +from vllm.entrypoints.openai.api_server import build_app, build_async_engine_client, init_app_state from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor @@ -25,7 +26,7 @@ logger = Logger("vllm_server") -async_engine_client: AsyncEngineClient +engine_client: EngineClient TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds @@ -45,7 +46,7 @@ async def generate(request: Request) -> Response: """ # check health before accepting request and fail fast if engine isn't healthy try: - await async_engine_client.check_health() + await engine_client.check_health() request_dict = await request.json() prompt = request_dict.pop("prompt") @@ -75,12 +76,12 @@ async def generate(request: Request) -> Response: ) guided_decoding_backend = ( - await async_engine_client.get_decoding_config() + await engine_client.get_decoding_config() ).guided_decoding_backend guided_decode_logit_processor = await get_guided_decoding_logits_processor( guided_decoding_backend, partial_openai_request, - await async_engine_client.get_tokenizer(lora_request=None), + await engine_client.get_tokenizer(lora_request=None), ) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: @@ -89,10 +90,10 @@ async def generate(request: Request) -> Response: request_id = random_uuid() - results_generator = async_engine_client.generate(prompt, sampling_params, request_id) + results_generator = engine_client.generate(prompt, sampling_params, request_id) async def abort_request() -> None: - await async_engine_client.abort(request_id) + await engine_client.abort(request_id) if stream: # Streaming case @@ -127,7 +128,7 @@ async def stream_results() -> AsyncGenerator[str, None]: last_output_text = request_output.outputs[-1].text if await request.is_disconnected(): # Abort the request if the client disconnects. - await async_engine_client.abort(request_id) + await engine_client.abort(request_id) return Response(status_code=499) final_output = request_output @@ -223,14 +224,27 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - global async_engine_client - async with build_async_engine_client(args) as async_engine_client: - app = await init_app(async_engine_client, args) + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # nosemgrep + temp_socket.bind(("", args.port)) + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + global engine_client + async with build_async_engine_client(args) as engine_client: + app = build_app(args) + + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) + + temp_socket.close() app.include_router(router) shutdown_task = await serve_http( app, - engine=async_engine_client, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/scripts/generate-openai-types.sh b/scripts/generate-openai-types.sh index 2a60d88c..e787cfe8 100755 --- a/scripts/generate-openai-types.sh +++ b/scripts/generate-openai-types.sh @@ -17,6 +17,11 @@ datamodel-codegen \ --strict-nullable \ --use-annotated +# replace pydantic import w/ our custom module to replace the AnyUrl types +# Pydantic AnyUrl is super problematic for various reasons +sed -i 's/^from pydantic import /from model_engine_server.common.pydantic_types import /' ${DEST_DIR}/openai.py + + CLIENT_DIR=${BASE_DIR}/clients/python/llmengine/data_types/gen # Generate OpenAPI types for client From 515ab650dd930b84fbb7d7368454a3e9773c2100 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Wed, 2 Oct 2024 19:18:10 -0700 Subject: [PATCH 391/425] Update docs to sunset free demo (#625) * Update docs to sunset free demo * Fix wording and formatting --- docs/getting_started.md | 23 ++++++++--------------- docs/pricing.md | 10 ++-------- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/docs/getting_started.md b/docs/getting_started.md index fea0531a..5dd3d422 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -1,8 +1,8 @@ # Getting Started -The fastest way to get started with LLM Engine is to use the Python client in this repository to -run inference and fine-tuning on Scale's infrastructure. This path does not require you to install -anything on your infrastructure, and Scale's free research preview gives you access to experimentation using open source LLMs. +**Note: As of October 31st 2024, LLM Engine's public demo service is sunsetted. We have thus removed the documentation +pieces relating to calling the demo service, procuring a Spellbook API key, etc. Please view our Self Hosting Guide instead. +We will however leave behind the Example Code snippets for posterity, and as a reference for self-hosted and Scale internal users.** To start, install LLM Engine via pip: @@ -11,28 +11,21 @@ To start, install LLM Engine via pip: pip install scale-llm-engine ``` -## Scale API Keys +## Scale user ID -Next, you need a Scale Spellbook API key. +Next, you need a Scale user ID. Recall that this is only applicable to Scale internal users for now, and we are just leaving +this note to serve as internal documentation. -### Retrieving your API Key - -To retrieve your API key, head to [Scale Spellbook](https://spellbook.scale.com) where -you will get an API key on the [settings](https://spellbook.scale.com/settings) page. - -!!! note "Different API Keys for different Scale Products" - - If you have leveraged Scale's platform for annotation work in the past, please note that your Spellbook API key will be different than the Scale Annotation API key. You will want to create a Spellbook API key before getting started. ### Set your API Key LLM Engine uses environment variables to access your API key. -Set this API key as the `SCALE_API_KEY` environment variable by running the following command in your terminal before you run your python application. +Set the `SCALE_API_KEY` environment variable to your Scale user ID by running the following command in your terminal before you run your python application. ``` -export SCALE_API_KEY="[Your API key]" +export SCALE_API_KEY="[Your Scale user ID]" ``` You can also add in the line above to your `.zshrc` or `.bash_profile` so it's automatically set for future sessions. diff --git a/docs/pricing.md b/docs/pricing.md index b61923b9..e1ffa2fe 100644 --- a/docs/pricing.md +++ b/docs/pricing.md @@ -1,16 +1,10 @@ # Pricing -LLM Engine is an open-source project and free [self-hosting](../guides/self_hosting) will always be an option. - -A hosted option for LLM Engine is being offered initially as a free preview via [Scale](https://scale.com/) [Spellbook](https://spellbook.scale.com/). +LLM Engine is an open-source project and free [self-hosting](../guides/self_hosting) will always be an option. As of October 31st 2024, +the free demo service is sunsetted. ## Self-Hosted Models We are committed to supporting the open-source community. [Self-hosting](../guides/self_hosting) LLM Engine will remain free and open-source. We would love [contributions](../contributing) from the community make this even more amazing! -## Hosted Models - -Once the limited preview period has ended, billing for hosted models will be managed through the Scale [Spellbook](https://spellbook.scale.com/settings) product. - -Scale Spellbook leverages usage-based spending, billed to a credit card. Details on usage-based pricing will be shared with everyone before completing the limited preview. From 2061eff2931a582b0ac980d594ff254ad5f8f3fb Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 3 Oct 2024 19:56:04 -0700 Subject: [PATCH 392/425] Add OpenAI compatible v2 completion (#627) * Add OpenAI compatible v2 completion * Fix batch completion response * fix test coverage * Get model_name from v2 requests correctly * Make post inference hook optional and add async version of forward * update predict method to be fully async * woops forgot api route * Add async for stream forwarder * Update to asyncgenerator * Move aiohttp_sse_client to common to avoid unnecessary imports * Fix tests to use aioresponses for aiohttp mocking * Skip streaming tests till I have time --- .../api/v2/chat_completion.py | 134 +++++---- .../model_engine_server/api/v2/common.py | 8 +- .../model_engine_server/api/v2/completion.py | 284 ++++++++++++++++++ .../gateways => common}/aiohttp_sse_client.py | 0 .../common/dtos/llms/batch_completion.py | 22 +- .../common/dtos/llms/chat_completion.py | 22 +- .../common/dtos/llms/completion.py | 26 +- .../common/dtos/llms/vllm.py | 61 ++++ .../use_cases/llm_model_endpoint_use_cases.py | 278 ++++++++++++++++- .../inference/forwarding/celery_forwarder.py | 3 +- .../inference/forwarding/echo_server.py | 7 +- .../inference/forwarding/forwarding.py | 165 +++++++--- .../inference/forwarding/http_forwarder.py | 54 ++-- ...eaming_model_endpoint_inference_gateway.py | 3 +- model-engine/requirements-test.txt | 1 + .../unit/inference/test_http_forwarder.py | 22 +- 16 files changed, 922 insertions(+), 168 deletions(-) create mode 100644 model-engine/model_engine_server/api/v2/completion.py rename model-engine/model_engine_server/{infra/gateways => common}/aiohttp_sse_client.py (100%) diff --git a/model-engine/model_engine_server/api/v2/chat_completion.py b/model-engine/model_engine_server/api/v2/chat_completion.py index 0dc1f989..b4b02837 100644 --- a/model-engine/model_engine_server/api/v2/chat_completion.py +++ b/model-engine/model_engine_server/api/v2/chat_completion.py @@ -11,14 +11,15 @@ ) from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.llms import ( - ChatCompletionV2ErrorChunk, ChatCompletionV2Request, ChatCompletionV2Response, ChatCompletionV2ResponseItem, + ChatCompletionV2StreamErrorChunk, StreamError, StreamErrorContent, TokenUsage, ) +from model_engine_server.common.dtos.llms.chat_completion import ChatCompletionV2StreamSuccessChunk from model_engine_server.core.auth.authentication_repository import User from model_engine_server.core.loggers import ( LoggerTagKey, @@ -65,7 +66,7 @@ def handle_streaming_exception( } logger.error("Exception: %s", structured_log) return { - "data": ChatCompletionV2ErrorChunk( + "data": ChatCompletionV2StreamErrorChunk( request_id=str(request_id), error=StreamError( status_code=code, @@ -92,76 +93,85 @@ async def handle_stream_request( tokenizer_repository=external_interfaces.tokenizer_repository, ) - try: - response = await use_case.execute( - user=auth, model_endpoint_name=model_endpoint_name, request=request - ) - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail=str(exc), - ) from exc - except ( - EndpointUnsupportedInferenceTypeException, - EndpointUnsupportedRequestException, - ) as exc: - raise HTTPException( - status_code=400, - detail=str(exc), - ) from exc - except ObjectHasInvalidValueException as exc: - raise HTTPException(status_code=400, detail=str(exc)) from exc - except Exception as exc: - raise HTTPException( - status_code=500, - detail="Internal error occurred. Our team has been notified.", - ) from exc - - async def event_generator(): + with timer() as use_case_timer: try: - ttft = None - message = None - with timer() as use_case_timer: # todo, this should be move to start of method + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + + # We fetch the first response to check if upstream request was successful + # If it was not, this will raise the corresponding HTTPException + # If it was, we will proceed to the event generator + first_message: ChatCompletionV2StreamSuccessChunk = await response.__anext__() + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except ( + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + ) as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException( + status_code=500, + detail="Internal error occurred. Our team has been notified.", + ) from exc + + async def event_generator(timer: timer = use_case_timer): + try: + ttft = None + message = None + yield {"data": first_message.model_dump_json(exclude_none=True)} + async for message in response: if ttft is None: - ttft = use_case_timer.lap() + ttft = timer.lap() # if ttft is None and message.startswith("data"): - # ttft = use_case_timer.lap() + # ttft = timer.lap() print("message", message.model_dump_json(exclude_none=True)) yield {"data": message.model_dump_json(exclude_none=True)} - if message: - background_tasks.add_task( - external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, - TokenUsage( - num_prompt_tokens=(message.usage.prompt_tokens if message.usage else None), - num_completion_tokens=( - message.usage.completion_tokens if message.usage else None + if message: + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=( + message.usage.prompt_tokens if message.usage else None + ), + num_completion_tokens=( + message.usage.completion_tokens if message.usage else None + ), + total_duration=timer.duration, ), - total_duration=use_case_timer.duration, - ), - metric_metadata, - ) + metric_metadata, + ) - # The following two exceptions are only raised after streaming begins, so we wrap the exception within a Response object - except InvalidRequestException as exc: - yield handle_streaming_exception(exc, 400, str(exc)) - except UpstreamServiceError as exc: - request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) - logger.exception( - f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" - ) - yield handle_streaming_exception( - exc, - 500, - f"Upstream service error for request_id {request_id}", - ) - except Exception as exc: - yield handle_streaming_exception( - exc, 500, "Internal error occurred. Our team has been notified." - ) + # The following two exceptions are only raised after streaming begins, so we wrap the exception within a Response object + except InvalidRequestException as exc: + yield handle_streaming_exception(exc, 400, str(exc)) + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + yield handle_streaming_exception( + exc, + 500, + f"Upstream service error for request_id {request_id}", + ) + except Exception as exc: + yield handle_streaming_exception( + exc, 500, "Internal error occurred. Our team has been notified." + ) - return EventSourceResponse(event_generator()) + return EventSourceResponse(event_generator(timer=use_case_timer)) async def handle_sync_request( diff --git a/model-engine/model_engine_server/api/v2/common.py b/model-engine/model_engine_server/api/v2/common.py index 50c61df2..d651eb4b 100644 --- a/model-engine/model_engine_server/api/v2/common.py +++ b/model-engine/model_engine_server/api/v2/common.py @@ -19,7 +19,13 @@ async def get_metric_metadata( request: Request, auth: User = Depends(verify_authentication), ) -> MetricMetadata: - model_name = request.query_params.get("model", None) + # note that this is ok because request will cache the body + body = await request.json() + model_name = body.get("model", None) + if not model_name: + # get model name from batch completion request + model_name = body.get("model_config", {}).get("model", None) + return MetricMetadata(user=auth, model_name=model_name) diff --git a/model-engine/model_engine_server/api/v2/completion.py b/model-engine/model_engine_server/api/v2/completion.py new file mode 100644 index 00000000..250acd83 --- /dev/null +++ b/model-engine/model_engine_server/api/v2/completion.py @@ -0,0 +1,284 @@ +import traceback +from datetime import datetime +from typing import Any + +import pytz +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.llms import ( + CompletionV2Request, + CompletionV2Response, + CompletionV2StreamErrorChunk, + StreamError, + StreamErrorContent, + TokenUsage, +) +from model_engine_server.common.dtos.llms.completion import CompletionV2StreamSuccessChunk +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.core.utils.timer import timer +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + InvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + UpstreamServiceError, +) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + CompletionStreamV2UseCase, + CompletionSyncV2UseCase, +) +from sse_starlette import EventSourceResponse + +from .common import get_metric_metadata, record_route_call + +logger = make_logger(logger_name()) + +completion_router_v2 = APIRouter(dependencies=[Depends(record_route_call)]) + + +def handle_streaming_exception( + e: Exception, + code: int, + message: str, +): # pragma: no cover + tb_str = traceback.format_exception(e) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": message, + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Exception: %s", structured_log) + return { + "data": CompletionV2StreamErrorChunk( + request_id=str(request_id), + error=StreamError( + status_code=code, + content=StreamErrorContent( + error=message, + timestamp=timestamp, + ), + ), + ).model_dump_json(exclude_none=True) + } + + +async def handle_stream_request( + external_interfaces: ExternalInterfaces, + background_tasks: BackgroundTasks, + request: CompletionV2Request, + auth: User, + model_endpoint_name: str, + metric_metadata: MetricMetadata, +): # pragma: no cover + use_case = CompletionStreamV2UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + + with timer() as use_case_timer: + try: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + + # We fetch the first response to check if upstream request was successful + # If it was not, this will raise the corresponding HTTPException + # If it was, we will proceed to the event generator + first_message: CompletionV2StreamSuccessChunk = await response.__anext__() + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except ( + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + ) as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException( + status_code=500, + detail="Internal error occurred. Our team has been notified.", + ) from exc + + async def event_generator(timer: timer = use_case_timer): + try: + ttft = None + message = None + yield {"data": first_message.model_dump_json(exclude_none=True)} + async for message in response: + if ttft is None: + ttft = timer.lap() + # if ttft is None and message.startswith("data"): + # ttft = timer.lap() + yield {"data": message.model_dump_json(exclude_none=True)} + + if message: + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=( + message.usage.prompt_tokens if message.usage else None + ), + num_completion_tokens=( + message.usage.completion_tokens if message.usage else None + ), + total_duration=timer.duration, + ), + metric_metadata, + ) + + # The following two exceptions are only raised after streaming begins, so we wrap the exception within a Response object + except InvalidRequestException as exc: + yield handle_streaming_exception(exc, 400, str(exc)) + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + yield handle_streaming_exception( + exc, + 500, + f"Upstream service error for request_id {request_id}", + ) + except Exception as exc: + yield handle_streaming_exception( + exc, 500, "Internal error occurred. Our team has been notified." + ) + + return EventSourceResponse(event_generator()) + + +async def handle_sync_request( + external_interfaces: ExternalInterfaces, + request: CompletionV2Request, + background_tasks: BackgroundTasks, + auth: User, + model_endpoint_name: str, + metric_metadata: MetricMetadata, +): + try: + use_case = CompletionSyncV2UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + with timer() as use_case_timer: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=(response.usage.prompt_tokens if response.usage else None), + num_completion_tokens=( + response.usage.completion_tokens if response.usage else None + ), + total_duration=use_case_timer.duration, + ), + metric_metadata, + ) + return response + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + raise HTTPException( + status_code=500, + detail=f"Upstream service error for request_id {request_id}", + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + if isinstance(exc, ObjectNotAuthorizedException): # pragma: no cover + logger.info( + f"POST /completions-sync to endpoint {model_endpoint_name} for {auth} failed with authz error {exc.args}" + ) + + raise HTTPException( + status_code=404, + detail="The specified endpoint could not be found.", + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=to_error_details(exc)) + except InvalidRequestException as exc: + raise HTTPException(status_code=400, detail=to_error_details(exc)) + except EndpointUnsupportedRequestException as exc: + raise HTTPException( + status_code=400, + detail=f"Endpoint does not support request: {str(exc)}", + ) from exc + except EndpointUnsupportedInferenceTypeException as exc: + raise HTTPException( + status_code=400, + detail=f"Unsupported inference type: {str(exc)}", + ) from exc + + +def to_error_details(exc: Exception) -> Any: + if not exc.args or len(exc.args) == 0: + return str(exc) + if len(exc.args) == 1: + return exc.args[0] + else: + return exc.args + + +@completion_router_v2.post("/completions", response_model=CompletionV2Response) +async def completion( + request: CompletionV2Request, + background_tasks: BackgroundTasks, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +) -> CompletionV2Response: # pragma: no cover + model_endpoint_name = request.model + if hmi_config.sensitive_log_mode: + logger.info( + f"POST /v2/completion ({('stream' if request.stream else 'sync')}) to endpoint {model_endpoint_name} for {auth}" + ) + else: + logger.info( + f"POST /v2/completion ({('stream' if request.stream else 'sync')}) with {request} to endpoint {model_endpoint_name} for {auth}" + ) + + if request.stream: + return await handle_stream_request( + external_interfaces=external_interfaces, + background_tasks=background_tasks, + request=request, + auth=auth, + model_endpoint_name=model_endpoint_name, + metric_metadata=metric_metadata, + ) + else: + return await handle_sync_request( + external_interfaces=external_interfaces, + background_tasks=background_tasks, + request=request, + auth=auth, + model_endpoint_name=model_endpoint_name, + metric_metadata=metric_metadata, + ) diff --git a/model-engine/model_engine_server/infra/gateways/aiohttp_sse_client.py b/model-engine/model_engine_server/common/aiohttp_sse_client.py similarity index 100% rename from model-engine/model_engine_server/infra/gateways/aiohttp_sse_client.py rename to model-engine/model_engine_server/common/aiohttp_sse_client.py diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py index 019aa707..e8eff546 100644 --- a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -1,6 +1,5 @@ -# Make sure to keep this in sync with inference/batch_inference/dto.py. from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from model_engine_server.common.dtos.llms.chat_completion import ( ChatCompletionV2Request, @@ -9,7 +8,7 @@ from model_engine_server.common.dtos.llms.completion import ( CompletionOutput, CompletionV2Request, - CompletionV2Response, + CompletionV2SyncResponse, ) from model_engine_server.common.dtos.llms.vllm import VLLMModelConfig from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field @@ -197,15 +196,16 @@ class FilteredChatCompletionV2Request(ChatCompletionV2Request): # V2 DTOs for batch completions -CompletionRequest: TypeAlias = Union[FilteredCompletionV2Request, FilteredChatCompletionV2Request] -CompletionResponse: TypeAlias = Union[CompletionV2Response, ChatCompletionV2SyncResponse] -CreateBatchCompletionsV2RequestContent: TypeAlias = Union[ - List[FilteredCompletionV2Request], List[FilteredChatCompletionV2Request] -] +CompletionRequest: TypeAlias = FilteredCompletionV2Request | FilteredChatCompletionV2Request +CompletionResponse: TypeAlias = CompletionV2SyncResponse | ChatCompletionV2SyncResponse +CreateBatchCompletionsV2RequestContent: TypeAlias = ( + List[FilteredCompletionV2Request] | List[FilteredChatCompletionV2Request] +) + CreateBatchCompletionsV2ModelConfig: TypeAlias = BatchCompletionsModelConfig -BatchCompletionContent = Union[ - CreateBatchCompletionsV1RequestContent, CreateBatchCompletionsV2RequestContent -] +BatchCompletionContent = ( + CreateBatchCompletionsV1RequestContent | CreateBatchCompletionsV2RequestContent +) class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): diff --git a/model-engine/model_engine_server/common/dtos/llms/chat_completion.py b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py index 547ee5f9..a5f89394 100644 --- a/model-engine/model_engine_server/common/dtos/llms/chat_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional from model_engine_server.common.dtos.llms.completion import StreamError from model_engine_server.common.dtos.llms.vllm import VLLMChatCompletionAdditionalParams @@ -33,23 +33,23 @@ class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMChatCompletionAdd ] -ChatCompletionV2SyncResponse = CreateChatCompletionResponse -ChatCompletionV2SuccessChunk = CreateChatCompletionStreamResponse +ChatCompletionV2SyncResponse: TypeAlias = CreateChatCompletionResponse +ChatCompletionV2StreamSuccessChunk: TypeAlias = CreateChatCompletionStreamResponse -class ChatCompletionV2ErrorChunk(BaseModel): +class ChatCompletionV2StreamErrorChunk(BaseModel): error: StreamError -ChatCompletionV2Chunk: TypeAlias = Union[ChatCompletionV2SuccessChunk, ChatCompletionV2ErrorChunk] +ChatCompletionV2Chunk: TypeAlias = ( + ChatCompletionV2StreamSuccessChunk | ChatCompletionV2StreamErrorChunk +) ChatCompletionV2StreamResponse: TypeAlias = ( - EventSourceResponse # EventSourceResponse[ChatCompletionV2Chunk | ChatCompletionV2ErrorChunk] + EventSourceResponse # EventSourceResponse[ChatCompletionV2Chunk] ) -ChatCompletionV2Response: TypeAlias = Union[ - ChatCompletionV2SyncResponse, ChatCompletionV2StreamResponse -] +ChatCompletionV2Response: TypeAlias = ChatCompletionV2SyncResponse | ChatCompletionV2StreamResponse # This is a version of ChatCompletionV2Response that is used by pydantic to determine the response model -# Since EventSourceResponse isn't a pydanitc model, we need to use a Union of the two response types -ChatCompletionV2ResponseItem: TypeAlias = Union[ChatCompletionV2SyncResponse, ChatCompletionV2Chunk] +# Since EventSourceResponse isn't a pydantic model, we need to use a Union of the two response types +ChatCompletionV2ResponseItem: TypeAlias = ChatCompletionV2SyncResponse | ChatCompletionV2Chunk diff --git a/model-engine/model_engine_server/common/dtos/llms/completion.py b/model-engine/model_engine_server/common/dtos/llms/completion.py index a680a2b7..44ae72db 100644 --- a/model-engine/model_engine_server/common/dtos/llms/completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/completion.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TypeAlias +from model_engine_server.common.dtos.llms.vllm import VLLMCompletionAdditionalParams from model_engine_server.common.pydantic_types import BaseModel, Field from model_engine_server.common.types.gen.openai import ( CreateCompletionRequest, CreateCompletionResponse, ) +from sse_starlette import EventSourceResponse from typing_extensions import Annotated # Fields that are a part of OpenAI spec but are not supported by model engine @@ -288,7 +290,7 @@ def inter_token_latency(self) -> Optional[float]: # Only for streaming requests return (self.total_duration - self.time_to_first_token) / (self.num_completion_tokens - 1) -class CompletionV2Request(CreateCompletionRequest): +class CompletionV2Request(CreateCompletionRequest, VLLMCompletionAdditionalParams): model: Annotated[ str, Field( @@ -320,5 +322,21 @@ class CompletionV2Request(CreateCompletionRequest): ] -class CompletionV2Response(CreateCompletionResponse): - pass +CompletionV2SyncResponse: TypeAlias = CreateCompletionResponse +CompletionV2StreamSuccessChunk: TypeAlias = CreateCompletionResponse + + +class CompletionV2StreamErrorChunk(BaseModel): + error: StreamError + + +CompletionV2StreamChunk: TypeAlias = CompletionV2StreamSuccessChunk | CompletionV2StreamErrorChunk +CompletionV2StreamResponse: TypeAlias = ( + EventSourceResponse # EventSourceResponse[CompletionV2StreamChunk] +) + +CompletionV2Response: TypeAlias = CompletionV2SyncResponse | CompletionV2StreamResponse + +# This is a version of CompletionV2Response that is used by pydantic to determine the response model +# Since EventSourceResponse isn't a pydantic model, we need to use a Union of the two response types +CompletionV2ResponseItem: TypeAlias = CompletionV2SyncResponse | CompletionV2StreamChunk diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py index af207d94..f23059af 100644 --- a/model-engine/model_engine_server/common/dtos/llms/vllm.py +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -1,6 +1,11 @@ from typing import Any, Dict, List, Optional from model_engine_server.common.pydantic_types import BaseModel, Field +from model_engine_server.common.types.gen.openai import ( + ResponseFormatJsonObject, + ResponseFormatJsonSchema, + ResponseFormatText, +) from typing_extensions import Annotated # This was last synced w/ vLLM v0.5.5 on 2024-09-03 @@ -168,3 +173,59 @@ class VLLMChatCompletionAdditionalParams(VLLMSamplingParams): "for guided json decoding." ), ) + + +class VLLMCompletionAdditionalParams(VLLMSamplingParams): + add_special_tokens: Optional[bool] = Field( + default=None, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt." + ), + ) + + response_format: Optional[ + ResponseFormatText | ResponseFormatJsonObject | ResponseFormatJsonSchema + ] = Field( + default=None, + description=( + "Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'} or {'type': 'text' } is " + "supported." + ), + ) + + guided_json: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_regex: Optional[str] = Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ) + + guided_grammar: Optional[str] = Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ) + + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ) + + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index aae8ff2f..6e1fa867 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -12,14 +12,14 @@ import re from dataclasses import asdict from functools import lru_cache -from typing import Any, AsyncIterable, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Union import yaml from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests from model_engine_server.common.dtos.llms import ( ChatCompletionV2Request, - ChatCompletionV2SuccessChunk, + ChatCompletionV2StreamSuccessChunk, ChatCompletionV2SyncResponse, CompletionOutput, CompletionStreamOutput, @@ -50,6 +50,11 @@ UpdateBatchCompletionsV2Response, VLLMEngineAdditionalArgs, ) +from model_engine_server.common.dtos.llms.completion import ( + CompletionV2Request, + CompletionV2StreamSuccessChunk, + CompletionV2SyncResponse, +) from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus @@ -133,6 +138,9 @@ CHAT_TEMPLATE_MAX_LENGTH = 10_000 CHAT_SUPPORTED_INFERENCE_FRAMEWORKS = [LLMInferenceFramework.VLLM] +OPENAI_COMPLETION_PATH = "/v1/completions" +OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS = [LLMInferenceFramework.VLLM] + LLM_METADATA_KEY = "_llm" RESERVED_METADATA_KEYS = [LLM_METADATA_KEY, CONVERTED_FROM_ARTIFACT_LIKE_KEY] VLLM_MODEL_WEIGHTS_FOLDER = "model_files" @@ -210,6 +218,8 @@ "llama-3-1-70b-instruct", "llama-3-1-405b", "llama-3-1-405b-instruct", + "llama-3-2-11b-vision-instruct", + "llama-3-2-90b-vision-instruct", "falcon-7b", "falcon-7b-instruct", "falcon-40b", @@ -887,7 +897,10 @@ async def create_vllm_bundle( healthcheck_route="/health", predict_route="/predict", streaming_predict_route="/stream", - extra_routes=[OPENAI_CHAT_COMPLETION_PATH], + extra_routes=[ + OPENAI_CHAT_COMPLETION_PATH, + OPENAI_COMPLETION_PATH, + ], env={}, ), metadata={}, @@ -2436,9 +2449,260 @@ async def _response_chunk_generator( # raising an exception if it is not one of the frameworks handled above. +def validate_endpoint_supports_openai_completion( + endpoint: ModelEndpoint, endpoint_content: GetLLMModelEndpointV1Response +): # pragma: no cover + if endpoint_content.inference_framework not in OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS: + raise EndpointUnsupportedInferenceTypeException( + f"The endpoint's inference framework ({endpoint_content.inference_framework}) does not support openai compatible completion." + ) + + if ( + not isinstance(endpoint.record.current_model_bundle.flavor, RunnableImageLike) + or OPENAI_COMPLETION_PATH not in endpoint.record.current_model_bundle.flavor.extra_routes + ): + raise EndpointUnsupportedRequestException( + "Endpoint does not support v2 openai compatible completion" + ) + + +class CompletionSyncV2UseCase: + """ + Use case for running a v2 openai compatible completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): # pragma: no cover + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + async def execute( + self, user: User, model_endpoint_name: str, request: CompletionV2Request + ) -> CompletionV2SyncResponse: # pragma: no cover + """ + Runs the use case to create a sync inference task. + + Args: + user: The user who is creating the sync inference task. + model_endpoint_name: The name of the model endpoint for the task. + request: The body of the request to forward to the endpoint. + + Returns: + A response object that contains the status and result of the task. + + Raises: + ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectNotAuthorizedException: If the owner does not own the model endpoint. + """ + + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if ( + model_endpoint.record.endpoint_type is not ModelEndpointType.SYNC + and model_endpoint.record.endpoint_type is not ModelEndpointType.STREAMING + ): + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} does not serve sync requests." + ) + + inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + validate_endpoint_supports_openai_completion(model_endpoint, endpoint_content) + + # if inference framework is VLLM, we need to set the model to use the weights folder + if endpoint_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + + inference_request = SyncEndpointPredictV1Request( + args=request.model_dump(exclude_none=True), + destination_path=OPENAI_COMPLETION_PATH, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + try: + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + raise UpstreamServiceError( + status_code=500, + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), + ) + + output = json.loads(predict_result.result["result"]) + # reset model name to correct value + output["model"] = model_endpoint.record.name + return CompletionV2SyncResponse.model_validate(output) + except UpstreamServiceError as exc: + # Expect upstream inference service to handle bulk of input validation + if 400 <= exc.status_code < 500: + raise InvalidRequestException(exc.content) + raise exc + + +class CompletionStreamV2UseCase: + """ + Use case for running a v2 openai compatible completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): # pragma: no cover + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + async def execute( + self, model_endpoint_name: str, request: CompletionV2Request, user: User + ) -> AsyncGenerator[CompletionV2StreamSuccessChunk, None]: # pragma: no cover + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if model_endpoint.record.endpoint_type != ModelEndpointType.STREAMING: + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} is not a streaming endpoint." + ) + + inference_gateway = ( + self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() + ) + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + validate_endpoint_supports_openai_completion(model_endpoint, model_content) + + # if inference framework is VLLM, we need to set the model to use the weights folder + if model_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + + inference_request = SyncEndpointPredictV1Request( + args=request.model_dump(exclude_none=True), + destination_path=OPENAI_COMPLETION_PATH, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + + return self._response_chunk_generator( + request_id=request_id, + model_endpoint=model_endpoint, + model_content=model_content, + inference_gateway=inference_gateway, + inference_request=inference_request, + ) + + async def _response_chunk_generator( + self, + request_id: Optional[str], + model_endpoint: ModelEndpoint, + model_content: GetLLMModelEndpointV1Response, + inference_gateway: StreamingModelEndpointInferenceGateway, + inference_request: SyncEndpointPredictV1Request, + ) -> AsyncGenerator[CompletionV2StreamSuccessChunk, None]: # pragma: no cover + """ + Async generator yielding tokens to stream for the completions response. Should only be called when + returned directly by execute(). + """ + try: + predict_result = inference_gateway.streaming_predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + ) + except UpstreamServiceError as exc: + # Expect upstream inference service to handle bulk of input validation + if 400 <= exc.status_code < 500: + raise InvalidRequestException(str(exc)) + + raise exc + + async for res in predict_result: + if not res.status == TaskStatus.SUCCESS or res.result is None: + raise UpstreamServiceError( + status_code=500, + content=(res.traceback.encode("utf-8") if res.traceback is not None else b""), + ) + else: + result = res.result["result"] + # Reset model name to correct value + if "DONE" in result: + continue + result["model"] = model_endpoint.record.name + yield CompletionV2StreamSuccessChunk.model_validate(result) + + def validate_endpoint_supports_chat_completion( endpoint: ModelEndpoint, endpoint_content: GetLLMModelEndpointV1Response -): +): # pragma: no cover if endpoint_content.inference_framework not in CHAT_SUPPORTED_INFERENCE_FRAMEWORKS: raise EndpointUnsupportedInferenceTypeException( f"The endpoint's inference framework ({endpoint_content.inference_framework}) does not support chat completion." @@ -2585,7 +2849,7 @@ def __init__( async def execute( self, model_endpoint_name: str, request: ChatCompletionV2Request, user: User - ) -> AsyncIterable[ChatCompletionV2SuccessChunk]: # pragma: no cover + ) -> AsyncGenerator[ChatCompletionV2StreamSuccessChunk, None]: # pragma: no cover request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) @@ -2654,7 +2918,7 @@ async def _response_chunk_generator( model_content: GetLLMModelEndpointV1Response, inference_gateway: StreamingModelEndpointInferenceGateway, inference_request: SyncEndpointPredictV1Request, - ) -> AsyncIterable[ChatCompletionV2SuccessChunk]: + ) -> AsyncGenerator[ChatCompletionV2StreamSuccessChunk, None]: """ Async generator yielding tokens to stream for the completions response. Should only be called when returned directly by execute(). @@ -2683,7 +2947,7 @@ async def _response_chunk_generator( if "DONE" in result: continue result["model"] = model_endpoint.record.name - yield ChatCompletionV2SuccessChunk.model_validate(result) + yield ChatCompletionV2StreamSuccessChunk.model_validate(result) class ModelDownloadV1UseCase: diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 016ded85..a4daff0c 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -119,7 +119,8 @@ def after_return( ) request_params = args[0] request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params) - forwarder.post_inference_hooks_handler.handle(request_params_pydantic, retval, task_id) # type: ignore + if forwarder.post_inference_hooks_handler: + forwarder.post_inference_hooks_handler.handle(request_params_pydantic, retval, task_id) # type: ignore # See documentation for options: # https://docs.celeryproject.org/en/stable/userguide/tasks.html#list-of-options diff --git a/model-engine/model_engine_server/inference/forwarding/echo_server.py b/model-engine/model_engine_server/inference/forwarding/echo_server.py index 6ed33d40..3d6333a3 100644 --- a/model-engine/model_engine_server/inference/forwarding/echo_server.py +++ b/model-engine/model_engine_server/inference/forwarding/echo_server.py @@ -3,8 +3,8 @@ """ import argparse +import asyncio import subprocess -import time from fastapi import FastAPI, Request from fastapi.responses import JSONResponse @@ -13,17 +13,20 @@ app = FastAPI() +@app.get("/health") @app.get("/healthz") @app.get("/readyz") def healthcheck(): return "OK" +@app.post("/v1/chat/completions") @app.post("/predict") async def predict(request: Request): dictionary = await request.json() + print("Received request", dictionary, flush=True) if "delay" in dictionary: - time.sleep(dictionary["delay"]) + await asyncio.sleep(dictionary["delay"]) return dictionary diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 3cc53d7c..f471ec64 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -4,13 +4,16 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Any, Iterable, List, Optional, Sequence, Tuple +from typing import Any, AsyncGenerator, Iterable, List, Optional, Sequence, Tuple +import aiohttp +import orjson import requests import sseclient import yaml from fastapi import HTTPException from fastapi.responses import JSONResponse +from model_engine_server.common.aiohttp_sse_client import EventSource from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.inference.common import get_endpoint_config from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( @@ -122,6 +125,10 @@ def parse_to_object_or_string(value: str) -> object: return value +def _serialize_json(data) -> str: + return orjson.dumps(data).decode() + + @dataclass class Forwarder(ModelEngineSerializationMixin): """Forwards inference requests to another service via HTTP POST. @@ -142,7 +149,49 @@ class Forwarder(ModelEngineSerializationMixin): serialize_results_as_string: bool wrap_response: bool forward_http_status: bool - post_inference_hooks_handler: PostInferenceHooksHandler + post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None + + async def forward(self, json_payload: Any) -> Any: + json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) + json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload + + logger.info(f"Accepted request, forwarding {json_payload_repr=}") + + try: + async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: + response_raw = await aioclient.post( + self.predict_endpoint, + json=json_payload, + headers={"Content-Type": "application/json"}, + ) + response = await response_raw.json() + + except Exception: + logger.exception( + f"Failed to get response for request ({json_payload_repr}) " + "from user-defined inference service." + ) + raise + if isinstance(response, dict): + logger.info( + f"Got response from user-defined service: {response.keys()=}, {response_raw.status=}" + ) + elif isinstance(response, list): + logger.info( + f"Got response from user-defined service: {len(response)=}, {response_raw.status=}" + ) + else: + logger.info( + f"Got response from user-defined service: {response=}, {response_raw.status=}" + ) + + if self.wrap_response: + response = self.get_response_payload(using_serialize_results_as_string, response) + + if self.forward_http_status: + return JSONResponse(content=response, status_code=response_raw.status) + else: + return response def __call__(self, json_payload: Any) -> Any: json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) @@ -210,7 +259,7 @@ class LoadForwarder: wrap_response: bool = True forward_http_status: bool = False - def load(self, resources: Path, cache: Any) -> Forwarder: + def load(self, resources: Optional[Path], cache: Any) -> Forwarder: if self.use_grpc: raise NotImplementedError( "User-defined service **MUST** use HTTP at the moment. " @@ -288,23 +337,26 @@ def endpoint(route: str) -> str: else: serialize_results_as_string = self.serialize_results_as_string - endpoint_config = get_endpoint_config() - handler = PostInferenceHooksHandler( - endpoint_name=endpoint_config.endpoint_name, - bundle_name=endpoint_config.bundle_name, - post_inference_hooks=endpoint_config.post_inference_hooks, - user_id=endpoint_config.user_id, - billing_queue=endpoint_config.billing_queue, - billing_tags=endpoint_config.billing_tags, - default_callback_url=endpoint_config.default_callback_url, - default_callback_auth=endpoint_config.default_callback_auth, - monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), - endpoint_id=endpoint_config.endpoint_id, - endpoint_type=endpoint_config.endpoint_type, - bundle_id=endpoint_config.bundle_id, - labels=endpoint_config.labels, - streaming_storage_gateway=FirehoseStreamingStorageGateway(), - ) + try: + endpoint_config = get_endpoint_config() + handler = PostInferenceHooksHandler( + endpoint_name=endpoint_config.endpoint_name, + bundle_name=endpoint_config.bundle_name, + post_inference_hooks=endpoint_config.post_inference_hooks, + user_id=endpoint_config.user_id, + billing_queue=endpoint_config.billing_queue, + billing_tags=endpoint_config.billing_tags, + default_callback_url=endpoint_config.default_callback_url, + default_callback_auth=endpoint_config.default_callback_auth, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id=endpoint_config.endpoint_id, + endpoint_type=endpoint_config.endpoint_type, + bundle_id=endpoint_config.bundle_id, + labels=endpoint_config.labels, + streaming_storage_gateway=FirehoseStreamingStorageGateway(), + ) + except Exception: + handler = None return Forwarder( predict_endpoint=pred, @@ -335,7 +387,38 @@ class StreamingForwarder(ModelEngineSerializationMixin): predict_endpoint: str model_engine_unwrap: bool serialize_results_as_string: bool - post_inference_hooks_handler: PostInferenceHooksHandler # unused for now + post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None # unused for now + + async def forward(self, json_payload: Any) -> AsyncGenerator[Any, None]: # pragma: no cover + json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) + json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload + + logger.info(f"Accepted request, forwarding {json_payload_repr=}") + + try: + response: aiohttp.ClientResponse + async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: + response = await aioclient.post( + self.predict_endpoint, + json=json_payload, + headers={"Content-Type": "application/json"}, + ) + + if response.status != 200: + raise HTTPException(status_code=response.status, detail=await response.json()) + + async with EventSource(response=response) as event_source: + async for event in event_source: + yield self.get_response_payload_stream( + using_serialize_results_as_string, event.data + ) + + except Exception: + logger.exception( + f"Failed to get response for request ({json_payload_repr}) " + "from user-defined inference service." + ) + raise def __call__(self, json_payload: Any) -> Iterable[Any]: json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) @@ -354,7 +437,6 @@ def __call__(self, json_payload: Any) -> Iterable[Any]: ) if response.status_code != 200: - print(response.json()) raise HTTPException(status_code=response.status_code, detail=response.json()) except Exception: @@ -396,7 +478,7 @@ class LoadStreamingForwarder: model_engine_unwrap: bool = True serialize_results_as_string: bool = False - def load(self, resources: Path, cache: Any) -> StreamingForwarder: + def load(self, resources: Optional[Path], cache: Any) -> StreamingForwarder: if self.use_grpc: raise NotImplementedError( "User-defined service **MUST** use HTTP at the moment. " @@ -474,23 +556,26 @@ def endpoint(route: str) -> str: else: serialize_results_as_string = self.serialize_results_as_string - endpoint_config = get_endpoint_config() - handler = PostInferenceHooksHandler( - endpoint_name=endpoint_config.endpoint_name, - bundle_name=endpoint_config.bundle_name, - post_inference_hooks=endpoint_config.post_inference_hooks, - user_id=endpoint_config.user_id, - billing_queue=endpoint_config.billing_queue, - billing_tags=endpoint_config.billing_tags, - default_callback_url=endpoint_config.default_callback_url, - default_callback_auth=endpoint_config.default_callback_auth, - monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), - endpoint_id=endpoint_config.endpoint_id, - endpoint_type=endpoint_config.endpoint_type, - bundle_id=endpoint_config.bundle_id, - labels=endpoint_config.labels, - streaming_storage_gateway=FirehoseStreamingStorageGateway(), - ) + try: + endpoint_config = get_endpoint_config() + handler = PostInferenceHooksHandler( + endpoint_name=endpoint_config.endpoint_name, + bundle_name=endpoint_config.bundle_name, + post_inference_hooks=endpoint_config.post_inference_hooks, + user_id=endpoint_config.user_id, + billing_queue=endpoint_config.billing_queue, + billing_tags=endpoint_config.billing_tags, + default_callback_url=endpoint_config.default_callback_url, + default_callback_auth=endpoint_config.default_callback_auth, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id=endpoint_config.endpoint_id, + endpoint_type=endpoint_config.endpoint_type, + bundle_id=endpoint_config.bundle_id, + labels=endpoint_config.labels, + streaming_storage_gateway=FirehoseStreamingStorageGateway(), + ) + except Exception: + handler = None return StreamingForwarder( predict_endpoint=pred, diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index 00a0c19c..89fcb3fb 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -34,7 +34,7 @@ def get_config(): ) -def get_forwarder_loader(destination_path: Optional[str] = None): +def get_forwarder_loader(destination_path: Optional[str] = None) -> LoadForwarder: config = get_config()["sync"] if "extra_routes" in config: del config["extra_routes"] @@ -44,7 +44,9 @@ def get_forwarder_loader(destination_path: Optional[str] = None): return forwarder_loader -def get_streaming_forwarder_loader(destination_path: Optional[str] = None): +def get_streaming_forwarder_loader( + destination_path: Optional[str] = None, +) -> LoadStreamingForwarder: config = get_config()["stream"] if "extra_routes" in config: del config["extra_routes"] @@ -55,7 +57,7 @@ def get_streaming_forwarder_loader(destination_path: Optional[str] = None): @lru_cache() -def get_concurrency_limiter(): +def get_concurrency_limiter() -> MultiprocessingConcurrencyLimiter: config = get_config() concurrency = int(config.get("max_concurrency", 100)) return MultiprocessingConcurrencyLimiter( @@ -64,27 +66,28 @@ def get_concurrency_limiter(): @lru_cache() -def load_forwarder(destination_path: Optional[str] = None): +def load_forwarder(destination_path: Optional[str] = None) -> Forwarder: return get_forwarder_loader(destination_path).load(None, None) @lru_cache() -def load_streaming_forwarder(destination_path: Optional[str] = None): +def load_streaming_forwarder(destination_path: Optional[str] = None) -> StreamingForwarder: return get_streaming_forwarder_loader(destination_path).load(None, None) -def predict( +async def predict( request: EndpointPredictV1Request, background_tasks: BackgroundTasks, - forwarder=Depends(load_forwarder), - limiter=Depends(get_concurrency_limiter), + forwarder: Forwarder = Depends(load_forwarder), + limiter: MultiprocessingConcurrencyLimiter = Depends(get_concurrency_limiter), ): with limiter: try: - response = forwarder(request.model_dump()) - background_tasks.add_task( - forwarder.post_inference_hooks_handler.handle, request, response - ) + response = await forwarder.forward(request.model_dump()) + if forwarder.post_inference_hooks_handler: + background_tasks.add_task( + forwarder.post_inference_hooks_handler.handle, request, response + ) return response except Exception: logger.error(f"Failed to decode payload from: {request}") @@ -93,8 +96,8 @@ def predict( async def stream( request: EndpointPredictV1Request, - forwarder=Depends(load_streaming_forwarder), - limiter=Depends(get_concurrency_limiter), + forwarder: StreamingForwarder = Depends(load_streaming_forwarder), + limiter: MultiprocessingConcurrencyLimiter = Depends(get_concurrency_limiter), ): with limiter: try: @@ -105,10 +108,16 @@ async def stream( else: logger.debug(f"Received request: {payload}") - responses = forwarder(payload) + responses = forwarder.forward(payload) + # We fetch the first response to check if upstream request was successful + # If it was not, this will raise the corresponding HTTPException + # If it was, we will proceed to the event generator + initial_response = await responses.__anext__() async def event_generator(): - for response in responses: + yield {"data": orjson.dumps(initial_response).decode("utf-8")} + + async for response in responses: yield {"data": orjson.dumps(response).decode("utf-8")} return EventSourceResponse(event_generator()) @@ -181,6 +190,13 @@ def add_extra_routes(app: FastAPI): all_routes = set(list(sync_forwarders.keys()) + list(stream_forwarders.keys())) for route in all_routes: + + def get_sync_forwarder(route=route): + return sync_forwarders.get(route) + + def get_stream_forwarder(route=route): + return stream_forwarders.get(route) + # This route is a catch-all for any requests that don't match the /predict or /stream routes # It will treat the request as a streaming request if the "stream" body parameter is set to true # NOTE: it is important for this to be defined AFTER the /predict and /stream endpoints @@ -188,8 +204,8 @@ def add_extra_routes(app: FastAPI): async def predict_or_stream( request: EndpointPredictV1Request, background_tasks: BackgroundTasks, - sync_forwarder=Depends(lambda: sync_forwarders.get(route)), - stream_forwarder=Depends(lambda: stream_forwarders.get(route)), + sync_forwarder: Forwarder = Depends(get_sync_forwarder), + stream_forwarder: StreamingForwarder = Depends(get_stream_forwarder), limiter=Depends(get_concurrency_limiter), ): if not request.args: @@ -197,7 +213,7 @@ async def predict_or_stream( if request.args.root.get("stream", False) and stream_forwarder: return await stream(request, stream_forwarder, limiter) elif request.args.root.get("stream") is not True and sync_forwarder: - return predict(request, background_tasks, sync_forwarder, limiter) + return await predict(request, background_tasks, sync_forwarder, limiter) else: raise Exception("No forwarder configured for this route") diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index 9cda4b9c..7b4a0e7d 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -4,6 +4,7 @@ import orjson import requests import sseclient +from model_engine_server.common.aiohttp_sse_client import EventSource from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.tasks import ( SyncEndpointPredictV1Request, @@ -22,7 +23,6 @@ from model_engine_server.domain.gateways.streaming_model_endpoint_inference_gateway import ( StreamingModelEndpointInferenceGateway, ) -from model_engine_server.infra.gateways.aiohttp_sse_client import EventSource from model_engine_server.infra.gateways.k8s_resource_parser import get_node_port from orjson import JSONDecodeError from tenacity import ( @@ -90,7 +90,6 @@ async def make_single_request(self, request_url: str, payload_json: Dict[str, An headers={"Content-Type": "application/json"}, ) status = aio_resp.status - print(status) if status == 200: async with EventSource(response=aio_resp) as event_source: async for event in event_source: diff --git a/model-engine/requirements-test.txt b/model-engine/requirements-test.txt index 0f7cd2ec..0115722b 100644 --- a/model-engine/requirements-test.txt +++ b/model-engine/requirements-test.txt @@ -1,3 +1,4 @@ +aioresponses>=0.7.6 coverage==5.5 diff-cover==7.7.0 frozendict==2.3.4 diff --git a/model-engine/tests/unit/inference/test_http_forwarder.py b/model-engine/tests/unit/inference/test_http_forwarder.py index e765accc..0edfb444 100644 --- a/model-engine/tests/unit/inference/test_http_forwarder.py +++ b/model-engine/tests/unit/inference/test_http_forwarder.py @@ -6,7 +6,8 @@ import pytest import requests_mock -from fastapi import BackgroundTasks +from aioresponses import aioresponses +from fastapi import BackgroundTasks, FastAPI from fastapi.responses import JSONResponse from fastapi.testclient import TestClient from model_engine_server.common.dtos.tasks import EndpointPredictV1Request @@ -310,15 +311,15 @@ def mocked_get_endpoint_config(): "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", mocked_get_endpoint_config, ) -async def mocked_app(): +async def mocked_app() -> FastAPI: with requests_mock.Mocker() as req_mock: healthcheck_endpoint = get_healthcheck_endpoint(mocked_get_config_with_extra_paths()) - print(healthcheck_endpoint) req_mock.get( healthcheck_endpoint, json={"status": "ok"}, ) - return await init_app() + app = await init_app() + return app def wrap_request(request): @@ -356,14 +357,19 @@ async def test_mocked_app_success(mocked_app): expected_result = wrap_result( json.dumps(raw_result) if config_sync["serialize_results_as_string"] else raw_result ) - with TestClient(mocked_app) as client, requests_mock.Mocker() as req_mock: - req_mock.get(healthcheck_endpoint, json={"status": "ok"}) - req_mock.post(predict_endpoint, json=raw_result) + with TestClient( + mocked_app + ) as client, aioresponses() as aio_mock, requests_mock.Mocker() as req_mock: + req_mock.get( + healthcheck_endpoint, + json={"status": "ok"}, + ) + aio_mock.post(predict_endpoint, status=200, payload=raw_result) response = client.post("/predict", json=payload) assert response.status_code == 200 assert response.json() == expected_result - req_mock.post(chat_endpoint, json=raw_result) + aio_mock.post(chat_endpoint, status=200, payload=raw_result) response = client.post("/v1/chat/completions", json=payload) assert response.status_code == 200 assert response.json() == expected_result From 1b8ee43239be95578ce3c16a1f194bb7994a1ff6 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 4 Oct 2024 15:30:12 -0700 Subject: [PATCH 393/425] Add completion routes to main router (#628) * Add completion routes to main router * Fix completion stream api * Relax content type expectation from inference server --- model-engine/model_engine_server/api/v2/__init__.py | 2 ++ .../model_engine_server/api/v2/chat_completion.py | 1 - model-engine/model_engine_server/api/v2/completion.py | 3 ++- .../model_engine_server/common/types/gen/openai.py | 4 ++-- .../inference/forwarding/forwarding.py | 8 ++++++-- .../live_streaming_model_endpoint_inference_gateway.py | 1 - scripts/openai-spec.yaml | 1 + 7 files changed, 13 insertions(+), 7 deletions(-) diff --git a/model-engine/model_engine_server/api/v2/__init__.py b/model-engine/model_engine_server/api/v2/__init__.py index d8d906b6..abb0fdec 100644 --- a/model-engine/model_engine_server/api/v2/__init__.py +++ b/model-engine/model_engine_server/api/v2/__init__.py @@ -4,9 +4,11 @@ from .batch_completion import batch_completions_router_v2 from .chat_completion import chat_router_v2 +from .completion import completion_router_v2 llm_router_v2 = APIRouter(prefix="/v2") llm_router_v2.include_router(batch_completions_router_v2) llm_router_v2.include_router(chat_router_v2) +llm_router_v2.include_router(completion_router_v2) __all__: Sequence[str] = ("llm_router_v2",) diff --git a/model-engine/model_engine_server/api/v2/chat_completion.py b/model-engine/model_engine_server/api/v2/chat_completion.py index b4b02837..614f159d 100644 --- a/model-engine/model_engine_server/api/v2/chat_completion.py +++ b/model-engine/model_engine_server/api/v2/chat_completion.py @@ -135,7 +135,6 @@ async def event_generator(timer: timer = use_case_timer): ttft = timer.lap() # if ttft is None and message.startswith("data"): # ttft = timer.lap() - print("message", message.model_dump_json(exclude_none=True)) yield {"data": message.model_dump_json(exclude_none=True)} if message: diff --git a/model-engine/model_engine_server/api/v2/completion.py b/model-engine/model_engine_server/api/v2/completion.py index 250acd83..ed529fe3 100644 --- a/model-engine/model_engine_server/api/v2/completion.py +++ b/model-engine/model_engine_server/api/v2/completion.py @@ -13,6 +13,7 @@ from model_engine_server.common.dtos.llms import ( CompletionV2Request, CompletionV2Response, + CompletionV2ResponseItem, CompletionV2StreamErrorChunk, StreamError, StreamErrorContent, @@ -246,7 +247,7 @@ def to_error_details(exc: Exception) -> Any: return exc.args -@completion_router_v2.post("/completions", response_model=CompletionV2Response) +@completion_router_v2.post("/completions", response_model=CompletionV2ResponseItem) async def completion( request: CompletionV2Request, background_tasks: BackgroundTasks, diff --git a/model-engine/model_engine_server/common/types/gen/openai.py b/model-engine/model_engine_server/common/types/gen/openai.py index f9444769..964d0c33 100644 --- a/model-engine/model_engine_server/common/types/gen/openai.py +++ b/model-engine/model_engine_server/common/types/gen/openai.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openai-spec.yaml -# timestamp: 2024-09-30T08:39:28+00:00 +# timestamp: 2024-10-04T21:01:02+00:00 from __future__ import annotations @@ -79,7 +79,7 @@ class Logprobs(BaseModel): class Choice(BaseModel): finish_reason: Annotated[ - Literal["stop", "length", "content_filter"], + Optional[Literal["stop", "length", "content_filter"]], Field( description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\nor `content_filter` if content was omitted due to a flag from our content filters.\n" ), diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index f471ec64..c3970107 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -164,7 +164,9 @@ async def forward(self, json_payload: Any) -> Any: json=json_payload, headers={"Content-Type": "application/json"}, ) - response = await response_raw.json() + response = await response_raw.json( + content_type=None + ) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error except Exception: logger.exception( @@ -405,7 +407,9 @@ async def forward(self, json_payload: Any) -> AsyncGenerator[Any, None]: # prag ) if response.status != 200: - raise HTTPException(status_code=response.status, detail=await response.json()) + raise HTTPException( + status_code=response.status, detail=await response.json(content_type=None) + ) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error async with EventSource(response=response) as event_source: async for event in event_source: diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index 7b4a0e7d..00c51e87 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -212,7 +212,6 @@ async def streaming_predict( timeout_seconds=timeout_seconds, num_retries=num_retries, ) - print(response) async for item in response: yield SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=item) except UpstreamServiceError as exc: diff --git a/scripts/openai-spec.yaml b/scripts/openai-spec.yaml index 7ec21c7f..6eb3f1cf 100644 --- a/scripts/openai-spec.yaml +++ b/scripts/openai-spec.yaml @@ -8940,6 +8940,7 @@ components: `length` if the maximum number of tokens specified in the request was reached, or `content_filter` if content was omitted due to a flag from our content filters. enum: ["stop", "length", "content_filter"] + nullable: true index: type: integer logprobs: From 9c5579f08cb3bafaa5fb7d199544d8dd585f5b70 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 8 Oct 2024 16:16:04 -0700 Subject: [PATCH 394/425] Add chat template override to client (#629) --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/data_types/rest.py | 21 +++++++++++++++++++++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 40ea1a28..21bed5ce 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0beta39" +__version__ = "0.0.0beta40" import os from typing import Sequence diff --git a/clients/python/llmengine/data_types/rest.py b/clients/python/llmengine/data_types/rest.py index 33f62750..d24b08b1 100644 --- a/clients/python/llmengine/data_types/rest.py +++ b/clients/python/llmengine/data_types/rest.py @@ -177,6 +177,10 @@ class CreateLLMEndpointRequest(BaseModel): """ Whether the endpoint can be used for inference for all users. LLM endpoints are public by default. """ + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) class CreateLLMEndpointResponse(BaseModel): @@ -237,6 +241,11 @@ class GetLLMEndpointResponse(BaseModel): ) """(For self-hosted users) Model endpoint details.""" + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + class ListLLMEndpointsResponse(BaseModel): """ @@ -291,6 +300,18 @@ class UpdateLLMEndpointRequest(BaseModel): default_callback_url: Optional[HttpUrl] default_callback_auth: Optional[CallbackAuth] public_inference: Optional[bool] + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + force_bundle_recreation: Optional[bool] = False + """ + Whether to force recreate the underlying bundle. + + If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created + that we would like to pick up for existing endpoints + """ class UpdateLLMEndpointResponse(BaseModel): diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 2d011128..9f963abb 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta39" +version = "0.0.0.beta40" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 73db3f02..ea6c5e02 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.8", - version="0.0.0.beta39", + version="0.0.0.beta40", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) From 8830f893611a8d97f0c67c10f3dd55baace04c05 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 9 Oct 2024 11:50:02 -0700 Subject: [PATCH 395/425] Multinode serving (#574) API/schema changes: Bundles/bundle v2: Use the new fields in the entity for doing multinode Endpoints: nodes_per_worker param that controls whether to use a LWS or not Create LWS endpoint, delete LWS endpoint is allowed Modifying a LWS endpoint to a non-LWS endpoint or modifying a non-LWS endpoint into a LWS endpoint is not allowed (similar to how we don't allow switching from sync and async) (also changing nodes_per_worker will not be allowed). Enforced by this not being capable through the API. Get LWS endpoint is allowed Only available LWS endpoint type will be gpu + streaming, cpu-only LWS doesn't really make sense imo, async/sync could make sense but we don't need it for LLM serving at this point k8s_erd get_resources + get_all_resources are made LWS compatible LWS endpoints create a service (separate from the service that the LWS auto-creates) + (if istio is enabled) a ServiceEntry. If istio is enabled, gateway will hackily manually resolve the service's IP and make a request there. LLM endpoints: create a multinode bundle if situation calls for it, use endpoint's api to include this new bundle (TODO reflect in production) build new vllm image to use multinode Endpoint will do tensorparallel=(# gpus in a single node), pipelineparallel=(# nodes) Python client: Add way to set nodes_per_worker on create + get nodes_per_worker on get/list * mark a bunch of todos, maybe not complete though * more todos in k8s_endpoint_resource_delegate * mark more todos * partially string nodes_per_worker through * wip k8s_endpoint_resource_delegate * lws service template config map yaml * delete lws * small things for delete * create_lws * mark more places * think this is sufficient for LWS args? * fix most of the domain test cases * fix rest of the domain test cases * partial llm stuff * validation of bundle extra params in endpoint use case layer * fix domain tests * try fixing more tests * fix more tests * don't set nodes per worker in update pls * validate endpoint + multinode compat, more lws loading stuff * k8s resource delegate stuff * . * add code to allow vllm to do multinode * start on creating the multinode bundle * create the bundle * try fixing tests? * fix remaining unit tests * screwed up the merge oops * temp turn off cache for testing, also try fixing the service template config map * hopefully the autogen templates was correct lol * oops need to await * fix a few typos * one more typo * refactor out _get_deployment * wip refactored out get_resources into deployment/lws types, todo the lws code * get_resources for lws, almost done * priority class * get_all_deployments * black * comments * try making the custom_obj_client calls return ApiException * a bunch of test stubs * one more test * client * comment out model endpoint infra gateway update_model_endpoint_infra nodes_per_worker * fill in some tests * more domain tests * get multinode test * fix more tests * stub test * delete some things that were added back in the merge conflict * fix some semantic merge broken things * delete test update multinode since that's not allowed in the api * update client * unmark some todos * mark more todos * unmark more todos * fix test * remove a test that isn't allowed in the api * more test * get test mostly working * autogen template * get test to pass * format * add explicit resource * use the example config * silly bug * fix test * . * update multinode in gateway test case * some cleanup * strip out that worker env/command metadata hack * black * turn cache back on * unmark todos that I've done * clean up more todos, add multinode deployment validation to live endpoint builder svc * cleanup * try commenting out some mocks since we might not need to mock them * fix test, blackwell isn't out yet * . * uncomment the config map stuff * . * more cleanup * oops * k8s doesn't allow underscores dang it * vllm worker doesn't get ray cluster size * oops * aah * try putting lws specific container envs first * try using LWS_LEADER_ADDRESS since that seems like it's provided * oops can't do own address * add labels to the leader/worker tpl for selecting later * create new service * dumb mistake * add lws service yaml + code, add unit tests * autogen tpl * oops * fix test * delete the right service * black * todos * add actual using of the multinode workers in the bundle command * yaml for lws service entry * lws service entry crud * thread the manually_make_dns_query param through, the actual code isn't implemented yet though * only when istio enabled * implementation for manually resolve dns * implementation for manually resolve dns * dumb * vllm host ip * tmp comment out the pipeline parallel stuff to test the connectivity stuff first * patch vllm host ip * uncomment pipeline parallel * autogen template * fix tests * black * clean up todos * respond to some smaller comments * respond to more comments re creating the bundle command * comment wording * respond to a few comments * more comments * comment: remove env var * comment: pull out is_multinode * clarify some documentation in the client * more docs * more explanations * comment * more comments * comment * comments * . * remove the extra configmap call --- .../service_template_config_map.yaml | 417 +++++++++ charts/model-engine/values_circleci.yaml | 18 +- charts/model-engine/values_sample.yaml | 18 +- clients/python/llmengine/data_types/rest.py | 6 + clients/python/llmengine/model.py | 32 +- .../api/model_endpoints_v1.py | 2 +- .../common/dtos/batch_jobs.py | 3 + .../common/dtos/endpoint_builder.py | 1 + .../common/dtos/llms/model_endpoints.py | 1 + .../common/dtos/model_endpoints.py | 1 + .../common/resource_limits.py | 2 +- .../domain/entities/model_endpoint_entity.py | 7 + ...eaming_model_endpoint_inference_gateway.py | 2 +- .../sync_model_endpoint_inference_gateway.py | 2 +- .../domain/services/model_endpoint_service.py | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 338 +++++-- .../use_cases/model_endpoint_use_cases.py | 49 ++ .../streaming_inference_use_cases.py | 14 +- .../use_cases/sync_inference_use_cases.py | 14 +- .../inference/vllm/Dockerfile.vllm | 1 + .../inference/vllm/init_ray.sh | 101 +++ .../infra/gateways/dns_resolver.py | 22 + .../live_model_endpoint_infra_gateway.py | 5 + ...eaming_model_endpoint_inference_gateway.py | 25 +- ...e_sync_model_endpoint_inference_gateway.py | 25 +- .../gateways/model_endpoint_infra_gateway.py | 1 + .../k8s_endpoint_resource_delegate.py | 738 ++++++++++++++-- .../gateways/resources/k8s_resource_types.py | 139 +++ .../live_endpoint_resource_gateway.py | 4 +- .../service_template_config_map_circleci.yaml | 708 ++++++++++++--- .../infra/services/live_batch_job_service.py | 1 + .../services/live_endpoint_builder_service.py | 8 + .../services/live_model_endpoint_service.py | 2 + model-engine/tests/unit/api/conftest.py | 5 + .../tests/unit/api/test_model_endpoints.py | 26 + model-engine/tests/unit/conftest.py | 63 +- model-engine/tests/unit/domain/conftest.py | 43 + .../tests/unit/domain/test_llm_use_cases.py | 45 + .../domain/test_model_endpoint_use_cases.py | 58 ++ .../resources/example_lws_config.json | 829 ++++++++++++++++++ .../test_k8s_endpoint_resource_delegate.py | 139 ++- .../test_live_model_endpoint_infra_gateway.py | 32 + .../tests/unit/infra/repositories/conftest.py | 1 + .../test_live_model_endpoint_service.py | 2 + 44 files changed, 3651 insertions(+), 300 deletions(-) create mode 100755 model-engine/model_engine_server/inference/vllm/init_ray.sh create mode 100644 model-engine/model_engine_server/infra/gateways/dns_resolver.py create mode 100644 model-engine/tests/unit/infra/gateways/resources/example_lws_config.json diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 1ce5d2d8..6ef924df 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -472,6 +472,387 @@ data: unsafeSsl: "false" databaseIndex: "${REDIS_DB_INDEX}" {{- end }} + {{- range $device := tuple "gpu" }} + {{- range $mode := tuple "streaming"}} + leader-worker-set-{{ $mode }}-{{ $device }}.yaml: |- + apiVersion: leaderworkerset.x-k8s.io/v1 + kind: LeaderWorkerSet + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + replicas: ${MIN_WORKERS} + leaderWorkerTemplate: + size: ${LWS_SIZE} + restartPolicy: RecreateGroupOnPodRestart # TODO un-hardcode? if necessary + leaderTemplate: + metadata: + labels: + app: ${RESOURCE_NAME} + role: leader + {{- $service_template_labels | nindent 14 }} + sidecar.istio.io/inject: "false" # Never inject istio, it screws up networking + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' + kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" + spec: + affinity: + {{- include "modelEngine.serviceTemplateAffinity" . | nindent 14 }} + {{- if eq $mode "async" }} # TODO + terminationGracePeriodSeconds: 1800 + {{- else }} + terminationGracePeriodSeconds: 600 + {{- end }} + {{- if $service_template_service_account_name }} + serviceAccount: {{ $service_template_service_account_name }} + {{- else }} + serviceAccount: {{ $launch_name }} + {{- end }} + {{- with $node_selector }} + nodeSelector: + {{- toYaml . | nindent 14 }} + {{- end }} + {{- if eq $device "gpu" }} + {{- if empty $node_selector }} + nodeSelector: + {{- end }} + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + {{- end }} + priorityClassName: ${PRIORITY} + containers: + {{- if eq $mode "sync" }} + - name: http-forwarder + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - --port + - "${FORWARDER_PORT}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + {{- $sync_forwarder_template_env | nindent 16 }} + readinessProbe: + httpGet: + path: /readyz + port: ${FORWARDER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + {{ $forwarder_volume_mounts | nindent 16 }} + ports: + - containerPort: ${FORWARDER_PORT} + name: http + {{- else if eq $mode "streaming" }} + - name: http-forwarder + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml + - --port + - "${FORWARDER_PORT}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.stream.predict_route=${STREAMING_PREDICT_ROUTE}" + - --set + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + {{- $sync_forwarder_template_env | nindent 16 }} + readinessProbe: + httpGet: + path: /readyz + port: ${FORWARDER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + {{ $forwarder_volume_mounts | nindent 16 }} + ports: + - containerPort: ${FORWARDER_PORT} + name: http + {{- else if eq $mode "async" }} + - name: celery-forwarder + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - --queue + - "${QUEUE}" + - --task-visibility + - "VISIBILITY_24H" + - --set + - "forwarder.async.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" + {{- if eq $celery_broker_type "sqs" }} + - --sqs-url + - "${SQS_QUEUE_URL}" + {{- end }} + - --num-workers + - "${PER_WORKER}" + - --broker-type + - {{ $celery_broker_type }} + {{- if eq $celery_broker_type "servicebus" }} + - --backend-protocol + - abs + {{- end }} + {{- $async_forwarder_template_env | nindent 16 }} + resources: + requests: + cpu: 0.1 + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + {{ $forwarder_volume_mounts | nindent 16 }} + {{- end }} + - name: lws-leader + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + env: ${MAIN_ENV} + readinessProbe: + httpGet: + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + {{- if $require_aws_config }} + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + - mountPath: /dev/shm + name: dshm + {{- if $mount_infra_config }} + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + {{- end }} + - name: user-config + mountPath: /app/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /app/endpoint_config + subPath: raw_data + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http + volumes: + {{- if $require_aws_config }} + - name: config-volume + configMap: + {{- if $service_template_aws_config_map_name }} + name: {{ $service_template_aws_config_map_name }} + {{- else }} + name: {{ $aws_config_map_name }} + {{- end }} + {{- end }} + - name: user-config + configMap: + name: ${RESOURCE_NAME} + - name: endpoint-config + configMap: + name: ${RESOURCE_NAME}-endpoint-config + - name: dshm + emptyDir: + medium: Memory + {{- if $config_values }} + - name: infra-service-config-volume + configMap: + name: {{ $launch_name }}-service-config + items: + - key: infra_service_config + path: config.yaml + {{- end }} + workerTemplate: + metadata: + labels: + app: ${RESOURCE_NAME} + role: worker + {{- $service_template_labels | nindent 14 }} + sidecar.istio.io/inject: "false" # Never inject istio for LWS, it screws up networking + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' + kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" + spec: + affinity: + {{- include "modelEngine.serviceTemplateAffinity" . | nindent 14 }} + {{- if eq $mode "async" }} # TODO + terminationGracePeriodSeconds: 1800 + {{- else }} + terminationGracePeriodSeconds: 600 + {{- end }} + {{- if $service_template_service_account_name }} + serviceAccount: {{ $service_template_service_account_name }} + {{- else }} + serviceAccount: {{ $launch_name }} + {{- end }} + {{- with $node_selector }} + nodeSelector: + {{- toYaml . | nindent 14 }} + {{- end }} + {{- if eq $device "gpu" }} + {{- if empty $node_selector }} + nodeSelector: + {{- end }} + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + {{- end }} + priorityClassName: ${PRIORITY} + containers: + - name: lws-worker + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${WORKER_COMMAND} + env: ${WORKER_ENV} + resources: + requests: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + {{- if $require_aws_config }} + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + - mountPath: /dev/shm + name: dshm + {{- if $mount_infra_config }} + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + {{- end }} + - name: user-config + mountPath: /app/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /app/endpoint_config + subPath: raw_data + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http + volumes: + {{- if $require_aws_config }} + - name: config-volume + configMap: + {{- if $service_template_aws_config_map_name }} + name: {{ $service_template_aws_config_map_name }} + {{- else }} + name: {{ $aws_config_map_name }} + {{- end }} + {{- end }} + - name: user-config + configMap: + name: ${RESOURCE_NAME} + - name: endpoint-config + configMap: + name: ${RESOURCE_NAME}-endpoint-config + - name: dshm + emptyDir: + medium: Memory + {{- if $config_values }} + - name: infra-service-config-volume + configMap: + name: {{ $launch_name }}-service-config + items: + - key: infra_service_config + path: config.yaml + {{- end }} + {{- end }} # mode + {{- end }} # device service.yaml: |- apiVersion: v1 kind: Service @@ -490,6 +871,25 @@ data: protocol: TCP name: http ${NODE_PORT_DICT} + lws-service.yaml: |- + apiVersion: v1 + kind: Service + metadata: + name: ${SERVICE_NAME_OVERRIDE} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + type: ${SERVICE_TYPE} + selector: + app: ${RESOURCE_NAME} + role: leader + ports: + - port: 80 + targetPort: ${SERVICE_TARGET_PORT} + protocol: TCP + name: http + ${NODE_PORT_DICT} {{- if .Values.virtualservice.enabled }} virtual-service.yaml: |- apiVersion: networking.istio.io/v1alpha3 @@ -526,6 +926,23 @@ data: loadBalancer: simple: LEAST_REQUEST {{- end }} + lws-service-entry.yaml: |- + apiVersion: networking.istio.io/v1beta1 + kind: ServiceEntry + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + hosts: + - "${SERVICE_NAME_OVERRIDE}.${NAMESPACE}.svc.cluster.local" + location: MESH_EXTERNAL + ports: + - number: 80 + name: http + protocol: HTTP + resolution: NONE vertical-pod-autoscaler.yaml: |- apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index f897f4df..ba7fa812 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -242,36 +242,49 @@ recommendedHardware: memory: 24Gi storage: 80Gi gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 - gpu_memory_le: 48 cpus: 20 gpus: 2 memory: 48Gi storage: 80Gi gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 - gpu_memory_le: 96 cpus: 40 gpus: 4 memory: 96Gi storage: 96Gi gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 - gpu_memory_le: 180 cpus: 20 gpus: 2 memory: 160Gi storage: 160Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - gpu_memory_le: 320 cpus: 40 gpus: 4 memory: 320Gi storage: 320Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - gpu_memory_le: 640 cpus: 80 gpus: 8 memory: 800Gi storage: 640Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 1280 + cpus: 80 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 2 byModelName: - name: llama-3-8b-instruct-262k cpus: 20 @@ -279,15 +292,18 @@ recommendedHardware: memory: 40Gi storage: 40Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - name: deepseek-coder-v2 cpus: 160 gpus: 8 memory: 800Gi storage: 640Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - name: deepseek-coder-v2-instruct cpus: 160 gpus: 8 memory: 800Gi storage: 640Gi - gpu_type: nvidia-hopper-h100 \ No newline at end of file + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 \ No newline at end of file diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 7eb04e52..a9d8d7e0 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -325,36 +325,49 @@ recommendedHardware: memory: 24Gi storage: 80Gi gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 - gpu_memory_le: 48 cpus: 20 gpus: 2 memory: 48Gi storage: 80Gi gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 - gpu_memory_le: 96 cpus: 40 gpus: 4 memory: 96Gi storage: 96Gi gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 - gpu_memory_le: 180 cpus: 20 gpus: 2 memory: 160Gi storage: 160Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - gpu_memory_le: 320 cpus: 40 gpus: 4 memory: 320Gi storage: 320Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - gpu_memory_le: 640 cpus: 80 gpus: 8 memory: 800Gi storage: 640Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 640 + cpus: 80 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 2 byModelName: - name: llama-3-8b-instruct-262k cpus: 20 @@ -362,15 +375,18 @@ recommendedHardware: memory: 40Gi storage: 40Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - name: deepseek-coder-v2 cpus: 160 gpus: 8 memory: 800Gi storage: 640Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - name: deepseek-coder-v2-instruct cpus: 160 gpus: 8 memory: 800Gi storage: 640Gi - gpu_type: nvidia-hopper-h100 \ No newline at end of file + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 \ No newline at end of file diff --git a/clients/python/llmengine/data_types/rest.py b/clients/python/llmengine/data_types/rest.py index d24b08b1..88aa3ad6 100644 --- a/clients/python/llmengine/data_types/rest.py +++ b/clients/python/llmengine/data_types/rest.py @@ -86,6 +86,10 @@ class ModelEndpointDeploymentState(BaseModel): class ModelEndpointResourceState(BaseModel): """ This is the entity-layer class for the resource settings per worker of a Model Endpoint. + Note: the values for cpus/gpus/memory/storage are per node, i.e. a single "worker" may consist of + multiple underlying "nodes" (corresponding to kubernetes pods), and the values for cpus/gpus/memory/storage + are the resources allocated for a single node. Thus, the total resource allocation + for the entire worker is multiplied by the value of `nodes_per_worker`. """ cpus: CpuSpecificationType # TODO(phil): try to use decimal.Decimal @@ -93,6 +97,7 @@ class ModelEndpointResourceState(BaseModel): memory: StorageSpecificationType gpu_type: Optional[GpuType] storage: Optional[StorageSpecificationType] + nodes_per_worker: int = Field(..., ge=1) # Multinode support. >1 = multinode. optimize_costs: Optional[bool] @@ -164,6 +169,7 @@ class CreateLLMEndpointRequest(BaseModel): memory: Optional[StorageSpecificationType] gpu_type: Optional[GpuType] storage: Optional[StorageSpecificationType] + nodes_per_worker: Optional[int] = None optimize_costs: Optional[bool] = None min_workers: int max_workers: int diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index cca90657..c03be3f5 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -47,6 +47,7 @@ def create( memory: Optional[str] = None, storage: Optional[str] = None, gpus: Optional[int] = None, + nodes_per_worker: int = 1, min_workers: int = 0, max_workers: int = 1, per_worker: int = 2, @@ -93,22 +94,35 @@ def create( For model weights, safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be longer). cpus (`Optional[int]`): - Number of cpus each worker should get, e.g. 1, 2, etc. This must be greater + Number of cpus each node in the worker should get, e.g. 1, 2, etc. This must be greater than or equal to 1. Recommendation is set it to 8 * GPU count. Can be inferred from the model size. memory (`Optional[str]`): - Amount of memory each worker should get, e.g. "4Gi", "512Mi", etc. This must + Amount of memory each node in the worker should get, e.g. "4Gi", "512Mi", etc. This must be a positive amount of memory. Recommendation is set it to 24Gi * GPU count. Can be inferred from the model size. storage (`Optional[str]`): - Amount of local ephemeral storage each worker should get, e.g. "4Gi", + Amount of local ephemeral storage each node in the worker should get, e.g. "4Gi", "512Mi", etc. This must be a positive amount of storage. Recommendataion is 40Gi for 7B models, 80Gi for 13B models and 200Gi for 70B models. Can be inferred from the model size. gpus (`Optional[int]`): - Number of gpus each worker should get, e.g. 0, 1, etc. Can be inferred from the model size. + Number of gpus each node in the worker should get, e.g. 0, 1, etc. Can be inferred from the model size. + + nodes_per_worker (`int`): + Number of nodes per worker. Used to request multinode serving. This must be greater than or equal to 1. + Controls how many nodes to dedicate to one instance of the model. + Specifically, if `nodes_per_worker` is set to greater than 1, the model will be sharded across + `nodes_per_worker` nodes (e.g. kubernetes pods). One of these nodes will be a "leader" node and receive requests. + LLM Engine will set up the inter-node communication. + Any compute resource requests (i.e. cpus, memory, storage) apply to each individual node, thus the total resources + allocated are multiplied by this number. This is useful for models that require more memory than a single node can provide. + Note: autoscaling is not supported for multinode serving. + Further note: if your model can fit on GPUs on only one machine, e.g. you have access to an 8xA100 machine and your model fits + on 8 A100s, it is recommended to set `nodes_per_worker` to 1 and the rest of the resources accordingly. + `nodes_per_worker > 1` should only be set if you require more resources than a single machine can provide. min_workers (`int`): The minimum number of workers. Must be greater than or equal to 0. This @@ -297,6 +311,7 @@ def create( endpoint_type=ModelEndpointType(endpoint_type), gpus=gpus, gpu_type=GpuType(gpu_type) if gpu_type is not None else None, + nodes_per_worker=nodes_per_worker, labels=labels or {}, max_workers=max_workers, memory=memory, @@ -482,6 +497,7 @@ def update( labels: Optional[Dict[str, str]] = None, request_headers: Optional[Dict[str, str]] = None, ) -> UpdateLLMEndpointResponse: + # Can't adjust nodes_per_worker """ Update an LLM model. Note: This API is only available for self-hosted users. @@ -511,20 +527,20 @@ def update( For model weights, safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be longer). cpus (`Optional[int]`): - Number of cpus each worker should get, e.g. 1, 2, etc. This must be greater + Number of cpus each node in the worker should get, e.g. 1, 2, etc. This must be greater than or equal to 1. Recommendation is set it to 8 * GPU count. memory (`Optional[str]`): - Amount of memory each worker should get, e.g. "4Gi", "512Mi", etc. This must + Amount of memory each node in the worker should get, e.g. "4Gi", "512Mi", etc. This must be a positive amount of memory. Recommendation is set it to 24Gi * GPU count. storage (`Optional[str]`): - Amount of local ephemeral storage each worker should get, e.g. "4Gi", + Amount of local ephemeral storage each node in the worker should get, e.g. "4Gi", "512Mi", etc. This must be a positive amount of storage. Recommendataion is 40Gi for 7B models, 80Gi for 13B models and 200Gi for 70B models. gpus (`Optional[int]`): - Number of gpus each worker should get, e.g. 0, 1, etc. + Number of gpus each node in the worker should get, e.g. 0, 1, etc. min_workers (`Optional[int]`): The minimum number of workers. Must be greater than or equal to 0. This diff --git a/model-engine/model_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py index 662e5ef8..fd3a06a4 100644 --- a/model-engine/model_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -137,7 +137,7 @@ async def update_model_endpoint( external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> UpdateModelEndpointV1Response: """ - Lists the Models owned by the current owner. + Updates the Model endpoint. """ logger.info(f"PUT /model-endpoints/{model_endpoint_id} with {request} for {auth}") try: diff --git a/model-engine/model_engine_server/common/dtos/batch_jobs.py b/model-engine/model_engine_server/common/dtos/batch_jobs.py index 7b6a62ed..8d24665e 100644 --- a/model-engine/model_engine_server/common/dtos/batch_jobs.py +++ b/model-engine/model_engine_server/common/dtos/batch_jobs.py @@ -65,6 +65,9 @@ class CreateDockerImageBatchJobResourceRequests(BaseModel): gpus: Optional[int] = None gpu_type: Optional[GpuType] = None storage: Optional[StorageSpecificationType] = None + nodes_per_worker: Optional[int] = ( + None # TODO this is used only for inferring hardware, if multinode batch jobs is added we can reuse this field + ) model_config = ConfigDict(from_attributes=True) @classmethod diff --git a/model-engine/model_engine_server/common/dtos/endpoint_builder.py b/model-engine/model_engine_server/common/dtos/endpoint_builder.py index 64ea43d0..2f5c5dbc 100644 --- a/model-engine/model_engine_server/common/dtos/endpoint_builder.py +++ b/model-engine/model_engine_server/common/dtos/endpoint_builder.py @@ -22,6 +22,7 @@ class BuildEndpointRequest(BaseModel): memory: StorageSpecificationType gpu_type: Optional[GpuType] = None storage: Optional[StorageSpecificationType] = None + nodes_per_worker: int = 1 # Multinode support. >1 = multinode. optimize_costs: bool aws_role: str results_s3_bucket: str diff --git a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py index c6f8df02..82619f47 100644 --- a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py @@ -58,6 +58,7 @@ class CreateLLMModelEndpointV1Request(BaseModel): memory: Optional[StorageSpecificationType] = None gpu_type: Optional[GpuType] = None storage: Optional[StorageSpecificationType] = None + nodes_per_worker: Optional[int] = None optimize_costs: Optional[bool] = None min_workers: int max_workers: int diff --git a/model-engine/model_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py index a173cfe0..8e6f929e 100644 --- a/model-engine/model_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -58,6 +58,7 @@ class CreateModelEndpointV1Request(BaseModel): memory: StorageSpecificationType gpu_type: Optional[GpuType] = None storage: StorageSpecificationType + nodes_per_worker: int = Field(gt=0, default=1) optimize_costs: Optional[bool] = None min_workers: int = Field(..., ge=0) max_workers: int = Field(..., ge=0) diff --git a/model-engine/model_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py index 04a07edc..57145c64 100644 --- a/model-engine/model_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -65,7 +65,7 @@ def validate_resource_requests( gpus: Optional[int], gpu_type: Optional[GpuType], ) -> None: - """Validates whether cpu/memory requests are reasonable""" + """Validates whether cpu/memory requests are reasonable. Shouldn't need to validate any nodes_per_worker in the multinode case""" if (gpus is None or gpus == 0) and gpu_type is not None: raise EndpointResourceInvalidRequestException( diff --git a/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py index 814c8683..f4e4db3c 100644 --- a/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py @@ -37,6 +37,12 @@ class ModelEndpointStatus(str, Enum): class ModelEndpointResourceState(BaseModel): """ This is the entity-layer class for the resource settings per worker of a Model Endpoint. + Note: in the multinode case, there are multiple "nodes" per "worker". + "Nodes" is analogous to a single k8s pod that may take up all the GPUs on a single machine. + "Workers" is the smallest unit that a request can be made to, and consists of one leader "node" and + multiple follower "nodes" (named "worker" in the k8s LeaderWorkerSet definition). + cpus/gpus/memory/storage are per-node, thus the total consumption by a "worker" + is cpus/gpus/etc. multiplied by nodes_per_worker. """ cpus: CpuSpecificationType # TODO(phil): try to use decimal.Decimal @@ -44,6 +50,7 @@ class ModelEndpointResourceState(BaseModel): memory: StorageSpecificationType gpu_type: Optional[GpuType] = None storage: Optional[StorageSpecificationType] = None + nodes_per_worker: int = Field(..., ge=1) # Multinode support. >1 = multinode. optimize_costs: Optional[bool] = None diff --git a/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py index 8b80a525..b00470c3 100644 --- a/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py @@ -17,7 +17,7 @@ class StreamingModelEndpointInferenceGateway(ABC): @abstractmethod def streaming_predict( - self, topic: str, predict_request: SyncEndpointPredictV1Request + self, topic: str, predict_request: SyncEndpointPredictV1Request, manually_resolve_dns: bool ) -> AsyncIterable[SyncEndpointPredictV1Response]: """ Runs a prediction request and returns a streaming response. diff --git a/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py index 90d77950..8df1277f 100644 --- a/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py @@ -16,7 +16,7 @@ class SyncModelEndpointInferenceGateway(ABC): @abstractmethod async def predict( - self, topic: str, predict_request: SyncEndpointPredictV1Request + self, topic: str, predict_request: SyncEndpointPredictV1Request, manually_resolve_dns: bool ) -> SyncEndpointPredictV1Response: """ Runs a prediction request and returns a response. diff --git a/model-engine/model_engine_server/domain/services/model_endpoint_service.py b/model-engine/model_engine_server/domain/services/model_endpoint_service.py index aed90ddd..4cc89227 100644 --- a/model-engine/model_engine_server/domain/services/model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/model_endpoint_service.py @@ -76,6 +76,7 @@ async def create_model_endpoint( memory: StorageSpecificationType, gpu_type: Optional[GpuType], storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, min_workers: int, max_workers: int, diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 6e1fa867..7f78dc50 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -507,11 +507,18 @@ async def execute( quantize: Optional[Quantization], checkpoint_path: Optional[str], chat_template_override: Optional[str], + nodes_per_worker: int, ) -> ModelBundle: + multinode = nodes_per_worker > 1 if source == LLMSource.HUGGING_FACE: self.check_docker_image_exists_for_image_tag( framework_image_tag, INFERENCE_FRAMEWORK_REPOSITORY[framework] ) + if multinode and framework != LLMInferenceFramework.VLLM: + raise ObjectHasInvalidValueException( + f"Multinode is not supported for framework {framework}." + ) + if framework == LLMInferenceFramework.DEEPSPEED: bundle_id = await self.create_deepspeed_bundle( user, @@ -531,16 +538,29 @@ async def execute( checkpoint_path, ) elif framework == LLMInferenceFramework.VLLM: - bundle_id = await self.create_vllm_bundle( - user, - model_name, - framework_image_tag, - endpoint_name, - num_shards, - quantize, - checkpoint_path, - chat_template_override, - ) + if multinode: + bundle_id = await self.create_vllm_multinode_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + nodes_per_worker, + quantize, + checkpoint_path, + chat_template_override, + ) + else: + bundle_id = await self.create_vllm_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + quantize, + checkpoint_path, + chat_template_override, + ) elif framework == LLMInferenceFramework.LIGHTLLM: bundle_id = await self.create_lightllm_bundle( user, @@ -819,17 +839,21 @@ async def create_deepspeed_bundle( ) ).model_bundle_id - async def create_vllm_bundle( + def _create_vllm_bundle_command( self, - user: User, model_name: str, framework_image_tag: str, - endpoint_unique_name: str, num_shards: int, quantize: Optional[Quantization], checkpoint_path: Optional[str], chat_template_override: Optional[str], + multinode: bool, + is_worker: bool, + nodes_per_worker: int = 1, # only used if multinode ): + """ + VLLM start command for the single worker, or the leader in a LeaderWorkerSet. + """ command = [] subcommands = [] @@ -846,65 +870,187 @@ async def create_vllm_bundle( final_weights_folder, ) - vllm_cmd = f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005" + if multinode and not is_worker: + ray_cmd = "/workspace/init_ray.sh leader --ray_cluster_size=$RAY_CLUSTER_SIZE --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" + subcommands.append(ray_cmd) + elif multinode and is_worker: + ray_cmd = "/workspace/init_ray.sh worker --ray_address=$LWS_LEADER_ADDRESS.svc.cluster.local --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" + subcommands.append(ray_cmd) - chat_template_cmd = None - if chat_template_override: - # We encode the chat template as base64 to avoid issues with special characters - # and decode it via bash - chat_template_cmd = f'export CHAT_TEMPLATE=$(echo "{encode_template(chat_template_override)}" | base64 --decode)' - subcommands.append(chat_template_cmd) - vllm_cmd += ' --chat-template "$CHAT_TEMPLATE"' + if not is_worker: + vllm_cmd = "" - if quantize: # pragma: no cover - if quantize != Quantization.AWQ: - raise InvalidRequestException(f"Quantization {quantize} is not supported by vLLM.") + vllm_cmd += f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005" - vllm_cmd += f" --quantization {quantize}" + if multinode: + vllm_cmd += f" --pipeline-parallel-size {nodes_per_worker}" - if hmi_config.sensitive_log_mode: # pragma: no cover - vllm_cmd += " --disable-log-requests" + chat_template_cmd = None + if chat_template_override: + # We encode the chat template as base64 to avoid issues with special characters + # and decode it via bash + chat_template_cmd = f'export CHAT_TEMPLATE=$(echo "{encode_template(chat_template_override)}" | base64 --decode)' + subcommands.append(chat_template_cmd) + vllm_cmd += ' --chat-template "$CHAT_TEMPLATE"' - additional_args = infer_addition_engine_args_from_model_name(model_name) + if quantize: # pragma: no cover + if quantize != Quantization.AWQ: + raise InvalidRequestException( + f"Quantization {quantize} is not supported by vLLM." + ) - if additional_args.max_gpu_memory_utilization: - vllm_cmd += f" --gpu-memory-utilization {additional_args.max_gpu_memory_utilization} --enforce-eager" + vllm_cmd += f" --quantization {quantize}" - if additional_args.attention_backend: - vllm_cmd += " --attention-backend FLASHINFER" + if hmi_config.sensitive_log_mode: # pragma: no cover + vllm_cmd += " --disable-log-requests" + + additional_args = infer_addition_engine_args_from_model_name(model_name) + + if additional_args.max_gpu_memory_utilization: + vllm_cmd += f" --gpu-memory-utilization {additional_args.max_gpu_memory_utilization} --enforce-eager" + + if additional_args.attention_backend: + vllm_cmd += " --attention-backend FLASHINFER" + + subcommands.append(vllm_cmd) - subcommands.append(vllm_cmd) command = [ "/bin/bash", "-c", ";".join(subcommands), ] + return command + + async def create_vllm_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + quantize: Optional[Quantization], + checkpoint_path: Optional[str], + chat_template_override: Optional[str], + ): + command = self._create_vllm_bundle_command( + model_name, + framework_image_tag, + num_shards, + quantize, + checkpoint_path, + chat_template_override, + multinode=False, + is_worker=False, + nodes_per_worker=1, + ) + + create_model_bundle_v2_request = CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.vllm_repository, + tag=framework_image_tag, + command=command, + streaming_command=command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/health", + predict_route="/predict", + streaming_predict_route="/stream", + extra_routes=[ + OPENAI_CHAT_COMPLETION_PATH, + OPENAI_COMPLETION_PATH, + ], + env={}, + ), + metadata={}, + ) + return ( await self.create_model_bundle_use_case.execute( user, - CreateModelBundleV2Request( - name=endpoint_unique_name, - schema_location="TBA", - flavor=StreamingEnhancedRunnableImageFlavor( - flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, - repository=hmi_config.vllm_repository, - tag=framework_image_tag, - command=command, - streaming_command=command, - protocol="http", - readiness_initial_delay_seconds=10, - healthcheck_route="/health", - predict_route="/predict", - streaming_predict_route="/stream", - extra_routes=[ - OPENAI_CHAT_COMPLETION_PATH, - OPENAI_COMPLETION_PATH, - ], - env={}, - ), - metadata={}, - ), + create_model_bundle_v2_request, + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + + async def create_vllm_multinode_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + nodes_per_worker: int, + quantize: Optional[Quantization], + checkpoint_path: Optional[str], + chat_template_override: Optional[str], + ): + leader_command = self._create_vllm_bundle_command( + model_name, + framework_image_tag, + num_shards, + quantize, + checkpoint_path, + chat_template_override, + multinode=True, + is_worker=False, + nodes_per_worker=nodes_per_worker, + ) + worker_command = self._create_vllm_bundle_command( + model_name, + framework_image_tag, + num_shards, + quantize, + checkpoint_path, + chat_template_override, + multinode=True, + is_worker=True, + nodes_per_worker=nodes_per_worker, + ) + + # These env vars e.g. K8S_OWN_POD_NAME, K8S_OWN_POD_NAME, K8S_OWN_NAMESPACE, K8S_LWS_CLUSTER_SIZE will be filled in automatically for all LWS pods through + # Launch's k8s_endpoint_resource_delegate + common_vllm_envs = { + "VLLM_HOST_IP": "$(K8S_OWN_POD_NAME).$(K8S_LWS_NAME).$(K8S_OWN_NAMESPACE).svc.cluster.local", # this needs to match what's given as --own-address in the vllm start command + "NCCL_SOCKET_IFNAME": "eth0", + "GLOO_SOCKET_IFNAME": "eth0", # maybe don't need + "NCCL_DEBUG": "INFO", # TODO remove once fully tested, will keep around for now + "VLLM_LOGGING_LEVEL": "INFO", # TODO remove once fully tested, will keep around for now + "RAY_CLUSTER_SIZE": "$(K8S_LWS_CLUSTER_SIZE)", + } + + create_model_bundle_v2_request = CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.vllm_repository, + tag=framework_image_tag, + command=leader_command, + streaming_command=leader_command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/health", + predict_route="/predict", + streaming_predict_route="/stream", + extra_routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH], + env=common_vllm_envs, + worker_command=worker_command, + worker_env=common_vllm_envs, + ), + metadata={}, + ) + + return ( + await self.create_model_bundle_use_case.execute( + user, + create_model_bundle_v2_request, do_auth_check=False, # Skip auth check because llm create endpoint is called as the user itself, # but the user isn't directly making the action. It should come from the fine tune @@ -1073,6 +1219,7 @@ async def execute( and request.cpus and request.memory and request.storage + and request.nodes_per_worker ): raise RuntimeError("Some hardware info is missing unexpectedly.") validate_deployment_resources( @@ -1111,6 +1258,14 @@ async def execute( request.inference_framework ) + if ( + request.nodes_per_worker > 1 + and not request.inference_framework == LLMInferenceFramework.VLLM + ): + raise ObjectHasInvalidValueException( + "Multinode endpoints are only supported for VLLM models." + ) + bundle = await self.create_llm_model_bundle_use_case.execute( user, endpoint_name=request.name, @@ -1123,6 +1278,7 @@ async def execute( quantize=request.quantize, checkpoint_path=request.checkpoint_path, chat_template_override=request.chat_template_override, + nodes_per_worker=request.nodes_per_worker, ) validate_resource_requests( bundle=bundle, @@ -1170,6 +1326,7 @@ async def execute( memory=request.memory, gpu_type=request.gpu_type, storage=request.storage, + nodes_per_worker=request.nodes_per_worker, optimize_costs=bool(request.optimize_costs), min_workers=request.min_workers, max_workers=request.max_workers, @@ -1380,6 +1537,7 @@ async def execute( quantize=quantize, checkpoint_path=checkpoint_path, chat_template_override=chat_template_override, + nodes_per_worker=model_endpoint.infra_state.resource_state.nodes_per_worker, ) metadata = endpoint_record.metadata or {} @@ -1803,6 +1961,12 @@ async def execute( endpoint_id=model_endpoint.record.id ) endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) validated_request = validate_and_update_completion_params( endpoint_content.inference_framework, request ) @@ -1835,6 +1999,7 @@ async def execute( predict_result = await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: @@ -1884,6 +2049,7 @@ async def execute( predict_result = await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1943,6 +2109,7 @@ async def execute( predict_result = await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -1993,6 +2160,7 @@ async def execute( predict_result = await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -2036,6 +2204,7 @@ async def execute( predict_result = await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -2157,6 +2326,12 @@ async def execute( ) request = validated_request + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + args: Any = None num_prompt_tokens = None if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: @@ -2288,6 +2463,7 @@ async def execute( inference_gateway=inference_gateway, inference_request=inference_request, num_prompt_tokens=num_prompt_tokens, + manually_resolve_dns=manually_resolve_dns, ) async def _response_chunk_generator( @@ -2299,13 +2475,16 @@ async def _response_chunk_generator( inference_gateway: StreamingModelEndpointInferenceGateway, inference_request: SyncEndpointPredictV1Request, num_prompt_tokens: Optional[int], + manually_resolve_dns: bool, ) -> AsyncIterable[CompletionStreamV1Response]: """ Async generator yielding tokens to stream for the completions response. Should only be called when returned directly by execute(). """ predict_result = inference_gateway.streaming_predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) num_completion_tokens = 0 @@ -2542,6 +2721,12 @@ async def execute( ) endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + validate_endpoint_supports_openai_completion(model_endpoint, endpoint_content) # if inference framework is VLLM, we need to set the model to use the weights folder @@ -2558,6 +2743,7 @@ async def execute( predict_result = await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -2640,6 +2826,13 @@ async def execute( ) model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + validate_endpoint_supports_openai_completion(model_endpoint, model_content) # if inference framework is VLLM, we need to set the model to use the weights folder @@ -2659,6 +2852,7 @@ async def execute( model_content=model_content, inference_gateway=inference_gateway, inference_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) async def _response_chunk_generator( @@ -2668,6 +2862,7 @@ async def _response_chunk_generator( model_content: GetLLMModelEndpointV1Response, inference_gateway: StreamingModelEndpointInferenceGateway, inference_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool, ) -> AsyncGenerator[CompletionV2StreamSuccessChunk, None]: # pragma: no cover """ Async generator yielding tokens to stream for the completions response. Should only be called when @@ -2677,6 +2872,7 @@ async def _response_chunk_generator( predict_result = inference_gateway.streaming_predict( topic=model_endpoint.record.destination, predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) except UpstreamServiceError as exc: # Expect upstream inference service to handle bulk of input validation @@ -2792,6 +2988,12 @@ async def execute( ) endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + validate_endpoint_supports_chat_completion(model_endpoint, endpoint_content) # if inference framework is VLLM, we need to set the model to use the weights folder @@ -2808,6 +3010,7 @@ async def execute( predict_result = await inference_gateway.predict( topic=model_endpoint.record.destination, predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -2890,6 +3093,12 @@ async def execute( ) model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) validate_endpoint_supports_chat_completion(model_endpoint, model_content) # if inference framework is VLLM, we need to set the model to use the weights folder @@ -2909,6 +3118,7 @@ async def execute( model_content=model_content, inference_gateway=inference_gateway, inference_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) async def _response_chunk_generator( @@ -2918,6 +3128,7 @@ async def _response_chunk_generator( model_content: GetLLMModelEndpointV1Response, inference_gateway: StreamingModelEndpointInferenceGateway, inference_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool, ) -> AsyncGenerator[ChatCompletionV2StreamSuccessChunk, None]: """ Async generator yielding tokens to stream for the completions response. Should only be called when @@ -2927,6 +3138,7 @@ async def _response_chunk_generator( predict_result = inference_gateway.streaming_predict( topic=model_endpoint.record.destination, predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, ) except UpstreamServiceError as exc: # Expect upstream inference service to handle bulk of input validation @@ -2992,6 +3204,7 @@ async def _fill_hardware_info( or request.cpus is None or request.memory is None or request.storage is None + or request.nodes_per_worker is None ): if not ( request.gpus is None @@ -2999,9 +3212,10 @@ async def _fill_hardware_info( and request.cpus is None and request.memory is None and request.storage is None + and request.nodes_per_worker is None ): raise ObjectHasInvalidValueException( - "All hardware spec fields (gpus, gpu_type, cpus, memory, storage) must be provided if any hardware spec field is missing." + "All hardware spec fields (gpus, gpu_type, cpus, memory, storage, nodes_per_worker) must be provided if any hardware spec field is missing." ) checkpoint_path = get_checkpoint_path(request.model_name, request.checkpoint_path) hardware_info = await _infer_hardware( @@ -3012,6 +3226,7 @@ async def _fill_hardware_info( request.cpus = hardware_info.cpus request.memory = hardware_info.memory request.storage = hardware_info.storage + request.nodes_per_worker = hardware_info.nodes_per_worker if hardware_info.gpus: # make lint happy request.num_shards = hardware_info.gpus @@ -3088,6 +3303,7 @@ async def _infer_hardware( memory = by_model_name[model_name]["memory"] storage = by_model_name[model_name]["storage"] gpu_type = by_model_name[model_name]["gpu_type"] + nodes_per_worker = by_model_name[model_name]["nodes_per_worker"] else: by_gpu_memory_gb = sorted(by_gpu_memory_gb, key=lambda x: x["gpu_memory_le"]) for recs in by_gpu_memory_gb: @@ -3097,12 +3313,18 @@ async def _infer_hardware( memory = recs["memory"] storage = recs["storage"] gpu_type = recs["gpu_type"] + nodes_per_worker = recs["nodes_per_worker"] break else: raise ObjectHasInvalidValueException(f"Unable to infer hardware for {model_name}.") return CreateDockerImageBatchJobResourceRequests( - cpus=cpus, gpus=gpus, memory=memory, storage=storage, gpu_type=gpu_type + cpus=cpus, + gpus=gpus, + memory=memory, + storage=storage, + gpu_type=gpu_type, + nodes_per_worker=nodes_per_worker, ) diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index 21a55dcb..d5318f74 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -27,8 +27,10 @@ LiveAuthorizationModule, ) from model_engine_server.domain.entities import ( + ModelBundle, ModelEndpoint, ModelEndpointType, + RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor, ) from model_engine_server.domain.exceptions import ( @@ -215,6 +217,47 @@ def validate_post_inference_hooks(user: User, post_inference_hooks: Optional[Lis ) +def validate_bundle_multinode_compatibility(bundle: ModelBundle, nodes_per_worker: int): + """ + Only some bundles can be multinode compatible. + """ + if nodes_per_worker == 1: + return + # can type ignore, bundle.flavor is a RunnableImageFlavor/StreamingEnhancedRunnableImageFlavor thus it has worker_command and worker_env + if ( + type(bundle.flavor) in {RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor} + and bundle.flavor.worker_command is not None # type: ignore + and bundle.flavor.worker_env is not None # type: ignore + ): + return + raise ObjectHasInvalidValueException( + f"Bundle {bundle.name} is not multinode compatible. It must be a RunnableImage and have worker_command and worker_args set." + ) + + +def validate_endpoint_resource_multinode_compatibility( + gpu_type: Optional[str], + gpus: Optional[int], + endpoint_type: ModelEndpointType, + nodes_per_worker: int, +): + """ + Only gpu streaming endpoints can be multinode compatible. + """ + if nodes_per_worker == 1: + return + if ( + endpoint_type == ModelEndpointType.STREAMING + and gpu_type is not None + and gpus is not None + and gpus > 0 + ): + return + raise ObjectHasInvalidValueException( + "Endpoint is not multinode compatible. Only streaming GPU endpoints can be multinode compatible." + ) + + class CreateModelEndpointV1UseCase: def __init__( self, @@ -241,8 +284,13 @@ async def execute( bundle = await self.model_bundle_repository.get_model_bundle( model_bundle_id=request.model_bundle_id ) + if bundle is None: raise ObjectNotFoundException + validate_bundle_multinode_compatibility(bundle, request.nodes_per_worker) + validate_endpoint_resource_multinode_compatibility( + request.gpu_type, request.gpus, request.endpoint_type, request.nodes_per_worker + ) if not self.authz_module.check_access_read_owned_entity(user, bundle): raise ObjectNotAuthorizedException if not isinstance(bundle.flavor, StreamingEnhancedRunnableImageFlavor) and ( @@ -300,6 +348,7 @@ async def execute( memory=request.memory, gpu_type=request.gpu_type, storage=request.storage, + nodes_per_worker=request.nodes_per_worker, optimize_costs=bool(request.optimize_costs), min_workers=request.min_workers, max_workers=request.max_workers, diff --git a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py index b358ea04..bdd27476 100644 --- a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py @@ -1,5 +1,6 @@ from typing import AsyncIterable +from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.tasks import ( SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, @@ -67,6 +68,17 @@ async def execute( await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint_id ) + # Hack: manually resolve dns if istio is present. Since we do not inject istio for multinode, + # empirically we find that without manual dns resolution, requests to the k8s service DNS name fail, + # likely because the requests are getting changed by Istio. A fix is to resolve the service DNS name + # (e.g. model-endpoint-foo.namespace.svc.cluster.local) to the actual IP address of the service + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) return inference_gateway.streaming_predict( - topic=model_endpoint.record.destination, predict_request=request + topic=model_endpoint.record.destination, + predict_request=request, + manually_resolve_dns=manually_resolve_dns, ) diff --git a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py index 6835f74a..4985063a 100644 --- a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py @@ -1,3 +1,4 @@ +from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.tasks import ( SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, @@ -71,6 +72,17 @@ async def execute( await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint_id ) + # Hack: manually resolve dns if istio is present. Since we do not inject istio for multinode, + # empirically we find that without manual dns resolution, requests to the k8s service DNS name fail, + # likely because the requests are getting changed by Istio. A fix is to resolve the service DNS name + # (e.g. model-endpoint-foo.namespace.svc.cluster.local) to the actual IP address of the service + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) return await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=request + topic=model_endpoint.record.destination, + predict_request=request, + manually_resolve_dns=manually_resolve_dns, ) diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm index 6494be71..8b005722 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm @@ -20,6 +20,7 @@ RUN ln -s /usr/bin/python3 /usr/bin/python FROM base AS vllm COPY model-engine/model_engine_server/inference/vllm/vllm_server.py /workspace/vllm_server.py +COPY model-engine/model_engine_server/inference/vllm/init_ray.sh /workspace/init_ray.sh # Need to override entrypoint from parent image ENTRYPOINT ["/bin/env"] diff --git a/model-engine/model_engine_server/inference/vllm/init_ray.sh b/model-engine/model_engine_server/inference/vllm/init_ray.sh new file mode 100755 index 00000000..f685206a --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/init_ray.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +# From https://github.com/kubernetes-sigs/lws/blob/main/docs/examples/vllm/build/ray_init.sh +subcommand=$1 +shift + +ray_port=6379 +ray_init_timeout=1200 # Needs to be large enough to overcome any skew from the s5cmd command + any pod startup time + +case "$subcommand" in + worker) + ray_address="" + while [ $# -gt 0 ]; do + case "$1" in + --ray_address=*) + ray_address="${1#*=}" + ;; + --ray_port=*) + ray_port="${1#*=}" + ;; + --ray_init_timeout=*) + ray_init_timeout="${1#*=}" + ;; + --own_address=*) + own_address="${1#*=}" + ;; + *) + echo "unknown argument: $1" + exit 1 + esac + shift + done + + if [ -z "$ray_address" ]; then + echo "Error: Missing argument --ray_address" + exit 1 + fi + for (( i=0; i < $ray_init_timeout; i+=5 )); do + ray start --address=$ray_address:$ray_port --block --node-ip-address=$own_address + if [ $? -eq 0 ]; then + echo "Worker: Ray runtime started with head address $ray_address:$ray_port" + exit 0 + fi + echo $? + echo "Waiting until the ray worker is active..." + sleep 5s; + done + echo "Ray worker starts timeout, head address: $ray_address:$ray_port" + exit 1 + ;; + + leader) + ray_cluster_size="" + while [ $# -gt 0 ]; do + case "$1" in + --ray_port=*) + ray_port="${1#*=}" + ;; + --ray_cluster_size=*) + ray_cluster_size="${1#*=}" + ;; + --ray_init_timeout=*) + ray_init_timeout="${1#*=}" + ;; + --own_address=*) + own_address="${1#*=}" + ;; + *) + echo "unknown argument: $1" + exit 1 + esac + shift + done + + if [ -z "$ray_cluster_size" ]; then + echo "Error: Missing argument --ray_cluster_size" + exit 1 + fi + + # start the ray daemon + ray start --head --port=$ray_port --node-ip-address=$own_address + # wait until all workers are active + for (( i=0; i < $ray_init_timeout; i+=5 )); do + active_nodes=`python3 -c 'import ray; ray.init(); print(sum(node["Alive"] for node in ray.nodes()))'` + if [ $active_nodes -eq $ray_cluster_size ]; then + echo "All ray workers are active and the ray cluster is initialized successfully." + exit 0 + fi + echo "Wait for all ray workers to be active. $active_nodes/$ray_cluster_size is active" + sleep 5s; + done + + echo "Waiting for all ray workers to be active timed out." + exit 1 + ;; + + *) + echo "unknown subcommand: $subcommand" + exit 1 + ;; +esac diff --git a/model-engine/model_engine_server/infra/gateways/dns_resolver.py b/model-engine/model_engine_server/infra/gateways/dns_resolver.py new file mode 100644 index 00000000..0579f98a --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/dns_resolver.py @@ -0,0 +1,22 @@ +import socket +from typing import Union + + +def resolve_dns(host: str, port: Union[str, int] = "http") -> str: + """ + Returns an IP address of the given host, e.g. "256.256.256.256" for IPv4, or + "[0000:0000:0000::0000]" for IPv6. You should be able to just substitute this into a URL. + """ + addrinfo = socket.getaddrinfo(host, port) + if len(addrinfo) == 0: + raise ValueError("Host not found.") + # Probably just need the first one + socket_type = addrinfo[0][0] + ip = addrinfo[0][4][0] + # Do I want to do anything with port? it probably ends up being the default (e.g. 80 for http, 443 for https) + if socket_type == socket.AF_INET6: + return f"[{ip}]" + elif socket_type == socket.AF_INET: + return ip + else: + raise ValueError("Unknown socket type.") diff --git a/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py index b92726e2..7294e533 100644 --- a/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py @@ -62,6 +62,7 @@ def create_model_endpoint_infra( memory: StorageSpecificationType, gpu_type: Optional[GpuType], storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, aws_role: str, results_s3_bucket: str, @@ -88,6 +89,7 @@ def create_model_endpoint_infra( memory=memory, gpu_type=gpu_type, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, aws_role=aws_role, results_s3_bucket=results_s3_bucket, @@ -151,6 +153,8 @@ async def update_model_endpoint_infra( gpu_type = infra_state.resource_state.gpu_type if storage is None: storage = infra_state.resource_state.storage + # Don't allow changing nodes_per_worker + nodes_per_worker = infra_state.resource_state.nodes_per_worker if optimize_costs is None: optimize_costs = infra_state.resource_state.optimize_costs or False if child_fn_info is None: @@ -199,6 +203,7 @@ async def update_model_endpoint_infra( memory=memory, gpu_type=gpu_type, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, aws_role=aws_role, results_s3_bucket=results_s3_bucket, diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index 00c51e87..d4df0b6b 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -23,6 +23,7 @@ from model_engine_server.domain.gateways.streaming_model_endpoint_inference_gateway import ( StreamingModelEndpointInferenceGateway, ) +from model_engine_server.infra.gateways.dns_resolver import resolve_dns from model_engine_server.infra.gateways.k8s_resource_parser import get_node_port from orjson import JSONDecodeError from tenacity import ( @@ -45,20 +46,27 @@ ) -def _get_streaming_endpoint_url(deployment_name: str, path: str = "/stream") -> str: +def _get_streaming_endpoint_url( + service_name: str, path: str = "/stream", manually_resolve_dns: bool = False +) -> str: if CIRCLECI: # Circle CI: a NodePort is used to expose the service # The IP address is obtained from `minikube ip`. protocol: str = "http" - hostname: str = f"192.168.49.2:{get_node_port(deployment_name)}" + hostname: str = f"192.168.49.2:{get_node_port(service_name)}" elif LOCAL: # local development: the svc.cluster.local address is only available w/in the k8s cluster protocol = "https" - hostname = f"{deployment_name}.{infra_config().dns_host_domain}" + hostname = f"{service_name}.{infra_config().dns_host_domain}" + elif manually_resolve_dns: + protocol = "http" + hostname = resolve_dns( + f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local", port=protocol + ) else: protocol = "http" # no need to hit external DNS resolution if we're w/in the k8s cluster - hostname = f"{deployment_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" + hostname = f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" return f"{protocol}://{hostname}{path}" @@ -189,10 +197,15 @@ async def make_request_with_retries( raise Exception("Should never reach this line") async def streaming_predict( - self, topic: str, predict_request: SyncEndpointPredictV1Request + self, + topic: str, + predict_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool = False, ) -> AsyncIterable[SyncEndpointPredictV1Response]: deployment_url = _get_streaming_endpoint_url( - topic, path=predict_request.destination_path or "/stream" + topic, + path=predict_request.destination_path or "/stream", + manually_resolve_dns=manually_resolve_dns, ) try: diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index 53230ff0..2683123c 100644 --- a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -21,6 +21,7 @@ from model_engine_server.domain.gateways.sync_model_endpoint_inference_gateway import ( SyncModelEndpointInferenceGateway, ) +from model_engine_server.infra.gateways.dns_resolver import resolve_dns from model_engine_server.infra.gateways.k8s_resource_parser import get_node_port from tenacity import ( AsyncRetrying, @@ -42,20 +43,27 @@ ) -def _get_sync_endpoint_url(deployment_name: str, destination_path: str = "/predict") -> str: +def _get_sync_endpoint_url( + service_name: str, destination_path: str = "/predict", manually_resolve_dns: bool = False +) -> str: if CIRCLECI: # Circle CI: a NodePort is used to expose the service # The IP address is obtained from `minikube ip`. protocol: str = "http" - hostname: str = f"192.168.49.2:{get_node_port(deployment_name)}" + hostname: str = f"192.168.49.2:{get_node_port(service_name)}" elif LOCAL: # local development: the svc.cluster.local address is only available w/in the k8s cluster protocol = "https" - hostname = f"{deployment_name}.{infra_config().dns_host_domain}" + hostname = f"{service_name}.{infra_config().dns_host_domain}" + elif manually_resolve_dns: + protocol = "http" + hostname = resolve_dns( + f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local", port=protocol + ) else: protocol = "http" # no need to hit external DNS resolution if we're w/in the k8s cluster - hostname = f"{deployment_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" + hostname = f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" return f"{protocol}://{hostname}{destination_path}" @@ -163,10 +171,15 @@ async def make_request_with_retries( return {} async def predict( - self, topic: str, predict_request: SyncEndpointPredictV1Request + self, + topic: str, + predict_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool = False, ) -> SyncEndpointPredictV1Response: deployment_url = _get_sync_endpoint_url( - topic, destination_path=predict_request.destination_path or "/predict" + topic, + destination_path=predict_request.destination_path or "/predict", + manually_resolve_dns=manually_resolve_dns, ) try: diff --git a/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py index 044bc038..a61a890f 100644 --- a/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py @@ -30,6 +30,7 @@ def create_model_endpoint_infra( memory: StorageSpecificationType, gpu_type: Optional[GpuType], storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, aws_role: str, results_s3_bucket: str, diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 153b8a9f..112d3554 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -61,6 +61,17 @@ BASE_PATH_IN_ENDPOINT = "/app" DATADOG_ENV_VAR = {"DD_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"} +LWS_DEFAULT_ENV_VAR = { + "K8S_OWN_POD_NAME", + "K8S_OWN_NAMESPACE", + "K8S_LWS_NAME", + "K8S_LWS_CLUSTER_SIZE", +} + +# These two should match the values present in `service_template_config_map.yaml` +# for the container names in the LWS template. +LWS_LEADER_CONTAINER_NAME = "lws-leader" +LWS_WORKER_CONTAINER_NAME = "lws-worker" _lazy_load_kubernetes_clients = True _kubernetes_apps_api = None @@ -225,8 +236,39 @@ def get_main_container_from_deployment_template(deployment_template: Dict[str, A return user_container -def add_datadog_env_to_main_container(deployment_template: Dict[str, Any]) -> None: - user_container = get_main_container_from_deployment_template(deployment_template) +def get_leader_container_from_lws_template(lws_template: Dict[str, Any]): + containers = lws_template["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "containers" + ] + for container in containers: + if container["name"] == LWS_LEADER_CONTAINER_NAME: + leader_container = container + break + else: + raise ValueError( + f"leader container (container['name'] == '{LWS_LEADER_CONTAINER_NAME}') not found in lws template when adding datadog env to leader container." + ) + return leader_container + + +def get_worker_container_from_lws_template(lws_template: Dict[str, Any]): + containers = lws_template["spec"]["leaderWorkerTemplate"]["workerTemplate"]["spec"][ + "containers" + ] + for container in containers: + if container["name"] == LWS_WORKER_CONTAINER_NAME: + worker_container = container + break + else: + raise ValueError( + f"worker container (container['name'] == '{LWS_WORKER_CONTAINER_NAME}') not found in lws template when adding datadog env to worker container." + ) + return worker_container + + +def add_datadog_env_to_container( + deployment_template: Dict[str, Any], user_container: Dict[str, Any] +) -> None: user_container_envs = [] for env in user_container["env"]: @@ -261,16 +303,51 @@ def add_datadog_env_to_main_container(deployment_template: Dict[str, Any]) -> No user_container["env"] = user_container_envs +def add_lws_default_env_vars_to_container(container: Dict[str, Any]) -> None: + container_envs = [] + container_envs.extend( + [ + {"name": "K8S_OWN_POD_NAME", "valueFrom": {"fieldRef": {"fieldPath": "metadata.name"}}}, + { + "name": "K8S_OWN_NAMESPACE", + "valueFrom": {"fieldRef": {"fieldPath": "metadata.namespace"}}, + }, + { + "name": "K8S_LWS_NAME", + "valueFrom": { + "fieldRef": {"fieldPath": "metadata.labels['leaderworkerset.sigs.k8s.io/name']"} + }, + }, + { + "name": "K8S_LWS_CLUSTER_SIZE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.annotations['leaderworkerset.sigs.k8s.io/size']" + } + }, + }, + ] + ) + + for env in container["env"]: + if env["name"] not in LWS_DEFAULT_ENV_VAR: + container_envs.append(env) + container["env"] = container_envs + + class K8SEndpointResourceDelegate: async def create_or_update_resources( self, request: CreateOrUpdateResourcesRequest, sqs_queue_name: Optional[str] = None, sqs_queue_url: Optional[str] = None, - ) -> None: + ) -> str: + """ + Returns a "destination", i.e. the name of the service/sqs queue to send tasks to the endpoint + """ await maybe_load_kube_config() try: - await self._create_or_update_resources( + return await self._create_or_update_resources( request=request, sqs_queue_name=sqs_queue_name, sqs_queue_url=sqs_queue_url, @@ -329,6 +406,18 @@ def _get_env_value_from_envlist( return envvar.value return None + @staticmethod + def _get_env_value_from_envlist_for_custom_object( + envlist: Optional[List[Dict]], name: str + ): # pragma: no cover + # Custom objects client returns nested Dicts, not objects. + if envlist is None: + return None + for envvar in envlist: + if envvar["name"] == name: + return envvar["value"] + return None + def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> CommonEndpointParams: """ Reads some values from k8s common to both sync and async endpoints @@ -391,6 +480,67 @@ def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> Common ) return common_build_endpoint_request + def _get_common_endpoint_params_for_lws_type(self, lws_config: Any) -> CommonEndpointParams: + main_container = self._get_main_leader_container_from_lws(lws_config) + launch_container = self._get_launch_container_from_lws(lws_config) + + resources = main_container["resources"] + image = main_container["image"] + + cpus = resources["requests"]["cpu"] + memory = resources["requests"]["memory"] + gpus = int((resources["limits"] or dict()).get("nvidia.com/gpu", 0)) + storage = resources["requests"].get("ephemeral-storage") + + envlist = launch_container["env"] + # There really isn't a bundle_url for LWS since those use RunnableImages + bundle_url = ( + self._get_env_value_from_envlist_for_custom_object(envlist, "BUNDLE_URL") or image + ) + aws_role = self._get_env_value_from_envlist_for_custom_object(envlist, "AWS_PROFILE") + results_s3_bucket = self._get_env_value_from_envlist_for_custom_object( + envlist, "RESULTS_S3_BUCKET" + ) + + # AWS_PROFILE and RESULTS_S3_BUCKET should always be set, but if not present + # we can fetch them from the config. + if aws_role is None: + aws_role = infra_config().profile_ml_inference_worker + if results_s3_bucket is None: + results_s3_bucket = infra_config().s3_bucket + + if bundle_url is None or aws_role is None or results_s3_bucket is None: + raise ValueError("Failed to fetch common endpoint values.") + + try: + node_selector = lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "nodeSelector" + ] + gpu_type = node_selector.get("k8s.amazonaws.com/accelerator", None) + except KeyError: + gpu_type = None + + try: + labels = lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["metadata"][ + "labels" + ] + except KeyError: + labels = None + + common_build_endpoint_request: CommonEndpointParams = dict( + cpus=cpus, + memory=memory, + gpus=gpus, + gpu_type=gpu_type, + storage=storage, + bundle_url=bundle_url, + aws_role=aws_role, + results_s3_bucket=results_s3_bucket, + image=image, + labels=labels, + ) + return common_build_endpoint_request + @staticmethod def _get_main_container(deployment_config: V1Deployment) -> V1Container: pod_containers = deployment_config.spec.template.spec.containers @@ -417,8 +567,83 @@ def _get_launch_container(deployment_config: V1Deployment) -> V1Container: raise ValueError("No main container detected") return name_to_container["main"] + @staticmethod + def _get_main_leader_container_from_lws(lws_config: Any): + """ + Similar to _get_main_container, this returns a nested dict. + """ + leader_containers = lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "containers" + ] + name_to_container = {container["name"]: container for container in leader_containers} + if LWS_LEADER_CONTAINER_NAME not in name_to_container: + raise ValueError("No main leader container detected") + return name_to_container[LWS_LEADER_CONTAINER_NAME] + + @staticmethod + def _get_launch_container_from_lws(lws_config: Any): + leader_containers = lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "containers" + ] + name_to_container = {container["name"]: container for container in leader_containers} + # If a celery forwarder is present, use that + if "celery-forwarder" in name_to_container: + return name_to_container["celery-forwarder"] + + # If a http forwarder is present, use that + if "http-forwarder" in name_to_container: + return name_to_container["http-forwarder"] + + # Don't need backwards compatibility here + raise ValueError("No forwarder container detected") + # --- Private low level fns that interact with k8s + @staticmethod + async def _create_lws( + lws: Dict[str, Any], + name: str, + ) -> None: + """ + Lower-level function to create/replace a LWS + Args: + lws: LWS body (a nested Dict in format specified by Kubernetes) + name: The name of the LWS on k8s + Returns: + Nothing: raises k8s APIException if failure + """ + custom_objects_api = get_kubernetes_custom_objects_client() + try: + await custom_objects_api.create_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + body=lws, + ) + except ApiException as exc: + if exc.status == 409: + logger.info(f"LeaderWorkerSet {name} already exists, replacing") + existing_lws = await custom_objects_api.get_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + name=name, + ) + new_lws = deep_update(existing_lws, lws) + await custom_objects_api.replace_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + name=name, + body=new_lws, + ) + else: + logger.exception("Got an exception when trying to apply the LeaderWorkerSet") + raise + @staticmethod async def _create_deployment( model_endpoint_record: ModelEndpointRecord, deployment: Dict[str, Any], name: str @@ -782,6 +1007,46 @@ async def _create_virtual_service(virtual_service: Dict[str, Any], name: str) -> logger.exception("Got an exception when trying to apply the VirtualService") raise + @staticmethod + async def _create_lws_service_entry(lws_service_entry: Dict[str, Any], name: str) -> None: + # Note: this istio ServiceEntry is specific to the LWS case, + # as it is used to enable the "hack" where we manually resolve + # the IP of a K8s service and route to the IP directly. + custom_objects_api = get_kubernetes_custom_objects_client() + try: + await custom_objects_api.create_namespaced_custom_object( + group="networking.istio.io", + version="v1beta1", + namespace=hmi_config.endpoint_namespace, + plural="serviceentries", + body=lws_service_entry, + ) + except ApiException as exc: + if exc.status == 409: + logger.info(f"ServiceEntry {name} already exists, replacing") + # The async k8s client has a bug with patching custom objects, so we manually + # merge the new ServiceEntry with the old one and then replace the old one with the merged + # one. + existing_service_entry = await custom_objects_api.get_namespaced_custom_object( + group="networking.istio.io", + version="v1beta1", + namespace=hmi_config.endpoint_namespace, + plural="serviceentries", + name=name, + ) + new_service_entry = deep_update(existing_service_entry, lws_service_entry) + await custom_objects_api.replace_namespaced_custom_object( + group="networking.istio.io", + version="v1beta1", + namespace=hmi_config.endpoint_namespace, + plural="serviceentries", + name=name, + body=new_service_entry, + ) + else: + logger.exception("Got an exception when trying to apply the ServiceEntry") + raise + @staticmethod async def _create_service(service, name: str) -> None: """ @@ -818,7 +1083,7 @@ async def _get_config_maps( ) -> List[kubernetes_asyncio.client.models.v1_config_map.V1ConfigMap]: """ Gets ConfigMaps associated with a given user id + endpoint name - This should be considered the same abstraction level as get_deployment + This should be considered the same abstraction level as _get_deployment """ k8s_core_api = get_kubernetes_core_client() @@ -841,6 +1106,34 @@ async def _get_config_maps( ) return config_maps.items + @staticmethod + async def _get_deployment(endpoint_id, deployment_name): + """ + Gets the Deployment associated with a given endpoint_id + deployment name + Handles a legacy fallback case as well, where Deployments were named differently. + + """ + apps_client = get_kubernetes_apps_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + deployment_config = await apps_client.read_namespaced_deployment( + name=k8s_resource_group_name, namespace=hmi_config.endpoint_namespace + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Could not find resource, falling back to legacy deployment_name: " + f"{k8s_resource_group_name=}, {endpoint_id=}, {deployment_name=}" + ) + k8s_resource_group_name = deployment_name + deployment_config = await apps_client.read_namespaced_deployment( + name=k8s_resource_group_name, + namespace=hmi_config.endpoint_namespace, + ) + else: + raise + return deployment_config + @staticmethod async def _get_all_config_maps() -> ( List[kubernetes_asyncio.client.models.v1_config_map.V1ConfigMap] @@ -888,6 +1181,28 @@ def _translate_k8s_config_maps_to_user_config_data( endpoint_config=endpoint_config, ) + @staticmethod + async def _delete_lws(endpoint_id: str) -> bool: + custom_objects_client = get_kubernetes_custom_objects_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + await custom_objects_client.delete_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + name=k8s_resource_group_name, + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Trying to delete nonexistent LeaderWorkerSet {k8s_resource_group_name}" + ) + else: + logger.exception(f"Deletion of LeaderWorkerSet {k8s_resource_group_name} failed") + return False + return True + @staticmethod async def _delete_deployment(endpoint_id: str, deployment_name: str) -> bool: apps_client = get_kubernetes_apps_client() @@ -949,8 +1264,8 @@ async def _delete_config_maps(self, endpoint_id: str, deployment_name: str) -> b @staticmethod async def _delete_service(endpoint_id: str, deployment_name: str) -> bool: - core_client = get_kubernetes_core_client() k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + core_client = get_kubernetes_core_client() try: await core_client.delete_namespaced_service( name=k8s_resource_group_name, namespace=hmi_config.endpoint_namespace @@ -980,6 +1295,22 @@ async def _delete_service(endpoint_id: str, deployment_name: str) -> bool: return False return True + @staticmethod + async def _delete_lws_service(endpoint_id: str, deployment_name: str): + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + lws_service_name = K8SEndpointResourceDelegate._get_lws_service_resource_name( + k8s_resource_group_name + ) + core_client = get_kubernetes_core_client() + try: + await core_client.delete_namespaced_service( + name=lws_service_name, namespace=hmi_config.endpoint_namespace + ) + except ApiException: + logger.exception(f"Deletion of Service {lws_service_name} failed") + return False + return True + @staticmethod async def _delete_destination_rule(endpoint_id: str) -> bool: custom_objects_client = get_kubernetes_custom_objects_client() @@ -1024,6 +1355,28 @@ async def _delete_virtual_service(endpoint_id: str) -> bool: return False return True + @staticmethod + async def _delete_lws_service_entry(endpoint_id: str) -> bool: + custom_objects_client = get_kubernetes_custom_objects_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + await custom_objects_client.delete_namespaced_custom_object( + group="networking.istio.io", + version="v1beta1", + namespace=hmi_config.endpoint_namespace, + plural="serviceentries", + name=k8s_resource_group_name, + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Trying to delete nonexistent ServiceEntry {k8s_resource_group_name}" + ) + else: + logger.exception(f"Deletion of ServiceEntry {k8s_resource_group_name} failed") + return False + return True + @staticmethod async def _delete_vpa(endpoint_id: str) -> bool: custom_objects_client = get_kubernetes_custom_objects_client() @@ -1146,12 +1499,45 @@ def _get_deployment_resource_name(request: CreateOrUpdateResourcesRequest) -> st deployment_resource_name = f"deployment-{flavor_class}-{mode}-{device}" return deployment_resource_name + @staticmethod + def _get_lws_resource_name(request: CreateOrUpdateResourcesRequest) -> str: + build_endpoint_request = request.build_endpoint_request + model_endpoint_record = build_endpoint_request.model_endpoint_record + flavor = model_endpoint_record.current_model_bundle.flavor + if isinstance(flavor, TritonEnhancedRunnableImageFlavor): + flavor_class = "triton-enhanced-runnable-image" + else: + flavor_class = "runnable-image" + if flavor_class == "triton-enhanced-runnable-image": + raise ValueError("LWS is not supported for Triton Enhanced Runnable Image") + # flavor not being triton-enhanced should already be checked in the endpoint create on the gateway + # but check again just in case + # Gateway should also guard against cloudpickle or zip being passed in here + + mode = model_endpoint_record.endpoint_type.value + device = "gpu" if build_endpoint_request.gpus > 0 else "cpu" + if mode not in ["streaming"]: + raise ValueError("LWS is not supported for async or sync endpoints") + if device not in ["gpu"]: + raise ValueError("LWS is not supported for CPU endpoints") + + lws_resource_name = f"leader-worker-set-{mode}-{device}" + return lws_resource_name + + @staticmethod + def _get_lws_service_resource_name(k8s_resource_group_name: str): + return f"{k8s_resource_group_name}-leader" + async def _create_or_update_resources( self, request: CreateOrUpdateResourcesRequest, sqs_queue_name: Optional[str] = None, sqs_queue_url: Optional[str] = None, - ) -> None: + ) -> str: + """ + Returns a "destination", which is how to address the endpoint, either through + sqs or through a k8s service. + """ sqs_queue_name_str = sqs_queue_name or "" sqs_queue_url_str = sqs_queue_url or "" build_endpoint_request = request.build_endpoint_request @@ -1159,29 +1545,56 @@ async def _create_or_update_resources( k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name( build_endpoint_request.model_endpoint_record.id ) + is_multinode = build_endpoint_request.nodes_per_worker > 1 - deployment_resource_name = self._get_deployment_resource_name(request) - deployment_arguments = get_endpoint_resource_arguments_from_request( - k8s_resource_group_name=k8s_resource_group_name, - request=request, - sqs_queue_name=sqs_queue_name_str, - sqs_queue_url=sqs_queue_url_str, - endpoint_resource_name=deployment_resource_name, - ) - deployment_template = load_k8s_yaml( - f"{deployment_resource_name}.yaml", deployment_arguments - ) - if isinstance( - request.build_endpoint_request.model_endpoint_record.current_model_bundle.flavor, - RunnableImageLike, - ): - add_datadog_env_to_main_container(deployment_template) - await self._create_deployment( - model_endpoint_record=request.build_endpoint_request.model_endpoint_record, - deployment=deployment_template, - name=k8s_resource_group_name, - ) + # Create LWS/Deployment + if is_multinode: + lws_resource_name = self._get_lws_resource_name(request) + lws_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name=lws_resource_name, + ) + lws_template = load_k8s_yaml(f"{lws_resource_name}.yaml", lws_arguments) + leader_template = get_leader_container_from_lws_template(lws_template) + worker_template = get_worker_container_from_lws_template(lws_template) + add_lws_default_env_vars_to_container(leader_template) + add_lws_default_env_vars_to_container(worker_template) + add_datadog_env_to_container(lws_template, leader_template) + add_datadog_env_to_container(lws_template, worker_template) + await self._create_lws( + lws=lws_template, + name=k8s_resource_group_name, + ) + k8s_service_name = self._get_lws_service_resource_name(k8s_resource_group_name) + else: + deployment_resource_name = self._get_deployment_resource_name(request) + deployment_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name=deployment_resource_name, + ) + deployment_template = load_k8s_yaml( + f"{deployment_resource_name}.yaml", deployment_arguments + ) + if isinstance( + request.build_endpoint_request.model_endpoint_record.current_model_bundle.flavor, + RunnableImageLike, + ): + user_container = get_main_container_from_deployment_template(deployment_template) + add_datadog_env_to_container(deployment_template, user_container) + await self._create_deployment( + model_endpoint_record=request.build_endpoint_request.model_endpoint_record, + deployment=deployment_template, + name=k8s_resource_group_name, + ) + k8s_service_name = k8s_resource_group_name + # Create ConfigMaps user_config_arguments = get_endpoint_resource_arguments_from_request( k8s_resource_group_name=k8s_resource_group_name, request=request, @@ -1208,6 +1621,7 @@ async def _create_or_update_resources( name=f"{k8s_resource_group_name}-endpoint-config", ) + # Create VPA if request.build_endpoint_request.optimize_costs: vpa_arguments = get_endpoint_resource_arguments_from_request( k8s_resource_group_name=k8s_resource_group_name, @@ -1222,23 +1636,33 @@ async def _create_or_update_resources( name=k8s_resource_group_name, ) - pdb_config_arguments = get_endpoint_resource_arguments_from_request( - k8s_resource_group_name=k8s_resource_group_name, - request=request, - sqs_queue_name=sqs_queue_name_str, - sqs_queue_url=sqs_queue_url_str, - endpoint_resource_name="pod-disruption-budget", - ) - pdb_template = load_k8s_yaml("pod-disruption-budget.yaml", pdb_config_arguments) - await self._create_pdb( - pdb=pdb_template, - name=k8s_resource_group_name, - ) + # Create PDB + if not is_multinode: + # Only create PDB if we're not using LWS + pdb_config_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="pod-disruption-budget", + ) + pdb_template = load_k8s_yaml("pod-disruption-budget.yaml", pdb_config_arguments) + await self._create_pdb( + pdb=pdb_template, + name=k8s_resource_group_name, + ) - if model_endpoint_record.endpoint_type in { - ModelEndpointType.SYNC, - ModelEndpointType.STREAMING, - }: + # Create HPA/Keda Scaled Object, Service (one of two types), VirtualService, DestinationRule, ServiceEntry + # as needed + if ( + model_endpoint_record.endpoint_type + in { + ModelEndpointType.SYNC, + ModelEndpointType.STREAMING, + } + and not is_multinode + ): + # Don't need HPA, keda, istio resources for LWS or async endpoints cluster_version = get_kubernetes_cluster_version() # For k8s cluster versions 1.23 - 1.25 we need to use the v2beta2 api # For 1.26+ v2beta2 has been deperecated and merged into v2 @@ -1298,7 +1722,7 @@ async def _create_or_update_resources( service_template = load_k8s_yaml("service.yaml", service_arguments) await self._create_service( service=service_template, - name=k8s_resource_group_name, + name=k8s_service_name, ) # TODO wsong: add flag to use istio and use these arguments @@ -1332,6 +1756,59 @@ async def _create_or_update_resources( destination_rule=destination_rule_template, name=k8s_resource_group_name, ) + elif ( + model_endpoint_record.endpoint_type + in { + ModelEndpointType.SYNC, + ModelEndpointType.STREAMING, + } + and is_multinode + ): + # Only create the service (and serviceEntry if istio is enabled) + service_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + service_name_override=k8s_service_name, + endpoint_resource_name="lws-service", + ) + service_template = load_k8s_yaml("lws-service.yaml", service_arguments) + await self._create_service( + service=service_template, + name=k8s_service_name, + ) + + if hmi_config.istio_enabled: + # If Istio is enabled, we also create a ServiceEntry. This is in service of the hack + # where we manually resolve the IP address of the K8s service created above. + # We empirically need to create this in order for the request to the service's IP address + # to go through. See live_{sync,streaming}_model_endpoint_inference_gateway.py for more details. + lws_service_entry_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="lws-service-entry", + service_name_override=k8s_service_name, + ) + lws_service_entry_template = load_k8s_yaml( + "lws-service-entry.yaml", lws_service_entry_arguments + ) + await self._create_lws_service_entry( + lws_service_entry=lws_service_entry_template, + name=k8s_resource_group_name, + ) + if model_endpoint_record.endpoint_type in { + ModelEndpointType.SYNC, + ModelEndpointType.STREAMING, + }: + return k8s_service_name + elif model_endpoint_record.endpoint_type == ModelEndpointType.ASYNC: + return sqs_queue_name_str + else: + # We should never get here + raise ValueError(f"Unsupported endpoint type {model_endpoint_record.endpoint_type}") @staticmethod def _get_vertical_autoscaling_params( @@ -1390,25 +1867,54 @@ def _get_sync_autoscaling_params_from_keda( async def _get_resources( self, endpoint_id: str, deployment_name: str, endpoint_type: ModelEndpointType ) -> ModelEndpointInfraState: - apps_client = get_kubernetes_apps_client() + custom_objects_client = get_kubernetes_custom_objects_client() k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + + logger.info( + f"trying to find lws at {k8s_resource_group_name}, {hmi_config.endpoint_namespace}" + ) try: - deployment_config = await apps_client.read_namespaced_deployment( - name=k8s_resource_group_name, namespace=hmi_config.endpoint_namespace + lws_config = await custom_objects_client.get_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + name=k8s_resource_group_name, ) except ApiException as e: - if e.status == 404: - logger.warning( - f"Could not find resource, falling back to legacy deployment_name: " - f"{k8s_resource_group_name=}, {endpoint_id=}, {deployment_name=}" - ) - k8s_resource_group_name = deployment_name - deployment_config = await apps_client.read_namespaced_deployment( - name=k8s_resource_group_name, - namespace=hmi_config.endpoint_namespace, - ) - else: - raise + # Need to handle the case where lws CRD isn't installed as well as the lws not existing. + logger.info(e) + lws_config = None + + # Make the call here so we can use it in both places, also this makes _get_resources_from_lws_type make zero requests to k8s + config_maps = await self._get_config_maps( + endpoint_id=endpoint_id, deployment_name=k8s_resource_group_name + ) + + if lws_config is None: + infra_state = await self._get_resources_from_deployment_type( + endpoint_id=endpoint_id, + deployment_name=deployment_name, + endpoint_type=endpoint_type, + config_maps=config_maps, + ) + else: + infra_state = await self._get_resources_from_lws_type( + endpoint_id=endpoint_id, + deployment_name=deployment_name, + endpoint_type=endpoint_type, + lws_config=lws_config, + config_maps=config_maps, + ) + return infra_state + + async def _get_resources_from_deployment_type( + self, endpoint_id: str, deployment_name: str, endpoint_type: ModelEndpointType, config_maps + ) -> ModelEndpointInfraState: + custom_objects_client = get_kubernetes_custom_objects_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + + deployment_config = await self._get_deployment(endpoint_id, deployment_name) common_params = self._get_common_endpoint_params(deployment_config) if endpoint_type == ModelEndpointType.ASYNC: @@ -1449,7 +1955,7 @@ async def _get_resources( raise ValueError(f"Unexpected endpoint type {endpoint_type}") vertical_autoscaling_params = None - custom_objects_client = get_kubernetes_custom_objects_client() + try: vpa_config = await custom_objects_client.get_namespaced_custom_object( group="autoscaling.k8s.io", @@ -1463,9 +1969,6 @@ async def _get_resources( if e.status == 404: pass - config_maps = await self._get_config_maps( - endpoint_id=endpoint_id, deployment_name=k8s_resource_group_name - ) launch_container = self._get_launch_container(deployment_config) envlist = launch_container.env # Note: the env var PREWARM is either "true" or "false" string (or doesn't exist for legacy) @@ -1497,6 +2000,7 @@ async def _get_resources( gpu_type=common_params["gpu_type"], # type: ignore memory=common_params["memory"], storage=common_params["storage"], + nodes_per_worker=1, # We're in "Deployment" case thus nodes_per_worker=1 optimize_costs=(vertical_autoscaling_params is not None), ), user_config_state=self._translate_k8s_config_maps_to_user_config_data( @@ -1505,6 +2009,63 @@ async def _get_resources( image=common_params["image"], num_queued_items=None, ) + + return infra_state + + async def _get_resources_from_lws_type( + self, + endpoint_id: str, + deployment_name: str, + endpoint_type: ModelEndpointType, + lws_config, + config_maps: List, + ) -> ModelEndpointInfraState: + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + + # Assume leader + worker share the same user-set env vars + common_params = self._get_common_endpoint_params_for_lws_type(lws_config) + + replicas = lws_config["spec"]["replicas"] + prewarm = False # not provided here + high_priority = ( + lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "priorityClassName" + ] + == LAUNCH_HIGH_PRIORITY_CLASS + ) + nodes_per_worker = lws_config["spec"]["leaderWorkerTemplate"]["size"] + + infra_state = ModelEndpointInfraState( + deployment_name=k8s_resource_group_name, + aws_role=common_params["aws_role"], + results_s3_bucket=common_params["results_s3_bucket"], + child_fn_info=None, + labels=common_params["labels"], + prewarm=prewarm, + high_priority=high_priority, + deployment_state=ModelEndpointDeploymentState( + min_workers=replicas, + max_workers=replicas, # We don't have any notion of autoscaling for LWS + per_worker=int(1), # TODO update this if we support LWS autoscaling + available_workers=replicas, # TODO unfortunately it doesn't look like we can get this from the LWS CRD, so this is kind of a dummy value + unavailable_workers=0, + ), + resource_state=ModelEndpointResourceState( + cpus=common_params["cpus"], + gpus=common_params["gpus"], + gpu_type=common_params["gpu_type"], # type: ignore + memory=common_params["memory"], + storage=common_params["storage"], + nodes_per_worker=nodes_per_worker, + optimize_costs=False, + ), + user_config_state=self._translate_k8s_config_maps_to_user_config_data( + k8s_resource_group_name, config_maps + ), + image=common_params["image"], + num_queued_items=None, + ) + return infra_state async def _get_all_resources( @@ -1550,16 +2111,34 @@ async def _get_all_resources( else: raise + try: + leader_worker_sets = ( + await custom_objects_client.list_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + ) + )["items"] + except ApiException as e: + if e.status == 404: + leader_worker_sets = [] + else: + raise + deployments_by_name = {deployment.metadata.name: deployment for deployment in deployments} hpas_by_name = {hpa.metadata.name: hpa for hpa in hpas} vpas_by_name = {vpa["metadata"]["name"]: vpa for vpa in vpas} keda_scaled_objects_by_name = {kso["metadata"]["name"]: kso for kso in keda_scaled_objects} + leader_worker_sets_by_name = {lws["metadata"]["name"]: lws for lws in leader_worker_sets} all_config_maps = await self._get_all_config_maps() # can safely assume hpa with same name as deployment corresponds to the same Launch Endpoint logger.info(f"Orphaned hpas: {set(hpas_by_name).difference(set(deployments_by_name))}") logger.info(f"Orphaned vpas: {set(vpas_by_name).difference(set(deployments_by_name))}") infra_states = {} - logger.info(f"Got data for {list(deployments_by_name.keys())}") + logger.info( + f"Got data for {list(deployments_by_name.keys())} and {list(leader_worker_sets_by_name.keys())}" + ) for name, deployment_config in deployments_by_name.items(): try: hpa_config = hpas_by_name.get(name, None) @@ -1616,6 +2195,7 @@ async def _get_all_resources( gpu_type=common_params["gpu_type"], # type: ignore memory=common_params["memory"], storage=common_params["storage"], + nodes_per_worker=1, # We're in a Deployment case, so nodes_per_worker is 1 optimize_costs=(vertical_autoscaling_params is not None), ), user_config_state=self._translate_k8s_config_maps_to_user_config_data( @@ -1634,9 +2214,27 @@ async def _get_all_resources( infra_states[key] = (is_key_an_endpoint_id, infra_state) except Exception: logger.exception(f"Error parsing deployment {name}") + for name, lws_config in leader_worker_sets_by_name.items(): + # name.startswith("launch-endpoint-id-") should always be true, the other case is a legacy. + key = _k8s_resource_group_name_to_endpoint_id(name) + is_key_an_endpoint_id = True + endpoint_id = key + deployment_name = name + endpoint_type = ( + ModelEndpointType.STREAMING + ) # TODO change if we ever support other endpoint types + infra_states[key] = ( + is_key_an_endpoint_id, + await self._get_resources_from_lws_type( + endpoint_id, deployment_name, endpoint_type, lws_config, all_config_maps + ), + ) return infra_states async def _delete_resources_async(self, endpoint_id: str, deployment_name: str) -> bool: + + # TODO check that this implementation actually works for multinode if/when we decide to support that + lws_delete_succeeded = await self._delete_lws(endpoint_id=endpoint_id) deployment_delete_succeeded = await self._delete_deployment( endpoint_id=endpoint_id, deployment_name=deployment_name ) @@ -1645,9 +2243,11 @@ async def _delete_resources_async(self, endpoint_id: str, deployment_name: str) ) await self._delete_vpa(endpoint_id=endpoint_id) await self._delete_pdb(endpoint_id=endpoint_id) - return deployment_delete_succeeded and config_map_delete_succeeded + return (deployment_delete_succeeded or lws_delete_succeeded) and config_map_delete_succeeded async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) -> bool: + lws_delete_succeeded = await self._delete_lws(endpoint_id=endpoint_id) + deployment_delete_succeeded = await self._delete_deployment( endpoint_id=endpoint_id, deployment_name=deployment_name, @@ -1658,6 +2258,9 @@ async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) - service_delete_succeeded = await self._delete_service( endpoint_id=endpoint_id, deployment_name=deployment_name ) + lws_service_delete_succeeded = await self._delete_lws_service( + endpoint_id=endpoint_id, deployment_name=deployment_name + ) # we should have created exactly one of an HPA or a keda scaled object hpa_delete_succeeded = await self._delete_hpa( endpoint_id=endpoint_id, deployment_name=deployment_name @@ -1667,6 +2270,7 @@ async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) - ) await self._delete_vpa(endpoint_id=endpoint_id) await self._delete_pdb(endpoint_id=endpoint_id) + await self._delete_lws_service_entry(endpoint_id=endpoint_id) destination_rule_delete_succeeded = await self._delete_destination_rule( endpoint_id=endpoint_id @@ -1682,4 +2286,4 @@ async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) - and (hpa_delete_succeeded or keda_scaled_object_succeeded) and destination_rule_delete_succeeded and virtual_service_delete_succeeded - ) + ) or (lws_delete_succeeded and config_map_delete_succeeded and lws_service_delete_succeeded) diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 8a6f0a8e..32af085e 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -211,6 +211,12 @@ class _TritonArguments(TypedDict): TRITON_COMMIT_TAG: str +class _LeaderWorkerSetArguments(TypedDict): + LWS_SIZE: int + WORKER_COMMAND: List[str] + WORKER_ENV: List[Dict[str, Any]] + + class DeploymentRunnableImageSyncCpuArguments( _RunnableImageDeploymentArguments, _SyncRunnableImageDeploymentArguments ): @@ -289,6 +295,15 @@ class DeploymentTritonEnhancedRunnableImageAsyncGpuArguments( """ +class LeaderWorkerSetRunnableImageStreamingGpuArguments( + _RunnableImageDeploymentArguments, + _StreamingDeploymentArguments, + _GpuArguments, + _LeaderWorkerSetArguments, +): + """Keyword-arguments for substituting into GPU streaming LeaderWorkerSet templates for runnable images.""" + + class HorizontalPodAutoscalerArguments(_BaseEndpointArguments): """Keyword-arguments for substituting into horizontal pod autoscaler templates.""" @@ -334,6 +349,13 @@ class ServiceArguments(_BaseEndpointArguments): NODE_PORT_DICT: DictStrInt +class LwsServiceArguments(ServiceArguments): + """Keyword-arguments for substituting into service templates for LWS. + Need this to override the service name for LWS.""" + + SERVICE_NAME_OVERRIDE: str + + class DestinationRuleArguments(_BaseEndpointArguments): """Keyword-arguments for substituting into destination-rule templates.""" @@ -357,6 +379,12 @@ class VirtualServiceArguments(_BaseEndpointArguments): DNS_HOST_DOMAIN: str +class LwsServiceEntryArguments(_BaseEndpointArguments): + """Keyword-arguments for substituting into istio service-entry templates to support LWS.""" + + SERVICE_NAME_OVERRIDE: str + + class BatchJobOrchestrationJobArguments(_JobArguments): """Keyword-arguments for substituting into batch-job-orchestration-job templates.""" @@ -483,6 +511,7 @@ def get_endpoint_resource_arguments_from_request( sqs_queue_url: str, endpoint_resource_name: str, api_version: str = "", + service_name_override: Optional[str] = None, ) -> EndpointResourceArguments: """Get the arguments for the endpoint resource templates from the request. @@ -502,6 +531,8 @@ def get_endpoint_resource_arguments_from_request( sqs_profile = f"eks-{infra_config().profile_ml_worker}" # TODO: Make this configurable s3_bucket = infra_config().s3_bucket + service_name_override = service_name_override or k8s_resource_group_name + storage_dict = DictStrStr("") if storage is not None: storage_dict = DictStrStr(f'ephemeral-storage: "{storage}"') @@ -543,6 +574,17 @@ def get_endpoint_resource_arguments_from_request( if abs_account_name is not None: main_env.append({"name": "ABS_ACCOUNT_NAME", "value": abs_account_name}) + # LeaderWorkerSet exclusive + worker_env = None + if isinstance(flavor, RunnableImageLike) and flavor.worker_env is not None: + worker_env = [{"name": key, "value": value} for key, value in flavor.worker_env.items()] + worker_env.append({"name": "AWS_PROFILE", "value": build_endpoint_request.aws_role}) + worker_env.append({"name": "AWS_CONFIG_FILE", "value": "/opt/.aws/config"}) + + worker_command = None + if isinstance(flavor, RunnableImageLike) and flavor.worker_command is not None: + worker_command = flavor.worker_command + infra_service_config_volume_mount_path = "/infra-config" forwarder_config_file_name = "service--forwarder.yaml" if ( @@ -1088,6 +1130,61 @@ def get_endpoint_resource_arguments_from_request( TRITON_COMMAND=triton_command, TRITON_COMMIT_TAG=flavor.triton_commit_tag, ) + elif endpoint_resource_name == "leader-worker-set-streaming-gpu": + assert isinstance(flavor, StreamingEnhancedRunnableImageFlavor) + assert build_endpoint_request.gpu_type is not None + assert worker_command is not None + assert worker_env is not None + return LeaderWorkerSetRunnableImageStreamingGpuArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + # Base deployment arguments + CHANGE_CAUSE_MESSAGE=change_cause_message, + AWS_ROLE=build_endpoint_request.aws_role, + PRIORITY=priority, + IMAGE=request.image, + IMAGE_HASH=image_hash, + DD_TRACE_ENABLED=str(dd_trace_enabled), + CPUS=str(build_endpoint_request.cpus), + MEMORY=str(build_endpoint_request.memory), + STORAGE_DICT=storage_dict, + PER_WORKER=build_endpoint_request.per_worker, + MIN_WORKERS=build_endpoint_request.min_workers, + MAX_WORKERS=build_endpoint_request.max_workers, + RESULTS_S3_BUCKET=s3_bucket, + # Runnable Image Arguments + MAIN_ENV=main_env, + COMMAND=flavor.streaming_command, + PREDICT_ROUTE=flavor.predict_route, + STREAMING_PREDICT_ROUTE=flavor.streaming_predict_route, + HEALTHCHECK_ROUTE=flavor.healthcheck_route, + READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, + INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, + FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, + FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, + FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, + FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, + USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, + # Streaming Arguments + FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, + # GPU Arguments + GPU_TYPE=build_endpoint_request.gpu_type.value, + GPUS=build_endpoint_request.gpus, + # Leader Worker Set Arguments + LWS_SIZE=build_endpoint_request.nodes_per_worker, + WORKER_COMMAND=worker_command, + WORKER_ENV=worker_env, + ) elif endpoint_resource_name == "user-config": app_config_serialized = python_json_to_b64(model_bundle.app_config) return UserConfigArguments( @@ -1198,6 +1295,33 @@ def get_endpoint_resource_arguments_from_request( SERVICE_TYPE=service_type, SERVICE_TARGET_PORT=FORWARDER_PORT, ) + elif endpoint_resource_name == "lws-service": + # Use ClusterIP by default for sync endpoint. + # In Circle CI, we use a NodePort to expose the service to CI. + service_type = "ClusterIP" if not CIRCLECI else "NodePort" + if service_type == "NodePort": + node_port = get_node_port(k8s_resource_group_name) + node_port_dict = DictStrInt(f"nodePort: {node_port}") + else: + node_port_dict = DictStrInt("") + return LwsServiceArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + # Service arguments + NODE_PORT_DICT=node_port_dict, + SERVICE_TYPE=service_type, + SERVICE_TARGET_PORT=FORWARDER_PORT, + # LWS Service args + SERVICE_NAME_OVERRIDE=service_name_override, + ) elif endpoint_resource_name == "virtual-service": return VirtualServiceArguments( # Base resource arguments @@ -1225,6 +1349,21 @@ def get_endpoint_resource_arguments_from_request( OWNER=owner, GIT_TAG=GIT_TAG, ) + elif endpoint_resource_name == "lws-service-entry": + return LwsServiceEntryArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + # LWS Service Entry args + SERVICE_NAME_OVERRIDE=service_name_override, + ) elif endpoint_resource_name == "vertical-pod-autoscaler": return VerticalPodAutoscalerArguments( # Base resource arguments diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index fb637c10..d884ab17 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -59,9 +59,7 @@ async def create_or_update_resources( q = await self.create_queue(endpoint_record, request.build_endpoint_request.labels) queue_name: Optional[str] = q.queue_name queue_url: Optional[str] = q.queue_url - destination: str = q.queue_name else: - destination = f"launch-endpoint-id-{endpoint_record.id.replace('_', '-')}" queue_name = None queue_url = None @@ -70,7 +68,7 @@ async def create_or_update_resources( endpoint_record.id ) - await self.k8s_delegate.create_or_update_resources( + destination: str = await self.k8s_delegate.create_or_update_resources( request=request, sqs_queue_name=queue_name, sqs_queue_url=queue_url, diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 4a61d564..dde09282 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -7,12 +7,13 @@ metadata: name: model-engine-service-template-config labels: team: infra + app.kubernetes.io/version: 852d5899343633774f6a3543e4ed9e977a533e5a + tags.datadoghq.com/version: 852d5899343633774f6a3543e4ed9e977a533e5a + tags.datadoghq.com/env: circleci + env: circleci product: model-engine - helm.sh/chart: model-engine-0.1.0 + helm.sh/chart: model-engine-0.1.3 app.kubernetes.io/managed-by: Helm - app.kubernetes.io/version: a93c7fe34529efde2b468b9cbbf3abf300308164 - tags.datadoghq.com/version: a93c7fe34529efde2b468b9cbbf3abf300308164 - tags.datadoghq.com/env: circleci annotations: "helm.sh/hook": pre-install,pre-upgrade "helm.sh/hook-weight": "-2" @@ -99,19 +100,16 @@ data: values: - "True" topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 + terminationGracePeriodSeconds: 1800 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.celery_forwarder @@ -127,9 +125,13 @@ data: - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" - --num-workers - "${PER_WORKER}" + - --broker-type + - redis env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -142,6 +144,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -178,7 +182,7 @@ data: - name: infra-service-config-volume mountPath: /workspace/model-engine/model_engine_server/core/configs - name: tritonserver - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton + image: nvidia/tritonserver:${TRITON_COMMIT_TAG}-triton imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -189,6 +193,8 @@ data: env: - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" ports: - containerPort: 8000 name: http @@ -236,6 +242,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -253,10 +260,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -272,7 +275,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -370,19 +373,16 @@ data: values: - "True" topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 + terminationGracePeriodSeconds: 1800 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.celery_forwarder @@ -398,9 +398,13 @@ data: - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" - --num-workers - "${PER_WORKER}" + - --broker-type + - redis env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -413,6 +417,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -463,6 +469,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -480,10 +487,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -499,7 +502,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -591,17 +594,14 @@ data: topologyKey: kubernetes.io/hostname terminationGracePeriodSeconds: 600 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.http_forwarder @@ -617,13 +617,11 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --set - - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" - - --set - - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -636,6 +634,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -650,9 +650,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -677,7 +678,7 @@ data: - containerPort: ${FORWARDER_PORT} name: http - name: tritonserver - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton + image: nvidia/tritonserver:${TRITON_COMMIT_TAG}-triton imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -688,6 +689,8 @@ data: env: - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" ports: - containerPort: 8000 name: http @@ -735,6 +738,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -752,10 +756,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -771,7 +771,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -863,17 +863,14 @@ data: topologyKey: kubernetes.io/hostname terminationGracePeriodSeconds: 600 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.http_forwarder @@ -889,13 +886,11 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --set - - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" - - --set - - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -908,6 +903,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -922,9 +919,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -963,6 +961,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -980,10 +979,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -999,7 +994,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -1091,17 +1086,14 @@ data: topologyKey: kubernetes.io/hostname terminationGracePeriodSeconds: 600 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.http_forwarder @@ -1126,6 +1118,8 @@ data: env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -1138,6 +1132,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -1152,9 +1148,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -1193,6 +1190,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -1210,10 +1208,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -1229,7 +1223,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -1327,10 +1321,9 @@ data: values: - "True" topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 + terminationGracePeriodSeconds: 1800 serviceAccount: default nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -1339,12 +1332,11 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.celery_forwarder @@ -1360,9 +1352,13 @@ data: - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" - --num-workers - "${PER_WORKER}" + - --broker-type + - redis env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -1375,6 +1371,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -1411,7 +1409,7 @@ data: - name: infra-service-config-volume mountPath: /workspace/model-engine/model_engine_server/core/configs - name: tritonserver - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton + image: nvidia/tritonserver:${TRITON_COMMIT_TAG}-triton imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -1422,6 +1420,8 @@ data: env: - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" ports: - containerPort: 8000 name: http @@ -1469,6 +1469,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: nvidia.com/gpu: ${GPUS} @@ -1488,10 +1489,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -1507,7 +1504,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -1605,10 +1602,9 @@ data: values: - "True" topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 + terminationGracePeriodSeconds: 1800 serviceAccount: default nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -1617,12 +1613,11 @@ data: priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.celery_forwarder @@ -1638,9 +1633,13 @@ data: - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" - --num-workers - "${PER_WORKER}" + - --broker-type + - redis env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -1653,6 +1652,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -1703,6 +1704,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: nvidia.com/gpu: ${GPUS} @@ -1722,10 +1724,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -1741,7 +1739,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -1834,7 +1832,6 @@ data: terminationGracePeriodSeconds: 600 serviceAccount: default nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -1843,12 +1840,11 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.http_forwarder @@ -1864,13 +1860,11 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --set - - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" - - --set - - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -1883,6 +1877,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -1897,9 +1893,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -1924,7 +1921,7 @@ data: - containerPort: ${FORWARDER_PORT} name: http - name: tritonserver - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton + image: nvidia/tritonserver:${TRITON_COMMIT_TAG}-triton imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -1935,6 +1932,8 @@ data: env: - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" ports: - containerPort: 8000 name: http @@ -1982,6 +1981,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: nvidia.com/gpu: ${GPUS} @@ -2001,10 +2001,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -2020,7 +2016,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -2113,7 +2109,6 @@ data: terminationGracePeriodSeconds: 600 serviceAccount: default nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -2122,12 +2117,11 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.http_forwarder @@ -2143,13 +2137,11 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --set - - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" - - --set - - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -2162,6 +2154,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -2176,9 +2170,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -2217,6 +2212,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: nvidia.com/gpu: ${GPUS} @@ -2236,10 +2232,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -2255,7 +2247,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -2348,7 +2340,6 @@ data: terminationGracePeriodSeconds: 600 serviceAccount: default nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -2357,12 +2348,11 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - model_engine_server.inference.forwarding.http_forwarder @@ -2387,6 +2377,8 @@ data: env: - name: DD_TRACE_ENABLED value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV @@ -2399,6 +2391,8 @@ data: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH @@ -2413,9 +2407,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -2454,6 +2449,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: nvidia.com/gpu: ${GPUS} @@ -2473,10 +2469,6 @@ data: name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/modelengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -2492,7 +2484,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -2628,6 +2620,335 @@ data: enableTLS: "false" unsafeSsl: "false" databaseIndex: "${REDIS_DB_INDEX}" + leader-worker-set-streaming-gpu.yaml: |- + apiVersion: leaderworkerset.x-k8s.io/v1 + kind: LeaderWorkerSet + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + replicas: ${MIN_WORKERS} + leaderWorkerTemplate: + size: ${LWS_SIZE} + restartPolicy: RecreateGroupOnPodRestart # TODO un-hardcode? if necessary + leaderTemplate: + metadata: + labels: + app: ${RESOURCE_NAME} + role: leader + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + sidecar.istio.io/inject: "false" # Never inject istio, it screws up networking + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' + kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" + spec: + affinity: + podAffinity: + preferredDuringSchedulingIgnoredDuringExecution: + - weight: 1 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: app + operator: In + values: + - ${RESOURCE_NAME} + topologyKey: kubernetes.io/hostname + - weight: 100 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: ${IMAGE_HASH} + operator: In + values: + - "True" + topologyKey: kubernetes.io/hostname + terminationGracePeriodSeconds: 600 + serviceAccount: default + nodeSelector: + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + priorityClassName: ${PRIORITY} + containers: + - name: http-forwarder + image: model-engine:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml + - --port + - "${FORWARDER_PORT}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.stream.predict_route=${STREAMING_PREDICT_ROUTE}" + - --set + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + env: + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" + - name: DD_SERVICE + value: "${ENDPOINT_NAME}" + - name: DD_ENV + value: circleci + - name: DD_VERSION + value: "${GIT_TAG}" + - name: DD_AGENT_HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + - name: AWS_PROFILE + value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config + - name: RESULTS_S3_BUCKET + value: "${RESULTS_S3_BUCKET}" + - name: BASE_PATH + value: "/workspace" + - name: ML_INFRA_SERVICES_CONFIG_PATH + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" + - name: HTTP_HOST + value: "0.0.0.0" + readinessProbe: + httpGet: + path: /readyz + port: ${FORWARDER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + + + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + - name: user-config + mountPath: /workspace/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /workspace/endpoint_config + subPath: raw_data + - name: infra-service-config-volume + mountPath: /workspace/model-engine/model_engine_server/core/configs + ports: + - containerPort: ${FORWARDER_PORT} + name: http + - name: lws-leader + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + env: ${MAIN_ENV} + readinessProbe: + httpGet: + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + nvidia.com/gpu: ${GPUS} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + nvidia.com/gpu: ${GPUS} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + - mountPath: /dev/shm + name: dshm + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + - name: user-config + mountPath: /app/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /app/endpoint_config + subPath: raw_data + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http + volumes: + - name: config-volume + configMap: + name: default-config + - name: user-config + configMap: + name: ${RESOURCE_NAME} + - name: endpoint-config + configMap: + name: ${RESOURCE_NAME}-endpoint-config + - name: dshm + emptyDir: + medium: Memory + - name: infra-service-config-volume + configMap: + name: model-engine-service-config + items: + - key: infra_service_config + path: config.yaml + workerTemplate: + metadata: + labels: + app: ${RESOURCE_NAME} + role: worker + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + sidecar.istio.io/inject: "false" # Never inject istio for LWS, it screws up networking + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' + kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" + spec: + affinity: + podAffinity: + preferredDuringSchedulingIgnoredDuringExecution: + - weight: 1 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: app + operator: In + values: + - ${RESOURCE_NAME} + topologyKey: kubernetes.io/hostname + - weight: 100 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: ${IMAGE_HASH} + operator: In + values: + - "True" + topologyKey: kubernetes.io/hostname + terminationGracePeriodSeconds: 600 + serviceAccount: default + nodeSelector: + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + priorityClassName: ${PRIORITY} + containers: + - name: lws-worker + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${WORKER_COMMAND} + env: ${WORKER_ENV} + resources: + requests: + nvidia.com/gpu: ${GPUS} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + nvidia.com/gpu: ${GPUS} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + - mountPath: /dev/shm + name: dshm + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + - name: user-config + mountPath: /app/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /app/endpoint_config + subPath: raw_data + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http + volumes: + - name: config-volume + configMap: + name: default-config + - name: user-config + configMap: + name: ${RESOURCE_NAME} + - name: endpoint-config + configMap: + name: ${RESOURCE_NAME}-endpoint-config + - name: dshm + emptyDir: + medium: Memory + - name: infra-service-config-volume + configMap: + name: model-engine-service-config + items: + - key: infra_service_config + path: config.yaml # mode # device service.yaml: |- apiVersion: v1 kind: Service @@ -2658,6 +2979,37 @@ data: protocol: TCP name: http ${NODE_PORT_DICT} + lws-service.yaml: |- + apiVersion: v1 + kind: Service + metadata: + name: ${SERVICE_NAME_OVERRIDE} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + type: ${SERVICE_TYPE} + selector: + app: ${RESOURCE_NAME} + role: leader + ports: + - port: 80 + targetPort: ${SERVICE_TARGET_PORT} + protocol: TCP + name: http + ${NODE_PORT_DICT} virtual-service.yaml: |- apiVersion: networking.istio.io/v1alpha3 kind: VirtualService @@ -2714,6 +3066,35 @@ data: trafficPolicy: loadBalancer: simple: LEAST_REQUEST + lws-service-entry.yaml: |- + apiVersion: networking.istio.io/v1beta1 + kind: ServiceEntry + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + hosts: + - "${SERVICE_NAME_OVERRIDE}.${NAMESPACE}.svc.cluster.local" + location: MESH_EXTERNAL + ports: + - number: 80 + name: http + protocol: HTTP + resolution: NONE vertical-pod-autoscaler.yaml: |- apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler @@ -2794,7 +3175,10 @@ data: tags.datadoghq.com/env: circleci tags.datadoghq.com/version: ${GIT_TAG} launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} spec: backoffLimit: 0 activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} @@ -2813,7 +3197,10 @@ data: tags.datadoghq.com/env: circleci tags.datadoghq.com/version: ${GIT_TAG} launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} sidecar.istio.io/inject: "false" version: v1 annotations: @@ -2821,8 +3208,6 @@ data: cluster-autoscaler.kubernetes.io/safe-to-evict: "false" spec: restartPolicy: Never - nodeSelector: - node-lifecycle: normal serviceAccountName: model-engine volumes: - name: config-volume @@ -2830,11 +3215,15 @@ data: name: default-config containers: - name: main - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} env: - name: DD_SERVICE value: ${RESOURCE_NAME} + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" - name: DD_TRACE_ENABLED + value: "true" + - name: DD_REMOTE_CONFIGURATION_ENABLED value: "false" - name: DD_ENV value: circleci @@ -2847,12 +3236,19 @@ data: value: http://model-engine.default:80 - name: AWS_PROFILE value: default + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: ECR_READ_AWS_PROFILE value: default + - name: DB_SECRET_AWS_PROFILE + value: default - name: S3_WRITE_AWS_PROFILE value: default - - name: DB_SECRET_NAME - value: prod/ml_infra_pg + - name: ML_INFRA_DATABASE_URL + valueFrom: + secretKeyRef: + key: database_url + name: model-engine-postgres-credentials - name: DEPLOY_SERVICE_CONFIG_PATH value: /workspace/model-engine/service_configs/service_config.yaml - name: ML_INFRA_SERVICES_CONFIG_PATH @@ -2871,7 +3267,6 @@ data: command: - dumb-init - -- - - ddtrace-run args: - python - -m @@ -2915,11 +3310,17 @@ data: tags.datadoghq.com/env: circleci tags.datadoghq.com/version: ${GIT_TAG} launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} spec: backoffLimit: 0 activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED} + completions: ${BATCH_JOB_NUM_WORKERS} + parallelism: ${BATCH_JOB_NUM_WORKERS} + completionMode: "Indexed" template: metadata: labels: @@ -2934,15 +3335,17 @@ data: tags.datadoghq.com/env: circleci tags.datadoghq.com/version: ${GIT_TAG} launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} sidecar.istio.io/inject: "false" version: v1 annotations: ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "launch_job_id:${JOB_ID}"]}]' + cluster-autoscaler.kubernetes.io/safe-to-evict: "false" spec: restartPolicy: Never - nodeSelector: - node-lifecycle: normal serviceAccountName: default volumes: - name: config-volume @@ -2959,7 +3362,11 @@ data: env: - name: DD_SERVICE value: ${RESOURCE_NAME} + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" - name: DD_TRACE_ENABLED + value: "true" + - name: DD_REMOTE_CONFIGURATION_ENABLED value: "false" - name: DD_ENV value: circleci @@ -2972,12 +3379,19 @@ data: value: http://model-engine.default:80 - name: AWS_PROFILE value: default + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: ECR_READ_AWS_PROFILE value: default + - name: DB_SECRET_AWS_PROFILE + value: default - name: S3_WRITE_AWS_PROFILE value: default - - name: DB_SECRET_NAME - value: prod/ml_infra_pg + - name: ML_INFRA_DATABASE_URL + valueFrom: + secretKeyRef: + key: database_url + name: model-engine-postgres-credentials - name: DEPLOY_SERVICE_CONFIG_PATH value: /workspace/model-engine/service_configs/service_config.yaml - name: ML_INFRA_SERVICES_CONFIG_PATH @@ -3014,7 +3428,10 @@ data: name: dshm initContainers: - name: input-downloader - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} + env: + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" command: - python - -m @@ -3056,11 +3473,17 @@ data: tags.datadoghq.com/env: circleci tags.datadoghq.com/version: ${GIT_TAG} launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} spec: backoffLimit: 0 activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED} + completions: ${BATCH_JOB_NUM_WORKERS} + parallelism: ${BATCH_JOB_NUM_WORKERS} + completionMode: "Indexed" template: metadata: labels: @@ -3075,15 +3498,18 @@ data: tags.datadoghq.com/env: circleci tags.datadoghq.com/version: ${GIT_TAG} launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} sidecar.istio.io/inject: "false" version: v1 annotations: ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "launch_job_id:${JOB_ID}"]}]' + cluster-autoscaler.kubernetes.io/safe-to-evict: "false" spec: restartPolicy: Never nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -3105,7 +3531,11 @@ data: env: - name: DD_SERVICE value: ${RESOURCE_NAME} + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" - name: DD_TRACE_ENABLED + value: "true" + - name: DD_REMOTE_CONFIGURATION_ENABLED value: "false" - name: DD_ENV value: circleci @@ -3118,12 +3548,19 @@ data: value: http://model-engine.default:80 - name: AWS_PROFILE value: default + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: ECR_READ_AWS_PROFILE value: default + - name: DB_SECRET_AWS_PROFILE + value: default - name: S3_WRITE_AWS_PROFILE value: default - - name: DB_SECRET_NAME - value: prod/ml_infra_pg + - name: ML_INFRA_DATABASE_URL + valueFrom: + secretKeyRef: + key: database_url + name: model-engine-postgres-credentials - name: DEPLOY_SERVICE_CONFIG_PATH value: /workspace/model-engine/service_configs/service_config.yaml - name: ML_INFRA_SERVICES_CONFIG_PATH @@ -3162,7 +3599,10 @@ data: name: dshm initContainers: - name: input-downloader - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/model-engine:${GIT_TAG} + image: model-engine:${GIT_TAG} + env: + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" command: - python - -m diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_service.py index 78cea2d1..aa5029c7 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_service.py @@ -104,6 +104,7 @@ async def create_batch_job( memory=memory, # type: ignore gpu_type=gpu_type, # type: ignore storage=storage, + nodes_per_worker=1, # TODO batch jobs currently doesn't support multinode, since async multinode isn't supported yet optimize_costs=False, min_workers=0, max_workers=max_workers, # type: ignore diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 7ae03645..ae3836c6 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -302,6 +302,7 @@ async def build_endpoint( memory=build_endpoint_request.memory, gpu_type=build_endpoint_request.gpu_type, storage=build_endpoint_request.storage, + nodes_per_worker=build_endpoint_request.nodes_per_worker, optimize_costs=build_endpoint_request.optimize_costs, ), user_config_state=ModelEndpointUserConfigState( @@ -793,6 +794,13 @@ def _validate_build_endpoint_request( raise ValueError( f"Runnable image endpoints cannot set the following env vars: {restriced_env_vars}" ) + if ( + not isinstance(model_bundle.flavor, RunnableImageLike) + and build_endpoint_request.nodes_per_worker > 1 + ): + raise ValueError( + "Multi-node deployment is only supported for RunnableImageLike model bundles." + ) @staticmethod def _get_restricted_env_vars(env_vars: Dict[str, str]) -> Set[str]: diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index 0750fd84..3bf62b5e 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -146,6 +146,7 @@ async def create_model_endpoint( memory: StorageSpecificationType, gpu_type: Optional[GpuType], storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, min_workers: int, max_workers: int, @@ -193,6 +194,7 @@ async def create_model_endpoint( memory=memory, gpu_type=gpu_type, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, aws_role=aws_role, results_s3_bucket=results_s3_bucket, diff --git a/model-engine/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py index d713b25a..3ce36f6d 100644 --- a/model-engine/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -779,6 +779,7 @@ def model_endpoint_1( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -836,6 +837,7 @@ def model_endpoint_1( "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -884,6 +886,7 @@ def model_endpoint_2( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -942,6 +945,7 @@ def model_endpoint_2( "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": False, }, "image": "test_image_2", @@ -1258,6 +1262,7 @@ def create_llm_model_endpoint_request_sync() -> Dict[str, Any]: "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "1Gi", + "nodes_per_worker": 1, "min_workers": 1, "max_workers": 5, "per_worker": 3, diff --git a/model-engine/tests/unit/api/test_model_endpoints.py b/model-engine/tests/unit/api/test_model_endpoints.py index d3d8b9a6..1cc02f0b 100644 --- a/model-engine/tests/unit/api/test_model_endpoints.py +++ b/model-engine/tests/unit/api/test_model_endpoints.py @@ -225,6 +225,32 @@ def test_create_model_endpoint_endpoint_already_exists_returns_400( assert response_1.status_code == 400 +def test_create_model_endpoint_multinode_from_nonmultinode_bundle_returns_400( + model_bundle_1_v1: Tuple[ModelBundle, Any], + create_model_endpoint_request_sync: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + create_model_endpoint_request_sync["nodes_per_worker"] = 2 + response_1 = client.post( + "/v1/model-endpoints", + auth=(test_api_key, ""), + json=create_model_endpoint_request_sync, + ) + assert response_1.status_code == 400 + + def test_list_model_endpoints( model_bundle_1_v1: Tuple[ModelBundle, Any], model_endpoint_1: Tuple[ModelEndpoint, Any], diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 34f00c92..e207d30c 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -774,6 +774,7 @@ def __init__(self): "llama-2-7b": ["model-fake.safetensors"], "mpt-7b": ["model-fake.safetensors"], "llama-3-70b": ["model-fake.safetensors"], + "llama-3-1-405b-instruct": ["model-fake.safetensors"], } self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} self.model_config = { @@ -1077,6 +1078,7 @@ def create_model_endpoint_infra( memory: StorageSpecificationType, gpu_type: Optional[GpuType], storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, aws_role: str, results_s3_bucket: str, @@ -1111,6 +1113,7 @@ def create_model_endpoint_infra( gpu_type=gpu_type, memory=memory, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, ), user_config_state=ModelEndpointUserConfigState( @@ -1277,6 +1280,7 @@ async def create_or_update_resources( gpu_type=build_endpoint_request.gpu_type, memory=build_endpoint_request.memory, storage=build_endpoint_request.storage, + nodes_per_worker=build_endpoint_request.nodes_per_worker, optimize_costs=build_endpoint_request.optimize_costs, ), user_config_state=ModelEndpointUserConfigState( @@ -1514,7 +1518,10 @@ def __init__(self): ] async def streaming_predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, + topic: str, + predict_request: EndpointPredictV1Request, + manually_resolve_dns: bool = False, ) -> AsyncIterable[SyncEndpointPredictV1Response]: """ Runs a prediction request and returns a response. @@ -1535,7 +1542,10 @@ def __init__(self, fake_sync_inference_content=None): self.response = fake_sync_inference_content async def predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, + topic: str, + predict_request: EndpointPredictV1Request, + manually_resolve_dns: bool = False, ) -> SyncEndpointPredictV1Response: """ Runs a prediction request and returns a response. @@ -1743,6 +1753,7 @@ async def create_model_endpoint( memory: StorageSpecificationType, gpu_type: Optional[GpuType], storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, min_workers: int, max_workers: int, @@ -1801,6 +1812,7 @@ async def create_model_endpoint( memory=memory, gpu_type=gpu_type, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, ), user_config_state=ModelEndpointUserConfigState( @@ -2673,6 +2685,7 @@ def model_endpoint_1(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -2738,6 +2751,7 @@ def model_endpoint_2(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -2793,6 +2807,7 @@ def model_endpoint_3(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -2848,6 +2863,7 @@ def model_endpoint_4(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -2903,6 +2919,7 @@ def model_endpoint_public(test_api_key: str, model_bundle_1: ModelBundle) -> Mod memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -2968,6 +2985,7 @@ def model_endpoint_public_sync(test_api_key: str, model_bundle_1: ModelBundle) - memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -3034,6 +3052,7 @@ def model_endpoint_runnable(test_api_key: str, model_bundle_4: ModelBundle) -> M memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -3090,6 +3109,7 @@ def model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -3256,6 +3276,7 @@ def build_endpoint_request_async_runnable_image( memory="3G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, broker_type=BrokerType.SQS, default_callback_url="https://example.com", @@ -3299,6 +3320,7 @@ def build_endpoint_request_streaming_runnable_image( memory="4G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, broker_type=BrokerType.SQS, default_callback_url="https://example.com", @@ -3342,6 +3364,7 @@ def build_endpoint_request_sync_runnable_image( memory="4G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, broker_type=BrokerType.SQS, default_callback_url="https://example.com", @@ -3385,6 +3408,7 @@ def build_endpoint_request_sync_pytorch( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, broker_type=BrokerType.SQS, default_callback_url="https://example.com", @@ -3428,6 +3452,7 @@ def build_endpoint_request_async_tensorflow( memory="1G", gpu_type=None, storage=None, + nodes_per_worker=1, optimize_costs=False, default_callback_url="https://example.com/path", default_callback_auth=CallbackAuth( @@ -3470,6 +3495,7 @@ def build_endpoint_request_async_custom( memory="1G", gpu_type=None, storage=None, + nodes_per_worker=1, optimize_costs=True, broker_type=BrokerType.SQS, default_callback_url=None, @@ -3512,6 +3538,7 @@ def build_endpoint_request_async_zipartifact_highpri( memory="1G", gpu_type=None, storage=None, + nodes_per_worker=1, optimize_costs=True, broker_type=BrokerType.SQS, default_callback_url=None, @@ -3553,6 +3580,7 @@ def build_endpoint_request_sync_custom( memory="1G", gpu_type=None, storage=None, + nodes_per_worker=1, optimize_costs=True, default_callback_url=None, default_callback_auth=None, @@ -3645,6 +3673,7 @@ def llm_model_endpoint_async( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -3719,6 +3748,7 @@ def llm_model_endpoint_async( "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -3777,6 +3807,7 @@ def llm_model_endpoint_sync( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -3851,6 +3882,7 @@ def llm_model_endpoint_sync( "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -3909,6 +3941,7 @@ def llm_model_endpoint_stream( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -3983,6 +4016,7 @@ def llm_model_endpoint_stream( "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -4041,6 +4075,7 @@ def llm_model_endpoint_sync_tgi( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -4115,6 +4150,7 @@ def llm_model_endpoint_sync_tgi( "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -4173,6 +4209,7 @@ def llm_model_endpoint_sync_lightllm( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -4247,6 +4284,7 @@ def llm_model_endpoint_sync_lightllm( "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -4305,6 +4343,7 @@ def llm_model_endpoint_sync_trt_llm( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -4379,6 +4418,7 @@ def llm_model_endpoint_sync_trt_llm( "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -4437,6 +4477,7 @@ def llm_model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -4502,6 +4543,7 @@ def llm_model_endpoint_text_generation_inference( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -4575,6 +4617,7 @@ def llm_model_endpoint_trt_llm( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -4609,36 +4652,49 @@ async def async_mock(*args, **kwargs): # noqa memory: 20Gi storage: 40Gi gpu_type: nvidia-hopper-h100-1g20gb + nodes_per_worker: 1 - gpu_memory_le: 40 cpus: 10 gpus: 1 memory: 40Gi storage: 80Gi gpu_type: nvidia-hopper-h100-3g40gb + nodes_per_worker: 1 - gpu_memory_le: 80 cpus: 20 gpus: 1 memory: 80Gi storage: 96Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - gpu_memory_le: 160 cpus: 40 gpus: 2 memory: 160Gi storage: 160Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - gpu_memory_le: 320 cpus: 80 gpus: 4 memory: 320Gi storage: 320Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - gpu_memory_le: 640 cpus: 160 gpus: 8 memory: 800Gi storage: 640Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 1280 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 2 """, "byModelName": """ - name: llama-3-8b-instruct-262k @@ -4647,18 +4703,21 @@ async def async_mock(*args, **kwargs): # noqa memory: 160Gi storage: 160Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - name: deepseek-coder-v2 cpus: 160 gpus: 8 memory: 800Gi storage: 640Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 - name: deepseek-coder-v2-instruct cpus: 160 gpus: 8 memory: 800Gi storage: 640Gi gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 """, } diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 1e30911a..861e7c00 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -89,6 +89,7 @@ def create_model_endpoint_request_sync( memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -113,6 +114,7 @@ def create_model_endpoint_request_streaming( memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=1, @@ -137,6 +139,7 @@ def create_model_endpoint_request_async( memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -193,6 +196,7 @@ def create_llm_model_endpoint_request_sync() -> CreateLLMModelEndpointV1Request: memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -220,6 +224,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=0, max_workers=3, per_worker=2, @@ -247,6 +252,7 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -298,6 +304,7 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -325,6 +332,7 @@ def create_llm_model_endpoint_request_llama_3_70b() -> CreateLLMModelEndpointV1R memory="8G", gpu_type=GpuType.NVIDIA_HOPPER_H100, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -352,6 +360,7 @@ def create_llm_model_endpoint_request_llama_3_70b_chat() -> CreateLLMModelEndpoi memory="8G", gpu_type=GpuType.NVIDIA_HOPPER_H100, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -363,6 +372,34 @@ def create_llm_model_endpoint_request_llama_3_70b_chat() -> CreateLLMModelEndpoi ) +@pytest.fixture +def create_llm_model_endpoint_request_llama_3_1_405b_instruct() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_3_1_405b_instruct", + model_name="llama-3-1-405b-instruct", + source="hugging_face", + inference_framework="vllm", + inference_framework_image_tag="1.0.0", + num_shards=8, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=8, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + nodes_per_worker=2, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-3-1-405b-instruct", + ) + + @pytest.fixture def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( CreateLLMModelEndpointV1Request @@ -382,6 +419,7 @@ def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -412,6 +450,7 @@ def create_llm_model_endpoint_text_generation_inference_request_async() -> ( memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -438,6 +477,7 @@ def create_llm_model_endpoint_trt_llm_request_streaming() -> CreateLLMModelEndpo memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -466,6 +506,7 @@ def create_llm_model_endpoint_trt_llm_request_async() -> CreateLLMModelEndpointV memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -493,6 +534,7 @@ def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndp memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, @@ -520,6 +562,7 @@ def create_llm_model_endpoint_request_invalid_quantization() -> CreateLLMModelEn memory="8G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, min_workers=1, max_workers=3, per_worker=2, diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 06098ce4..4b4e45a8 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -93,6 +93,7 @@ async def test_create_model_endpoint_use_case_success( create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, create_llm_model_endpoint_request_llama_2: CreateLLMModelEndpointV1Request, create_llm_model_endpoint_request_llama_3_70b: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_request_llama_3_1_405b_instruct: CreateLLMModelEndpointV1Request, ): fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository bundle_use_case = CreateModelBundleV2UseCase( @@ -208,6 +209,20 @@ async def test_create_model_endpoint_use_case_success( ) assert " --gpu-memory-utilization 0.95" in bundle.flavor.command[-1] + response_6 = await use_case.execute( + user=user, request=create_llm_model_endpoint_request_llama_3_1_405b_instruct + ) + assert response_6.endpoint_creation_task_id + assert isinstance(response_6, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_llama_3_1_405b_instruct.name, + order_by=None, + ) + )[0] + assert endpoint.infra_state.resource_state.nodes_per_worker == 2 + @pytest.mark.asyncio @pytest.mark.parametrize( @@ -282,6 +297,7 @@ async def test_create_model_bundle_fails_if_no_checkpoint( quantize=request.quantize, checkpoint_path=checkpoint_path, chat_template_override=request.chat_template_override, + nodes_per_worker=1, ) @@ -2090,6 +2106,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "800Gi" assert hardware.storage == "640Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "deepseek-coder-v2-instruct", "", is_batch_job=True @@ -2099,6 +2116,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "800Gi" assert hardware.storage == "640Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 # deepseek lite https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct/raw/main/config.json fake_llm_artifact_gateway.model_config = { @@ -2152,6 +2170,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "80Gi" assert hardware.storage == "96Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, @@ -2164,6 +2183,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "800Gi" assert hardware.storage == "640Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, @@ -2214,6 +2234,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "20Gi" assert hardware.storage == "40Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "phi-3-mini-4k-instruct", "", is_batch_job=True @@ -2223,6 +2244,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 # Phi 3 small from https://huggingface.co/microsoft/Phi-3-small-8k-instruct/blob/main/config.json fake_llm_artifact_gateway.model_config = { @@ -2273,6 +2295,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "20Gi" assert hardware.storage == "40Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "phi-3-small-8k-instruct", "", is_batch_job=True @@ -2283,6 +2306,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "architectures": ["Phi3ForCausalLM"], @@ -2320,6 +2344,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "phi-3-medium-8k-instruct", "", is_batch_job=True @@ -2329,6 +2354,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "80Gi" assert hardware.storage == "96Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "architectures": ["MixtralForCausalLM"], @@ -2359,6 +2385,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "mixtral-8x7b", "", is_batch_job=True @@ -2368,6 +2395,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "architectures": ["MixtralForCausalLM"], @@ -2399,6 +2427,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "800Gi" assert hardware.storage == "640Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "mixtral-8x22b", "", is_batch_job=True @@ -2408,6 +2437,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "800Gi" assert hardware.storage == "640Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "_name_or_path": "meta-llama/Llama-2-7b-hf", @@ -2435,6 +2465,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "20Gi" assert hardware.storage == "40Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True) assert hardware.cpus == 10 @@ -2442,6 +2473,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "architectures": ["LlamaForCausalLM"], @@ -2470,6 +2502,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "20Gi" assert hardware.storage == "40Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True) assert hardware.cpus == 10 @@ -2477,6 +2510,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "_name_or_path": "meta-llama/Llama-2-13b-hf", @@ -2504,6 +2538,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "40Gi" assert hardware.storage == "80Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "llama-2-13b", "", is_batch_job=True @@ -2513,6 +2548,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "80Gi" assert hardware.storage == "96Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "architectures": ["LlamaForCausalLM"], @@ -2540,6 +2576,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "80Gi" assert hardware.storage == "96Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "codellama-34b", "", is_batch_job=True @@ -2549,6 +2586,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "_name_or_path": "meta-llama/Llama-2-70b-hf", @@ -2576,6 +2614,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "llama-2-70b", "", is_batch_job=True @@ -2585,6 +2624,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "320Gi" assert hardware.storage == "320Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "architectures": ["LlamaForCausalLM"], @@ -2613,6 +2653,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 hardware = await _infer_hardware( fake_llm_artifact_gateway, "llama-3-70b", "", is_batch_job=True @@ -2622,6 +2663,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "320Gi" assert hardware.storage == "320Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "_name_or_path": "gradientai/llama3-8b-stage65k-chat", @@ -2651,6 +2693,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "160Gi" assert hardware.storage == "160Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 fake_llm_artifact_gateway.model_config = { "architectures": ["Qwen2ForCausalLM"], @@ -2683,6 +2726,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.memory == "320Gi" assert hardware.storage == "320Gi" assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 with pytest.raises(ObjectHasInvalidValueException): await _infer_hardware(fake_llm_artifact_gateway, "unsupported_model", "") @@ -2713,6 +2757,7 @@ async def test_fill_hardware_info(fake_llm_artifact_gateway): assert request.memory == "160Gi" assert request.storage == "160Gi" assert request.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert request.nodes_per_worker == 1 request = CreateLLMModelEndpointV1Request( name="mixtral-8x7b", diff --git a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py index e9958b11..f6ea8ab6 100644 --- a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py @@ -618,19 +618,77 @@ async def test_create_model_endpoint_use_case_sets_high_priority( await fake_model_endpoint_service.delete_model_endpoint(endpoints[0].record.id) +@pytest.mark.asyncio +async def test_create_multinode_endpoint_with_nonmultinode_bundle_fails( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + create_model_endpoint_request_streaming: CreateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_1.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + create_model_endpoint_request_streaming.nodes_per_worker = 2 + create_model_endpoint_request_streaming.model_bundle_id = model_bundle_1.id + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute(user=user, request=create_model_endpoint_request_streaming) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("nodes_per_worker", [1, 2]) +async def test_create_multinode_or_nonmultinode_endpoint_with_multinode_bundle_succeeds( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_5: ModelBundle, + create_model_endpoint_request_streaming: CreateModelEndpointV1Request, + nodes_per_worker: int, +): + # mb5 is a streaming runnable image bundle + model_bundle_5.flavor.worker_env = {"fake_env": "fake_value"} + model_bundle_5.flavor.worker_command = ["fake_command"] + fake_model_bundle_repository.add_model_bundle(model_bundle_5) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_5.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + create_model_endpoint_request_streaming.nodes_per_worker = nodes_per_worker + create_model_endpoint_request_streaming.model_bundle_id = model_bundle_5.id + response = await use_case.execute(user=user, request=create_model_endpoint_request_streaming) + assert response.endpoint_creation_task_id + assert isinstance(response, CreateModelEndpointV1Response) + + @pytest.mark.asyncio async def test_get_model_endpoint_use_case_success( test_api_key: str, fake_model_endpoint_service, model_endpoint_1: ModelEndpoint, + model_endpoint_2: ModelEndpoint, ): + # Tests single node + multinode fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + model_endpoint_2.infra_state.resource_state.nodes_per_worker = 2 + fake_model_endpoint_service.add_model_endpoint(model_endpoint_2) use_case = GetModelEndpointByIdV1UseCase(model_endpoint_service=fake_model_endpoint_service) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response = await use_case.execute(user=user, model_endpoint_id=model_endpoint_1.record.id) assert isinstance(response, GetModelEndpointV1Response) + response_2 = await use_case.execute(user=user, model_endpoint_id=model_endpoint_2.record.id) + assert isinstance(response_2, GetModelEndpointV1Response) + assert response_2.resource_state.nodes_per_worker == 2 + @pytest.mark.asyncio async def test_get_model_endpoint_use_case_same_team_finds_endpoint( diff --git a/model-engine/tests/unit/infra/gateways/resources/example_lws_config.json b/model-engine/tests/unit/infra/gateways/resources/example_lws_config.json new file mode 100644 index 00000000..41793478 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/resources/example_lws_config.json @@ -0,0 +1,829 @@ +{ + "apiVersion": "leaderworkerset.x-k8s.io/v1", + "kind": "LeaderWorkerSet", + "metadata": { + "creationTimestamp": "2000-01-01T01:38:14Z", + "generation": 1, + "labels": { + "created_by": "userid000000", + "endpoint_id": "end_abcdefg", + "endpoint_name": "endpoint_name", + "env": "training", + "managed-by": "model-engine", + "owner": "userid000000", + "product": "testing", + "tags.datadoghq.com/env": "training", + "tags.datadoghq.com/service": "endpoint_name", + "tags.datadoghq.com/version": "05fb96620692a205e52d33980eff475d6a52748a", + "team": "infra", + "use_scale_launch_endpoint_network_policy": "true", + "user_id": "userid000000" + }, + "managedFields": [ + { + "apiVersion": "leaderworkerset.x-k8s.io/v1", + "fieldsType": "FieldsV1", + "fieldsV1": { + "f:metadata": { + "f:labels": { + ".": {}, + "f:created_by": {}, + "f:endpoint_id": {}, + "f:endpoint_name": {}, + "f:env": {}, + "f:managed-by": {}, + "f:owner": {}, + "f:product": {}, + "f:tags.datadoghq.com/env": {}, + "f:tags.datadoghq.com/service": {}, + "f:tags.datadoghq.com/version": {}, + "f:team": {}, + "f:use_scale_launch_endpoint_network_policy": {}, + "f:user_id": {} + } + }, + "f:spec": { + ".": {}, + "f:leaderWorkerTemplate": { + ".": {}, + "f:leaderTemplate": { + ".": {}, + "f:metadata": { + ".": {}, + "f:annotations": { + ".": {}, + "f:ad.datadoghq.com/main.logs": {}, + "f:kubernetes.io/change-cause": {} + }, + "f:labels": { + ".": {}, + "f:app": {}, + "f:created_by": {}, + "f:endpoint_id": {}, + "f:endpoint_name": {}, + "f:env": {}, + "f:managed-by": {}, + "f:owner": {}, + "f:product": {}, + "f:sidecar.istio.io/inject": {}, + "f:tags.datadoghq.com/env": {}, + "f:tags.datadoghq.com/service": {}, + "f:tags.datadoghq.com/version": {}, + "f:team": {}, + "f:use_scale_launch_endpoint_network_policy": {}, + "f:user_id": {}, + "f:version": {} + } + }, + "f:spec": { + ".": {}, + "f:affinity": { + ".": {}, + "f:podAffinity": { + ".": {}, + "f:preferredDuringSchedulingIgnoredDuringExecution": {} + } + }, + "f:containers": {}, + "f:nodeSelector": {}, + "f:priorityClassName": {}, + "f:serviceAccount": {}, + "f:terminationGracePeriodSeconds": {}, + "f:tolerations": {}, + "f:volumes": {} + } + }, + "f:restartPolicy": {}, + "f:size": {}, + "f:workerTemplate": { + ".": {}, + "f:metadata": { + ".": {}, + "f:annotations": { + ".": {}, + "f:ad.datadoghq.com/main.logs": {}, + "f:kubernetes.io/change-cause": {} + }, + "f:labels": { + ".": {}, + "f:app": {}, + "f:created_by": {}, + "f:endpoint_id": {}, + "f:endpoint_name": {}, + "f:env": {}, + "f:managed-by": {}, + "f:owner": {}, + "f:product": {}, + "f:sidecar.istio.io/inject": {}, + "f:tags.datadoghq.com/env": {}, + "f:tags.datadoghq.com/service": {}, + "f:tags.datadoghq.com/version": {}, + "f:team": {}, + "f:use_scale_launch_endpoint_network_policy": {}, + "f:user_id": {}, + "f:version": {} + } + }, + "f:spec": { + ".": {}, + "f:affinity": { + ".": {}, + "f:podAffinity": { + ".": {}, + "f:preferredDuringSchedulingIgnoredDuringExecution": {} + } + }, + "f:containers": {}, + "f:nodeSelector": {}, + "f:priorityClassName": {}, + "f:serviceAccount": {}, + "f:terminationGracePeriodSeconds": {}, + "f:tolerations": {}, + "f:volumes": {} + } + } + }, + "f:replicas": {}, + "f:startupPolicy": {} + } + }, + "manager": "OpenAPI-Generator", + "operation": "Update", + "time": "2000-01-01T01:38:14Z" + }, + { + "apiVersion": "leaderworkerset.x-k8s.io/v1", + "fieldsType": "FieldsV1", + "fieldsV1": { + "f:status": { + ".": {}, + "f:conditions": {}, + "f:hpaPodSelector": {} + } + }, + "manager": "manager", + "operation": "Update", + "subresource": "status", + "time": "2000-01-01T01:38:14Z" + } + ], + "name": "launch-endpoint-id-end-abcdefg", + "namespace": "scale-deploy", + "resourceVersion": "2289583184", + "uid": "1d66ad78-3148-41b3-83fd-fb71d7656fb1" + }, + "spec": { + "leaderWorkerTemplate": { + "leaderTemplate": { + "metadata": { + "annotations": { + "ad.datadoghq.com/main.logs": "[{\"service\": \"endpoint_name\", \"source\": \"python\"}]", + "kubernetes.io/change-cause": "Deployment at 2000-01-01 01:38:13.814158 UTC. Using deployment constructed from model bundle ID bun_cqi4v12d6mt002nap720, model bundle name endpoint_name, endpoint ID end_abcdefg" + }, + "labels": { + "app": "launch-endpoint-id-end-abcdefg", + "created_by": "userid000000", + "endpoint_id": "end_abcdefg", + "endpoint_name": "endpoint_name", + "env": "training", + "managed-by": "model-engine", + "owner": "userid000000", + "product": "testing", + "sidecar.istio.io/inject": "false", + "tags.datadoghq.com/env": "training", + "tags.datadoghq.com/service": "endpoint_name", + "tags.datadoghq.com/version": "05fb96620692a205e52d33980eff475d6a52748a", + "team": "infra", + "use_scale_launch_endpoint_network_policy": "true", + "user_id": "userid000000", + "version": "v1" + } + }, + "spec": { + "affinity": { + "podAffinity": { + "preferredDuringSchedulingIgnoredDuringExecution": [ + { + "podAffinityTerm": { + "labelSelector": { + "matchExpressions": [ + { + "key": "app", + "operator": "In", + "values": [ + "launch-endpoint-id-end-abcdefg" + ] + } + ] + }, + "topologyKey": "kubernetes.io/hostname" + }, + "weight": 1 + }, + { + "podAffinityTerm": { + "labelSelector": { + "matchExpressions": [ + { + "key": "3d45a96760a60018eb4a9d874e919aef", + "operator": "In", + "values": [ + "True" + ] + } + ] + }, + "topologyKey": "kubernetes.io/hostname" + }, + "weight": 100 + } + ] + } + }, + "containers": [ + { + "command": [ + "/usr/bin/dumb-init", + "--", + "ddtrace-run", + "python", + "-m", + "model_engine_server.inference.forwarding.http_forwarder", + "--config", + "/workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml", + "--port", + "5000", + "--num-workers", + "2", + "--set", + "forwarder.sync.predict_route=/predict", + "--set", + "forwarder.stream.predict_route=/stream", + "--set", + "forwarder.sync.healthcheck_route=/health", + "--set", + "forwarder.stream.healthcheck_route=/health" + ], + "env": [ + { + "name": "DD_TRACE_ENABLED", + "value": "True" + }, + { + "name": "DD_REMOTE_CONFIGURATION_ENABLED", + "value": "false" + }, + { + "name": "DD_SERVICE", + "value": "endpoint_name" + }, + { + "name": "DD_ENV", + "value": "training" + }, + { + "name": "DD_VERSION", + "value": "05fb96620692a205e52d33980eff475d6a52748a" + }, + { + "name": "DD_AGENT_HOST", + "valueFrom": { + "fieldRef": { + "fieldPath": "status.hostIP" + } + } + }, + { + "name": "AWS_PROFILE", + "value": "aws-profile" + }, + { + "name": "AWS_CONFIG_FILE", + "value": "/opt/.aws/config" + }, + { + "name": "RESULTS_S3_BUCKET", + "value": "bucket" + }, + { + "name": "BASE_PATH", + "value": "/workspace" + }, + { + "name": "ML_INFRA_SERVICES_CONFIG_PATH", + "value": "/workspace/model-engine-internal/resources/configs/infra_config_training.yaml" + } + ], + "image": "000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:tag", + "imagePullPolicy": "IfNotPresent", + "name": "http-forwarder", + "ports": [ + { + "containerPort": 5000, + "name": "http", + "protocol": "TCP" + } + ], + "readinessProbe": { + "httpGet": { + "path": "/readyz", + "port": 5000 + }, + "initialDelaySeconds": 10, + "periodSeconds": 5, + "timeoutSeconds": 5 + }, + "resources": { + "limits": { + "cpu": "1", + "ephemeral-storage": "1G", + "memory": "2Gi" + }, + "requests": { + "cpu": "1", + "ephemeral-storage": "100M", + "memory": "100M" + } + }, + "volumeMounts": [ + { + "mountPath": "/opt/.aws/config", + "name": "config-volume", + "subPath": "config" + }, + { + "mountPath": "/workspace/user_config", + "name": "user-config", + "subPath": "raw_data" + }, + { + "mountPath": "/workspace/endpoint_config", + "name": "endpoint-config", + "subPath": "raw_data" + } + ] + }, + { + "command": [ + "/bin/bash", + "-c", + "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' s3://bucket/tag/userid000000/model_weights/* model_files;/workspace/init_ray.sh leader --ray_cluster_size=$RAY_CLUSTER_SIZE --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local;python -m vllm_server --model model_files --tensor-parallel-size 1 --port 5005 --disable-log-requests--enforce-eager" + ], + "env": [ + { + "name": "VLLM_HOST_IP", + "value": "$(K8S_LEADER_NAME).$(K8S_LWS_NAME).$(K8S_OWN_NAMESPACE).svc.cluster.local" + }, + { + "name": "NCCL_SOCKET_IFNAME", + "value": "eth0" + }, + { + "name": "GLOO_SOCKET_IFNAME", + "value": "eth0" + }, + { + "name": "NCCL_DEBUG", + "value": "INFO" + }, + { + "name": "VLLM_LOGGING_LEVEL", + "value": "INFO" + }, + { + "name": "AWS_PROFILE", + "value": "aws-profile" + }, + { + "name": "AWS_CONFIG_FILE", + "value": "/opt/.aws/config" + }, + { + "name": "K8S_OWN_POD_NAME", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.name" + } + } + }, + { + "name": "K8S_OWN_NAMESPACE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.namespace" + } + } + }, + { + "name": "K8S_LWS_NAME", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['leaderworkerset.sigs.k8s.io/name']" + } + } + }, + { + "name": "K8S_LWS_CLUSTER_SIZE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.annotations['leaderworkerset.sigs.k8s.io/size']" + } + } + }, + { + "name": "DD_TRACE_ENABLED", + "value": "true" + }, + { + "name": "DD_SERVICE", + "value": "endpoint_name" + }, + { + "name": "DD_ENV", + "value": "training" + }, + { + "name": "DD_VERSION", + "value": "05fb96620692a205e52d33980eff475d6a52748a" + }, + { + "name": "DD_AGENT_HOST", + "valueFrom": { + "fieldRef": { + "fieldPath": "status.hostIP" + } + } + } + ], + "image": "000000000000.dkr.ecr.us-west-2.amazonaws.com/vllm:0.5.3.post1", + "imagePullPolicy": "IfNotPresent", + "name": "lws-leader", + "ports": [ + { + "containerPort": 5005, + "name": "http", + "protocol": "TCP" + } + ], + "readinessProbe": { + "httpGet": { + "path": "/health", + "port": 5005 + }, + "initialDelaySeconds": 10, + "periodSeconds": 5, + "timeoutSeconds": 5 + }, + "resources": { + "limits": { + "cpu": "10", + "ephemeral-storage": "94Gi", + "memory": "40Gi", + "nvidia.com/gpu": "1" + }, + "requests": { + "cpu": "10", + "ephemeral-storage": "94Gi", + "memory": "40Gi", + "nvidia.com/gpu": "1" + } + }, + "volumeMounts": [ + { + "mountPath": "/opt/.aws/config", + "name": "config-volume", + "subPath": "config" + }, + { + "mountPath": "/dev/shm", + "name": "dshm" + }, + { + "mountPath": "/app/user_config", + "name": "user-config", + "subPath": "raw_data" + }, + { + "mountPath": "/app/endpoint_config", + "name": "endpoint-config", + "subPath": "raw_data" + } + ] + } + ], + "nodeSelector": { + "k8s.amazonaws.com/accelerator": "nvidia-hopper-h100", + "node-lifecycle": "normal" + }, + "priorityClassName": "model-engine-high-priority", + "serviceAccount": "aws-profile", + "terminationGracePeriodSeconds": 600, + "tolerations": [ + { + "effect": "NoSchedule", + "key": "nvidia.com/gpu", + "operator": "Exists" + } + ], + "volumes": [ + { + "configMap": { + "name": "aws-profile-config" + }, + "name": "config-volume" + }, + { + "configMap": { + "name": "launch-endpoint-id-end-abcdefg" + }, + "name": "user-config" + }, + { + "configMap": { + "name": "launch-endpoint-id-end-abcdefg-endpoint-config" + }, + "name": "endpoint-config" + }, + { + "emptyDir": { + "medium": "Memory" + }, + "name": "dshm" + } + ] + } + }, + "restartPolicy": "RecreateGroupOnPodRestart", + "size": 2, + "workerTemplate": { + "metadata": { + "annotations": { + "ad.datadoghq.com/main.logs": "[{\"service\": \"endpoint_name\", \"source\": \"python\"}]", + "kubernetes.io/change-cause": "Deployment at 2000-01-01 01:38:13.814158 UTC. Using deployment constructed from model bundle ID bun_cqi4v12d6mt002nap720, model bundle name endpoint_name, endpoint ID end_abcdefg" + }, + "labels": { + "app": "launch-endpoint-id-end-abcdefg", + "created_by": "userid000000", + "endpoint_id": "end_abcdefg", + "endpoint_name": "endpoint_name", + "env": "training", + "managed-by": "model-engine", + "owner": "userid000000", + "product": "testing", + "sidecar.istio.io/inject": "false", + "tags.datadoghq.com/env": "training", + "tags.datadoghq.com/service": "endpoint_name", + "tags.datadoghq.com/version": "05fb96620692a205e52d33980eff475d6a52748a", + "team": "infra", + "use_scale_launch_endpoint_network_policy": "true", + "user_id": "userid000000", + "version": "v1" + } + }, + "spec": { + "affinity": { + "podAffinity": { + "preferredDuringSchedulingIgnoredDuringExecution": [ + { + "podAffinityTerm": { + "labelSelector": { + "matchExpressions": [ + { + "key": "app", + "operator": "In", + "values": [ + "launch-endpoint-id-end-abcdefg" + ] + } + ] + }, + "topologyKey": "kubernetes.io/hostname" + }, + "weight": 1 + }, + { + "podAffinityTerm": { + "labelSelector": { + "matchExpressions": [ + { + "key": "3d45a96760a60018eb4a9d874e919aef", + "operator": "In", + "values": [ + "True" + ] + } + ] + }, + "topologyKey": "kubernetes.io/hostname" + }, + "weight": 100 + } + ] + } + }, + "containers": [ + { + "command": [ + "/bin/bash", + "-c", + "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' s3://bucket/key/userid000000/model_weights/* model_files;/workspace/init_ray.sh worker --ray_cluster_size=$RAY_CLUSTER_SIZE --ray_address=$K8S_LEADER_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" + ], + "env": [ + { + "name": "VLLM_HOST_IP", + "value": "$(K8S_LEADER_NAME).$(K8S_LWS_NAME).$(K8S_OWN_NAMESPACE).svc.cluster.local" + }, + { + "name": "NCCL_SOCKET_IFNAME", + "value": "eth0" + }, + { + "name": "GLOO_SOCKET_IFNAME", + "value": "eth0" + }, + { + "name": "NCCL_DEBUG", + "value": "INFO" + }, + { + "name": "VLLM_LOGGING_LEVEL", + "value": "INFO" + }, + { + "name": "AWS_PROFILE", + "value": "aws-profile" + }, + { + "name": "AWS_CONFIG_FILE", + "value": "/opt/.aws/config" + }, + { + "name": "K8S_OWN_POD_NAME", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.name" + } + } + }, + { + "name": "K8S_OWN_NAMESPACE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.namespace" + } + } + }, + { + "name": "K8S_LWS_NAME", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['leaderworkerset.sigs.k8s.io/name']" + } + } + }, + { + "name": "K8S_LWS_CLUSTER_SIZE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.annotations['leaderworkerset.sigs.k8s.io/size']" + } + } + }, + { + "name": "DD_TRACE_ENABLED", + "value": "true" + }, + { + "name": "DD_SERVICE", + "value": "endpoint_name" + }, + { + "name": "DD_ENV", + "value": "training" + }, + { + "name": "DD_VERSION", + "value": "05fb96620692a205e52d33980eff475d6a52748a" + }, + { + "name": "DD_AGENT_HOST", + "valueFrom": { + "fieldRef": { + "fieldPath": "status.hostIP" + } + } + } + ], + "image": "000000000000.dkr.ecr.us-west-2.amazonaws.com/vllm:0.5.3.post1", + "imagePullPolicy": "IfNotPresent", + "name": "lws-worker", + "ports": [ + { + "containerPort": 5005, + "name": "http", + "protocol": "TCP" + } + ], + "resources": { + "limits": { + "cpu": "10", + "ephemeral-storage": "94Gi", + "memory": "40Gi", + "nvidia.com/gpu": "1" + }, + "requests": { + "cpu": "10", + "ephemeral-storage": "94Gi", + "memory": "40Gi", + "nvidia.com/gpu": "1" + } + }, + "volumeMounts": [ + { + "mountPath": "/opt/.aws/config", + "name": "config-volume", + "subPath": "config" + }, + { + "mountPath": "/dev/shm", + "name": "dshm" + }, + { + "mountPath": "/app/user_config", + "name": "user-config", + "subPath": "raw_data" + }, + { + "mountPath": "/app/endpoint_config", + "name": "endpoint-config", + "subPath": "raw_data" + } + ] + } + ], + "nodeSelector": { + "k8s.amazonaws.com/accelerator": "nvidia-hopper-h100", + "node-lifecycle": "normal" + }, + "priorityClassName": "model-engine-high-priority", + "serviceAccount": "aws-profile", + "terminationGracePeriodSeconds": 600, + "tolerations": [ + { + "effect": "NoSchedule", + "key": "nvidia.com/gpu", + "operator": "Exists" + } + ], + "volumes": [ + { + "configMap": { + "name": "aws-profile-config" + }, + "name": "config-volume" + }, + { + "configMap": { + "name": "launch-endpoint-id-end-abcdefg" + }, + "name": "user-config" + }, + { + "configMap": { + "name": "launch-endpoint-id-end-abcdefg-endpoint-config" + }, + "name": "endpoint-config" + }, + { + "emptyDir": { + "medium": "Memory" + }, + "name": "dshm" + } + ] + } + } + }, + "replicas": 0, + "rolloutStrategy": { + "rollingUpdateConfiguration": { + "maxSurge": 0, + "maxUnavailable": 1 + }, + "type": "RollingUpdate" + }, + "startupPolicy": "LeaderCreated" + }, + "status": { + "conditions": [ + { + "lastTransitionTime": "2000-01-01T01:38:14Z", + "message": "All replicas are ready", + "reason": "AllGroupsReady", + "status": "True", + "type": "Available" + } + ], + "hpaPodSelector": "leaderworkerset.sigs.k8s.io/name=launch-endpoint-id-end-abcdefg,leaderworkerset.sigs.k8s.io/worker-index=0" + } + } \ No newline at end of file diff --git a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py index 4e3c6415..f9ac3284 100644 --- a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py @@ -1,3 +1,5 @@ +import json +import os from typing import Any, Dict, List from unittest.mock import AsyncMock, Mock, patch @@ -7,6 +9,7 @@ from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest from model_engine_server.common.env_vars import GIT_TAG from model_engine_server.domain.entities import ( + ModelBundle, ModelEndpointConfig, ModelEndpointType, ModelEndpointUserConfigState, @@ -15,7 +18,7 @@ from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( DATADOG_ENV_VAR, K8SEndpointResourceDelegate, - add_datadog_env_to_main_container, + add_datadog_env_to_container, get_main_container_from_deployment_template, load_k8s_yaml, ) @@ -28,6 +31,10 @@ MODULE_PATH = "model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate" +EXAMPLE_LWS_CONFIG_PATH = os.path.abspath(os.path.join(__file__, "..", "example_lws_config.json")) +with open(EXAMPLE_LWS_CONFIG_PATH, "r") as f: + EXAMPLE_LWS_CONFIG = json.load(f) + @pytest.fixture def mock_get_kubernetes_cluster_version(): @@ -165,7 +172,8 @@ def test_resource_arguments_type_and_add_datadog_env_to_main_container(resource_ deployment_template = load_k8s_yaml(f"{resource_arguments_type_name}.yaml", resource_arguments) if "runnable-image" in resource_arguments_type_name: - add_datadog_env_to_main_container(deployment_template) + user_container = get_main_container_from_deployment_template(deployment_template) + add_datadog_env_to_container(deployment_template, user_container) user_container = get_main_container_from_deployment_template(deployment_template) @@ -281,7 +289,7 @@ def _verify_custom_object_plurals(call_args_list, expected_plurals: List[str]) - @pytest.mark.asyncio -async def test_create_async_endpoint_has_correct_labels( +async def test_create_async_endpoint_has_correct_labels_and_dest( k8s_endpoint_resource_delegate, mock_apps_client, mock_core_client, @@ -294,9 +302,10 @@ async def test_create_async_endpoint_has_correct_labels( for request in [ create_resources_request_async_runnable_image, ]: - await k8s_endpoint_resource_delegate.create_or_update_resources( + dest = await k8s_endpoint_resource_delegate.create_or_update_resources( request, sqs_queue_name="my_queue", sqs_queue_url="https://my_queue" ) + assert dest == "my_queue" # Verify deployment labels create_deployment_call_args = mock_apps_client.create_namespaced_deployment.call_args @@ -350,7 +359,7 @@ async def test_create_async_endpoint_has_correct_labels( @pytest.mark.asyncio -async def test_create_streaming_endpoint_has_correct_labels( +async def test_create_streaming_endpoint_has_correct_labels_and_dest( k8s_endpoint_resource_delegate, mock_apps_client, mock_core_client, @@ -361,11 +370,15 @@ async def test_create_streaming_endpoint_has_correct_labels( create_resources_request_streaming_runnable_image: CreateOrUpdateResourcesRequest, ): request = create_resources_request_streaming_runnable_image - await k8s_endpoint_resource_delegate.create_or_update_resources( + dest = await k8s_endpoint_resource_delegate.create_or_update_resources( request, sqs_queue_name="my_queue", sqs_queue_url="https://my_queue", ) + service_name = mock_core_client.create_namespaced_service.call_args.kwargs["body"]["metadata"][ + "name" + ] + assert dest == service_name # Verify deployment labels create_deployment_call_args = mock_apps_client.create_namespaced_deployment.call_args @@ -423,7 +436,7 @@ async def test_create_streaming_endpoint_has_correct_labels( @pytest.mark.asyncio -async def test_create_sync_endpoint_has_correct_labels( +async def test_create_sync_endpoint_has_correct_labels_and_dest( k8s_endpoint_resource_delegate, mock_apps_client, mock_core_client, @@ -436,11 +449,15 @@ async def test_create_sync_endpoint_has_correct_labels( for request in [ create_resources_request_sync_runnable_image, ]: - await k8s_endpoint_resource_delegate.create_or_update_resources( + dest = await k8s_endpoint_resource_delegate.create_or_update_resources( request, sqs_queue_name="my_queue", sqs_queue_url="https://my_queue,", ) + service_name = mock_core_client.create_namespaced_service.call_args.kwargs["body"][ + "metadata" + ]["name"] + assert dest == service_name # Verify deployment labels create_deployment_call_args = mock_apps_client.create_namespaced_deployment.call_args @@ -523,6 +540,48 @@ async def test_create_sync_endpoint_has_correct_k8s_service_type( assert service_body["spec"] is not None +@pytest.mark.asyncio +async def test_create_multinode_endpoint_creates_lws_and_correct_dest( + k8s_endpoint_resource_delegate, + mock_apps_client, + mock_core_client, + mock_autoscaling_client, + mock_policy_client, + mock_custom_objects_client, + mock_get_kubernetes_cluster_version, + create_resources_request_streaming_runnable_image: CreateOrUpdateResourcesRequest, + model_bundle_5: ModelBundle, +): + # Patch model bundle so that it supports multinode + model_bundle_5.flavor.worker_env = {"fake_env": "fake_value"} + model_bundle_5.flavor.worker_command = ["fake_command"] + create_resources_request_streaming_runnable_image.build_endpoint_request.model_endpoint_record.current_model_bundle = ( + model_bundle_5 + ) + create_resources_request_streaming_runnable_image.build_endpoint_request.model_endpoint_record.endpoint_type = ( + ModelEndpointType.STREAMING + ) + + create_resources_request_streaming_runnable_image.build_endpoint_request.nodes_per_worker = 2 + dest = await k8s_endpoint_resource_delegate.create_or_update_resources( + create_resources_request_streaming_runnable_image, + sqs_queue_name="my_queue", + sqs_queue_url="https://my_queue", + ) + service_name = mock_core_client.create_namespaced_service.call_args.kwargs["body"]["metadata"][ + "name" + ] + assert dest == service_name + # Verify call to custom objects client with LWS is made + create_custom_objects_call_args_list = ( + mock_custom_objects_client.create_namespaced_custom_object.call_args_list + ) + assert any( + call_args.kwargs["group"] == "leaderworkerset.x-k8s.io" + for call_args in create_custom_objects_call_args_list + ) + + @pytest.mark.asyncio async def test_create_endpoint_raises_k8s_endpoint_resource_delegate( k8s_endpoint_resource_delegate, @@ -563,6 +622,8 @@ async def test_get_resources_async_success( mock_policy_client, mock_custom_objects_client, ): + # Pretend that LWS get gives an ApiException since it doesn't exist + mock_custom_objects_client.get_namespaced_custom_object = AsyncMock(side_effect=ApiException) k8s_endpoint_resource_delegate.__setattr__( "_get_common_endpoint_params", Mock( @@ -623,6 +684,8 @@ async def test_get_resources_sync_success( mock_policy_client, mock_custom_objects_client, ): + # Pretend that LWS get and keda get give an ApiException + mock_custom_objects_client.get_namespaced_custom_object = AsyncMock(side_effect=ApiException) k8s_endpoint_resource_delegate.__setattr__( "_get_common_endpoint_params", Mock( @@ -668,6 +731,40 @@ async def test_get_resources_sync_success( assert infra_state +@pytest.mark.asyncio +async def test_get_resources_multinode_success( + k8s_endpoint_resource_delegate, + mock_apps_client, + mock_core_client, + mock_autoscaling_client, + mock_policy_client, + mock_custom_objects_client, +): + k8s_endpoint_resource_delegate.__setattr__( + "_translate_k8s_config_maps_to_user_config_data", + Mock( + return_value=ModelEndpointUserConfigState( + app_config=None, + endpoint_config=ModelEndpointConfig( + endpoint_name="test_endpoint", + bundle_name="test_bundle", + post_inference_hooks=["callback"], + ), + ) + ), + ) + + mock_custom_objects_client.get_namespaced_custom_object = AsyncMock( + return_value=EXAMPLE_LWS_CONFIG + ) + + infra_state = await k8s_endpoint_resource_delegate.get_resources( + endpoint_id="", deployment_name="", endpoint_type=ModelEndpointType.STREAMING + ) + assert infra_state + assert infra_state.resource_state.nodes_per_worker == 2 + + @pytest.mark.asyncio async def test_delete_resources_invalid_endpoint_type_returns_false( k8s_endpoint_resource_delegate, @@ -708,6 +805,32 @@ async def test_delete_resources_sync_success( assert deleted +@pytest.mark.asyncio +async def test_delete_resources_multinode_success( + k8s_endpoint_resource_delegate, + mock_apps_client, + mock_core_client, + mock_autoscaling_client, + mock_policy_client, + mock_custom_objects_client, +): + mock_custom_objects_client.get_namespaced_custom_object = AsyncMock( + return_value=EXAMPLE_LWS_CONFIG + ) + mock_custom_objects_client.delete_namespaced_custom_object = AsyncMock() + deleted = await k8s_endpoint_resource_delegate.delete_resources( + endpoint_id="", deployment_name="", endpoint_type=ModelEndpointType.STREAMING + ) + assert deleted + delete_called_for_lws = False + for call_args in mock_custom_objects_client.delete_namespaced_custom_object.call_args_list: + # 'group' is kwargs in delete_namespaced_custom_object + if call_args[1]["group"] == "leaderworkerset.x-k8s.io": + delete_called_for_lws = True + break + assert delete_called_for_lws + + @pytest.mark.asyncio async def test_create_pdb( k8s_endpoint_resource_delegate, diff --git a/model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py index e03a9840..041b12aa 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py @@ -43,6 +43,7 @@ def test_create_model_endpoint_infra( memory=endpoint.infra_state.resource_state.memory, gpu_type=endpoint.infra_state.resource_state.gpu_type, storage=endpoint.infra_state.resource_state.storage, + nodes_per_worker=endpoint.infra_state.resource_state.nodes_per_worker, optimize_costs=bool(endpoint.infra_state.resource_state.optimize_costs), aws_role=endpoint.infra_state.aws_role, results_s3_bucket=endpoint.infra_state.results_s3_bucket, @@ -133,6 +134,37 @@ async def test_update_model_endpoint_infra( ) +@pytest.mark.asyncio +async def test_update_multinode_endpoint_keeps_nodes_per_worker( + model_endpoint_infra_gateway: LiveModelEndpointInfraGateway, + model_endpoint_1: ModelEndpoint, + fake_task_queue_gateway, +): + model_endpoint_1.infra_state.resource_state.nodes_per_worker = 2 + resource_gateway: Any = model_endpoint_infra_gateway.resource_gateway + existing_infra_state = model_endpoint_1.infra_state + assert existing_infra_state is not None + live_model_endpoint_infra_gateway.generate_deployment_name = Mock( + return_value=existing_infra_state.deployment_name + ) + resource_gateway.add_resource(model_endpoint_1.record.id, existing_infra_state) + + creation_task_id_1 = await model_endpoint_infra_gateway.update_model_endpoint_infra( + model_endpoint_record=model_endpoint_1.record, + max_workers=2, + cpus=2, + memory=2, + storage=2, + ) + assert creation_task_id_1 + assert ( + fake_task_queue_gateway.get_task_args(creation_task_id_1)["kwargs"][ + "build_endpoint_request_json" + ].get("nodes_per_worker") + == 2 + ) + + @pytest.mark.asyncio async def test_get_model_endpoint_infra_success( model_endpoint_infra_gateway: LiveModelEndpointInfraGateway, diff --git a/model-engine/tests/unit/infra/repositories/conftest.py b/model-engine/tests/unit/infra/repositories/conftest.py index 12c550b6..dbf0109e 100644 --- a/model-engine/tests/unit/infra/repositories/conftest.py +++ b/model-engine/tests/unit/infra/repositories/conftest.py @@ -276,6 +276,7 @@ def entity_model_endpoint_infra_state() -> ModelEndpointInfraState: memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( diff --git a/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py index 52886431..b67fc4cc 100644 --- a/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py +++ b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py @@ -52,6 +52,7 @@ async def _create_model_endpoint_helper( memory=infra_state.resource_state.memory, gpu_type=infra_state.resource_state.gpu_type, storage=infra_state.resource_state.storage, + nodes_per_worker=infra_state.resource_state.nodes_per_worker, optimize_costs=bool(infra_state.resource_state.optimize_costs), min_workers=infra_state.deployment_state.min_workers, max_workers=infra_state.deployment_state.max_workers, @@ -156,6 +157,7 @@ async def test_create_model_endpoint_raises_already_exists( memory=infra_state.resource_state.memory, gpu_type=infra_state.resource_state.gpu_type, storage=infra_state.resource_state.storage, + nodes_per_worker=infra_state.resource_state.nodes_per_worker, optimize_costs=bool(infra_state.resource_state.optimize_costs), min_workers=infra_state.deployment_state.min_workers, max_workers=infra_state.deployment_state.max_workers, From 8dc74c899de67f421e159701682f100614b6db1c Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 10 Oct 2024 10:17:09 -0700 Subject: [PATCH 396/425] Enable more vllm args to be passed through for batch completions (#630) * Enable more vllm args to be passed through * Enabling bypassing _infer_hardware check by providing own hardware specification * Add deepseek to model list * fix bug * Add tests * Add fake batch completions service * fix test * Fix --- .../common/dtos/llms/batch_completion.py | 39 ++++--- .../common/dtos/llms/vllm.py | 19 ++++ .../use_cases/llm_model_endpoint_use_cases.py | 91 ++++++++++++---- .../fake_llm_batch_completions_service.py | 87 +++++++++++++++ model-engine/tests/unit/conftest.py | 9 ++ model-engine/tests/unit/domain/conftest.py | 52 +++++++++ .../tests/unit/domain/test_llm_use_cases.py | 100 ++++++++++++++++++ 7 files changed, 360 insertions(+), 37 deletions(-) create mode 100644 model-engine/model_engine_server/infra/services/fake_llm_batch_completions_service.py diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py index e8eff546..9f1eea1e 100644 --- a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -10,8 +10,13 @@ CompletionV2Request, CompletionV2SyncResponse, ) -from model_engine_server.common.dtos.llms.vllm import VLLMModelConfig +from model_engine_server.common.dtos.llms.vllm import VLLMEngineAdditionalArgs, VLLMModelConfig from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field +from model_engine_server.domain.entities.common_types import ( + CpuSpecificationType, + StorageSpecificationType, +) +from model_engine_server.domain.entities.gpu_type import GpuType from typing_extensions import TypeAlias @@ -112,6 +117,25 @@ class BatchCompletionsRequestBase(BaseModel): NOTE: this config is highly experimental and signature will change significantly in future iterations.""", ) + cpus: Optional[CpuSpecificationType] = Field( + default=None, description="CPUs to use for the batch inference." + ) + gpus: Optional[int] = Field( + default=None, description="Number of GPUs to use for the batch inference." + ) + memory: Optional[StorageSpecificationType] = Field( + default=None, description="Amount of memory to use for the batch inference." + ) + gpu_type: Optional[GpuType] = Field( + default=None, description="GPU type to use for the batch inference." + ) + storage: Optional[StorageSpecificationType] = Field( + default=None, description="Storage to use for the batch inference." + ) + nodes_per_worker: Optional[int] = Field( + default=None, description="Number of nodes per worker for the batch inference." + ) + # V1 DTOs for batch completions CompletionV1Output = CompletionOutput @@ -297,19 +321,6 @@ class GetBatchCompletionV2Response(BaseModel): job: BatchCompletionsJob -class VLLMEngineAdditionalArgs(BaseModel): - max_gpu_memory_utilization: Optional[float] = Field( - default=0.9, - le=1.0, - description="Maximum GPU memory utilization for the batch inference. Default to 90%.", - ) - - attention_backend: Optional[str] = Field( - default=None, - description="Attention backend to use for vLLM. Default to None.", - ) - - class CreateBatchCompletionsEngineRequest(BatchCompletionsRequestBase, VLLMEngineAdditionalArgs): """ Internal model for representing request to the inference framework. This contains additional fields that we want diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py index f23059af..a904a597 100644 --- a/model-engine/model_engine_server/common/dtos/llms/vllm.py +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -34,6 +34,25 @@ class VLLMModelConfig(BaseModel): description="Maximum GPU memory utilization for the batch inference. Default to 90%.", ) + trust_remote_code: Optional[bool] = Field( + default=False, + description="Whether to trust remote code from Hugging face hub. This is only applicable to models whose code is not supported natively by the transformers library (e.g. deepseek). Default to False.", + ) + + +class VLLMEngineAdditionalArgs(BaseModel): + """Additional arguments to configure for vLLM that are not direct inputs to the vLLM engine""" + + max_gpu_memory_utilization: Optional[float] = Field( + None, + description="Maximum GPU memory utilization for the batch inference. Default to 90%. Deprecated in favor of specifying this in VLLMModelConfig", + ) + + attention_backend: Optional[str] = Field( + default=None, + description="Attention backend to use for vLLM. Default to None.", + ) + class VLLMSamplingParams(BaseModel): best_of: Optional[int] = Field( diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 7f78dc50..9ec9efb0 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -55,6 +55,7 @@ CompletionV2StreamSuccessChunk, CompletionV2SyncResponse, ) +from model_engine_server.common.dtos.llms.vllm import VLLMModelConfig from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus @@ -255,6 +256,8 @@ "phi-3-small-128k-instruct", "phi-3-medium-4-instruct", "phi-3-medium-128k-instruct", + "deepseek-v2", + "deepseek-v2-chat", "deepseek-coder-v2", "deepseek-coder-v2-instruct", "deepseek-coder-v2-lite", @@ -878,9 +881,7 @@ def _create_vllm_bundle_command( subcommands.append(ray_cmd) if not is_worker: - vllm_cmd = "" - - vllm_cmd += f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005" + vllm_cmd = f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005" if multinode: vllm_cmd += f" --pipeline-parallel-size {nodes_per_worker}" @@ -906,11 +907,16 @@ def _create_vllm_bundle_command( additional_args = infer_addition_engine_args_from_model_name(model_name) - if additional_args.max_gpu_memory_utilization: - vllm_cmd += f" --gpu-memory-utilization {additional_args.max_gpu_memory_utilization} --enforce-eager" + for field in VLLMModelConfig.model_fields.keys(): + config_value = getattr(additional_args, field, None) + if config_value is not None: + vllm_cmd += f" --{field.replace('_', '-')} {config_value}" + + if field == "gpu_memory_utilization": + vllm_cmd += " --enforce-eager" - if additional_args.attention_backend: - vllm_cmd += " --attention-backend FLASHINFER" + if additional_args.attention_backend is not None: + vllm_cmd += f" --attention-backend {additional_args.attention_backend}" subcommands.append(vllm_cmd) @@ -3328,9 +3334,13 @@ async def _infer_hardware( ) +class VLLMAdditionalArgs(VLLMModelConfig, VLLMEngineAdditionalArgs): + pass + + def infer_addition_engine_args_from_model_name( model_name: str, -) -> VLLMEngineAdditionalArgs: +) -> VLLMAdditionalArgs: # Increase max gpu utilization for larger models model_param_count_b = get_model_param_count_b(model_name) if model_param_count_b >= 70: @@ -3343,9 +3353,15 @@ def infer_addition_engine_args_from_model_name( if model_name.startswith("gemma-2"): attention_backend = "FLASHINFER" - return VLLMEngineAdditionalArgs( - max_gpu_memory_utilization=gpu_memory_utilization, + trust_remote_code = None + # DeepSeek requires trust_remote_code + if model_name.startswith("deepseek"): + trust_remote_code = True + + return VLLMAdditionalArgs( + gpu_memory_utilization=gpu_memory_utilization, attention_backend=attention_backend, + trust_remote_code=trust_remote_code, ) @@ -3437,9 +3453,7 @@ async def execute( engine_request.model_cfg.model ) - engine_request.max_gpu_memory_utilization = ( - additional_engine_args.max_gpu_memory_utilization - ) + engine_request.max_gpu_memory_utilization = additional_engine_args.gpu_memory_utilization engine_request.attention_backend = additional_engine_args.attention_backend batch_bundle = await self.create_batch_job_bundle(user, engine_request, hardware) @@ -3490,13 +3504,40 @@ async def execute( request.model_cfg.checkpoint_path = get_checkpoint_path( request.model_cfg.model, request.model_cfg.checkpoint_path ) - hardware = await _infer_hardware( - self.llm_artifact_gateway, - request.model_cfg.model, - request.model_cfg.checkpoint_path, - is_batch_job=True, - max_context_length=request.model_cfg.max_context_length, - ) + + if ( + request.cpus is not None + and request.gpus is not None + and request.memory is not None + and request.storage is not None + and request.gpu_type is not None + ): + hardware = CreateDockerImageBatchJobResourceRequests( + cpus=request.cpus, + gpus=request.gpus, + memory=request.memory, + storage=request.storage, + gpu_type=request.gpu_type, + ) + else: + if ( + request.cpus is not None + or request.gpus is not None + or request.memory is not None + or request.storage is not None + or request.gpu_type is not None + ): + logger.warning( + "All hardware spec fields (cpus, gpus, memory, storage, gpu_type) must be provided if any hardware spec field is provided. Will attempt to infer hardware spec from checkpoint." + ) + + hardware = await _infer_hardware( + self.llm_artifact_gateway, + request.model_cfg.model, + request.model_cfg.checkpoint_path, + is_batch_job=True, + max_context_length=request.model_cfg.max_context_length, + ) engine_request = CreateBatchCompletionsEngineRequest.from_api_v2(request) engine_request.model_cfg.num_shards = hardware.gpus @@ -3520,9 +3561,13 @@ async def execute( additional_engine_args = infer_addition_engine_args_from_model_name( engine_request.model_cfg.model ) - engine_request.max_gpu_memory_utilization = ( - additional_engine_args.max_gpu_memory_utilization - ) + + # Overwrite model config fields with those determined by additional engine args + for field in VLLMModelConfig.model_fields.keys(): + config_value = getattr(additional_engine_args, field, None) + if config_value is not None and hasattr(engine_request.model_cfg, field): + setattr(engine_request.model_cfg, field, config_value) + engine_request.attention_backend = additional_engine_args.attention_backend return await self.llm_batch_completions_service.create_batch_job( diff --git a/model-engine/model_engine_server/infra/services/fake_llm_batch_completions_service.py b/model-engine/model_engine_server/infra/services/fake_llm_batch_completions_service.py new file mode 100644 index 00000000..bc72fc22 --- /dev/null +++ b/model-engine/model_engine_server/infra/services/fake_llm_batch_completions_service.py @@ -0,0 +1,87 @@ +from typing import Dict, Optional + +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.llms.batch_completion import ( + BatchCompletionsJob, + CreateBatchCompletionsEngineRequest, + UpdateBatchCompletionsV2Request, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.services.llm_batch_completions_service import ( + LLMBatchCompletionsService, +) + + +class FakeLLMBatchCompletionsService(LLMBatchCompletionsService): + def __init__( + self, + ): + self.jobs = [] + + async def create_batch_job( + self, + *, + user: User, + image_repo: str, + image_tag: str, + job_request: CreateBatchCompletionsEngineRequest, + resource_requests: CreateDockerImageBatchJobResourceRequests, + max_runtime_sec: int = 24 * 60 * 60, + labels: Dict[str, str] = {}, + num_workers: Optional[int] = 1, + ) -> BatchCompletionsJob: + """ + Create a batch completion job. + + Args: + owner: The user who requested the batch job + image_repo: The docker repo where the image is stored + image_tag: The tag of the batch completions image + job_config: The user-specified input to the batch job. Exposed as a file mounted at mount_location to the batch job + labels: Labels to apply to the batch job. + resource_requests: The resource requests for the batch job. + max_runtime_sec: The timeout of the batch job in seconds. + num_workers: The number of workers to run in the job. + + Returns: + The ID of the batch job. + """ + raise NotImplementedError() + + async def get_batch_job(self, batch_job_id: str, user: User) -> Optional[BatchCompletionsJob]: + """ + Get a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + The batch job, or None if it does not exist. + """ + raise NotImplementedError() + + async def update_batch_job( + self, batch_job_id: str, request: UpdateBatchCompletionsV2Request, user: User + ) -> Optional[BatchCompletionsJob]: + """ + Get a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + The batch job, or None if it does not exist. + """ + raise NotImplementedError() + + async def cancel_batch_job(self, batch_job_id: str, user: User) -> bool: + """ + Update a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + Whether the batch job was updated successfully. + """ + return False diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index e207d30c..81e5e1cc 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -152,6 +152,9 @@ translate_model_bundle_orm_to_model_bundle, ) from model_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService +from model_engine_server.infra.services.fake_llm_batch_completions_service import ( + FakeLLMBatchCompletionsService, +) from model_engine_server.infra.services.image_cache_service import ImageCacheService from model_engine_server.infra.services.live_llm_batch_completions_service import ( LiveLLMBatchCompletionsService, @@ -2098,6 +2101,12 @@ def fake_docker_image_batch_job_gateway() -> FakeDockerImageBatchJobGateway: return gateway +@pytest.fixture +def fake_llm_batch_completions_service() -> FakeLLMBatchCompletionsService: + service = FakeLLMBatchCompletionsService() + return service + + @pytest.fixture def fake_monitoring_metrics_gateway() -> FakeMonitoringMetricsGateway: gateway = FakeMonitoringMetricsGateway() diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 861e7c00..58c1a0de 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -12,6 +12,11 @@ CreateLLMModelEndpointV1Request, UpdateLLMModelEndpointV1Request, ) +from model_engine_server.common.dtos.llms.batch_completion import ( + CreateBatchCompletionsV2ModelConfig, + CreateBatchCompletionsV2Request, + FilteredCompletionV2Request, +) from model_engine_server.common.dtos.model_bundles import ( CreateModelBundleV1Request, CreateModelBundleV2Request, @@ -609,3 +614,50 @@ def create_batch_completions_v1_request() -> CreateBatchCompletionsV1Request: ), data_parallelism=1, ) + + +@pytest.fixture +def create_batch_completions_v2_request() -> CreateBatchCompletionsV2Request: + return CreateBatchCompletionsV2Request( + output_data_path="test_output_data_path", + content=[ + FilteredCompletionV2Request( + prompt="What is machine learning?", + max_tokens=10, + temperature=0.5, + ) + ], + model_config=CreateBatchCompletionsV2ModelConfig( + model="mpt-7b", + checkpoint_path="s3://test_checkpoint_path", + labels={}, + num_shards=1, + ), + data_parallelism=1, + ) + + +@pytest.fixture +def create_batch_completions_v2_request_with_hardware() -> CreateBatchCompletionsV2Request: + return CreateBatchCompletionsV2Request( + output_data_path="test_output_data_path", + content=[ + FilteredCompletionV2Request( + prompt="What is machine learning?", + max_tokens=10, + temperature=0.5, + ) + ], + model_config=CreateBatchCompletionsV2ModelConfig( + model="mpt-7b", + checkpoint_path="s3://test_checkpoint_path", + labels={}, + num_shards=1, + ), + data_parallelism=1, + cpus=1, + gpus=1, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + ) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 4b4e45a8..1341ba46 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -3,6 +3,7 @@ from unittest import mock import pytest +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests from model_engine_server.common.dtos.llms import ( CompletionOutput, CompletionStreamV1Request, @@ -15,6 +16,10 @@ TokenOutput, UpdateLLMModelEndpointV1Request, ) +from model_engine_server.common.dtos.llms.batch_completion import ( + CreateBatchCompletionsEngineRequest, + CreateBatchCompletionsV2Request, +) from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus from model_engine_server.core.auth.authentication_repository import User from model_engine_server.domain.entities import ( @@ -43,6 +48,7 @@ CompletionStreamV1UseCase, CompletionSyncV1UseCase, CreateBatchCompletionsUseCase, + CreateBatchCompletionsV2UseCase, CreateLLMModelBundleV1UseCase, CreateLLMModelEndpointV1UseCase, DeleteLLMEndpointByNameUseCase, @@ -62,6 +68,13 @@ from ..conftest import mocked__get_recommended_hardware_config_map +def mocked__get_latest_batch_v2_tag(): + async def async_mock(*args, **kwargs): # noqa + return "fake_docker_repository_latest_image_tag" + + return mock.AsyncMock(side_effect=async_mock) + + def mocked__get_latest_batch_tag(): async def async_mock(*args, **kwargs): # noqa return "fake_docker_repository_latest_image_tag" @@ -2815,6 +2828,93 @@ async def test_create_batch_completions_v1( ] +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_batch_v2_tag", + mocked__get_latest_batch_v2_tag(), +) +async def test_create_batch_completions_v2( + fake_llm_batch_completions_service, + fake_llm_artifact_gateway, + test_api_key: str, + create_batch_completions_v2_request: CreateBatchCompletionsV2Request, + create_batch_completions_v2_request_with_hardware: CreateBatchCompletionsV2Request, +): + fake_llm_batch_completions_service.create_batch_job = mock.AsyncMock() + use_case = CreateBatchCompletionsV2UseCase( + llm_batch_completions_service=fake_llm_batch_completions_service, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + await use_case.execute(create_batch_completions_v2_request, user) + + expected_engine_request = CreateBatchCompletionsEngineRequest( + model_cfg=create_batch_completions_v2_request.model_cfg, + max_runtime_sec=create_batch_completions_v2_request.max_runtime_sec, + data_parallelism=create_batch_completions_v2_request.data_parallelism, + labels=create_batch_completions_v2_request.labels, + content=create_batch_completions_v2_request.content, + output_data_path=create_batch_completions_v2_request.output_data_path, + ) + + expected_hardware = CreateDockerImageBatchJobResourceRequests( + cpus=10, + memory="40Gi", + gpus=1, + gpu_type=GpuType.NVIDIA_HOPPER_H100_3G_40GB, + storage="80Gi", + nodes_per_worker=1, + ) + + # assert fake_llm_batch_completions_service was called with the correct arguments + fake_llm_batch_completions_service.create_batch_job.assert_called_with( + user=user, + job_request=expected_engine_request, + image_repo="llm-engine/batch-infer-vllm", + image_tag="fake_docker_repository_latest_image_tag", + resource_requests=expected_hardware, + labels=create_batch_completions_v2_request.labels, + max_runtime_sec=create_batch_completions_v2_request.max_runtime_sec, + num_workers=create_batch_completions_v2_request.data_parallelism, + ) + + await use_case.execute(create_batch_completions_v2_request_with_hardware, user) + + expected_engine_request = CreateBatchCompletionsEngineRequest( + model_cfg=create_batch_completions_v2_request_with_hardware.model_cfg, + max_runtime_sec=create_batch_completions_v2_request_with_hardware.max_runtime_sec, + data_parallelism=create_batch_completions_v2_request_with_hardware.data_parallelism, + labels=create_batch_completions_v2_request_with_hardware.labels, + content=create_batch_completions_v2_request_with_hardware.content, + output_data_path=create_batch_completions_v2_request_with_hardware.output_data_path, + ) + + expected_hardware = CreateDockerImageBatchJobResourceRequests( + cpus=create_batch_completions_v2_request_with_hardware.cpus, + gpus=create_batch_completions_v2_request_with_hardware.gpus, + memory=create_batch_completions_v2_request_with_hardware.memory, + storage=create_batch_completions_v2_request_with_hardware.storage, + gpu_type=create_batch_completions_v2_request_with_hardware.gpu_type, + nodes_per_worker=create_batch_completions_v2_request_with_hardware.nodes_per_worker, + ) + # assert fake_llm_batch_completions_service was called with the correct arguments + fake_llm_batch_completions_service.create_batch_job.assert_called_with( + user=user, + job_request=expected_engine_request, + image_repo="llm-engine/batch-infer-vllm", + image_tag="fake_docker_repository_latest_image_tag", + resource_requests=expected_hardware, + labels=create_batch_completions_v2_request.labels, + max_runtime_sec=create_batch_completions_v2_request.max_runtime_sec, + num_workers=create_batch_completions_v2_request.data_parallelism, + ) + + def test_merge_metadata(): request_metadata = { "key1": "value1", From 2f62171cd3cb9b8507d07a96652113a893c920cd Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Thu, 10 Oct 2024 12:54:47 -0700 Subject: [PATCH 397/425] add rec hardware to the configmap yaml (#631) --- .../model-engine/templates/recommended_hardware_config_map.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/charts/model-engine/templates/recommended_hardware_config_map.yaml b/charts/model-engine/templates/recommended_hardware_config_map.yaml index 47474ceb..6c999145 100644 --- a/charts/model-engine/templates/recommended_hardware_config_map.yaml +++ b/charts/model-engine/templates/recommended_hardware_config_map.yaml @@ -15,6 +15,7 @@ data: memory: {{ .memory }} storage: {{ .storage }} gpu_type: {{ .gpu_type }} + nodes_per_worker: {{ .nodes_per_worker }} {{- end }} byModelName: |- {{- range $.Values.recommendedHardware.byModelName }} @@ -24,5 +25,6 @@ data: memory: {{ .memory }} storage: {{ .storage }} gpu_type: {{ .gpu_type }} + nodes_per_worker: {{ .nodes_per_worker }} {{- end }} {{- end }} \ No newline at end of file From 258862a49db8d5e8bebcb882947cfe5da0e5a509 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:51:24 -0700 Subject: [PATCH 398/425] fix bug in batch completions v2 (#633) need to deduplicate some arguments --- .../model_engine_server/inference/vllm/vllm_batch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_batch.py b/model-engine/model_engine_server/inference/vllm/vllm_batch.py index ea0989fe..b9f3f086 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_batch.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_batch.py @@ -242,14 +242,17 @@ async def init_engine( print("VLLM additional configs:", parsed_configs.model_dump()) - engine_args = AsyncEngineArgs( + engine_args_dict = parsed_configs.model_dump(exclude_none=True) + default_engine_args_dict = dict( model=model, tensor_parallel_size=request.model_cfg.num_shards, seed=request.model_cfg.seed or 0, disable_log_requests=True, gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9, - **parsed_configs.model_dump(exclude_none=True), ) + default_engine_args_dict.update(engine_args_dict) + + engine_args = AsyncEngineArgs(**default_engine_args_dict) engine_client = AsyncLLMEngine.from_engine_args(engine_args) model_config = await engine_client.get_model_config() From 5a69175fdd2ff6bcbc86d36e786e26fc3b3a28ef Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Fri, 11 Oct 2024 07:45:44 -0700 Subject: [PATCH 399/425] Add hardware spec to client (#632) * Add hardware spec to client * fix import and update version * fix import and update version --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/completion.py | 13 ++++++++++++ .../llmengine/data_types/batch_completion.py | 20 +++++++++++++++++++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 5 files changed, 36 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 21bed5ce..86f99642 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0beta40" +__version__ = "0.0.0beta42" import os from typing import Sequence diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 29617f26..8f972e91 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -7,12 +7,15 @@ CompletionStreamV1Request, CompletionSyncResponse, CompletionSyncV1Request, + CpuSpecificationType, CreateBatchCompletionsModelConfig, CreateBatchCompletionsV1Request, CreateBatchCompletionsV1RequestContent, CreateBatchCompletionsV1Response, CreateBatchCompletionsV2Request, CreateBatchCompletionsV2Response, + GpuType, + StorageSpecificationType, ToolConfig, ) @@ -486,6 +489,11 @@ def batch_create( priority: Optional[str] = None, use_v2: bool = False, tool_config: Optional[ToolConfig] = None, + cpus: Optional[CpuSpecificationType] = None, + gpus: Optional[int] = None, + memory: Optional[StorageSpecificationType] = None, + gpu_type: Optional[GpuType] = None, + storage: Optional[StorageSpecificationType] = None, request_headers: Optional[Dict[str, str]] = None, ) -> Union[CreateBatchCompletionsV1Response, CreateBatchCompletionsV2Response]: """ @@ -636,6 +644,11 @@ def batch_create( max_runtime_sec=max_runtime_sec, tool_config=tool_config, priority=priority, + cpus=cpus, + gpus=gpus, + memory=memory, + gpu_type=gpu_type, + storage=storage, ).dict() response = cls.post_sync( resource_name="v2/batch-completions", diff --git a/clients/python/llmengine/data_types/batch_completion.py b/clients/python/llmengine/data_types/batch_completion.py index 6c14fcce..d72a3e82 100644 --- a/clients/python/llmengine/data_types/batch_completion.py +++ b/clients/python/llmengine/data_types/batch_completion.py @@ -6,6 +6,7 @@ from .chat_completion import ChatCompletionV2Request, ChatCompletionV2Response from .completion import CompletionOutput, CompletionV2Request, CompletionV2Response from .pydantic_types import BaseModel, Field +from .rest import CpuSpecificationType, GpuType, StorageSpecificationType # Common DTOs for batch completions @@ -105,6 +106,25 @@ class BatchCompletionsRequestBase(BaseModel): NOTE: this config is highly experimental and signature will change significantly in future iterations.""", ) + cpus: Optional[CpuSpecificationType] = Field( + default=None, description="CPUs to use for the batch inference." + ) + gpus: Optional[int] = Field( + default=None, description="Number of GPUs to use for the batch inference." + ) + memory: Optional[StorageSpecificationType] = Field( + default=None, description="Amount of memory to use for the batch inference." + ) + gpu_type: Optional[GpuType] = Field( + default=None, description="GPU type to use for the batch inference." + ) + storage: Optional[StorageSpecificationType] = Field( + default=None, description="Storage to use for the batch inference." + ) + nodes_per_worker: Optional[int] = Field( + default=None, description="Number of nodes per worker for the batch inference." + ) + # V1 DTOs for batch completions CompletionV1Output = CompletionOutput diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 9f963abb..9a150429 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta40" +version = "0.0.0.beta42" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index ea6c5e02..a1fc8bee 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.8", - version="0.0.0.beta40", + version="0.0.0.beta42", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) From 89b9ddd5c78c1d57d498c262a47057b3bc842a56 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 14 Oct 2024 12:56:41 -0700 Subject: [PATCH 400/425] Add *.py files to model weights if trust_remote_code is provided (#635) * Add *.py files to model weights if trust_remote_code is provided * Add to azure * add test * Add additional tests --- .../common/dtos/llms/__init__.py | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 47 +++++++++++++++---- .../inference/vllm/vllm_batch.py | 7 +-- .../tests/unit/domain/test_llm_use_cases.py | 22 +++++++++ 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/model-engine/model_engine_server/common/dtos/llms/__init__.py b/model-engine/model_engine_server/common/dtos/llms/__init__.py index ae7bef45..663be186 100644 --- a/model-engine/model_engine_server/common/dtos/llms/__init__.py +++ b/model-engine/model_engine_server/common/dtos/llms/__init__.py @@ -6,3 +6,4 @@ from .chat_completion import * # noqa: F403 from .completion import * # noqa: F403 from .model_endpoints import * # noqa: F403 +from .vllm import * # noqa: F403 diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 9ec9efb0..60874560 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -665,15 +665,28 @@ async def create_text_generation_inference_bundle( ).model_bundle_id def load_model_weights_sub_commands( - self, framework, framework_image_tag, checkpoint_path, final_weights_folder + self, + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code: bool = False, ): if checkpoint_path.startswith("s3://"): return self.load_model_weights_sub_commands_s3( - framework, framework_image_tag, checkpoint_path, final_weights_folder + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code, ) elif checkpoint_path.startswith("azure://") or "blob.core.windows.net" in checkpoint_path: return self.load_model_weights_sub_commands_abs( - framework, framework_image_tag, checkpoint_path, final_weights_folder + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code, ) else: raise ObjectHasInvalidValueException( @@ -681,7 +694,12 @@ def load_model_weights_sub_commands( ) def load_model_weights_sub_commands_s3( - self, framework, framework_image_tag, checkpoint_path, final_weights_folder + self, + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code: bool, ): subcommands = [] s5cmd = "s5cmd" @@ -700,14 +718,23 @@ def load_model_weights_sub_commands_s3( validate_checkpoint_files(checkpoint_files) # filter to configs ('*.model' and '*.json') and weights ('*.safetensors') + # For models that are not supported by transformers directly, we need to include '*.py' and '*.bin' + # to load the model. Only set this flag if "trust_remote_code" is set to True file_selection_str = '--include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*"' + if trust_remote_code: + file_selection_str += ' --include "*.py"' subcommands.append( f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" ) return subcommands def load_model_weights_sub_commands_abs( - self, framework, framework_image_tag, checkpoint_path, final_weights_folder + self, + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code: bool, ): subcommands = [] @@ -729,9 +756,8 @@ def load_model_weights_sub_commands_abs( ] ) else: - file_selection_str = ( - '--include-pattern "*.model;*.json;*.safetensors" --exclude-pattern "optimizer*"' - ) + additional_pattern = ";*.py" if trust_remote_code else "" + file_selection_str = f'--include-pattern "*.model;*.json;*.safetensors{additional_pattern}" --exclude-pattern "optimizer*"' subcommands.append( f"azcopy copy --recursive {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" ) @@ -861,6 +887,8 @@ def _create_vllm_bundle_command( subcommands = [] checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) + additional_args = infer_addition_engine_args_from_model_name(model_name) + # added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder if "mistral" in model_name: final_weights_folder = "mistral_files" @@ -871,6 +899,7 @@ def _create_vllm_bundle_command( framework_image_tag, checkpoint_path, final_weights_folder, + trust_remote_code=additional_args.trust_remote_code or False, ) if multinode and not is_worker: @@ -905,8 +934,6 @@ def _create_vllm_bundle_command( if hmi_config.sensitive_log_mode: # pragma: no cover vllm_cmd += " --disable-log-requests" - additional_args = infer_addition_engine_args_from_model_name(model_name) - for field in VLLMModelConfig.model_fields.keys(): config_value = getattr(additional_args, field, None) if config_value is not None: diff --git a/model-engine/model_engine_server/inference/vllm/vllm_batch.py b/model-engine/model_engine_server/inference/vllm/vllm_batch.py index b9f3f086..d24ca74e 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_batch.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_batch.py @@ -89,8 +89,9 @@ async def dummy_receive() -> MutableMapping[str, Any]: ) -async def download_model(checkpoint_path: str, target_dir: str) -> None: - s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}" +async def download_model(checkpoint_path: str, target_dir: str, trust_remote_code: bool) -> None: + additional_include = "--include '*.py'" if trust_remote_code else "" + s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' {additional_include} --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}" env = os.environ.copy() env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") # Need to override these env vars so s5cmd uses AWS_PROFILE @@ -319,11 +320,11 @@ async def handle_batch_job(request: CreateBatchCompletionsEngineRequest) -> None metrics_gateway = DatadogInferenceMonitoringMetricsGateway() model = get_model_name(request.model_cfg) - if request.model_cfg.checkpoint_path: await download_model( checkpoint_path=request.model_cfg.checkpoint_path, target_dir=MODEL_WEIGHTS_FOLDER, + trust_remote_code=request.model_cfg.trust_remote_code or False, ) content = load_batch_content(request) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 1341ba46..9fd09b73 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -524,6 +524,16 @@ def test_load_model_weights_sub_commands( ] assert expected_result == subcommands + trust_remote_code = True + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code + ) + + expected_result = [ + './s5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*" --include "*.py" s3://fake-checkpoint/* test_folder', + ] + assert expected_result == subcommands + framework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE framework_image_tag = "1.0.0" checkpoint_path = "s3://fake-checkpoint" @@ -555,6 +565,18 @@ def test_load_model_weights_sub_commands( ] assert expected_result == subcommands + trust_remote_code = True + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code + ) + + expected_result = [ + "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", + "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", + 'azcopy copy --recursive --include-pattern "*.model;*.json;*.safetensors;*.py" --exclude-pattern "optimizer*" azure://fake-checkpoint/* test_folder', + ] + assert expected_result == subcommands + @pytest.mark.asyncio async def test_create_model_endpoint_trt_llm_use_case_success( From 74a40e7276d1f1b5096578bed3ceac3a90b8a777 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 14 Oct 2024 15:29:25 -0700 Subject: [PATCH 401/425] vllm 0.6.3 (#636) --- .../model_engine_server/inference/vllm/Dockerfile.vllm | 5 +++-- .../inference/vllm/build_and_upload_image.sh | 7 +++++-- .../inference/vllm/requirements-dev.txt | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm index 8b005722..76085416 100644 --- a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm @@ -1,6 +1,7 @@ # syntax=docker/dockerfile:1 -ARG VLLM_VERSION=0.6.2 -ARG VLLM_BASE_IMAGE=vllm/vllm-openai:v${VLLM_VERSION} +ARG VLLM_VERSION=0.6.3 +ARG VLLM_BASE_REPO=vllm/vllm-openai +ARG VLLM_BASE_IMAGE=${VLLM_BASE_REPO}:v${VLLM_VERSION} FROM ${VLLM_BASE_IMAGE} AS base RUN apt-get update \ diff --git a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh index 65c49b32..10765cc0 100755 --- a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh +++ b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh @@ -4,7 +4,7 @@ set -eo pipefail # Build and push vLLM docker image to AWS ECR. # -# Usage: VLLM_VERSION=0.5.3.post1 ./build_and_upload_image.sh vllm|vllm_batch|vllm_batch_v2 +# Usage: VLLM_VERSION=0.6.3 ./build_and_upload_image.sh vllm|vllm_batch|vllm_batch_v2 SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) PROJECT_DIR=$SCRIPT_DIR/../../../.. @@ -21,14 +21,16 @@ if [ -z "$2" ]; then fi if [ -z "$3" ]; then - echo "Must supply the build target (either vllm or vllm_batch)" + echo "Must supply the build target (either vllm or vllm_batch_v2)" exit 1; fi + ACCOUNT=$1 IMAGE_TAG=$2 BUILD_TARGET=$3 VLLM_VERSION=${VLLM_VERSION:-"0.6.2"} +VLLM_BASE_REPO=${VLLM_BASE_REPO:-"vllm/vllm-openai"} # if build target = vllm use vllm otherwise use vllm_batch if [ "$BUILD_TARGET" == "vllm" ]; then @@ -40,6 +42,7 @@ fi aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com DOCKER_BUILDKIT=1 docker build \ --build-arg VLLM_VERSION=${VLLM_VERSION} \ + --build-arg VLLM_BASE_REPO=${VLLM_BASE_REPO} \ -f Dockerfile.vllm \ --target ${BUILD_TARGET} \ -t $IMAGE ${PROJECT_DIR} diff --git a/model-engine/model_engine_server/inference/vllm/requirements-dev.txt b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt index d330101a..b75668a1 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements-dev.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt @@ -1 +1 @@ -vllm==0.6.2 \ No newline at end of file +vllm==0.6.3 From 4adc3f229ab62eba5bb60c985b17855ed7e685c1 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 15 Oct 2024 13:45:54 -0700 Subject: [PATCH 402/425] Refactor client data types + add vllm arg passthrough (#637) * Refactor client data types + add vllm arg passthrough * Bump client version * fix dict assignment * add test --- clients/python/llmengine/__init__.py | 6 +- .../python/llmengine/data_types/__init__.py | 3 + .../llmengine/data_types/batch_completion.py | 3 +- .../llmengine/data_types/chat_completion.py | 89 ++---- .../python/llmengine/data_types/completion.py | 81 +++--- clients/python/llmengine/data_types/core.py | 84 ++++++ .../llmengine/data_types/model_endpoints.py | 212 ++++++++++++++ clients/python/llmengine/data_types/rest.py | 272 +----------------- clients/python/llmengine/data_types/vllm.py | 251 ++++++++++++++++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- .../common/dtos/llms/model_endpoints.py | 5 +- .../common/dtos/llms/vllm.py | 36 ++- .../use_cases/llm_model_endpoint_use_cases.py | 104 ++++--- model-engine/tests/unit/domain/conftest.py | 33 +++ .../tests/unit/domain/test_llm_use_cases.py | 64 ++++- 16 files changed, 830 insertions(+), 417 deletions(-) create mode 100644 clients/python/llmengine/data_types/core.py create mode 100644 clients/python/llmengine/data_types/model_endpoints.py create mode 100644 clients/python/llmengine/data_types/vllm.py diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 86f99642..e8d44235 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0beta42" +__version__ = "0.0.0beta43" import os from typing import Sequence @@ -61,7 +61,7 @@ ModelDownloadRequest, ModelDownloadResponse, UploadFileResponse, - VLLMAdditionalFields, + VLLMEndpointAdditionalArgs, ) from llmengine.file import File from llmengine.fine_tuning import FineTune @@ -80,7 +80,7 @@ "CancelFineTuneResponse", "ChatCompletionV2Request", "ChatCompletionV2Response", - "VLLMAdditionalFields", + "VLLMEndpointAdditionalArgs", "Completion", "CompletionOutput", "CompletionStreamOutput", diff --git a/clients/python/llmengine/data_types/__init__.py b/clients/python/llmengine/data_types/__init__.py index ff34f72b..a666add4 100644 --- a/clients/python/llmengine/data_types/__init__.py +++ b/clients/python/llmengine/data_types/__init__.py @@ -7,7 +7,10 @@ from .batch_completion import * # noqa: F403 from .chat_completion import * # noqa: F403 from .completion import * # noqa: F403 +from .core import * # noqa: F403 +from .model_endpoints import * # noqa: F403 from .rest import * # noqa: F403 +from .vllm import * # noqa: F403 # Alias for backwards compatibility CreateBatchCompletionsRequestContent: TypeAlias = ( diff --git a/clients/python/llmengine/data_types/batch_completion.py b/clients/python/llmengine/data_types/batch_completion.py index d72a3e82..6935351f 100644 --- a/clients/python/llmengine/data_types/batch_completion.py +++ b/clients/python/llmengine/data_types/batch_completion.py @@ -7,6 +7,7 @@ from .completion import CompletionOutput, CompletionV2Request, CompletionV2Response from .pydantic_types import BaseModel, Field from .rest import CpuSpecificationType, GpuType, StorageSpecificationType +from .vllm import VLLMModelConfig # Common DTOs for batch completions @@ -34,7 +35,7 @@ class ToolConfig(BaseModel): """ -class BatchCompletionsModelConfig(BaseModel): +class BatchCompletionsModelConfig(VLLMModelConfig): model: str = Field( description="ID of the model to use.", examples=["mixtral-8x7b-instruct"], diff --git a/clients/python/llmengine/data_types/chat_completion.py b/clients/python/llmengine/data_types/chat_completion.py index fdfa85b4..251b2bd5 100644 --- a/clients/python/llmengine/data_types/chat_completion.py +++ b/clients/python/llmengine/data_types/chat_completion.py @@ -1,73 +1,19 @@ -from typing import Any, Dict, List, Optional - -from .gen.openai import CreateChatCompletionRequest, CreateChatCompletionResponse +from typing import Optional, TypeAlias + +from .core import StreamError +from .gen.openai import ( + CreateChatCompletionRequest, + CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +) from .pydantic_types import BaseModel, Field +from .vllm import VLLMChatCompletionAdditionalParams # Fields that are a part of OpenAI spec but are not supported by model engine UNSUPPORTED_FIELDS = ["service_tier"] -class VLLMAdditionalFields(BaseModel): - chat_template: Optional[str] = Field( - default=None, - description=( - "A Jinja template to use for this conversion. " - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ), - ) - chat_template_kwargs: Optional[Dict[str, Any]] = Field( - default=None, - description=( - "Additional kwargs to pass to the template renderer. " - "Will be accessible by the chat template." - ), - ) - - guided_json: Optional[Dict[str, Any]] = Field( - default=None, - description="JSON schema for guided decoding. Only supported in vllm.", - ) - - guided_regex: Optional[str] = Field( - default=None, - description="Regex for guided decoding. Only supported in vllm.", - ) - guided_choice: Optional[List[str]] = Field( - default=None, - description="Choices for guided decoding. Only supported in vllm.", - ) - - guided_grammar: Optional[str] = Field( - default=None, - description="Context-free grammar for guided decoding. Only supported in vllm.", - ) - - guided_decoding_backend: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be either " - "'outlines' / 'lm-format-enforcer'" - ), - ) - - guided_whitespace_pattern: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding." - ), - ) - - skip_special_tokens: Optional[bool] = Field( - True, - description="Whether to skip special tokens in the output. Only supported in vllm.", - ) - - -class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMAdditionalFields): +class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMChatCompletionAdditionalParams): model: str = Field( description="ID of the model to use.", examples=["mixtral-8x7b-instruct"], @@ -89,5 +35,16 @@ class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMAdditionalFields) ) -class ChatCompletionV2Response(CreateChatCompletionResponse): - pass +ChatCompletionV2SyncResponse: TypeAlias = CreateChatCompletionResponse +ChatCompletionV2StreamSuccessChunk: TypeAlias = CreateChatCompletionStreamResponse + + +class ChatCompletionV2StreamErrorChunk(BaseModel): + error: StreamError + + +ChatCompletionV2Chunk: TypeAlias = ( + ChatCompletionV2StreamSuccessChunk | ChatCompletionV2StreamErrorChunk +) + +ChatCompletionV2Response: TypeAlias = ChatCompletionV2SyncResponse diff --git a/clients/python/llmengine/data_types/completion.py b/clients/python/llmengine/data_types/completion.py index fc92f711..94d7c62a 100644 --- a/clients/python/llmengine/data_types/completion.py +++ b/clients/python/llmengine/data_types/completion.py @@ -1,7 +1,11 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TypeAlias +from typing_extensions import Annotated + +from .core import StreamError from .gen.openai import CreateCompletionRequest, CreateCompletionResponse from .pydantic_types import BaseModel, Field +from .vllm import VLLMCompletionAdditionalParams # Fields that are a part of OpenAI spec but are not supported by model engine UNSUPPORTED_FIELDS = ["service_tier"] @@ -205,24 +209,6 @@ class CompletionStreamOutput(BaseModel): """Detailed token information.""" -class StreamErrorContent(BaseModel): - error: str - """Error message.""" - timestamp: str - """Timestamp of the error.""" - - -class StreamError(BaseModel): - """ - Error object for a stream prompt completion task. - """ - - status_code: int - """The HTTP status code of the error.""" - content: StreamErrorContent - """The error content.""" - - class CompletionStreamV1Response(BaseModel): """Error of the response (if any).""" @@ -285,27 +271,46 @@ def inter_token_latency(self) -> Optional[float]: # Only for streaming requests return (self.total_duration - self.time_to_first_token) / (self.num_completion_tokens - 1) -class CompletionV2Request(CreateCompletionRequest): - model: str = Field( - description="ID of the model to use.", - examples=["mixtral-8x7b-instruct"], - ) +class CompletionV2Request(CreateCompletionRequest, VLLMCompletionAdditionalParams): + model: Annotated[ + str, + Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ), + ] + + stream: Annotated[ + Optional[bool], + Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + + top_k: Annotated[ + Optional[int], + Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ), + ] + + include_stop_str_in_output: Annotated[ + Optional[bool], + Field(None, description="Whether to include the stop strings in output text."), + ] + + +CompletionV2SyncResponse: TypeAlias = CreateCompletionResponse +CompletionV2StreamSuccessChunk: TypeAlias = CreateCompletionResponse - stream: Optional[bool] = Field( - False, - description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", - ) - top_k: Optional[int] = Field( - None, - ge=-1, - description="Controls the number of top tokens to consider. -1 means consider all tokens.", - ) +class CompletionV2StreamErrorChunk(BaseModel): + error: StreamError - include_stop_str_in_output: Optional[bool] = Field( - None, description="Whether to include the stop strings in output text." - ) +CompletionV2StreamChunk: TypeAlias = CompletionV2StreamSuccessChunk | CompletionV2StreamErrorChunk -class CompletionV2Response(CreateCompletionResponse): - pass +CompletionV2Response: TypeAlias = CompletionV2SyncResponse diff --git a/clients/python/llmengine/data_types/core.py b/clients/python/llmengine/data_types/core.py new file mode 100644 index 00000000..93961a9e --- /dev/null +++ b/clients/python/llmengine/data_types/core.py @@ -0,0 +1,84 @@ +from enum import Enum +from typing import Literal, Union + +from .pydantic_types import BaseModel, Field + +CpuSpecificationType = Union[str, int, float] +StorageSpecificationType = Union[str, int, float] + + +class LLMInferenceFramework(str, Enum): + DEEPSPEED = "deepspeed" + TEXT_GENERATION_INFERENCE = "text_generation_inference" + VLLM = "vllm" + LIGHTLLM = "lightllm" + TENSORRT_LLM = "tensorrt_llm" + + +class LLMSource(str, Enum): + HUGGING_FACE = "hugging_face" + + +class Quantization(str, Enum): + BITSANDBYTES = "bitsandbytes" + AWQ = "awq" + + +class GpuType(str, Enum): + """Lists allowed GPU types for LLMEngine.""" + + NVIDIA_TESLA_T4 = "nvidia-tesla-t4" + NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" + NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" + NVIDIA_AMPERE_A100E = "nvidia-ampere-a100e" + NVIDIA_HOPPER_H100 = "nvidia-hopper-h100" + NVIDIA_HOPPER_H100_1G_20GB = "nvidia-hopper-h100-1g20gb" + NVIDIA_HOPPER_H100_3G_40GB = "nvidia-hopper-h100-3g40gb" + + +class ModelEndpointType(str, Enum): + STREAMING = "streaming" + + +class ModelEndpointStatus(str, Enum): + # Duplicates common/types::EndpointStatus, when refactor is done, delete the old type + # See EndpointStatus for status explanations + READY = "READY" + UPDATE_PENDING = "UPDATE_PENDING" + UPDATE_IN_PROGRESS = "UPDATE_IN_PROGRESS" + UPDATE_FAILED = "UPDATE_FAILED" + DELETE_IN_PROGRESS = "DELETE_IN_PROGRESS" + + +class CallbackBasicAuth(BaseModel): + kind: Literal["basic"] + username: str + password: str + + +class CallbackmTLSAuth(BaseModel): + kind: Literal["mtls"] + cert: str + key: str + + +class CallbackAuth(BaseModel): + __root__: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") + + +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" diff --git a/clients/python/llmengine/data_types/model_endpoints.py b/clients/python/llmengine/data_types/model_endpoints.py new file mode 100644 index 00000000..2e087773 --- /dev/null +++ b/clients/python/llmengine/data_types/model_endpoints.py @@ -0,0 +1,212 @@ +from typing import Any, Dict, List, Optional + +from .core import ( + CallbackAuth, + CpuSpecificationType, + GpuType, + LLMInferenceFramework, + LLMSource, + ModelEndpointStatus, + ModelEndpointType, + Quantization, + StorageSpecificationType, +) +from .pydantic_types import BaseModel, Field, HttpUrl +from .rest import GetModelEndpointResponse +from .vllm import VLLMEndpointAdditionalArgs + + +class CreateLLMEndpointRequest(VLLMEndpointAdditionalArgs, BaseModel): + name: str + + # LLM specific fields + model_name: str + source: LLMSource = LLMSource.HUGGING_FACE + inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM + inference_framework_image_tag: str = "latest" + num_shards: int = 1 + """ + Number of shards to distribute the model onto GPUs. + """ + + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + + # General endpoint fields + metadata: Dict[str, Any] # TODO: JSON type + post_inference_hooks: Optional[List[str]] = None + endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + nodes_per_worker: Optional[int] = None + optimize_costs: Optional[bool] = None + min_workers: int + max_workers: int + per_worker: int + labels: Dict[str, str] + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrl] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = True # LLM endpoints are public by default. + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + +class CreateLLMEndpointResponse(BaseModel): + endpoint_creation_task_id: str + + +class GetLLMEndpointResponse(BaseModel): + """ + Response object for retrieving a Model. + """ + + id: Optional[str] = Field( + default=None, + description="(For self-hosted users) The autogenerated ID of the model.", + ) + """(For self-hosted users) The autogenerated ID of the model.""" + + name: str = Field( + description="The name of the model. Use this for making inference requests to the model." + ) + """The name of the model. Use this for making inference requests to the model.""" + + model_name: Optional[str] = Field( + default=None, + description="(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`.", + ) + """(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`.""" + + source: LLMSource = Field(description="The source of the model, e.g. Hugging Face.") + """The source of the model, e.g. Hugging Face.""" + + status: ModelEndpointStatus = Field(description="The status of the model.") + """The status of the model (can be one of "READY", "UPDATE_PENDING", "UPDATE_IN_PROGRESS", "UPDATE_FAILED", "DELETE_IN_PROGRESS").""" + + inference_framework: LLMInferenceFramework = Field( + description="The inference framework used by the model." + ) + """(For self-hosted users) The inference framework used by the model.""" + + inference_framework_tag: Optional[str] = Field( + default=None, + description="(For self-hosted users) The Docker image tag used to run the model.", + ) + """(For self-hosted users) The Docker image tag used to run the model.""" + + num_shards: Optional[int] = Field( + default=None, description="(For self-hosted users) The number of shards." + ) + """(For self-hosted users) The number of shards.""" + + quantize: Optional[Quantization] = Field( + default=None, description="(For self-hosted users) The quantization method." + ) + """(For self-hosted users) The quantization method.""" + + spec: Optional[GetModelEndpointResponse] = Field( + default=None, description="(For self-hosted users) Model endpoint details." + ) + """(For self-hosted users) Model endpoint details.""" + + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + +class ListLLMEndpointsResponse(BaseModel): + """ + Response object for listing Models. + """ + + model_endpoints: List[GetLLMEndpointResponse] = Field( + ..., + description="The list of models.", + ) + """ + A list of Models, represented as `GetLLMEndpointResponse`s. + """ + + +class UpdateLLMEndpointRequest(VLLMEndpointAdditionalArgs, BaseModel): + # LLM specific fields + model_name: Optional[str] = None + source: Optional[LLMSource] = None + inference_framework_image_tag: Optional[str] = None + num_shards: Optional[int] = None + """ + Number of shards to distribute the model onto GPUs. + """ + + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + + # General endpoint fields + metadata: Optional[Dict[str, Any]] = None + post_inference_hooks: Optional[List[str]] = None + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None + min_workers: Optional[int] = None + max_workers: Optional[int] = None + per_worker: Optional[int] = None + labels: Optional[Dict[str, str]] = None + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrl] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = None + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + force_bundle_recreation: Optional[bool] = False + """ + Whether to force recreate the underlying bundle. + + If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created + that we would like to pick up for existing endpoints + """ + + +class UpdateLLMEndpointResponse(BaseModel): + endpoint_creation_task_id: str + + +class DeleteLLMEndpointResponse(BaseModel): + """ + Response object for deleting a Model. + """ + + deleted: bool = Field(..., description="Whether deletion was successful.") + """ + Whether the deletion succeeded. + """ diff --git a/clients/python/llmengine/data_types/rest.py b/clients/python/llmengine/data_types/rest.py index 88aa3ad6..f2978cd3 100644 --- a/clients/python/llmengine/data_types/rest.py +++ b/clients/python/llmengine/data_types/rest.py @@ -4,72 +4,18 @@ import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union - +from typing import Any, Dict, List, Optional + +from .core import ( + CallbackAuth, + CpuSpecificationType, + GpuType, + ModelEndpointStatus, + ModelEndpointType, + StorageSpecificationType, +) from .pydantic_types import BaseModel, Field, HttpUrl -CpuSpecificationType = Union[str, int, float] -StorageSpecificationType = Union[str, int, float] - - -class LLMInferenceFramework(str, Enum): - DEEPSPEED = "deepspeed" - TEXT_GENERATION_INFERENCE = "text_generation_inference" - VLLM = "vllm" - LIGHTLLM = "lightllm" - TENSORRT_LLM = "tensorrt_llm" - - -class LLMSource(str, Enum): - HUGGING_FACE = "hugging_face" - - -class Quantization(str, Enum): - BITSANDBYTES = "bitsandbytes" - AWQ = "awq" - - -class GpuType(str, Enum): - """Lists allowed GPU types for LLMEngine.""" - - NVIDIA_TESLA_T4 = "nvidia-tesla-t4" - NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" - NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" - NVIDIA_AMPERE_A100E = "nvidia-ampere-a100e" - NVIDIA_HOPPER_H100 = "nvidia-hopper-h100" - NVIDIA_HOPPER_H100_1G_20GB = "nvidia-hopper-h100-1g20gb" - NVIDIA_HOPPER_H100_3G_40GB = "nvidia-hopper-h100-3g40gb" - - -class ModelEndpointType(str, Enum): - STREAMING = "streaming" - - -class ModelEndpointStatus(str, Enum): - # Duplicates common/types::EndpointStatus, when refactor is done, delete the old type - # See EndpointStatus for status explanations - READY = "READY" - UPDATE_PENDING = "UPDATE_PENDING" - UPDATE_IN_PROGRESS = "UPDATE_IN_PROGRESS" - UPDATE_FAILED = "UPDATE_FAILED" - DELETE_IN_PROGRESS = "DELETE_IN_PROGRESS" - - -class CallbackBasicAuth(BaseModel): - kind: Literal["basic"] - username: str - password: str - - -class CallbackmTLSAuth(BaseModel): - kind: Literal["mtls"] - cert: str - key: str - - -class CallbackAuth(BaseModel): - __root__: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") - class ModelEndpointDeploymentState(BaseModel): """ @@ -137,204 +83,6 @@ class PostInferenceHooks(str, Enum): CALLBACK: str = "callback" -class CreateLLMEndpointRequest(BaseModel): - name: str - - # LLM specific fields - model_name: str - source: LLMSource = LLMSource.HUGGING_FACE - inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM - inference_framework_image_tag: str - num_shards: int = 1 - """ - Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models - """ - - quantize: Optional[Quantization] = None - """ - Quantization for the LLM. Only affects behavior for text-generation-inference models - """ - - checkpoint_path: Optional[str] = None - """ - Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models - """ - - # General endpoint fields - metadata: Dict[str, Any] # TODO: JSON type - post_inference_hooks: Optional[List[str]] - endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING - cpus: Optional[CpuSpecificationType] - gpus: Optional[int] - memory: Optional[StorageSpecificationType] - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - nodes_per_worker: Optional[int] = None - optimize_costs: Optional[bool] = None - min_workers: int - max_workers: int - per_worker: int - labels: Dict[str, str] - prewarm: Optional[bool] = None - high_priority: Optional[bool] - default_callback_url: Optional[HttpUrl] = None - default_callback_auth: Optional[CallbackAuth] = None - public_inference: Optional[bool] = True - """ - Whether the endpoint can be used for inference for all users. LLM endpoints are public by default. - """ - chat_template_override: Optional[str] = Field( - default=None, - description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", - ) - - -class CreateLLMEndpointResponse(BaseModel): - endpoint_creation_task_id: str - - -class GetLLMEndpointResponse(BaseModel): - """ - Response object for retrieving a Model. - """ - - id: Optional[str] = Field( - default=None, - description="(For self-hosted users) The autogenerated ID of the model.", - ) - """(For self-hosted users) The autogenerated ID of the model.""" - - name: str = Field( - description="The name of the model. Use this for making inference requests to the model." - ) - """The name of the model. Use this for making inference requests to the model.""" - - model_name: Optional[str] = Field( - default=None, - description="(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`.", - ) - """(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`.""" - - source: LLMSource = Field(description="The source of the model, e.g. Hugging Face.") - """The source of the model, e.g. Hugging Face.""" - - status: ModelEndpointStatus = Field(description="The status of the model.") - """The status of the model (can be one of "READY", "UPDATE_PENDING", "UPDATE_IN_PROGRESS", "UPDATE_FAILED", "DELETE_IN_PROGRESS").""" - - inference_framework: LLMInferenceFramework = Field( - description="The inference framework used by the model." - ) - """(For self-hosted users) The inference framework used by the model.""" - - inference_framework_tag: Optional[str] = Field( - default=None, - description="(For self-hosted users) The Docker image tag used to run the model.", - ) - """(For self-hosted users) The Docker image tag used to run the model.""" - - num_shards: Optional[int] = Field( - default=None, description="(For self-hosted users) The number of shards." - ) - """(For self-hosted users) The number of shards.""" - - quantize: Optional[Quantization] = Field( - default=None, description="(For self-hosted users) The quantization method." - ) - """(For self-hosted users) The quantization method.""" - - spec: Optional[GetModelEndpointResponse] = Field( - default=None, description="(For self-hosted users) Model endpoint details." - ) - """(For self-hosted users) Model endpoint details.""" - - chat_template_override: Optional[str] = Field( - default=None, - description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", - ) - - -class ListLLMEndpointsResponse(BaseModel): - """ - Response object for listing Models. - """ - - model_endpoints: List[GetLLMEndpointResponse] = Field( - ..., - description="The list of models.", - ) - """ - A list of Models, represented as `GetLLMEndpointResponse`s. - """ - - -class UpdateLLMEndpointRequest(BaseModel): - # LLM specific fields - model_name: Optional[str] - source: Optional[LLMSource] - inference_framework_image_tag: Optional[str] - num_shards: Optional[int] - """ - Number of shards to distribute the model onto GPUs. - """ - - quantize: Optional[Quantization] - """ - Whether to quantize the model. - """ - - checkpoint_path: Optional[str] - """ - Path to the checkpoint to load the model from. - """ - - # General endpoint fields - metadata: Optional[Dict[str, Any]] - post_inference_hooks: Optional[List[str]] - cpus: Optional[CpuSpecificationType] - gpus: Optional[int] - memory: Optional[StorageSpecificationType] - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] - min_workers: Optional[int] - max_workers: Optional[int] - per_worker: Optional[int] - labels: Optional[Dict[str, str]] - prewarm: Optional[bool] - high_priority: Optional[bool] - billing_tags: Optional[Dict[str, Any]] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] - public_inference: Optional[bool] - chat_template_override: Optional[str] = Field( - default=None, - description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", - ) - - force_bundle_recreation: Optional[bool] = False - """ - Whether to force recreate the underlying bundle. - - If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created - that we would like to pick up for existing endpoints - """ - - -class UpdateLLMEndpointResponse(BaseModel): - endpoint_creation_task_id: str - - -class DeleteLLMEndpointResponse(BaseModel): - """ - Response object for deleting a Model. - """ - - deleted: bool = Field(..., description="Whether deletion was successful.") - """ - Whether the deletion succeeded. - """ - - class CreateFineTuneRequest(BaseModel): """ Request object for creating a FineTune. diff --git a/clients/python/llmengine/data_types/vllm.py b/clients/python/llmengine/data_types/vllm.py new file mode 100644 index 00000000..56250a34 --- /dev/null +++ b/clients/python/llmengine/data_types/vllm.py @@ -0,0 +1,251 @@ +from typing import Any, Dict, List, Optional, Union + +from typing_extensions import Annotated + +from .gen.openai import ResponseFormatJsonObject, ResponseFormatJsonSchema, ResponseFormatText +from .pydantic_types import BaseModel, Field + +# This was last synced w/ vLLM v0.5.5 on 2024-09-03 + + +class VLLMModelConfig(BaseModel): + """Model configuration for VLLM""" + + max_model_len: Optional[int] = Field( + None, + description="""Model context length, If unspecified, will be automatically derived from the model config""", + ) + + max_num_seqs: Optional[int] = Field( + None, + description="""Maximum number of sequences per iteration""", + ) + + enforce_eager: Optional[bool] = Field( + None, + description="""Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph in hybrid for maximal perforamnce and flexibility""", + ) + + gpu_memory_utilization: Optional[float] = Field( + None, + description="Maximum GPU memory utilization for the batch inference. Default to 90%.", + ) + + trust_remote_code: Optional[bool] = Field( + default=False, + description="Whether to trust remote code from Hugging face hub. This is only applicable to models whose code is not supported natively by the transformers library (e.g. deepseek). Default to False.", + ) + + +class VLLMEngineAdditionalArgs(BaseModel): + """Additional arguments to configure for vLLM that are not direct inputs to the vLLM engine""" + + max_gpu_memory_utilization: Optional[float] = Field( + None, + description="Maximum GPU memory utilization for the batch inference. Default to 90%. Deprecated in favor of specifying this in VLLMModelConfig", + ) + + attention_backend: Optional[str] = Field( + default=None, + description="Attention backend to use for vLLM. Default to None.", + ) + + +class VLLMEndpointAdditionalArgs(VLLMModelConfig, VLLMEngineAdditionalArgs, BaseModel): + pass + + +class VLLMSamplingParams(BaseModel): + best_of: Optional[int] = Field( + None, + description="""Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. This is treated as + the beam width when `use_beam_search` is True. By default, `best_of` + is set to `n`.""", + ) + top_k: Annotated[ + Optional[int], + Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ), + ] + min_p: Optional[float] = Field( + None, + description="""Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this.""", + ) + use_beam_search: Optional[bool] = Field( + None, + description="""Whether to use beam search for sampling.""", + ) + length_penalty: Optional[float] = Field( + default=None, + description="""Float that penalizes sequences based on their length. + Used in beam search.""", + ) + repetition_penalty: Optional[float] = Field( + default=None, + description="""Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens.""", + ) + early_stopping: Optional[bool] = Field( + None, + description="""Controls the stopping condition for beam search. It + accepts the following values: `True`, where the generation stops as + soon as there are `best_of` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very + unlikely to find better candidates; `"never"`, where the beam search + procedure only stops when there cannot be better candidates + (canonical beam search algorithm).""", + ) + stop_token_ids: Optional[List[int]] = Field( + default_factory=list, + description="""List of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are special tokens.""", + ) + include_stop_str_in_output: Annotated[ + Optional[bool], + Field( + None, + description="""Whether to include the stop strings in + output text. Defaults to False.""", + ), + ] + ignore_eos: Optional[bool] = Field( + None, + description="""Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated.""", + ) + min_tokens: Optional[int] = Field( + None, + description="""Minimum number of tokens to generate per output sequence + before EOS or stop_token_ids can be generated""", + ) + + skip_special_tokens: Optional[bool] = Field( + True, + description="Whether to skip special tokens in the output. Only supported in vllm.", + ) + + spaces_between_special_tokens: Optional[bool] = Field( + True, + description="Whether to add spaces between special tokens in the output. Only supported in vllm.", + ) + + +class VLLMChatCompletionAdditionalParams(VLLMSamplingParams): + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the model's tokenizer " + "does not define one and no override template is given" + ), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ) + + guided_json: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_regex: Optional[str] = Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ) + + guided_grammar: Optional[str] = Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ) + + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ) + + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ) + + +class VLLMCompletionAdditionalParams(VLLMSamplingParams): + add_special_tokens: Optional[bool] = Field( + default=None, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt." + ), + ) + + response_format: Optional[ + Union[ResponseFormatText, ResponseFormatJsonObject, ResponseFormatJsonSchema] + ] = Field( + default=None, + description=( + "Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'} or {'type': 'text' } is " + "supported." + ), + ) + + guided_json: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_regex: Optional[str] = Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ) + + guided_grammar: Optional[str] = Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ) + + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ) + + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 9a150429..a0e1794f 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta42" +version = "0.0.0.beta43" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index a1fc8bee..097bcd1e 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.8", - version="0.0.0.beta42", + version="0.0.0.beta43", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) diff --git a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py index 82619f47..71cf4e69 100644 --- a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional from model_engine_server.common.dtos.core import HttpUrlStr +from model_engine_server.common.dtos.llms.vllm import VLLMEndpointAdditionalArgs from model_engine_server.common.dtos.model_endpoints import ( CpuSpecificationType, GetModelEndpointV1Response, @@ -26,7 +27,7 @@ ) -class CreateLLMModelEndpointV1Request(BaseModel): +class CreateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel): name: str # LLM specific fields @@ -106,7 +107,7 @@ class ListLLMModelEndpointsV1Response(BaseModel): model_endpoints: List[GetLLMModelEndpointV1Response] -class UpdateLLMModelEndpointV1Request(BaseModel): +class UpdateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel): # LLM specific fields model_name: Optional[str] = None source: Optional[LLMSource] = None diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py index a904a597..72c13aa4 100644 --- a/model-engine/model_engine_server/common/dtos/llms/vllm.py +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -31,7 +31,7 @@ class VLLMModelConfig(BaseModel): gpu_memory_utilization: Optional[float] = Field( None, - description="Maximum GPU memory utilization for the batch inference. Default to 90%.", + description="Maximum GPU memory utilization use for the engine. Default to 90%.", ) trust_remote_code: Optional[bool] = Field( @@ -39,6 +39,36 @@ class VLLMModelConfig(BaseModel): description="Whether to trust remote code from Hugging face hub. This is only applicable to models whose code is not supported natively by the transformers library (e.g. deepseek). Default to False.", ) + pipeline_parallel_size: Optional[int] = Field( + None, + description="Number of pipeline stages. Default to None.", + ) + + tensor_parallel_size: Optional[int] = Field( + None, + description="Number of tensor parallel replicas. Default to None.", + ) + + quantization: Optional[str] = Field( + None, + description="Method used to quantize the weights. If " + "None, we first check the `quantization_config` " + "attribute in the model config file. If that is " + "None, we assume the model weights are not " + "quantized and use `dtype` to determine the data " + "type of the weights.", + ) + + disable_log_requests: Optional[bool] = Field( + None, + description="Disable logging requests. Default to None.", + ) + + chat_template: Optional[str] = Field( + None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + class VLLMEngineAdditionalArgs(BaseModel): """Additional arguments to configure for vLLM that are not direct inputs to the vLLM engine""" @@ -54,6 +84,10 @@ class VLLMEngineAdditionalArgs(BaseModel): ) +class VLLMEndpointAdditionalArgs(VLLMModelConfig, VLLMEngineAdditionalArgs, BaseModel): + pass + + class VLLMSamplingParams(BaseModel): best_of: Optional[int] = Field( None, diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 60874560..8c736725 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -48,14 +48,13 @@ GetBatchCompletionV2Response, UpdateBatchCompletionsV2Request, UpdateBatchCompletionsV2Response, - VLLMEngineAdditionalArgs, ) from model_engine_server.common.dtos.llms.completion import ( CompletionV2Request, CompletionV2StreamSuccessChunk, CompletionV2SyncResponse, ) -from model_engine_server.common.dtos.llms.vllm import VLLMModelConfig +from model_engine_server.common.dtos.llms.vllm import VLLMEndpointAdditionalArgs, VLLMModelConfig from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus @@ -511,6 +510,7 @@ async def execute( checkpoint_path: Optional[str], chat_template_override: Optional[str], nodes_per_worker: int, + additional_args: Optional[Dict[str, Any]] = None, ) -> ModelBundle: multinode = nodes_per_worker > 1 if source == LLMSource.HUGGING_FACE: @@ -541,6 +541,11 @@ async def execute( checkpoint_path, ) elif framework == LLMInferenceFramework.VLLM: + additional_vllm_args = ( + VLLMEndpointAdditionalArgs.model_validate(additional_args) + if additional_args + else None + ) if multinode: bundle_id = await self.create_vllm_multinode_bundle( user, @@ -552,6 +557,7 @@ async def execute( quantize, checkpoint_path, chat_template_override, + additional_args=additional_vllm_args, ) else: bundle_id = await self.create_vllm_bundle( @@ -563,6 +569,7 @@ async def execute( quantize, checkpoint_path, chat_template_override, + additional_args=additional_vllm_args, ) elif framework == LLMInferenceFramework.LIGHTLLM: bundle_id = await self.create_lightllm_bundle( @@ -879,71 +886,84 @@ def _create_vllm_bundle_command( multinode: bool, is_worker: bool, nodes_per_worker: int = 1, # only used if multinode + additional_args: Optional[VLLMEndpointAdditionalArgs] = None, ): """ VLLM start command for the single worker, or the leader in a LeaderWorkerSet. """ - command = [] subcommands = [] checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) - additional_args = infer_addition_engine_args_from_model_name(model_name) + + # merge additional_args with inferred_additional_args + # We assume user provided additional args takes precedence over inferred args + vllm_args = VLLMEndpointAdditionalArgs.model_validate( + { + **( + infer_addition_engine_args_from_model_name(model_name).model_dump( + exclude_none=True + ) + ), + **(additional_args.model_dump(exclude_none=True) if additional_args else {}), + } + ) # added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder - if "mistral" in model_name: - final_weights_folder = "mistral_files" - else: - final_weights_folder = "model_files" + final_weights_folder = "mistral_files" if "mistral" in model_name else "model_files" subcommands += self.load_model_weights_sub_commands( LLMInferenceFramework.VLLM, framework_image_tag, checkpoint_path, final_weights_folder, - trust_remote_code=additional_args.trust_remote_code or False, + trust_remote_code=vllm_args.trust_remote_code or False, ) - if multinode and not is_worker: - ray_cmd = "/workspace/init_ray.sh leader --ray_cluster_size=$RAY_CLUSTER_SIZE --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" - subcommands.append(ray_cmd) - elif multinode and is_worker: - ray_cmd = "/workspace/init_ray.sh worker --ray_address=$LWS_LEADER_ADDRESS.svc.cluster.local --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" + if multinode: + if not is_worker: + ray_cmd = "/workspace/init_ray.sh leader --ray_cluster_size=$RAY_CLUSTER_SIZE --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" + else: + ray_cmd = "/workspace/init_ray.sh worker --ray_address=$LWS_LEADER_ADDRESS.svc.cluster.local --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" subcommands.append(ray_cmd) if not is_worker: - vllm_cmd = f"python -m vllm_server --model {final_weights_folder} --tensor-parallel-size {num_shards} --port 5005" + vllm_args.tensor_parallel_size = num_shards + + if vllm_args.gpu_memory_utilization is not None: + vllm_args.enforce_eager = True if multinode: - vllm_cmd += f" --pipeline-parallel-size {nodes_per_worker}" + vllm_args.pipeline_parallel_size = nodes_per_worker - chat_template_cmd = None if chat_template_override: - # We encode the chat template as base64 to avoid issues with special characters - # and decode it via bash - chat_template_cmd = f'export CHAT_TEMPLATE=$(echo "{encode_template(chat_template_override)}" | base64 --decode)' - subcommands.append(chat_template_cmd) - vllm_cmd += ' --chat-template "$CHAT_TEMPLATE"' + vllm_args.chat_template = chat_template_override - if quantize: # pragma: no cover + if quantize: if quantize != Quantization.AWQ: raise InvalidRequestException( f"Quantization {quantize} is not supported by vLLM." ) - vllm_cmd += f" --quantization {quantize}" + vllm_args.quantization = quantize - if hmi_config.sensitive_log_mode: # pragma: no cover - vllm_cmd += " --disable-log-requests" + if hmi_config.sensitive_log_mode: + vllm_args.disable_log_requests = True - for field in VLLMModelConfig.model_fields.keys(): - config_value = getattr(additional_args, field, None) + vllm_cmd = f"python -m vllm_server --model {final_weights_folder} --port 5005" + for field in VLLMEndpointAdditionalArgs.model_fields.keys(): + config_value = getattr(vllm_args, field, None) if config_value is not None: - vllm_cmd += f" --{field.replace('_', '-')} {config_value}" - - if field == "gpu_memory_utilization": - vllm_cmd += " --enforce-eager" - - if additional_args.attention_backend is not None: - vllm_cmd += f" --attention-backend {additional_args.attention_backend}" + # Special handling for chat_template + # Need to encode the chat template as base64 to avoid issues with special characters + if field == "chat_template": + chat_template_cmd = f'export CHAT_TEMPLATE=$(echo "{encode_template(config_value)}" | base64 --decode)' + subcommands.append(chat_template_cmd) + config_value = '"$CHAT_TEMPLATE"' + + # if type of config_value is True, then only need to add the key + if isinstance(config_value, bool) and config_value: + vllm_cmd += f" --{field.replace('_', '-')}" + else: + vllm_cmd += f" --{field.replace('_', '-')} {config_value}" subcommands.append(vllm_cmd) @@ -965,6 +985,7 @@ async def create_vllm_bundle( quantize: Optional[Quantization], checkpoint_path: Optional[str], chat_template_override: Optional[str], + additional_args: Optional[VLLMEndpointAdditionalArgs] = None, ): command = self._create_vllm_bundle_command( model_name, @@ -976,6 +997,7 @@ async def create_vllm_bundle( multinode=False, is_worker=False, nodes_per_worker=1, + additional_args=additional_args, ) create_model_bundle_v2_request = CreateModelBundleV2Request( @@ -1023,6 +1045,7 @@ async def create_vllm_multinode_bundle( quantize: Optional[Quantization], checkpoint_path: Optional[str], chat_template_override: Optional[str], + additional_args: Optional[VLLMEndpointAdditionalArgs] = None, ): leader_command = self._create_vllm_bundle_command( model_name, @@ -1034,6 +1057,7 @@ async def create_vllm_multinode_bundle( multinode=True, is_worker=False, nodes_per_worker=nodes_per_worker, + additional_args=additional_args, ) worker_command = self._create_vllm_bundle_command( model_name, @@ -1312,6 +1336,7 @@ async def execute( checkpoint_path=request.checkpoint_path, chat_template_override=request.chat_template_override, nodes_per_worker=request.nodes_per_worker, + additional_args=request.model_dump(exclude_none=True), ) validate_resource_requests( bundle=bundle, @@ -1571,6 +1596,7 @@ async def execute( checkpoint_path=checkpoint_path, chat_template_override=chat_template_override, nodes_per_worker=model_endpoint.infra_state.resource_state.nodes_per_worker, + additional_args=request.model_dump(exclude_none=True), ) metadata = endpoint_record.metadata or {} @@ -3361,13 +3387,9 @@ async def _infer_hardware( ) -class VLLMAdditionalArgs(VLLMModelConfig, VLLMEngineAdditionalArgs): - pass - - def infer_addition_engine_args_from_model_name( model_name: str, -) -> VLLMAdditionalArgs: +) -> VLLMEndpointAdditionalArgs: # Increase max gpu utilization for larger models model_param_count_b = get_model_param_count_b(model_name) if model_param_count_b >= 70: @@ -3385,7 +3407,7 @@ def infer_addition_engine_args_from_model_name( if model_name.startswith("deepseek"): trust_remote_code = True - return VLLMAdditionalArgs( + return VLLMEndpointAdditionalArgs( gpu_memory_utilization=gpu_memory_utilization, attention_backend=attention_backend, trust_remote_code=trust_remote_code, diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py index 58c1a0de..f18808a7 100644 --- a/model-engine/tests/unit/domain/conftest.py +++ b/model-engine/tests/unit/domain/conftest.py @@ -377,6 +377,39 @@ def create_llm_model_endpoint_request_llama_3_70b_chat() -> CreateLLMModelEndpoi ) +@pytest.fixture +def create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args() -> ( + CreateLLMModelEndpointV1Request +): + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_3_70b_chat", + model_name="llama-3-70b", + source="hugging_face", + inference_framework="vllm", + inference_framework_image_tag="1.0.0", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-3-70b", + chat_template_override="test-template", + max_model_len=1000, + max_num_seqs=10, + ) + + @pytest.fixture def create_llm_model_endpoint_request_llama_3_1_405b_instruct() -> CreateLLMModelEndpointV1Request: return CreateLLMModelEndpointV1Request( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 9fd09b73..9e160846 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -397,7 +397,6 @@ async def test_create_model_endpoint_w_chat_template( llm_artifact_gateway=fake_llm_artifact_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - print(create_llm_model_endpoint_request_llama_3_70b_chat) response = await use_case.execute( user=user, request=create_llm_model_endpoint_request_llama_3_70b_chat, @@ -411,6 +410,7 @@ async def test_create_model_endpoint_w_chat_template( order_by=None, ) )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING assert endpoint.record.metadata == { "_llm": { @@ -426,6 +426,68 @@ async def test_create_model_endpoint_w_chat_template( } +@pytest.mark.asyncio +async def test_create_model_endpoint_w_vllm_args( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response = await use_case.execute( + user=user, + request=create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args, + ) + assert response.endpoint_creation_task_id + assert isinstance(response, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.name, + order_by=None, + ) + )[0] + + bundle_command = endpoint.record.current_model_bundle.flavor.command[2] + expected_vllm_args = ["max-model-len", "max-num-seqs", "chat-template"] + for arg in expected_vllm_args: + assert arg in bundle_command + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.model_name, + "source": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.source, + "inference_framework": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.num_shards, + "quantize": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.quantize, + "checkpoint_path": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.chat_template_override, + } + } + + @pytest.mark.asyncio async def test_create_model_endpoint_text_generation_inference_use_case_success( test_api_key: str, From ff971eadce7afbf94ea5f8ee65d0a4b42c00e8c7 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 15 Oct 2024 17:04:28 -0700 Subject: [PATCH 403/425] Update oai spec to remove strict flag default to workaround vllm incompatibilty + additional flags to set through API (#638) --- .../model_engine_server/common/dtos/llms/vllm.py | 10 ++++++++++ .../model_engine_server/common/types/gen/openai.py | 4 ++-- .../domain/use_cases/llm_model_endpoint_use_cases.py | 5 +++-- scripts/openai-spec.yaml | 2 +- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py index 72c13aa4..7494376e 100644 --- a/model-engine/model_engine_server/common/dtos/llms/vllm.py +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -69,6 +69,16 @@ class VLLMModelConfig(BaseModel): description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", ) + tool_call_parser: Optional[str] = Field( + None, + description="Tool call parser", + ) + + enable_auto_tool_choice: Optional[bool] = Field( + None, + description="Enable auto tool choice", + ) + class VLLMEngineAdditionalArgs(BaseModel): """Additional arguments to configure for vLLM that are not direct inputs to the vLLM engine""" diff --git a/model-engine/model_engine_server/common/types/gen/openai.py b/model-engine/model_engine_server/common/types/gen/openai.py index 964d0c33..c206d98f 100644 --- a/model-engine/model_engine_server/common/types/gen/openai.py +++ b/model-engine/model_engine_server/common/types/gen/openai.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openai-spec.yaml -# timestamp: 2024-10-04T21:01:02+00:00 +# timestamp: 2024-10-15T23:20:07+00:00 from __future__ import annotations @@ -309,7 +309,7 @@ class FunctionObject(BaseModel): Field( description="Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling)." ), - ] = False + ] = None class ResponseFormatText(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 8c736725..f6f4954a 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -960,8 +960,9 @@ def _create_vllm_bundle_command( config_value = '"$CHAT_TEMPLATE"' # if type of config_value is True, then only need to add the key - if isinstance(config_value, bool) and config_value: - vllm_cmd += f" --{field.replace('_', '-')}" + if isinstance(config_value, bool): + if config_value: + vllm_cmd += f" --{field.replace('_', '-')}" else: vllm_cmd += f" --{field.replace('_', '-')} {config_value}" diff --git a/scripts/openai-spec.yaml b/scripts/openai-spec.yaml index 6eb3f1cf..01cbcde6 100644 --- a/scripts/openai-spec.yaml +++ b/scripts/openai-spec.yaml @@ -9336,7 +9336,7 @@ components: strict: type: boolean nullable: true - default: false + # default: false (TODO: dmchoi) revert once vllm updates their spec description: Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling). required: - name From c6f87b89aa241c45ed05684ee42fb7f92c4985db Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:02:04 -0700 Subject: [PATCH 404/425] 0-N scaling for sync/streaming endpoints (#634) * quick todo * add prometheus metric for the 1-N scaling part * todo prometheus server addr * values sample * autogen tpl * add in a thing to the model endpoint service to say whether we can autoscale from zero * pass through validation * untested to get concurrency value * fix some tests, add some tests * clean up some things * fix a few bugs * autogen tpl * cleanup * comment dependency * rename * rename --- .../templates/service_template_config_map.yaml | 6 ++++++ charts/model-engine/values_sample.yaml | 3 +++ .../model_engine_server/api/dependencies.py | 1 + model-engine/model_engine_server/core/config.py | 1 + .../domain/services/model_endpoint_service.py | 8 ++++++++ .../domain/use_cases/batch_job_use_cases.py | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 2 ++ .../domain/use_cases/model_endpoint_use_cases.py | 12 ++++++------ .../entrypoints/start_batch_job_orchestration.py | 1 + .../resources/k8s_endpoint_resource_delegate.py | 9 ++++++++- .../infra/gateways/resources/k8s_resource_types.py | 10 ++++++++-- .../service_template_config_map_circleci.yaml | 12 ++++++++++-- .../infra/services/live_model_endpoint_service.py | 5 +++++ model-engine/tests/unit/conftest.py | 10 ++++++++++ .../unit/domain/test_model_endpoint_use_cases.py | 13 +++++++++++++ model-engine/tests/unit/infra/services/conftest.py | 1 + 16 files changed, 84 insertions(+), 11 deletions(-) diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index 6ef924df..a418557a 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -472,6 +472,12 @@ data: unsafeSsl: "false" databaseIndex: "${REDIS_DB_INDEX}" {{- end }} + - type: prometheus + metadata: + threshold: "${CONCURRENCY}" + metricName: request_concurrency_average + query: sum(rate(istio_request_duration_milliseconds_sum{destination_workload="${RESOURCE_NAME}"}[2m])) / 1000 + serverAddress: ${PROMETHEUS_SERVER_ADDRESS} {{- range $device := tuple "gpu" }} {{- range $mode := tuple "streaming"}} leader-worker-set-{{ $mode }}-{{ $device }}.yaml: |- diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index a9d8d7e0..5f9969b8 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -188,6 +188,9 @@ config: db_engine_echo: false db_engine_echo_pool: false db_engine_disconnect_strategy: "pessimistic" + # prometheus_server_address [optional, required if you want scale from zero for sync/streaming endpoints] + # is the address of the Prometheus server to query for endpoint metrics + prometheus_server_address: "http://prometheus-server.istio-system.svc.cluster.local:80" launch: # endpoint_namespace [required] is K8s namespace the endpoints will be created in endpoint_namespace: llm-engine diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 0fc61abb..5dce68fe 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -277,6 +277,7 @@ def _get_external_interfaces( sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, + can_scale_http_endpoint_from_zero_flag=infra_config().prometheus_server_address is not None, ) llm_model_endpoint_service = LiveLLMModelEndpointService( model_endpoint_record_repository=model_endpoint_record_repo, diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index dc6a492b..0474976c 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -48,6 +48,7 @@ class _InfraConfig: identity_service_url: Optional[str] = None firehose_role_arn: Optional[str] = None firehose_stream_name: Optional[str] = None + prometheus_server_address: Optional[str] = None @dataclass diff --git a/model-engine/model_engine_server/domain/services/model_endpoint_service.py b/model-engine/model_engine_server/domain/services/model_endpoint_service.py index 4cc89227..83a5cc2e 100644 --- a/model-engine/model_engine_server/domain/services/model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/model_endpoint_service.py @@ -256,3 +256,11 @@ async def update_model_endpoint( ExistingEndpointOperationInProgressException: if the endpoint is currently being edited (corresponds to an HTTP 409) """ + + @abstractmethod + def can_scale_http_endpoint_from_zero(self) -> bool: + """ + Returns whether the service can autoscale sync/stream endpoints from zero. + For instance, if particular dependencies in the cluster are not installed, then this should + return False + """ diff --git a/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py index d1f98cc9..eeb7ce67 100644 --- a/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py @@ -68,6 +68,7 @@ async def execute( min_workers=0, max_workers=request.resource_requests.max_workers, endpoint_type=ModelEndpointType.ASYNC, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), ) bundle = await self.model_bundle_repository.get_model_bundle( diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index f6f4954a..20db4d68 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1284,6 +1284,7 @@ async def execute( min_workers=request.min_workers, max_workers=request.max_workers, endpoint_type=request.endpoint_type, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), ) if request.gpu_type == GpuType.NVIDIA_AMPERE_A100E: # pragma: no cover raise ObjectHasInvalidValueException( @@ -1632,6 +1633,7 @@ async def execute( min_workers=request.min_workers, max_workers=request.max_workers, endpoint_type=endpoint_record.endpoint_type, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), ) if request.metadata is not None: diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index d5318f74..69beac00 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -104,15 +104,13 @@ def validate_deployment_resources( min_workers: Optional[int], max_workers: Optional[int], endpoint_type: ModelEndpointType, + can_scale_http_endpoint_from_zero: bool, ) -> None: - if endpoint_type in [ModelEndpointType.STREAMING, ModelEndpointType.SYNC]: - # Special case for sync endpoints, where we can have 0, 1 min/max workers. - # Otherwise, fall through to the general case. - if min_workers == 0 and max_workers == 1: - return # TODO: we should be also validating the update request against the existing state in k8s (e.g. # so min_workers <= max_workers always) maybe this occurs already in update_model_endpoint. - min_endpoint_size = 0 if endpoint_type == ModelEndpointType.ASYNC else 1 + min_endpoint_size = ( + 0 if endpoint_type == ModelEndpointType.ASYNC or can_scale_http_endpoint_from_zero else 1 + ) if min_workers is not None and min_workers < min_endpoint_size: raise EndpointResourceInvalidRequestException( f"Requested min workers {min_workers} too low" @@ -275,6 +273,7 @@ async def execute( min_workers=request.min_workers, max_workers=request.max_workers, endpoint_type=request.endpoint_type, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), ) if request.labels is None: raise EndpointLabelsException("Endpoint labels cannot be None!") @@ -457,6 +456,7 @@ async def execute( min_workers=request.min_workers, max_workers=request.max_workers, endpoint_type=endpoint_record.endpoint_type, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), ) if request.metadata is not None and CONVERTED_FROM_ARTIFACT_LIKE_KEY in request.metadata: diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index c059a9eb..af6eeef7 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -131,6 +131,7 @@ async def run_batch_job( sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, + can_scale_http_endpoint_from_zero_flag=False, # shouldn't matter since we only use this to create async endpoints ) batch_job_record_repository = DbBatchJobRecordRepository(session=session, read_only=False) batch_job_progress_gateway = LiveBatchJobProgressGateway(filesystem_gateway=filesystem_gateway) diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 112d3554..824db978 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -1858,10 +1858,17 @@ def _get_sync_autoscaling_params_from_keda( keda_config, ) -> HorizontalAutoscalingEndpointParams: spec = keda_config["spec"] + concurrency = 1 + for trigger in spec["triggers"]: + if trigger["metadata"].get("metricName") == "request_concurrency_average": + # Needs to match what is defined in the keda-scaled-obj section in + # service_template_config_map.yaml! + concurrency = trigger["metadata"]["threshold"] + break return dict( max_workers=spec.get("maxReplicaCount"), min_workers=spec.get("minReplicaCount"), - per_worker=1, # TODO dummy value, fill in when we autoscale from 0 to 1 + per_worker=concurrency, ) async def _get_resources( diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index 32af085e..c1c64c34 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -316,11 +316,12 @@ class HorizontalPodAutoscalerArguments(_BaseEndpointArguments): class KedaScaledObjectArguments(_BaseEndpointArguments): MIN_WORKERS: int MAX_WORKERS: int - # CONCURRENCY: float # TODO add in when we scale from 1 -> N pods + CONCURRENCY: float REDIS_HOST_PORT: str REDIS_DB_INDEX: str SERVICEBUS_NAMESPACE: Optional[str] AUTHENTICATION_REF: str + PROMETHEUS_SERVER_ADDRESS: str class UserConfigArguments(_BaseEndpointArguments): @@ -1250,6 +1251,9 @@ def get_endpoint_resource_arguments_from_request( MAX_WORKERS=build_endpoint_request.max_workers, ) elif endpoint_resource_name == "keda-scaled-object": + concurrency = get_target_concurrency_from_per_worker_value( + build_endpoint_request.per_worker + ) return KedaScaledObjectArguments( # Base resource arguments RESOURCE_NAME=k8s_resource_group_name, @@ -1264,11 +1268,13 @@ def get_endpoint_resource_arguments_from_request( # Scaled Object arguments MIN_WORKERS=build_endpoint_request.min_workers, MAX_WORKERS=build_endpoint_request.max_workers, - # CONCURRENCY=build_endpoint_request.concurrency, + CONCURRENCY=concurrency, REDIS_HOST_PORT=hmi_config.cache_redis_host_port, REDIS_DB_INDEX=str(hmi_config.cache_redis_db_index), SERVICEBUS_NAMESPACE=os.getenv("SERVICEBUS_NAMESPACE"), AUTHENTICATION_REF="azure-workload-identity", + PROMETHEUS_SERVER_ADDRESS=infra_config().prometheus_server_address + or "dummy-value", # We should never get to "dummy-value", validation should have taken place to ensure prom_server_addr is not None. ) elif endpoint_resource_name == "service": # Use ClusterIP by default for sync endpoint. diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index dde09282..b78a6545 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -7,8 +7,8 @@ metadata: name: model-engine-service-template-config labels: team: infra - app.kubernetes.io/version: 852d5899343633774f6a3543e4ed9e977a533e5a - tags.datadoghq.com/version: 852d5899343633774f6a3543e4ed9e977a533e5a + app.kubernetes.io/version: afe1df8a92039d3403e2f5ef266009231b02bf50 + tags.datadoghq.com/version: afe1df8a92039d3403e2f5ef266009231b02bf50 tags.datadoghq.com/env: circleci env: circleci product: model-engine @@ -2620,6 +2620,12 @@ data: enableTLS: "false" unsafeSsl: "false" databaseIndex: "${REDIS_DB_INDEX}" + - type: prometheus + metadata: + threshold: "${CONCURRENCY}" + metricName: request_concurrency_average + query: sum(rate(istio_request_duration_milliseconds_sum{destination_workload="${RESOURCE_NAME}"}[2m])) / 1000 + serverAddress: ${PROMETHEUS_SERVER_ADDRESS} leader-worker-set-streaming-gpu.yaml: |- apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet @@ -3286,9 +3292,11 @@ data: requests: cpu: 1 memory: 8Gi + ephemeral-storage: 10Gi limits: cpu: 4 memory: 32Gi + ephemeral-storage: 30Gi volumeMounts: - name: config-volume mountPath: /opt/.aws/config diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index 3bf62b5e..8cda63c5 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -55,6 +55,7 @@ def __init__( sync_model_endpoint_inference_gateway: SyncModelEndpointInferenceGateway, model_endpoints_schema_gateway: ModelEndpointsSchemaGateway, inference_autoscaling_metrics_gateway: InferenceAutoscalingMetricsGateway, + can_scale_http_endpoint_from_zero_flag: bool, ): self.model_endpoint_record_repository = model_endpoint_record_repository self.model_endpoint_infra_gateway = model_endpoint_infra_gateway @@ -64,6 +65,7 @@ def __init__( self.sync_model_endpoint_inference_gateway = sync_model_endpoint_inference_gateway self.model_endpoints_schema_gateway = model_endpoints_schema_gateway self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway + self.can_scale_http_endpoint_from_zero_flag = can_scale_http_endpoint_from_zero_flag def get_async_model_endpoint_inference_gateway( self, @@ -400,3 +402,6 @@ async def delete_model_endpoint(self, model_endpoint_id: str) -> None: ) logger.info(f"Endpoint delete released lock for {created_by}, {name}") + + def can_scale_http_endpoint_from_zero(self) -> bool: + return self.can_scale_http_endpoint_from_zero_flag diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 81e5e1cc..6812e49e 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -1680,6 +1680,7 @@ def __init__( ] = None, sync_model_endpoint_inference_gateway: Optional[SyncModelEndpointInferenceGateway] = None, inference_autoscaling_metrics_gateway: Optional[InferenceAutoscalingMetricsGateway] = None, + can_scale_http_endpoint_from_zero_flag: bool = True, ): if contents: self.db = contents @@ -1718,6 +1719,8 @@ def __init__( filesystem_gateway=FakeFilesystemGateway() ) + self.can_scale_http_endpoint_from_zero_flag = can_scale_http_endpoint_from_zero_flag + def get_async_model_endpoint_inference_gateway( self, ) -> AsyncModelEndpointInferenceGateway: @@ -1935,6 +1938,12 @@ async def delete_model_endpoint(self, model_endpoint_id: str) -> None: raise ObjectNotFoundException del self.db[model_endpoint_id] + def set_can_scale_http_endpoint_from_zero_flag(self, flag: bool): + self.can_scale_http_endpoint_from_zero_flag = flag + + def can_scale_http_endpoint_from_zero(self) -> bool: + return self.can_scale_http_endpoint_from_zero_flag + class FakeTokenizerRepository(TokenizerRepository): def load_tokenizer(self, model_name: str) -> AutoTokenizer: @@ -2279,6 +2288,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, + can_scale_http_endpoint_from_zero_flag=True, # reasonable default, gets overridden in individual tests if needed ) fake_batch_job_service = LiveBatchJobService( batch_job_record_repository=FakeBatchJobRecordRepository( diff --git a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py index f6ea8ab6..ba2ddf7c 100644 --- a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py @@ -70,6 +70,7 @@ async def test_create_model_endpoint_use_case_success( assert isinstance(response_3, CreateModelEndpointV1Response) # test special case where sync/streaming endpoint that has 0-1 min-max workers works + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(True) request = create_model_endpoint_request_sync.copy() request.min_workers = 0 request.max_workers = 1 @@ -84,6 +85,14 @@ async def test_create_model_endpoint_use_case_success( assert response_5.endpoint_creation_task_id assert isinstance(response_5, CreateModelEndpointV1Response) + # test general case as well for 0-N + request = create_model_endpoint_request_sync.copy() + request.min_workers = 0 + request.max_workers = 5 + response_6 = await use_case.execute(user=user, request=request) + assert response_6.endpoint_creation_task_id + assert isinstance(response_6, CreateModelEndpointV1Response) + @pytest.mark.asyncio async def test_create_model_endpoint_use_case_raises_invalid_value_exception( @@ -184,10 +193,12 @@ async def test_create_model_endpoint_use_case_raises_resource_request_exception( with pytest.raises(EndpointResourceInvalidRequestException): await use_case.execute(user=user, request=request) + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(False) request = create_model_endpoint_request_sync.copy() request.min_workers = 0 with pytest.raises(EndpointResourceInvalidRequestException): await use_case.execute(user=user, request=request) + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(True) request = create_model_endpoint_request_async.copy() request.max_workers = 2**63 @@ -1028,6 +1039,7 @@ async def test_update_model_endpoint_use_case_raises_resource_request_exception( ) # specific to sync endpoint + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(False) request = update_model_endpoint_request.copy() request.min_workers = 0 with pytest.raises(EndpointResourceInvalidRequestException): @@ -1036,6 +1048,7 @@ async def test_update_model_endpoint_use_case_raises_resource_request_exception( model_endpoint_id=model_endpoint_2.record.id, request=request, ) + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(True) request = update_model_endpoint_request.copy() request.max_workers = 2**63 diff --git a/model-engine/tests/unit/infra/services/conftest.py b/model-engine/tests/unit/infra/services/conftest.py index acd43e5a..873eea9e 100644 --- a/model-engine/tests/unit/infra/services/conftest.py +++ b/model-engine/tests/unit/infra/services/conftest.py @@ -40,6 +40,7 @@ def fake_live_model_endpoint_service( sync_model_endpoint_inference_gateway=fake_sync_model_endpoint_inference_gateway, inference_autoscaling_metrics_gateway=fake_inference_autoscaling_metrics_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, + can_scale_http_endpoint_from_zero_flag=True, # reasonable default, gets overridden in individual tests if needed ) return service From 9c07cadc070b98bd96d42fbd219de948f17a820f Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 16 Oct 2024 11:38:05 -0700 Subject: [PATCH 405/425] Bump commit in integration tests (#640) --- integration_tests/rest_api_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index 4fe29618..7db937dc 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -78,7 +78,7 @@ def my_model(**keyword_args): "flavor": { "flavor": "streaming_enhanced_runnable_image", "repository": "model-engine", - "tag": "2c1951dfff7159d7d29dd13b4f888e8355f8d51e", + "tag": "830c81ecba2a147022e504917c6ce18b00c2af44", "command": [ "dumb-init", "--", @@ -269,7 +269,7 @@ def my_model(**keyword_args): CREATE_DOCKER_IMAGE_BATCH_JOB_BUNDLE_REQUEST: Dict[str, Any] = { "name": format_name("di_batch_job_bundle_1"), "image_repository": "model-engine", - "image_tag": "2c1951dfff7159d7d29dd13b4f888e8355f8d51e", + "image_tag": "830c81ecba2a147022e504917c6ce18b00c2af44", "command": ["jq", ".", "/launch_mount_location/file"], "env": {"ENV1": "VAL1"}, "mount_location": "/launch_mount_location/file", @@ -289,7 +289,7 @@ def my_model(**keyword_args): CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST: Dict[str, Any] = { "name": format_name("fine_tune_di_batch_job_bundle_1"), "image_repository": "model-engine", - "image_tag": "2c1951dfff7159d7d29dd13b4f888e8355f8d51e", + "image_tag": "830c81ecba2a147022e504917c6ce18b00c2af44", "command": ["cat", "/launch_mount_location/file"], "env": {"ENV1": "VAL1"}, "mount_location": "/launch_mount_location/file", From 841b4d48a96b5f4a59aabc981a1b7a3782d97d06 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 16 Oct 2024 12:01:09 -0700 Subject: [PATCH 406/425] Add served_model_name (#639) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 20db4d68..37c8f56d 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -948,7 +948,7 @@ def _create_vllm_bundle_command( if hmi_config.sensitive_log_mode: vllm_args.disable_log_requests = True - vllm_cmd = f"python -m vllm_server --model {final_weights_folder} --port 5005" + vllm_cmd = f"python -m vllm_server --model {final_weights_folder} --served-model-name {model_name} {final_weights_folder} --port 5005" for field in VLLMEndpointAdditionalArgs.model_fields.keys(): config_value = getattr(vllm_args, field, None) if config_value is not None: From 7cb43cd026a36c15dc036c6b8bd41e6c4c5593ad Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 16 Oct 2024 12:44:32 -0700 Subject: [PATCH 407/425] Remove model name override (#641) --- .../use_cases/llm_model_endpoint_use_cases.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 37c8f56d..0e9b6edd 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -143,7 +143,6 @@ LLM_METADATA_KEY = "_llm" RESERVED_METADATA_KEYS = [LLM_METADATA_KEY, CONVERTED_FROM_ARTIFACT_LIKE_KEY] -VLLM_MODEL_WEIGHTS_FOLDER = "model_files" INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = { LLMInferenceFramework.DEEPSPEED: "instant-llm", @@ -2791,10 +2790,6 @@ async def execute( validate_endpoint_supports_openai_completion(model_endpoint, endpoint_content) - # if inference framework is VLLM, we need to set the model to use the weights folder - if endpoint_content.inference_framework == LLMInferenceFramework.VLLM: - request.model = VLLM_MODEL_WEIGHTS_FOLDER - inference_request = SyncEndpointPredictV1Request( args=request.model_dump(exclude_none=True), destination_path=OPENAI_COMPLETION_PATH, @@ -2897,10 +2892,6 @@ async def execute( validate_endpoint_supports_openai_completion(model_endpoint, model_content) - # if inference framework is VLLM, we need to set the model to use the weights folder - if model_content.inference_framework == LLMInferenceFramework.VLLM: - request.model = VLLM_MODEL_WEIGHTS_FOLDER - inference_request = SyncEndpointPredictV1Request( args=request.model_dump(exclude_none=True), destination_path=OPENAI_COMPLETION_PATH, @@ -3058,10 +3049,6 @@ async def execute( validate_endpoint_supports_chat_completion(model_endpoint, endpoint_content) - # if inference framework is VLLM, we need to set the model to use the weights folder - if endpoint_content.inference_framework == LLMInferenceFramework.VLLM: - request.model = VLLM_MODEL_WEIGHTS_FOLDER - inference_request = SyncEndpointPredictV1Request( args=request.model_dump(exclude_none=True), destination_path=OPENAI_CHAT_COMPLETION_PATH, @@ -3163,10 +3150,6 @@ async def execute( ) validate_endpoint_supports_chat_completion(model_endpoint, model_content) - # if inference framework is VLLM, we need to set the model to use the weights folder - if model_content.inference_framework == LLMInferenceFramework.VLLM: - request.model = VLLM_MODEL_WEIGHTS_FOLDER - inference_request = SyncEndpointPredictV1Request( args=request.model_dump(exclude_none=True), destination_path=OPENAI_CHAT_COMPLETION_PATH, From 33ce5ab29d43736064a725b1f0358851f7521ed7 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 16 Oct 2024 19:36:46 -0700 Subject: [PATCH 408/425] Add 1b 3b to model zoo (#642) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 0e9b6edd..1631f7e8 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -217,6 +217,8 @@ "llama-3-1-70b-instruct", "llama-3-1-405b", "llama-3-1-405b-instruct", + "llama-3-2-1b-instruct", + "llama-3-2-3b-instruct", "llama-3-2-11b-vision-instruct", "llama-3-2-90b-vision-instruct", "falcon-7b", From 3ce747f88d7bf3db7aad4ab462f43ff0dc653187 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 16 Oct 2024 19:57:00 -0700 Subject: [PATCH 409/425] Fix guided decoding logit setup (#643) --- .../inference/vllm/vllm_server.py | 45 +++++-------------- 1 file changed, 10 insertions(+), 35 deletions(-) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 6e3e8cbd..183d5c64 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -9,15 +9,16 @@ from logging import Logger from typing import AsyncGenerator, Dict, List, Optional -from fastapi import APIRouter, BackgroundTasks, HTTPException, Request +from fastapi import APIRouter, BackgroundTasks, Request from fastapi.responses import Response, StreamingResponse -from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.async_llm_engine import ( + AsyncEngineDeadError, + build_guided_decoding_logits_processor_async, +) from vllm.engine.protocol import EngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import build_app, build_async_engine_client, init_app_state from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest -from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.outputs import CompletionOutput from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob @@ -51,42 +52,16 @@ async def generate(request: Request) -> Response: request_dict = await request.json() prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) - guided_json = request_dict.pop("guided_json", None) - guided_regex = request_dict.pop("guided_regex", None) - guided_choice = request_dict.pop("guided_choice", None) - guided_grammar = request_dict.pop("guided_grammar", None) - sampling_params = SamplingParams(**request_dict) - - # Dummy request to get guided decode logit processor - try: - partial_openai_request = OpenAICompletionRequest.model_validate( - { - "model": "", - "prompt": "", - "guided_json": guided_json, - "guided_regex": guided_regex, - "guided_choice": guided_choice, - "guided_grammar": guided_grammar, - } - ) - except Exception: - raise HTTPException( - status_code=400, - detail="Bad request: failed to parse guided decoding parameters.", - ) guided_decoding_backend = ( await engine_client.get_decoding_config() ).guided_decoding_backend - guided_decode_logit_processor = await get_guided_decoding_logits_processor( - guided_decoding_backend, - partial_openai_request, - await engine_client.get_tokenizer(lora_request=None), + + sampling_params = await build_guided_decoding_logits_processor_async( + sampling_params=SamplingParams(**request_dict), + tokenizer=await engine_client.get_tokenizer(lora_request=None), + default_guided_backend=guided_decoding_backend, ) - if guided_decode_logit_processor is not None: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append(guided_decode_logit_processor) request_id = random_uuid() From 18457abed6df345d23d540aaa994a96e7e1f305c Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 17 Oct 2024 10:15:06 -0700 Subject: [PATCH 410/425] Revert "Remove model name override (#641)" (#644) This reverts commit 7cb43cd026a36c15dc036c6b8bd41e6c4c5593ad. --- .../use_cases/llm_model_endpoint_use_cases.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 1631f7e8..d2d10b9d 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -143,6 +143,7 @@ LLM_METADATA_KEY = "_llm" RESERVED_METADATA_KEYS = [LLM_METADATA_KEY, CONVERTED_FROM_ARTIFACT_LIKE_KEY] +VLLM_MODEL_WEIGHTS_FOLDER = "model_files" INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = { LLMInferenceFramework.DEEPSPEED: "instant-llm", @@ -2792,6 +2793,10 @@ async def execute( validate_endpoint_supports_openai_completion(model_endpoint, endpoint_content) + # if inference framework is VLLM, we need to set the model to use the weights folder + if endpoint_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + inference_request = SyncEndpointPredictV1Request( args=request.model_dump(exclude_none=True), destination_path=OPENAI_COMPLETION_PATH, @@ -2894,6 +2899,10 @@ async def execute( validate_endpoint_supports_openai_completion(model_endpoint, model_content) + # if inference framework is VLLM, we need to set the model to use the weights folder + if model_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + inference_request = SyncEndpointPredictV1Request( args=request.model_dump(exclude_none=True), destination_path=OPENAI_COMPLETION_PATH, @@ -3051,6 +3060,10 @@ async def execute( validate_endpoint_supports_chat_completion(model_endpoint, endpoint_content) + # if inference framework is VLLM, we need to set the model to use the weights folder + if endpoint_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + inference_request = SyncEndpointPredictV1Request( args=request.model_dump(exclude_none=True), destination_path=OPENAI_CHAT_COMPLETION_PATH, @@ -3152,6 +3165,10 @@ async def execute( ) validate_endpoint_supports_chat_completion(model_endpoint, model_content) + # if inference framework is VLLM, we need to set the model to use the weights folder + if model_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + inference_request = SyncEndpointPredictV1Request( args=request.model_dump(exclude_none=True), destination_path=OPENAI_CHAT_COMPLETION_PATH, From 1d855ca073d4a504c811cc8d602b998d838a29aa Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 17 Oct 2024 11:05:31 -0700 Subject: [PATCH 411/425] Miscellaneous improvments (#645) * Revert "Remove model name override (#641)" This reverts commit 7cb43cd026a36c15dc036c6b8bd41e6c4c5593ad. * Remove 'Annotated' type usage - python 3.8/pydantic doesn't like it * Add request id on batch completions error --- .../python/llmengine/data_types/completion.py | 51 ++++++++----------- clients/python/llmengine/data_types/vllm.py | 26 ++++------ .../api/v2/batch_completion.py | 19 ++++--- 3 files changed, 42 insertions(+), 54 deletions(-) diff --git a/clients/python/llmengine/data_types/completion.py b/clients/python/llmengine/data_types/completion.py index 94d7c62a..67384427 100644 --- a/clients/python/llmengine/data_types/completion.py +++ b/clients/python/llmengine/data_types/completion.py @@ -1,7 +1,5 @@ from typing import Any, Dict, List, Optional, TypeAlias -from typing_extensions import Annotated - from .core import StreamError from .gen.openai import CreateCompletionRequest, CreateCompletionResponse from .pydantic_types import BaseModel, Field @@ -272,35 +270,26 @@ def inter_token_latency(self) -> Optional[float]: # Only for streaming requests class CompletionV2Request(CreateCompletionRequest, VLLMCompletionAdditionalParams): - model: Annotated[ - str, - Field( - description="ID of the model to use.", - examples=["mixtral-8x7b-instruct"], - ), - ] - - stream: Annotated[ - Optional[bool], - Field( - False, - description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", - ), - ] - - top_k: Annotated[ - Optional[int], - Field( - None, - ge=-1, - description="Controls the number of top tokens to consider. -1 means consider all tokens.", - ), - ] - - include_stop_str_in_output: Annotated[ - Optional[bool], - Field(None, description="Whether to include the stop strings in output text."), - ] + model: str = Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ) + + stream: Optional[bool] = Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ) + + top_k: Optional[int] = Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ) + + include_stop_str_in_output: Optional[bool] = Field( + None, + description="Whether to include the stop strings in output text.", + ) CompletionV2SyncResponse: TypeAlias = CreateCompletionResponse diff --git a/clients/python/llmengine/data_types/vllm.py b/clients/python/llmengine/data_types/vllm.py index 56250a34..831dcf8d 100644 --- a/clients/python/llmengine/data_types/vllm.py +++ b/clients/python/llmengine/data_types/vllm.py @@ -1,7 +1,5 @@ from typing import Any, Dict, List, Optional, Union -from typing_extensions import Annotated - from .gen.openai import ResponseFormatJsonObject, ResponseFormatJsonSchema, ResponseFormatText from .pydantic_types import BaseModel, Field @@ -64,14 +62,11 @@ class VLLMSamplingParams(BaseModel): the beam width when `use_beam_search` is True. By default, `best_of` is set to `n`.""", ) - top_k: Annotated[ - Optional[int], - Field( - None, - ge=-1, - description="Controls the number of top tokens to consider. -1 means consider all tokens.", - ), - ] + top_k: Optional[int] = Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ) min_p: Optional[float] = Field( None, description="""Float that represents the minimum probability for a token to be @@ -110,14 +105,11 @@ class VLLMSamplingParams(BaseModel): generated. The returned output will contain the stop tokens unless the stop tokens are special tokens.""", ) - include_stop_str_in_output: Annotated[ - Optional[bool], - Field( - None, - description="""Whether to include the stop strings in + include_stop_str_in_output: Optional[bool] = Field( + None, + description="""Whether to include the stop strings in output text. Defaults to False.""", - ), - ] + ) ignore_eos: Optional[bool] = Field( None, description="""Whether to ignore the EOS token and continue generating diff --git a/model-engine/model_engine_server/api/v2/batch_completion.py b/model-engine/model_engine_server/api/v2/batch_completion.py index fb8262eb..78a8bfdf 100644 --- a/model-engine/model_engine_server/api/v2/batch_completion.py +++ b/model-engine/model_engine_server/api/v2/batch_completion.py @@ -14,7 +14,12 @@ UpdateBatchCompletionsV2Response, ) from model_engine_server.core.auth.authentication_repository import User -from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) from model_engine_server.domain.exceptions import ( ObjectHasInvalidValueException, ObjectNotAuthorizedException, @@ -43,7 +48,8 @@ async def batch_completions( request: CreateBatchCompletionsV2Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -) -> CreateBatchCompletionsV2Response: +) -> CreateBatchCompletionsV2Response: # pragma: no cover + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) logger.info(f"POST /v2/batch-completions {request} for {auth}") try: use_case = CreateBatchCompletionsV2UseCase( @@ -64,7 +70,7 @@ async def batch_completions( logger.exception(f"Error processing request {request} for {auth}") raise HTTPException( status_code=500, - detail="Internal server error", + detail=f"Internal server error. request_id: {request_id}", ) from exc @@ -100,7 +106,7 @@ async def update_batch_completion( request: UpdateBatchCompletionsV2Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), -) -> UpdateBatchCompletionsV2Response: +) -> UpdateBatchCompletionsV2Response: # pragma: no cover logger.info(f"POST /v2/batch-completions/{batch_completion_id} {request} for {auth}") try: use_case = UpdateBatchCompletionV2UseCase( @@ -130,7 +136,8 @@ async def cancel_batch_completion( batch_completion_id: str, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), -) -> CancelBatchCompletionsV2Response: +) -> CancelBatchCompletionsV2Response: # pragma: no cover + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) logger.info(f"POST /v2/batch-completions/{batch_completion_id}/actions/cancel for {auth}") try: use_case = CancelBatchCompletionV2UseCase( @@ -149,5 +156,5 @@ async def cancel_batch_completion( ) raise HTTPException( status_code=500, - detail="Internal server error", + detail=f"Internal server error. request_id: {request_id}", ) from exc From 36e088f04fd4839f7ea767f5840e1c0ba4050c1e Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 22 Oct 2024 17:18:38 -0700 Subject: [PATCH 412/425] Fix up the image caching functionality so it works with h100s (#646) * work on image cache gateway * some service stuff * rename * upd test * autogen tpl and add h100 * quick test * quick test * fix test * black * fix some config misnames * rename * another rename --- charts/model-engine/values_circleci.yaml | 7 ++ charts/model-engine/values_sample.yaml | 8 +-- .../gateways/resources/image_cache_gateway.py | 12 +++- .../k8s_endpoint_resource_delegate.py | 9 +++ .../service_template_config_map_circleci.yaml | 45 ++++++++++++- .../infra/services/image_cache_service.py | 12 +++- model-engine/tests/unit/conftest.py | 4 +- .../resources/test_image_cache_gateway.py | 66 +++++++++++++++++++ .../test_k8s_endpoint_resource_delegate.py | 10 +++ .../services/test_image_cache_service.py | 2 +- 10 files changed, 164 insertions(+), 11 deletions(-) create mode 100644 model-engine/tests/unit/infra/gateways/resources/test_image_cache_gateway.py diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index ba7fa812..50e78b36 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -228,6 +228,13 @@ imageCache: - key: "nvidia.com/gpu" operator: "Exists" effect: "NoSchedule" + - name: h100 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" celeryBrokerType: redis diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 5f9969b8..1a56c680 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -299,16 +299,16 @@ imageCache: - key: "nvidia.com/gpu" operator: "Exists" effect: "NoSchedule" - - name: h100-mig-1g-20gb + - name: h100-1g20gb nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-hopper-h100-mig-1g-20gb + k8s.amazonaws.com/accelerator: nvidia-hopper-h100-1g20gb tolerations: - key: "nvidia.com/gpu" operator: "Exists" effect: "NoSchedule" - - name: h100-mig-3g-40gb + - name: h100-3g40gb nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-hopper-h100-mig-3g-40gb + k8s.amazonaws.com/accelerator: nvidia-hopper-h100-3g40gb tolerations: - key: "nvidia.com/gpu" operator: "Exists" diff --git a/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py index 84f5c011..84af84bd 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py @@ -5,7 +5,10 @@ from model_engine_server.common.config import hmi_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( - get_kubernetes_apps_client, + get_kubernetes_apps_client, # If you ever add more imports here, update test_image_cache_gateway accordingly, otherwise you will likely mangle live cluster resources +) +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( + k8s_yaml_exists, load_k8s_yaml, ) from model_engine_server.infra.gateways.resources.k8s_resource_types import ( @@ -21,6 +24,9 @@ class CachedImages(TypedDict): a10: List[str] a100: List[str] t4: List[str] + h100: List[str] + h100_3g40gb: List[str] + h100_1g20gb: List[str] class ImageCacheGateway: @@ -39,6 +45,7 @@ async def create_or_update_image_cache(self, cached_images: CachedImages) -> Non for compute_type, images in cached_images.items(): # Required for mypy TypedDict compute_type = cast(str, compute_type) + compute_type = compute_type.replace("_", "-") # for k8s valid name images = cast(list, images) name = f"{base_name}-{compute_type}" @@ -47,6 +54,9 @@ async def create_or_update_image_cache(self, cached_images: CachedImages) -> Non NAMESPACE=hmi_config.endpoint_namespace, ) resource_key = f"image-cache-{compute_type}.yaml" + if not k8s_yaml_exists(resource_key): + logger.info(f"Didn't find yaml for {compute_type}, skipping") + continue image_cache = load_k8s_yaml(resource_key, substitution_kwargs) labels = image_cache["spec"]["template"]["metadata"]["labels"] diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 824db978..24eb4335 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -205,6 +205,15 @@ async def maybe_load_kube_config(): _kube_config_loaded = True +def k8s_yaml_exists(key: str) -> bool: + if LAUNCH_SERVICE_TEMPLATE_FOLDER is not None: + return os.path.exists(os.path.join(LAUNCH_SERVICE_TEMPLATE_FOLDER, key)) + else: + with open(LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH, "r") as f: + config_map_str = yaml.safe_load(f.read()) + return key in config_map_str["data"] + + def load_k8s_yaml(key: str, substitution_kwargs: ResourceArguments) -> Dict[str, Any]: if LAUNCH_SERVICE_TEMPLATE_FOLDER is not None: with open(os.path.join(LAUNCH_SERVICE_TEMPLATE_FOLDER, key), "r") as f: diff --git a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index b78a6545..83e0fa0d 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -7,8 +7,8 @@ metadata: name: model-engine-service-template-config labels: team: infra - app.kubernetes.io/version: afe1df8a92039d3403e2f5ef266009231b02bf50 - tags.datadoghq.com/version: afe1df8a92039d3403e2f5ef266009231b02bf50 + app.kubernetes.io/version: 88f8003b2b52c772e8f34d264b3dfb95da1c1e9b + tags.datadoghq.com/version: 88f8003b2b52c772e8f34d264b3dfb95da1c1e9b tags.datadoghq.com/env: circleci env: circleci product: model-engine @@ -3795,6 +3795,47 @@ data: name: busybox command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] terminationGracePeriodSeconds: 0 + image-cache-h100.yaml: |- + apiVersion: apps/v1 + kind: DaemonSet + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + team: infra + product: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/service: ${RESOURCE_NAME} + spec: + selector: + matchLabels: + app: ${RESOURCE_NAME} + version: v1 + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + app: ${RESOURCE_NAME} + team: infra + product: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/service: ${RESOURCE_NAME} + version: v1 + sidecar.istio.io/inject: "false" + spec: + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100 + tolerations: + - effect: NoSchedule + key: nvidia.com/gpu + operator: Exists + containers: + - image: public.ecr.aws/docker/library/busybox:latest + imagePullPolicy: IfNotPresent + name: busybox + command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] + terminationGracePeriodSeconds: 0 cron-trigger.yaml: |- apiVersion: batch/v1 kind: CronJob diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py index 4966e9f4..f2b1dc28 100644 --- a/model-engine/model_engine_server/infra/services/image_cache_service.py +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -109,7 +109,7 @@ def _cache_finetune_llm_images( ) continue image = f"{llm_image.repo}:{llm_image.tag}" - for key in ["a10", "a100"]: + for key in ["a10", "a100", "h100", "h100_3g40gb", "h100_1g20gb"]: images_to_cache_priority[key][image] = llm_image_cache_priority async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpointInfraState]]): @@ -118,6 +118,9 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi "a10": {}, "a100": {}, "t4": {}, + "h100": {}, + "h100_3g40gb": {}, + "h100_1g20gb": {}, } self._cache_finetune_llm_images(images_to_cache_priority) @@ -167,6 +170,9 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi (GpuType.NVIDIA_AMPERE_A10, "a10"), (GpuType.NVIDIA_AMPERE_A100, "a100"), (GpuType.NVIDIA_TESLA_T4, "t4"), + (GpuType.NVIDIA_HOPPER_H100, "h100"), + (GpuType.NVIDIA_HOPPER_H100_3G_40GB, "h100_3g40gb"), + (GpuType.NVIDIA_HOPPER_H100_1G_20GB, "h100_1g20gb"), ]: if state.resource_state.gpu_type == gpu_type and ( ( @@ -179,7 +185,9 @@ async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpoi and self.docker_repository.image_exists(image_tag, repository_name) ): images_to_cache_priority[key][state.image] = cache_priority - images_to_cache = CachedImages(cpu=[], a10=[], a100=[], t4=[]) + images_to_cache = CachedImages( + cpu=[], a10=[], a100=[], t4=[], h100=[], h100_1g20gb=[], h100_3g40gb=[] + ) for key, val in images_to_cache_priority.items(): images_to_cache[key] = sorted( # type: ignore val.keys(), key=lambda image: val[image], reverse=True diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 6812e49e..119c4d7b 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -945,7 +945,9 @@ async def delete_trigger( class FakeImageCacheGateway(ImageCacheGateway): def __init__(self): - self.cached_images = CachedImages(cpu=[], a10=[], a100=[], t4=[]) + self.cached_images = CachedImages( + cpu=[], a10=[], a100=[], t4=[], h100=[], h100_1g20gb=[], h100_3g40gb=[] + ) async def create_or_update_image_cache(self, cached_images: CachedImages) -> None: self.cached_images = cached_images diff --git a/model-engine/tests/unit/infra/gateways/resources/test_image_cache_gateway.py b/model-engine/tests/unit/infra/gateways/resources/test_image_cache_gateway.py new file mode 100644 index 00000000..60fa39cb --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/resources/test_image_cache_gateway.py @@ -0,0 +1,66 @@ +from typing import Dict, Set +from unittest.mock import AsyncMock, patch + +import pytest +from model_engine_server.infra.gateways.resources.image_cache_gateway import ( + CachedImages, + ImageCacheGateway, +) + +MODULE_PATH = "model_engine_server.infra.gateways.resources.image_cache_gateway" + + +@pytest.fixture +def mock_apps_client(): + mock_client = AsyncMock() + with patch( + f"{MODULE_PATH}.get_kubernetes_apps_client", + return_value=mock_client, + ): + yield mock_client + + +@pytest.mark.asyncio +async def test_create_or_update_image_cache( + mock_apps_client, +): + gateway = ImageCacheGateway() + await gateway.create_or_update_image_cache( + CachedImages( + cpu=["cpu_image"], + a10=["a10_image"], + a100=["a100_image"], + t4=["t4_image"], + h100=["h100_image"], + ) + ) + + # Needs to correspond with model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml + expected_images: Dict[str, Set[str]] = { + "cpu": {"cpu_image"}, + "a10": {"a10_image"}, + "a100": {"a100_image"}, + "t4": {"t4_image"}, + "h100": {"h100_image"}, + } + + actual_images: Dict[str, Set[str]] = { + "cpu": set(), + "a10": set(), + "a100": set(), + "t4": set(), + "h100": set(), + } + + for call_args in mock_apps_client.create_namespaced_daemon_set.call_args_list: + _, kwargs = call_args + compute_type = kwargs["body"]["metadata"]["name"].split("-")[-1] + actual_images[compute_type] = set( + container["image"] + for container in kwargs["body"]["spec"]["template"]["spec"]["containers"] + ) + + for k in expected_images.keys(): + assert expected_images[k].issubset( + actual_images[k] + ), f"Missing {expected_images[k].difference(actual_images[k])}" diff --git a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py index f9ac3284..84a5063f 100644 --- a/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py @@ -20,6 +20,7 @@ K8SEndpointResourceDelegate, add_datadog_env_to_container, get_main_container_from_deployment_template, + k8s_yaml_exists, load_k8s_yaml, ) from model_engine_server.infra.gateways.resources.k8s_resource_types import ( @@ -135,6 +136,15 @@ def k8s_endpoint_resource_delegate( return gateway +def test_k8s_yaml_exists(): + # This is tied to + # llm-engine/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml + assert k8s_yaml_exists("image-cache-h100.yaml"), "image-cache-h100.yaml should exist" + assert not k8s_yaml_exists( + "image-cache-abc9001.yaml" + ), "image-cache-abc9001.yaml should not exist" + + @pytest.mark.parametrize("resource_arguments_type", ResourceArguments.__args__) def test_resource_arguments_type_and_add_datadog_env_to_main_container(resource_arguments_type): # Convert the name of the type to a kebab case string diff --git a/model-engine/tests/unit/infra/services/test_image_cache_service.py b/model-engine/tests/unit/infra/services/test_image_cache_service.py index 3dd1913d..5f3bb72d 100644 --- a/model-engine/tests/unit/infra/services/test_image_cache_service.py +++ b/model-engine/tests/unit/infra/services/test_image_cache_service.py @@ -68,7 +68,7 @@ async def test_caching_finetune_llm_images( ) forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/model-engine", GIT_TAG) - for key in ["a10", "a100"]: + for key in ["a10", "a100", "h100", "h100_3g40gb", "h100_1g20gb"]: for llm_image in [ istio_image, tgi_image_110, From 8f9a672a593f76af507b6a43a4c409201eef150a Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:57:39 -0700 Subject: [PATCH 413/425] increase storage limit for h100s (#648) * up storage limit + test * actually bump it again * bump recHardware also * oops --- charts/model-engine/values_circleci.yaml | 2 +- .../model_engine_server/common/resource_limits.py | 9 ++++++--- model-engine/tests/unit/conftest.py | 2 +- .../tests/unit/domain/test_model_endpoint_use_cases.py | 8 ++++++++ 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml index 50e78b36..9096c76f 100644 --- a/charts/model-engine/values_circleci.yaml +++ b/charts/model-engine/values_circleci.yaml @@ -289,7 +289,7 @@ recommendedHardware: cpus: 80 gpus: 8 memory: 800Gi - storage: 640Gi + storage: 900Gi gpu_type: nvidia-hopper-h100 nodes_per_worker: 2 byModelName: diff --git a/model-engine/model_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py index 57145c64..1ede52de 100644 --- a/model-engine/model_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -34,7 +34,7 @@ ) # Should we allow multi-gpu instances? This allows the largest single-gpu g5dn instance. # p4d.24xlarge, p4de.24xlarge A100_INSTANCE_LIMITS = dict(cpus=95, memory="1000Gi") -H100_INSTANCE_LIMITS = dict(cpus=191, memory="2000Gi") +H100_INSTANCE_LIMITS = dict(cpus=191, memory="2000Gi", storage="1300Gi") H100_1G_20GB_INSTANCE_LIMITS = dict(cpus=47, memory="500Gi") H100_3G_40GB_INSTANCE_LIMITS = dict(cpus=95, memory="1000Gi") STORAGE_LIMIT = "640Gi" # TODO: figure out an actual limit. @@ -150,7 +150,10 @@ def validate_resource_requests( if storage <= 0: raise EndpointResourceInvalidRequestException("Requested storage must be positive") - available_storage_for_user = parse_mem_request(STORAGE_LIMIT) + available_storage_for_user = parse_mem_request( + resource_limits.get("storage", STORAGE_LIMIT) # type: ignore + ) + total_available_storage = available_storage_for_user if isinstance(bundle, ModelBundle): storage += parse_mem_request(FORWARDER_STORAGE_USAGE) @@ -165,7 +168,7 @@ def validate_resource_requests( else: storage += parse_mem_request(bundle.flavor.triton_storage) - if storage > parse_mem_request(STORAGE_LIMIT): + if storage > total_available_storage: raise EndpointResourceInvalidRequestException( f"Requested {storage=} too high. The maximum for {gpu_type=} is {format_bytes(available_storage_for_user)}" ) diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 119c4d7b..0b086f40 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -4713,7 +4713,7 @@ async def async_mock(*args, **kwargs): # noqa cpus: 160 gpus: 8 memory: 800Gi - storage: 640Gi + storage: 900Gi gpu_type: nvidia-hopper-h100 nodes_per_worker: 2 """, diff --git a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py index ba2ddf7c..8fd5cf19 100644 --- a/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py @@ -93,6 +93,14 @@ async def test_create_model_endpoint_use_case_success( assert response_6.endpoint_creation_task_id assert isinstance(response_6, CreateModelEndpointV1Response) + # test you can ask for more storage on H100s + request = create_model_endpoint_request_sync.copy() + request.storage = "950Gi" + request.gpu_type = "nvidia-hopper-h100" + response_7 = await use_case.execute(user=user, request=request) + assert response_7.endpoint_creation_task_id + assert isinstance(response_7, CreateModelEndpointV1Response) + @pytest.mark.asyncio async def test_create_model_endpoint_use_case_raises_invalid_value_exception( From 9233b9a60874c6fa35bbbf8c37f217154867594a Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 24 Oct 2024 15:16:30 -0700 Subject: [PATCH 414/425] Bearer auth for oai compatibility (#649) * Bearer auth for oai compatibility * fix test --- .../model_engine_server/api/dependencies.py | 76 +++++++++++++------ model-engine/tests/unit/api/conftest.py | 15 ++-- 2 files changed, 61 insertions(+), 30 deletions(-) diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 5dce68fe..e708adf9 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -6,7 +6,7 @@ import aioredis from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.security import HTTPBasic, HTTPBasicCredentials, OAuth2PasswordBearer from model_engine_server.common.config import hmi_config from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.env_vars import CIRCLECI @@ -131,7 +131,8 @@ logger = make_logger(logger_name()) -AUTH = HTTPBasic(auto_error=False) +basic_auth = HTTPBasic(auto_error=False) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) @dataclass @@ -433,35 +434,60 @@ async def get_auth_repository(): async def verify_authentication( - credentials: HTTPBasicCredentials = Depends(AUTH), + credentials: Optional[HTTPBasicCredentials] = Depends(basic_auth), + tokens: Optional[str] = Depends(oauth2_scheme), auth_repo: AuthenticationRepository = Depends(get_auth_repository), ) -> User: """ Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise, raises a 401. """ - username = credentials.username if credentials is not None else None - if username is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="No authentication was passed in", - headers={"WWW-Authenticate": "Basic"}, - ) - - auth = await auth_repo.get_auth_from_username_async(username=username) - - if not auth: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not authenticate user", - headers={"WWW-Authenticate": "Basic"}, - ) - - # set logger context with identity data - LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id) - LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id) - - return auth + # Basic Authentication + if credentials is not None: + username = credentials.username if credentials is not None else None + if username is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="No authentication was passed in", + headers={"WWW-Authenticate": "Basic"}, + ) + + auth = await auth_repo.get_auth_from_username_async(username=username) + + if not auth: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not authenticate user", + headers={"WWW-Authenticate": "Basic"}, + ) + + # set logger context with identity data + LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id) + LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id) + + return auth + + # bearer token + if tokens is not None: + auth = await auth_repo.get_auth_from_username_async(username=tokens) + if not auth: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not authenticate user", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # set logger context with identity data + LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id) + LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id) + + return auth + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="No authentication was passed in", + headers={"WWW-Authenticate": "Bearer"}, + ) _pool: Optional[aioredis.BlockingConnectionPool] = None diff --git a/model-engine/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py index 3ce36f6d..29454dca 100644 --- a/model-engine/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -1,6 +1,6 @@ import asyncio import datetime -from typing import Any, Dict, Iterator, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple import pytest import pytest_asyncio @@ -10,9 +10,10 @@ from httpx import AsyncClient from model_engine_server.api.app import app from model_engine_server.api.dependencies import ( - AUTH, + basic_auth, get_external_interfaces, get_external_interfaces_read_only, + oauth2_scheme, verify_authentication, ) from model_engine_server.core.auth.authentication_repository import AuthenticationRepository, User @@ -65,15 +66,19 @@ def get_test_auth_repository() -> Iterator[AuthenticationRepository]: def fake_verify_authentication( - credentials: HTTPBasicCredentials = Depends(AUTH), + credentials: Optional[HTTPBasicCredentials] = Depends(basic_auth), + tokens: Optional[str] = Depends(oauth2_scheme), auth_repo: AuthenticationRepository = Depends(get_test_auth_repository), ) -> User: """ Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise, raises a 401. """ - auth_username = credentials.username if credentials is not None else None - if not auth_username: + if credentials is not None: + auth_username = credentials.username + elif tokens is not None: + auth_username = tokens + else: raise HTTPException(status_code=401, detail="No authentication was passed in") auth = auth_repo.get_auth_from_username(username=auth_username) From 785e0fa46235589ea56b886844b485b387333c71 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 29 Oct 2024 17:57:03 -0700 Subject: [PATCH 415/425] Updates to helm charts to sync with SGP (#651) * Updates to helm charts to sync with SGP * bump versiong --- charts/model-engine/Chart.yaml | 2 +- .../templates/balloon_deployments.yaml | 2 +- .../templates/celery_autoscaler_stateful_set.yaml | 2 +- .../recommended_hardware_config_map.yaml | 2 +- .../templates/service_account_image_builder.yaml | 2 +- .../templates/service_account_inference.yaml | 2 +- .../templates/service_template_config_map.yaml | 10 ++++++++-- .../templates/trigger_authentication.yaml | 15 ++++++++++++++- charts/model-engine/values.yaml | 4 ++++ 9 files changed, 32 insertions(+), 9 deletions(-) diff --git a/charts/model-engine/Chart.yaml b/charts/model-engine/Chart.yaml index 1ebd5db6..e346794d 100644 --- a/charts/model-engine/Chart.yaml +++ b/charts/model-engine/Chart.yaml @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.3 +version: 0.1.4 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to diff --git a/charts/model-engine/templates/balloon_deployments.yaml b/charts/model-engine/templates/balloon_deployments.yaml index 3a4e1f20..735aff86 100644 --- a/charts/model-engine/templates/balloon_deployments.yaml +++ b/charts/model-engine/templates/balloon_deployments.yaml @@ -52,4 +52,4 @@ spec: --- {{- end }} {{- end }} -{{- end }} \ No newline at end of file +{{- end }} diff --git a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml index 810e7e1f..93d359b5 100644 --- a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml +++ b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml @@ -109,4 +109,4 @@ spec: name: config-volume {{- end}} {{- end }} -{{- end }} \ No newline at end of file +{{- end }} diff --git a/charts/model-engine/templates/recommended_hardware_config_map.yaml b/charts/model-engine/templates/recommended_hardware_config_map.yaml index 6c999145..b185a5ab 100644 --- a/charts/model-engine/templates/recommended_hardware_config_map.yaml +++ b/charts/model-engine/templates/recommended_hardware_config_map.yaml @@ -27,4 +27,4 @@ data: gpu_type: {{ .gpu_type }} nodes_per_worker: {{ .nodes_per_worker }} {{- end }} -{{- end }} \ No newline at end of file +{{- end }} diff --git a/charts/model-engine/templates/service_account_image_builder.yaml b/charts/model-engine/templates/service_account_image_builder.yaml index 8cdec485..e68cd7b2 100644 --- a/charts/model-engine/templates/service_account_image_builder.yaml +++ b/charts/model-engine/templates/service_account_image_builder.yaml @@ -16,4 +16,4 @@ metadata: {{- end }} --- {{- end }} -{{- end }} \ No newline at end of file +{{- end }} diff --git a/charts/model-engine/templates/service_account_inference.yaml b/charts/model-engine/templates/service_account_inference.yaml index c9fa94fb..9a4a698c 100644 --- a/charts/model-engine/templates/service_account_inference.yaml +++ b/charts/model-engine/templates/service_account_inference.yaml @@ -22,4 +22,4 @@ imagePullSecrets: - name: egp-ecr-regcred {{- end }} --- -{{- end }} \ No newline at end of file +{{- end }} diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index a418557a..6836784c 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -464,13 +464,19 @@ data: - type: redis metadata: address: ${REDIS_HOST_PORT} # Format must be host:port + {{- if not .Values.redis.enableAuth }} passwordFromEnv: "" + {{- end }} listName: "launch-endpoint-autoscaling:${ENDPOINT_ID}" listLength: "100" # something absurdly high so we don't scale past 1 pod activationListLength: "0" - enableTLS: "false" - unsafeSsl: "false" + enableTLS: "{{ .Values.redis.enableTLS }}" + unsafeSsl: "{{ .Values.redis.unsafeSsl }}" databaseIndex: "${REDIS_DB_INDEX}" + {{- if .Values.redis.enableAuth }} + authenticationRef: + name: "keda-trigger-auth-redis-secret" + {{- end }} {{- end }} - type: prometheus metadata: diff --git a/charts/model-engine/templates/trigger_authentication.yaml b/charts/model-engine/templates/trigger_authentication.yaml index 63209f68..088dee94 100644 --- a/charts/model-engine/templates/trigger_authentication.yaml +++ b/charts/model-engine/templates/trigger_authentication.yaml @@ -8,4 +8,17 @@ spec: podIdentity: provider: azure-workload identityId: {{ .Values.azure.client_id }} -{{- end }} \ No newline at end of file +{{- else if .Values.redis.enableAuth }} +apiVersion: keda.sh/v1alpha1 +kind: TriggerAuthentication +metadata: + name: keda-trigger-auth-redis-secret + namespace: {{ .Values.config.values.launch.endpoint_namespace }} +spec: + awsSecretManager: + podIdentity: + provider: aws + secrets: + - parameter: password + name: {{ .Values.redis.kedaSecretName }} +{{- end }} diff --git a/charts/model-engine/values.yaml b/charts/model-engine/values.yaml index 1ea7522e..c7c0ec0d 100644 --- a/charts/model-engine/values.yaml +++ b/charts/model-engine/values.yaml @@ -3,6 +3,10 @@ spellbook: enabled: false redis: auth: + enableTLS: false + enableAuth: false + kedaSecretName: "" + unsafeSsl: false db: runDbInitScript: false runDbMigrationScript: false From 05f2eccd116ef9b4a19911cc9e8163d577bae9aa Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 30 Oct 2024 16:56:12 -0700 Subject: [PATCH 416/425] Add script to stamp initial schema (#653) --- .../db/migrations/stamp_initial_schema.sh | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100755 model-engine/model_engine_server/db/migrations/stamp_initial_schema.sh diff --git a/model-engine/model_engine_server/db/migrations/stamp_initial_schema.sh b/model-engine/model_engine_server/db/migrations/stamp_initial_schema.sh new file mode 100755 index 00000000..bf7d3781 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/stamp_initial_schema.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Usage +# ML_INFRA_DATABASE_URL="postgresql://postgres:password@localhost:54320/postgres" bash stamp_initial_schema.sh + +# Get the directory of this script +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +# Change directory to the directory of this script +cd $DIR + +# Stamps initial revision to new table +alembic stamp fa3267c80731 \ No newline at end of file From 0024b0cc9888abb245644b61f929089c3e06b6b7 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 30 Oct 2024 17:55:54 -0700 Subject: [PATCH 417/425] Remove ENV requirement for db migration (#654) --- .../model_engine_server/db/migrations/alembic/env.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model-engine/model_engine_server/db/migrations/alembic/env.py b/model-engine/model_engine_server/db/migrations/alembic/env.py index 24c67aa5..6cb95d67 100644 --- a/model-engine/model_engine_server/db/migrations/alembic/env.py +++ b/model-engine/model_engine_server/db/migrations/alembic/env.py @@ -7,10 +7,10 @@ from sqlalchemy import engine_from_config, pool env = os.environ.get("ENV") -if env is None: - assert ( - os.getenv("ML_INFRA_DATABASE_URL") is not None - ), "Expected ML_INFRA_DATABASE_URL to be set if ENV is not set." +# if env is None: +# assert ( +# os.getenv("ML_INFRA_DATABASE_URL") is not None +# ), "Expected ML_INFRA_DATABASE_URL to be set if ENV is not set." # this is the Alembic Config object, which provides # access to the values within the .ini file in use. From 84f31a8496776d32884620c41dba9e988dfcf906 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 4 Nov 2024 10:27:19 -0800 Subject: [PATCH 418/425] Remove restricte model name check (#656) --- .../use_cases/llm_model_endpoint_use_cases.py | 5 +-- .../inference/vllm/build_and_upload_image.sh | 2 +- .../tests/unit/domain/test_llm_use_cases.py | 35 ------------------- 3 files changed, 4 insertions(+), 38 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index d2d10b9d..afdfb5ab 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -381,9 +381,10 @@ def _model_endpoint_entity_to_get_llm_model_endpoint_response( def validate_model_name(model_name: str, inference_framework: LLMInferenceFramework) -> None: + # TODO: replace this logic to check if the model architecture is supported instead if model_name not in _SUPPORTED_MODELS_BY_FRAMEWORK[inference_framework]: - raise ObjectHasInvalidValueException( - f"Model name {model_name} is not supported for inference framework {inference_framework}." + logger.warning( + f"Model name {model_name} may not be supported by inference framework {inference_framework}." ) diff --git a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh index 10765cc0..3b1ab4cb 100755 --- a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh +++ b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh @@ -43,7 +43,7 @@ aws ecr get-login-password --region us-west-2 | docker login --username AWS --pa DOCKER_BUILDKIT=1 docker build \ --build-arg VLLM_VERSION=${VLLM_VERSION} \ --build-arg VLLM_BASE_REPO=${VLLM_BASE_REPO} \ - -f Dockerfile.vllm \ + -f ${DOCKERFILE} \ --target ${BUILD_TARGET} \ -t $IMAGE ${PROJECT_DIR} docker push $IMAGE diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 9e160846..f1392168 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -704,41 +704,6 @@ async def test_create_model_endpoint_trt_llm_use_case_success( ) -@pytest.mark.asyncio -async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception( - test_api_key: str, - fake_model_bundle_repository, - fake_model_endpoint_service, - fake_docker_repository_image_always_exists, - fake_model_primitive_gateway, - fake_llm_artifact_gateway, - create_llm_model_endpoint_request_invalid_model_name: CreateLLMModelEndpointV1Request, -): - fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository - bundle_use_case = CreateModelBundleV2UseCase( - model_bundle_repository=fake_model_bundle_repository, - docker_repository=fake_docker_repository_image_always_exists, - model_primitive_gateway=fake_model_primitive_gateway, - ) - llm_bundle_use_case = CreateLLMModelBundleV1UseCase( - create_model_bundle_use_case=bundle_use_case, - model_bundle_repository=fake_model_bundle_repository, - llm_artifact_gateway=fake_llm_artifact_gateway, - docker_repository=fake_docker_repository_image_always_exists, - ) - use_case = CreateLLMModelEndpointV1UseCase( - create_llm_model_bundle_use_case=llm_bundle_use_case, - model_endpoint_service=fake_model_endpoint_service, - docker_repository=fake_docker_repository_image_always_exists, - llm_artifact_gateway=fake_llm_artifact_gateway, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - with pytest.raises(ObjectHasInvalidValueException): - await use_case.execute( - user=user, request=create_llm_model_endpoint_request_invalid_model_name - ) - - @pytest.mark.asyncio async def test_create_llm_model_endpoint_use_case_quantization_exception( test_api_key: str, From cb699e898e4abff29c7c9cd2ec7501fa8bd6d382 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 4 Nov 2024 15:04:16 -0800 Subject: [PATCH 419/425] Safe handle model param (#657) --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index afdfb5ab..2fe693d2 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -3397,11 +3397,13 @@ def infer_addition_engine_args_from_model_name( model_name: str, ) -> VLLMEndpointAdditionalArgs: # Increase max gpu utilization for larger models - model_param_count_b = get_model_param_count_b(model_name) - if model_param_count_b >= 70: - gpu_memory_utilization = 0.95 - else: - gpu_memory_utilization = 0.9 + gpu_memory_utilization = 0.9 + try: + model_param_count_b = get_model_param_count_b(model_name) + if model_param_count_b >= 70: + gpu_memory_utilization = 0.95 + except ObjectHasInvalidValueException: # pragma: no cover + pass # Gemma 2 requires flashinfer attention backend attention_backend = None From c2692c48ed134a26f4e118ee42190b23a0c7f3aa Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Mon, 4 Nov 2024 18:13:53 -0800 Subject: [PATCH 420/425] More vllm args passthrough (#658) --- .../common/dtos/llms/vllm.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py index 7494376e..bd228bec 100644 --- a/model-engine/model_engine_server/common/dtos/llms/vllm.py +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -79,6 +79,42 @@ class VLLMModelConfig(BaseModel): description="Enable auto tool choice", ) + load_format: Optional[str] = Field( + None, + description="The format of the model weights to load.\n\n" + '* "auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available.\n" + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading.\n" + '* "dummy" will initialize the weights with random values, ' + "which is mainly for profiling.\n" + '* "tensorizer" will load the weights using tensorizer from ' + "CoreWeave. See the Tensorize vLLM Model script in the Examples " + "section for more information.\n" + '* "bitsandbytes" will load the weights using bitsandbytes ' + "quantization.\n", + ) + + config_format: Optional[str] = Field( + None, + description="The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'.", + ) + + tokenizer_mode: Optional[str] = Field( + None, + description="Tokenizer mode. 'auto' will use the fast tokenizer if" + "available, 'slow' will always use the slow tokenizer, and" + "'mistral' will always use the tokenizer from `mistral_common`.", + ) + + limit_mm_per_prompt: Optional[str] = Field( + None, + description="Maximum number of data instances per modality per prompt. Only applicable for multimodal models.", + ) + class VLLMEngineAdditionalArgs(BaseModel): """Additional arguments to configure for vLLM that are not direct inputs to the vLLM engine""" From b6eac17438d6a96c85aa5f498d04cdbb1bbeae2e Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Tue, 5 Nov 2024 10:09:29 -0800 Subject: [PATCH 421/425] Changes to balloons to support a less "on-demand" style of compute (#655) * balloons can take up more than 1 gpu * setting to make balloons only for high priority * values.yaml default * bump helm chart version --- charts/model-engine/Chart.yaml | 2 +- charts/model-engine/templates/balloon_deployments.yaml | 2 +- .../templates/model_engine_default_priority_class.yaml | 2 ++ charts/model-engine/values.yaml | 2 ++ charts/model-engine/values_sample.yaml | 5 +++++ 5 files changed, 11 insertions(+), 2 deletions(-) diff --git a/charts/model-engine/Chart.yaml b/charts/model-engine/Chart.yaml index e346794d..e9ef0518 100644 --- a/charts/model-engine/Chart.yaml +++ b/charts/model-engine/Chart.yaml @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.4 +version: 0.1.5 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to diff --git a/charts/model-engine/templates/balloon_deployments.yaml b/charts/model-engine/templates/balloon_deployments.yaml index 735aff86..49a1890f 100644 --- a/charts/model-engine/templates/balloon_deployments.yaml +++ b/charts/model-engine/templates/balloon_deployments.yaml @@ -41,7 +41,7 @@ spec: resources: limits: memory: 28Gi - nvidia.com/gpu: 1 + nvidia.com/gpu: {{ .gpuCount | default 1 }} cpu: 4 command: - /bin/bash diff --git a/charts/model-engine/templates/model_engine_default_priority_class.yaml b/charts/model-engine/templates/model_engine_default_priority_class.yaml index a2d2dbb9..a2b80367 100644 --- a/charts/model-engine/templates/model_engine_default_priority_class.yaml +++ b/charts/model-engine/templates/model_engine_default_priority_class.yaml @@ -4,8 +4,10 @@ kind: PriorityClass metadata: name: "{{ include "modelEngine.fullname" . }}-default-priority" value: 1 +{{- if .Values.balloonConfig.reserveHighPriority }} # This ensures that the default launch pods will never preempt any pods, which means # they cannot take advantage of the dummy nodes. preemptionPolicy: Never +{{- end }} description: "Default Priority Class for Launch" {{- end }} diff --git a/charts/model-engine/values.yaml b/charts/model-engine/values.yaml index c7c0ec0d..b8e60da0 100644 --- a/charts/model-engine/values.yaml +++ b/charts/model-engine/values.yaml @@ -10,6 +10,8 @@ redis: db: runDbInitScript: false runDbMigrationScript: false +balloonConfig: + reserveHighPriority: true balloonNodeSelector: node-lifecycle: normal nodeSelector: diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index 1a56c680..f7e1fe58 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -75,6 +75,10 @@ replicaCount: # builder is the endpoint builder deployment builder: 1 +balloonConfig: + # If set to true, only high priority pods can preempt balloons. Otherwise, all pods can preempt balloons. + reserveHighPriority: true + balloons: # A low priority pod deployment for A10 GPU nodes - acceleratorName: nvidia-ampere-a10 @@ -91,6 +95,7 @@ balloons: # A low priority pod deployment for H100 GPU nodes - acceleratorName: nvidia-hopper-h100 replicaCount: 0 + gpuCount: 4 # autoscaling is the autoscaling configuration for LLM Engine server deployments (e.g gateway, cache, and builder deployments) autoscaling: From 3609e0878f4b438e0ed129c14d6095aaeefd3a28 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Thu, 7 Nov 2024 16:44:42 -0800 Subject: [PATCH 422/425] More vllm args passthrough (#659) --- .../model_engine_server/common/dtos/llms/vllm.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py index bd228bec..af6a6710 100644 --- a/model-engine/model_engine_server/common/dtos/llms/vllm.py +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -115,6 +115,15 @@ class VLLMModelConfig(BaseModel): description="Maximum number of data instances per modality per prompt. Only applicable for multimodal models.", ) + enable_prefix_caching: Optional[bool] = Field( + None, + description="Enables automatic prefix caching.", + ) + + max_num_batched_tokens: Optional[int] = Field( + None, description="Maximum number of batched tokens per iteration" + ) + class VLLMEngineAdditionalArgs(BaseModel): """Additional arguments to configure for vLLM that are not direct inputs to the vLLM engine""" From f2be2a9bf80d6ecd0788e91fb7922594bc3a6de5 Mon Sep 17 00:00:00 2001 From: Sean Shi <69175566+seanshi-scale@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:59:52 -0800 Subject: [PATCH 423/425] emit model name in dd traces, also emit error dd metrics on http timeouts (#660) * add this new trace dimension * try bumping ddtrace to the newest 1.x.y version * reset reqs to main * again * remove thing that doesn't work (rip) * emit sync call timeout metrics in monitoring metrics gateway * initialize the sync/streaming inference gateways to use the monitoring metrics gateway * Revert "initialize the sync/streaming inference gateways to use the monitoring metrics gateway" Let's just emit in the use case instead This reverts commit 0bf2a5446c37ad3bdfc9dcc1623b5a981c08bfb1. * wip try emitting from use cases, will probably abandon it * Revert "wip try emitting from use cases, will probably abandon it" This reverts commit 6b599bd5803721d0722dca8c9828aea4b2918c71. * Revert "Revert "initialize the sync/streaming inference gateways to use the monitoring metrics gateway"" ok let's actually just emit from the sync/streaming gateways This reverts commit 432c0b5792fbf6965766d7b74760d1720eedd83f. * small refactor * thread the readable endpoint name through everywhere * actually emit the metrics * rename * rename * comment + small type thing --- .../model_engine_server/api/dependencies.py | 2 + .../common/datadog_utils.py | 15 +++ .../gateways/monitoring_metrics_gateway.py | 9 ++ ...eaming_model_endpoint_inference_gateway.py | 8 +- .../sync_model_endpoint_inference_gateway.py | 7 +- .../use_cases/llm_model_endpoint_use_cases.py | 24 +++- .../streaming_inference_use_cases.py | 1 + .../use_cases/sync_inference_use_cases.py | 1 + .../start_batch_job_orchestration.py | 2 + .../datadog_monitoring_metrics_gateway.py | 5 + .../fake_monitoring_metrics_gateway.py | 5 + ...eaming_model_endpoint_inference_gateway.py | 13 +- ...e_sync_model_endpoint_inference_gateway.py | 13 +- model-engine/tests/unit/conftest.py | 2 + ...eaming_model_endpoint_inference_gateway.py | 93 ++++++++++---- ...e_sync_model_endpoint_inference_gateway.py | 115 +++++++++++++----- 16 files changed, 253 insertions(+), 62 deletions(-) diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index e708adf9..e120fbf0 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -251,9 +251,11 @@ def _get_external_interfaces( ) # In CircleCI, we cannot use asyncio because aiohttp cannot connect to the sync endpoints. sync_model_endpoint_inference_gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=monitoring_metrics_gateway, use_asyncio=(not CIRCLECI), ) streaming_model_endpoint_inference_gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=monitoring_metrics_gateway, use_asyncio=(not CIRCLECI), ) filesystem_gateway = ( diff --git a/model-engine/model_engine_server/common/datadog_utils.py b/model-engine/model_engine_server/common/datadog_utils.py index 5707d964..f5b2844e 100644 --- a/model-engine/model_engine_server/common/datadog_utils.py +++ b/model-engine/model_engine_server/common/datadog_utils.py @@ -13,3 +13,18 @@ def add_trace_request_id(request_id: Optional[str]): current_span = tracer.current_span() if current_span: current_span.set_tag("launch.request_id", request_id) + + +def add_trace_model_name(model_name: Optional[str]): + """Adds a custom tag to a given dd trace corresponding to the model name + so that we can filter in Datadog easier + + Only use this when the number of model names is small, otherwise it will + blow up the cardinality in Datadog + """ + if not model_name: + return + + current_span = tracer.current_span() + if current_span: + current_span.set_tag("launch.model_name", model_name) diff --git a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py index c17e5b09..dcad95d5 100644 --- a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -81,3 +81,12 @@ def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMeta Token count metrics """ pass + + @abstractmethod + def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int): + """ + Sync call timeout/error metrics, emitted when sync/streaming request + times out or otherwise errors (likely due to scale-from-zero not being + fast enough, or we're out of capacity, or the upstream svc is unhealthy) + """ + pass diff --git a/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py index b00470c3..5cb99973 100644 --- a/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import AsyncIterable +from typing import AsyncIterable, Optional from model_engine_server.common.dtos.tasks import ( SyncEndpointPredictV1Request, @@ -17,7 +17,11 @@ class StreamingModelEndpointInferenceGateway(ABC): @abstractmethod def streaming_predict( - self, topic: str, predict_request: SyncEndpointPredictV1Request, manually_resolve_dns: bool + self, + topic: str, + predict_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool, + endpoint_name: Optional[str] = None, ) -> AsyncIterable[SyncEndpointPredictV1Response]: """ Runs a prediction request and returns a streaming response. diff --git a/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py index 8df1277f..2bc9631e 100644 --- a/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Optional from model_engine_server.common.dtos.tasks import ( SyncEndpointPredictV1Request, @@ -16,7 +17,11 @@ class SyncModelEndpointInferenceGateway(ABC): @abstractmethod async def predict( - self, topic: str, predict_request: SyncEndpointPredictV1Request, manually_resolve_dns: bool + self, + topic: str, + predict_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool, + endpoint_name: Optional[str] = None, ) -> SyncEndpointPredictV1Response: """ Runs a prediction request and returns a response. diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 2fe693d2..e2069b39 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -119,7 +119,7 @@ get_models_s3_uri, ) -from ...common.datadog_utils import add_trace_request_id +from ...common.datadog_utils import add_trace_model_name, add_trace_request_id from ..authorization.live_authorization_module import LiveAuthorizationModule from .model_bundle_use_cases import CreateModelBundleV2UseCase from .model_endpoint_use_cases import ( @@ -2001,6 +2001,8 @@ async def execute( f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" ) + add_trace_model_name(model_endpoint_name) + model_endpoint = model_endpoints[0] if not self.authz_module.check_access_read_owned_entity( @@ -2065,6 +2067,7 @@ async def execute( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: @@ -2115,6 +2118,7 @@ async def execute( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -2175,6 +2179,7 @@ async def execute( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -2226,6 +2231,7 @@ async def execute( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -2270,6 +2276,7 @@ async def execute( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -2357,6 +2364,8 @@ async def execute( f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" ) + add_trace_model_name(model_endpoint_name) + model_endpoint = model_endpoints[0] if not self.authz_module.check_access_read_owned_entity( @@ -2550,6 +2559,7 @@ async def _response_chunk_generator( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) num_completion_tokens = 0 @@ -2760,6 +2770,8 @@ async def execute( f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" ) + add_trace_model_name(model_endpoint_name) + model_endpoint = model_endpoints[0] if not self.authz_module.check_access_read_owned_entity( @@ -2809,6 +2821,7 @@ async def execute( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -2866,6 +2879,8 @@ async def execute( f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" ) + add_trace_model_name(model_endpoint_name) + model_endpoint = model_endpoints[0] if not self.authz_module.check_access_read_owned_entity( @@ -2938,6 +2953,7 @@ async def _response_chunk_generator( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) except UpstreamServiceError as exc: # Expect upstream inference service to handle bulk of input validation @@ -3027,6 +3043,8 @@ async def execute( f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" ) + add_trace_model_name(model_endpoint_name) + model_endpoint = model_endpoints[0] if not self.authz_module.check_access_read_owned_entity( @@ -3076,6 +3094,7 @@ async def execute( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: @@ -3133,6 +3152,8 @@ async def execute( f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" ) + add_trace_model_name(model_endpoint_name) + model_endpoint = model_endpoints[0] if not self.authz_module.check_access_read_owned_entity( @@ -3204,6 +3225,7 @@ async def _response_chunk_generator( topic=model_endpoint.record.destination, predict_request=inference_request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) except UpstreamServiceError as exc: # Expect upstream inference service to handle bulk of input validation diff --git a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py index bdd27476..baf68a06 100644 --- a/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py @@ -81,4 +81,5 @@ async def execute( topic=model_endpoint.record.destination, predict_request=request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) diff --git a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py index 4985063a..a665a846 100644 --- a/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py @@ -85,4 +85,5 @@ async def execute( topic=model_endpoint.record.destination, predict_request=request, manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index af6eeef7..de1bd59b 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -109,9 +109,11 @@ async def run_batch_job( task_queue_gateway=inference_task_queue_gateway ) streaming_model_endpoint_inference_gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=monitoring_metrics_gateway, use_asyncio=(not CIRCLECI), ) sync_model_endpoint_inference_gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=monitoring_metrics_gateway, use_asyncio=(not CIRCLECI), ) filesystem_gateway = ( diff --git a/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py index 8732615d..93a73970 100644 --- a/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py @@ -85,3 +85,8 @@ def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMeta inter_token_latency = f"{self.prefix}.inter_token_latency" if token_usage.inter_token_latency is not None: statsd.distribution(inter_token_latency, token_usage.inter_token_latency, tags=tags) + + def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int): + tags = self.tags + tags.extend([f"endpoint_name:{endpoint_name}", f"error_code:{error_code}"]) + statsd.increment(f"{self.prefix}.upstream_sync_error", tags=tags) diff --git a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py index dc419a07..25bf45fa 100644 --- a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -22,6 +22,7 @@ def __init__(self): self.route_call = defaultdict(int) self.token_count = 0 self.total_tokens_per_second = 0 + self.sync_call_timeout = defaultdict(int) def reset(self): self.attempted_build = 0 @@ -37,6 +38,7 @@ def reset(self): self.route_call = defaultdict(int) self.token_count = 0 self.total_tokens_per_second = 0 + self.sync_call_timeout = defaultdict(int) def emit_attempted_build_metric(self): self.attempted_build += 1 @@ -74,3 +76,6 @@ def emit_route_call_metric(self, route: str, _metadata: MetricMetadata): def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMetadata): self.token_count += token_usage.num_total_tokens self.total_tokens_per_second = token_usage.total_tokens_per_second + + def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int): + self.sync_call_timeout[(endpoint_name, error_code)] += 1 diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index d4df0b6b..97f2b83a 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterable, Dict +from typing import Any, AsyncIterable, Dict, Optional import aiohttp import orjson @@ -20,6 +20,7 @@ TooManyRequestsException, UpstreamServiceError, ) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway from model_engine_server.domain.gateways.streaming_model_endpoint_inference_gateway import ( StreamingModelEndpointInferenceGateway, ) @@ -85,7 +86,8 @@ class LiveStreamingModelEndpointInferenceGateway(StreamingModelEndpointInference streaming_predict() wraps make_request_with_retries() and yields SyncEndpointPredictV1Response """ - def __init__(self, use_asyncio: bool): + def __init__(self, monitoring_metrics_gateway: MonitoringMetricsGateway, use_asyncio: bool): + self.monitoring_metrics_gateway = monitoring_metrics_gateway self.use_asyncio = use_asyncio async def make_single_request(self, request_url: str, payload_json: Dict[str, Any]): @@ -137,6 +139,7 @@ async def make_request_with_retries( payload_json: Dict[str, Any], timeout_seconds: float, num_retries: int, + endpoint_name: str, ) -> AsyncIterable[Dict[str, Any]]: # Copied from document-endpoint # More details at https://tenacity.readthedocs.io/en/latest/#retrying-code-block @@ -177,15 +180,19 @@ async def make_request_with_retries( except RetryError as e: if isinstance(e.last_attempt.exception(), TooManyRequestsException): logger.warning("Hit max # of retries, returning 429 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 429) raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") elif isinstance(e.last_attempt.exception(), NoHealthyUpstreamException): logger.warning("Pods didn't spin up in time, returning 503 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 503) raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") elif isinstance(e.last_attempt.exception(), aiohttp.ClientConnectorError): logger.warning("ClientConnectorError, returning 503 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 503) raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") else: logger.error("Unknown Exception Type") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 500) raise UpstreamServiceError(status_code=500, content=b"Unknown error") except JSONDecodeError: logger.exception("JSONDecodeError") @@ -201,6 +208,7 @@ async def streaming_predict( topic: str, predict_request: SyncEndpointPredictV1Request, manually_resolve_dns: bool = False, + endpoint_name: Optional[str] = None, ) -> AsyncIterable[SyncEndpointPredictV1Response]: deployment_url = _get_streaming_endpoint_url( topic, @@ -224,6 +232,7 @@ async def streaming_predict( payload_json=predict_request.model_dump(exclude_none=True), timeout_seconds=timeout_seconds, num_retries=num_retries, + endpoint_name=endpoint_name or topic, ) async for item in response: yield SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=item) diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index 2683123c..3e083fb6 100644 --- a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional import aiohttp import orjson @@ -18,6 +18,7 @@ TooManyRequestsException, UpstreamServiceError, ) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway from model_engine_server.domain.gateways.sync_model_endpoint_inference_gateway import ( SyncModelEndpointInferenceGateway, ) @@ -78,7 +79,8 @@ class LiveSyncModelEndpointInferenceGateway(SyncModelEndpointInferenceGateway): Concrete implementation for an SyncModelEndpointInferenceGateway. """ - def __init__(self, use_asyncio: bool): + def __init__(self, monitoring_metrics_gateway: MonitoringMetricsGateway, use_asyncio: bool): + self.monitoring_metrics_gateway = monitoring_metrics_gateway self.use_asyncio = use_asyncio async def make_single_request(self, request_url: str, payload_json: Dict[str, Any]): @@ -119,6 +121,7 @@ async def make_request_with_retries( payload_json: Dict[str, Any], timeout_seconds: float, num_retries: int, + endpoint_name: str, ) -> Dict[str, Any]: # Copied from document-endpoint # More details at https://tenacity.readthedocs.io/en/latest/#retrying-code-block @@ -154,15 +157,19 @@ async def make_request_with_retries( except RetryError as e: if isinstance(e.last_attempt.exception(), TooManyRequestsException): logger.warning("Hit max # of retries, returning 429 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 429) raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") elif isinstance(e.last_attempt.exception(), NoHealthyUpstreamException): logger.warning("Pods didn't spin up in time, returning 503 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 503) raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") elif isinstance(e.last_attempt.exception(), aiohttp.ClientConnectorError): logger.warning("ClientConnectorError, returning 503 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 503) raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") else: logger.error("Unknown Exception Type") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 500) raise UpstreamServiceError(status_code=500, content=b"Unknown error") # Never reached because tenacity should throw a RetryError if we exit the for loop. @@ -175,6 +182,7 @@ async def predict( topic: str, predict_request: SyncEndpointPredictV1Request, manually_resolve_dns: bool = False, + endpoint_name: Optional[str] = None, ) -> SyncEndpointPredictV1Response: deployment_url = _get_sync_endpoint_url( topic, @@ -198,6 +206,7 @@ async def predict( payload_json=predict_request.model_dump(exclude_none=True), timeout_seconds=timeout_seconds, num_retries=num_retries, + endpoint_name=endpoint_name or topic, ) except UpstreamServiceError as exc: logger.error(f"Service error on sync task: {exc.content!r}") diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 0b086f40..d96c86fa 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -1527,6 +1527,7 @@ async def streaming_predict( topic: str, predict_request: EndpointPredictV1Request, manually_resolve_dns: bool = False, + endpoint_name: Optional[str] = None, ) -> AsyncIterable[SyncEndpointPredictV1Response]: """ Runs a prediction request and returns a response. @@ -1551,6 +1552,7 @@ async def predict( topic: str, predict_request: EndpointPredictV1Request, manually_resolve_dns: bool = False, + endpoint_name: Optional[str] = None, ) -> SyncEndpointPredictV1Response: """ Runs a prediction request and returns a response. diff --git a/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py index db48e7f8..2bddfffa 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py @@ -10,6 +10,7 @@ SyncEndpointPredictV1Response, ) from model_engine_server.domain.exceptions import InvalidRequestException, UpstreamServiceError +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway from model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway import ( LiveStreamingModelEndpointInferenceGateway, ) @@ -71,8 +72,12 @@ async def _aexit(*exc): @pytest.mark.asyncio -async def test_make_request_with_retries_success(): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) +async def test_make_request_with_retries_success( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) fake_response = FakeResponse(status=200) mock_client_session = _get_mock_client_session(fake_response) @@ -81,7 +86,9 @@ async def test_make_request_with_retries_success(): "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): - response = gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) + response = gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) count = 0 async for message in response: assert message == {"test": "content"} @@ -90,8 +97,12 @@ async def test_make_request_with_retries_success(): @pytest.mark.asyncio -async def test_make_request_with_retries_failed_429(): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) +async def test_make_request_with_retries_failed_429( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) fake_response = FakeResponse(status=429) mock_client_session = _get_mock_client_session(fake_response) @@ -100,13 +111,19 @@ async def test_make_request_with_retries_failed_429(): "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): - async for response in gateway.make_request_with_retries("test_request_url", {}, 0.05, 2): + async for response in gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ): response @pytest.mark.asyncio -async def test_make_request_with_retries_failed_traceback(): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) +async def test_make_request_with_retries_failed_traceback( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) fake_response = FakeResponse(status=500) mock_client_session = _get_mock_client_session(fake_response) @@ -115,13 +132,19 @@ async def test_make_request_with_retries_failed_traceback(): "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): - async for response in gateway.make_request_with_retries("test_request_url", {}, 0.05, 2): + async for response in gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ): response @pytest.mark.asyncio -async def test_make_request_with_retries_failed_with_client_connector_error(): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) +async def test_make_request_with_retries_failed_with_client_connector_error( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) mock_client_session = _get_mock_client_session_with_client_connector_error() @@ -129,15 +152,20 @@ async def test_make_request_with_retries_failed_with_client_connector_error(): "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): - async for response in gateway.make_request_with_retries("test_request_url", {}, 0.05, 2): + async for response in gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ): response @pytest.mark.asyncio async def test_streaming_predict_success( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) fake_response = FakeResponse(status=200) mock_client_session = _get_mock_client_session(fake_response) @@ -146,7 +174,9 @@ async def test_streaming_predict_success( mock_client_session, ): response = gateway.streaming_predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) count = 0 async for message in response: @@ -162,9 +192,12 @@ async def test_streaming_predict_success( @pytest.mark.asyncio async def test_predict_raises_traceback_json( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) content = json.dumps({"detail": {"traceback": "test_traceback"}}).encode("utf-8") fake_response = FakeResponse(status=500, message_content=content) @@ -174,7 +207,9 @@ async def test_predict_raises_traceback_json( mock_client_session, ): response = gateway.streaming_predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) count = 0 async for message in response: @@ -190,9 +225,12 @@ async def test_predict_raises_traceback_json( @pytest.mark.asyncio async def test_predict_raises_traceback_not_json( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) content = b"Test traceback content" fake_response = FakeResponse(status=500, message_content=content) @@ -202,7 +240,9 @@ async def test_predict_raises_traceback_not_json( mock_client_session, ): response = gateway.streaming_predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) count = 0 async for message in response: @@ -218,9 +258,12 @@ async def test_predict_raises_traceback_not_json( @pytest.mark.asyncio async def test_predict_upstream_raises_400( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) content = json.dumps({"result": json.dumps({"error": "error"})}).encode("utf-8") fake_response = FakeResponse(status=400, message_content=content) @@ -231,7 +274,9 @@ async def test_predict_upstream_raises_400( ): with pytest.raises(InvalidRequestException): response = gateway.streaming_predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) async for message in response: message diff --git a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py index fdc74288..608b73cf 100644 --- a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py @@ -10,6 +10,7 @@ SyncEndpointPredictV1Response, ) from model_engine_server.domain.exceptions import InvalidRequestException, UpstreamServiceError +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway from model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway import ( LiveSyncModelEndpointInferenceGateway, ) @@ -55,8 +56,12 @@ async def _aexit(*exc): @pytest.mark.asyncio -async def test_make_request_with_retries_success(): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) +async def test_make_request_with_retries_success( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) fake_response = FakeResponse(status=200) mock_client_session = _get_mock_client_session(fake_response) @@ -65,13 +70,19 @@ async def test_make_request_with_retries_success(): "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): - response = await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) + response = await gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) assert response == {"test_key": "test_value"} @pytest.mark.asyncio -async def test_make_request_with_retries_failed_429(): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) +async def test_make_request_with_retries_failed_429( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) fake_response = FakeResponse(status=429) mock_client_session = _get_mock_client_session(fake_response) @@ -80,12 +91,18 @@ async def test_make_request_with_retries_failed_429(): "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): - await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) + await gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) @pytest.mark.asyncio -async def test_make_request_with_retries_failed_traceback(): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) +async def test_make_request_with_retries_failed_traceback( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) fake_response = FakeResponse(status=500) mock_client_session = _get_mock_client_session(fake_response) @@ -94,12 +111,18 @@ async def test_make_request_with_retries_failed_traceback(): "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): - await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) + await gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) @pytest.mark.asyncio -async def test_make_request_with_retries_failed_with_client_connector_error(): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) +async def test_make_request_with_retries_failed_with_client_connector_error( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) mock_client_session = _get_mock_client_session_with_client_connector_error() @@ -107,14 +130,19 @@ async def test_make_request_with_retries_failed_with_client_connector_error(): "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", mock_client_session, ): - await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) + await gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) @pytest.mark.asyncio async def test_predict_success( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) fake_response = FakeResponse(status=200, body={"test_key": "test_value"}) mock_client_session = _get_mock_client_session(fake_response) @@ -123,7 +151,9 @@ async def test_predict_success( mock_client_session, ): response = await gateway.predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) assert isinstance(response, SyncEndpointPredictV1Response) assert response.dict() == { @@ -135,9 +165,12 @@ async def test_predict_success( @pytest.mark.asyncio async def test_predict_raises_traceback_json( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) content = json.dumps({"detail": {"traceback": "test_traceback"}}).encode("utf-8") fake_response = FakeResponse(status=500, content=content) @@ -147,7 +180,9 @@ async def test_predict_raises_traceback_json( mock_client_session, ): response = await gateway.predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) assert isinstance(response, SyncEndpointPredictV1Response) assert response.dict() == { @@ -159,9 +194,12 @@ async def test_predict_raises_traceback_json( @pytest.mark.asyncio async def test_predict_raises_traceback_not_json( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) content = b"Test traceback content" fake_response = FakeResponse(status=500, content=content) @@ -171,7 +209,9 @@ async def test_predict_raises_traceback_not_json( mock_client_session, ): response = await gateway.predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) assert isinstance(response, SyncEndpointPredictV1Response) assert response.dict() == { @@ -183,9 +223,12 @@ async def test_predict_raises_traceback_not_json( @pytest.mark.asyncio async def test_predict_raises_traceback_wrapped( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) content = json.dumps( {"result": json.dumps({"detail": {"traceback": "test_traceback"}})} @@ -197,7 +240,9 @@ async def test_predict_raises_traceback_wrapped( mock_client_session, ): response = await gateway.predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) assert isinstance(response, SyncEndpointPredictV1Response) assert response.dict() == { @@ -209,9 +254,12 @@ async def test_predict_raises_traceback_wrapped( @pytest.mark.asyncio async def test_predict_raises_traceback_wrapped_detail_array( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) content = json.dumps({"result": json.dumps({"detail": [{"error": "error"}]})}).encode("utf-8") fake_response = FakeResponse(status=500, content=content) @@ -221,7 +269,9 @@ async def test_predict_raises_traceback_wrapped_detail_array( mock_client_session, ): response = await gateway.predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) assert isinstance(response, SyncEndpointPredictV1Response) assert response.dict() == { @@ -233,9 +283,12 @@ async def test_predict_raises_traceback_wrapped_detail_array( @pytest.mark.asyncio async def test_predict_upstream_raises_400( - sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, ): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) content = json.dumps({"result": json.dumps({"error": "error"})}).encode("utf-8") fake_response = FakeResponse(status=400, content=content) @@ -247,5 +300,7 @@ async def test_predict_upstream_raises_400( # assert that the exception is raised with pytest.raises(InvalidRequestException): await gateway.predict( - topic="test_topic", predict_request=sync_endpoint_predict_request_1[0] + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", ) From bd77a0a9c8535866f5d4487c9c10b8be95129e2b Mon Sep 17 00:00:00 2001 From: Sandesh Ghanta Date: Tue, 19 Nov 2024 14:04:47 -0800 Subject: [PATCH 424/425] Add max_model_len as Optional Argument for Model.create API (#661) Pass said parameter to vLLM engine if requested by user --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/model.py | 5 +++++ clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index e8d44235..206b405b 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0beta43" +__version__ = "0.0.0beta44" import os from typing import Sequence diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index c03be3f5..13527a67 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -42,6 +42,7 @@ def create( num_shards: int = 1, quantize: Optional[Quantization] = None, checkpoint_path: Optional[str] = None, + max_model_len: Optional[int] = None, # General endpoint fields cpus: Optional[int] = None, memory: Optional[str] = None, @@ -93,6 +94,9 @@ def create( Can be either a folder or a tar file. Folder is preferred since we don't need to untar and model loads faster. For model weights, safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be longer). + max_model_len (`Optional[int]`): + Model context length. If unspecified, will be automatically derived from the model config. + cpus (`Optional[int]`): Number of cpus each node in the worker should get, e.g. 1, 2, etc. This must be greater than or equal to 1. Recommendation is set it to 8 * GPU count. Can be inferred from the model size. @@ -307,6 +311,7 @@ def create( num_shards=num_shards, quantize=quantize, checkpoint_path=checkpoint_path, + max_model_len=max_model_len, cpus=cpus, endpoint_type=ModelEndpointType(endpoint_type), gpus=gpus, diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index a0e1794f..4c2bd931 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta43" +version = "0.0.0.beta44" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 097bcd1e..986694bb 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,7 +3,7 @@ setup( name="scale-llm-engine", python_requires=">=3.8", - version="0.0.0.beta43", + version="0.0.0.beta44", packages=find_packages(), package_data={"llmengine": ["py.typed"]}, ) From 36b8240c18a2c8b3ef413da9646930efdea1ae0b Mon Sep 17 00:00:00 2001 From: anagnoko23 Date: Wed, 20 Nov 2024 13:31:29 +0200 Subject: [PATCH 425/425] Update setup.py I removed an empty line. --- model-engine/setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/model-engine/setup.py b/model-engine/setup.py index 190be7c7..bc0a0548 100644 --- a/model-engine/setup.py +++ b/model-engine/setup.py @@ -1,6 +1,5 @@ # To get circleci to work from setuptools import find_packages, setup - setup( name="model_engine_server", version="1.0.0", @@ -16,5 +15,5 @@ "autogen=model_engine_server.scripts.autogenerate_client_and_docs:entrypoint", "launch-admin=model_engine_server.cli.bin:entrypoint", ], - }, + } )