Skip to content

Commit

Permalink
Support json schema (#196)
Browse files Browse the repository at this point in the history
* Support json schema

* Update ai_test.go

* Update gemini.go
  • Loading branch information
sunshineplan authored Feb 10, 2025
1 parent c13eed9 commit 93554cb
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 31 deletions.
23 changes: 15 additions & 8 deletions ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,24 @@ type AI interface {
Close() error
}

type Function struct {
Name string
Description string
Parameters Schema
}

type Schema struct {
Type string `json:"type"`
Properties map[string]any `json:"properties"`
Enum []string `json:"enum,omitempty"`
Required []string `json:"required"`
Items *Schema `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
}

type JSONSchema struct {
Name string
Description string
Schema Schema
}

type Function struct {
Name string
Description string
Parameters Schema
}

var _ encoding.TextUnmarshaler = new(FunctionCallingMode)
Expand Down Expand Up @@ -67,7 +74,7 @@ type Model interface {
SetMaxTokens(x int64)
SetTemperature(x float64)
SetTopP(x float64)
SetJSONResponse(b bool)
SetJSONResponse(set bool, schema *JSONSchema)
}

type Chatbot interface {
Expand Down
64 changes: 64 additions & 0 deletions ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,68 @@ func testImage(t *testing.T, model string, c ai.AI) {
}
}

func testJSON(t *testing.T, model string, c ai.AI) {
if model == "" {
return
} else {
c.SetModel(model)
}
c.SetTemperature(0)
c.SetJSONResponse(true, nil)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
resp, err := c.Chat(ctx, ai.Text("List the primary colors."))
if err != nil {
t.Fatal(err)
}
if res := resp.Results(); len(res) == 0 {
t.Fatal("no result")
} else {
t.Log(res[0])
var a any
if err := json.Unmarshal([]byte(res[0]), &a); err != nil {
t.Fatal(err)
}
}
c.SetJSONResponse(true, &ai.JSONSchema{
Name: "color list",
Schema: ai.Schema{
Type: "array",
Items: &ai.Schema{
Type: "object",
Properties: map[string]any{
"name": map[string]string{
"type": "string",
"description": "The name of the color",
},
"RGB": map[string]string{
"type": "string",
"description": "The RGB value of the color, in hex",
},
},
Required: []string{"name", "RGB"},
},
},
})
resp, err = c.Chat(ctx, ai.Text("List the primary colors."))
if err != nil {
t.Fatal(err)
}
if res := resp.Results(); len(res) == 0 {
t.Fatal("no result")
} else {
t.Log(res[0])
type color struct {
Name, RGB string
}
var v []color
if err := json.Unmarshal([]byte(res[0]), &v); err != nil {
t.Fatal(err)
}
}
c.SetJSONResponse(false, nil)
}

func testFunctionCall(t *testing.T, model string, c ai.AI) {
if model == "" {
return
Expand Down Expand Up @@ -292,6 +354,7 @@ func TestGemini(t *testing.T) {
if err := testChatSession(model, gemini); err != nil {
t.Error(err)
}
testJSON(t, model, gemini)
testImage(t, os.Getenv("GEMINI_MODEL_FOR_IMAGE"), gemini)
testFunctionCall(t, os.Getenv("GEMINI_MODEL_FOR_TOOLS"), gemini)
}
Expand Down Expand Up @@ -320,6 +383,7 @@ func TestChatGPT(t *testing.T) {
if err := testChatSession(model, chatgpt); err != nil {
t.Error(err)
}
testJSON(t, model, chatgpt)
testImage(t, os.Getenv("CHATGPT_MODEL_FOR_IMAGE"), chatgpt)
testFunctionCall(t, os.Getenv("CHATGPT_MODEL_FOR_TOOLS"), chatgpt)
}
Expand Down
2 changes: 1 addition & 1 deletion anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (ai *Anthropic) SetTopP(f float64) { ai.topP = &f }
func (ai *Anthropic) SetCount(i int64) {
fmt.Println("Anthropic doesn't support SetCount")
}
func (ai *Anthropic) SetJSONResponse(b bool) {
func (ai *Anthropic) SetJSONResponse(_ bool, _ *ai.JSONSchema) {
fmt.Println("Anthropic currently doesn't support SetJSONResponse")
}

Expand Down
34 changes: 26 additions & 8 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type ChatGPT struct {
temperature *float64
topP *float64
count *int64
json *bool
json openai.ChatCompletionNewParamsResponseFormatUnion

limiter *rate.Limiter
}
Expand Down Expand Up @@ -137,7 +137,29 @@ func (ai *ChatGPT) SetCount(i int64) { ai.count = &i }
func (ai *ChatGPT) SetMaxTokens(i int64) { ai.maxTokens = &i }
func (ai *ChatGPT) SetTemperature(f float64) { ai.temperature = &f }
func (ai *ChatGPT) SetTopP(f float64) { ai.topP = &f }
func (ai *ChatGPT) SetJSONResponse(b bool) { ai.json = &b }
func (ai *ChatGPT) SetJSONResponse(set bool, schema *ai.JSONSchema) {
var responseFormat openai.ChatCompletionNewParamsResponseFormatUnion
if set {
if schema != nil {
var format any
b, _ := json.Marshal(schema.Schema)
_ = json.Unmarshal(b, &format)
responseFormat = shared.ResponseFormatJSONSchemaParam{
Type: openai.F(shared.ResponseFormatJSONSchemaTypeJSONSchema),
JSONSchema: openai.F(shared.ResponseFormatJSONSchemaJSONSchemaParam{
Name: openai.String(schema.Name),
Description: openai.String(schema.Description),
Schema: openai.F(format),
}),
}
} else {
responseFormat = shared.ResponseFormatJSONObjectParam{
Type: openai.F(shared.ResponseFormatJSONObjectTypeJSONObject),
}
}
}
ai.json = responseFormat
}

var _ ai.ChatResponse = new(ChatResponse[*openai.ChatCompletion])

Expand Down Expand Up @@ -245,12 +267,8 @@ func (c *ChatGPT) createRequest(
if c.topP != nil {
req.TopP = openai.Float(*c.topP)
}
if c.json != nil && *c.json {
req.ResponseFormat = openai.F[openai.ChatCompletionNewParamsResponseFormatUnion](
openai.ChatCompletionNewParamsResponseFormat{
Type: openai.F(openai.ChatCompletionNewParamsResponseFormatTypeJSONObject),
},
)
if c.json != nil {
req.ResponseFormat = openai.F(c.json)
}
var msgs []openai.ChatCompletionMessageParamUnion
for _, i := range history {
Expand Down
3 changes: 2 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type ModelConfig struct {
Temperature *float64
TopP *float64
JSONResponse *bool
JSONSchema *JSONSchema
Tools []Function
ToolConfig FunctionCallingMode
}
Expand All @@ -37,7 +38,7 @@ func ApplyModelConfig(ai AI, cfg ModelConfig) {
ai.SetTopP(*cfg.TopP)
}
if cfg.JSONResponse != nil {
ai.SetJSONResponse(*cfg.JSONResponse)
ai.SetJSONResponse(*cfg.JSONResponse, cfg.JSONSchema)
}
ai.SetFunctionCall(cfg.Tools, cfg.ToolConfig)
}
Expand Down
70 changes: 57 additions & 13 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@ func (ai *Gemini) SetModel(model string) {
ai.model.GenerationConfig = ai.config
}

func genaiSchema(schema *ai.Schema) (*genai.Schema, error) {
if schema == nil {
return nil, nil
}
p, err := genaiProperties(schema.Properties)
if err != nil {
return nil, err
}
var items *genai.Schema
if schema.Items != nil {
p, err := genaiProperties(schema.Items.Properties)
if err != nil {
return nil, err
}
items = &genai.Schema{
Type: genaiType(schema.Items.Type),
Properties: p,
Enum: schema.Items.Enum,
Required: schema.Items.Required,
}
}
return &genai.Schema{
Type: genaiType(schema.Type),
Properties: p,
Enum: schema.Enum,
Items: items,
Required: schema.Required,
}, nil
}

func (gemini *Gemini) SetFunctionCall(f []ai.Function, mode ai.FunctionCallingMode) {
if len(f) == 0 {
gemini.model.Tools = nil
Expand All @@ -115,19 +145,14 @@ func (gemini *Gemini) SetFunctionCall(f []ai.Function, mode ai.FunctionCallingMo
}
var declarations []*genai.FunctionDeclaration
for _, i := range f {
p, err := genaiProperties(i.Parameters.Properties)
schema, err := genaiSchema(&i.Parameters)
if err != nil {
continue
}
declarations = append(declarations, &genai.FunctionDeclaration{
Name: i.Name,
Description: i.Description,
Parameters: &genai.Schema{
Type: genaiType(i.Parameters.Type),
Properties: p,
Enum: i.Parameters.Enum,
Required: i.Parameters.Required,
},
Parameters: schema,
})
}
gemini.model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}}
Expand All @@ -148,16 +173,35 @@ func (gemini *Gemini) SetFunctionCall(f []ai.Function, mode ai.FunctionCallingMo
gemini.model.ToolConfig = nil
}
}
func (ai *Gemini) SetCount(i int64) { ai.config.SetCandidateCount(int32(i)) }
func (ai *Gemini) SetMaxTokens(i int64) { ai.config.SetMaxOutputTokens(int32(i)) }
func (ai *Gemini) SetTemperature(f float64) { ai.config.SetTemperature(float32(f)) }
func (ai *Gemini) SetTopP(f float64) { ai.config.SetTopP(float32(f)) }
func (ai *Gemini) SetJSONResponse(json bool) {
if json {
func (ai *Gemini) SetCount(i int64) {
ai.config.SetCandidateCount(int32(i))
ai.model.GenerationConfig = ai.config
}
func (ai *Gemini) SetMaxTokens(i int64) {
ai.config.SetMaxOutputTokens(int32(i))
ai.model.GenerationConfig = ai.config
}
func (ai *Gemini) SetTemperature(f float64) {
ai.config.SetTemperature(float32(f))
ai.model.GenerationConfig = ai.config
}
func (ai *Gemini) SetTopP(f float64) {
ai.config.SetTopP(float32(f))
ai.model.GenerationConfig = ai.config
}
func (ai *Gemini) SetJSONResponse(set bool, schema *ai.JSONSchema) {
if set {
ai.config.ResponseMIMEType = "application/json"
if schema != nil {
ai.config.ResponseSchema, _ = genaiSchema(&schema.Schema)
} else {
ai.config.ResponseSchema = nil
}
} else {
ai.config.ResponseMIMEType = "text/plain"
ai.config.ResponseSchema = nil
}
ai.model.GenerationConfig = ai.config
}

func toParts(src []ai.Part) (dst []genai.Part) {
Expand Down

0 comments on commit 93554cb

Please # to comment.