From 34728ad645b92d9a5d00cf5f958a057dfe1f95cc Mon Sep 17 00:00:00 2001 From: Sai Atmakuri Date: Thu, 16 Nov 2023 01:10:31 +0000 Subject: [PATCH 1/2] emit metrics on token counts --- .../model_engine_server/api/llms_v1.py | 47 +++++++++++++++---- .../model_engine_server/common/dtos/llms.py | 9 ++++ .../gateways/monitoring_metrics_gateway.py | 9 ++++ .../fake_monitoring_metrics_gateway.py | 7 +++ 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 153b4c2c..6fc2a1d1 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -5,7 +5,7 @@ from typing import Optional import pytz -from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, @@ -32,6 +32,7 @@ ModelDownloadResponse, StreamError, StreamErrorContent, + TokenUsage, ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User @@ -86,17 +87,20 @@ def format_request_route(request: Request) -> str: return f"{request.method}_{url_path}".lower() -async def record_route_call( +async def get_metric_metadata( request: Request, auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -): - route = format_request_route(request) +) -> MetricMetadata: model_name = request.query_params.get("model_endpoint_name", None) + return MetricMetadata(user=auth, model_name=model_name) - external_interfaces.monitoring_metrics_gateway.emit_route_call_metric( - route, MetricMetadata(user=auth, model_name=model_name) - ) + +async def record_route_call( + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + route: str = Depends(format_request_route), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +): + external_interfaces.monitoring_metrics_gateway.emit_route_call_metric(route, metric_metadata) llm_router_v1 = APIRouter(prefix="/v1/llm", dependencies=[Depends(record_route_call)]) @@ -234,8 +238,10 @@ async def get_model_endpoint( async def create_completion_sync_task( model_endpoint_name: str, request: CompletionSyncV1Request, + background_tasks: BackgroundTasks, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), ) -> CompletionSyncV1Response: """ Runs a sync prompt completion on an LLM. @@ -249,9 +255,20 @@ async def create_completion_sync_task( llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, tokenizer_repository=external_interfaces.tokenizer_repository, ) - return await use_case.execute( + response = await use_case.execute( user=auth, model_endpoint_name=model_endpoint_name, request=request ) + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=response.output.num_prompt_tokens if response.output else None, + num_completion_tokens=response.output.num_completion_tokens + if response.output + else None, + ), + metric_metadata, + ) + return response except UpstreamServiceError: request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) logger.exception(f"Upstream service error for request {request_id}") @@ -279,8 +296,10 @@ async def create_completion_sync_task( async def create_completion_stream_task( model_endpoint_name: str, request: CompletionStreamV1Request, + background_tasks: BackgroundTasks, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), ) -> EventSourceResponse: """ Runs a stream prompt completion on an LLM. @@ -299,6 +318,16 @@ async def event_generator(): try: async for message in response: yield {"data": message.json()} + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=message.output.num_prompt_tokens if message.output else None, + num_completion_tokens=message.output.num_completion_tokens + if message.output + else None, + ), + metric_metadata, + ) except (InvalidRequestException, ObjectHasInvalidValueException) as exc: yield handle_streaming_exception(exc, 400, str(exc)) except ( diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index c0e6b9fc..dd2e06a0 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -233,6 +233,15 @@ class CompletionStreamV1Response(BaseModel): """Error of the response (if any).""" +class TokenUsage(BaseModel): + num_prompt_tokens: Optional[int] = 0 + num_completion_tokens: Optional[int] = 0 + + @property + def num_total_tokens(self) -> int: + return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) + + class CreateFineTuneRequest(BaseModel): model: str training_file: str diff --git a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py index 5e7e0382..9bca6a0d 100644 --- a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from typing import Optional +from model_engine_server.common.dtos.llms import TokenUsage from model_engine_server.core.auth.authentication_repository import User from pydantic import BaseModel @@ -81,3 +82,11 @@ def emit_route_call_metric(self, route: str, metadata: MetricMetadata): """ pass + + @abstractmethod + def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMetadata): + """ + Token count metrics + + """ + pass diff --git a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py index 32c2b6f3..bdc1fc79 100644 --- a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -1,5 +1,6 @@ from collections import defaultdict +from model_engine_server.common.dtos.llms import TokenUsage from model_engine_server.domain.gateways.monitoring_metrics_gateway import ( MetricMetadata, MonitoringMetricsGateway, @@ -19,6 +20,7 @@ def __init__(self): self.database_cache_hit = 0 self.database_cache_miss = 0 self.route_call = defaultdict(int) + self.token_count = 0 def reset(self): self.attempted_build = 0 @@ -32,6 +34,7 @@ def reset(self): self.database_cache_hit = 0 self.database_cache_miss = 0 self.route_call = defaultdict(int) + self.token_count = 0 def emit_attempted_build_metric(self): self.attempted_build += 1 @@ -65,3 +68,7 @@ def emit_database_cache_miss_metric(self): def emit_route_call_metric(self, route: str, _metadata: MetricMetadata): self.route_call[route] += 1 + print("route_call", route, self.route_call[route]) + + def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMetadata): + self.token_count += token_usage.num_total_tokens From d65363b80b3ec454069b539345deb615aec4b973 Mon Sep 17 00:00:00 2001 From: Sai Atmakuri Date: Thu, 16 Nov 2023 01:39:59 +0000 Subject: [PATCH 2/2] remove print --- .../infra/gateways/fake_monitoring_metrics_gateway.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py index bdc1fc79..9b63a135 100644 --- a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -68,7 +68,6 @@ def emit_database_cache_miss_metric(self): def emit_route_call_metric(self, route: str, _metadata: MetricMetadata): self.route_call[route] += 1 - print("route_call", route, self.route_call[route]) def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMetadata): self.token_count += token_usage.num_total_tokens