From 338bce6da8d95902fab8192dc7352729bac1cff2 Mon Sep 17 00:00:00 2001 From: jimpang Date: Tue, 21 Nov 2023 17:05:05 +0800 Subject: [PATCH] feat: return output_token_ids in generate api --- vllm/entrypoints/api_server.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index b4b42e52b5e90..d5a59e81dda7d 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -2,9 +2,9 @@ import json from typing import AsyncGenerator +import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse -import uvicorn from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -55,7 +55,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]: text_outputs = [ output.text for output in request_output.outputs ] - ret = {"text": text_outputs} + output_tokens = [output.token_ids for output in request_output.outputs] + + ret = {"text": text_outputs, "output_token_ids": output_tokens} yield (json.dumps(ret) + "\0").encode("utf-8") if stream: @@ -72,7 +74,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: assert final_output is not None text_outputs = [output.text for output in final_output.outputs] - ret = {"text": text_outputs} + output_tokens = [output.token_ids for output in final_output.outputs] + ret = {"text": text_outputs, "output_token_ids": output_tokens} return JSONResponse(ret)