From 73f5b26a7e38074b4fb48573b11e337118f87415 Mon Sep 17 00:00:00 2001 From: jimpang Date: Mon, 1 Apr 2024 20:50:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=BF=94=E5=9B=9Elog?= =?UTF-8?q?probs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm/entrypoints/api_server.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 3a0d05f48491e..77129e2dcf393 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -9,6 +9,7 @@ import argparse import json import ssl +from dataclasses import asdict from typing import AsyncGenerator import uvicorn @@ -75,7 +76,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: output.text for output in request_output.outputs ] output_tokens = [output.token_ids for output in request_output.outputs] - logprobs = [output.logprobs for output in request_output.outputs] + logprobs = [[{k: asdict(v) for k, v in logprobs.items()} for logprobs in + output.logprobs] if output.logprobs is not None else None for output in request_output.outputs] ret = {"text": text_outputs, "output_token_ids": output_tokens, "logprobs": logprobs} yield (json.dumps(ret) + "\0").encode("utf-8") @@ -94,7 +96,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: assert final_output is not None text_outputs = [output.text for output in final_output.outputs] output_tokens = [output.token_ids for output in final_output.outputs] - logprobs = [output.logprobs for output in final_output.outputs] + logprobs = [[{k: asdict(v) for k, v in logprobs.items()} for logprobs in + output.logprobs] if output.logprobs is not None else None for output in final_output.outputs] ret = {"text": text_outputs, "output_token_ids": output_tokens, "logprobs": logprobs} return JSONResponse(ret)