Skip to content

Commit 1dc1012

Browse files
committed
Modularization, Batch servers and Documentation
1 parent f69de58 commit 1dc1012

File tree

9 files changed

+777
-706
lines changed

9 files changed

+777
-706
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ RUN apt-get update && apt-get install -y \
2323

2424
# Install Python dependencies
2525
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt && \
26-
pip install "fastapi[standard]" "uvicorn[standard]" httpx fastapi-mcp
26+
pip install "fastapi[standard]" "uvicorn[standard]" httpx fastapi-mcp psutil
2727

2828
# (Optional) Run your setup_env.py if needed
2929
RUN python /code/setup_env.py -md /code/models/BitNet-b1.58-2B-4T -q i2_s

app/lib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .endpoints import ChatRequest
1+
from .endpoints.chat_endpoints import ChatRequest
22
from typing import List
33
from pydantic import BaseModel
44

app/lib/endpoints.py

Lines changed: 0 additions & 672 deletions
This file was deleted.
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import os
2+
import asyncio
3+
import logging
4+
from ..models import ModelEnum, BenchmarkRequest, PerplexityRequest
5+
from ..utils import parse_benchmark_data, parse_perplexity_data
6+
7+
import os
8+
import subprocess # Keep for CalledProcessError
9+
import asyncio # Ensure asyncio is imported
10+
from pydantic import BaseModel, Field
11+
from fastapi import HTTPException, Query, Depends
12+
import logging # Import logging
13+
14+
# --- Logging Configuration for this module ---
15+
logger = logging.getLogger(__name__)
16+
17+
def validate_prompt_length(prompt: str = Query(..., description="Input text for perplexity calculation"), ctx_size: int = Query(10, gt=3)) -> str:
18+
token_count = len(prompt.split())
19+
min_tokens = 2 * ctx_size
20+
if token_count < min_tokens:
21+
raise HTTPException(
22+
status_code=400,
23+
detail=f"Prompt too short. Needs at least {min_tokens} tokens, got {token_count}"
24+
)
25+
return prompt
26+
27+
async def run_benchmark(
28+
model: ModelEnum,
29+
n_token: int = Query(128, gt=0),
30+
threads: int = Query(2, gt=0),
31+
n_prompt: int = Query(32, gt=0)
32+
):
33+
request = BenchmarkRequest(model=model, n_token=n_token, threads=threads, n_prompt=n_prompt)
34+
build_dir = os.getenv("BUILD_DIR", "build")
35+
bench_path = os.path.join(build_dir, "bin", "llama-bench")
36+
if not os.path.exists(bench_path):
37+
raise HTTPException(status_code=500, detail="Benchmark binary not found")
38+
command = [
39+
bench_path,
40+
'-m', request.model.value,
41+
'-n', str(request.n_token),
42+
'-ngl', '0',
43+
'-b', '1',
44+
'-t', str(request.threads),
45+
'-p', str(request.n_prompt),
46+
'-r', '5'
47+
]
48+
try:
49+
process = await asyncio.create_subprocess_exec(
50+
*command,
51+
stdout=asyncio.subprocess.PIPE,
52+
stderr=asyncio.subprocess.PIPE
53+
)
54+
stdout_bytes, stderr_bytes = await process.communicate()
55+
if process.returncode != 0:
56+
raise HTTPException(status_code=500, detail=f"Benchmark failed: {stderr_bytes.decode(errors='ignore')}")
57+
parsed_data = parse_benchmark_data(stdout_bytes.decode(errors='ignore'))
58+
return parsed_data
59+
except Exception as e:
60+
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during benchmark: {str(e)}")
61+
62+
async def run_perplexity(
63+
model: ModelEnum,
64+
prompt: str = Depends(validate_prompt_length),
65+
threads: int = Query(2, gt=0),
66+
ctx_size: int = Query(10, gt=3),
67+
ppl_stride: int = Query(0, ge=0)
68+
):
69+
"""Calculate perplexity for given text and model"""
70+
try:
71+
request = PerplexityRequest(
72+
model=model,
73+
prompt=prompt,
74+
threads=threads,
75+
ctx_size=ctx_size,
76+
ppl_stride=ppl_stride
77+
)
78+
except ValueError as e:
79+
raise HTTPException(status_code=400, detail=str(e))
80+
81+
build_dir = os.getenv("BUILD_DIR", "build")
82+
ppl_path = os.path.join(build_dir, "bin", "llama-perplexity")
83+
if not os.path.exists(ppl_path):
84+
logger.error(f"Perplexity binary not found at '{ppl_path}'.")
85+
raise HTTPException(status_code=500, detail="Perplexity binary not found")
86+
87+
command = [
88+
ppl_path,
89+
'--model', request.model.value,
90+
'--prompt', request.prompt,
91+
'--threads', str(request.threads),
92+
'--ctx-size', str(request.ctx_size),
93+
'--perplexity',
94+
'--ppl-stride', str(request.ppl_stride)
95+
]
96+
97+
try:
98+
logger.info(f"Running perplexity calculation with command: {' '.join(command)}")
99+
process = await asyncio.create_subprocess_exec(
100+
*command,
101+
stdout=asyncio.subprocess.PIPE, # Perplexity might output to stdout or stderr
102+
stderr=asyncio.subprocess.PIPE
103+
)
104+
stdout_bytes, stderr_bytes = await process.communicate()
105+
106+
if process.returncode != 0:
107+
logger.error(f"Perplexity calculation failed. RC: {process.returncode}. Stderr: {stderr_bytes.decode(errors='ignore')}")
108+
raise subprocess.CalledProcessError(
109+
process.returncode, cmd=command, output=stdout_bytes, stderr=stderr_bytes
110+
)
111+
112+
# Original code parsed from stderr, stick to that unless known otherwise
113+
parsed_data = parse_perplexity_data(stderr_bytes.decode(errors='ignore'))
114+
logger.info("Perplexity calculation completed successfully.")
115+
return parsed_data
116+
except subprocess.CalledProcessError as e:
117+
logger.error(f"Perplexity calculation failed: {str(e)}. Command: {e.cmd}. RC: {e.returncode}. Stderr: {e.stderr.decode(errors='ignore') if e.stderr else ''}", exc_info=True)
118+
raise HTTPException(status_code=500, detail=f"Perplexity calculation failed: {e.stderr.decode(errors='ignore') if e.stderr else str(e)}")
119+
except Exception as e: # Catch any other unexpected errors
120+
logger.error(f"Unexpected error during perplexity calculation: {str(e)}", exc_info=True)
121+
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during perplexity calculation: {str(e)}")
122+
123+
def get_model_sizes():
124+
"""Endpoint to get the file sizes of supported .gguf models."""
125+
model_sizes = {}
126+
models_dir = "models"
127+
for subdir in os.listdir(models_dir):
128+
subdir_path = os.path.join(models_dir, subdir)
129+
if os.path.isdir(subdir_path):
130+
for file in os.listdir(subdir_path):
131+
if file.endswith(".gguf"):
132+
file_path = os.path.join(subdir_path, file)
133+
file_size_bytes = os.path.getsize(file_path)
134+
file_size_mb = round(file_size_bytes / (1024 * 1024), 3)
135+
file_size_gb = round(file_size_bytes / (1024 * 1024 * 1024), 3)
136+
model_sizes[file] = {
137+
"bytes": file_size_bytes,
138+
"MB": file_size_mb,
139+
"GB": file_size_gb
140+
}
141+
return model_sizes
142+
143+
async def run_benchmark(
144+
model: ModelEnum,
145+
n_token: int = Query(128, gt=0),
146+
threads: int = Query(2, gt=0),
147+
n_prompt: int = Query(32, gt=0)
148+
):
149+
"""Run benchmark on specified model"""
150+
request = BenchmarkRequest(model=model, n_token=n_token, threads=threads, n_prompt=n_prompt)
151+
build_dir = os.getenv("BUILD_DIR", "build")
152+
bench_path = os.path.join(build_dir, "bin", "llama-bench")
153+
if not os.path.exists(bench_path):
154+
logger.error(f"Benchmark binary not found at '{bench_path}'.")
155+
raise HTTPException(status_code=500, detail="Benchmark binary not found")
156+
command = [
157+
bench_path,
158+
'-m', request.model.value,
159+
'-n', str(request.n_token),
160+
'-ngl', '0',
161+
'-b', '1',
162+
'-t', str(request.threads),
163+
'-p', str(request.n_prompt),
164+
'-r', '5'
165+
]
166+
try:
167+
logger.info(f"Running benchmark with command: {' '.join(command)}")
168+
# Replace subprocess.run with asyncio.create_subprocess_exec and communicate
169+
process = await asyncio.create_subprocess_exec(
170+
*command,
171+
stdout=asyncio.subprocess.PIPE,
172+
stderr=asyncio.subprocess.PIPE
173+
)
174+
stdout_bytes, stderr_bytes = await process.communicate() # Wait for completion
175+
176+
if process.returncode != 0:
177+
logger.error(f"Benchmark failed. RC: {process.returncode}. Stderr: {stderr_bytes.decode(errors='ignore')}")
178+
raise subprocess.CalledProcessError(
179+
process.returncode, cmd=command, output=stdout_bytes, stderr=stderr_bytes
180+
)
181+
182+
parsed_data = parse_benchmark_data(stdout_bytes.decode(errors='ignore'))
183+
logger.info("Benchmark completed successfully.")
184+
return parsed_data
185+
except subprocess.CalledProcessError as e: # Catch the specific error
186+
# Log details from the CalledProcessError object
187+
logger.error(f"Benchmark failed: {str(e)}. Command: {e.cmd}. RC: {e.returncode}. Stdout: {e.stdout.decode(errors='ignore') if e.stdout else ''}. Stderr: {e.stderr.decode(errors='ignore') if e.stderr else ''}", exc_info=True)
188+
raise HTTPException(status_code=500, detail=f"Benchmark failed: {e.stderr.decode(errors='ignore') if e.stderr else str(e)}")
189+
except Exception as e: # Catch any other unexpected errors
190+
logger.error(f"Unexpected error during benchmark: {str(e)}", exc_info=True)
191+
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during benchmark: {str(e)}")

