From e43a9c12836cdcb47371cbcf4e99eadd26891104 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Tue, 2 Apr 2024 12:32:39 +0200 Subject: [PATCH 01/14] Add a generic Provider interface --- backend/executor/queue.go | 45 +++--- backend/main.go | 2 - .../{agent/agent.go => providers/openai.go} | 134 ++++++++++-------- backend/providers/providers.go | 104 ++++++++++++++ backend/services/openai.go | 95 ------------- 5 files changed, 201 insertions(+), 179 deletions(-) rename backend/{agent/agent.go => providers/openai.go} (79%) create mode 100644 backend/providers/providers.go delete mode 100644 backend/services/openai.go diff --git a/backend/executor/queue.go b/backend/executor/queue.go index 6b26023..669a6c9 100644 --- a/backend/executor/queue.go +++ b/backend/executor/queue.go @@ -8,11 +8,10 @@ import ( "log" "github.com/docker/docker/api/types/container" - "github.com/semanser/ai-coder/agent" "github.com/semanser/ai-coder/database" gmodel "github.com/semanser/ai-coder/graph/model" "github.com/semanser/ai-coder/graph/subscriptions" - "github.com/semanser/ai-coder/services" + "github.com/semanser/ai-coder/providers" "github.com/semanser/ai-coder/websocket" ) @@ -52,6 +51,14 @@ func CleanQueue(flowId int64) { func ProcessQueue(flowId int64, db *database.Queries) { log.Println("Starting tasks processor for queue %d", flowId) + provider, err := providers.ProviderFactory(providers.ProviderOpenAI) + + if err != nil { + log.Printf("failed to get provider: %v", err) + CleanQueue(flowId) + return + } + go func() { for { select { @@ -78,14 +85,14 @@ func ProcessQueue(flowId int64, db *database.Queries) { }) if task.Type.String == "input" { - err := processInputTask(db, task) + err := processInputTask(provider, db, task) if err != nil { log.Printf("failed to process input: %w", err) continue } - nextTask, err := getNextTask(db, task.FlowID.Int64) + nextTask, err := getNextTask(provider, db, task.FlowID.Int64) if err != nil { log.Printf("failed to get next task: %w", err) @@ -111,7 +118,7 @@ func ProcessQueue(flowId int64, db *database.Queries) { log.Printf("failed to process terminal: %w", err) continue } - nextTask, err := getNextTask(db, task.FlowID.Int64) + nextTask, err := getNextTask(provider, db, task.FlowID.Int64) if err != nil { log.Printf("failed to get next task: %w", err) @@ -129,7 +136,7 @@ func ProcessQueue(flowId int64, db *database.Queries) { continue } - nextTask, err := getNextTask(db, task.FlowID.Int64) + nextTask, err := getNextTask(provider, db, task.FlowID.Int64) if err != nil { log.Printf("failed to get next task: %w", err) @@ -156,7 +163,7 @@ func ProcessQueue(flowId int64, db *database.Queries) { continue } - nextTask, err := getNextTask(db, task.FlowID.Int64) + nextTask, err := getNextTask(provider, db, task.FlowID.Int64) if err != nil { log.Printf("failed to get next task: %w", err) @@ -171,7 +178,7 @@ func ProcessQueue(flowId int64, db *database.Queries) { } func processBrowserTask(db *database.Queries, task database.Task) error { - var args = agent.BrowserArgs{} + var args = providers.BrowserArgs{} err := json.Unmarshal([]byte(task.Args.String), &args) if err != nil { return fmt.Errorf("failed to unmarshal args: %v", err) @@ -180,7 +187,7 @@ func processBrowserTask(db *database.Queries, task database.Task) error { var url = args.Url var screenshotName string - if args.Action == agent.Read { + if args.Action == providers.Read { content, screenshot, err := Content(url) if err != nil { @@ -200,7 +207,7 @@ func processBrowserTask(db *database.Queries, task database.Task) error { } } - if args.Action == agent.Url { + if args.Action == providers.Url { content, screenshot, err := URLs(url) if err != nil { @@ -247,7 +254,7 @@ func processDoneTask(db *database.Queries, task database.Task) error { return nil } -func processInputTask(db *database.Queries, task database.Task) error { +func processInputTask(provider providers.Provider, db *database.Queries, task database.Task) error { tasks, err := db.ReadTasksByFlowId(context.Background(), sql.NullInt64{ Int64: task.FlowID.Int64, Valid: true, @@ -260,13 +267,13 @@ func processInputTask(db *database.Queries, task database.Task) error { // This is the first task in the flow. // We need to get the basic flow data as well as spin up the container if len(tasks) == 1 { - summary, err := services.GetMessageSummary(task.Message.String, 10) + summary, err := provider.Summary(task.Message.String, 10) if err != nil { return fmt.Errorf("failed to get message summary: %w", err) } - dockerImage, err := services.GetDockerImageName(task.Message.String) + dockerImage, err := provider.DockerImageName(task.Message.String) if err != nil { return fmt.Errorf("failed to get docker image name: %w", err) @@ -376,7 +383,7 @@ func processAskTask(db *database.Queries, task database.Task) error { } func processTerminalTask(db *database.Queries, task database.Task) error { - var args = agent.TerminalArgs{} + var args = providers.TerminalArgs{} err := json.Unmarshal([]byte(task.Args.String), &args) if err != nil { return fmt.Errorf("failed to unmarshal args: %v", err) @@ -401,7 +408,7 @@ func processTerminalTask(db *database.Queries, task database.Task) error { } func processCodeTask(db *database.Queries, task database.Task) error { - var args = agent.CodeArgs{} + var args = providers.CodeArgs{} err := json.Unmarshal([]byte(task.Args.String), &args) if err != nil { return fmt.Errorf("failed to unmarshal args: %v", err) @@ -410,7 +417,7 @@ func processCodeTask(db *database.Queries, task database.Task) error { var cmd = "" var results = "" - if args.Action == agent.ReadFile { + if args.Action == providers.ReadFile { // TODO consider using dockerClient.CopyFromContainer command instead cmd = fmt.Sprintf("cat %s", args.Path) results, err = ExecCommand(task.FlowID.Int64, cmd, db) @@ -420,7 +427,7 @@ func processCodeTask(db *database.Queries, task database.Task) error { } } - if args.Action == agent.UpdateFile { + if args.Action == providers.UpdateFile { err = WriteFile(task.FlowID.Int64, args.Content, args.Path, db) if err != nil { @@ -446,7 +453,7 @@ func processCodeTask(db *database.Queries, task database.Task) error { return nil } -func getNextTask(db *database.Queries, flowId int64) (*database.Task, error) { +func getNextTask(provider providers.Provider, db *database.Queries, flowId int64) (*database.Task, error) { flow, err := db.ReadFlow(context.Background(), flowId) if err != nil { @@ -472,7 +479,7 @@ func getNextTask(db *database.Queries, flowId int64) (*database.Task, error) { } } - c := agent.NextTask(agent.AgentPrompt{ + c := provider.NextTask(providers.NextTaskOptions{ Tasks: tasks, DockerImage: flow.ContainerImage.String, }) diff --git a/backend/main.go b/backend/main.go index 00065bd..d0e1024 100644 --- a/backend/main.go +++ b/backend/main.go @@ -17,7 +17,6 @@ import ( "github.com/semanser/ai-coder/database" "github.com/semanser/ai-coder/executor" "github.com/semanser/ai-coder/router" - "github.com/semanser/ai-coder/services" ) //go:embed templates/prompts/*.tmpl @@ -55,7 +54,6 @@ func main() { r := router.New(queries) assets.Init(promptTemplates, scriptTemplates) - services.Init() err = executor.InitClient() if err != nil { diff --git a/backend/agent/agent.go b/backend/providers/openai.go similarity index 79% rename from backend/agent/agent.go rename to backend/providers/openai.go index 30792a1..c406765 100644 --- a/backend/agent/agent.go +++ b/backend/providers/openai.go @@ -1,8 +1,7 @@ -package agent +package providers import ( "context" - "database/sql" "encoding/json" "fmt" "log" @@ -12,62 +11,93 @@ import ( "github.com/semanser/ai-coder/assets" "github.com/semanser/ai-coder/config" "github.com/semanser/ai-coder/database" - "github.com/semanser/ai-coder/services" "github.com/semanser/ai-coder/templates" ) -type Message string - -type InputArgs struct { - Query string +type OpenAIProvider struct { + client *openai.Client } -type TerminalArgs struct { - Input string - Message +func (p OpenAIProvider) New() Provider { + cfg := openai.DefaultConfig(config.Config.OpenAIKey) + cfg.BaseURL = config.Config.OpenAIServerURL + client := openai.NewClientWithConfig(cfg) + + return OpenAIProvider{client} } -type BrowserAction string +func (p OpenAIProvider) Summary(query string, n int) (string, error) { + prompt, err := templates.Render(assets.PromptTemplates, "prompts/summary.tmpl", map[string]any{ + "Text": query, + "N": n, + }) + if err != nil { + return "", err + } -const ( - Read BrowserAction = "read" - Url BrowserAction = "url" -) + req := openai.ChatCompletionRequest{ + Temperature: 0.0, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: prompt, + }, + }, + TopP: 0.2, + N: 1, + } -type BrowserArgs struct { - Url string - Action BrowserAction - Message -} + resp, err := p.client.CreateChatCompletion(context.Background(), req) + if err != nil { + return "", fmt.Errorf("completion error: %v", err) + } -type CodeAction string + choices := resp.Choices -const ( - ReadFile CodeAction = "read_file" - UpdateFile CodeAction = "update_file" -) + if len(choices) == 0 { + return "", fmt.Errorf("no choices found") + } -type CodeArgs struct { - Action CodeAction - Content string - Path string - Message + return choices[0].Message.Content, nil } -type AskArgs struct { - Message -} +func (p OpenAIProvider) DockerImageName(task string) (string, error) { + prompt, err := templates.Render(assets.PromptTemplates, "prompts/docker.tmpl", map[string]any{ + "Task": task, + }) + if err != nil { + return "", err + } -type DoneArgs struct { - Message -} + req := openai.ChatCompletionRequest{ + Temperature: 0.0, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: prompt, + }, + }, + TopP: 0.2, + N: 1, + } + + resp, err := p.client.CreateChatCompletion(context.Background(), req) + if err != nil { + return "", fmt.Errorf("completion error: %v", err) + } + + choices := resp.Choices + + if len(choices) == 0 { + return "", fmt.Errorf("no choices found") + } -type AgentPrompt struct { - Tasks []database.Task - DockerImage string + return choices[0].Message.Content, nil } -func NextTask(args AgentPrompt) *database.Task { +func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { log.Println("Getting next task") prompt, err := templates.Render(assets.PromptTemplates, "prompts/agent.tmpl", args) @@ -184,7 +214,7 @@ func NextTask(args AgentPrompt) *database.Task { N: 1, } - resp, err := services.OpenAIclient.CreateChatCompletion(context.Background(), req) + resp, err := p.client.CreateChatCompletion(context.Background(), req) if err != nil { log.Printf("Failed to get response from OpenAI %v", err) return defaultAskTask("There was an error getting the next task") @@ -295,25 +325,3 @@ func NextTask(args AgentPrompt) *database.Task { return &task } - -func defaultAskTask(message string) *database.Task { - task := database.Task{ - Type: database.StringToNullString("ask"), - } - - task.Args = database.StringToNullString("{}") - task.Message = sql.NullString{ - String: fmt.Sprintf("%s. What should I do next?", message), - Valid: true, - } - - return &task -} - -func extractArgs[T any](openAIargs string, args *T) (*T, error) { - err := json.Unmarshal([]byte(openAIargs), args) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal args: %v", err) - } - return args, nil -} diff --git a/backend/providers/providers.go b/backend/providers/providers.go new file mode 100644 index 0000000..3b6d97e --- /dev/null +++ b/backend/providers/providers.go @@ -0,0 +1,104 @@ +package providers + +import ( + "database/sql" + "encoding/json" + "fmt" + + "github.com/semanser/ai-coder/database" +) + +type ProviderType string + +const ( + ProviderOpenAI ProviderType = "openai" +) + +type Provider interface { + New() Provider + Summary(query string, n int) (string, error) + DockerImageName(task string) (string, error) + NextTask(args NextTaskOptions) *database.Task +} + +type NextTaskOptions struct { + Tasks []database.Task + DockerImage string +} + +func ProviderFactory(provider ProviderType) (Provider, error) { + switch provider { + case "openai": + return OpenAIProvider{}.New(), nil + default: + return nil, fmt.Errorf("unknown provider: %s", provider) + } +} + +type Message string + +type InputArgs struct { + Query string +} + +type TerminalArgs struct { + Input string + Message +} + +type BrowserAction string + +const ( + Read BrowserAction = "read" + Url BrowserAction = "url" +) + +type BrowserArgs struct { + Url string + Action BrowserAction + Message +} + +type CodeAction string + +const ( + ReadFile CodeAction = "read_file" + UpdateFile CodeAction = "update_file" +) + +type CodeArgs struct { + Action CodeAction + Content string + Path string + Message +} + +type AskArgs struct { + Message +} + +type DoneArgs struct { + Message +} + +func defaultAskTask(message string) *database.Task { + task := database.Task{ + Type: database.StringToNullString("ask"), + } + + task.Args = database.StringToNullString("{}") + task.Message = sql.NullString{ + String: fmt.Sprintf("%s. What should I do next?", message), + Valid: true, + } + + return &task +} + +func extractArgs[T any](openAIargs string, args *T) (*T, error) { + err := json.Unmarshal([]byte(openAIargs), args) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal args: %v", err) + } + return args, nil +} diff --git a/backend/services/openai.go b/backend/services/openai.go deleted file mode 100644 index b638216..0000000 --- a/backend/services/openai.go +++ /dev/null @@ -1,95 +0,0 @@ -package services - -import ( - "context" - "fmt" - "log" - - "github.com/sashabaranov/go-openai" - "github.com/semanser/ai-coder/assets" - "github.com/semanser/ai-coder/config" - "github.com/semanser/ai-coder/templates" -) - -var OpenAIclient *openai.Client - -func Init() { - cfg := openai.DefaultConfig(config.Config.OpenAIKey) - cfg.BaseURL = config.Config.OpenAIServerURL - OpenAIclient = openai.NewClientWithConfig(cfg) - - if config.Config.OpenAIKey == "" { - log.Fatal("OPEN_AI_KEY is not set") - } -} - -func GetMessageSummary(query string, n int) (string, error) { - prompt, err := templates.Render(assets.PromptTemplates, "prompts/summary.tmpl", map[string]any{ - "Text": query, - "N": n, - }) - if err != nil { - return "", err - } - - req := openai.ChatCompletionRequest{ - Temperature: 0.0, - Model: openai.GPT3Dot5Turbo, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - Content: prompt, - }, - }, - TopP: 0.2, - N: 1, - } - - resp, err := OpenAIclient.CreateChatCompletion(context.Background(), req) - if err != nil { - return "", fmt.Errorf("completion error: %v", err) - } - - choices := resp.Choices - - if len(choices) == 0 { - return "", fmt.Errorf("no choices found") - } - - return choices[0].Message.Content, nil -} - -func GetDockerImageName(task string) (string, error) { - prompt, err := templates.Render(assets.PromptTemplates, "prompts/docker.tmpl", map[string]any{ - "Task": task, - }) - if err != nil { - return "", err - } - - req := openai.ChatCompletionRequest{ - Temperature: 0.0, - Model: openai.GPT3Dot5Turbo, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - Content: prompt, - }, - }, - TopP: 0.2, - N: 1, - } - - resp, err := OpenAIclient.CreateChatCompletion(context.Background(), req) - if err != nil { - return "", fmt.Errorf("completion error: %v", err) - } - - choices := resp.Choices - - if len(choices) == 0 { - return "", fmt.Errorf("no choices found") - } - - return choices[0].Message.Content, nil -} From e53c759b0bb8666e3aa7f513bd3ea87ae6184f07 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Tue, 2 Apr 2024 13:46:36 +0200 Subject: [PATCH 02/14] Implement OpenAI client --- backend/go.mod | 7 +- backend/go.sum | 26 +++-- backend/providers/openai.go | 206 +++++++++++++++++------------------- 3 files changed, 117 insertions(+), 122 deletions(-) diff --git a/backend/go.mod b/backend/go.mod index e7c5f5b..e894e64 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -16,7 +16,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/mattn/go-sqlite3 v1.14.22 github.com/pressly/goose/v3 v3.19.2 - github.com/sashabaranov/go-openai v1.20.4 + github.com/tmc/langchaingo v0.1.8 github.com/vektah/gqlparser/v2 v2.5.11 ) @@ -28,9 +28,9 @@ require ( github.com/bytedance/sonic v1.11.3 // indirect github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect github.com/chenzhuoyu/iasm v0.9.1 // indirect - github.com/containerd/log v0.1.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/distribution/reference v0.5.0 // indirect + github.com/dlclark/regexp2 v1.11.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect @@ -54,11 +54,11 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/morikuni/aec v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.0 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pkoukk/tiktoken-go v0.1.6 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sethvargo/go-retry v0.2.4 // indirect @@ -87,7 +87,6 @@ require ( golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect - golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.19.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index f1b55b1..921ae35 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -36,6 +36,7 @@ github.com/bytedance/sonic v1.11.3 h1:jRN+yEjakWh8aK5FzrciUHG8OFXK+4/KrAX/ysEtHA github.com/bytedance/sonic v1.11.3/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4= github.com/caarlos0/env/v10 v10.0.0 h1:yIHUBZGsyqCnpTkbjk8asUlx6RFhhEs+h7TOBdgdzXA= github.com/caarlos0/env/v10 v10.0.0/go.mod h1:ZfulV76NvVPw3tm591U4SwL3Xx9ldzBP9aGxzeN7G18= +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= @@ -58,6 +59,8 @@ github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+ github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0= github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= +github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/docker/cli v24.0.7+incompatible h1:wa/nIwYFW7BVTGa7SWPVyyXU9lgORqUb1xfI36MSkFg= github.com/docker/cli v24.0.7+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/docker v25.0.5+incompatible h1:UmQydMduGkrD5nQde1mecF/YnSbTOaPeFIeP5C4W+DE= @@ -155,8 +158,8 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02 github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= @@ -200,6 +203,8 @@ github.com/pierrec/lz4/v4 v4.1.18 h1:xaKrnTkyoqfh1YItXl56+6KJNVYWlEEPuAQW9xsplYQ github.com/pierrec/lz4/v4 v4.1.18/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= +github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pressly/goose/v3 v3.19.2 h1:z1yuD41jS4iaqLkyjkzGkKBz4rgyz/BYtCyMMGHlgzQ= @@ -212,8 +217,6 @@ github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-openai v1.20.4 h1:095xQ/fAtRa0+Rj21sezVJABgKfGPNbyx/sAN/hJUmg= -github.com/sashabaranov/go-openai v1.20.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= @@ -238,6 +241,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tmc/langchaingo v0.1.8 h1:nrImgh0aWdu3stJTHz80N60WGwPWY8HXCK10gQny7bA= +github.com/tmc/langchaingo v0.1.8/go.mod h1:iNBfS9e6jxBKsJSPWnlqNhoVWgdA3D1g5cdFJjbIZNQ= github.com/tursodatabase/libsql-client-go v0.0.0-20240220085343-4ae0eb9d0898 h1:1MvEhzI5pvP27e9Dzz861mxk9WzXZLSJwzOU67cKTbU= github.com/tursodatabase/libsql-client-go v0.0.0-20240220085343-4ae0eb9d0898/go.mod h1:9bKuHS7eZh/0mJndbUOrCx8Ej3PlsRDszj4L7oVYMPQ= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= @@ -346,12 +351,13 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20240102182953-50ed04b92917 h1:rcS6EyEaoCO52hQDupoSfrxI3R6C2Tq741is7X8OvnM= -google.golang.org/genproto/googleapis/api v0.0.0-20240102182953-50ed04b92917/go.mod h1:CmlNWB9lSezaYELKS5Ym1r44VrrbPUa7JTvw+6MbpJ0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240102182953-50ed04b92917 h1:6G8oQ016D88m1xAKljMlBOOGWDZkes4kMhgGFlf8WcQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240102182953-50ed04b92917/go.mod h1:xtjpI3tXFPP051KaWnhvxkiubL/6dJ18vLVf7q2pTOU= -google.golang.org/grpc v1.61.1 h1:kLAiWrZs7YeDM6MumDe7m3y4aM6wacLzM1Y/wiLP9XY= -google.golang.org/grpc v1.61.1/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= +google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= +google.golang.org/genproto/googleapis/api v0.0.0-20240123012728-ef4313101c80 h1:Lj5rbfG876hIAYFjqiJnPHfhXbv+nzTWfm04Fg/XSVU= +google.golang.org/genproto/googleapis/api v0.0.0-20240123012728-ef4313101c80/go.mod h1:4jWUdICTdgc3Ibxmr8nAJiiLHwQBY0UI0XZcEMaFKaA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 h1:AjyfHzEPEFp/NpvfN5g+KDla3EMojjhRVZc1i7cj+oM= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80/go.mod h1:PAREbraiVEVGVdTZsVWjSbbTtSyGbAgIIvni8a8CD5s= +google.golang.org/grpc v1.62.0 h1:HQKZ/fa1bXkX1oFOvSjmZEUL8wLSaZTjCcLAlmZRtdk= +google.golang.org/grpc v1.62.0/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/backend/providers/openai.go b/backend/providers/openai.go index c406765..a2b0a87 100644 --- a/backend/providers/openai.go +++ b/backend/providers/openai.go @@ -3,25 +3,33 @@ package providers import ( "context" "encoding/json" - "fmt" "log" "github.com/invopop/jsonschema" - openai "github.com/sashabaranov/go-openai" "github.com/semanser/ai-coder/assets" "github.com/semanser/ai-coder/config" "github.com/semanser/ai-coder/database" "github.com/semanser/ai-coder/templates" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/openai" + "github.com/tmc/langchaingo/schema" ) type OpenAIProvider struct { - client *openai.Client + client *openai.LLM } func (p OpenAIProvider) New() Provider { - cfg := openai.DefaultConfig(config.Config.OpenAIKey) - cfg.BaseURL = config.Config.OpenAIServerURL - client := openai.NewClientWithConfig(cfg) + client, err := openai.New( + openai.WithToken(config.Config.OpenAIKey), + openai.WithModel(config.Config.OpenAIModel), + openai.WithBaseURL(config.Config.OpenAIServerURL), + ) + + if err != nil { + log.Fatalf("Failed to create OpenAI client: %v", err) + } return OpenAIProvider{client} } @@ -35,31 +43,18 @@ func (p OpenAIProvider) Summary(query string, n int) (string, error) { return "", err } - req := openai.ChatCompletionRequest{ - Temperature: 0.0, - Model: openai.GPT3Dot5Turbo, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - Content: prompt, - }, - }, - TopP: 0.2, - N: 1, - } - - resp, err := p.client.CreateChatCompletion(context.Background(), req) - if err != nil { - return "", fmt.Errorf("completion error: %v", err) - } - - choices := resp.Choices - - if len(choices) == 0 { - return "", fmt.Errorf("no choices found") - } - - return choices[0].Message.Content, nil + response, err := llms.GenerateFromSinglePrompt( + context.Background(), + p.client, + prompt, + llms.WithTemperature(0.0), + // Use a simpler model for this task + llms.WithModel(config.Config.OpenAIModel), + llms.WithTopP(0.2), + llms.WithN(1), + ) + + return response, err } func (p OpenAIProvider) DockerImageName(task string) (string, error) { @@ -70,31 +65,18 @@ func (p OpenAIProvider) DockerImageName(task string) (string, error) { return "", err } - req := openai.ChatCompletionRequest{ - Temperature: 0.0, - Model: openai.GPT3Dot5Turbo, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - Content: prompt, - }, - }, - TopP: 0.2, - N: 1, - } - - resp, err := p.client.CreateChatCompletion(context.Background(), req) - if err != nil { - return "", fmt.Errorf("completion error: %v", err) - } - - choices := resp.Choices - - if len(choices) == 0 { - return "", fmt.Errorf("no choices found") - } - - return choices[0].Message.Content, nil + response, err := llms.GenerateFromSinglePrompt( + context.Background(), + p.client, + prompt, + llms.WithTemperature(0.0), + // Use a simpler model for this task + llms.WithModel(config.Config.OpenAIModel), + llms.WithTopP(0.2), + llms.WithN(1), + ) + + return response, err } func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { @@ -113,42 +95,42 @@ func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { return defaultAskTask("There was an error getting the next task") } - tools := []openai.Tool{ + tools := []llms.Tool{ { - Type: openai.ToolTypeFunction, - Function: &openai.FunctionDefinition{ + Type: "function", + Function: &llms.FunctionDefinition{ Name: "terminal", Description: "Calls a terminal command", Parameters: jsonschema.Reflect(&TerminalArgs{}).Definitions["TerminalArgs"], }, }, { - Type: openai.ToolTypeFunction, - Function: &openai.FunctionDefinition{ + Type: "function", + Function: &llms.FunctionDefinition{ Name: "browser", Description: "Opens a browser to look for additional information", Parameters: jsonschema.Reflect(&BrowserArgs{}).Definitions["BrowserArgs"], }, }, { - Type: openai.ToolTypeFunction, - Function: &openai.FunctionDefinition{ + Type: "function", + Function: &llms.FunctionDefinition{ Name: "code", Description: "Modifies or reads code files", Parameters: jsonschema.Reflect(&CodeArgs{}).Definitions["CodeArgs"], }, }, { - Type: openai.ToolTypeFunction, - Function: &openai.FunctionDefinition{ + Type: "function", + Function: &llms.FunctionDefinition{ Name: "ask", Description: "Sends a question to the user for additional information", Parameters: jsonschema.Reflect(&AskArgs{}).Definitions["AskArgs"], }, }, { - Type: openai.ToolTypeFunction, - Function: &openai.FunctionDefinition{ + Type: "function", + Function: &llms.FunctionDefinition{ Name: "done", Description: "Mark the whole task as done. Should be called at the very end when everything is completed", Parameters: jsonschema.Reflect(&DoneArgs{}).Definitions["DoneArgs"], @@ -156,65 +138,73 @@ func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { }, } - var messages []openai.ChatCompletionMessage + var messages []llms.MessageContent - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleSystem, - Content: prompt, + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeSystem, + Parts: []llms.ContentPart{ + llms.TextPart(prompt), + }, }) for _, task := range args.Tasks { if task.Type.String == "input" { - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleUser, - Content: task.Args.String, + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextPart(prompt), + }, }) } if task.ToolCallID.String != "" { - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, - ToolCalls: []openai.ToolCall{ - { + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.ToolCall{ ID: task.ToolCallID.String, - Function: openai.FunctionCall{ + FunctionCall: &schema.FunctionCall{ Name: task.Type.String, Arguments: task.Args.String, }, - Type: openai.ToolTypeFunction, + Type: "function", }, }, }) - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleTool, - ToolCallID: task.ToolCallID.String, - Content: task.Results.String, + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.ToolCallResponse{ + ToolCallID: task.ToolCallID.String, + Name: task.Type.String, + Content: task.Results.String, + }, + }, }) } // This Ask was generated by the agent itself in case of some error (not the OpenAI) if task.Type.String == "ask" && task.ToolCallID.String == "" { - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, - Content: task.Message.String, + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.TextPart(task.Message.String), + }, }) } } - req := openai.ChatCompletionRequest{ - Temperature: 0.0, - Model: config.Config.OpenAIModel, - Messages: messages, - ResponseFormat: &openai.ChatCompletionResponseFormat{ - Type: openai.ChatCompletionResponseFormatTypeJSONObject, - }, - TopP: 0.2, - Tools: tools, - N: 1, - } + resp, err := p.client.GenerateContent( + context.Background(), + messages, + llms.WithTemperature(0.0), + llms.WithModel(config.Config.OpenAIModel), + llms.WithTopP(0.2), + llms.WithN(1), + llms.WithTools(tools), + ) - resp, err := p.client.CreateChatCompletion(context.Background(), req) if err != nil { log.Printf("Failed to get response from OpenAI %v", err) return defaultAskTask("There was an error getting the next task") @@ -227,7 +217,7 @@ func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { return defaultAskTask("Looks like I couldn't find a task to run") } - toolCalls := choices[0].Message.ToolCalls + toolCalls := choices[0].ToolCalls if len(toolCalls) == 0 { log.Println("No tool calls found, asking user") @@ -236,18 +226,18 @@ func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { tool := toolCalls[0] - if tool.Function.Name == "" { + if tool.FunctionCall.Name == "" { log.Println("No tool found, asking user") return defaultAskTask("The next task is empty, I don't know what to do next") } task := database.Task{ - Type: database.StringToNullString(tool.Function.Name), + Type: database.StringToNullString(tool.FunctionCall.Name), } - switch tool.Function.Name { + switch tool.FunctionCall.Name { case "terminal": - params, err := extractArgs(tool.Function.Arguments, &TerminalArgs{}) + params, err := extractArgs(tool.FunctionCall.Arguments, &TerminalArgs{}) if err != nil { log.Printf("Failed to extract terminal args, asking user: %v", err) return defaultAskTask("There was an error running the terminal command") @@ -269,7 +259,7 @@ func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { task.Status = database.StringToNullString("in_progress") case "browser": - params, err := extractArgs(tool.Function.Arguments, &BrowserArgs{}) + params, err := extractArgs(tool.FunctionCall.Arguments, &BrowserArgs{}) if err != nil { log.Printf("Failed to extract browser args, asking user: %v", err) return defaultAskTask("There was an error opening the browser") @@ -282,7 +272,7 @@ func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { task.Args = database.StringToNullString(string(args)) task.Message = database.StringToNullString(string(params.Message)) case "code": - params, err := extractArgs(tool.Function.Arguments, &CodeArgs{}) + params, err := extractArgs(tool.FunctionCall.Arguments, &CodeArgs{}) if err != nil { log.Printf("Failed to extract code args, asking user: %v", err) return defaultAskTask("There was an error reading or updating the file") @@ -295,7 +285,7 @@ func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { task.Args = database.StringToNullString(string(args)) task.Message = database.StringToNullString(string(params.Message)) case "ask": - params, err := extractArgs(tool.Function.Arguments, &AskArgs{}) + params, err := extractArgs(tool.FunctionCall.Arguments, &AskArgs{}) if err != nil { log.Printf("Failed to extract ask args, asking user: %v", err) return defaultAskTask("There was an error asking the user for additional information") @@ -308,7 +298,7 @@ func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { task.Args = database.StringToNullString(string(args)) task.Message = database.StringToNullString(string(params.Message)) case "done": - params, err := extractArgs(tool.Function.Arguments, &DoneArgs{}) + params, err := extractArgs(tool.FunctionCall.Arguments, &DoneArgs{}) if err != nil { log.Printf("Failed to extract done args, asking user: %v", err) return defaultAskTask("There was an error marking the task as done") From 14dcafd6b802c35698e6b6c189a1ad27cffb8c8f Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Tue, 2 Apr 2024 14:07:28 +0200 Subject: [PATCH 03/14] Code cleaning --- backend/executor/processor.go | 288 ++++++++++++++++++++++++++++++ backend/executor/queue.go | 305 ++------------------------------ backend/providers/openai.go | 308 +++----------------------------- backend/providers/providers.go | 311 +++++++++++++++++++++++++++++---- backend/providers/types.go | 47 +++++ 5 files changed, 651 insertions(+), 608 deletions(-) create mode 100644 backend/executor/processor.go create mode 100644 backend/providers/types.go diff --git a/backend/executor/processor.go b/backend/executor/processor.go new file mode 100644 index 0000000..f51b363 --- /dev/null +++ b/backend/executor/processor.go @@ -0,0 +1,288 @@ +package executor + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "log" + + "github.com/docker/docker/api/types/container" + "github.com/semanser/ai-coder/database" + gmodel "github.com/semanser/ai-coder/graph/model" + "github.com/semanser/ai-coder/graph/subscriptions" + "github.com/semanser/ai-coder/providers" + "github.com/semanser/ai-coder/websocket" +) + +func processBrowserTask(db *database.Queries, task database.Task) error { + var args = providers.BrowserArgs{} + err := json.Unmarshal([]byte(task.Args.String), &args) + if err != nil { + return fmt.Errorf("failed to unmarshal args: %v", err) + } + + var url = args.Url + var screenshotName string + + if args.Action == providers.Read { + content, screenshot, err := Content(url) + + if err != nil { + return fmt.Errorf("failed to get content: %w", err) + } + + log.Println("Screenshot taken") + screenshotName = screenshot + + _, err = db.UpdateTaskResults(context.Background(), database.UpdateTaskResultsParams{ + ID: task.ID, + Results: database.StringToNullString(content), + }) + + if err != nil { + return fmt.Errorf("failed to update task results: %w", err) + } + } + + if args.Action == providers.Url { + content, screenshot, err := URLs(url) + + if err != nil { + return fmt.Errorf("failed to get content: %w", err) + } + + screenshotName = screenshot + + _, err = db.UpdateTaskResults(context.Background(), database.UpdateTaskResultsParams{ + ID: task.ID, + Results: database.StringToNullString(content), + }) + + if err != nil { + return fmt.Errorf("failed to update task results: %w", err) + } + } + + subscriptions.BroadcastBrowserUpdated(task.FlowID.Int64, &gmodel.Browser{ + URL: url, + // TODO Use a dynamic URL + ScreenshotURL: "http://localhost:8080/browser/" + screenshotName, + }) + + return nil +} + +func processDoneTask(db *database.Queries, task database.Task) error { + flow, err := db.UpdateFlowStatus(context.Background(), database.UpdateFlowStatusParams{ + ID: task.FlowID.Int64, + Status: database.StringToNullString("finished"), + }) + + if err != nil { + return fmt.Errorf("failed to update task status: %w", err) + } + + subscriptions.BroadcastFlowUpdated(task.FlowID.Int64, &gmodel.Flow{ + ID: uint(flow.ID), + Status: gmodel.FlowStatus("finished"), + Terminal: &gmodel.Terminal{}, + }) + + return nil +} + +func processInputTask(provider providers.Provider, db *database.Queries, task database.Task) error { + tasks, err := db.ReadTasksByFlowId(context.Background(), sql.NullInt64{ + Int64: task.FlowID.Int64, + Valid: true, + }) + + if err != nil { + return fmt.Errorf("failed to get tasks by flow id: %w", err) + } + + // This is the first task in the flow. + // We need to get the basic flow data as well as spin up the container + if len(tasks) == 1 { + summary, err := provider.Summary(task.Message.String, 10) + + if err != nil { + return fmt.Errorf("failed to get message summary: %w", err) + } + + dockerImage, err := provider.DockerImageName(task.Message.String) + + if err != nil { + return fmt.Errorf("failed to get docker image name: %w", err) + } + + flow, err := db.UpdateFlowName(context.Background(), database.UpdateFlowNameParams{ + ID: task.FlowID.Int64, + Name: database.StringToNullString(summary), + }) + + if err != nil { + return fmt.Errorf("failed to update flow: %w", err) + } + + subscriptions.BroadcastFlowUpdated(flow.ID, &gmodel.Flow{ + ID: uint(flow.ID), + Name: summary, + Terminal: &gmodel.Terminal{ + ContainerName: dockerImage, + Connected: false, + }, + }) + + msg := websocket.FormatTerminalSystemOutput(fmt.Sprintf("Initializing the docker image %s...", dockerImage)) + l, err := db.CreateLog(context.Background(), database.CreateLogParams{ + FlowID: task.FlowID, + Message: msg, + Type: "system", + }) + + if err != nil { + return fmt.Errorf("error creating log: %w", err) + } + + subscriptions.BroadcastTerminalLogsAdded(flow.ID, &gmodel.Log{ + ID: uint(l.ID), + Text: msg, + }) + + terminalContainerName := TerminalName(flow.ID) + terminalContainerID, err := SpawnContainer(context.Background(), + terminalContainerName, + &container.Config{ + Image: dockerImage, + Cmd: []string{"tail", "-f", "/dev/null"}, + }, + &container.HostConfig{}, + db, + ) + + if err != nil { + return fmt.Errorf("failed to spawn container: %w", err) + } + + subscriptions.BroadcastFlowUpdated(flow.ID, &gmodel.Flow{ + ID: uint(flow.ID), + Name: summary, + Terminal: &gmodel.Terminal{ + Connected: true, + ContainerName: dockerImage, + }, + }) + + _, err = db.UpdateFlowContainer(context.Background(), database.UpdateFlowContainerParams{ + ID: flow.ID, + ContainerID: sql.NullInt64{Int64: terminalContainerID, Valid: true}, + }) + + if err != nil { + return fmt.Errorf("failed to update flow container: %w", err) + } + + msg = websocket.FormatTerminalSystemOutput("Container initialized. Ready to execute commands.") + l, err = db.CreateLog(context.Background(), database.CreateLogParams{ + FlowID: task.FlowID, + Message: msg, + Type: "system", + }) + + if err != nil { + return fmt.Errorf("error creating log: %w", err) + } + subscriptions.BroadcastTerminalLogsAdded(flow.ID, &gmodel.Log{ + ID: uint(l.ID), + Text: msg, + }) + } + + return nil +} + +func processAskTask(db *database.Queries, task database.Task) error { + task, err := db.UpdateTaskStatus(context.Background(), database.UpdateTaskStatusParams{ + Status: database.StringToNullString("finished"), + ID: task.ID, + }) + + if err != nil { + return fmt.Errorf("failed to find task with id %d: %w", task.ID, err) + } + + return nil +} + +func processTerminalTask(db *database.Queries, task database.Task) error { + var args = providers.TerminalArgs{} + err := json.Unmarshal([]byte(task.Args.String), &args) + if err != nil { + return fmt.Errorf("failed to unmarshal args: %v", err) + } + + results, err := ExecCommand(task.FlowID.Int64, args.Input, db) + + if err != nil { + return fmt.Errorf("failed to execute command: %w", err) + } + + _, err = db.UpdateTaskResults(context.Background(), database.UpdateTaskResultsParams{ + ID: task.ID, + Results: database.StringToNullString(results), + }) + + if err != nil { + return fmt.Errorf("failed to update task results: %w", err) + } + + return nil +} + +func processCodeTask(db *database.Queries, task database.Task) error { + var args = providers.CodeArgs{} + err := json.Unmarshal([]byte(task.Args.String), &args) + if err != nil { + return fmt.Errorf("failed to unmarshal args: %v", err) + } + + var cmd = "" + var results = "" + + if args.Action == providers.ReadFile { + // TODO consider using dockerClient.CopyFromContainer command instead + cmd = fmt.Sprintf("cat %s", args.Path) + results, err = ExecCommand(task.FlowID.Int64, cmd, db) + + if err != nil { + return fmt.Errorf("error executing cat command: %w", err) + } + } + + if args.Action == providers.UpdateFile { + err = WriteFile(task.FlowID.Int64, args.Content, args.Path, db) + + if err != nil { + return fmt.Errorf("error writing a file: %w", err) + } + + results = "File updated" + } + + if err != nil { + return fmt.Errorf("failed to execute command: %w", err) + } + + _, err = db.UpdateTaskResults(context.Background(), database.UpdateTaskResultsParams{ + ID: task.ID, + Results: database.StringToNullString(results), + }) + + if err != nil { + return fmt.Errorf("failed to update task results: %w", err) + } + + return nil +} diff --git a/backend/executor/queue.go b/backend/executor/queue.go index 669a6c9..b658b91 100644 --- a/backend/executor/queue.go +++ b/backend/executor/queue.go @@ -3,16 +3,13 @@ package executor import ( "context" "database/sql" - "encoding/json" "fmt" "log" - "github.com/docker/docker/api/types/container" "github.com/semanser/ai-coder/database" gmodel "github.com/semanser/ai-coder/graph/model" "github.com/semanser/ai-coder/graph/subscriptions" "github.com/semanser/ai-coder/providers" - "github.com/semanser/ai-coder/websocket" ) var queue = make(map[int64]chan database.Task) @@ -45,14 +42,16 @@ func CleanQueue(flowId int64) { stopChannels[flowId] = nil } - log.Println(fmt.Sprintf("Queue for flow %d cleaned", flowId)) + log.Printf("Queue %d cleaned", flowId) } func ProcessQueue(flowId int64, db *database.Queries) { - log.Println("Starting tasks processor for queue %d", flowId) + log.Println("Starting tasks processor for queue", flowId) provider, err := providers.ProviderFactory(providers.ProviderOpenAI) + log.Println("Using provider: ", provider.Name()) + if err != nil { log.Printf("failed to get provider: %v", err) CleanQueue(flowId) @@ -88,14 +87,14 @@ func ProcessQueue(flowId int64, db *database.Queries) { err := processInputTask(provider, db, task) if err != nil { - log.Printf("failed to process input: %w", err) + log.Printf("failed to process input: %v", err) continue } nextTask, err := getNextTask(provider, db, task.FlowID.Int64) if err != nil { - log.Printf("failed to get next task: %w", err) + log.Printf("failed to get next task: %v", err) continue } @@ -106,7 +105,7 @@ func ProcessQueue(flowId int64, db *database.Queries) { err := processAskTask(db, task) if err != nil { - log.Printf("failed to process ask: %w", err) + log.Printf("failed to process ask: %v", err) continue } } @@ -115,13 +114,13 @@ func ProcessQueue(flowId int64, db *database.Queries) { err := processTerminalTask(db, task) if err != nil { - log.Printf("failed to process terminal: %w", err) + log.Printf("failed to process terminal: %v", err) continue } nextTask, err := getNextTask(provider, db, task.FlowID.Int64) if err != nil { - log.Printf("failed to get next task: %w", err) + log.Printf("failed to get next task: %v", err) continue } @@ -132,14 +131,14 @@ func ProcessQueue(flowId int64, db *database.Queries) { err := processCodeTask(db, task) if err != nil { - log.Printf("failed to process code: %w", err) + log.Printf("failed to process code: %v", err) continue } nextTask, err := getNextTask(provider, db, task.FlowID.Int64) if err != nil { - log.Printf("failed to get next task: %w", err) + log.Printf("failed to get next task: %v", err) continue } @@ -150,7 +149,7 @@ func ProcessQueue(flowId int64, db *database.Queries) { err := processDoneTask(db, task) if err != nil { - log.Printf("failed to process done: %w", err) + log.Printf("failed to process done: %v", err) continue } } @@ -159,14 +158,14 @@ func ProcessQueue(flowId int64, db *database.Queries) { err := processBrowserTask(db, task) if err != nil { - log.Printf("failed to process browser: %w", err) + log.Printf("failed to process browser: %v", err) continue } nextTask, err := getNextTask(provider, db, task.FlowID.Int64) if err != nil { - log.Printf("failed to get next task: %w", err) + log.Printf("failed to get next task: %v", err) continue } @@ -177,282 +176,6 @@ func ProcessQueue(flowId int64, db *database.Queries) { }() } -func processBrowserTask(db *database.Queries, task database.Task) error { - var args = providers.BrowserArgs{} - err := json.Unmarshal([]byte(task.Args.String), &args) - if err != nil { - return fmt.Errorf("failed to unmarshal args: %v", err) - } - - var url = args.Url - var screenshotName string - - if args.Action == providers.Read { - content, screenshot, err := Content(url) - - if err != nil { - return fmt.Errorf("failed to get content: %w", err) - } - - log.Println("Screenshot taken") - screenshotName = screenshot - - _, err = db.UpdateTaskResults(context.Background(), database.UpdateTaskResultsParams{ - ID: task.ID, - Results: database.StringToNullString(content), - }) - - if err != nil { - return fmt.Errorf("failed to update task results: %w", err) - } - } - - if args.Action == providers.Url { - content, screenshot, err := URLs(url) - - if err != nil { - return fmt.Errorf("failed to get content: %w", err) - } - - screenshotName = screenshot - - _, err = db.UpdateTaskResults(context.Background(), database.UpdateTaskResultsParams{ - ID: task.ID, - Results: database.StringToNullString(content), - }) - - if err != nil { - return fmt.Errorf("failed to update task results: %w", err) - } - } - - subscriptions.BroadcastBrowserUpdated(task.FlowID.Int64, &gmodel.Browser{ - URL: url, - // TODO Use a dynamic URL - ScreenshotURL: "http://localhost:8080/browser/" + screenshotName, - }) - - return nil -} - -func processDoneTask(db *database.Queries, task database.Task) error { - flow, err := db.UpdateFlowStatus(context.Background(), database.UpdateFlowStatusParams{ - ID: task.FlowID.Int64, - Status: database.StringToNullString("finished"), - }) - - if err != nil { - return fmt.Errorf("failed to update task status: %w", err) - } - - subscriptions.BroadcastFlowUpdated(task.FlowID.Int64, &gmodel.Flow{ - ID: uint(flow.ID), - Status: gmodel.FlowStatus("finished"), - Terminal: &gmodel.Terminal{}, - }) - - return nil -} - -func processInputTask(provider providers.Provider, db *database.Queries, task database.Task) error { - tasks, err := db.ReadTasksByFlowId(context.Background(), sql.NullInt64{ - Int64: task.FlowID.Int64, - Valid: true, - }) - - if err != nil { - return fmt.Errorf("failed to get tasks by flow id: %w", err) - } - - // This is the first task in the flow. - // We need to get the basic flow data as well as spin up the container - if len(tasks) == 1 { - summary, err := provider.Summary(task.Message.String, 10) - - if err != nil { - return fmt.Errorf("failed to get message summary: %w", err) - } - - dockerImage, err := provider.DockerImageName(task.Message.String) - - if err != nil { - return fmt.Errorf("failed to get docker image name: %w", err) - } - - flow, err := db.UpdateFlowName(context.Background(), database.UpdateFlowNameParams{ - ID: task.FlowID.Int64, - Name: database.StringToNullString(summary), - }) - - if err != nil { - return fmt.Errorf("failed to update flow: %w", err) - } - - subscriptions.BroadcastFlowUpdated(flow.ID, &gmodel.Flow{ - ID: uint(flow.ID), - Name: summary, - Terminal: &gmodel.Terminal{ - ContainerName: dockerImage, - Connected: false, - }, - }) - - msg := websocket.FormatTerminalSystemOutput(fmt.Sprintf("Initializing the docker image %s...", dockerImage)) - l, err := db.CreateLog(context.Background(), database.CreateLogParams{ - FlowID: task.FlowID, - Message: msg, - Type: "system", - }) - - if err != nil { - return fmt.Errorf("Error creating log: %w", err) - } - - subscriptions.BroadcastTerminalLogsAdded(flow.ID, &gmodel.Log{ - ID: uint(l.ID), - Text: msg, - }) - - terminalContainerName := TerminalName(flow.ID) - terminalContainerID, err := SpawnContainer(context.Background(), - terminalContainerName, - &container.Config{ - Image: dockerImage, - Cmd: []string{"tail", "-f", "/dev/null"}, - }, - &container.HostConfig{}, - db, - ) - - if err != nil { - return fmt.Errorf("failed to spawn container: %w", err) - } - - subscriptions.BroadcastFlowUpdated(flow.ID, &gmodel.Flow{ - ID: uint(flow.ID), - Name: summary, - Terminal: &gmodel.Terminal{ - Connected: true, - ContainerName: dockerImage, - }, - }) - - _, err = db.UpdateFlowContainer(context.Background(), database.UpdateFlowContainerParams{ - ID: flow.ID, - ContainerID: sql.NullInt64{Int64: terminalContainerID, Valid: true}, - }) - - if err != nil { - return fmt.Errorf("failed to update flow container: %w", err) - } - - msg = websocket.FormatTerminalSystemOutput("Container initialized. Ready to execute commands.") - l, err = db.CreateLog(context.Background(), database.CreateLogParams{ - FlowID: task.FlowID, - Message: msg, - Type: "system", - }) - - if err != nil { - return fmt.Errorf("Error creating log: %w", err) - } - subscriptions.BroadcastTerminalLogsAdded(flow.ID, &gmodel.Log{ - ID: uint(l.ID), - Text: msg, - }) - - if err != nil { - log.Printf("failed to send initialized message to the channel: %w", err) - } - } - - return nil -} - -func processAskTask(db *database.Queries, task database.Task) error { - task, err := db.UpdateTaskStatus(context.Background(), database.UpdateTaskStatusParams{ - Status: database.StringToNullString("finished"), - ID: task.ID, - }) - - if err != nil { - return fmt.Errorf("failed to find task with id %d: %w", task.ID, err) - } - - return nil -} - -func processTerminalTask(db *database.Queries, task database.Task) error { - var args = providers.TerminalArgs{} - err := json.Unmarshal([]byte(task.Args.String), &args) - if err != nil { - return fmt.Errorf("failed to unmarshal args: %v", err) - } - - results, err := ExecCommand(task.FlowID.Int64, args.Input, db) - - if err != nil { - return fmt.Errorf("failed to execute command: %w", err) - } - - _, err = db.UpdateTaskResults(context.Background(), database.UpdateTaskResultsParams{ - ID: task.ID, - Results: database.StringToNullString(results), - }) - - if err != nil { - return fmt.Errorf("failed to update task results: %w", err) - } - - return nil -} - -func processCodeTask(db *database.Queries, task database.Task) error { - var args = providers.CodeArgs{} - err := json.Unmarshal([]byte(task.Args.String), &args) - if err != nil { - return fmt.Errorf("failed to unmarshal args: %v", err) - } - - var cmd = "" - var results = "" - - if args.Action == providers.ReadFile { - // TODO consider using dockerClient.CopyFromContainer command instead - cmd = fmt.Sprintf("cat %s", args.Path) - results, err = ExecCommand(task.FlowID.Int64, cmd, db) - - if err != nil { - return fmt.Errorf("error executing cat command: %w", err) - } - } - - if args.Action == providers.UpdateFile { - err = WriteFile(task.FlowID.Int64, args.Content, args.Path, db) - - if err != nil { - return fmt.Errorf("error writing a file: %w", err) - } - - results = "File updated" - } - - if err != nil { - return fmt.Errorf("failed to execute command: %w", err) - } - - _, err = db.UpdateTaskResults(context.Background(), database.UpdateTaskResultsParams{ - ID: task.ID, - Results: database.StringToNullString(results), - }) - - if err != nil { - return fmt.Errorf("failed to update task results: %w", err) - } - - return nil -} - func getNextTask(provider providers.Provider, db *database.Queries, flowId int64) (*database.Task, error) { flow, err := db.ReadFlow(context.Background(), flowId) diff --git a/backend/providers/openai.go b/backend/providers/openai.go index a2b0a87..54b6c20 100644 --- a/backend/providers/openai.go +++ b/backend/providers/openai.go @@ -1,317 +1,57 @@ package providers import ( - "context" - "encoding/json" "log" - "github.com/invopop/jsonschema" - "github.com/semanser/ai-coder/assets" "github.com/semanser/ai-coder/config" "github.com/semanser/ai-coder/database" - "github.com/semanser/ai-coder/templates" - "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/openai" - "github.com/tmc/langchaingo/schema" ) type OpenAIProvider struct { - client *openai.LLM + client *openai.LLM + model string + baseURL string + name ProviderType } func (p OpenAIProvider) New() Provider { + model := config.Config.OpenAIModel + baseURL := config.Config.OpenAIServerURL + client, err := openai.New( openai.WithToken(config.Config.OpenAIKey), - openai.WithModel(config.Config.OpenAIModel), - openai.WithBaseURL(config.Config.OpenAIServerURL), + openai.WithModel(model), + openai.WithBaseURL(baseURL), ) if err != nil { log.Fatalf("Failed to create OpenAI client: %v", err) } - return OpenAIProvider{client} -} - -func (p OpenAIProvider) Summary(query string, n int) (string, error) { - prompt, err := templates.Render(assets.PromptTemplates, "prompts/summary.tmpl", map[string]any{ - "Text": query, - "N": n, - }) - if err != nil { - return "", err + return OpenAIProvider{ + client: client, + model: model, + baseURL: baseURL, + name: ProviderOpenAI, } +} - response, err := llms.GenerateFromSinglePrompt( - context.Background(), - p.client, - prompt, - llms.WithTemperature(0.0), - // Use a simpler model for this task - llms.WithModel(config.Config.OpenAIModel), - llms.WithTopP(0.2), - llms.WithN(1), - ) +func (p OpenAIProvider) Name() ProviderType { + return p.name +} - return response, err +func (p OpenAIProvider) Summary(query string, n int) (string, error) { + // TODO Use more basic model for this task + return Summary(p.client, config.Config.OpenAIModel, query, n) } func (p OpenAIProvider) DockerImageName(task string) (string, error) { - prompt, err := templates.Render(assets.PromptTemplates, "prompts/docker.tmpl", map[string]any{ - "Task": task, - }) - if err != nil { - return "", err - } - - response, err := llms.GenerateFromSinglePrompt( - context.Background(), - p.client, - prompt, - llms.WithTemperature(0.0), - // Use a simpler model for this task - llms.WithModel(config.Config.OpenAIModel), - llms.WithTopP(0.2), - llms.WithN(1), - ) - - return response, err + // TODO Use more basic model for this task + return DockerImageName(p.client, config.Config.OpenAIModel, task) } func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { - log.Println("Getting next task") - - prompt, err := templates.Render(assets.PromptTemplates, "prompts/agent.tmpl", args) - - // TODO In case of lots of tasks, we should try to get a summary using gpt-3.5 - if len(prompt) > 30000 { - log.Println("Prompt too long, asking user") - return defaultAskTask("My prompt is too long and I can't process it") - } - - if err != nil { - log.Println("Failed to render prompt, asking user, %w", err) - return defaultAskTask("There was an error getting the next task") - } - - tools := []llms.Tool{ - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "terminal", - Description: "Calls a terminal command", - Parameters: jsonschema.Reflect(&TerminalArgs{}).Definitions["TerminalArgs"], - }, - }, - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "browser", - Description: "Opens a browser to look for additional information", - Parameters: jsonschema.Reflect(&BrowserArgs{}).Definitions["BrowserArgs"], - }, - }, - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "code", - Description: "Modifies or reads code files", - Parameters: jsonschema.Reflect(&CodeArgs{}).Definitions["CodeArgs"], - }, - }, - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "ask", - Description: "Sends a question to the user for additional information", - Parameters: jsonschema.Reflect(&AskArgs{}).Definitions["AskArgs"], - }, - }, - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "done", - Description: "Mark the whole task as done. Should be called at the very end when everything is completed", - Parameters: jsonschema.Reflect(&DoneArgs{}).Definitions["DoneArgs"], - }, - }, - } - - var messages []llms.MessageContent - - messages = append(messages, llms.MessageContent{ - Role: schema.ChatMessageTypeSystem, - Parts: []llms.ContentPart{ - llms.TextPart(prompt), - }, - }) - - for _, task := range args.Tasks { - if task.Type.String == "input" { - messages = append(messages, llms.MessageContent{ - Role: schema.ChatMessageTypeHuman, - Parts: []llms.ContentPart{ - llms.TextPart(prompt), - }, - }) - } - - if task.ToolCallID.String != "" { - messages = append(messages, llms.MessageContent{ - Role: schema.ChatMessageTypeAI, - Parts: []llms.ContentPart{ - llms.ToolCall{ - ID: task.ToolCallID.String, - FunctionCall: &schema.FunctionCall{ - Name: task.Type.String, - Arguments: task.Args.String, - }, - Type: "function", - }, - }, - }) - - messages = append(messages, llms.MessageContent{ - Role: schema.ChatMessageTypeTool, - Parts: []llms.ContentPart{ - llms.ToolCallResponse{ - ToolCallID: task.ToolCallID.String, - Name: task.Type.String, - Content: task.Results.String, - }, - }, - }) - } - - // This Ask was generated by the agent itself in case of some error (not the OpenAI) - if task.Type.String == "ask" && task.ToolCallID.String == "" { - messages = append(messages, llms.MessageContent{ - Role: schema.ChatMessageTypeAI, - Parts: []llms.ContentPart{ - llms.TextPart(task.Message.String), - }, - }) - } - } - - resp, err := p.client.GenerateContent( - context.Background(), - messages, - llms.WithTemperature(0.0), - llms.WithModel(config.Config.OpenAIModel), - llms.WithTopP(0.2), - llms.WithN(1), - llms.WithTools(tools), - ) - - if err != nil { - log.Printf("Failed to get response from OpenAI %v", err) - return defaultAskTask("There was an error getting the next task") - } - - choices := resp.Choices - - if len(choices) == 0 { - log.Println("No choices found, asking user") - return defaultAskTask("Looks like I couldn't find a task to run") - } - - toolCalls := choices[0].ToolCalls - - if len(toolCalls) == 0 { - log.Println("No tool calls found, asking user") - return defaultAskTask("I couln't find a task to run") - } - - tool := toolCalls[0] - - if tool.FunctionCall.Name == "" { - log.Println("No tool found, asking user") - return defaultAskTask("The next task is empty, I don't know what to do next") - } - - task := database.Task{ - Type: database.StringToNullString(tool.FunctionCall.Name), - } - - switch tool.FunctionCall.Name { - case "terminal": - params, err := extractArgs(tool.FunctionCall.Arguments, &TerminalArgs{}) - if err != nil { - log.Printf("Failed to extract terminal args, asking user: %v", err) - return defaultAskTask("There was an error running the terminal command") - } - args, err := json.Marshal(params) - if err != nil { - log.Printf("Failed to marshal terminal args, asking user: %v", err) - return defaultAskTask("There was an error running the terminal command") - } - task.Args = database.StringToNullString(string(args)) - - // Sometimes the model returns an empty string for the message - msg := string(params.Message) - if msg == "" { - msg = params.Input - } - - task.Message = database.StringToNullString(msg) - task.Status = database.StringToNullString("in_progress") - - case "browser": - params, err := extractArgs(tool.FunctionCall.Arguments, &BrowserArgs{}) - if err != nil { - log.Printf("Failed to extract browser args, asking user: %v", err) - return defaultAskTask("There was an error opening the browser") - } - args, err := json.Marshal(params) - if err != nil { - log.Printf("Failed to marshal browser args, asking user: %v", err) - return defaultAskTask("There was an error opening the browser") - } - task.Args = database.StringToNullString(string(args)) - task.Message = database.StringToNullString(string(params.Message)) - case "code": - params, err := extractArgs(tool.FunctionCall.Arguments, &CodeArgs{}) - if err != nil { - log.Printf("Failed to extract code args, asking user: %v", err) - return defaultAskTask("There was an error reading or updating the file") - } - args, err := json.Marshal(params) - if err != nil { - log.Printf("Failed to marshal code args, asking user: %v", err) - return defaultAskTask("There was an error reading or updating the file") - } - task.Args = database.StringToNullString(string(args)) - task.Message = database.StringToNullString(string(params.Message)) - case "ask": - params, err := extractArgs(tool.FunctionCall.Arguments, &AskArgs{}) - if err != nil { - log.Printf("Failed to extract ask args, asking user: %v", err) - return defaultAskTask("There was an error asking the user for additional information") - } - args, err := json.Marshal(params) - if err != nil { - log.Printf("Failed to marshal ask args, asking user: %v", err) - return defaultAskTask("There was an error asking the user for additional information") - } - task.Args = database.StringToNullString(string(args)) - task.Message = database.StringToNullString(string(params.Message)) - case "done": - params, err := extractArgs(tool.FunctionCall.Arguments, &DoneArgs{}) - if err != nil { - log.Printf("Failed to extract done args, asking user: %v", err) - return defaultAskTask("There was an error marking the task as done") - } - args, err := json.Marshal(params) - if err != nil { - return defaultAskTask("There was an error marking the task as done") - } - task.Args = database.StringToNullString(string(args)) - task.Message = database.StringToNullString(string(params.Message)) - } - - task.ToolCallID = database.StringToNullString(tool.ID) - - return &task + return NextTask(args, p.client) } diff --git a/backend/providers/providers.go b/backend/providers/providers.go index 3b6d97e..9c6c48e 100644 --- a/backend/providers/providers.go +++ b/backend/providers/providers.go @@ -1,11 +1,20 @@ package providers import ( + "context" "database/sql" "encoding/json" "fmt" + "log" + "github.com/semanser/ai-coder/assets" + "github.com/semanser/ai-coder/config" "github.com/semanser/ai-coder/database" + "github.com/semanser/ai-coder/templates" + + "github.com/invopop/jsonschema" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/schema" ) type ProviderType string @@ -16,6 +25,7 @@ const ( type Provider interface { New() Provider + Name() ProviderType Summary(query string, n int) (string, error) DockerImageName(task string) (string, error) NextTask(args NextTaskOptions) *database.Task @@ -28,57 +38,292 @@ type NextTaskOptions struct { func ProviderFactory(provider ProviderType) (Provider, error) { switch provider { - case "openai": + case ProviderOpenAI: return OpenAIProvider{}.New(), nil default: return nil, fmt.Errorf("unknown provider: %s", provider) } } -type Message string +func Summary(llm llms.Model, model string, query string, n int) (string, error) { + prompt, err := templates.Render(assets.PromptTemplates, "prompts/summary.tmpl", map[string]any{ + "Text": query, + "N": n, + }) + if err != nil { + return "", err + } -type InputArgs struct { - Query string -} + response, err := llms.GenerateFromSinglePrompt( + context.Background(), + llm, + prompt, + llms.WithTemperature(0.0), + // Use a simpler model for this task + llms.WithModel(model), + llms.WithTopP(0.2), + llms.WithN(1), + ) -type TerminalArgs struct { - Input string - Message + return response, err } -type BrowserAction string +func DockerImageName(llm llms.Model, model string, task string) (string, error) { + prompt, err := templates.Render(assets.PromptTemplates, "prompts/docker.tmpl", map[string]any{ + "Task": task, + }) + if err != nil { + return "", err + } -const ( - Read BrowserAction = "read" - Url BrowserAction = "url" -) + response, err := llms.GenerateFromSinglePrompt( + context.Background(), + llm, + prompt, + llms.WithTemperature(0.0), + llms.WithModel(model), + llms.WithTopP(0.2), + llms.WithN(1), + ) -type BrowserArgs struct { - Url string - Action BrowserAction - Message + return response, err } -type CodeAction string +func NextTask(args NextTaskOptions, llm llms.Model) *database.Task { + log.Println("Getting next task") -const ( - ReadFile CodeAction = "read_file" - UpdateFile CodeAction = "update_file" -) + prompt, err := templates.Render(assets.PromptTemplates, "prompts/agent.tmpl", args) -type CodeArgs struct { - Action CodeAction - Content string - Path string - Message -} + // TODO In case of lots of tasks, we should try to get a summary using gpt-3.5 + if len(prompt) > 30000 { + log.Println("Prompt too long, asking user") + return defaultAskTask("My prompt is too long and I can't process it") + } -type AskArgs struct { - Message -} + if err != nil { + log.Println("Failed to render prompt, asking user, %w", err) + return defaultAskTask("There was an error getting the next task") + } + + tools := []llms.Tool{ + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "terminal", + Description: "Calls a terminal command", + Parameters: jsonschema.Reflect(&TerminalArgs{}).Definitions["TerminalArgs"], + }, + }, + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "browser", + Description: "Opens a browser to look for additional information", + Parameters: jsonschema.Reflect(&BrowserArgs{}).Definitions["BrowserArgs"], + }, + }, + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "code", + Description: "Modifies or reads code files", + Parameters: jsonschema.Reflect(&CodeArgs{}).Definitions["CodeArgs"], + }, + }, + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "ask", + Description: "Sends a question to the user for additional information", + Parameters: jsonschema.Reflect(&AskArgs{}).Definitions["AskArgs"], + }, + }, + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "done", + Description: "Mark the whole task as done. Should be called at the very end when everything is completed", + Parameters: jsonschema.Reflect(&DoneArgs{}).Definitions["DoneArgs"], + }, + }, + } + + var messages []llms.MessageContent + + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeSystem, + Parts: []llms.ContentPart{ + llms.TextPart(prompt), + }, + }) + + for _, task := range args.Tasks { + if task.Type.String == "input" { + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextPart(prompt), + }, + }) + } + + if task.ToolCallID.String != "" { + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.ToolCall{ + ID: task.ToolCallID.String, + FunctionCall: &schema.FunctionCall{ + Name: task.Type.String, + Arguments: task.Args.String, + }, + Type: "function", + }, + }, + }) + + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.ToolCallResponse{ + ToolCallID: task.ToolCallID.String, + Name: task.Type.String, + Content: task.Results.String, + }, + }, + }) + } + + // This Ask was generated by the agent itself in case of some error (not the OpenAI) + if task.Type.String == "ask" && task.ToolCallID.String == "" { + messages = append(messages, llms.MessageContent{ + Role: schema.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.TextPart(task.Message.String), + }, + }) + } + } -type DoneArgs struct { - Message + resp, err := llm.GenerateContent( + context.Background(), + messages, + llms.WithTemperature(0.0), + llms.WithModel(config.Config.OpenAIModel), + llms.WithTopP(0.2), + llms.WithN(1), + llms.WithTools(tools), + ) + + if err != nil { + log.Printf("Failed to get response from OpenAI %v", err) + return defaultAskTask("There was an error getting the next task") + } + + choices := resp.Choices + + if len(choices) == 0 { + log.Println("No choices found, asking user") + return defaultAskTask("Looks like I couldn't find a task to run") + } + + toolCalls := choices[0].ToolCalls + + if len(toolCalls) == 0 { + log.Println("No tool calls found, asking user") + return defaultAskTask("I couln't find a task to run") + } + + tool := toolCalls[0] + + if tool.FunctionCall.Name == "" { + log.Println("No tool found, asking user") + return defaultAskTask("The next task is empty, I don't know what to do next") + } + + task := database.Task{ + Type: database.StringToNullString(tool.FunctionCall.Name), + } + + switch tool.FunctionCall.Name { + case "terminal": + params, err := extractArgs(tool.FunctionCall.Arguments, &TerminalArgs{}) + if err != nil { + log.Printf("Failed to extract terminal args, asking user: %v", err) + return defaultAskTask("There was an error running the terminal command") + } + args, err := json.Marshal(params) + if err != nil { + log.Printf("Failed to marshal terminal args, asking user: %v", err) + return defaultAskTask("There was an error running the terminal command") + } + task.Args = database.StringToNullString(string(args)) + + // Sometimes the model returns an empty string for the message + msg := string(params.Message) + if msg == "" { + msg = params.Input + } + + task.Message = database.StringToNullString(msg) + task.Status = database.StringToNullString("in_progress") + + case "browser": + params, err := extractArgs(tool.FunctionCall.Arguments, &BrowserArgs{}) + if err != nil { + log.Printf("Failed to extract browser args, asking user: %v", err) + return defaultAskTask("There was an error opening the browser") + } + args, err := json.Marshal(params) + if err != nil { + log.Printf("Failed to marshal browser args, asking user: %v", err) + return defaultAskTask("There was an error opening the browser") + } + task.Args = database.StringToNullString(string(args)) + task.Message = database.StringToNullString(string(params.Message)) + case "code": + params, err := extractArgs(tool.FunctionCall.Arguments, &CodeArgs{}) + if err != nil { + log.Printf("Failed to extract code args, asking user: %v", err) + return defaultAskTask("There was an error reading or updating the file") + } + args, err := json.Marshal(params) + if err != nil { + log.Printf("Failed to marshal code args, asking user: %v", err) + return defaultAskTask("There was an error reading or updating the file") + } + task.Args = database.StringToNullString(string(args)) + task.Message = database.StringToNullString(string(params.Message)) + case "ask": + params, err := extractArgs(tool.FunctionCall.Arguments, &AskArgs{}) + if err != nil { + log.Printf("Failed to extract ask args, asking user: %v", err) + return defaultAskTask("There was an error asking the user for additional information") + } + args, err := json.Marshal(params) + if err != nil { + log.Printf("Failed to marshal ask args, asking user: %v", err) + return defaultAskTask("There was an error asking the user for additional information") + } + task.Args = database.StringToNullString(string(args)) + task.Message = database.StringToNullString(string(params.Message)) + case "done": + params, err := extractArgs(tool.FunctionCall.Arguments, &DoneArgs{}) + if err != nil { + log.Printf("Failed to extract done args, asking user: %v", err) + return defaultAskTask("There was an error marking the task as done") + } + args, err := json.Marshal(params) + if err != nil { + return defaultAskTask("There was an error marking the task as done") + } + task.Args = database.StringToNullString(string(args)) + task.Message = database.StringToNullString(string(params.Message)) + } + + task.ToolCallID = database.StringToNullString(tool.ID) + + return &task } func defaultAskTask(message string) *database.Task { diff --git a/backend/providers/types.go b/backend/providers/types.go new file mode 100644 index 0000000..3bc27fc --- /dev/null +++ b/backend/providers/types.go @@ -0,0 +1,47 @@ +package providers + +type Message string + +type InputArgs struct { + Query string +} + +type TerminalArgs struct { + Input string + Message +} + +type BrowserAction string + +const ( + Read BrowserAction = "read" + Url BrowserAction = "url" +) + +type BrowserArgs struct { + Url string + Action BrowserAction + Message +} + +type CodeAction string + +const ( + ReadFile CodeAction = "read_file" + UpdateFile CodeAction = "update_file" +) + +type CodeArgs struct { + Action CodeAction + Content string + Path string + Message +} + +type AskArgs struct { + Message +} + +type DoneArgs struct { + Message +} From b511283e33242311ac160492a63ef191dea18628 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Wed, 3 Apr 2024 13:37:41 +0200 Subject: [PATCH 04/14] Add Ollama support --- backend/config/config.go | 11 +- backend/executor/browser.go | 50 ++-- backend/executor/queue.go | 2 +- backend/providers/common.go | 53 ++++ backend/providers/ollama.go | 155 ++++++++++++ backend/providers/openai.go | 47 +++- backend/providers/providers.go | 357 +++++++++++---------------- backend/templates/prompts/agent.tmpl | 39 +-- 8 files changed, 454 insertions(+), 260 deletions(-) create mode 100644 backend/providers/common.go create mode 100644 backend/providers/ollama.go diff --git a/backend/config/config.go b/backend/config/config.go index 74eec4e..6e6c649 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -8,11 +8,18 @@ import ( ) type config struct { + // General + DatabaseURL string `env:"DATABASE_URL" envDefault:"database.db"` + Port int `env:"PORT" envDefault:"8080"` + + // OpenAI OpenAIKey string `env:"OPEN_AI_KEY"` OpenAIModel string `env:"OPEN_AI_MODEL" envDefault:"gpt-4-0125-preview"` OpenAIServerURL string `env:"OPEN_AI_SERVER_URL" envDefault:"https://api.openai.com/v1"` - DatabaseURL string `env:"DATABASE_URL" envDefault:"database.db"` - Port int `env:"PORT" envDefault:"8080"` + + // Ollama + OllamaModel string `env:"OLLAMA_MODEL" envDefault:"llama2"` + OllamaServerURL string `env:"OLLAMA_SERVER_URL" envDefault:"http://localhost:11434"` } var Config config diff --git a/backend/executor/browser.go b/backend/executor/browser.go index 8677e81..7fe21dd 100644 --- a/backend/executor/browser.go +++ b/backend/executor/browser.go @@ -57,37 +57,37 @@ func Content(url string) (result string, screenshotName string, err error) { page, err := loadPage() if err != nil { - return "", "", fmt.Errorf("Error loading page: %w", err) + return "", "", fmt.Errorf("error loading page: %w", err) } err = loadUrl(page, url) if err != nil { - return "", "", fmt.Errorf("Error loading url: %w", err) + return "", "", fmt.Errorf("error loading url: %w", err) } script, err := templates.Render(assets.ScriptTemplates, "scripts/content.js", nil) if err != nil { - return "", "", fmt.Errorf("Error reading script: %w", err) + return "", "", fmt.Errorf("error reading script: %w", err) } pageText, err := page.Eval(string(script)) if err != nil { - return "", "", fmt.Errorf("Error evaluating script: %w", err) + return "", "", fmt.Errorf("error evaluating script: %w", err) } screenshot, err := page.Screenshot(false, nil) if err != nil { - return "", "", fmt.Errorf("Error taking screenshot: %w", err) + return "", "", fmt.Errorf("error taking screenshot: %w", err) } screenshotName, err = writeScreenshotToFile(screenshot) if err != nil { - return "", "", fmt.Errorf("Error writing screenshot to file: %w", err) + return "", "", fmt.Errorf("error writing screenshot to file: %w", err) } return pageText.Value.Str(), screenshotName, nil @@ -99,37 +99,37 @@ func URLs(url string) (result string, screenshotName string, err error) { page, err := loadPage() if err != nil { - return "", "", fmt.Errorf("Error loading page: %w", err) + return "", "", fmt.Errorf("error loading page: %w", err) } err = loadUrl(page, url) if err != nil { - return "", "", fmt.Errorf("Error loading url: %w", err) + return "", "", fmt.Errorf("error loading url: %w", err) } script, err := templates.Render(assets.ScriptTemplates, "scripts/urls.js", nil) if err != nil { - return "", "", fmt.Errorf("Error reading script: %w", err) + return "", "", fmt.Errorf("error reading script: %w", err) } urls, err := page.Eval(string(script)) if err != nil { - return "", "", fmt.Errorf("Error evaluating script: %w", err) + return "", "", fmt.Errorf("error evaluating script: %w", err) } screenshot, err := page.Screenshot(true, nil) if err != nil { - return "", "", fmt.Errorf("Error taking screenshot: %w", err) + return "", "", fmt.Errorf("error taking screenshot: %w", err) } screenshotName, err = writeScreenshotToFile(screenshot) if err != nil { - return "", "", fmt.Errorf("Error writing screenshot to file: %w", err) + return "", "", fmt.Errorf("error writing screenshot to file: %w", err) } return urls.Value.Str(), screenshotName, nil @@ -143,13 +143,13 @@ func writeScreenshotToFile(screenshot []byte) (filename string, err error) { err = os.MkdirAll(path, os.ModePerm) if err != nil { - return "", fmt.Errorf("Error creating directory: %w", err) + return "", fmt.Errorf("error creating directory: %w", err) } file, err := os.Create(filepath) if err != nil { - return "", fmt.Errorf("Error creating file: %w", err) + return "", fmt.Errorf("error creating file: %w", err) } defer file.Close() @@ -157,41 +157,41 @@ func writeScreenshotToFile(screenshot []byte) (filename string, err error) { _, err = file.Write(screenshot) if err != nil { - return "", fmt.Errorf("Error writing to file: %w", err) + return "", fmt.Errorf("error writing to file: %w", err) } return filename, nil } func BrowserName() string { - return fmt.Sprintf("codel-browser") + return "codel-browser" } func loadPage() (*rod.Page, error) { u, err := launcher.ResolveURL("") if err != nil { - return nil, fmt.Errorf("Error resolving url: %w", err) + return nil, fmt.Errorf("error resolving url: %w", err) } browser := rod.New().ControlURL(u) err = browser.Connect() - version, err := browser.Version() - if err != nil { - return nil, fmt.Errorf("Error getting browser version: %w", err) + return nil, fmt.Errorf("error connecting to browser: %w", err) } - log.Println("Connected to browser %s", version.Product) + + version, err := browser.Version() if err != nil { - return nil, fmt.Errorf("Error connecting to browser: %w", err) + return nil, fmt.Errorf("error getting browser version: %w", err) } + log.Printf("Connected to browser %s", version.Product) page, err := browser.Page(proto.TargetCreateTarget{}) if err != nil { - return nil, fmt.Errorf("Error opening page: %w", err) + return nil, fmt.Errorf("error opening page: %w", err) } return page, nil @@ -222,13 +222,13 @@ func loadUrl(page *rod.Page, url string) error { err := page.Navigate(url) if err != nil { - return fmt.Errorf("Error navigating to page: %w", err) + return fmt.Errorf("error navigating to page: %w", err) } err = page.WaitDOMStable(time.Second*1, 5) if err != nil { - return fmt.Errorf("Error waiting for page to stabilize: %w", err) + return fmt.Errorf("error waiting for page to stabilize: %w", err) } return nil diff --git a/backend/executor/queue.go b/backend/executor/queue.go index b658b91..b16b9c6 100644 --- a/backend/executor/queue.go +++ b/backend/executor/queue.go @@ -48,7 +48,7 @@ func CleanQueue(flowId int64) { func ProcessQueue(flowId int64, db *database.Queries) { log.Println("Starting tasks processor for queue", flowId) - provider, err := providers.ProviderFactory(providers.ProviderOpenAI) + provider, err := providers.ProviderFactory(providers.ProviderOllama) log.Println("Using provider: ", provider.Name()) diff --git a/backend/providers/common.go b/backend/providers/common.go new file mode 100644 index 0000000..8d83025 --- /dev/null +++ b/backend/providers/common.go @@ -0,0 +1,53 @@ +package providers + +import ( + "context" + + "github.com/semanser/ai-coder/assets" + "github.com/semanser/ai-coder/templates" + "github.com/tmc/langchaingo/llms" +) + +func Summary(llm llms.Model, model string, query string, n int) (string, error) { + prompt, err := templates.Render(assets.PromptTemplates, "prompts/summary.tmpl", map[string]any{ + "Text": query, + "N": n, + }) + if err != nil { + return "", err + } + + response, err := llms.GenerateFromSinglePrompt( + context.Background(), + llm, + prompt, + llms.WithTemperature(0.0), + // TODO Use a simpler model for this task + llms.WithModel(model), + llms.WithTopP(0.2), + llms.WithN(1), + ) + + return response, err +} + +func DockerImageName(llm llms.Model, model string, task string) (string, error) { + prompt, err := templates.Render(assets.PromptTemplates, "prompts/docker.tmpl", map[string]any{ + "Task": task, + }) + if err != nil { + return "", err + } + + response, err := llms.GenerateFromSinglePrompt( + context.Background(), + llm, + prompt, + llms.WithTemperature(0.0), + llms.WithModel(model), + llms.WithTopP(0.2), + llms.WithN(1), + ) + + return response, err +} diff --git a/backend/providers/ollama.go b/backend/providers/ollama.go new file mode 100644 index 0000000..8fc5cd2 --- /dev/null +++ b/backend/providers/ollama.go @@ -0,0 +1,155 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "log" + + "github.com/semanser/ai-coder/assets" + "github.com/semanser/ai-coder/config" + "github.com/semanser/ai-coder/database" + "github.com/semanser/ai-coder/templates" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/ollama" +) + +type OllamaProvider struct { + client *ollama.LLM + model string + baseURL string + name ProviderType +} + +func (p OllamaProvider) New() Provider { + model := config.Config.OllamaModel + baseURL := config.Config.OllamaServerURL + + client, err := ollama.New( + ollama.WithModel(model), + ollama.WithFormat("json"), + ollama.WithServerURL(baseURL), + ) + + if err != nil { + log.Fatalf("Failed to create Ollama client: %v", err) + } + + return OllamaProvider{ + client: client, + model: model, + baseURL: baseURL, + name: ProviderOllama, + } +} + +func (p OllamaProvider) Name() ProviderType { + return p.name +} + +func (p OllamaProvider) Summary(query string, n int) (string, error) { + client, err := ollama.New( + ollama.WithModel(p.model), + ) + + if err != nil { + return "", fmt.Errorf("failed to create Ollama client: %v", err) + } + + return Summary(client, p.model, query, n) +} + +func (p OllamaProvider) DockerImageName(task string) (string, error) { + client, err := ollama.New( + ollama.WithModel(p.model), + ) + + if err != nil { + return "", fmt.Errorf("failed to create Ollama client: %v", err) + } + + return DockerImageName(client, p.model, task) +} + +type Call struct { + Tool string `json:"tool"` + Input map[string]string `json:"tool_input"` + Message string `json:"message"` +} + +func (p OllamaProvider) NextTask(args NextTaskOptions) *database.Task { + log.Println("Getting next task") + + promptArgs := map[string]interface{}{ + "DockerImage": args.DockerImage, + "ToolPlaceholder": getToolPlaceholder(), + "Tasks": args.Tasks, + } + + prompt, err := templates.Render(assets.PromptTemplates, "prompts/agent.tmpl", promptArgs) + + // TODO In case of lots of tasks, we should try to get a summary using gpt-3.5 + if len(prompt) > 30000 { + log.Println("Prompt too long, asking user") + return defaultAskTask("My prompt is too long and I can't process it") + } + + if err != nil { + log.Println("Failed to render prompt, asking user, %w", err) + return defaultAskTask("There was an error getting the next task") + } + + messages := tasksToMessages(args.Tasks, prompt) + + resp, err := p.client.GenerateContent( + context.Background(), + messages, + llms.WithTemperature(0.0), + llms.WithModel(p.model), + llms.WithTopP(0.2), + llms.WithN(1), + ) + + if err != nil { + log.Printf("Failed to get response from model %v", err) + return defaultAskTask("There was an error getting the next task") + } + + choices := resp.Choices + + if len(choices) == 0 { + log.Println("No choices found, asking user") + return defaultAskTask("Looks like I couldn't find a task to run") + } + + task, err := textToTask(choices[0].Content) + + if err != nil { + log.Println("Failed to convert text to the next task, asking user") + return defaultAskTask("There was an error getting the next task") + } + + return task +} + +func getToolPlaceholder() string { + bs, err := json.Marshal(Tools) + if err != nil { + log.Fatal(err) + } + + return fmt.Sprintf(`You have access to the following tools: + +%s + +To use a tool, respond with a JSON object with the following structure: +{ + "tool": , + "tool_input": , + "message": +} + +Always use a tool. Always reply with valid JOSN. Always include a message. +`, string(bs)) +} diff --git a/backend/providers/openai.go b/backend/providers/openai.go index 54b6c20..88e882e 100644 --- a/backend/providers/openai.go +++ b/backend/providers/openai.go @@ -1,11 +1,15 @@ package providers import ( + "context" "log" + "github.com/semanser/ai-coder/assets" "github.com/semanser/ai-coder/config" "github.com/semanser/ai-coder/database" + "github.com/semanser/ai-coder/templates" + "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/openai" ) @@ -43,15 +47,52 @@ func (p OpenAIProvider) Name() ProviderType { } func (p OpenAIProvider) Summary(query string, n int) (string, error) { - // TODO Use more basic model for this task return Summary(p.client, config.Config.OpenAIModel, query, n) } func (p OpenAIProvider) DockerImageName(task string) (string, error) { - // TODO Use more basic model for this task return DockerImageName(p.client, config.Config.OpenAIModel, task) } func (p OpenAIProvider) NextTask(args NextTaskOptions) *database.Task { - return NextTask(args, p.client) + log.Println("Getting next task") + + prompt, err := templates.Render(assets.PromptTemplates, "prompts/agent.tmpl", args) + + // TODO In case of lots of tasks, we should try to get a summary using gpt-3.5 + if len(prompt) > 30000 { + log.Println("Prompt too long, asking user") + return defaultAskTask("My prompt is too long and I can't process it") + } + + if err != nil { + log.Println("Failed to render prompt, asking user, %w", err) + return defaultAskTask("There was an error getting the next task") + } + + messages := tasksToMessages(args.Tasks, prompt) + + resp, err := p.client.GenerateContent( + context.Background(), + messages, + llms.WithTemperature(0.0), + llms.WithModel(p.model), + llms.WithTopP(0.2), + llms.WithN(1), + llms.WithTools(Tools), + ) + + if err != nil { + log.Printf("Failed to get response from model %v", err) + return defaultAskTask("There was an error getting the next task") + } + + task, err := toolToTask(resp.Choices) + + if err != nil { + log.Printf("Failed to convert tool to task %v", err) + return defaultAskTask("There was an error getting the next task") + } + + return task } diff --git a/backend/providers/providers.go b/backend/providers/providers.go index 9c6c48e..cdc5a92 100644 --- a/backend/providers/providers.go +++ b/backend/providers/providers.go @@ -1,16 +1,12 @@ package providers import ( - "context" "database/sql" "encoding/json" "fmt" "log" - "github.com/semanser/ai-coder/assets" - "github.com/semanser/ai-coder/config" "github.com/semanser/ai-coder/database" - "github.com/semanser/ai-coder/templates" "github.com/invopop/jsonschema" "github.com/tmc/langchaingo/llms" @@ -21,6 +17,7 @@ type ProviderType string const ( ProviderOpenAI ProviderType = "openai" + ProviderOllama ProviderType = "ollama" ) type Provider interface { @@ -36,120 +33,76 @@ type NextTaskOptions struct { DockerImage string } +var Tools = []llms.Tool{ + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "terminal", + Description: "Calls a terminal command", + Parameters: jsonschema.Reflect(&TerminalArgs{}).Definitions["TerminalArgs"], + }, + }, + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "browser", + Description: "Opens a browser to look for additional information", + Parameters: jsonschema.Reflect(&BrowserArgs{}).Definitions["BrowserArgs"], + }, + }, + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "code", + Description: "Modifies or reads code files", + Parameters: jsonschema.Reflect(&CodeArgs{}).Definitions["CodeArgs"], + }, + }, + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "ask", + Description: "Sends a question to the user for additional information", + Parameters: jsonschema.Reflect(&AskArgs{}).Definitions["AskArgs"], + }, + }, + { + Type: "function", + Function: &llms.FunctionDefinition{ + Name: "done", + Description: "Mark the whole task as done. Should be called at the very end when everything is completed", + Parameters: jsonschema.Reflect(&DoneArgs{}).Definitions["DoneArgs"], + }, + }, +} + func ProviderFactory(provider ProviderType) (Provider, error) { switch provider { case ProviderOpenAI: return OpenAIProvider{}.New(), nil + case ProviderOllama: + return OllamaProvider{}.New(), nil default: return nil, fmt.Errorf("unknown provider: %s", provider) } } -func Summary(llm llms.Model, model string, query string, n int) (string, error) { - prompt, err := templates.Render(assets.PromptTemplates, "prompts/summary.tmpl", map[string]any{ - "Text": query, - "N": n, - }) - if err != nil { - return "", err +func defaultAskTask(message string) *database.Task { + task := database.Task{ + Type: database.StringToNullString("ask"), } - response, err := llms.GenerateFromSinglePrompt( - context.Background(), - llm, - prompt, - llms.WithTemperature(0.0), - // Use a simpler model for this task - llms.WithModel(model), - llms.WithTopP(0.2), - llms.WithN(1), - ) - - return response, err -} - -func DockerImageName(llm llms.Model, model string, task string) (string, error) { - prompt, err := templates.Render(assets.PromptTemplates, "prompts/docker.tmpl", map[string]any{ - "Task": task, - }) - if err != nil { - return "", err + task.Args = database.StringToNullString("{}") + task.Message = sql.NullString{ + String: fmt.Sprintf("%s. What should I do next?", message), + Valid: true, } - response, err := llms.GenerateFromSinglePrompt( - context.Background(), - llm, - prompt, - llms.WithTemperature(0.0), - llms.WithModel(model), - llms.WithTopP(0.2), - llms.WithN(1), - ) - - return response, err + return &task } -func NextTask(args NextTaskOptions, llm llms.Model) *database.Task { - log.Println("Getting next task") - - prompt, err := templates.Render(assets.PromptTemplates, "prompts/agent.tmpl", args) - - // TODO In case of lots of tasks, we should try to get a summary using gpt-3.5 - if len(prompt) > 30000 { - log.Println("Prompt too long, asking user") - return defaultAskTask("My prompt is too long and I can't process it") - } - - if err != nil { - log.Println("Failed to render prompt, asking user, %w", err) - return defaultAskTask("There was an error getting the next task") - } - - tools := []llms.Tool{ - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "terminal", - Description: "Calls a terminal command", - Parameters: jsonschema.Reflect(&TerminalArgs{}).Definitions["TerminalArgs"], - }, - }, - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "browser", - Description: "Opens a browser to look for additional information", - Parameters: jsonschema.Reflect(&BrowserArgs{}).Definitions["BrowserArgs"], - }, - }, - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "code", - Description: "Modifies or reads code files", - Parameters: jsonschema.Reflect(&CodeArgs{}).Definitions["CodeArgs"], - }, - }, - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "ask", - Description: "Sends a question to the user for additional information", - Parameters: jsonschema.Reflect(&AskArgs{}).Definitions["AskArgs"], - }, - }, - { - Type: "function", - Function: &llms.FunctionDefinition{ - Name: "done", - Description: "Mark the whole task as done. Should be called at the very end when everything is completed", - Parameters: jsonschema.Reflect(&DoneArgs{}).Definitions["DoneArgs"], - }, - }, - } - +func tasksToMessages(tasks []database.Task, prompt string) []llms.MessageContent { var messages []llms.MessageContent - messages = append(messages, llms.MessageContent{ Role: schema.ChatMessageTypeSystem, Parts: []llms.ContentPart{ @@ -157,7 +110,7 @@ func NextTask(args NextTaskOptions, llm llms.Model) *database.Task { }, }) - for _, task := range args.Tasks { + for _, task := range tasks { if task.Type.String == "input" { messages = append(messages, llms.MessageContent{ Role: schema.ChatMessageTypeHuman, @@ -205,143 +158,123 @@ func NextTask(args NextTaskOptions, llm llms.Model) *database.Task { } } - resp, err := llm.GenerateContent( - context.Background(), - messages, - llms.WithTemperature(0.0), - llms.WithModel(config.Config.OpenAIModel), - llms.WithTopP(0.2), - llms.WithN(1), - llms.WithTools(tools), - ) + return messages +} + +func textToTask(text string) (*database.Task, error) { + c := unmarshalCall(text) + if c == nil { + return nil, fmt.Errorf("can't unmarshalCall %s", text) + } + + task := database.Task{ + // TODO validate tool name + Type: database.StringToNullString(c.Tool), + } + + arg, err := json.Marshal(c.Input) if err != nil { - log.Printf("Failed to get response from OpenAI %v", err) - return defaultAskTask("There was an error getting the next task") + log.Printf("Failed to marshal terminal args, asking user: %v", err) + return defaultAskTask("There was an error running the terminal command"), nil } + task.Args = database.StringToNullString(string(arg)) - choices := resp.Choices + // Sometimes the model returns an empty string for the message + // In that case, we use the input as the message + msg := c.Message + if msg == "" { + msg = string(arg) + } - if len(choices) == 0 { - log.Println("No choices found, asking user") - return defaultAskTask("Looks like I couldn't find a task to run") + task.Message = database.StringToNullString(msg) + task.Status = database.StringToNullString("in_progress") + + return &task, nil +} + +func extractJSONArgs[T any](functionArgs map[string]string, args *T) (*T, error) { + b, err := json.Marshal(functionArgs) + + if err != nil { + return nil, fmt.Errorf("failed to marshal args: %v", err) } - toolCalls := choices[0].ToolCalls + err = json.Unmarshal(b, args) - if len(toolCalls) == 0 { - log.Println("No tool calls found, asking user") - return defaultAskTask("I couln't find a task to run") + if err != nil { + return nil, fmt.Errorf("failed to unmarshal args: %v", err) } + return args, nil +} - tool := toolCalls[0] +func unmarshalCall(input string) *Call { + log.Printf("Unmarshalling tool call: %v", input) - if tool.FunctionCall.Name == "" { - log.Println("No tool found, asking user") - return defaultAskTask("The next task is empty, I don't know what to do next") + var c Call + + err := json.Unmarshal([]byte(input), &c) + if err != nil { + log.Printf("Failed to unmarshal tool call: %v", err) + return nil } - task := database.Task{ - Type: database.StringToNullString(tool.FunctionCall.Name), + if c.Tool != "" { + log.Printf("Unmarshalled tool call: %v", c) + return &c } - switch tool.FunctionCall.Name { - case "terminal": - params, err := extractArgs(tool.FunctionCall.Arguments, &TerminalArgs{}) - if err != nil { - log.Printf("Failed to extract terminal args, asking user: %v", err) - return defaultAskTask("There was an error running the terminal command") - } - args, err := json.Marshal(params) - if err != nil { - log.Printf("Failed to marshal terminal args, asking user: %v", err) - return defaultAskTask("There was an error running the terminal command") - } - task.Args = database.StringToNullString(string(args)) + return nil +} - // Sometimes the model returns an empty string for the message - msg := string(params.Message) - if msg == "" { - msg = params.Input - } +func toolToTask(choices []*llms.ContentChoice) (*database.Task, error) { + if len(choices) == 0 { + return nil, fmt.Errorf("no choices found, asking user") + } - task.Message = database.StringToNullString(msg) - task.Status = database.StringToNullString("in_progress") + toolCalls := choices[0].ToolCalls - case "browser": - params, err := extractArgs(tool.FunctionCall.Arguments, &BrowserArgs{}) - if err != nil { - log.Printf("Failed to extract browser args, asking user: %v", err) - return defaultAskTask("There was an error opening the browser") - } - args, err := json.Marshal(params) - if err != nil { - log.Printf("Failed to marshal browser args, asking user: %v", err) - return defaultAskTask("There was an error opening the browser") - } - task.Args = database.StringToNullString(string(args)) - task.Message = database.StringToNullString(string(params.Message)) - case "code": - params, err := extractArgs(tool.FunctionCall.Arguments, &CodeArgs{}) - if err != nil { - log.Printf("Failed to extract code args, asking user: %v", err) - return defaultAskTask("There was an error reading or updating the file") - } - args, err := json.Marshal(params) - if err != nil { - log.Printf("Failed to marshal code args, asking user: %v", err) - return defaultAskTask("There was an error reading or updating the file") - } - task.Args = database.StringToNullString(string(args)) - task.Message = database.StringToNullString(string(params.Message)) - case "ask": - params, err := extractArgs(tool.FunctionCall.Arguments, &AskArgs{}) - if err != nil { - log.Printf("Failed to extract ask args, asking user: %v", err) - return defaultAskTask("There was an error asking the user for additional information") - } - args, err := json.Marshal(params) - if err != nil { - log.Printf("Failed to marshal ask args, asking user: %v", err) - return defaultAskTask("There was an error asking the user for additional information") - } - task.Args = database.StringToNullString(string(args)) - task.Message = database.StringToNullString(string(params.Message)) - case "done": - params, err := extractArgs(tool.FunctionCall.Arguments, &DoneArgs{}) - if err != nil { - log.Printf("Failed to extract done args, asking user: %v", err) - return defaultAskTask("There was an error marking the task as done") - } - args, err := json.Marshal(params) - if err != nil { - return defaultAskTask("There was an error marking the task as done") - } - task.Args = database.StringToNullString(string(args)) - task.Message = database.StringToNullString(string(params.Message)) + if len(toolCalls) == 0 { + return nil, fmt.Errorf("no tool calls found, asking user") } - task.ToolCallID = database.StringToNullString(tool.ID) - - return &task -} + tool := toolCalls[0] -func defaultAskTask(message string) *database.Task { task := database.Task{ - Type: database.StringToNullString("ask"), + Type: database.StringToNullString(tool.FunctionCall.Name), } - task.Args = database.StringToNullString("{}") - task.Message = sql.NullString{ - String: fmt.Sprintf("%s. What should I do next?", message), - Valid: true, + if tool.FunctionCall.Name == "" { + return nil, fmt.Errorf("no tool name found, asking user") } - return &task + // We use AskArgs to extract the message + params, err := extractToolArgs(tool.FunctionCall.Arguments, &AskArgs{}) + if err != nil { + return nil, fmt.Errorf("failed to extract args: %v", err) + } + args, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("failed to marshal terminal args, asking user: %v", err) + } + task.Args = database.StringToNullString(string(args)) + + // Sometimes the model returns an empty string for the message + msg := string(params.Message) + if msg == "" { + msg = tool.FunctionCall.Arguments + } + + task.Message = database.StringToNullString(msg) + task.Status = database.StringToNullString("in_progress") + + task.ToolCallID = database.StringToNullString(tool.ID) + + return &task, nil } -func extractArgs[T any](openAIargs string, args *T) (*T, error) { - err := json.Unmarshal([]byte(openAIargs), args) +func extractToolArgs[T any](functionArgs string, args *T) (*T, error) { + err := json.Unmarshal([]byte(functionArgs), args) if err != nil { return nil, fmt.Errorf("failed to unmarshal args: %v", err) } diff --git a/backend/templates/prompts/agent.tmpl b/backend/templates/prompts/agent.tmpl index 5b9de9b..3753e42 100644 --- a/backend/templates/prompts/agent.tmpl +++ b/backend/templates/prompts/agent.tmpl @@ -2,8 +2,8 @@ You're a robot that performs engineering work to successfully finish a user-defi You have access to the terminal, browser, and text editor. You have to perform step-by-step work execution to achieve the end goal that is determined by the user. You will be provided with a list of previous commands (generated by LLM) and inputs (generated by the user). -Your goal is to figure out what is the best next step in this flow. -You can try multiple commands if you encounter some errors. +Your goal is to give the next best step in this flow. +You can try multiple commands if you encounter errors. Your goal is to make progress on each step, so your steps should NOT be repetitive. You can install packages and libraries when needed without asking for permissions using apt. Don't run apt-update to update packages. Assume that you're using the latest versions of everything. @@ -14,34 +14,39 @@ You don't want to spend much time on a single task. Never repeat the same command more than 3 times. All your commands will be executed inside a Docker {{.DockerImage}} image. Always use your function calling functionality instead of returning JSON. - -These are the possible types of commands for your next steps: - -- `terminal` - Use this command to execute a new command in a terminal that you're provided with. You will have an output of the command so you can use it in future commands. -- `browser` - Use the browser to get additional information from the internet. Use Google as the default search engine when you need more information but you're not sure what URL to open. -- `code` - Use this command to modify or read file content. -- `ask` - Use this command when you need to get more information from the user such as inputs, and any clarifications or questions that you may have. -- `stop` - Stop another action specified by its id. Eg, when you need to stop some long-running command. -- `done` - Mark the whole user task as done. Use this command only when the initial (main) task that the user wanted to accomplish is done. - Always include a `message` field that describes what you are planning to achieve with this command. Use conversation-like (chat) style of communication. For example: "My plan is to read the documentation. Looking for it on the web.", "Let me try to use the terminal to do that.", or "It seems like I'm having issues with npm. Are you sure it's installed?". The `message` field is always shown to the user, so you have to communicate clearly. It's mandatory to have it. +These are the possible types of commands for your next steps and their arguments: + Each command has a set of arguments that you always have to include: -- `terminal` +- `terminal` - Use this command to execute a new command in a terminal that you're provided with. You will have an output of the command so you can use it in future commands. - `input`: Command to be run in the terminal. -- `browser` +- `browser` - Use the browser to get additional information from the internet. Use Google as the default search engine when you need more information but you're not sure what URL to open. - `url`: URL to be opened in a browser. - `action`: Possible values: - `read` - Returns the content of the page. - `url` - Get the list of all URLs on the page to be used in later calls (e.g., open search results after the initial search lookup) -- `code` +- `code` - Use this command to modify or read file content. - `action`: Possible values: - `read_file` - Read the entire file - `update_file` - Update the entire file - `content`: Should be used only if action is update. This content will be used to replace the content of the entire file. - `path`: Path to the file that you want to work on. -- `ask` +- `ask` - Use this command when you need to get more information from the user such as inputs, and any clarifications or questions that you may have. - `input`: Question or any other information that should be sent to the user for clarifications. -- `done`: No arguments are needed. +- `done`: Mark the whole user task as done. Use this command only when the initial (main) task that the user wanted to accomplish is done. No arguments are needed. + +{{.ToolPlaceholder}} + +The history of all the previous commands and user inputs: +{{ range .Tasks }} +{ + "id": {{ .ID }}, + "type": "{{ .Type }}", + "args": {{ if .Args }}{{ .Args }}{{ else }}{}{{ end }}, + "results": {{ if .Results }}{{ .Results }}{{ else }}{}{{ end }}, + "message": "{{ .Message }}" +} +{{ end }} From 348585ddf7917a1a90c31e2a2682774cb476ad9e Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Wed, 3 Apr 2024 13:38:43 +0200 Subject: [PATCH 05/14] Do not fail if there are no openai key --- backend/config/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/config/config.go b/backend/config/config.go index 6e6c649..31faebe 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -28,7 +28,7 @@ func Init() { godotenv.Load() if err := env.ParseWithOptions(&Config, env.Options{ - RequiredIfNoDef: true, + RequiredIfNoDef: false, }); err != nil { log.Fatalf("Unable to parse config: %v\n", err) } From efa4052a7ed791e5c7b968c6450236e21b15a425 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Wed, 3 Apr 2024 14:00:16 +0200 Subject: [PATCH 06/14] Do not run Docker in detached mode --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7ec99b1..1d88883 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ The simplest way to run Codel is to use a pre-built Docker image. You can find t > Don't forget to set the required environment variables. ```bash -docker run -d \ +docker run \ -e OPEN_AI_KEY= \ -p 3000:8080 \ -v /var/run/docker.sock:/var/run/docker.sock \ @@ -31,7 +31,7 @@ docker run -d \ Alternatively, you can create a .env file and run the Docker image with the following command: ```bash -docker run -d \ +docker run \ --env-file .env \ -p 3000:8080 \ -v /var/run/docker.sock:/var/run/docker.sock \ From 36bd3a85d40106dd81444b9ce012c84525dd4220 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Wed, 3 Apr 2024 14:29:38 +0200 Subject: [PATCH 07/14] Add model for each flow. Expose it via graphql api --- backend/config/config.go | 2 +- backend/database/containers.sql.go | 2 +- backend/database/db.go | 2 +- backend/database/flows.sql.go | 34 ++- backend/database/logs.sql.go | 2 +- backend/database/models.go | 3 +- backend/database/tasks.sql.go | 2 +- backend/executor/queue.go | 12 +- backend/graph/generated.go | 217 +++++++++++++++++- backend/graph/model/models_gen.go | 1 + backend/graph/schema.graphqls | 4 +- backend/graph/schema.resolvers.go | 30 ++- .../20240403115154_add_model_to_each_flow.sql | 11 + backend/models/flows.sql | 4 +- frontend/generated/graphql.schema.json | 59 ++++- frontend/generated/graphql.ts | 75 +++++- .../components/Sidebar/NewTask/NewTask.css.ts | 6 +- .../src/layouts/AppLayout/AppLayout.graphql | 4 + frontend/src/layouts/AppLayout/AppLayout.tsx | 7 +- frontend/src/pages/ChatPage/ChatPage.graphql | 5 +- frontend/src/pages/ChatPage/ChatPage.tsx | 1 + 21 files changed, 438 insertions(+), 45 deletions(-) create mode 100644 backend/migrations/20240403115154_add_model_to_each_flow.sql diff --git a/backend/config/config.go b/backend/config/config.go index 31faebe..1e73e20 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -18,7 +18,7 @@ type config struct { OpenAIServerURL string `env:"OPEN_AI_SERVER_URL" envDefault:"https://api.openai.com/v1"` // Ollama - OllamaModel string `env:"OLLAMA_MODEL" envDefault:"llama2"` + OllamaModel string `env:"OLLAMA_MODEL"` OllamaServerURL string `env:"OLLAMA_SERVER_URL" envDefault:"http://localhost:11434"` } diff --git a/backend/database/containers.sql.go b/backend/database/containers.sql.go index 82dad5a..b6646e7 100644 --- a/backend/database/containers.sql.go +++ b/backend/database/containers.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 // source: containers.sql package database diff --git a/backend/database/db.go b/backend/database/db.go index 61f5bf4..a3cc795 100644 --- a/backend/database/db.go +++ b/backend/database/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database diff --git a/backend/database/flows.sql.go b/backend/database/flows.sql.go index f404ee1..ae29731 100644 --- a/backend/database/flows.sql.go +++ b/backend/database/flows.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 // source: flows.sql package database @@ -12,22 +12,28 @@ import ( const createFlow = `-- name: CreateFlow :one INSERT INTO flows ( - name, status, container_id + name, status, container_id, model ) VALUES ( - ?, ?, ? + ?, ?, ?, ? ) -RETURNING id, created_at, updated_at, name, status, container_id +RETURNING id, created_at, updated_at, name, status, container_id, model ` type CreateFlowParams struct { Name sql.NullString Status sql.NullString ContainerID sql.NullInt64 + Model sql.NullString } func (q *Queries) CreateFlow(ctx context.Context, arg CreateFlowParams) (Flow, error) { - row := q.db.QueryRowContext(ctx, createFlow, arg.Name, arg.Status, arg.ContainerID) + row := q.db.QueryRowContext(ctx, createFlow, + arg.Name, + arg.Status, + arg.ContainerID, + arg.Model, + ) var i Flow err := row.Scan( &i.ID, @@ -36,13 +42,14 @@ func (q *Queries) CreateFlow(ctx context.Context, arg CreateFlowParams) (Flow, e &i.Name, &i.Status, &i.ContainerID, + &i.Model, ) return i, err } const readAllFlows = `-- name: ReadAllFlows :many SELECT - f.id, f.created_at, f.updated_at, f.name, f.status, f.container_id, + f.id, f.created_at, f.updated_at, f.name, f.status, f.container_id, f.model, c.name AS container_name FROM flows f LEFT JOIN containers c ON f.container_id = c.id @@ -56,6 +63,7 @@ type ReadAllFlowsRow struct { Name sql.NullString Status sql.NullString ContainerID sql.NullInt64 + Model sql.NullString ContainerName sql.NullString } @@ -75,6 +83,7 @@ func (q *Queries) ReadAllFlows(ctx context.Context) ([]ReadAllFlowsRow, error) { &i.Name, &i.Status, &i.ContainerID, + &i.Model, &i.ContainerName, ); err != nil { return nil, err @@ -92,7 +101,7 @@ func (q *Queries) ReadAllFlows(ctx context.Context) ([]ReadAllFlowsRow, error) { const readFlow = `-- name: ReadFlow :one SELECT - f.id, f.created_at, f.updated_at, f.name, f.status, f.container_id, + f.id, f.created_at, f.updated_at, f.name, f.status, f.container_id, f.model, c.name AS container_name, c.image AS container_image, c.status AS container_status, @@ -109,6 +118,7 @@ type ReadFlowRow struct { Name sql.NullString Status sql.NullString ContainerID sql.NullInt64 + Model sql.NullString ContainerName sql.NullString ContainerImage sql.NullString ContainerStatus sql.NullString @@ -125,6 +135,7 @@ func (q *Queries) ReadFlow(ctx context.Context, id int64) (ReadFlowRow, error) { &i.Name, &i.Status, &i.ContainerID, + &i.Model, &i.ContainerName, &i.ContainerImage, &i.ContainerStatus, @@ -137,7 +148,7 @@ const updateFlowContainer = `-- name: UpdateFlowContainer :one UPDATE flows SET container_id = ? WHERE id = ? -RETURNING id, created_at, updated_at, name, status, container_id +RETURNING id, created_at, updated_at, name, status, container_id, model ` type UpdateFlowContainerParams struct { @@ -155,6 +166,7 @@ func (q *Queries) UpdateFlowContainer(ctx context.Context, arg UpdateFlowContain &i.Name, &i.Status, &i.ContainerID, + &i.Model, ) return i, err } @@ -163,7 +175,7 @@ const updateFlowName = `-- name: UpdateFlowName :one UPDATE flows SET name = ? WHERE id = ? -RETURNING id, created_at, updated_at, name, status, container_id +RETURNING id, created_at, updated_at, name, status, container_id, model ` type UpdateFlowNameParams struct { @@ -181,6 +193,7 @@ func (q *Queries) UpdateFlowName(ctx context.Context, arg UpdateFlowNameParams) &i.Name, &i.Status, &i.ContainerID, + &i.Model, ) return i, err } @@ -189,7 +202,7 @@ const updateFlowStatus = `-- name: UpdateFlowStatus :one UPDATE flows SET status = ? WHERE id = ? -RETURNING id, created_at, updated_at, name, status, container_id +RETURNING id, created_at, updated_at, name, status, container_id, model ` type UpdateFlowStatusParams struct { @@ -207,6 +220,7 @@ func (q *Queries) UpdateFlowStatus(ctx context.Context, arg UpdateFlowStatusPara &i.Name, &i.Status, &i.ContainerID, + &i.Model, ) return i, err } diff --git a/backend/database/logs.sql.go b/backend/database/logs.sql.go index ad63cee..587f689 100644 --- a/backend/database/logs.sql.go +++ b/backend/database/logs.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 // source: logs.sql package database diff --git a/backend/database/models.go b/backend/database/models.go index 795487b..a80c884 100644 --- a/backend/database/models.go +++ b/backend/database/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package database @@ -24,6 +24,7 @@ type Flow struct { Name sql.NullString Status sql.NullString ContainerID sql.NullInt64 + Model sql.NullString } type Log struct { diff --git a/backend/database/tasks.sql.go b/backend/database/tasks.sql.go index 2203c7c..3e67685 100644 --- a/backend/database/tasks.sql.go +++ b/backend/database/tasks.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 // source: tasks.sql package database diff --git a/backend/executor/queue.go b/backend/executor/queue.go index b16b9c6..babf42a 100644 --- a/backend/executor/queue.go +++ b/backend/executor/queue.go @@ -48,9 +48,15 @@ func CleanQueue(flowId int64) { func ProcessQueue(flowId int64, db *database.Queries) { log.Println("Starting tasks processor for queue", flowId) - provider, err := providers.ProviderFactory(providers.ProviderOllama) + flow, err := db.ReadFlow(context.Background(), flowId) - log.Println("Using provider: ", provider.Name()) + if err != nil { + log.Printf("failed to get provider: %v", err) + CleanQueue(flowId) + return + } + + provider, err := providers.ProviderFactory(providers.ProviderType(flow.Model.String)) if err != nil { log.Printf("failed to get provider: %v", err) @@ -58,6 +64,8 @@ func ProcessQueue(flowId int64, db *database.Queries) { return } + log.Println("Using provider: ", provider.Name()) + go func() { for { select { diff --git a/backend/graph/generated.go b/backend/graph/generated.go index 79632f6..cbff079 100644 --- a/backend/graph/generated.go +++ b/backend/graph/generated.go @@ -58,6 +58,7 @@ type ComplexityRoot struct { Flow struct { Browser func(childComplexity int) int ID func(childComplexity int) int + Model func(childComplexity int) int Name func(childComplexity int) int Status func(childComplexity int) int Tasks func(childComplexity int) int @@ -70,15 +71,16 @@ type ComplexityRoot struct { } Mutation struct { - CreateFlow func(childComplexity int) int + CreateFlow func(childComplexity int, model string) int CreateTask func(childComplexity int, flowID uint, query string) int Exec func(childComplexity int, containerID string, command string) int FinishFlow func(childComplexity int, flowID uint) int } Query struct { - Flow func(childComplexity int, id uint) int - Flows func(childComplexity int) int + AvailableModels func(childComplexity int) int + Flow func(childComplexity int, id uint) int + Flows func(childComplexity int) int } Subscription struct { @@ -107,12 +109,13 @@ type ComplexityRoot struct { } type MutationResolver interface { - CreateFlow(ctx context.Context) (*gmodel.Flow, error) + CreateFlow(ctx context.Context, model string) (*gmodel.Flow, error) CreateTask(ctx context.Context, flowID uint, query string) (*gmodel.Task, error) FinishFlow(ctx context.Context, flowID uint) (*gmodel.Flow, error) Exec(ctx context.Context, containerID string, command string) (string, error) } type QueryResolver interface { + AvailableModels(ctx context.Context) ([]string, error) Flows(ctx context.Context) ([]*gmodel.Flow, error) Flow(ctx context.Context, id uint) (*gmodel.Flow, error) } @@ -171,6 +174,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Flow.ID(childComplexity), true + case "Flow.model": + if e.complexity.Flow.Model == nil { + break + } + + return e.complexity.Flow.Model(childComplexity), true + case "Flow.name": if e.complexity.Flow.Name == nil { break @@ -218,7 +228,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in break } - return e.complexity.Mutation.CreateFlow(childComplexity), true + args, err := ec.field_Mutation_createFlow_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Mutation.CreateFlow(childComplexity, args["model"].(string)), true case "Mutation.createTask": if e.complexity.Mutation.CreateTask == nil { @@ -256,6 +271,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Mutation.FinishFlow(childComplexity, args["flowId"].(uint)), true + case "Query.availableModels": + if e.complexity.Query.AvailableModels == nil { + break + } + + return e.complexity.Query.AvailableModels(childComplexity), true + case "Query.flow": if e.complexity.Query.Flow == nil { break @@ -564,6 +586,21 @@ func (ec *executionContext) field_Mutation__exec_args(ctx context.Context, rawAr return args, nil } +func (ec *executionContext) field_Mutation_createFlow_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 string + if tmp, ok := rawArgs["model"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("model")) + arg0, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["model"] = arg0 + return args, nil +} + func (ec *executionContext) field_Mutation_createTask_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -1113,6 +1150,50 @@ func (ec *executionContext) fieldContext_Flow_status(ctx context.Context, field return fc, nil } +func (ec *executionContext) _Flow_model(ctx context.Context, field graphql.CollectedField, obj *gmodel.Flow) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Flow_model(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Model, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Flow_model(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Flow", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Log_id(ctx context.Context, field graphql.CollectedField, obj *gmodel.Log) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Log_id(ctx, field) if err != nil { @@ -1215,7 +1296,7 @@ func (ec *executionContext) _Mutation_createFlow(ctx context.Context, field grap }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Mutation().CreateFlow(rctx) + return ec.resolvers.Mutation().CreateFlow(rctx, fc.Args["model"].(string)) }) if err != nil { ec.Error(ctx, err) @@ -1252,10 +1333,23 @@ func (ec *executionContext) fieldContext_Mutation_createFlow(ctx context.Context return ec.fieldContext_Flow_browser(ctx, field) case "status": return ec.fieldContext_Flow_status(ctx, field) + case "model": + return ec.fieldContext_Flow_model(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type Flow", field.Name) }, } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Mutation_createFlow_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return fc, err + } return fc, nil } @@ -1381,6 +1475,8 @@ func (ec *executionContext) fieldContext_Mutation_finishFlow(ctx context.Context return ec.fieldContext_Flow_browser(ctx, field) case "status": return ec.fieldContext_Flow_status(ctx, field) + case "model": + return ec.fieldContext_Flow_model(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type Flow", field.Name) }, @@ -1454,6 +1550,50 @@ func (ec *executionContext) fieldContext_Mutation__exec(ctx context.Context, fie return fc, nil } +func (ec *executionContext) _Query_availableModels(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_availableModels(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().AvailableModels(rctx) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.([]string) + fc.Result = res + return ec.marshalNString2ᚕstringᚄ(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Query_availableModels(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Query", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Query_flows(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Query_flows(ctx, field) if err != nil { @@ -1505,6 +1645,8 @@ func (ec *executionContext) fieldContext_Query_flows(ctx context.Context, field return ec.fieldContext_Flow_browser(ctx, field) case "status": return ec.fieldContext_Flow_status(ctx, field) + case "model": + return ec.fieldContext_Flow_model(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type Flow", field.Name) }, @@ -1563,6 +1705,8 @@ func (ec *executionContext) fieldContext_Query_flow(ctx context.Context, field g return ec.fieldContext_Flow_browser(ctx, field) case "status": return ec.fieldContext_Flow_status(ctx, field) + case "model": + return ec.fieldContext_Flow_model(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type Flow", field.Name) }, @@ -1934,6 +2078,8 @@ func (ec *executionContext) fieldContext_Subscription_flowUpdated(ctx context.Co return ec.fieldContext_Flow_browser(ctx, field) case "status": return ec.fieldContext_Flow_status(ctx, field) + case "model": + return ec.fieldContext_Flow_model(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type Flow", field.Name) }, @@ -4414,6 +4560,11 @@ func (ec *executionContext) _Flow(ctx context.Context, sel ast.SelectionSet, obj if out.Values[i] == graphql.Null { out.Invalids++ } + case "model": + out.Values[i] = ec._Flow_model(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -4570,6 +4721,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr switch field.Name { case "__typename": out.Values[i] = graphql.MarshalString("Query") + case "availableModels": + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_availableModels(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, + func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + } + + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) case "flows": field := field @@ -5302,6 +5475,38 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S return res } +func (ec *executionContext) unmarshalNString2ᚕstringᚄ(ctx context.Context, v interface{}) ([]string, error) { + var vSlice []interface{} + if v != nil { + vSlice = graphql.CoerceList(v) + } + var err error + res := make([]string, len(vSlice)) + for i := range vSlice { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i)) + res[i], err = ec.unmarshalNString2string(ctx, vSlice[i]) + if err != nil { + return nil, err + } + } + return res, nil +} + +func (ec *executionContext) marshalNString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler { + ret := make(graphql.Array, len(v)) + for i := range v { + ret[i] = ec.marshalNString2string(ctx, sel, v[i]) + } + + for _, e := range ret { + if e == graphql.Null { + return graphql.Null + } + } + + return ret +} + func (ec *executionContext) marshalNTask2githubᚗcomᚋsemanserᚋaiᚑcoderᚋgraphᚋmodelᚐTask(ctx context.Context, sel ast.SelectionSet, v gmodel.Task) graphql.Marshaler { return ec._Task(ctx, sel, &v) } diff --git a/backend/graph/model/models_gen.go b/backend/graph/model/models_gen.go index 9f9be6e..dd14c90 100644 --- a/backend/graph/model/models_gen.go +++ b/backend/graph/model/models_gen.go @@ -21,6 +21,7 @@ type Flow struct { Terminal *Terminal `json:"terminal"` Browser *Browser `json:"browser"` Status FlowStatus `json:"status"` + Model string `json:"model"` } type Log struct { diff --git a/backend/graph/schema.graphqls b/backend/graph/schema.graphqls index 3c64b5e..2052711 100644 --- a/backend/graph/schema.graphqls +++ b/backend/graph/schema.graphqls @@ -56,15 +56,17 @@ type Flow { terminal: Terminal! browser: Browser! status: FlowStatus! + model: String! } type Query { + availableModels: [String!]! flows: [Flow!]! flow(id: Uint!): Flow! } type Mutation { - createFlow: Flow! + createFlow(model: String!): Flow! createTask(flowId: Uint!, query: String!): Task! finishFlow(flowId: Uint!): Flow! diff --git a/backend/graph/schema.resolvers.go b/backend/graph/schema.resolvers.go index 8c9dc5c..fd00a91 100644 --- a/backend/graph/schema.resolvers.go +++ b/backend/graph/schema.resolvers.go @@ -12,6 +12,7 @@ import ( "fmt" "log" + "github.com/semanser/ai-coder/config" "github.com/semanser/ai-coder/database" "github.com/semanser/ai-coder/executor" gmodel "github.com/semanser/ai-coder/graph/model" @@ -20,10 +21,15 @@ import ( ) // CreateFlow is the resolver for the createFlow field. -func (r *mutationResolver) CreateFlow(ctx context.Context) (*gmodel.Flow, error) { +func (r *mutationResolver) CreateFlow(ctx context.Context, model string) (*gmodel.Flow, error) { + if model == "" { + return nil, fmt.Errorf("modelID is required") + } + flow, err := r.Db.CreateFlow(ctx, database.CreateFlowParams{ Name: database.StringToNullString("New Task"), Status: database.StringToNullString("in_progress"), + Model: database.StringToNullString(model), }) if err != nil { @@ -36,6 +42,7 @@ func (r *mutationResolver) CreateFlow(ctx context.Context) (*gmodel.Flow, error) ID: uint(flow.ID), Name: flow.Name.String, Status: gmodel.FlowStatus(flow.Status.String), + Model: flow.Model.String, }, nil } @@ -65,10 +72,6 @@ func (r *mutationResolver) CreateTask(ctx context.Context, flowID uint, query st executor.AddCommand(int64(flowID), task) - if err != nil { - return nil, fmt.Errorf("failed to execute command: %w", err) - } - return &gmodel.Task{ ID: uint(task.ID), Message: task.Message.String, @@ -125,6 +128,21 @@ func (r *mutationResolver) Exec(ctx context.Context, containerID string, command return b.String(), nil } +// AvailableModels is the resolver for the availableModels field. +func (r *queryResolver) AvailableModels(ctx context.Context) ([]string, error) { + var availableModels []string + + if config.Config.OpenAIKey != "" && config.Config.OpenAIModel != "" { + availableModels = append(availableModels, config.Config.OpenAIModel) + } + + if config.Config.OllamaModel != "" { + availableModels = append(availableModels, config.Config.OllamaModel) + } + + return availableModels, nil +} + // Flows is the resolver for the flows field. func (r *queryResolver) Flows(ctx context.Context) ([]*gmodel.Flow, error) { flows, err := r.Db.ReadAllFlows(ctx) @@ -149,6 +167,7 @@ func (r *queryResolver) Flows(ctx context.Context) ([]*gmodel.Flow, error) { }, Tasks: gTasks, Status: gmodel.FlowStatus(flow.Status.String), + Model: flow.Model.String, }) } @@ -217,6 +236,7 @@ func (r *queryResolver) Flow(ctx context.Context, id uint) (*gmodel.Flow, error) URL: "", ScreenshotURL: "", }, + Model: flow.Model.String, } return gFlow, nil diff --git a/backend/migrations/20240403115154_add_model_to_each_flow.sql b/backend/migrations/20240403115154_add_model_to_each_flow.sql new file mode 100644 index 0000000..043be4e --- /dev/null +++ b/backend/migrations/20240403115154_add_model_to_each_flow.sql @@ -0,0 +1,11 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE flows +ADD COLUMN model TEXT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE flows +DROP COLUMN model; +-- +goose StatementEnd diff --git a/backend/models/flows.sql b/backend/models/flows.sql index 25fe9b8..9cc8b4a 100644 --- a/backend/models/flows.sql +++ b/backend/models/flows.sql @@ -1,9 +1,9 @@ -- name: CreateFlow :one INSERT INTO flows ( - name, status, container_id + name, status, container_id, model ) VALUES ( - ?, ?, ? + ?, ?, ?, ? ) RETURNING *; diff --git a/frontend/generated/graphql.schema.json b/frontend/generated/graphql.schema.json index d842d34..96e520d 100644 --- a/frontend/generated/graphql.schema.json +++ b/frontend/generated/graphql.schema.json @@ -100,6 +100,22 @@ "isDeprecated": false, "deprecationReason": null }, + { + "name": "model", + "description": null, + "args": [], + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + }, + "isDeprecated": false, + "deprecationReason": null + }, { "name": "name", "description": null, @@ -311,7 +327,24 @@ { "name": "createFlow", "description": null, - "args": [], + "args": [ + { + "name": "model", + "description": null, + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + }, + "defaultValue": null, + "isDeprecated": false, + "deprecationReason": null + } + ], "type": { "kind": "NON_NULL", "name": null, @@ -417,6 +450,30 @@ "name": "Query", "description": null, "fields": [ + { + "name": "availableModels", + "description": null, + "args": [], + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "LIST", + "name": null, + "ofType": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + } + } + }, + "isDeprecated": false, + "deprecationReason": null + }, { "name": "flow", "description": null, diff --git a/frontend/generated/graphql.ts b/frontend/generated/graphql.ts index a7308ab..696cb07 100644 --- a/frontend/generated/graphql.ts +++ b/frontend/generated/graphql.ts @@ -30,6 +30,7 @@ export type Flow = { __typename?: 'Flow'; browser: Browser; id: Scalars['Uint']['output']; + model: Scalars['String']['output']; name: Scalars['String']['output']; status: FlowStatus; tasks: Array; @@ -62,6 +63,11 @@ export type Mutation_ExecArgs = { }; +export type MutationCreateFlowArgs = { + model: Scalars['String']['input']; +}; + + export type MutationCreateTaskArgs = { flowId: Scalars['Uint']['input']; query: Scalars['String']['input']; @@ -74,6 +80,7 @@ export type MutationFinishFlowArgs = { export type Query = { __typename?: 'Query'; + availableModels: Array; flow: Flow; flows: Array; }; @@ -181,6 +188,7 @@ export const FlowFragmentFragmentDoc = gql` id name status + model terminal { containerName connected @@ -209,6 +217,15 @@ export const FlowsDocument = gql` export function useFlowsQuery(options?: Omit, 'query'>) { return Urql.useQuery({ query: FlowsDocument, ...options }); }; +export const AvailableModelsDocument = gql` + query availableModels { + availableModels +} + `; + +export function useAvailableModelsQuery(options?: Omit, 'query'>) { + return Urql.useQuery({ query: AvailableModelsDocument, ...options }); +}; export const FlowDocument = gql` query flow($id: Uint!) { flow(id: $id) { @@ -221,8 +238,8 @@ export function useFlowQuery(options: Omit return Urql.useQuery({ query: FlowDocument, ...options }); }; export const CreateFlowDocument = gql` - mutation createFlow { - createFlow { + mutation createFlow($model: String!) { + createFlow(model: $model) { id name } @@ -311,22 +328,29 @@ export type FlowsQueryVariables = Exact<{ [key: string]: never; }>; export type FlowsQuery = { __typename?: 'Query', flows: Array<{ __typename?: 'Flow', id: any, name: string, status: FlowStatus }> }; +export type AvailableModelsQueryVariables = Exact<{ [key: string]: never; }>; + + +export type AvailableModelsQuery = { __typename?: 'Query', availableModels: Array }; + export type TaskFragmentFragment = { __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }; export type LogFragmentFragment = { __typename?: 'Log', text: string, id: any }; export type BrowserFragmentFragment = { __typename?: 'Browser', url: string, screenshotUrl: string }; -export type FlowFragmentFragment = { __typename?: 'Flow', id: any, name: string, status: FlowStatus, terminal: { __typename?: 'Terminal', containerName: string, connected: boolean, logs: Array<{ __typename?: 'Log', text: string, id: any }> }, browser: { __typename?: 'Browser', url: string, screenshotUrl: string }, tasks: Array<{ __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }> }; +export type FlowFragmentFragment = { __typename?: 'Flow', id: any, name: string, status: FlowStatus, model: string, terminal: { __typename?: 'Terminal', containerName: string, connected: boolean, logs: Array<{ __typename?: 'Log', text: string, id: any }> }, browser: { __typename?: 'Browser', url: string, screenshotUrl: string }, tasks: Array<{ __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }> }; export type FlowQueryVariables = Exact<{ id: Scalars['Uint']['input']; }>; -export type FlowQuery = { __typename?: 'Query', flow: { __typename?: 'Flow', id: any, name: string, status: FlowStatus, terminal: { __typename?: 'Terminal', containerName: string, connected: boolean, logs: Array<{ __typename?: 'Log', text: string, id: any }> }, browser: { __typename?: 'Browser', url: string, screenshotUrl: string }, tasks: Array<{ __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }> } }; +export type FlowQuery = { __typename?: 'Query', flow: { __typename?: 'Flow', id: any, name: string, status: FlowStatus, model: string, terminal: { __typename?: 'Terminal', containerName: string, connected: boolean, logs: Array<{ __typename?: 'Log', text: string, id: any }> }, browser: { __typename?: 'Browser', url: string, screenshotUrl: string }, tasks: Array<{ __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }> } }; -export type CreateFlowMutationVariables = Exact<{ [key: string]: never; }>; +export type CreateFlowMutationVariables = Exact<{ + model: Scalars['String']['input']; +}>; export type CreateFlowMutation = { __typename?: 'Mutation', createFlow: { __typename?: 'Flow', id: any, name: string } }; @@ -443,6 +467,17 @@ export default { }, "args": [] }, + { + "name": "model", + "type": { + "kind": "NON_NULL", + "ofType": { + "kind": "SCALAR", + "name": "Any" + } + }, + "args": [] + }, { "name": "name", "type": { @@ -573,7 +608,18 @@ export default { "ofType": null } }, - "args": [] + "args": [ + { + "name": "model", + "type": { + "kind": "NON_NULL", + "ofType": { + "kind": "SCALAR", + "name": "Any" + } + } + } + ] }, { "name": "createTask", @@ -638,6 +684,23 @@ export default { "kind": "OBJECT", "name": "Query", "fields": [ + { + "name": "availableModels", + "type": { + "kind": "NON_NULL", + "ofType": { + "kind": "LIST", + "ofType": { + "kind": "NON_NULL", + "ofType": { + "kind": "SCALAR", + "name": "Any" + } + } + } + }, + "args": [] + }, { "name": "flow", "type": { diff --git a/frontend/src/components/Sidebar/NewTask/NewTask.css.ts b/frontend/src/components/Sidebar/NewTask/NewTask.css.ts index 7513c8c..10d3223 100644 --- a/frontend/src/components/Sidebar/NewTask/NewTask.css.ts +++ b/frontend/src/components/Sidebar/NewTask/NewTask.css.ts @@ -8,7 +8,7 @@ export const wrapperStyles = style([ { display: "block", textDecoration: "none", - background: "none", + background: vars.color.gray3, border: "none", textAlign: "left", color: vars.color.gray12, @@ -20,13 +20,13 @@ export const wrapperStyles = style([ selectors: { "&.active": { color: vars.color.primary9, - backgroundColor: vars.color.gray2, + backgroundColor: vars.color.gray5, }, }, ":hover": { color: vars.color.primary9, - backgroundColor: vars.color.gray2, + backgroundColor: vars.color.gray4, }, }, ]); diff --git a/frontend/src/layouts/AppLayout/AppLayout.graphql b/frontend/src/layouts/AppLayout/AppLayout.graphql index c1b5764..931f260 100644 --- a/frontend/src/layouts/AppLayout/AppLayout.graphql +++ b/frontend/src/layouts/AppLayout/AppLayout.graphql @@ -9,3 +9,7 @@ query flows { ...flowOverviewFragment } } + +query availableModels { + availableModels +} diff --git a/frontend/src/layouts/AppLayout/AppLayout.tsx b/frontend/src/layouts/AppLayout/AppLayout.tsx index 2654dd1..5a2e32e 100644 --- a/frontend/src/layouts/AppLayout/AppLayout.tsx +++ b/frontend/src/layouts/AppLayout/AppLayout.tsx @@ -1,12 +1,17 @@ import { Outlet } from "react-router-dom"; import { Sidebar } from "@/components/Sidebar/Sidebar"; -import { FlowStatus, useFlowsQuery } from "@/generated/graphql"; +import { + FlowStatus, + useAvailableModelsQuery, + useFlowsQuery, +} from "@/generated/graphql"; import { wrapperStyles } from "./AppLayout.css"; export const AppLayout = () => { const [{ data }] = useFlowsQuery(); + const [{ data: availableModelsData }] = useAvailableModelsQuery(); const sidebarItems = data?.flows.map((flow) => ({ diff --git a/frontend/src/pages/ChatPage/ChatPage.graphql b/frontend/src/pages/ChatPage/ChatPage.graphql index f3a6a0e..b774932 100644 --- a/frontend/src/pages/ChatPage/ChatPage.graphql +++ b/frontend/src/pages/ChatPage/ChatPage.graphql @@ -22,6 +22,7 @@ fragment flowFragment on Flow { id name status + model terminal { containerName connected @@ -43,8 +44,8 @@ query flow($id: Uint!) { } } -mutation createFlow { - createFlow { +mutation createFlow($model: String!) { + createFlow(model: $model) { id name } diff --git a/frontend/src/pages/ChatPage/ChatPage.tsx b/frontend/src/pages/ChatPage/ChatPage.tsx index 402fb9b..44a603c 100644 --- a/frontend/src/pages/ChatPage/ChatPage.tsx +++ b/frontend/src/pages/ChatPage/ChatPage.tsx @@ -46,6 +46,7 @@ export const ChatPage = () => { true, ); const [activeTab, setActiveTab] = useState("terminal"); + const [selectedModel, setSelectedModel] = useLocalStorage("model", ""); const [{ operation, data }] = useFlowQuery({ pause: isNewFlow, From 4d396d3402e17746bb1f06c3d9805d830866dd12 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Wed, 3 Apr 2024 15:45:54 +0200 Subject: [PATCH 08/14] Add model selector --- backend/database/flows.sql.go | 34 ++- backend/database/models.go | 15 +- backend/executor/queue.go | 4 +- backend/graph/generated.go | 286 +++++++++++++++--- backend/graph/model/models_gen.go | 7 +- backend/graph/schema.graphqls | 11 +- backend/graph/schema.resolvers.go | 42 ++- ...132844_add_model_provider_to_each_flow.sql | 11 + backend/models/flows.sql | 4 +- frontend/generated/graphql.schema.json | 69 ++++- frontend/generated/graphql.ts | 98 ++++-- frontend/package.json | 1 + .../src/components/Dropdown/Dropdown.css.ts | 78 +++++ frontend/src/components/Dropdown/Dropdown.tsx | 28 ++ .../src/components/Messages/Messages.css.ts | 4 + frontend/src/components/Messages/Messages.tsx | 10 +- .../ModelSelector/ModelSelector.css.ts | 28 ++ .../NewTask/ModelSelector/ModelSelector.tsx | 78 +++++ .../components/Sidebar/NewTask/NewTask.css.ts | 13 +- .../components/Sidebar/NewTask/NewTask.tsx | 51 +++- frontend/src/components/Sidebar/Sidebar.tsx | 5 +- .../src/layouts/AppLayout/AppLayout.graphql | 9 +- frontend/src/layouts/AppLayout/AppLayout.tsx | 5 +- frontend/src/pages/ChatPage/ChatPage.graphql | 8 +- frontend/src/pages/ChatPage/ChatPage.tsx | 12 +- frontend/yarn.lock | 116 +++++++ 26 files changed, 896 insertions(+), 131 deletions(-) create mode 100644 backend/migrations/20240403132844_add_model_provider_to_each_flow.sql create mode 100644 frontend/src/components/Dropdown/Dropdown.css.ts create mode 100644 frontend/src/components/Dropdown/Dropdown.tsx create mode 100644 frontend/src/components/Sidebar/NewTask/ModelSelector/ModelSelector.css.ts create mode 100644 frontend/src/components/Sidebar/NewTask/ModelSelector/ModelSelector.tsx diff --git a/backend/database/flows.sql.go b/backend/database/flows.sql.go index ae29731..46a6d2e 100644 --- a/backend/database/flows.sql.go +++ b/backend/database/flows.sql.go @@ -12,19 +12,20 @@ import ( const createFlow = `-- name: CreateFlow :one INSERT INTO flows ( - name, status, container_id, model + name, status, container_id, model, model_provider ) VALUES ( - ?, ?, ?, ? + ?, ?, ?, ?, ? ) -RETURNING id, created_at, updated_at, name, status, container_id, model +RETURNING id, created_at, updated_at, name, status, container_id, model, model_provider ` type CreateFlowParams struct { - Name sql.NullString - Status sql.NullString - ContainerID sql.NullInt64 - Model sql.NullString + Name sql.NullString + Status sql.NullString + ContainerID sql.NullInt64 + Model sql.NullString + ModelProvider sql.NullString } func (q *Queries) CreateFlow(ctx context.Context, arg CreateFlowParams) (Flow, error) { @@ -33,6 +34,7 @@ func (q *Queries) CreateFlow(ctx context.Context, arg CreateFlowParams) (Flow, e arg.Status, arg.ContainerID, arg.Model, + arg.ModelProvider, ) var i Flow err := row.Scan( @@ -43,13 +45,14 @@ func (q *Queries) CreateFlow(ctx context.Context, arg CreateFlowParams) (Flow, e &i.Status, &i.ContainerID, &i.Model, + &i.ModelProvider, ) return i, err } const readAllFlows = `-- name: ReadAllFlows :many SELECT - f.id, f.created_at, f.updated_at, f.name, f.status, f.container_id, f.model, + f.id, f.created_at, f.updated_at, f.name, f.status, f.container_id, f.model, f.model_provider, c.name AS container_name FROM flows f LEFT JOIN containers c ON f.container_id = c.id @@ -64,6 +67,7 @@ type ReadAllFlowsRow struct { Status sql.NullString ContainerID sql.NullInt64 Model sql.NullString + ModelProvider sql.NullString ContainerName sql.NullString } @@ -84,6 +88,7 @@ func (q *Queries) ReadAllFlows(ctx context.Context) ([]ReadAllFlowsRow, error) { &i.Status, &i.ContainerID, &i.Model, + &i.ModelProvider, &i.ContainerName, ); err != nil { return nil, err @@ -101,7 +106,7 @@ func (q *Queries) ReadAllFlows(ctx context.Context) ([]ReadAllFlowsRow, error) { const readFlow = `-- name: ReadFlow :one SELECT - f.id, f.created_at, f.updated_at, f.name, f.status, f.container_id, f.model, + f.id, f.created_at, f.updated_at, f.name, f.status, f.container_id, f.model, f.model_provider, c.name AS container_name, c.image AS container_image, c.status AS container_status, @@ -119,6 +124,7 @@ type ReadFlowRow struct { Status sql.NullString ContainerID sql.NullInt64 Model sql.NullString + ModelProvider sql.NullString ContainerName sql.NullString ContainerImage sql.NullString ContainerStatus sql.NullString @@ -136,6 +142,7 @@ func (q *Queries) ReadFlow(ctx context.Context, id int64) (ReadFlowRow, error) { &i.Status, &i.ContainerID, &i.Model, + &i.ModelProvider, &i.ContainerName, &i.ContainerImage, &i.ContainerStatus, @@ -148,7 +155,7 @@ const updateFlowContainer = `-- name: UpdateFlowContainer :one UPDATE flows SET container_id = ? WHERE id = ? -RETURNING id, created_at, updated_at, name, status, container_id, model +RETURNING id, created_at, updated_at, name, status, container_id, model, model_provider ` type UpdateFlowContainerParams struct { @@ -167,6 +174,7 @@ func (q *Queries) UpdateFlowContainer(ctx context.Context, arg UpdateFlowContain &i.Status, &i.ContainerID, &i.Model, + &i.ModelProvider, ) return i, err } @@ -175,7 +183,7 @@ const updateFlowName = `-- name: UpdateFlowName :one UPDATE flows SET name = ? WHERE id = ? -RETURNING id, created_at, updated_at, name, status, container_id, model +RETURNING id, created_at, updated_at, name, status, container_id, model, model_provider ` type UpdateFlowNameParams struct { @@ -194,6 +202,7 @@ func (q *Queries) UpdateFlowName(ctx context.Context, arg UpdateFlowNameParams) &i.Status, &i.ContainerID, &i.Model, + &i.ModelProvider, ) return i, err } @@ -202,7 +211,7 @@ const updateFlowStatus = `-- name: UpdateFlowStatus :one UPDATE flows SET status = ? WHERE id = ? -RETURNING id, created_at, updated_at, name, status, container_id, model +RETURNING id, created_at, updated_at, name, status, container_id, model, model_provider ` type UpdateFlowStatusParams struct { @@ -221,6 +230,7 @@ func (q *Queries) UpdateFlowStatus(ctx context.Context, arg UpdateFlowStatusPara &i.Status, &i.ContainerID, &i.Model, + &i.ModelProvider, ) return i, err } diff --git a/backend/database/models.go b/backend/database/models.go index a80c884..0080a51 100644 --- a/backend/database/models.go +++ b/backend/database/models.go @@ -18,13 +18,14 @@ type Container struct { } type Flow struct { - ID int64 - CreatedAt sql.NullTime - UpdatedAt sql.NullTime - Name sql.NullString - Status sql.NullString - ContainerID sql.NullInt64 - Model sql.NullString + ID int64 + CreatedAt sql.NullTime + UpdatedAt sql.NullTime + Name sql.NullString + Status sql.NullString + ContainerID sql.NullInt64 + Model sql.NullString + ModelProvider sql.NullString } type Log struct { diff --git a/backend/executor/queue.go b/backend/executor/queue.go index babf42a..19a3c6d 100644 --- a/backend/executor/queue.go +++ b/backend/executor/queue.go @@ -56,7 +56,7 @@ func ProcessQueue(flowId int64, db *database.Queries) { return } - provider, err := providers.ProviderFactory(providers.ProviderType(flow.Model.String)) + provider, err := providers.ProviderFactory(providers.ProviderType(flow.ModelProvider.String)) if err != nil { log.Printf("failed to get provider: %v", err) @@ -64,7 +64,7 @@ func ProcessQueue(flowId int64, db *database.Queries) { return } - log.Println("Using provider: ", provider.Name()) + log.Printf("Using provider: %s. Model: %s\n", provider.Name(), flow.ModelProvider.String) go func() { for { diff --git a/backend/graph/generated.go b/backend/graph/generated.go index cbff079..d9a7c59 100644 --- a/backend/graph/generated.go +++ b/backend/graph/generated.go @@ -70,8 +70,13 @@ type ComplexityRoot struct { Text func(childComplexity int) int } + Model struct { + ID func(childComplexity int) int + Provider func(childComplexity int) int + } + Mutation struct { - CreateFlow func(childComplexity int, model string) int + CreateFlow func(childComplexity int, modelProvider string, modelID string) int CreateTask func(childComplexity int, flowID uint, query string) int Exec func(childComplexity int, containerID string, command string) int FinishFlow func(childComplexity int, flowID uint) int @@ -109,13 +114,13 @@ type ComplexityRoot struct { } type MutationResolver interface { - CreateFlow(ctx context.Context, model string) (*gmodel.Flow, error) + CreateFlow(ctx context.Context, modelProvider string, modelID string) (*gmodel.Flow, error) CreateTask(ctx context.Context, flowID uint, query string) (*gmodel.Task, error) FinishFlow(ctx context.Context, flowID uint) (*gmodel.Flow, error) Exec(ctx context.Context, containerID string, command string) (string, error) } type QueryResolver interface { - AvailableModels(ctx context.Context) ([]string, error) + AvailableModels(ctx context.Context) ([]*gmodel.Model, error) Flows(ctx context.Context) ([]*gmodel.Flow, error) Flow(ctx context.Context, id uint) (*gmodel.Flow, error) } @@ -223,6 +228,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Log.Text(childComplexity), true + case "Model.id": + if e.complexity.Model.ID == nil { + break + } + + return e.complexity.Model.ID(childComplexity), true + + case "Model.provider": + if e.complexity.Model.Provider == nil { + break + } + + return e.complexity.Model.Provider(childComplexity), true + case "Mutation.createFlow": if e.complexity.Mutation.CreateFlow == nil { break @@ -233,7 +252,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Mutation.CreateFlow(childComplexity, args["model"].(string)), true + return e.complexity.Mutation.CreateFlow(childComplexity, args["modelProvider"].(string), args["modelId"].(string)), true case "Mutation.createTask": if e.complexity.Mutation.CreateTask == nil { @@ -590,14 +609,23 @@ func (ec *executionContext) field_Mutation_createFlow_args(ctx context.Context, var err error args := map[string]interface{}{} var arg0 string - if tmp, ok := rawArgs["model"]; ok { - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("model")) + if tmp, ok := rawArgs["modelProvider"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("modelProvider")) arg0, err = ec.unmarshalNString2string(ctx, tmp) if err != nil { return nil, err } } - args["model"] = arg0 + args["modelProvider"] = arg0 + var arg1 string + if tmp, ok := rawArgs["modelId"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("modelId")) + arg1, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["modelId"] = arg1 return args, nil } @@ -1176,9 +1204,9 @@ func (ec *executionContext) _Flow_model(ctx context.Context, field graphql.Colle } return graphql.Null } - res := resTmp.(string) + res := resTmp.(*gmodel.Model) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalNModel2ᚖgithubᚗcomᚋsemanserᚋaiᚑcoderᚋgraphᚋmodelᚐModel(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Flow_model(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1188,7 +1216,13 @@ func (ec *executionContext) fieldContext_Flow_model(ctx context.Context, field g IsMethod: false, IsResolver: false, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - return nil, errors.New("field of type String does not have child fields") + switch field.Name { + case "provider": + return ec.fieldContext_Model_provider(ctx, field) + case "id": + return ec.fieldContext_Model_id(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type Model", field.Name) }, } return fc, nil @@ -1282,6 +1316,94 @@ func (ec *executionContext) fieldContext_Log_text(ctx context.Context, field gra return fc, nil } +func (ec *executionContext) _Model_provider(ctx context.Context, field graphql.CollectedField, obj *gmodel.Model) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Model_provider(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Provider, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Model_provider(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Model", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _Model_id(ctx context.Context, field graphql.CollectedField, obj *gmodel.Model) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Model_id(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.ID, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Model_id(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Model", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Mutation_createFlow(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Mutation_createFlow(ctx, field) if err != nil { @@ -1296,7 +1418,7 @@ func (ec *executionContext) _Mutation_createFlow(ctx context.Context, field grap }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Mutation().CreateFlow(rctx, fc.Args["model"].(string)) + return ec.resolvers.Mutation().CreateFlow(rctx, fc.Args["modelProvider"].(string), fc.Args["modelId"].(string)) }) if err != nil { ec.Error(ctx, err) @@ -1576,9 +1698,9 @@ func (ec *executionContext) _Query_availableModels(ctx context.Context, field gr } return graphql.Null } - res := resTmp.([]string) + res := resTmp.([]*gmodel.Model) fc.Result = res - return ec.marshalNString2ᚕstringᚄ(ctx, field.Selections, res) + return ec.marshalNModel2ᚕᚖgithubᚗcomᚋsemanserᚋaiᚑcoderᚋgraphᚋmodelᚐModelᚄ(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_Query_availableModels(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -1588,7 +1710,13 @@ func (ec *executionContext) fieldContext_Query_availableModels(ctx context.Conte IsMethod: true, IsResolver: true, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - return nil, errors.New("field of type String does not have child fields") + switch field.Name { + case "provider": + return ec.fieldContext_Model_provider(ctx, field) + case "id": + return ec.fieldContext_Model_id(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type Model", field.Name) }, } return fc, nil @@ -4632,6 +4760,50 @@ func (ec *executionContext) _Log(ctx context.Context, sel ast.SelectionSet, obj return out } +var modelImplementors = []string{"Model"} + +func (ec *executionContext) _Model(ctx context.Context, sel ast.SelectionSet, obj *gmodel.Model) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, modelImplementors) + + out := graphql.NewFieldSet(fields) + deferred := make(map[string]*graphql.FieldSet) + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("Model") + case "provider": + out.Values[i] = ec._Model_provider(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "id": + out.Values[i] = ec._Model_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch(ctx) + if out.Invalids > 0 { + return graphql.Null + } + + atomic.AddInt32(&ec.deferred, int32(len(deferred))) + + for label, dfs := range deferred { + ec.processDeferredGroup(graphql.DeferredGroup{ + Label: label, + Path: graphql.GetPath(ctx), + FieldSet: dfs, + Context: ctx, + }) + } + + return out +} + var mutationImplementors = []string{"Mutation"} func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) graphql.Marshaler { @@ -5460,6 +5632,60 @@ func (ec *executionContext) marshalNLog2ᚖgithubᚗcomᚋsemanserᚋaiᚑcoder return ec._Log(ctx, sel, v) } +func (ec *executionContext) marshalNModel2ᚕᚖgithubᚗcomᚋsemanserᚋaiᚑcoderᚋgraphᚋmodelᚐModelᚄ(ctx context.Context, sel ast.SelectionSet, v []*gmodel.Model) graphql.Marshaler { + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalNModel2ᚖgithubᚗcomᚋsemanserᚋaiᚑcoderᚋgraphᚋmodelᚐModel(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + + for _, e := range ret { + if e == graphql.Null { + return graphql.Null + } + } + + return ret +} + +func (ec *executionContext) marshalNModel2ᚖgithubᚗcomᚋsemanserᚋaiᚑcoderᚋgraphᚋmodelᚐModel(ctx context.Context, sel ast.SelectionSet, v *gmodel.Model) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "the requested element is null which the schema does not allow") + } + return graphql.Null + } + return ec._Model(ctx, sel, v) +} + func (ec *executionContext) unmarshalNString2string(ctx context.Context, v interface{}) (string, error) { res, err := graphql.UnmarshalString(v) return res, graphql.ErrorOnPath(ctx, err) @@ -5475,38 +5701,6 @@ func (ec *executionContext) marshalNString2string(ctx context.Context, sel ast.S return res } -func (ec *executionContext) unmarshalNString2ᚕstringᚄ(ctx context.Context, v interface{}) ([]string, error) { - var vSlice []interface{} - if v != nil { - vSlice = graphql.CoerceList(v) - } - var err error - res := make([]string, len(vSlice)) - for i := range vSlice { - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i)) - res[i], err = ec.unmarshalNString2string(ctx, vSlice[i]) - if err != nil { - return nil, err - } - } - return res, nil -} - -func (ec *executionContext) marshalNString2ᚕstringᚄ(ctx context.Context, sel ast.SelectionSet, v []string) graphql.Marshaler { - ret := make(graphql.Array, len(v)) - for i := range v { - ret[i] = ec.marshalNString2string(ctx, sel, v[i]) - } - - for _, e := range ret { - if e == graphql.Null { - return graphql.Null - } - } - - return ret -} - func (ec *executionContext) marshalNTask2githubᚗcomᚋsemanserᚋaiᚑcoderᚋgraphᚋmodelᚐTask(ctx context.Context, sel ast.SelectionSet, v gmodel.Task) graphql.Marshaler { return ec._Task(ctx, sel, &v) } diff --git a/backend/graph/model/models_gen.go b/backend/graph/model/models_gen.go index dd14c90..74e8697 100644 --- a/backend/graph/model/models_gen.go +++ b/backend/graph/model/models_gen.go @@ -21,7 +21,7 @@ type Flow struct { Terminal *Terminal `json:"terminal"` Browser *Browser `json:"browser"` Status FlowStatus `json:"status"` - Model string `json:"model"` + Model *Model `json:"model"` } type Log struct { @@ -29,6 +29,11 @@ type Log struct { Text string `json:"text"` } +type Model struct { + Provider string `json:"provider"` + ID string `json:"id"` +} + type Mutation struct { } diff --git a/backend/graph/schema.graphqls b/backend/graph/schema.graphqls index 2052711..bd0c618 100644 --- a/backend/graph/schema.graphqls +++ b/backend/graph/schema.graphqls @@ -49,6 +49,11 @@ type Browser { screenshotUrl: String! } +type Model { + provider: String! + id: String! +} + type Flow { id: Uint! name: String! @@ -56,17 +61,17 @@ type Flow { terminal: Terminal! browser: Browser! status: FlowStatus! - model: String! + model: Model! } type Query { - availableModels: [String!]! + availableModels: [Model!]! flows: [Flow!]! flow(id: Uint!): Flow! } type Mutation { - createFlow(model: String!): Flow! + createFlow(modelProvider: String!, modelId: String!): Flow! createTask(flowId: Uint!, query: String!): Task! finishFlow(flowId: Uint!): Flow! diff --git a/backend/graph/schema.resolvers.go b/backend/graph/schema.resolvers.go index fd00a91..fb1b27e 100644 --- a/backend/graph/schema.resolvers.go +++ b/backend/graph/schema.resolvers.go @@ -21,15 +21,16 @@ import ( ) // CreateFlow is the resolver for the createFlow field. -func (r *mutationResolver) CreateFlow(ctx context.Context, model string) (*gmodel.Flow, error) { - if model == "" { - return nil, fmt.Errorf("modelID is required") +func (r *mutationResolver) CreateFlow(ctx context.Context, modelProvider string, modelID string) (*gmodel.Flow, error) { + if modelID == "" || modelProvider == "" { + return nil, fmt.Errorf("model is required") } flow, err := r.Db.CreateFlow(ctx, database.CreateFlowParams{ - Name: database.StringToNullString("New Task"), - Status: database.StringToNullString("in_progress"), - Model: database.StringToNullString(model), + Name: database.StringToNullString("New Task"), + Status: database.StringToNullString("in_progress"), + Model: database.StringToNullString(modelID), + ModelProvider: database.StringToNullString(modelProvider), }) if err != nil { @@ -42,7 +43,10 @@ func (r *mutationResolver) CreateFlow(ctx context.Context, model string) (*gmode ID: uint(flow.ID), Name: flow.Name.String, Status: gmodel.FlowStatus(flow.Status.String), - Model: flow.Model.String, + Model: &gmodel.Model{ + Provider: flow.ModelProvider.String, + ID: flow.Model.String, + }, }, nil } @@ -129,15 +133,21 @@ func (r *mutationResolver) Exec(ctx context.Context, containerID string, command } // AvailableModels is the resolver for the availableModels field. -func (r *queryResolver) AvailableModels(ctx context.Context) ([]string, error) { - var availableModels []string +func (r *queryResolver) AvailableModels(ctx context.Context) ([]*gmodel.Model, error) { + var availableModels []*gmodel.Model if config.Config.OpenAIKey != "" && config.Config.OpenAIModel != "" { - availableModels = append(availableModels, config.Config.OpenAIModel) + availableModels = append(availableModels, &gmodel.Model{ + Provider: "openai", + ID: config.Config.OpenAIModel, + }) } if config.Config.OllamaModel != "" { - availableModels = append(availableModels, config.Config.OllamaModel) + availableModels = append(availableModels, &gmodel.Model{ + Provider: "ollama", + ID: config.Config.OllamaModel, + }) } return availableModels, nil @@ -167,7 +177,10 @@ func (r *queryResolver) Flows(ctx context.Context) ([]*gmodel.Flow, error) { }, Tasks: gTasks, Status: gmodel.FlowStatus(flow.Status.String), - Model: flow.Model.String, + Model: &gmodel.Model{ + Provider: flow.ModelProvider.String, + ID: flow.Model.String, + }, }) } @@ -236,7 +249,10 @@ func (r *queryResolver) Flow(ctx context.Context, id uint) (*gmodel.Flow, error) URL: "", ScreenshotURL: "", }, - Model: flow.Model.String, + Model: &gmodel.Model{ + Provider: flow.ModelProvider.String, + ID: flow.Model.String, + }, } return gFlow, nil diff --git a/backend/migrations/20240403132844_add_model_provider_to_each_flow.sql b/backend/migrations/20240403132844_add_model_provider_to_each_flow.sql new file mode 100644 index 0000000..e1c156e --- /dev/null +++ b/backend/migrations/20240403132844_add_model_provider_to_each_flow.sql @@ -0,0 +1,11 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE flows +ADD COLUMN model_provider TEXT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE flows +DROP COLUMN model_provider; +-- +goose StatementEnd diff --git a/backend/models/flows.sql b/backend/models/flows.sql index 9cc8b4a..f78fe17 100644 --- a/backend/models/flows.sql +++ b/backend/models/flows.sql @@ -1,9 +1,9 @@ -- name: CreateFlow :one INSERT INTO flows ( - name, status, container_id, model + name, status, container_id, model, model_provider ) VALUES ( - ?, ?, ?, ? + ?, ?, ?, ?, ? ) RETURNING *; diff --git a/frontend/generated/graphql.schema.json b/frontend/generated/graphql.schema.json index 96e520d..9184e0e 100644 --- a/frontend/generated/graphql.schema.json +++ b/frontend/generated/graphql.schema.json @@ -108,8 +108,8 @@ "kind": "NON_NULL", "name": null, "ofType": { - "kind": "SCALAR", - "name": "String", + "kind": "OBJECT", + "name": "Model", "ofType": null } }, @@ -270,6 +270,49 @@ "enumValues": null, "possibleTypes": null }, + { + "kind": "OBJECT", + "name": "Model", + "description": null, + "fields": [ + { + "name": "id", + "description": null, + "args": [], + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + }, + "isDeprecated": false, + "deprecationReason": null + }, + { + "name": "provider", + "description": null, + "args": [], + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + }, + "isDeprecated": false, + "deprecationReason": null + } + ], + "inputFields": null, + "interfaces": [], + "enumValues": null, + "possibleTypes": null + }, { "kind": "OBJECT", "name": "Mutation", @@ -329,7 +372,23 @@ "description": null, "args": [ { - "name": "model", + "name": "modelId", + "description": null, + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + }, + "defaultValue": null, + "isDeprecated": false, + "deprecationReason": null + }, + { + "name": "modelProvider", "description": null, "type": { "kind": "NON_NULL", @@ -464,8 +523,8 @@ "kind": "NON_NULL", "name": null, "ofType": { - "kind": "SCALAR", - "name": "String", + "kind": "OBJECT", + "name": "Model", "ofType": null } } diff --git a/frontend/generated/graphql.ts b/frontend/generated/graphql.ts index 696cb07..0855764 100644 --- a/frontend/generated/graphql.ts +++ b/frontend/generated/graphql.ts @@ -30,7 +30,7 @@ export type Flow = { __typename?: 'Flow'; browser: Browser; id: Scalars['Uint']['output']; - model: Scalars['String']['output']; + model: Model; name: Scalars['String']['output']; status: FlowStatus; tasks: Array; @@ -48,6 +48,12 @@ export type Log = { text: Scalars['String']['output']; }; +export type Model = { + __typename?: 'Model'; + id: Scalars['String']['output']; + provider: Scalars['String']['output']; +}; + export type Mutation = { __typename?: 'Mutation'; _exec: Scalars['String']['output']; @@ -64,7 +70,8 @@ export type Mutation_ExecArgs = { export type MutationCreateFlowArgs = { - model: Scalars['String']['input']; + modelId: Scalars['String']['input']; + modelProvider: Scalars['String']['input']; }; @@ -80,7 +87,7 @@ export type MutationFinishFlowArgs = { export type Query = { __typename?: 'Query'; - availableModels: Array; + availableModels: Array; flow: Flow; flows: Array; }; @@ -160,6 +167,12 @@ export const FlowOverviewFragmentFragmentDoc = gql` status } `; +export const ModelFragmentFragmentDoc = gql` + fragment modelFragment on Model { + id + provider +} + `; export const LogFragmentFragmentDoc = gql` fragment logFragment on Log { text @@ -188,7 +201,9 @@ export const FlowFragmentFragmentDoc = gql` id name status - model + model { + ...modelFragment + } terminal { containerName connected @@ -203,7 +218,8 @@ export const FlowFragmentFragmentDoc = gql` ...taskFragment } } - ${LogFragmentFragmentDoc} + ${ModelFragmentFragmentDoc} +${LogFragmentFragmentDoc} ${BrowserFragmentFragmentDoc} ${TaskFragmentFragmentDoc}`; export const FlowsDocument = gql` @@ -219,9 +235,11 @@ export function useFlowsQuery(options?: Omit, 'query'>) { return Urql.useQuery({ query: AvailableModelsDocument, ...options }); @@ -238,8 +256,8 @@ export function useFlowQuery(options: Omit return Urql.useQuery({ query: FlowDocument, ...options }); }; export const CreateFlowDocument = gql` - mutation createFlow($model: String!) { - createFlow(model: $model) { + mutation createFlow($modelProvider: String!, $modelId: String!) { + createFlow(modelProvider: $modelProvider, modelId: $modelId) { id name } @@ -323,6 +341,8 @@ export function useBrowserUpdatedSubscription; @@ -331,7 +351,7 @@ export type FlowsQuery = { __typename?: 'Query', flows: Array<{ __typename?: 'Fl export type AvailableModelsQueryVariables = Exact<{ [key: string]: never; }>; -export type AvailableModelsQuery = { __typename?: 'Query', availableModels: Array }; +export type AvailableModelsQuery = { __typename?: 'Query', availableModels: Array<{ __typename?: 'Model', id: string, provider: string }> }; export type TaskFragmentFragment = { __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }; @@ -339,17 +359,18 @@ export type LogFragmentFragment = { __typename?: 'Log', text: string, id: any }; export type BrowserFragmentFragment = { __typename?: 'Browser', url: string, screenshotUrl: string }; -export type FlowFragmentFragment = { __typename?: 'Flow', id: any, name: string, status: FlowStatus, model: string, terminal: { __typename?: 'Terminal', containerName: string, connected: boolean, logs: Array<{ __typename?: 'Log', text: string, id: any }> }, browser: { __typename?: 'Browser', url: string, screenshotUrl: string }, tasks: Array<{ __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }> }; +export type FlowFragmentFragment = { __typename?: 'Flow', id: any, name: string, status: FlowStatus, model: { __typename?: 'Model', id: string, provider: string }, terminal: { __typename?: 'Terminal', containerName: string, connected: boolean, logs: Array<{ __typename?: 'Log', text: string, id: any }> }, browser: { __typename?: 'Browser', url: string, screenshotUrl: string }, tasks: Array<{ __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }> }; export type FlowQueryVariables = Exact<{ id: Scalars['Uint']['input']; }>; -export type FlowQuery = { __typename?: 'Query', flow: { __typename?: 'Flow', id: any, name: string, status: FlowStatus, model: string, terminal: { __typename?: 'Terminal', containerName: string, connected: boolean, logs: Array<{ __typename?: 'Log', text: string, id: any }> }, browser: { __typename?: 'Browser', url: string, screenshotUrl: string }, tasks: Array<{ __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }> } }; +export type FlowQuery = { __typename?: 'Query', flow: { __typename?: 'Flow', id: any, name: string, status: FlowStatus, model: { __typename?: 'Model', id: string, provider: string }, terminal: { __typename?: 'Terminal', containerName: string, connected: boolean, logs: Array<{ __typename?: 'Log', text: string, id: any }> }, browser: { __typename?: 'Browser', url: string, screenshotUrl: string }, tasks: Array<{ __typename?: 'Task', id: any, type: TaskType, message: string, status: TaskStatus, args: any, results: any, createdAt: any }> } }; export type CreateFlowMutationVariables = Exact<{ - model: Scalars['String']['input']; + modelProvider: Scalars['String']['input']; + modelId: Scalars['String']['input']; }>; @@ -472,8 +493,9 @@ export default { "type": { "kind": "NON_NULL", "ofType": { - "kind": "SCALAR", - "name": "Any" + "kind": "OBJECT", + "name": "Model", + "ofType": null } }, "args": [] @@ -562,6 +584,35 @@ export default { ], "interfaces": [] }, + { + "kind": "OBJECT", + "name": "Model", + "fields": [ + { + "name": "id", + "type": { + "kind": "NON_NULL", + "ofType": { + "kind": "SCALAR", + "name": "Any" + } + }, + "args": [] + }, + { + "name": "provider", + "type": { + "kind": "NON_NULL", + "ofType": { + "kind": "SCALAR", + "name": "Any" + } + }, + "args": [] + } + ], + "interfaces": [] + }, { "kind": "OBJECT", "name": "Mutation", @@ -610,7 +661,17 @@ export default { }, "args": [ { - "name": "model", + "name": "modelId", + "type": { + "kind": "NON_NULL", + "ofType": { + "kind": "SCALAR", + "name": "Any" + } + } + }, + { + "name": "modelProvider", "type": { "kind": "NON_NULL", "ofType": { @@ -693,8 +754,9 @@ export default { "ofType": { "kind": "NON_NULL", "ofType": { - "kind": "SCALAR", - "name": "Any" + "kind": "OBJECT", + "name": "Model", + "ofType": null } } } diff --git a/frontend/package.json b/frontend/package.json index af9694c..8b1b99f 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -16,6 +16,7 @@ }, "dependencies": { "@radix-ui/colors": "^3.0.0", + "@radix-ui/react-dropdown-menu": "^2.0.6", "@radix-ui/react-tabs": "^1.0.4", "@radix-ui/react-tooltip": "^1.0.7", "@uidotdev/usehooks": "^2.4.1", diff --git a/frontend/src/components/Dropdown/Dropdown.css.ts b/frontend/src/components/Dropdown/Dropdown.css.ts new file mode 100644 index 0000000..d148dd2 --- /dev/null +++ b/frontend/src/components/Dropdown/Dropdown.css.ts @@ -0,0 +1,78 @@ +import { globalStyle, style } from "@vanilla-extract/css"; + +import { font } from "@/styles/font.css"; +import { vars } from "@/styles/theme.css"; + +export const triggerStyles = style({ + all: "unset", + borderRadius: 6, + + selectors: { + '&[data-state="open"]': { + backgroundColor: vars.color.gray3, + }, + }, +}); + +export const dropdownMenuContentStyles = style({ + minWidth: 220, + backgroundColor: vars.color.gray3, + border: `1px solid ${vars.color.gray4}`, + borderRadius: 6, + padding: 3, + boxShadow: `0 0 10px 2px #12121187`, +}); + +export const dropdownMenuSubContentStyles = dropdownMenuContentStyles; + +export const dropdownMenuItemStyles = style([ + font.textSmMedium, + { + display: "flex", + borderRadius: 3, + alignItems: "center", + height: 32, + padding: "0 3px", + position: "relative", + paddingLeft: 32, + userSelect: "none", + outline: "none", + color: vars.color.gray12, + cursor: "pointer", + + selectors: { + "&[data-highlighted]": { + backgroundColor: vars.color.gray4, + }, + }, + }, +]); + +export const dropdownMenuItemIconStyles = style({ + position: "absolute", + left: 8, + top: 8, + color: vars.color.gray9, + width: 16, + height: 16, +}); + +globalStyle(`${dropdownMenuItemStyles}:hover ${dropdownMenuItemIconStyles}`, { + color: vars.color.primary9, +}); + +export const dropdownMenuSubTriggerStyles = dropdownMenuItemStyles; + +export const dropdownMenuSeparatorStyles = style({ + height: 1, + backgroundColor: vars.color.gray4, + margin: 5, +}); + +export const dropdownMenuRightSlotStyles = style({ + display: "flex", + marginLeft: "auto", + paddingLeft: 20, + top: 4, + color: vars.color.gray9, +}); diff --git a/frontend/src/components/Dropdown/Dropdown.tsx b/frontend/src/components/Dropdown/Dropdown.tsx new file mode 100644 index 0000000..44f9c1c --- /dev/null +++ b/frontend/src/components/Dropdown/Dropdown.tsx @@ -0,0 +1,28 @@ +import * as DropdownMenu from "@radix-ui/react-dropdown-menu"; +import React from "react"; + +import { + dropdownMenuContentStyles, + dropdownMenuItemStyles, + triggerStyles, +} from "./Dropdown.css"; + +type DropdownProps = { + children: React.ReactNode; + content: React.ReactNode; +} & React.ComponentProps; + +export const Dropdown = ({ children, content, ...rest }: DropdownProps) => { + return ( + + + + + {content} + + ); +}; + +export { dropdownMenuItemStyles, dropdownMenuContentStyles }; diff --git a/frontend/src/components/Messages/Messages.css.ts b/frontend/src/components/Messages/Messages.css.ts index f907f2c..5f0090e 100644 --- a/frontend/src/components/Messages/Messages.css.ts +++ b/frontend/src/components/Messages/Messages.css.ts @@ -30,6 +30,10 @@ export const titleStyles = style([ }, ]); +export const modelStyles = style({ + color: vars.color.gray10, +}); + export const newMessageTextarea = style([ font.textSmMedium, { diff --git a/frontend/src/components/Messages/Messages.tsx b/frontend/src/components/Messages/Messages.tsx index c35836c..0ea7b69 100644 --- a/frontend/src/components/Messages/Messages.tsx +++ b/frontend/src/components/Messages/Messages.tsx @@ -1,12 +1,13 @@ import { useEffect, useRef } from "react"; -import { FlowStatus, Task } from "@/generated/graphql"; +import { FlowStatus, Model, Task } from "@/generated/graphql"; import { Button } from "../Button/Button"; import { Message } from "./Message/Message"; import { messagesListWrapper, messagesWrapper, + modelStyles, newMessageTextarea, titleStyles, } from "./Messages.css"; @@ -18,6 +19,7 @@ type MessagesProps = { onFlowStop: () => void; flowStatus?: FlowStatus; isNew?: boolean; + model?: Model; }; export const Messages = ({ @@ -27,6 +29,7 @@ export const Messages = ({ onSubmit, isNew, onFlowStop, + model, }: MessagesProps) => { const messages = tasks.map((task) => ({ @@ -90,14 +93,15 @@ export const Messages = ({
{name && (
- {name}{" "} + {name} + {` - ${model?.id}`}{" "} {isFlowFinished ? ( " (Finished)" ) : ( - )} + )}{" "}
)}
diff --git a/frontend/src/components/Sidebar/NewTask/ModelSelector/ModelSelector.css.ts b/frontend/src/components/Sidebar/NewTask/ModelSelector/ModelSelector.css.ts new file mode 100644 index 0000000..4a3c988 --- /dev/null +++ b/frontend/src/components/Sidebar/NewTask/ModelSelector/ModelSelector.css.ts @@ -0,0 +1,28 @@ +import { style } from "@vanilla-extract/css"; + +import { font } from "@/styles/font.css"; +import { vars } from "@/styles/theme.css"; + +export const buttonStyles = style([ + font.textSmRegular, + { + display: "block", + textDecoration: "none", + background: vars.color.gray3, + border: "none", + textAlign: "left", + color: vars.color.gray10, + padding: "9px 16px", + cursor: "pointer", + borderRadius: "0 6px 6px 0", + flex: 1, + width: "100px", + textOverflow: "ellipsis", + overflow: "hidden", + whiteSpace: "nowrap", + + ":hover": { + backgroundColor: vars.color.gray4, + }, + }, +]); diff --git a/frontend/src/components/Sidebar/NewTask/ModelSelector/ModelSelector.tsx b/frontend/src/components/Sidebar/NewTask/ModelSelector/ModelSelector.tsx new file mode 100644 index 0000000..bf53da8 --- /dev/null +++ b/frontend/src/components/Sidebar/NewTask/ModelSelector/ModelSelector.tsx @@ -0,0 +1,78 @@ +import * as DropdownMenu from "@radix-ui/react-dropdown-menu"; +import { useEffect } from "react"; + +import { + Dropdown, + dropdownMenuContentStyles, + dropdownMenuItemStyles, +} from "@/components/Dropdown/Dropdown"; +import { dropdownMenuItemIconStyles } from "@/components/Dropdown/Dropdown.css"; +import { Icon } from "@/components/Icon/Icon"; +import { Model } from "@/generated/graphql"; + +import { buttonStyles } from "./ModelSelector.css"; + +type ModelSelectorProps = { + availableModels: Model[]; + selectedModel?: Model; + activeModel?: Model; + setSelectedModel: (model: Model) => void; +}; + +export const ModelSelector = ({ + availableModels = [], + selectedModel, + activeModel, + setSelectedModel, +}: ModelSelectorProps) => { + // Automatically select the first available model + useEffect(() => { + if (!activeModel && availableModels[0]) { + setSelectedModel(availableModels[0]); + } + }, [activeModel, availableModels]); + + const handleValueChange = (value: string) => { + const newModel = availableModels.find((model) => model.id === value); + + if (!newModel) return; + + setSelectedModel(newModel); + }; + + const dropdownContent = ( + + + {availableModels.length > 0 ? ( + availableModels.map((model) => ( + + + + + {model.id} + + )) + ) : ( + + No available models + + )} + + + ); + + return ( + +
{activeModel?.id || "No model"}
+
+ ); +}; diff --git a/frontend/src/components/Sidebar/NewTask/NewTask.css.ts b/frontend/src/components/Sidebar/NewTask/NewTask.css.ts index 10d3223..53cda8a 100644 --- a/frontend/src/components/Sidebar/NewTask/NewTask.css.ts +++ b/frontend/src/components/Sidebar/NewTask/NewTask.css.ts @@ -3,7 +3,14 @@ import { style } from "@vanilla-extract/css"; import { font } from "@/styles/font.css"; import { vars } from "@/styles/theme.css"; -export const wrapperStyles = style([ +export const wrapperStyles = style({ + display: "flex", + alignItems: "center", + marginBottom: 16, + justifyContent: "space-between", +}); + +export const linkWrapperStyles = style([ font.textSmSemibold, { display: "block", @@ -14,8 +21,8 @@ export const wrapperStyles = style([ color: vars.color.gray12, padding: "9px 16px", cursor: "pointer", - marginBottom: 16, - borderRadius: 6, + borderRadius: "6px 0 0 6px", + flex: 1, selectors: { "&.active": { diff --git a/frontend/src/components/Sidebar/NewTask/NewTask.tsx b/frontend/src/components/Sidebar/NewTask/NewTask.tsx index a2c2a5e..4b6855e 100644 --- a/frontend/src/components/Sidebar/NewTask/NewTask.tsx +++ b/frontend/src/components/Sidebar/NewTask/NewTask.tsx @@ -1,11 +1,50 @@ -import { NavLink } from "react-router-dom"; +import { useLocalStorage } from "@uidotdev/usehooks"; +import { useNavigate } from "react-router-dom"; -import { wrapperStyles } from "./NewTask.css"; +import { Tooltip } from "@/components/Tooltip/Tooltip"; +import { Model } from "@/generated/graphql"; + +import { ModelSelector } from "./ModelSelector/ModelSelector"; +import { linkWrapperStyles, wrapperStyles } from "./NewTask.css"; + +type NewTaskProps = { + availableModels: Model[]; +}; + +export const NewTask = ({ availableModels = [] }: NewTaskProps) => { + const navigate = useNavigate(); + const [selectedModel, setSelectedModel] = useLocalStorage( + "model", + ); + const activeModel = availableModels.find( + (model) => model.id == selectedModel?.id, + ); + + const handleNewTask = () => { + navigate("/chat/new"); + }; + + const tooltipContent = activeModel + ? "Create a new flow" + : "Please select a model first"; -export const NewTask = () => { return ( - - ✨ New task - +
+ + + + +
); }; diff --git a/frontend/src/components/Sidebar/Sidebar.tsx b/frontend/src/components/Sidebar/Sidebar.tsx index 17a980d..ddfcf5b 100644 --- a/frontend/src/components/Sidebar/Sidebar.tsx +++ b/frontend/src/components/Sidebar/Sidebar.tsx @@ -4,12 +4,13 @@ import { wrapperStyles } from "./Sidebar.css"; type SidebarProps = { items: MenuItemProps[]; + availableModels: string[]; }; -export const Sidebar = ({ items = [] }: SidebarProps) => { +export const Sidebar = ({ items = [], availableModels = [] }: SidebarProps) => { return (
- + {items.map((item) => ( ))} diff --git a/frontend/src/layouts/AppLayout/AppLayout.graphql b/frontend/src/layouts/AppLayout/AppLayout.graphql index 931f260..14a8c02 100644 --- a/frontend/src/layouts/AppLayout/AppLayout.graphql +++ b/frontend/src/layouts/AppLayout/AppLayout.graphql @@ -4,6 +4,11 @@ fragment flowOverviewFragment on Flow { status } +fragment modelFragment on Model { + id + provider +} + query flows { flows { ...flowOverviewFragment @@ -11,5 +16,7 @@ query flows { } query availableModels { - availableModels + availableModels { + ...modelFragment + } } diff --git a/frontend/src/layouts/AppLayout/AppLayout.tsx b/frontend/src/layouts/AppLayout/AppLayout.tsx index 5a2e32e..9822532 100644 --- a/frontend/src/layouts/AppLayout/AppLayout.tsx +++ b/frontend/src/layouts/AppLayout/AppLayout.tsx @@ -22,7 +22,10 @@ export const AppLayout = () => { return (
- +
); diff --git a/frontend/src/pages/ChatPage/ChatPage.graphql b/frontend/src/pages/ChatPage/ChatPage.graphql index b774932..543b0a3 100644 --- a/frontend/src/pages/ChatPage/ChatPage.graphql +++ b/frontend/src/pages/ChatPage/ChatPage.graphql @@ -22,7 +22,9 @@ fragment flowFragment on Flow { id name status - model + model { + ...modelFragment + } terminal { containerName connected @@ -44,8 +46,8 @@ query flow($id: Uint!) { } } -mutation createFlow($model: String!) { - createFlow(model: $model) { +mutation createFlow($modelProvider: String!, $modelId: String!) { + createFlow(modelProvider: $modelProvider, modelId: $modelId) { id name } diff --git a/frontend/src/pages/ChatPage/ChatPage.tsx b/frontend/src/pages/ChatPage/ChatPage.tsx index 44a603c..f8624b2 100644 --- a/frontend/src/pages/ChatPage/ChatPage.tsx +++ b/frontend/src/pages/ChatPage/ChatPage.tsx @@ -17,6 +17,7 @@ import { import { Terminal } from "@/components/Terminal/Terminal"; import { Tooltip } from "@/components/Tooltip/Tooltip"; import { + Model, useBrowserUpdatedSubscription, useCreateFlowMutation, useCreateTaskMutation, @@ -45,8 +46,8 @@ export const ChatPage = () => { "isFollowingTabs", true, ); + const [selectedModel] = useLocalStorage("model"); const [activeTab, setActiveTab] = useState("terminal"); - const [selectedModel, setSelectedModel] = useLocalStorage("model", ""); const [{ operation, data }] = useFlowQuery({ pause: isNewFlow, @@ -61,6 +62,7 @@ export const ChatPage = () => { const status = !isStaleData ? data?.flow.status : undefined; const terminal = !isStaleData ? data?.flow.terminal : undefined; const browser = !isStaleData ? data?.flow.browser : undefined; + const model = !isStaleData ? data?.flow.model : undefined; useBrowserUpdatedSubscription( { @@ -97,8 +99,11 @@ export const ChatPage = () => { }); const handleSubmit = async (message: string) => { - if (isNewFlow) { - const result = await createFlowMutation({}); + if (isNewFlow && selectedModel.id) { + const result = await createFlowMutation({ + modelProvider: selectedModel.provider, + modelId: selectedModel.id, + }); const flowId = result?.data?.createFlow.id; if (flowId) { @@ -135,6 +140,7 @@ export const ChatPage = () => { flowStatus={status} isNew={isNewFlow} onFlowStop={handleFlowStop} + model={model} /> diff --git a/frontend/yarn.lock b/frontend/yarn.lock index bf622a5..dd8bcb7 100644 --- a/frontend/yarn.lock +++ b/frontend/yarn.lock @@ -1536,6 +1536,37 @@ "@radix-ui/react-use-callback-ref" "1.0.1" "@radix-ui/react-use-escape-keydown" "1.0.3" +"@radix-ui/react-dropdown-menu@^2.0.6": + version "2.0.6" + resolved "https://registry.yarnpkg.com/@radix-ui/react-dropdown-menu/-/react-dropdown-menu-2.0.6.tgz#cdf13c956c5e263afe4e5f3587b3071a25755b63" + integrity sha512-i6TuFOoWmLWq+M/eCLGd/bQ2HfAX1RJgvrBQ6AQLmzfvsLdefxbWu8G9zczcPFfcSPehz9GcpF6K9QYreFV8hA== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/primitive" "1.0.1" + "@radix-ui/react-compose-refs" "1.0.1" + "@radix-ui/react-context" "1.0.1" + "@radix-ui/react-id" "1.0.1" + "@radix-ui/react-menu" "2.0.6" + "@radix-ui/react-primitive" "1.0.3" + "@radix-ui/react-use-controllable-state" "1.0.1" + +"@radix-ui/react-focus-guards@1.0.1": + version "1.0.1" + resolved "https://registry.yarnpkg.com/@radix-ui/react-focus-guards/-/react-focus-guards-1.0.1.tgz#1ea7e32092216b946397866199d892f71f7f98ad" + integrity sha512-Rect2dWbQ8waGzhMavsIbmSVCgYxkXLxxR3ZvCX79JOglzdEy4JXMb98lq4hPxUbLr77nP0UOGf4rcMU+s1pUA== + dependencies: + "@babel/runtime" "^7.13.10" + +"@radix-ui/react-focus-scope@1.0.4": + version "1.0.4" + resolved "https://registry.yarnpkg.com/@radix-ui/react-focus-scope/-/react-focus-scope-1.0.4.tgz#2ac45fce8c5bb33eb18419cdc1905ef4f1906525" + integrity sha512-sL04Mgvf+FmyvZeYfNu1EPAaaxD+aw7cYeIB9L9Fvq8+urhltTRaEo5ysKOpHuKPclsZcSUMKlN05x4u+CINpA== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/react-compose-refs" "1.0.1" + "@radix-ui/react-primitive" "1.0.3" + "@radix-ui/react-use-callback-ref" "1.0.1" + "@radix-ui/react-id@1.0.1": version "1.0.1" resolved "https://registry.yarnpkg.com/@radix-ui/react-id/-/react-id-1.0.1.tgz#73cdc181f650e4df24f0b6a5b7aa426b912c88c0" @@ -1544,6 +1575,31 @@ "@babel/runtime" "^7.13.10" "@radix-ui/react-use-layout-effect" "1.0.1" +"@radix-ui/react-menu@2.0.6": + version "2.0.6" + resolved "https://registry.yarnpkg.com/@radix-ui/react-menu/-/react-menu-2.0.6.tgz#2c9e093c1a5d5daa87304b2a2f884e32288ae79e" + integrity sha512-BVkFLS+bUC8HcImkRKPSiVumA1VPOOEC5WBMiT+QAVsPzW1FJzI9KnqgGxVDPBcql5xXrHkD3JOVoXWEXD8SYA== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/primitive" "1.0.1" + "@radix-ui/react-collection" "1.0.3" + "@radix-ui/react-compose-refs" "1.0.1" + "@radix-ui/react-context" "1.0.1" + "@radix-ui/react-direction" "1.0.1" + "@radix-ui/react-dismissable-layer" "1.0.5" + "@radix-ui/react-focus-guards" "1.0.1" + "@radix-ui/react-focus-scope" "1.0.4" + "@radix-ui/react-id" "1.0.1" + "@radix-ui/react-popper" "1.1.3" + "@radix-ui/react-portal" "1.0.4" + "@radix-ui/react-presence" "1.0.1" + "@radix-ui/react-primitive" "1.0.3" + "@radix-ui/react-roving-focus" "1.0.4" + "@radix-ui/react-slot" "1.0.2" + "@radix-ui/react-use-callback-ref" "1.0.1" + aria-hidden "^1.1.1" + react-remove-scroll "2.5.5" + "@radix-ui/react-popper@1.1.3": version "1.1.3" resolved "https://registry.yarnpkg.com/@radix-ui/react-popper/-/react-popper-1.1.3.tgz#24c03f527e7ac348fabf18c89795d85d21b00b42" @@ -2198,6 +2254,13 @@ argparse@^2.0.1: resolved "https://registry.yarnpkg.com/argparse/-/argparse-2.0.1.tgz#246f50f3ca78a3240f6c997e8a9bd1eac49e4b38" integrity sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q== +aria-hidden@^1.1.1: + version "1.2.4" + resolved "https://registry.yarnpkg.com/aria-hidden/-/aria-hidden-1.2.4.tgz#b78e383fdbc04d05762c78b4a25a501e736c4522" + integrity sha512-y+CcFFwelSXpLZk/7fMB2mUbGtX9lKycf1MWJ7CaTIERyitVlyQx6C+sxcROU2BAJ24OiZyK+8wj2i8AlBoS3A== + dependencies: + tslib "^2.0.0" + array-union@^2.1.0: version "2.1.0" resolved "https://registry.yarnpkg.com/array-union/-/array-union-2.1.0.tgz#b798420adbeb1de828d84acd8a2e23d3efe85e8d" @@ -2680,6 +2743,11 @@ detect-indent@^6.0.0: resolved "https://registry.yarnpkg.com/detect-indent/-/detect-indent-6.1.0.tgz#592485ebbbf6b3b1ab2be175c8393d04ca0d57e6" integrity sha512-reYkTUJAZb9gUuZ2RvVCNhVHdg62RHnJ7WJl8ftMi4diZ6NWlciOzQN88pUhSELEwflJht4oQDv0F0BMlwaYtA== +detect-node-es@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/detect-node-es/-/detect-node-es-1.1.0.tgz#163acdf643330caa0b4cd7c21e7ee7755d6fa493" + integrity sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ== + dir-glob@^3.0.1: version "3.0.1" resolved "https://registry.yarnpkg.com/dir-glob/-/dir-glob-3.0.1.tgz#56dbf73d992a4a93ba1584f4534063fd2e41717f" @@ -3076,6 +3144,11 @@ get-intrinsic@^1.1.3, get-intrinsic@^1.2.4: has-symbols "^1.0.3" hasown "^2.0.0" +get-nonce@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/get-nonce/-/get-nonce-1.0.1.tgz#fdf3f0278073820d2ce9426c18f07481b1e0cdf3" + integrity sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q== + glob-parent@^5.1.2: version "5.1.2" resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.2.tgz#869832c58034fe68a4093c17dc15e8340d8401c4" @@ -4029,6 +4102,25 @@ react-refresh@^0.14.0: resolved "https://registry.yarnpkg.com/react-refresh/-/react-refresh-0.14.0.tgz#4e02825378a5f227079554d4284889354e5f553e" integrity sha512-wViHqhAd8OHeLS/IRMJjTSDHF3U9eWi62F/MledQGPdJGDhodXJ9PBLNGr6WWL7qlH12Mt3TyTpbS+hGXMjCzQ== +react-remove-scroll-bar@^2.3.3: + version "2.3.6" + resolved "https://registry.yarnpkg.com/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.6.tgz#3e585e9d163be84a010180b18721e851ac81a29c" + integrity sha512-DtSYaao4mBmX+HDo5YWYdBWQwYIQQshUV/dVxFxK+KM26Wjwp1gZ6rv6OC3oujI6Bfu6Xyg3TwK533AQutsn/g== + dependencies: + react-style-singleton "^2.2.1" + tslib "^2.0.0" + +react-remove-scroll@2.5.5: + version "2.5.5" + resolved "https://registry.yarnpkg.com/react-remove-scroll/-/react-remove-scroll-2.5.5.tgz#1e31a1260df08887a8a0e46d09271b52b3a37e77" + integrity sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw== + dependencies: + react-remove-scroll-bar "^2.3.3" + react-style-singleton "^2.2.1" + tslib "^2.1.0" + use-callback-ref "^1.3.0" + use-sidecar "^1.1.2" + react-router-dom@^6.22.3: version "6.22.3" resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-6.22.3.tgz#9781415667fd1361a475146c5826d9f16752a691" @@ -4044,6 +4136,15 @@ react-router@6.22.3: dependencies: "@remix-run/router" "1.15.3" +react-style-singleton@^2.2.1: + version "2.2.1" + resolved "https://registry.yarnpkg.com/react-style-singleton/-/react-style-singleton-2.2.1.tgz#f99e420492b2d8f34d38308ff660b60d0b1205b4" + integrity sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g== + dependencies: + get-nonce "^1.0.0" + invariant "^2.2.4" + tslib "^2.0.0" + react@^18.2.0: version "18.2.0" resolved "https://registry.yarnpkg.com/react/-/react-18.2.0.tgz#555bd98592883255fa00de14f1151a917b5d77d5" @@ -4536,6 +4637,21 @@ urql@^4.0.6: "@urql/core" "^4.2.0" wonka "^6.3.2" +use-callback-ref@^1.3.0: + version "1.3.2" + resolved "https://registry.yarnpkg.com/use-callback-ref/-/use-callback-ref-1.3.2.tgz#6134c7f6ff76e2be0b56c809b17a650c942b1693" + integrity sha512-elOQwe6Q8gqZgDA8mrh44qRTQqpIHDcZ3hXTLjBe1i4ph8XpNJnO+aQf3NaG+lriLopI4HMx9VjQLfPQ6vhnoA== + dependencies: + tslib "^2.0.0" + +use-sidecar@^1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/use-sidecar/-/use-sidecar-1.1.2.tgz#2f43126ba2d7d7e117aa5855e5d8f0276dfe73c2" + integrity sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw== + dependencies: + detect-node-es "^1.1.0" + tslib "^2.0.0" + util-deprecate@^1.0.1: version "1.0.2" resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" From 3bfe6556c768ccc6d4ba8491e2be2e78ac4b3ce7 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Thu, 4 Apr 2024 11:09:46 +0200 Subject: [PATCH 09/14] Update .env.example --- backend/.env.example | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/backend/.env.example b/backend/.env.example index 252b6e8..2b9d680 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -1,6 +1,16 @@ -OPEN_AI_KEY= -OPEN_AI_SERVER_URL= +# General DATABASE_URL= DOCKER_HOST= +PORT= + +# OpenAI +OPEN_AI_KEY= +OPEN_AI_SERVER_URL= OPEN_AI_MODEL= -PORT= \ No newline at end of file + +# Ollama +OLLAMA_MODEL= + +# Goose +GOOSE_DRIVER= +GOOSE_DBSTRING= From cdb7d8072259368ab98652b8783b313765912016 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Thu, 4 Apr 2024 11:10:05 +0200 Subject: [PATCH 10/14] Update readme --- README.md | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 1d88883..b4eb7c5 100644 --- a/README.md +++ b/README.md @@ -14,47 +14,35 @@ - 🤳 Self-hosted - 💅 Modern UI -# Usage +# Getting started The simplest way to run Codel is to use a pre-built Docker image. You can find the latest image on the [Github Container Registry](https://github.com/semanser/codel/pkgs/container/codel). > [!IMPORTANT] -> Don't forget to set the required environment variables. +> You need to use a corresponding environment variable in order to use any of the supported language models. -```bash -docker run \ - -e OPEN_AI_KEY= \ - -p 3000:8080 \ - -v /var/run/docker.sock:/var/run/docker.sock \ - ghcr.io/semanser/codel:latest -``` -Alternatively, you can create a .env file and run the Docker image with the following command: ```bash docker run \ - --env-file .env \ + -e OPEN_AI_KEY=your_open_ai_key \ # Use any of the supported language models -p 3000:8080 \ -v /var/run/docker.sock:/var/run/docker.sock \ ghcr.io/semanser/codel:latest ``` -Now you can visit [localhost:3000](localhost:3000) in your browser and start using Codel. +Alternatively, you can create a `.env` file and run the Docker image with the `--env-file` flag. More information can be found [here](https://docs.docker.com/reference/cli/docker/container/run/#env) -
- Required environment variables - - - `OPEN_AI_KEY` - OpenAI API key -
+Now you can visit [localhost:3000](localhost:3000) in your browser and start using Codel.
- Optional environment variables - - - `OPEN_AI_MODEL` - OpenAI model (default: gpt-4-0125-preview). The list of supported OpenAI models can be found [here](https://pkg.go.dev/github.com/sashabaranov/go-openai#pkg-constants). - - `DATABASE_URL` - PostgreSQL database URL (eg. `postgres://user:password@localhost:5432/database`) - - `DOCKER_HOST` - Docker SDK API (eg. `DOCKER_HOST=unix:///Users//Library/Containers/com.docker.docker/Data/docker.raw.sock`) [more info](https://stackoverflow.com/a/62757128/5922857) - - `PORT` - Port to run the server in the Docker container (default: 8080) + Supported environment variables + * `OPEN_AI_KEY` - OpenAI API key. You can get the key [here](https://platform.openai.com/account/api-keys). + * `OPEN_AI_MODEL` - OpenAI model (default: gpt-4-0125-preview). The list of supported OpenAI models can be found [here](https://pkg.go.dev/github.com/sashabaranov/go-openai#pkg-constants). + * `OPEN_AI_SERVER_URL` - OpenAI server URL (default: https://api.openai.com/v1). Change this URL if you are using an OpenAI compatible server. + * `OLLAMA_MODEL` - locally hosted Ollama model (default: https://ollama.com/model). The list of supported Ollama models can be found [here](https://ollama.com/models). See backend [.env.example](./backend/.env.example) for more details. +
# Development From 76ef9e30726bb043ff8d073a29da62818732dc45 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Thu, 4 Apr 2024 11:16:26 +0200 Subject: [PATCH 11/14] Fix TS type --- frontend/src/components/Sidebar/Sidebar.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/src/components/Sidebar/Sidebar.tsx b/frontend/src/components/Sidebar/Sidebar.tsx index ddfcf5b..ffec87b 100644 --- a/frontend/src/components/Sidebar/Sidebar.tsx +++ b/frontend/src/components/Sidebar/Sidebar.tsx @@ -1,10 +1,11 @@ +import { Model } from "@/generated/graphql"; import { MenuItem, MenuItemProps } from "./MenuItem/MenuItem"; import { NewTask } from "./NewTask/NewTask"; import { wrapperStyles } from "./Sidebar.css"; type SidebarProps = { items: MenuItemProps[]; - availableModels: string[]; + availableModels: Model[]; }; export const Sidebar = ({ items = [], availableModels = [] }: SidebarProps) => { From 32d4f447a3b43a7445d4170bae376373e88eddaf Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Thu, 4 Apr 2024 11:54:47 +0200 Subject: [PATCH 12/14] Implement host.docker.internal --- backend/providers/ollama.go | 41 ++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/backend/providers/ollama.go b/backend/providers/ollama.go index 8fc5cd2..dfffbab 100644 --- a/backend/providers/ollama.go +++ b/backend/providers/ollama.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log" + "os" "github.com/semanser/ai-coder/assets" "github.com/semanser/ai-coder/config" @@ -26,10 +27,14 @@ func (p OllamaProvider) New() Provider { model := config.Config.OllamaModel baseURL := config.Config.OllamaServerURL + if isRunningInDockerContainer() { + baseURL = fmt.Sprintf("http://host.docker.internal:%s", "11434") + } + client, err := ollama.New( ollama.WithModel(model), - ollama.WithFormat("json"), ollama.WithServerURL(baseURL), + ollama.WithFormat("json"), ) if err != nil { @@ -49,8 +54,16 @@ func (p OllamaProvider) Name() ProviderType { } func (p OllamaProvider) Summary(query string, n int) (string, error) { + model := config.Config.OllamaModel + baseURL := config.Config.OllamaServerURL + + if isRunningInDockerContainer() { + baseURL = fmt.Sprintf("http://host.docker.internal:%s", "11434") + } + client, err := ollama.New( - ollama.WithModel(p.model), + ollama.WithModel(model), + ollama.WithServerURL(baseURL), ) if err != nil { @@ -61,8 +74,16 @@ func (p OllamaProvider) Summary(query string, n int) (string, error) { } func (p OllamaProvider) DockerImageName(task string) (string, error) { + model := config.Config.OllamaModel + baseURL := config.Config.OllamaServerURL + + if isRunningInDockerContainer() { + baseURL = fmt.Sprintf("http://host.docker.internal:%s", "11434") + } + client, err := ollama.New( - ollama.WithModel(p.model), + ollama.WithModel(model), + ollama.WithServerURL(baseURL), ) if err != nil { @@ -153,3 +174,17 @@ To use a tool, respond with a JSON object with the following structure: Always use a tool. Always reply with valid JOSN. Always include a message. `, string(bs)) } + +// Source: https://paulbradley.org/indocker/ +func isRunningInDockerContainer() bool { + // docker creates a .dockerenv file at the root + // of the directory tree inside the container. + // if this file exists then the viewer is running + // from inside a container so return true + + if _, err := os.Stat("/.dockerenv"); err == nil { + return true + } + + return false +} From f61b91b33fd7e7770837ef3ceebaff12863f0d6b Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Thu, 4 Apr 2024 12:00:55 +0200 Subject: [PATCH 13/14] Improve docs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b4eb7c5..7f1cee1 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ The simplest way to run Codel is to use a pre-built Docker image. You can find t ```bash docker run \ - -e OPEN_AI_KEY=your_open_ai_key \ # Use any of the supported language models + -e OPEN_AI_KEY=your_open_ai_key \ # Replace OPEN_AI_KEY with any other supported model -p 3000:8080 \ -v /var/run/docker.sock:/var/run/docker.sock \ ghcr.io/semanser/codel:latest From 63b7817f0d1349e96317f1a55b6bf4e1f55b74f5 Mon Sep 17 00:00:00 2001 From: Andriy Semenets Date: Thu, 4 Apr 2024 12:05:55 +0200 Subject: [PATCH 14/14] Improve docs --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7f1cee1..3aa2ba4 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,12 @@ The simplest way to run Codel is to use a pre-built Docker image. You can find t > [!IMPORTANT] > You need to use a corresponding environment variable in order to use any of the supported language models. - +You can run the Docker image with the following command. Remove or change the environment variables according to your needs. ```bash docker run \ - -e OPEN_AI_KEY=your_open_ai_key \ # Replace OPEN_AI_KEY with any other supported model + -e OPEN_AI_KEY=your_open_ai_key \ + -e OPEN_AI_MODEL=gpt-4-0125-preview \ + -e OLLAMA_MODEL=llama2 \ -p 3000:8080 \ -v /var/run/docker.sock:/var/run/docker.sock \ ghcr.io/semanser/codel:latest