From 93554cbf276a07cde850b49a0fd2d5cb639e579c Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 10 Feb 2025 16:52:58 +0800 Subject: [PATCH] Support json schema (#196) * Support json schema * Update ai_test.go * Update gemini.go --- ai.go | 23 +++++++++----- ai_test.go | 64 ++++++++++++++++++++++++++++++++++++++ anthropic/anthropic.go | 2 +- chatgpt/chatgpt.go | 34 +++++++++++++++----- config.go | 3 +- gemini/gemini.go | 70 ++++++++++++++++++++++++++++++++++-------- 6 files changed, 165 insertions(+), 31 deletions(-) diff --git a/ai.go b/ai.go index 335ddd9..a2e5e3a 100644 --- a/ai.go +++ b/ai.go @@ -24,17 +24,24 @@ type AI interface { Close() error } -type Function struct { - Name string - Description string - Parameters Schema -} - type Schema struct { Type string `json:"type"` Properties map[string]any `json:"properties"` Enum []string `json:"enum,omitempty"` - Required []string `json:"required"` + Items *Schema `json:"items,omitempty"` + Required []string `json:"required,omitempty"` +} + +type JSONSchema struct { + Name string + Description string + Schema Schema +} + +type Function struct { + Name string + Description string + Parameters Schema } var _ encoding.TextUnmarshaler = new(FunctionCallingMode) @@ -67,7 +74,7 @@ type Model interface { SetMaxTokens(x int64) SetTemperature(x float64) SetTopP(x float64) - SetJSONResponse(b bool) + SetJSONResponse(set bool, schema *JSONSchema) } type Chatbot interface { diff --git a/ai_test.go b/ai_test.go index 6cf5092..2edfda3 100644 --- a/ai_test.go +++ b/ai_test.go @@ -132,6 +132,68 @@ func testImage(t *testing.T, model string, c ai.AI) { } } +func testJSON(t *testing.T, model string, c ai.AI) { + if model == "" { + return + } else { + c.SetModel(model) + } + c.SetTemperature(0) + c.SetJSONResponse(true, nil) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + resp, err := c.Chat(ctx, ai.Text("List the primary colors.")) + if err != nil { + t.Fatal(err) + } + if res := resp.Results(); len(res) == 0 { + t.Fatal("no result") + } else { + t.Log(res[0]) + var a any + if err := json.Unmarshal([]byte(res[0]), &a); err != nil { + t.Fatal(err) + } + } + c.SetJSONResponse(true, &ai.JSONSchema{ + Name: "color list", + Schema: ai.Schema{ + Type: "array", + Items: &ai.Schema{ + Type: "object", + Properties: map[string]any{ + "name": map[string]string{ + "type": "string", + "description": "The name of the color", + }, + "RGB": map[string]string{ + "type": "string", + "description": "The RGB value of the color, in hex", + }, + }, + Required: []string{"name", "RGB"}, + }, + }, + }) + resp, err = c.Chat(ctx, ai.Text("List the primary colors.")) + if err != nil { + t.Fatal(err) + } + if res := resp.Results(); len(res) == 0 { + t.Fatal("no result") + } else { + t.Log(res[0]) + type color struct { + Name, RGB string + } + var v []color + if err := json.Unmarshal([]byte(res[0]), &v); err != nil { + t.Fatal(err) + } + } + c.SetJSONResponse(false, nil) +} + func testFunctionCall(t *testing.T, model string, c ai.AI) { if model == "" { return @@ -292,6 +354,7 @@ func TestGemini(t *testing.T) { if err := testChatSession(model, gemini); err != nil { t.Error(err) } + testJSON(t, model, gemini) testImage(t, os.Getenv("GEMINI_MODEL_FOR_IMAGE"), gemini) testFunctionCall(t, os.Getenv("GEMINI_MODEL_FOR_TOOLS"), gemini) } @@ -320,6 +383,7 @@ func TestChatGPT(t *testing.T) { if err := testChatSession(model, chatgpt); err != nil { t.Error(err) } + testJSON(t, model, chatgpt) testImage(t, os.Getenv("CHATGPT_MODEL_FOR_IMAGE"), chatgpt) testFunctionCall(t, os.Getenv("CHATGPT_MODEL_FOR_TOOLS"), chatgpt) } diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index 8680df6..ebd54a0 100644 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -139,7 +139,7 @@ func (ai *Anthropic) SetTopP(f float64) { ai.topP = &f } func (ai *Anthropic) SetCount(i int64) { fmt.Println("Anthropic doesn't support SetCount") } -func (ai *Anthropic) SetJSONResponse(b bool) { +func (ai *Anthropic) SetJSONResponse(_ bool, _ *ai.JSONSchema) { fmt.Println("Anthropic currently doesn't support SetJSONResponse") } diff --git a/chatgpt/chatgpt.go b/chatgpt/chatgpt.go index 2d52d8d..2dba94f 100644 --- a/chatgpt/chatgpt.go +++ b/chatgpt/chatgpt.go @@ -32,7 +32,7 @@ type ChatGPT struct { temperature *float64 topP *float64 count *int64 - json *bool + json openai.ChatCompletionNewParamsResponseFormatUnion limiter *rate.Limiter } @@ -137,7 +137,29 @@ func (ai *ChatGPT) SetCount(i int64) { ai.count = &i } func (ai *ChatGPT) SetMaxTokens(i int64) { ai.maxTokens = &i } func (ai *ChatGPT) SetTemperature(f float64) { ai.temperature = &f } func (ai *ChatGPT) SetTopP(f float64) { ai.topP = &f } -func (ai *ChatGPT) SetJSONResponse(b bool) { ai.json = &b } +func (ai *ChatGPT) SetJSONResponse(set bool, schema *ai.JSONSchema) { + var responseFormat openai.ChatCompletionNewParamsResponseFormatUnion + if set { + if schema != nil { + var format any + b, _ := json.Marshal(schema.Schema) + _ = json.Unmarshal(b, &format) + responseFormat = shared.ResponseFormatJSONSchemaParam{ + Type: openai.F(shared.ResponseFormatJSONSchemaTypeJSONSchema), + JSONSchema: openai.F(shared.ResponseFormatJSONSchemaJSONSchemaParam{ + Name: openai.String(schema.Name), + Description: openai.String(schema.Description), + Schema: openai.F(format), + }), + } + } else { + responseFormat = shared.ResponseFormatJSONObjectParam{ + Type: openai.F(shared.ResponseFormatJSONObjectTypeJSONObject), + } + } + } + ai.json = responseFormat +} var _ ai.ChatResponse = new(ChatResponse[*openai.ChatCompletion]) @@ -245,12 +267,8 @@ func (c *ChatGPT) createRequest( if c.topP != nil { req.TopP = openai.Float(*c.topP) } - if c.json != nil && *c.json { - req.ResponseFormat = openai.F[openai.ChatCompletionNewParamsResponseFormatUnion]( - openai.ChatCompletionNewParamsResponseFormat{ - Type: openai.F(openai.ChatCompletionNewParamsResponseFormatTypeJSONObject), - }, - ) + if c.json != nil { + req.ResponseFormat = openai.F(c.json) } var msgs []openai.ChatCompletionMessageParamUnion for _, i := range history { diff --git a/config.go b/config.go index aa2e3a3..f100051 100644 --- a/config.go +++ b/config.go @@ -19,6 +19,7 @@ type ModelConfig struct { Temperature *float64 TopP *float64 JSONResponse *bool + JSONSchema *JSONSchema Tools []Function ToolConfig FunctionCallingMode } @@ -37,7 +38,7 @@ func ApplyModelConfig(ai AI, cfg ModelConfig) { ai.SetTopP(*cfg.TopP) } if cfg.JSONResponse != nil { - ai.SetJSONResponse(*cfg.JSONResponse) + ai.SetJSONResponse(*cfg.JSONResponse, cfg.JSONSchema) } ai.SetFunctionCall(cfg.Tools, cfg.ToolConfig) } diff --git a/gemini/gemini.go b/gemini/gemini.go index 66686c6..7e4e04d 100644 --- a/gemini/gemini.go +++ b/gemini/gemini.go @@ -107,6 +107,36 @@ func (ai *Gemini) SetModel(model string) { ai.model.GenerationConfig = ai.config } +func genaiSchema(schema *ai.Schema) (*genai.Schema, error) { + if schema == nil { + return nil, nil + } + p, err := genaiProperties(schema.Properties) + if err != nil { + return nil, err + } + var items *genai.Schema + if schema.Items != nil { + p, err := genaiProperties(schema.Items.Properties) + if err != nil { + return nil, err + } + items = &genai.Schema{ + Type: genaiType(schema.Items.Type), + Properties: p, + Enum: schema.Items.Enum, + Required: schema.Items.Required, + } + } + return &genai.Schema{ + Type: genaiType(schema.Type), + Properties: p, + Enum: schema.Enum, + Items: items, + Required: schema.Required, + }, nil +} + func (gemini *Gemini) SetFunctionCall(f []ai.Function, mode ai.FunctionCallingMode) { if len(f) == 0 { gemini.model.Tools = nil @@ -115,19 +145,14 @@ func (gemini *Gemini) SetFunctionCall(f []ai.Function, mode ai.FunctionCallingMo } var declarations []*genai.FunctionDeclaration for _, i := range f { - p, err := genaiProperties(i.Parameters.Properties) + schema, err := genaiSchema(&i.Parameters) if err != nil { continue } declarations = append(declarations, &genai.FunctionDeclaration{ Name: i.Name, Description: i.Description, - Parameters: &genai.Schema{ - Type: genaiType(i.Parameters.Type), - Properties: p, - Enum: i.Parameters.Enum, - Required: i.Parameters.Required, - }, + Parameters: schema, }) } gemini.model.Tools = []*genai.Tool{{FunctionDeclarations: declarations}} @@ -148,16 +173,35 @@ func (gemini *Gemini) SetFunctionCall(f []ai.Function, mode ai.FunctionCallingMo gemini.model.ToolConfig = nil } } -func (ai *Gemini) SetCount(i int64) { ai.config.SetCandidateCount(int32(i)) } -func (ai *Gemini) SetMaxTokens(i int64) { ai.config.SetMaxOutputTokens(int32(i)) } -func (ai *Gemini) SetTemperature(f float64) { ai.config.SetTemperature(float32(f)) } -func (ai *Gemini) SetTopP(f float64) { ai.config.SetTopP(float32(f)) } -func (ai *Gemini) SetJSONResponse(json bool) { - if json { +func (ai *Gemini) SetCount(i int64) { + ai.config.SetCandidateCount(int32(i)) + ai.model.GenerationConfig = ai.config +} +func (ai *Gemini) SetMaxTokens(i int64) { + ai.config.SetMaxOutputTokens(int32(i)) + ai.model.GenerationConfig = ai.config +} +func (ai *Gemini) SetTemperature(f float64) { + ai.config.SetTemperature(float32(f)) + ai.model.GenerationConfig = ai.config +} +func (ai *Gemini) SetTopP(f float64) { + ai.config.SetTopP(float32(f)) + ai.model.GenerationConfig = ai.config +} +func (ai *Gemini) SetJSONResponse(set bool, schema *ai.JSONSchema) { + if set { ai.config.ResponseMIMEType = "application/json" + if schema != nil { + ai.config.ResponseSchema, _ = genaiSchema(&schema.Schema) + } else { + ai.config.ResponseSchema = nil + } } else { ai.config.ResponseMIMEType = "text/plain" + ai.config.ResponseSchema = nil } + ai.model.GenerationConfig = ai.config } func toParts(src []ai.Part) (dst []genai.Part) {