diff --git a/go.mod b/go.mod index 069597d..1515a18 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,3 @@ module large-model-proxy + +go 1.18 diff --git a/main.go b/main.go index bf046fe..98c3464 100644 --- a/main.go +++ b/main.go @@ -71,10 +71,8 @@ type LlmApiModel struct { OwnedBy string `json:"owned_by"` Created int64 `json:"created"` } -type LlmCompletionRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Stream bool `json:"stream"` +type ModelContainingRequest struct { + Model string `json:"model"` } func (rm ResourceManager) getRunningService(name string) RunningService { @@ -207,7 +205,6 @@ type contextKey string var rawConnectionContextKey = contextKey("rawConn") func startLlmApi(llmApi LlmApi, services []ServiceConfig) { - mux := http.NewServeMux() modelToServiceMap := make(map[string]ServiceConfig) models := make([]LlmApiModel, 0) @@ -241,6 +238,18 @@ func startLlmApi(llmApi LlmApi, services []ServiceConfig) { mux.HandleFunc("/v1/completions", func(responseWriter http.ResponseWriter, request *http.Request) { handleCompletions(responseWriter, request, &modelToServiceMap) }) + mux.HandleFunc("/v1/chat/completions", func(responseWriter http.ResponseWriter, request *http.Request) { + handleCompletions(responseWriter, request, &modelToServiceMap) + }) + mux.HandleFunc("/", func(responseWriter http.ResponseWriter, request *http.Request) { + //404 + log.Printf("[LLM Request API] %s request to unsupported URL: %s", request.Method, request.RequestURI) + http.Error( + responseWriter, + fmt.Sprintf("%s %s is not supoprted by large-model-proxy", request.Method, request.RequestURI), + http.StatusNotFound, + ) + }) // Create a custom http.Server that uses ConnContext // to attach the *rawCaptureConnection to each request's Context. @@ -288,20 +297,19 @@ func handleCompletions(responseWriter http.ResponseWriter, request *http.Request return } - var completionRequest LlmCompletionRequest - if err := json.Unmarshal(bodyBytes, &completionRequest); err != nil { - log.Printf("[LLM API Server] Error decoding /v1/completions request: %v\n", err) - http.Error(responseWriter, fmt.Sprintf("Failed to parse request body: %v", err), http.StatusBadRequest) + model, ok := extractModelFromRequest(request.URL.String(), bodyBytes) + if !ok { + http.Error(responseWriter, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest) return } - service, ok := (*modelToServiceMap)[completionRequest.Model] + service, ok := (*modelToServiceMap)[model] if !ok { - log.Printf("[LLM API Server] Unknown model requested: %v\n", completionRequest.Model) - http.Error(responseWriter, fmt.Sprintf("Unknown model: %v", completionRequest.Model), http.StatusBadRequest) + log.Printf("[LLM API Server] Unknown model requested: %v\n", model) + http.Error(responseWriter, fmt.Sprintf("Unknown model: %v", model), http.StatusBadRequest) return } - log.Printf("[LLM API Server] Sending /v1/completions request through to %s\n", service.Name) + log.Printf("[LLM API Server] Sending %s request through to %s\n", request.URL, service.Name) originalWriter := responseWriter hijacker, ok := originalWriter.(http.Hijacker) if !ok { @@ -331,6 +339,17 @@ func handleCompletions(responseWriter http.ResponseWriter, request *http.Request } handleConnection(clientConnection, service, rawRequestBytes) } + +// extractModelFromRequest returns model name and whether reading model name was successful +func extractModelFromRequest(url string, bodyBytes []byte) (string, bool) { + var completionRequest ModelContainingRequest + if err := json.Unmarshal(bodyBytes, &completionRequest); err != nil { + log.Printf("[LLM API Server] Error decoding %s request: %v\n%s", url, err, bodyBytes) + return "", false + } + return completionRequest.Model, true +} + func signalToString(sig os.Signal) string { switch sig { case syscall.SIGINT: diff --git a/main_test.go b/main_test.go index 73c77b0..316700a 100644 --- a/main_test.go +++ b/main_test.go @@ -19,6 +19,7 @@ import ( "time" ) +// LlmCompletionResponse is what /v1/completions returns type LlmCompletionResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -36,6 +37,44 @@ type LlmCompletionResponse struct { } `json:"usage"` } +// LlmCompletionRequest is used by /v1/completions +type LlmCompletionRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Stream bool `json:"stream"` +} + +// LlmChatCompletionRequest is used by /v1/chat/completions +type LlmChatCompletionRequest struct { + Model string `json:"model,omitempty"` + Messages []ChatMessage `json:"messages,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// LlmChatCompletionResponse is what /v1/chat/completions returns +type LlmChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + Delta ChatMessage `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + func connectOnly(test *testing.T, proxyAddress string) { _, err := net.Dial("tcp", proxyAddress) if err != nil { @@ -158,8 +197,8 @@ func idleTimeoutMultipleServices(test *testing.T, serviceOneAddress string, serv func llmApi(test *testing.T) { //sanity check that nothing is running before initial connection assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"}) + client := &http.Client{} - // Create a new GET request req, err := http.NewRequest("GET", "http://localhost:2016/v1/models", nil) if err != nil { test.Fatalf("Failed to create request: %v", err) @@ -228,8 +267,15 @@ func llmApi(test *testing.T) { assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"}) testCompletionRequest(test, "http://localhost:2016", "test-llm-1") + assertPortsAreClosed(test, []string{"localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"}) + + testCompletionStreamingExpectingSuccess(test, "test-llm-1") + testChatCompletionRequestExpectingSuccess(test, "http://localhost:2016", "test-llm-1") + testChatCompletionStreamingExpectingSuccess(test, "http://localhost:2016", "test-llm-1") + llm1Pid := runReadPidCloseConnection(test, "localhost:12018") assertPortsAreClosed(test, []string{"localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"}) + time.Sleep(4 * time.Second) if isProcessRunning(llm1Pid) { @@ -237,9 +283,18 @@ func llmApi(test *testing.T) { } assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12019", "localhost:12020", "localhost:12021", "localhost:12022", "localhost:12023"}) + testChatCompletionRequestExpectingSuccess(test, "http://localhost:2016", "fizz") + assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12021", "localhost:12022", "localhost:12023"}) + testCompletionRequest(test, "http://localhost:2016", "fizz") - llm2Pid := runReadPidCloseConnection(test, "localhost:12020") assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12021", "localhost:12022", "localhost:12023"}) + + testChatCompletionStreamingExpectingSuccess(test, "http://localhost:2016", "fizz") + assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12021", "localhost:12022", "localhost:12023"}) + + testCompletionStreamingExpectingSuccess(test, "fizz") + assertPortsAreClosed(test, []string{"localhost:12017", "localhost:12018", "localhost:12021", "localhost:12022", "localhost:12023"}) + llm2Pid := runReadPidCloseConnection(test, "localhost:12020") time.Sleep(4 * time.Second) if isProcessRunning(llm2Pid) { test.Fatalf("test-llm-2 service is still running, but inactivity timeout should have shut it down by now") @@ -255,8 +310,6 @@ func llmApi(test *testing.T) { testCompletionRequest(test, "http://localhost:2019", "foo") llm2Pid = runReadPidCloseConnection(test, "localhost:12020") - - testCompletionStreaming(test) time.Sleep(4 * time.Second) if isProcessRunning(llm2Pid) { test.Fatalf("test-llm-2 service is still running, but inactivity timeout should have shut it down by now") @@ -273,10 +326,9 @@ func assertPortsAreClosed(test *testing.T, servicesToCheckForClosedPorts []strin } } -func testCompletionStreaming(t *testing.T) { +// testCompletionStreamingExpectingSuccess checks streaming completions from /v1/completions +func testCompletionStreamingExpectingSuccess(t *testing.T, model string) { address := "http://localhost:2016" - model := "fizz" - testPrompt := "This is a test prompt\nЭто проверочный промт\n这是一个测试提示" reqBodyStruct := LlmCompletionRequest{ Model: model, @@ -284,15 +336,124 @@ func testCompletionStreaming(t *testing.T) { Stream: true, } - reqBody, err := json.Marshal(reqBodyStruct) - if err != nil { - t.Fatalf("Failed to marshal JSON: %v", err) + url := fmt.Sprintf("%s/v1/completions", address) + testStreamingRequest(t, url, reqBodyStruct, []string{ + "Hello, this is chunk #1. ", + "Now chunk #2 arrives. ", + "Finally, chunk #3 completes the message.", + fmt.Sprintf("Your prompt was:\n%s", testPrompt), + }, + func(t *testing.T, payload string) string { + var chunkResp LlmCompletionResponse + if err := json.Unmarshal([]byte(payload), &chunkResp); err != nil { + t.Fatalf("Error unmarshalling SSE chunk JSON: %v", err) + } + if len(chunkResp.Choices) == 0 { + t.Fatalf("Received chunk without choices: %+v", chunkResp) + } + return chunkResp.Choices[0].Text + }, + ) +} +func testCompletionRequest(test *testing.T, address string, model string) { + testPrompt := "This is a test prompt\nЭто проверочный промт\n这是一个测试提示" + + // Prepare request body + completionReq := LlmCompletionRequest{ + Model: model, + Prompt: testPrompt, + Stream: false, } + completionResp := sendCompletionRequestExpectingSuccess(test, address, completionReq) + if len(completionResp.Choices) == 0 { + test.Fatalf("No choices returned in completion response: %+v", completionResp) + } + expected := fmt.Sprintf( + "\nThis is a test completion text.\n Your prompt was:\n%s", + testPrompt, + ) - url := fmt.Sprintf("%s/v1/completions", address) + got := completionResp.Choices[0].Text + if got != expected { + test.Fatalf("Completion text mismatch.\nExpected:\n%q\nGot:\n%q", expected, got) + } + + if completionResp.Model != model { + test.Fatalf("Model mismatch.\nExpected:\n%q\nGot:\n%q", model, completionResp.Model) + } +} + +// testChatCompletionRequestExpectingSuccess checks a non-streaming chat completion +func testChatCompletionRequestExpectingSuccess(t *testing.T, address, model string) { + messages := []ChatMessage{ + {Role: "system", Content: "You are a helpful AI assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + } + + chatReq := LlmChatCompletionRequest{ + Model: model, + Messages: messages, + Stream: false, + } + + chatResp := sendChatCompletionRequestExpectingSuccess(t, address, chatReq) + if len(chatResp.Choices) == 0 { + t.Fatalf("No choices returned in chat completion response: %+v", chatResp) + } + + expected := fmt.Sprintf("Hello! This is a response from the test Chat endpoint. The last message was: %q", messages[len(messages)-1].Content) + got := chatResp.Choices[0].Message.Content + if got != expected { + t.Fatalf("Chat completion text mismatch.\nExpected:\n%q\nGot:\n%q", expected, got) + } + + if chatResp.Model != model { + t.Fatalf("Model mismatch.\nExpected:\n%q\nGot:\n%q", model, chatResp.Model) + } +} + +// testChatCompletionStreamingExpectingSuccess checks streaming chat completions from /v1/chat/completions +func testChatCompletionStreamingExpectingSuccess(t *testing.T, address, model string) { + messages := []ChatMessage{ + {Role: "system", Content: "You are a helpful AI assistant."}, + {Role: "user", Content: "Tell me something interesting."}, + {Role: "assistant", Content: "I absolutely will not"}, + {Role: "user", Content: "Thanks\nfor\nnothing!"}, + } + + url := fmt.Sprintf("%s/v1/chat/completions", address) + testStreamingRequest(t, url, LlmChatCompletionRequest{ + Model: model, + Messages: messages, + Stream: true, + }, []string{ + "Hello, this is chunk #1.", + "Your last message was:\n", + "Thanks\nfor\nnothing!", + "", //done chunk which doesn't have a delta + }, func(t *testing.T, payload string) string { + var chunkResp LlmChatCompletionResponse + if err := json.Unmarshal([]byte(payload), &chunkResp); err != nil { + t.Fatalf("Error unmarshalling SSE chunk JSON: %v", err) + } + if len(chunkResp.Choices) == 0 { + t.Fatalf("Received chunk without choices: %+v", chunkResp) + } + chunk := chunkResp.Choices[0].Delta.Content + return chunk + }, + ) +} + +func testStreamingRequest(t *testing.T, url string, requestBodyObject any, expectedChunks []string, readChunkFunc func(t *testing.T, payload string) string) { + + reqBody, err := json.Marshal(requestBodyObject) + if err != nil { + t.Fatalf("%s: Failed to marshal JSON: %v", url, err) + } req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(reqBody)) if err != nil { - t.Fatalf("Failed to create request: %v", err) + t.Fatalf("%s, Failed to create request: %v", url, err) } req.Header.Set("Content-Type", "application/json") @@ -301,101 +462,90 @@ func testCompletionStreaming(t *testing.T) { } resp, err := client.Do(req) if err != nil { - t.Fatalf("Streaming request failed: %v", err) + t.Fatalf("%s: Streaming request failed: %v", url, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - t.Fatalf("Expected status code 200, got %d", resp.StatusCode) + t.Fatalf("%s: Expected status code 200, got %d", url, resp.StatusCode) } - // We expect multiple SSE “data:” lines. Let’s read them line-by-line. scanner := bufio.NewScanner(resp.Body) - var allChunks []string doneReceived := false for scanner.Scan() { line := scanner.Text() - - // Skip empty lines (SSE typically separates events by a blank line) if line == "" { continue } - // SSE lines that carry data start with "data:" if strings.HasPrefix(line, "data: ") { - // The rest after "data: " can be JSON or the [DONE] marker payload := strings.TrimPrefix(line, "data: ") - - // Check if we're done if payload == "[DONE]" { doneReceived = true break } - // Otherwise, parse JSON chunk - var chunkResp LlmCompletionResponse - if err := json.Unmarshal([]byte(payload), &chunkResp); err != nil { - t.Fatalf("Error unmarshalling SSE chunk JSON: %v\nLine: %s", err, line) - } - - if len(chunkResp.Choices) == 0 { - t.Fatalf("Received chunk without choices: %+v", chunkResp) - } - - allChunks = append(allChunks, chunkResp.Choices[0].Text) + chunk := readChunkFunc(t, payload) + allChunks = append(allChunks, chunk) } } if !doneReceived { - t.Fatalf("Did not receive [DONE] marker in SSE stream") - } - expectedChunks := []string{ - "Hello, this is chunk #1. ", - "Now chunk #2 arrives. ", - "Finally, chunk #3 completes the message.", - fmt.Sprintf("Your prompt was:\n%s", testPrompt), + t.Fatalf("%s: Did not receive [DONE] marker in SSE stream", url) } if len(allChunks) != len(expectedChunks) { - t.Fatalf("Expected %d chunks, got %d\nChunks: %+v", len(expectedChunks), len(allChunks), allChunks) + t.Fatalf("%s: Expected %d chunks, got %d\nChunks: %+v", url, len(expectedChunks), len(allChunks), allChunks) } for i, expected := range expectedChunks { if allChunks[i] != expected { - t.Fatalf("Mismatch in chunk #%d.\nExpected: %q\nGot: %q", i+1, expected, allChunks[i]) + t.Fatalf("%s: Mismatch in chunk #%d.\nExpected: %q\nGot: %q", url, i+1, expected, allChunks[i]) } } } -func testCompletionRequest(test *testing.T, address string, model string) { - testPrompt := "This is a test prompt\nЭто проверочный промт\n这是一个测试提示" - // Prepare request body - completionReq := LlmCompletionRequest{ - Model: model, - Prompt: testPrompt, - Stream: false, +func sendChatCompletionRequestExpectingSuccess(t *testing.T, address string, chatReq LlmChatCompletionRequest) LlmChatCompletionResponse { + resp := sendChatCompletionRequest(t, address, chatReq) + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status code 200, got %d", resp.StatusCode) } - completionResp := sendCompletionRequestExpectingSuccess(test, address, completionReq) - if len(completionResp.Choices) == 0 { - test.Fatalf("No choices returned in completion response: %+v", completionResp) + + var chatResp LlmChatCompletionResponse + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { + t.Fatalf("Failed to decode /v1/chat/completions response: %v", err) } - expected := fmt.Sprintf( - "\nThis is a test completion text.\n Your prompt was:\n%s", - testPrompt, - ) + return chatResp +} - got := completionResp.Choices[0].Text - if got != expected { - test.Fatalf("Completion text mismatch.\nExpected:\n%q\nGot:\n%q", expected, got) +// sendChatCompletionRequest sends a POST to /v1/chat/completions with the given JSON body +func sendChatCompletionRequest(t *testing.T, address string, chatReq LlmChatCompletionRequest) *http.Response { + reqBody, err := json.Marshal(chatReq) + if err != nil { + t.Fatalf("Failed to marshal JSON body: %v", err) } - if completionResp.Model != model { - test.Fatalf("Model mismatch.\nExpected:\n%q\nGot:\n%q", model, completionResp.Model) + url := fmt.Sprintf("%s/v1/chat/completions", address) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(reqBody)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("/v1/chat/completions request failed: %v", err) + } + return resp } func sendCompletionRequestExpectingSuccess(test *testing.T, address string, completionReq LlmCompletionRequest) LlmCompletionResponse { @@ -436,7 +586,6 @@ func sendCompletionRequest(test *testing.T, address string, completionReq LlmCom if err != nil { test.Fatalf("/v1/completions Request failed: %v", err) } - return resp } diff --git a/test-server/main.go b/test-server/main.go index 8f57c49..3bbdf72 100644 --- a/test-server/main.go +++ b/test-server/main.go @@ -176,19 +176,19 @@ type ChatCompletionResponse struct { } type ChatCompletionChunk struct { - ID string `json:"id"` - Object string `json:"object"` // e.g. "chat.completion.chunk" - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - // "delta" is how OpenAI streams partial content - Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` - } `json:"choices"` + ID string `json:"id"` + Object string `json:"object"` // e.g. "chat.completion.chunk" + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` +} +type ChatCompletionChoice struct { + Index int `json:"index"` + Delta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + } `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` } func llmApiListen(port *string) { @@ -392,14 +392,7 @@ func handleStreamChat(w http.ResponseWriter, chatRequest LlmChatRequest) { Object: "chat.completion.chunk", Created: time.Now().Unix(), Model: chatRequest.Model, - Choices: []struct { - Index int `json:"index"` - Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` - }{ + Choices: []ChatCompletionChoice{ { Index: 0, }, @@ -411,53 +404,48 @@ func handleStreamChat(w http.ResponseWriter, chatRequest LlmChatRequest) { } response.Choices[0].Delta.Content = chunk - data, err := json.Marshal(response) - if err != nil { - log.Printf("Failed to encode response: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // SSE requires each message to start with `data: ` - _, err = fmt.Fprintf(w, "data: %s\n\n", data) - if err != nil { - log.Printf("Failed to write SSE to client: %v", err) + if !sendResponseChunk(w, response, flusher) { return } - flusher.Flush() time.Sleep(time.Millisecond * 300) } - - doneChunkData := map[string]interface{}{ - "id": "chatcmpl-test-id", - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": chatRequest.Model, - "choices": []map[string]interface{}{ + finishReason := "stop" + sendResponseChunk(w, ChatCompletionChunk{ + ID: "chatcmpl-test-id", + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: chatRequest.Model, + Choices: []ChatCompletionChoice{ { - "index": 0, - "delta": map[string]interface{}{}, - "finish_reason": "stop", + Index: 0, + FinishReason: &finishReason, }, }, - } + }, flusher) - doneChunkBytes, err := json.Marshal(doneChunkData) + _, err := fmt.Fprint(w, "data: [DONE]\n\n") if err != nil { - log.Println("Error converting DONE chunk to JSON:", err) - return + log.Printf("Failed to write [DONE] to client: %v", err) } + flusher.Flush() +} - _, err = w.Write(doneChunkBytes) +func sendResponseChunk(responseWriter http.ResponseWriter, chatCompletionChunk ChatCompletionChunk, flusher http.Flusher) bool { + data, err := json.Marshal(chatCompletionChunk) if err != nil { - log.Printf("Failed to write done chunk to client: %v", err) + log.Printf("Failed to encode chatCompletionChunk: %v", err) + http.Error(responseWriter, err.Error(), http.StatusInternalServerError) + return false } - _, err = fmt.Fprint(w, "data: [DONE]\n\n") + // SSE requires each message to start with `data: ` + _, err = fmt.Fprintf(responseWriter, "data: %s\n\n", data) if err != nil { - log.Printf("Failed to write [DONE] to client: %v", err) + log.Printf("Failed to write SSE to client: %v", err) + return false } flusher.Flush() + return true } func parseAndValidateChatRequest(w http.ResponseWriter, r *http.Request) (LlmChatRequest, bool) {