diff --git a/go.mod b/go.mod
index 069597d..1515a18 100644
--- a/go.mod
+++ b/go.mod
@@ -1 +1,3 @@
module large-model-proxy
+
+go 1.18
diff --git a/main.go b/main.go
index bf046fe..98c3464 100644
--- a/main.go
+++ b/main.go
@@ -71,10 +71,8 @@ type LlmApiModel struct {
OwnedBy string `json:"owned_by"`
Created int64 `json:"created"`
}
-type LlmCompletionRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt"`
- Stream bool `json:"stream"`
+type ModelContainingRequest struct {
+ Model string `json:"model"`
}
func (rm ResourceManager) getRunningService(name string) RunningService {
@@ -207,7 +205,6 @@ type contextKey string
var rawConnectionContextKey = contextKey("rawConn")
func startLlmApi(llmApi LlmApi, services []ServiceConfig) {
-
mux := http.NewServeMux()
modelToServiceMap := make(map[string]ServiceConfig)
models := make([]LlmApiModel, 0)
@@ -241,6 +238,18 @@ func startLlmApi(llmApi LlmApi, services []ServiceConfig) {
mux.HandleFunc("/v1/completions", func(responseWriter http.ResponseWriter, request *http.Request) {
handleCompletions(responseWriter, request, &modelToServiceMap)
})
+ mux.HandleFunc("/v1/chat/completions", func(responseWriter http.ResponseWriter, request *http.Request) {
+ handleCompletions(responseWriter, request, &modelToServiceMap)
+ })
+ mux.HandleFunc("/", func(responseWriter http.ResponseWriter, request *http.Request) {
+ //404
+ log.Printf("[LLM Request API] %s request to unsupported URL: %s", request.Method, request.RequestURI)
+ http.Error(
+ responseWriter,
+ fmt.Sprintf("%s %s is not supoprted by large-model-proxy", request.Method, request.RequestURI),
+ http.StatusNotFound,
+ )
+ })
// Create a custom http.Server that uses ConnContext
// to attach the *rawCaptureConnection to each request's Context.
@@ -288,20 +297,19 @@ func handleCompletions(responseWriter http.ResponseWriter, request *http.Request
return
}
- var completionRequest LlmCompletionRequest
- if err := json.Unmarshal(bodyBytes, &completionRequest); err != nil {
- log.Printf("[LLM API Server] Error decoding /v1/completions request: %v\n", err)
- http.Error(responseWriter, fmt.Sprintf("Failed to parse request body: %v", err), http.StatusBadRequest)
+ model, ok := extractModelFromRequest(request.URL.String(), bodyBytes)
+ if !ok {
+ http.Error(responseWriter, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
return
}
- service, ok := (*modelToServiceMap)[completionRequest.Model]
+ service, ok := (*modelToServiceMap)[model]
if !ok {
- log.Printf("[LLM API Server] Unknown model requested: %v\n", completionRequest.Model)
- http.Error(responseWriter, fmt.Sprintf("Unknown model: %v", completionRequest.Model), http.StatusBadRequest)
+ log.Printf("[LLM API Server] Unknown model requested: %v\n", model)
+ http.Error(responseWriter, fmt.Sprintf("Unknown model: %v", model), http.StatusBadRequest)
return
}
- log.Printf("[LLM API Server] Sending /v1/completions request through to %s\n", service.Name)
+ log.Printf("[LLM API Server] Sending %s request through to %s\n", request.URL, service.Name)
originalWriter := responseWriter
hijacker, ok := originalWriter.(http.Hijacker)
if !ok {
@@ -331,6 +339,17 @@ func handleCompletions(responseWriter http.ResponseWriter, request *http.Request
}
handleConnection(clientConnection, service, rawRequestBytes)
}
+
+// extractModelFromRequest returns model name and whether reading model name was successful
+func extractModelFromRequest(url string, bodyBytes []byte) (string, bool) {
+ var completionRequest ModelContainingRequest
+ if err := json.Unmarshal(bodyBytes, &completionRequest); err != nil {
+ log.Printf("[LLM API Server] Error decoding %s request: %v\n%s", url, err, bodyBytes)
+ return "", false
+ }
+ return completionRequest.Model, true
+}
+
func signalToString(sig os.Signal) string {
switch sig {
case syscall.SIGINT:
diff --git a/main_test.go b/main_test.go
index 73c77b0..316700a 100644
--- a/main_test.go
+++ b/main_test.go
@@ -19,6 +19,7 @@ import (
"time"
)
+// LlmCompletionResponse is what /v1/completions returns
type LlmCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
@@ -36,6 +37,44 @@ type LlmCompletionResponse struct {
} `json:"usage"`
}
+// LlmCompletionRequest is used by /v1/completions
+type LlmCompletionRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ Stream bool `json:"stream"`
+}
+
+// LlmChatCompletionRequest is used by /v1/chat/completions
+type LlmChatCompletionRequest struct {
+ Model string `json:"model,omitempty"`
+ Messages []ChatMessage `json:"messages,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+}
+
+type ChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+// LlmChatCompletionResponse is what /v1/chat/completions returns
+type LlmChatCompletionResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ Choices []struct {
+ Index int `json:"index"`
+ Message ChatMessage `json:"message"`
+ Delta ChatMessage `json:"delta"`
+ FinishReason string `json:"finish_reason"`
+ } `json:"choices"`
+ Usage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ } `json:"usage"`
+}
+
func connectOnly(test *testing.T, proxyAddress string) {
_, err := net.Dial("tcp", proxyAddress)
if err != nil {
@@ -158,8 +197,8 @@ func idleTimeoutMultipleServices(test *testing.T, serviceOneAddress string, serv
func llmApi(test *testing.T) {
//sanity check that nothing is running before initial connection
assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"})
+
client := &http.Client{}
- // Create a new GET request
req, err := http.NewRequest("GET", "http://localhost:2016/v1/models", nil)
if err != nil {
test.Fatalf("Failed to create request: %v", err)
@@ -228,8 +267,15 @@ func llmApi(test *testing.T) {
assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"})
testCompletionRequest(test, "http://localhost:2016", "test-llm-1")
+ assertPortsAreClosed(test, []string{"localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"})
+
+ testCompletionStreamingExpectingSuccess(test, "test-llm-1")
+ testChatCompletionRequestExpectingSuccess(test, "http://localhost:2016", "test-llm-1")
+ testChatCompletionStreamingExpectingSuccess(test, "http://localhost:2016", "test-llm-1")
+
llm1Pid := runReadPidCloseConnection(test, "localhost:12018")
assertPortsAreClosed(test, []string{"localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"})
+
time.Sleep(4 * time.Second)
if isProcessRunning(llm1Pid) {
@@ -237,9 +283,18 @@ func llmApi(test *testing.T) {
}
assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"})
+ testChatCompletionRequestExpectingSuccess(test, "http://localhost:2016", "fizz")
+ assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12021", "localhost:12022", "localhost:12023"})
+
testCompletionRequest(test, "http://localhost:2016", "fizz")
- llm2Pid := runReadPidCloseConnection(test, "localhost:12020")
assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12021", "localhost:12022", "localhost:12023"})
+
+ testChatCompletionStreamingExpectingSuccess(test, "http://localhost:2016", "fizz")
+ assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12021", "localhost:12022", "localhost:12023"})
+
+ testCompletionStreamingExpectingSuccess(test, "fizz")
+ assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12021", "localhost:12022", "localhost:12023"})
+ llm2Pid := runReadPidCloseConnection(test, "localhost:12020")
time.Sleep(4 * time.Second)
if isProcessRunning(llm2Pid) {
test.Fatalf("test-llm-2 service is still running, but inactivity timeout should have shut it down by now")
@@ -255,8 +310,6 @@ func llmApi(test *testing.T) {
testCompletionRequest(test, "http://localhost:2019", "foo")
llm2Pid = runReadPidCloseConnection(test, "localhost:12020")
-
- testCompletionStreaming(test)
time.Sleep(4 * time.Second)
if isProcessRunning(llm2Pid) {
test.Fatalf("test-llm-2 service is still running, but inactivity timeout should have shut it down by now")
@@ -273,10 +326,9 @@ func assertPortsAreClosed(test *testing.T, servicesToCheckForClosedPorts []strin
}
}
-func testCompletionStreaming(t *testing.T) {
+// testCompletionStreamingExpectingSuccess checks streaming completions from /v1/completions
+func testCompletionStreamingExpectingSuccess(t *testing.T, model string) {
address := "http://localhost:2016"
- model := "fizz"
-
testPrompt := "This is a test prompt\nЭто проверочный промт\n这是一个测试提示"
reqBodyStruct := LlmCompletionRequest{
Model: model,
@@ -284,15 +336,124 @@ func testCompletionStreaming(t *testing.T) {
Stream: true,
}
- reqBody, err := json.Marshal(reqBodyStruct)
- if err != nil {
- t.Fatalf("Failed to marshal JSON: %v", err)
+ url := fmt.Sprintf("%s/v1/completions", address)
+ testStreamingRequest(t, url, reqBodyStruct, []string{
+ "Hello, this is chunk #1. ",
+ "Now chunk #2 arrives. ",
+ "Finally, chunk #3 completes the message.",
+ fmt.Sprintf("Your prompt was:\n%s", testPrompt),
+ },
+ func(t *testing.T, payload string) string {
+ var chunkResp LlmCompletionResponse
+ if err := json.Unmarshal([]byte(payload), &chunkResp); err != nil {
+ t.Fatalf("Error unmarshalling SSE chunk JSON: %v", err)
+ }
+ if len(chunkResp.Choices) == 0 {
+ t.Fatalf("Received chunk without choices: %+v", chunkResp)
+ }
+ return chunkResp.Choices[0].Text
+ },
+ )
+}
+func testCompletionRequest(test *testing.T, address string, model string) {
+ testPrompt := "This is a test prompt\nЭто проверочный промт\n这是一个测试提示"
+
+ // Prepare request body
+ completionReq := LlmCompletionRequest{
+ Model: model,
+ Prompt: testPrompt,
+ Stream: false,
}
+ completionResp := sendCompletionRequestExpectingSuccess(test, address, completionReq)
+ if len(completionResp.Choices) == 0 {
+ test.Fatalf("No choices returned in completion response: %+v", completionResp)
+ }
+ expected := fmt.Sprintf(
+ "\nThis is a test completion text.\n Your prompt was:\n%s",
+ testPrompt,
+ )
- url := fmt.Sprintf("%s/v1/completions", address)
+ got := completionResp.Choices[0].Text
+ if got != expected {
+ test.Fatalf("Completion text mismatch.\nExpected:\n%q\nGot:\n%q", expected, got)
+ }
+
+ if completionResp.Model != model {
+ test.Fatalf("Model mismatch.\nExpected:\n%q\nGot:\n%q", model, completionResp.Model)
+ }
+}
+
+// testChatCompletionRequestExpectingSuccess checks a non-streaming chat completion
+func testChatCompletionRequestExpectingSuccess(t *testing.T, address, model string) {
+ messages := []ChatMessage{
+ {Role: "system", Content: "You are a helpful AI assistant."},
+ {Role: "user", Content: "Hello, how are you?"},
+ }
+
+ chatReq := LlmChatCompletionRequest{
+ Model: model,
+ Messages: messages,
+ Stream: false,
+ }
+
+ chatResp := sendChatCompletionRequestExpectingSuccess(t, address, chatReq)
+ if len(chatResp.Choices) == 0 {
+ t.Fatalf("No choices returned in chat completion response: %+v", chatResp)
+ }
+
+ expected := fmt.Sprintf("Hello! This is a response from the test Chat endpoint. The last message was: %q", messages[len(messages)-1].Content)
+ got := chatResp.Choices[0].Message.Content
+ if got != expected {
+ t.Fatalf("Chat completion text mismatch.\nExpected:\n%q\nGot:\n%q", expected, got)
+ }
+
+ if chatResp.Model != model {
+ t.Fatalf("Model mismatch.\nExpected:\n%q\nGot:\n%q", model, chatResp.Model)
+ }
+}
+
+// testChatCompletionStreamingExpectingSuccess checks streaming chat completions from /v1/chat/completions
+func testChatCompletionStreamingExpectingSuccess(t *testing.T, address, model string) {
+ messages := []ChatMessage{
+ {Role: "system", Content: "You are a helpful AI assistant."},
+ {Role: "user", Content: "Tell me something interesting."},
+ {Role: "assistant", Content: "I absolutely will not"},
+ {Role: "user", Content: "Thanks\nfor\nnothing!"},
+ }
+
+ url := fmt.Sprintf("%s/v1/chat/completions", address)
+ testStreamingRequest(t, url, LlmChatCompletionRequest{
+ Model: model,
+ Messages: messages,
+ Stream: true,
+ }, []string{
+ "Hello, this is chunk #1.",
+ "Your last message was:\n",
+ "Thanks\nfor\nnothing!",
+ "", //done chunk which doesn't have a delta
+ }, func(t *testing.T, payload string) string {
+ var chunkResp LlmChatCompletionResponse
+ if err := json.Unmarshal([]byte(payload), &chunkResp); err != nil {
+ t.Fatalf("Error unmarshalling SSE chunk JSON: %v", err)
+ }
+ if len(chunkResp.Choices) == 0 {
+ t.Fatalf("Received chunk without choices: %+v", chunkResp)
+ }
+ chunk := chunkResp.Choices[0].Delta.Content
+ return chunk
+ },
+ )
+}
+
+func testStreamingRequest(t *testing.T, url string, requestBodyObject any, expectedChunks []string, readChunkFunc func(t *testing.T, payload string) string) {
+
+ reqBody, err := json.Marshal(requestBodyObject)
+ if err != nil {
+ t.Fatalf("%s: Failed to marshal JSON: %v", url, err)
+ }
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(reqBody))
if err != nil {
- t.Fatalf("Failed to create request: %v", err)
+ t.Fatalf("%s, Failed to create request: %v", url, err)
}
req.Header.Set("Content-Type", "application/json")
@@ -301,101 +462,90 @@ func testCompletionStreaming(t *testing.T) {
}
resp, err := client.Do(req)
if err != nil {
- t.Fatalf("Streaming request failed: %v", err)
+ t.Fatalf("%s: Streaming request failed: %v", url, err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
- t.Fatalf("Expected status code 200, got %d", resp.StatusCode)
+ t.Fatalf("%s: Expected status code 200, got %d", url, resp.StatusCode)
}
- // We expect multiple SSE “data:” lines. Let’s read them line-by-line.
scanner := bufio.NewScanner(resp.Body)
-
var allChunks []string
doneReceived := false
for scanner.Scan() {
line := scanner.Text()
-
- // Skip empty lines (SSE typically separates events by a blank line)
if line == "" {
continue
}
- // SSE lines that carry data start with "data:"
if strings.HasPrefix(line, "data: ") {
- // The rest after "data: " can be JSON or the [DONE] marker
payload := strings.TrimPrefix(line, "data: ")
-
- // Check if we're done
if payload == "[DONE]" {
doneReceived = true
break
}
- // Otherwise, parse JSON chunk
- var chunkResp LlmCompletionResponse
- if err := json.Unmarshal([]byte(payload), &chunkResp); err != nil {
- t.Fatalf("Error unmarshalling SSE chunk JSON: %v\nLine: %s", err, line)
- }
-
- if len(chunkResp.Choices) == 0 {
- t.Fatalf("Received chunk without choices: %+v", chunkResp)
- }
-
- allChunks = append(allChunks, chunkResp.Choices[0].Text)
+ chunk := readChunkFunc(t, payload)
+ allChunks = append(allChunks, chunk)
}
}
if !doneReceived {
- t.Fatalf("Did not receive [DONE] marker in SSE stream")
- }
- expectedChunks := []string{
- "Hello, this is chunk #1. ",
- "Now chunk #2 arrives. ",
- "Finally, chunk #3 completes the message.",
- fmt.Sprintf("Your prompt was:\n%s", testPrompt),
+ t.Fatalf("%s: Did not receive [DONE] marker in SSE stream", url)
}
if len(allChunks) != len(expectedChunks) {
- t.Fatalf("Expected %d chunks, got %d\nChunks: %+v", len(expectedChunks), len(allChunks), allChunks)
+ t.Fatalf("%s: Expected %d chunks, got %d\nChunks: %+v", url, len(expectedChunks), len(allChunks), allChunks)
}
for i, expected := range expectedChunks {
if allChunks[i] != expected {
- t.Fatalf("Mismatch in chunk #%d.\nExpected: %q\nGot: %q", i+1, expected, allChunks[i])
+ t.Fatalf("%s: Mismatch in chunk #%d.\nExpected: %q\nGot: %q", url, i+1, expected, allChunks[i])
}
}
}
-func testCompletionRequest(test *testing.T, address string, model string) {
- testPrompt := "This is a test prompt\nЭто проверочный промт\n这是一个测试提示"
- // Prepare request body
- completionReq := LlmCompletionRequest{
- Model: model,
- Prompt: testPrompt,
- Stream: false,
+func sendChatCompletionRequestExpectingSuccess(t *testing.T, address string, chatReq LlmChatCompletionRequest) LlmChatCompletionResponse {
+ resp := sendChatCompletionRequest(t, address, chatReq)
+ defer func() {
+ _ = resp.Body.Close()
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("Expected status code 200, got %d", resp.StatusCode)
}
- completionResp := sendCompletionRequestExpectingSuccess(test, address, completionReq)
- if len(completionResp.Choices) == 0 {
- test.Fatalf("No choices returned in completion response: %+v", completionResp)
+
+ var chatResp LlmChatCompletionResponse
+ if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
+ t.Fatalf("Failed to decode /v1/chat/completions response: %v", err)
}
- expected := fmt.Sprintf(
- "\nThis is a test completion text.\n Your prompt was:\n%s",
- testPrompt,
- )
+ return chatResp
+}
- got := completionResp.Choices[0].Text
- if got != expected {
- test.Fatalf("Completion text mismatch.\nExpected:\n%q\nGot:\n%q", expected, got)
+// sendChatCompletionRequest sends a POST to /v1/chat/completions with the given JSON body
+func sendChatCompletionRequest(t *testing.T, address string, chatReq LlmChatCompletionRequest) *http.Response {
+ reqBody, err := json.Marshal(chatReq)
+ if err != nil {
+ t.Fatalf("Failed to marshal JSON body: %v", err)
}
- if completionResp.Model != model {
- test.Fatalf("Model mismatch.\nExpected:\n%q\nGot:\n%q", model, completionResp.Model)
+ url := fmt.Sprintf("%s/v1/chat/completions", address)
+ req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(reqBody))
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
}
+ req.Header.Set("Content-Type", "application/json")
+
+ client := &http.Client{Timeout: 5 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatalf("/v1/chat/completions request failed: %v", err)
+ }
+ return resp
}
func sendCompletionRequestExpectingSuccess(test *testing.T, address string, completionReq LlmCompletionRequest) LlmCompletionResponse {
@@ -436,7 +586,6 @@ func sendCompletionRequest(test *testing.T, address string, completionReq LlmCom
if err != nil {
test.Fatalf("/v1/completions Request failed: %v", err)
}
-
return resp
}
diff --git a/test-server/main.go b/test-server/main.go
index 8f57c49..3bbdf72 100644
--- a/test-server/main.go
+++ b/test-server/main.go
@@ -176,19 +176,19 @@ type ChatCompletionResponse struct {
}
type ChatCompletionChunk struct {
- ID string `json:"id"`
- Object string `json:"object"` // e.g. "chat.completion.chunk"
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []struct {
- Index int `json:"index"`
- // "delta" is how OpenAI streams partial content
- Delta struct {
- Role string `json:"role,omitempty"`
- Content string `json:"content,omitempty"`
- } `json:"delta"`
- FinishReason *string `json:"finish_reason,omitempty"`
- } `json:"choices"`
+ ID string `json:"id"`
+ Object string `json:"object"` // e.g. "chat.completion.chunk"
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ Choices []ChatCompletionChoice `json:"choices"`
+}
+type ChatCompletionChoice struct {
+ Index int `json:"index"`
+ Delta struct {
+ Role string `json:"role,omitempty"`
+ Content string `json:"content,omitempty"`
+ } `json:"delta"`
+ FinishReason *string `json:"finish_reason,omitempty"`
}
func llmApiListen(port *string) {
@@ -392,14 +392,7 @@ func handleStreamChat(w http.ResponseWriter, chatRequest LlmChatRequest) {
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: chatRequest.Model,
- Choices: []struct {
- Index int `json:"index"`
- Delta struct {
- Role string `json:"role,omitempty"`
- Content string `json:"content,omitempty"`
- } `json:"delta"`
- FinishReason *string `json:"finish_reason,omitempty"`
- }{
+ Choices: []ChatCompletionChoice{
{
Index: 0,
},
@@ -411,53 +404,48 @@ func handleStreamChat(w http.ResponseWriter, chatRequest LlmChatRequest) {
}
response.Choices[0].Delta.Content = chunk
- data, err := json.Marshal(response)
- if err != nil {
- log.Printf("Failed to encode response: %v", err)
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- // SSE requires each message to start with `data: `
- _, err = fmt.Fprintf(w, "data: %s\n\n", data)
- if err != nil {
- log.Printf("Failed to write SSE to client: %v", err)
+ if !sendResponseChunk(w, response, flusher) {
return
}
- flusher.Flush()
time.Sleep(time.Millisecond * 300)
}
-
- doneChunkData := map[string]interface{}{
- "id": "chatcmpl-test-id",
- "object": "chat.completion.chunk",
- "created": time.Now().Unix(),
- "model": chatRequest.Model,
- "choices": []map[string]interface{}{
+ finishReason := "stop"
+ sendResponseChunk(w, ChatCompletionChunk{
+ ID: "chatcmpl-test-id",
+ Object: "chat.completion.chunk",
+ Created: time.Now().Unix(),
+ Model: chatRequest.Model,
+ Choices: []ChatCompletionChoice{
{
- "index": 0,
- "delta": map[string]interface{}{},
- "finish_reason": "stop",
+ Index: 0,
+ FinishReason: &finishReason,
},
},
- }
+ }, flusher)
- doneChunkBytes, err := json.Marshal(doneChunkData)
+ _, err := fmt.Fprint(w, "data: [DONE]\n\n")
if err != nil {
- log.Println("Error converting DONE chunk to JSON:", err)
- return
+ log.Printf("Failed to write [DONE] to client: %v", err)
}
+ flusher.Flush()
+}
- _, err = w.Write(doneChunkBytes)
+func sendResponseChunk(responseWriter http.ResponseWriter, chatCompletionChunk ChatCompletionChunk, flusher http.Flusher) bool {
+ data, err := json.Marshal(chatCompletionChunk)
if err != nil {
- log.Printf("Failed to write done chunk to client: %v", err)
+ log.Printf("Failed to encode chatCompletionChunk: %v", err)
+ http.Error(responseWriter, err.Error(), http.StatusInternalServerError)
+ return false
}
- _, err = fmt.Fprint(w, "data: [DONE]\n\n")
+ // SSE requires each message to start with `data: `
+ _, err = fmt.Fprintf(responseWriter, "data: %s\n\n", data)
if err != nil {
- log.Printf("Failed to write [DONE] to client: %v", err)
+ log.Printf("Failed to write SSE to client: %v", err)
+ return false
}
flusher.Flush()
+ return true
}
func parseAndValidateChatRequest(w http.ResponseWriter, r *http.Request) (LlmChatRequest, bool) {