diff --git a/main.go b/main.go index 463e848..8856179 100644 --- a/main.go +++ b/main.go @@ -15,15 +15,12 @@ import ( openai "github.com/spideyz0r/openai-go" ) -const ( - model = "gpt-3.5-turbo" -) - func main() { help := getopt.BoolLong("help", 'h', "display this help") apiKey := getopt.StringLong("api", 'a', "", "API key (default: OPENAI_API_KEY environment variable)") temperature := getopt.StringLong("temperature", 't', "0.8", "temperature (default: 0.8)") - system_role := getopt.StringLong("system-role", 'r', "You're an expert in everything.", "system role (default: You're an expert in everything. You like speaking.)") + model := getopt.StringLong("model", 'm', "gpt-4", "gpt chat model (default: gpt-4)") + system_role := getopt.StringLong("system-role", 'S', "You're an expert in everything.", "system role (default: You're an expert in everything. You like speaking.)") output_width := getopt.StringLong("output-width", 'w', "80", "output width (default: 80)") delim := getopt.StringLong("delimiter", 'd', "\n", "set the delimiter for the user input (default: new line)") stdin_input := getopt.BoolLong("stdin", 's', "read the message from stdin and exit (default: false)") @@ -65,7 +62,7 @@ func main() { if err != nil { log.Fatal(err) } - output, err := sendMessage(client, messages, float32(t), uint(w)) + output, err := sendMessage(client, messages, float32(t), uint(w), *model) if err != nil { log.Fatal(err) } @@ -87,7 +84,7 @@ func main() { stop := make(chan bool) go spinner(10*time.Millisecond, stop) - output, err := sendMessage(client, messages, float32(t), uint(w)) + output, err := sendMessage(client, messages, float32(t), uint(w), *model) if err != nil { log.Fatal(err) } @@ -96,7 +93,7 @@ func main() { } } -func sendMessage(client *openai.OpenAIClient, messages []openai.Message, t float32, w uint) (string, error) { +func sendMessage(client *openai.OpenAIClient, messages []openai.Message, t float32, w uint, model string) (string, error) { completion, err := client.GetCompletion(model, messages, t) if err != nil { return "", err