Skip to content

Commit

Permalink
Implement /v1/chat/completions endpoint support
Browse files Browse the repository at this point in the history
  • Loading branch information
perk11 committed Feb 3, 2025
1 parent 9b38c54 commit 577b038
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 126 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
module large-model-proxy

go 1.18
45 changes: 32 additions & 13 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,8 @@ type LlmApiModel struct {
OwnedBy string `json:"owned_by"`
Created int64 `json:"created"`
}
type LlmCompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
type ModelContainingRequest struct {
Model string `json:"model"`
}

func (rm ResourceManager) getRunningService(name string) RunningService {
Expand Down Expand Up @@ -207,7 +205,6 @@ type contextKey string
var rawConnectionContextKey = contextKey("rawConn")

func startLlmApi(llmApi LlmApi, services []ServiceConfig) {

mux := http.NewServeMux()
modelToServiceMap := make(map[string]ServiceConfig)
models := make([]LlmApiModel, 0)
Expand Down Expand Up @@ -241,6 +238,18 @@ func startLlmApi(llmApi LlmApi, services []ServiceConfig) {
mux.HandleFunc("/v1/completions", func(responseWriter http.ResponseWriter, request *http.Request) {
handleCompletions(responseWriter, request, &modelToServiceMap)
})
mux.HandleFunc("/v1/chat/completions", func(responseWriter http.ResponseWriter, request *http.Request) {
handleCompletions(responseWriter, request, &modelToServiceMap)
})
mux.HandleFunc("/", func(responseWriter http.ResponseWriter, request *http.Request) {
//404
log.Printf("[LLM Request API] %s request to unsupported URL: %s", request.Method, request.RequestURI)
http.Error(
responseWriter,
fmt.Sprintf("%s %s is not supoprted by large-model-proxy", request.Method, request.RequestURI),
http.StatusNotFound,
)
})

// Create a custom http.Server that uses ConnContext
// to attach the *rawCaptureConnection to each request's Context.
Expand Down Expand Up @@ -288,20 +297,19 @@ func handleCompletions(responseWriter http.ResponseWriter, request *http.Request
return
}

var completionRequest LlmCompletionRequest
if err := json.Unmarshal(bodyBytes, &completionRequest); err != nil {
log.Printf("[LLM API Server] Error decoding /v1/completions request: %v\n", err)
http.Error(responseWriter, fmt.Sprintf("Failed to parse request body: %v", err), http.StatusBadRequest)
model, ok := extractModelFromRequest(request.URL.String(), bodyBytes)
if !ok {
http.Error(responseWriter, fmt.Sprintf("Failed to parse request: %v", err), http.StatusBadRequest)
return
}

service, ok := (*modelToServiceMap)[completionRequest.Model]
service, ok := (*modelToServiceMap)[model]
if !ok {
log.Printf("[LLM API Server] Unknown model requested: %v\n", completionRequest.Model)
http.Error(responseWriter, fmt.Sprintf("Unknown model: %v", completionRequest.Model), http.StatusBadRequest)
log.Printf("[LLM API Server] Unknown model requested: %v\n", model)
http.Error(responseWriter, fmt.Sprintf("Unknown model: %v", model), http.StatusBadRequest)
return
}
log.Printf("[LLM API Server] Sending /v1/completions request through to %s\n", service.Name)
log.Printf("[LLM API Server] Sending %s request through to %s\n", request.URL, service.Name)
originalWriter := responseWriter
hijacker, ok := originalWriter.(http.Hijacker)
if !ok {
Expand Down Expand Up @@ -331,6 +339,17 @@ func handleCompletions(responseWriter http.ResponseWriter, request *http.Request
}
handleConnection(clientConnection, service, rawRequestBytes)
}

// extractModelFromRequest returns model name and whether reading model name was successful
func extractModelFromRequest(url string, bodyBytes []byte) (string, bool) {
var completionRequest ModelContainingRequest
if err := json.Unmarshal(bodyBytes, &completionRequest); err != nil {
log.Printf("[LLM API Server] Error decoding %s request: %v\n%s", url, err, bodyBytes)
return "", false
}
return completionRequest.Model, true
}

func signalToString(sig os.Signal) string {
switch sig {
case syscall.SIGINT:
Expand Down
Loading

0 comments on commit 577b038

Please # to comment.