app/lib/endpoints/chat_endpoints.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from fastapi import HTTPException
2+
import httpx
3+
import asyncio
4+
import logging
5+
from .process_management import get_server_processes, get_server_configs
6+
from pydantic import BaseModel
7+
from typing import List
8+
9+
logger = logging.getLogger(__name__)
10+
11+
class ChatRequest(BaseModel):
12+
message: str
13+
port: int = 8081
14+
threads: int = 1
15+
ctx_size: int = 2048
16+
n_predict: int = 256
17+
temperature: float = 0.8
18+
19+
async def chat_with_bitnet(chat: ChatRequest):
20+
host = "127.0.0.1"
21+
key = (host, chat.port)
22+
proc_entry = get_server_processes().get(key)
23+
cfg = get_server_configs().get(key)
24+
if not (proc_entry and proc_entry["process"].returncode is None and cfg):
25+
logger.warning(f"Chat request to non-existent or stopped server on port {chat.port}.")
26+
raise HTTPException(status_code=404, detail=f"Server on port {chat.port} not running or not configured.")
27+
server_url = f"http://{host}:{chat.port}/completion"
28+
payload = {
29+
"prompt": chat.message,
30+
"threads": chat.threads,
31+
"ctx_size": chat.ctx_size,
32+
"n_predict": chat.n_predict,
33+
"temperature": chat.temperature
34+
}
35+
async def _chat():
36+
async with httpx.AsyncClient() as client:
37+
try:
38+
logger.info(f"Forwarding chat message to BitNet server on port {chat.port}.")
39+
response = await client.post(server_url, json=payload, timeout=300.0)
40+
response.raise_for_status()
41+
return response.json()
42+
except httpx.ReadTimeout:
43+
logger.error(f"ReadTimeout when communicating with BitNet server on port {chat.port}.")
44+
raise HTTPException(status_code=504, detail=f"Request to BitNet server on port {chat.port} timed out.")
45+
except httpx.ConnectError:
46+
logger.error(f"ConnectError when communicating with BitNet server on port {chat.port}.")
47+
raise HTTPException(status_code=503, detail=f"Could not connect to BitNet server on port {chat.port}.")
48+
except httpx.HTTPStatusError as e:
49+
logger.error(f"HTTPStatusError from BitNet server on port {chat.port}: {e.response.status_code} - {e.response.text}", exc_info=True)
50+
raise HTTPException(status_code=e.response.status_code, detail=f"BitNet server error: {e.response.text}")
51+
except Exception as e:
52+
logger.error(f"Unexpected error during chat with BitNet server on port {chat.port}: {str(e)}", exc_info=True)
53+
error_detail = f"An unexpected error occurred while communicating with BitNet server on port {chat.port}: {str(e)}"
54+
raise HTTPException(status_code=500, detail=error_detail)
55+
return await _chat()
56+
57+
class MultiChatRequest(BaseModel):
58+
requests: List[ChatRequest]
59+
60+
async def multichat_with_bitnet(multichat: MultiChatRequest):
61+
logger.info(f"Multichat request received for {len(multichat.requests)} chats.")
62+
async def run_chat(chat_req: ChatRequest):
63+
chat_fn = chat_with_bitnet(chat_req)
64+
return await chat_fn
65+
results = await asyncio.gather(*(run_chat(req) for req in multichat.requests), return_exceptions=True)
66+
formatted = []
67+
for i, res in enumerate(results):
68+
if isinstance(res, Exception):
69+
if isinstance(res, HTTPException):
70+
formatted.append({"error": res.detail})
71+
else:
72+
formatted.append({"error": str(res)})
73+
elif isinstance(res, dict) and "content" in res:
74+
formatted.append(res["content"])
75+
else:
76+
formatted.append(res)
77+
logger.info("Multichat processing completed.")
78+
return {"results": formatted}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import psutil
2+
import os
3+
import asyncio
4+
import logging
5+
import atexit
6+
from typing import Dict, Tuple, Any
7+
8+
logger = logging.getLogger(__name__)
9+
10+
server_processes: Dict[Tuple[str, int], Dict[str, Any]] = {}
11+
server_configs: Dict[Tuple[str, int], Dict[str, Any]] = {}
12+
13+
def get_server_processes():
14+
return server_processes
15+
16+
def get_server_configs():
17+
return server_configs
18+
FASTAPI_PORT = 8080
19+
_atexit_cleanup_completed = False
20+
21+
def _max_server_instances_by_ram(per_server_gb=1):
22+
total_gb = psutil.virtual_memory().total / (1024 ** 3)
23+
used_gb = psutil.virtual_memory().used / (1024 ** 3)
24+
available_gb = total_gb - used_gb
25+
return int(available_gb // per_server_gb)
26+
27+
async def _terminate_server_process(key: tuple[str, int]):
28+
host, port = key
29+
if port == FASTAPI_PORT:
30+
logger.warning(f"Attempt to terminate FastAPI server on port {port} denied.")
31+
return f"Operation denied: Port {port} is used by the FastAPI application and cannot be terminated via this endpoint."
32+
proc_entry = server_processes.get(key)
33+
if not proc_entry:
34+
server_configs.pop(key, None)
35+
logger.info(f"No server process found for key {key} (port {port}) during termination attempt.")
36+
return f"No server process found for key {key} (port {port})."
37+
proc_to_terminate = proc_entry["process"]
38+
pid = proc_entry["pid"]
39+
if proc_to_terminate.returncode is None:
40+
logger.info(f"Attempting to terminate server on port {port} (PID: {pid}).")
41+
try:
42+
proc_to_terminate.terminate()
43+
await asyncio.wait_for(proc_to_terminate.wait(), timeout=5.0)
44+
logger.info(f"Server on port {port} (PID: {pid}) terminated successfully after SIGTERM.")
45+
server_processes.pop(key, None)
46+
server_configs.pop(key, None)
47+
return f"Server on port {port} (PID: {pid}) terminated successfully."
48+
except asyncio.TimeoutError:
49+
logger.warning(f"Server on port {port} (PID: {pid}) did not respond to SIGTERM within timeout. Attempting SIGKILL.")
50+
try:
51+
proc_to_terminate.kill()
52+
await proc_to_terminate.wait()
53+
logger.info(f"Server on port {port} (PID: {pid}) forcefully killed.")
54+
server_processes.pop(key, None)
55+
server_configs.pop(key, None)
56+
return f"Server on port {port} (PID: {pid}) forcefully killed as it did not respond to SIGTERM."
57+
except Exception as e_kill:
58+
logger.error(f"Error forcefully killing server on port {port} (PID: {pid}): {str(e_kill)}", exc_info=True)
59+
return f"Error forcefully killing server on port {port} (PID: {pid}): {str(e_kill)}. Process may still be running."
60+
except Exception as e_term:
61+
logger.error(f"Error terminating server on port {port} (PID: {pid}) with SIGTERM: {str(e_term)}", exc_info=True)
62+
return f"Error terminating server on port {port} (PID: {pid}): {str(e_term)}. Process may still be running."
63+
else:
64+
logger.info(f"Server on port {port} was already stopped (return code: {proc_to_terminate.returncode}). Cleaned up tracking.")
65+
server_processes.pop(key, None)
66+
server_configs.pop(key, None)
67+
return f"Server on port {port} was already stopped. Cleaned up tracking."
68+
69+
async def _terminate_all_servers():
70+
global _atexit_cleanup_completed
71+
if _atexit_cleanup_completed:
72+
return
73+
logger.info("Attempting to terminate all running server processes asynchronously at exit.")
74+
keys_to_terminate = list(server_processes.keys())
75+
tasks = [_terminate_server_process(key) for key in keys_to_terminate]
76+
results = await asyncio.gather(*tasks, return_exceptions=True)
77+
for i, key in enumerate(keys_to_terminate):
78+
result = results[i]
79+
if isinstance(result, Exception):
80+
logger.error(f"Error during atexit termination for server {key}: {result}", exc_info=result)
81+
else:
82+
logger.info(f"Atexit termination for server {key}: {result}")
83+
logger.info("Asynchronous termination of all server processes at exit completed.")
84+
_atexit_cleanup_completed = True
85+
86+
def _run_async_cleanup_on_exit():
87+
try:
88+
asyncio.run(_terminate_all_servers())
89+
except RuntimeError as e:
90+
if ("cannot schedule new futures after shutdown" in str(e).lower() or "event loop is closed" in str(e).lower()):
91+
logger.warning(f"Could not run async cleanup at exit because event loop was closed or shutting down: {e}")
92+
else:
93+
logger.error(f"Unexpected RuntimeError during atexit async cleanup: {e}", exc_info=True)
94+
except Exception as e:
95+
logger.error(f"Unexpected Exception during atexit async cleanup: {e}", exc_info=True)
96+
97+
atexit.register(_run_async_cleanup_on_exit)

0 commit comments

Comments
 (0)