-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathchat.go
126 lines (107 loc) · 2.79 KB
/
chat.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package gollama
import (
"fmt"
"strings"
)
type ChatOption interface{}
// Chat generates a response to a prompt using the Ollama API.
//
// The first argument is the prompt to generate a response to.
//
// The function takes a variable number of options as arguments. The options are:
// - A slice of strings representing the paths to images that should be passed as vision input.
// - A slice of Tool objects representing the tools that should be available to the model.
//
// The function returns a pointer to a ChatOuput object, which contains the response to the prompt,
// as well as some additional information about the response. If an error occurs, the function
// returns nil and an error.
func (c *Gollama) Chat(prompt string, options ...ChatOption) (*ChatOuput, error) {
var (
temperature float64
seed = c.SeedOrNegative
contextLength = c.ContextLength
promptImages = []PromptImage{}
tools = []Tool{}
format = StructuredFormat{}
)
for _, option := range options {
switch opt := option.(type) {
case PromptImage:
promptImages = append(promptImages, opt)
case []PromptImage:
promptImages = opt
case Tool:
tools = append(tools, opt)
case []Tool:
tools = opt
case StructuredFormat:
format = opt
default:
continue
}
}
if seed < 0 {
temperature = c.TemperatureIfNegativeSeed
}
messages := []chatMessage{}
if c.SystemPrompt != "" {
messages = append(messages, chatMessage{
Role: "system",
Content: c.SystemPrompt,
})
}
userMessage := chatMessage{
Role: "user",
Content: prompt,
}
base64VisionImages := make([]string, 0)
for _, image := range promptImages {
base64image, err := base64EncodeFile(image.Filename)
if err != nil {
return nil, err
}
base64VisionImages = append(base64VisionImages, base64image)
}
if len(base64VisionImages) > 0 {
userMessage.Images = base64VisionImages
}
messages = append(messages, userMessage)
req := chatRequest{
Stream: false,
Model: c.ModelName,
Messages: messages,
Options: chatOptionsRequest{
Seed: seed,
Temperature: temperature,
ContextLength: contextLength,
},
}
if len(tools) > 0 {
req.Tools = &tools
}
if len(format.Properties) > 0 {
req.Format = &format
}
if c.ContextLength != 0 {
req.Options.ContextLength = c.ContextLength
}
var resp chatResponse
err := c.apiPost("/api/chat", &resp, req)
if err != nil {
return nil, err
}
if resp.Model != c.ModelName {
return nil, fmt.Errorf("model don't found")
}
out := &ChatOuput{
Role: resp.Message.Role,
Content: resp.Message.Content,
ToolCalls: resp.Message.ToolCalls,
PromptTokens: resp.PromptEvalCount,
ResponseTokens: resp.EvalCount,
}
if c.TrimSpace {
out.Content = strings.TrimSpace(out.Content)
}
return out, nil
}