Skip to content

emit metrics on token counts #382

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 3 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

import pytz
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request
from model_engine_server.api.dependencies import (
ExternalInterfaces,
get_external_interfaces,
Expand All @@ -32,6 +32,7 @@
ModelDownloadResponse,
StreamError,
StreamErrorContent,
TokenUsage,
)
from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy
from model_engine_server.core.auth.authentication_repository import User
Expand Down Expand Up @@ -86,17 +87,20 @@ def format_request_route(request: Request) -> str:
return f"{request.method}_{url_path}".lower()


async def record_route_call(
async def get_metric_metadata(
request: Request,
auth: User = Depends(verify_authentication),
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only),
):
route = format_request_route(request)
) -> MetricMetadata:
model_name = request.query_params.get("model_endpoint_name", None)
return MetricMetadata(user=auth, model_name=model_name)

external_interfaces.monitoring_metrics_gateway.emit_route_call_metric(
route, MetricMetadata(user=auth, model_name=model_name)
)

async def record_route_call(
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only),
route: str = Depends(format_request_route),
metric_metadata: MetricMetadata = Depends(get_metric_metadata),
):
external_interfaces.monitoring_metrics_gateway.emit_route_call_metric(route, metric_metadata)


llm_router_v1 = APIRouter(prefix="/v1/llm", dependencies=[Depends(record_route_call)])
Expand Down Expand Up @@ -234,8 +238,10 @@ async def get_model_endpoint(
async def create_completion_sync_task(
model_endpoint_name: str,
request: CompletionSyncV1Request,
background_tasks: BackgroundTasks,
auth: User = Depends(verify_authentication),
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only),
metric_metadata: MetricMetadata = Depends(get_metric_metadata),
) -> CompletionSyncV1Response:
"""
Runs a sync prompt completion on an LLM.
Expand All @@ -249,9 +255,20 @@ async def create_completion_sync_task(
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
tokenizer_repository=external_interfaces.tokenizer_repository,
)
return await use_case.execute(
response = await use_case.execute(
user=auth, model_endpoint_name=model_endpoint_name, request=request
)
background_tasks.add_task(
external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics,
TokenUsage(
num_prompt_tokens=response.output.num_prompt_tokens if response.output else None,
num_completion_tokens=response.output.num_completion_tokens
if response.output
else None,
),
metric_metadata,
)
return response
except UpstreamServiceError:
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
logger.exception(f"Upstream service error for request {request_id}")
Expand Down Expand Up @@ -279,8 +296,10 @@ async def create_completion_sync_task(
async def create_completion_stream_task(
model_endpoint_name: str,
request: CompletionStreamV1Request,
background_tasks: BackgroundTasks,
auth: User = Depends(verify_authentication),
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only),
metric_metadata: MetricMetadata = Depends(get_metric_metadata),
) -> EventSourceResponse:
"""
Runs a stream prompt completion on an LLM.
Expand All @@ -299,6 +318,16 @@ async def event_generator():
try:
async for message in response:
yield {"data": message.json()}
background_tasks.add_task(
external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics,
TokenUsage(
num_prompt_tokens=message.output.num_prompt_tokens if message.output else None,
num_completion_tokens=message.output.num_completion_tokens
if message.output
else None,
),
metric_metadata,
)
except (InvalidRequestException, ObjectHasInvalidValueException) as exc:
yield handle_streaming_exception(exc, 400, str(exc))
except (
Expand Down
9 changes: 9 additions & 0 deletions model-engine/model_engine_server/common/dtos/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,15 @@ class CompletionStreamV1Response(BaseModel):
"""Error of the response (if any)."""


class TokenUsage(BaseModel):
num_prompt_tokens: Optional[int] = 0
num_completion_tokens: Optional[int] = 0

@property
def num_total_tokens(self) -> int:
return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0)


class CreateFineTuneRequest(BaseModel):
model: str
training_file: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from abc import ABC, abstractmethod
from typing import Optional

from model_engine_server.common.dtos.llms import TokenUsage
from model_engine_server.core.auth.authentication_repository import User
from pydantic import BaseModel

Expand Down Expand Up @@ -81,3 +82,11 @@ def emit_route_call_metric(self, route: str, metadata: MetricMetadata):

"""
pass

@abstractmethod
def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMetadata):
"""
Token count metrics

"""
pass
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict

from model_engine_server.common.dtos.llms import TokenUsage
from model_engine_server.domain.gateways.monitoring_metrics_gateway import (
MetricMetadata,
MonitoringMetricsGateway,
Expand All @@ -19,6 +20,7 @@ def __init__(self):
self.database_cache_hit = 0
self.database_cache_miss = 0
self.route_call = defaultdict(int)
self.token_count = 0

def reset(self):
self.attempted_build = 0
Expand All @@ -32,6 +34,7 @@ def reset(self):
self.database_cache_hit = 0
self.database_cache_miss = 0
self.route_call = defaultdict(int)
self.token_count = 0

def emit_attempted_build_metric(self):
self.attempted_build += 1
Expand Down Expand Up @@ -65,3 +68,6 @@ def emit_database_cache_miss_metric(self):

def emit_route_call_metric(self, route: str, _metadata: MetricMetadata):
self.route_call[route] += 1

def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMetadata):
self.token_count += token_usage.num_total_tokens