From bb84fbd185448bdee5e848e761f094b91365e4c2 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 29 Apr 2024 19:25:26 -0400 Subject: [PATCH] feat(vertexai/genai): constrained decoding (#9731) Expose FunctionCallingMode, which lets the user configure how the model issues function calls. By setting the new field `Model.ToolConfig`, the user can select whether the model can call any provided function, only those in a provided list, or no functions. Regenerating also brought along a few other new types and values, some of which required configuration. --- vertexai/genai/aiplatformpb_veneer.gen.go | 282 ++++++++++++++++++---- vertexai/genai/client.go | 2 + vertexai/genai/client_test.go | 70 ++++-- vertexai/genai/config.yaml | 27 ++- vertexai/genai/example_test.go | 137 +++++++++++ 5 files changed, 448 insertions(+), 70 deletions(-) diff --git a/vertexai/genai/aiplatformpb_veneer.gen.go b/vertexai/genai/aiplatformpb_veneer.gen.go index a8879083f955..63a1e8a5b2e2 100644 --- a/vertexai/genai/aiplatformpb_veneer.gen.go +++ b/vertexai/genai/aiplatformpb_veneer.gen.go @@ -24,13 +24,11 @@ import ( "cloud.google.com/go/vertexai/internal/support" ) -// Blob contains raw media bytes. -// -// Text should not be sent as raw bytes, use the 'text' field. +// Blob contains binary data like images. Use [Text] for text. type Blob struct { // Required. The IANA standard MIME type of the source data. MIMEType string - // Required. Raw bytes for media formats. + // Required. Raw bytes. Data []byte } @@ -64,12 +62,19 @@ const ( BlockedReasonSafety BlockedReason = 1 // BlockedReasonOther means candidates blocked due to other reason. BlockedReasonOther BlockedReason = 2 + // BlockedReasonBlocklist means candidates blocked due to the terms which are included from the + // terminology blocklist. + BlockedReasonBlocklist BlockedReason = 3 + // BlockedReasonProhibitedContent means candidates blocked due to prohibited content. + BlockedReasonProhibitedContent BlockedReason = 4 ) var namesForBlockedReason = map[BlockedReason]string{ - BlockedReasonUnspecified: "BlockedReasonUnspecified", - BlockedReasonSafety: "BlockedReasonSafety", - BlockedReasonOther: "BlockedReasonOther", + BlockedReasonUnspecified: "BlockedReasonUnspecified", + BlockedReasonSafety: "BlockedReasonSafety", + BlockedReasonOther: "BlockedReasonOther", + BlockedReasonBlocklist: "BlockedReasonBlocklist", + BlockedReasonProhibitedContent: "BlockedReasonProhibitedContent", } func (v BlockedReason) String() string { @@ -371,6 +376,69 @@ func (FunctionCall) fromProto(p *pb.FunctionCall) *FunctionCall { } } +// FunctionCallingConfig holds configuration for function calling. +type FunctionCallingConfig struct { + // Optional. Function calling mode. + Mode FunctionCallingMode + // Optional. Function names to call. Only set when the Mode is ANY. Function + // names should match [FunctionDeclaration.name]. With mode set to ANY, model + // will predict a function call from the set of function names provided. + AllowedFunctionNames []string +} + +func (v *FunctionCallingConfig) toProto() *pb.FunctionCallingConfig { + if v == nil { + return nil + } + return &pb.FunctionCallingConfig{ + Mode: pb.FunctionCallingConfig_Mode(v.Mode), + AllowedFunctionNames: v.AllowedFunctionNames, + } +} + +func (FunctionCallingConfig) fromProto(p *pb.FunctionCallingConfig) *FunctionCallingConfig { + if p == nil { + return nil + } + return &FunctionCallingConfig{ + Mode: FunctionCallingMode(p.Mode), + AllowedFunctionNames: p.AllowedFunctionNames, + } +} + +// FunctionCallingMode is function calling mode. +type FunctionCallingMode int32 + +const ( + // FunctionCallingUnspecified means unspecified function calling mode. This value should not be used. + FunctionCallingUnspecified FunctionCallingMode = 0 + // FunctionCallingAuto means default model behavior, model decides to predict either a function call + // or a natural language repspose. + FunctionCallingAuto FunctionCallingMode = 1 + // FunctionCallingAny means model is constrained to always predicting a function call only. + // If "allowed_function_names" are set, the predicted function call will be + // limited to any one of "allowed_function_names", else the predicted + // function call will be any one of the provided "function_declarations". + FunctionCallingAny FunctionCallingMode = 2 + // FunctionCallingNone means model will not predict any function call. Model behavior is same as when + // not passing any function declarations. + FunctionCallingNone FunctionCallingMode = 3 +) + +var namesForFunctionCallingMode = map[FunctionCallingMode]string{ + FunctionCallingUnspecified: "FunctionCallingUnspecified", + FunctionCallingAuto: "FunctionCallingAuto", + FunctionCallingAny: "FunctionCallingAny", + FunctionCallingNone: "FunctionCallingNone", +} + +func (v FunctionCallingMode) String() string { + if n, ok := namesForFunctionCallingMode[v]; ok { + return n + } + return fmt.Sprintf("FunctionCallingMode(%d)", v) +} + // FunctionDeclaration is structured representation of a function declaration as defined by the // [OpenAPI 3.0 specification](https://spec.openapis.org/oas/v3.0.3). Included // in this declaration are the function name and parameters. This @@ -379,8 +447,8 @@ func (FunctionCall) fromProto(p *pb.FunctionCall) *FunctionCall { type FunctionDeclaration struct { // Required. The name of the function to call. // Must start with a letter or an underscore. - // Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum - // length of 64. + // Must be a-z, A-Z, 0-9, or contain underscores, dots and dashes, with a + // maximum length of 64. Name string // Optional. Description and purpose of the function. // Model uses it to decide how and whether to call the function. @@ -389,8 +457,10 @@ type FunctionDeclaration struct { // format. Reflects the Open API 3.03 Parameter Object. string Key: the name // of the parameter. Parameter names are case sensitive. Schema Value: the // Schema defining the type used for the parameter. For function with no - // parameters, this can be left unset. Example with 1 required and 1 optional - // parameter: type: OBJECT properties: + // parameters, this can be left unset. Parameter names must start with a + // letter or an underscore and must only contain chars a-z, A-Z, 0-9, or + // underscores with a maximum length of 64. Example with 1 required and 1 + // optional parameter: type: OBJECT properties: // // param1: // type: STRING @@ -400,6 +470,10 @@ type FunctionDeclaration struct { // required: // - param1 Parameters *Schema + // Optional. Describes the output from this function in JSON Schema format. + // Reflects the Open API 3.03 Response Object. The Schema defines the type + // used for the response value of the function. + Response *Schema } func (v *FunctionDeclaration) toProto() *pb.FunctionDeclaration { @@ -410,6 +484,7 @@ func (v *FunctionDeclaration) toProto() *pb.FunctionDeclaration { Name: v.Name, Description: v.Description, Parameters: v.Parameters.toProto(), + Response: v.Response.toProto(), } } @@ -421,6 +496,7 @@ func (FunctionDeclaration) fromProto(p *pb.FunctionDeclaration) *FunctionDeclara Name: p.Name, Description: p.Description, Parameters: (Schema{}).fromProto(p.Parameters), + Response: (Schema{}).fromProto(p.Response), } } @@ -504,6 +580,18 @@ type GenerationConfig struct { MaxOutputTokens *int32 // Optional. Stop sequences. StopSequences []string + // Optional. Positive penalties. + PresencePenalty *float32 + // Optional. Frequency penalties. + FrequencyPenalty *float32 + // Optional. Output response mimetype of the generated candidate text. + // Supported mimetype: + // - `text/plain`: (default) Text output. + // - `application/json`: JSON response in the candidates. + // The model needs to be prompted to output the appropriate response type, + // otherwise the behavior is undefined. + // This is a preview feature. + ResponseMIMEType string } func (v *GenerationConfig) toProto() *pb.GenerationConfig { @@ -511,12 +599,15 @@ func (v *GenerationConfig) toProto() *pb.GenerationConfig { return nil } return &pb.GenerationConfig{ - Temperature: v.Temperature, - TopP: v.TopP, - TopK: int32pToFloat32p(v.TopK), - CandidateCount: v.CandidateCount, - MaxOutputTokens: v.MaxOutputTokens, - StopSequences: v.StopSequences, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: int32pToFloat32p(v.TopK), + CandidateCount: v.CandidateCount, + MaxOutputTokens: v.MaxOutputTokens, + StopSequences: v.StopSequences, + PresencePenalty: v.PresencePenalty, + FrequencyPenalty: v.FrequencyPenalty, + ResponseMimeType: v.ResponseMIMEType, } } @@ -525,13 +616,41 @@ func (GenerationConfig) fromProto(p *pb.GenerationConfig) *GenerationConfig { return nil } return &GenerationConfig{ - Temperature: p.Temperature, - TopP: p.TopP, - TopK: float32pToInt32p(p.TopK), - CandidateCount: p.CandidateCount, - MaxOutputTokens: p.MaxOutputTokens, - StopSequences: p.StopSequences, + Temperature: p.Temperature, + TopP: p.TopP, + TopK: float32pToInt32p(p.TopK), + CandidateCount: p.CandidateCount, + MaxOutputTokens: p.MaxOutputTokens, + StopSequences: p.StopSequences, + PresencePenalty: p.PresencePenalty, + FrequencyPenalty: p.FrequencyPenalty, + ResponseMIMEType: p.ResponseMimeType, + } +} + +// HarmBlockMethod determines how harm blocking is done. +type HarmBlockMethod int32 + +const ( + // HarmBlockMethodUnspecified means the harm block method is unspecified. + HarmBlockMethodUnspecified HarmBlockMethod = 0 + // HarmBlockMethodSeverity means the harm block method uses both probability and severity scores. + HarmBlockMethodSeverity HarmBlockMethod = 1 + // HarmBlockMethodProbability means the harm block method uses the probability score. + HarmBlockMethodProbability HarmBlockMethod = 2 +) + +var namesForHarmBlockMethod = map[HarmBlockMethod]string{ + HarmBlockMethodUnspecified: "HarmBlockMethodUnspecified", + HarmBlockMethodSeverity: "HarmBlockMethodSeverity", + HarmBlockMethodProbability: "HarmBlockMethodProbability", +} + +func (v HarmBlockMethod) String() string { + if n, ok := namesForHarmBlockMethod[v]; ok { + return n } + return fmt.Sprintf("HarmBlockMethod(%d)", v) } // HarmBlockThreshold specifies probability based thresholds levels for blocking. @@ -627,7 +746,7 @@ func (v HarmProbability) String() string { return fmt.Sprintf("HarmProbability(%d)", v) } -// HarmSeverity is harm severity levels. +// HarmSeverity specifies harm severity levels. type HarmSeverity int32 const ( @@ -741,6 +860,9 @@ type SafetySetting struct { Category HarmCategory // Required. The harm block threshold. Threshold HarmBlockThreshold + // Optional. Specify if the threshold is used for probability or severity + // score. If not specified, the threshold is used for probability score. + Method HarmBlockMethod } func (v *SafetySetting) toProto() *pb.SafetySetting { @@ -750,6 +872,7 @@ func (v *SafetySetting) toProto() *pb.SafetySetting { return &pb.SafetySetting{ Category: pb.HarmCategory(v.Category), Threshold: pb.SafetySetting_HarmBlockThreshold(v.Threshold), + Method: pb.SafetySetting_HarmBlockMethod(v.Method), } } @@ -760,6 +883,7 @@ func (SafetySetting) fromProto(p *pb.SafetySetting) *SafetySetting { return &SafetySetting{ Category: HarmCategory(p.Category), Threshold: HarmBlockThreshold(p.Threshold), + Method: HarmBlockMethod(p.Method), } } @@ -773,23 +897,49 @@ type Schema struct { // Optional. The format of the data. // Supported formats: // - // for NUMBER type: float, double - // for INTEGER type: int32, int64 + // for NUMBER type: "float", "double" + // for INTEGER type: "int32", "int64" + // for STRING type: "email", "byte", etc Format string + // Optional. The title of the Schema. + Title string // Optional. The description of the data. Description string // Optional. Indicates if the value may be null. Nullable bool - // Optional. Schema of the elements of Type.ARRAY. + // Optional. SCHEMA FIELDS FOR TYPE ARRAY + // Schema of the elements of Type.ARRAY. Items *Schema + // Optional. Minimum number of the elements for Type.ARRAY. + MinItems int64 + // Optional. Maximum number of the elements for Type.ARRAY. + MaxItems int64 // Optional. Possible values of the element of Type.STRING with enum format. // For example we can define an Enum Direction as : // {type:STRING, format:enum, enum:["EAST", NORTH", "SOUTH", "WEST"]} Enum []string - // Optional. Properties of Type.OBJECT. + // Optional. SCHEMA FIELDS FOR TYPE OBJECT + // Properties of Type.OBJECT. Properties map[string]*Schema // Optional. Required properties of Type.OBJECT. Required []string + // Optional. Minimum number of the properties for Type.OBJECT. + MinProperties int64 + // Optional. Maximum number of the properties for Type.OBJECT. + MaxProperties int64 + // Optional. SCHEMA FIELDS FOR TYPE INTEGER and NUMBER + // Minimum value of the Type.INTEGER and Type.NUMBER + Minimum float64 + // Optional. Maximum value of the Type.INTEGER and Type.NUMBER + Maximum float64 + // Optional. SCHEMA FIELDS FOR TYPE STRING + // Minimum length of the Type.STRING + MinLength int64 + // Optional. Maximum length of the Type.STRING + MaxLength int64 + // Optional. Pattern of the Type.STRING to restrict a string to a regular + // expression. + Pattern string } func (v *Schema) toProto() *pb.Schema { @@ -797,14 +947,24 @@ func (v *Schema) toProto() *pb.Schema { return nil } return &pb.Schema{ - Type: pb.Type(v.Type), - Format: v.Format, - Description: v.Description, - Nullable: v.Nullable, - Items: v.Items.toProto(), - Enum: v.Enum, - Properties: support.TransformMapValues(v.Properties, (*Schema).toProto), - Required: v.Required, + Type: pb.Type(v.Type), + Format: v.Format, + Title: v.Title, + Description: v.Description, + Nullable: v.Nullable, + Items: v.Items.toProto(), + MinItems: v.MinItems, + MaxItems: v.MaxItems, + Enum: v.Enum, + Properties: support.TransformMapValues(v.Properties, (*Schema).toProto), + Required: v.Required, + MinProperties: v.MinProperties, + MaxProperties: v.MaxProperties, + Minimum: v.Minimum, + Maximum: v.Maximum, + MinLength: v.MinLength, + MaxLength: v.MaxLength, + Pattern: v.Pattern, } } @@ -813,14 +973,24 @@ func (Schema) fromProto(p *pb.Schema) *Schema { return nil } return &Schema{ - Type: Type(p.Type), - Format: p.Format, - Description: p.Description, - Nullable: p.Nullable, - Items: (Schema{}).fromProto(p.Items), - Enum: p.Enum, - Properties: support.TransformMapValues(p.Properties, (Schema{}).fromProto), - Required: p.Required, + Type: Type(p.Type), + Format: p.Format, + Title: p.Title, + Description: p.Description, + Nullable: p.Nullable, + Items: (Schema{}).fromProto(p.Items), + MinItems: p.MinItems, + MaxItems: p.MaxItems, + Enum: p.Enum, + Properties: support.TransformMapValues(p.Properties, (Schema{}).fromProto), + Required: p.Required, + MinProperties: p.MinProperties, + MaxProperties: p.MaxProperties, + Minimum: p.Minimum, + Maximum: p.Maximum, + MinLength: p.MinLength, + MaxLength: p.MaxLength, + Pattern: p.Pattern, } } @@ -861,6 +1031,30 @@ func (Tool) fromProto(p *pb.Tool) *Tool { } } +// ToolConfig configures tools. +type ToolConfig struct { + // Optional. Function calling config. + FunctionCallingConfig *FunctionCallingConfig +} + +func (v *ToolConfig) toProto() *pb.ToolConfig { + if v == nil { + return nil + } + return &pb.ToolConfig{ + FunctionCallingConfig: v.FunctionCallingConfig.toProto(), + } +} + +func (ToolConfig) fromProto(p *pb.ToolConfig) *ToolConfig { + if p == nil { + return nil + } + return &ToolConfig{ + FunctionCallingConfig: (FunctionCallingConfig{}).fromProto(p.FunctionCallingConfig), + } +} + // Type contains the list of OpenAPI data types as defined by // https://swagger.io/docs/specification/data-models/data-types/ type Type int32 diff --git a/vertexai/genai/client.go b/vertexai/genai/client.go index a0f387272a53..de186900abc8 100644 --- a/vertexai/genai/client.go +++ b/vertexai/genai/client.go @@ -95,6 +95,7 @@ type GenerativeModel struct { GenerationConfig SafetySettings []*SafetySetting Tools []*Tool + ToolConfig *ToolConfig // configuration for tools } const defaultMaxOutputTokens = 2048 @@ -145,6 +146,7 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) *pb.Ge Contents: support.TransformSlice(contents, (*Content).toProto), SafetySettings: support.TransformSlice(m.SafetySettings, (*SafetySetting).toProto), Tools: support.TransformSlice(m.Tools, (*Tool).toProto), + ToolConfig: m.ToolConfig.toProto(), GenerationConfig: m.GenerationConfig.toProto(), } } diff --git a/vertexai/genai/client_test.go b/vertexai/genai/client_test.go index 17b06d5d5443..d0e5f7dbb2c6 100644 --- a/vertexai/genai/client_test.go +++ b/vertexai/genai/client_test.go @@ -234,32 +234,52 @@ func TestLive(t *testing.T) { model := client.GenerativeModel(*modelName) model.SetTemperature(0) model.Tools = []*Tool{weatherTool} - session := model.StartChat() - res, err := session.SendMessage(ctx, Text("What is the weather like in New York?")) - if err != nil { - t.Fatal(err) - } - part := res.Candidates[0].Content.Parts[0] - funcall, ok := part.(FunctionCall) - if !ok { - t.Fatalf("want FunctionCall, got %T", part) - } - if g, w := funcall.Name, weatherTool.FunctionDeclarations[0].Name; g != w { - t.Errorf("FunctionCall.Name: got %q, want %q", g, w) - } - if g, c := funcall.Args["location"], "New York"; !strings.Contains(g.(string), c) { - t.Errorf(`FunctionCall.Args["location"]: got %q, want string containing %q`, g, c) - } - res, err = session.SendMessage(ctx, FunctionResponse{ - Name: weatherTool.FunctionDeclarations[0].Name, - Response: map[string]any{ - "weather_there": "cold", - }, + t.Run("funcall", func(t *testing.T) { + session := model.StartChat() + res, err := session.SendMessage(ctx, Text("What is the weather like in New York?")) + if err != nil { + t.Fatal(err) + } + part := res.Candidates[0].Content.Parts[0] + funcall, ok := part.(FunctionCall) + if !ok { + t.Fatalf("want FunctionCall, got %T", part) + } + if g, w := funcall.Name, weatherTool.FunctionDeclarations[0].Name; g != w { + t.Errorf("FunctionCall.Name: got %q, want %q", g, w) + } + if g, c := funcall.Args["location"], "New York"; !strings.Contains(g.(string), c) { + t.Errorf(`FunctionCall.Args["location"]: got %q, want string containing %q`, g, c) + } + res, err = session.SendMessage(ctx, FunctionResponse{ + Name: weatherTool.FunctionDeclarations[0].Name, + Response: map[string]any{ + "weather_there": "cold", + }, + }) + if err != nil { + t.Fatal(err) + } + checkMatch(t, responseString(res), "(it's|it is|weather) .*cold") + }) + t.Run("funcall-none", func(t *testing.T) { + model.ToolConfig = &ToolConfig{ + FunctionCallingConfig: &FunctionCallingConfig{ + Mode: FunctionCallingNone, // never return a FunctionCall part + }, + } + session := model.StartChat() + res, err := session.SendMessage(ctx, Text("What is the weather like in New York?")) + if err != nil { + t.Fatal(err) + } + // We should not find a FunctionCall part. + for _, p := range res.Candidates[0].Content.Parts { + if _, ok := p.(FunctionCall); ok { + t.Fatal("saw FunctionCall") + } + } }) - if err != nil { - t.Fatal(err) - } - checkMatch(t, responseString(res), "(it's|it is|weather) .*cold") }) } diff --git a/vertexai/genai/config.yaml b/vertexai/genai/config.yaml index f223835508a2..dedbe681b0ca 100644 --- a/vertexai/genai/config.yaml +++ b/vertexai/genai/config.yaml @@ -18,6 +18,12 @@ types: valueNames: SafetySetting_HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockUnspecified + SafetySetting_HarmBlockMethod: + name: HarmBlockMethod + protoPrefix: SafetySetting_ + veneerPrefix: HarmBlockMethod + doc: 'determines how harm blocking is done.' + SafetyRating_HarmProbability: name: HarmProbability protoPrefix: SafetyRating_ @@ -50,7 +56,8 @@ types: fields: MimeType: name: MIMEType - docVerb: contains + doc: 'contains binary data like images. Use [Text] for text.' + removeOtherDoc: true FileData: fields: @@ -63,11 +70,23 @@ types: FunctionResponse: + FunctionCallingConfig: + doc: 'holds configuration for function calling.' + + FunctionCallingConfig_Mode: + name: FunctionCallingMode + protoPrefix: FunctionCallingConfig + veneerPrefix: FunctionCalling + valueNames: + FunctionCallingConfig_MODE_UNSPECIFIED: FunctionCallingUnspecified + GenerationConfig: fields: TopK: type: '*int32' convertToFrom: int32pToFloat32p, float32pToInt32p + ResponseMimeType: + name: ResponseMIMEType SafetyRating: docVerb: 'is the' @@ -105,10 +124,16 @@ types: GoogleSearchRetrieval: omit: true + ToolConfig: + doc: 'configures tools.' + Schema: fields: Example: omit: true + Default: + # TODO(jba): protoveneer should treat a *structpb.Value as an any + omit: true CitationMetadata: FunctionDeclaration: diff --git a/vertexai/genai/example_test.go b/vertexai/genai/example_test.go index c549f3b22d8d..229c7e71cf5c 100644 --- a/vertexai/genai/example_test.go +++ b/vertexai/genai/example_test.go @@ -135,6 +135,143 @@ func ExampleChatSession() { printResponse(res) } +func ExampleTool() { + ctx := context.Background() + client, err := genai.NewClient(ctx, projectID, location) + if err != nil { + log.Fatal(err) + } + defer client.Close() + + currentWeather := func(city string) string { + switch city { + case "New York, NY": + return "cold" + case "Miami, FL": + return "warm" + default: + return "unknown" + } + } + + // To use functions / tools, we have to first define a schema that describes + // the function to the model. The schema is similar to OpenAPI 3.0. + // + // In this example, we create a single function that provides the model with + // a weather forecast in a given location. + schema := &genai.Schema{ + Type: genai.TypeObject, + Properties: map[string]*genai.Schema{ + "location": { + Type: genai.TypeString, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: genai.TypeString, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + } + + weatherTool := &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{{ + Name: "CurrentWeather", + Description: "Get the current weather in a given location", + Parameters: schema, + }}, + } + + model := client.GenerativeModel("gemini-1.0-pro") + + // Before initiating a conversation, we tell the model which tools it has + // at its disposal. + model.Tools = []*genai.Tool{weatherTool} + + // For using tools, the chat mode is useful because it provides the required + // chat context. A model needs to have tools supplied to it in the chat + // history so it can use them in subsequent conversations. + // + // The flow of message expected here is: + // + // 1. We send a question to the model + // 2. The model recognizes that it needs to use a tool to answer the question, + // an returns a FunctionCall response asking to use the CurrentWeather + // tool. + // 3. We send a FunctionResponse message, simulating the return value of + // CurrentWeather for the model's query. + // 4. The model provides its text answer in response to this message. + session := model.StartChat() + + res, err := session.SendMessage(ctx, genai.Text("What is the weather like in New York?")) + if err != nil { + log.Fatal(err) + } + + part := res.Candidates[0].Content.Parts[0] + funcall, ok := part.(genai.FunctionCall) + if !ok { + log.Fatalf("expected FunctionCall: %v", part) + } + + if funcall.Name != "CurrentWeather" { + log.Fatalf("expected CurrentWeather: %v", funcall.Name) + } + + // Expect the model to pass a proper string "location" argument to the tool. + locArg, ok := funcall.Args["location"].(string) + if !ok { + log.Fatalf("expected string: %v", funcall.Args["location"]) + } + + weatherData := currentWeather(locArg) + res, err = session.SendMessage(ctx, genai.FunctionResponse{ + Name: weatherTool.FunctionDeclarations[0].Name, + Response: map[string]any{ + "weather": weatherData, + }, + }) + if err != nil { + log.Fatal(err) + } + + printResponse(res) +} + +func ExampleToolConifg() { + // This example shows how to affect how the model uses the tools provided to it. + // By setting the ToolConfig, you can disable function calling. + + // Assume we have created a Model and have set its Tools field with some functions. + // See the Example for Tool for details. + var model *genai.GenerativeModel + + // By default, the model will use the functions in its responses if it thinks they are + // relevant, by returning FunctionCall parts. + // Here we set the model's ToolConfig to disable function calling completely. + model.ToolConfig = &genai.ToolConfig{ + FunctionCallingConfig: &genai.FunctionCallingConfig{ + Mode: genai.FunctionCallingNone, + }, + } + + // Subsequent calls to ChatSession.SendMessage will not result in FunctionCall responses. + session := model.StartChat() + res, err := session.SendMessage(context.Background(), genai.Text("What is the weather like in New York?")) + if err != nil { + log.Fatal(err) + } + for _, part := range res.Candidates[0].Content.Parts { + if _, ok := part.(genai.FunctionCall); ok { + log.Fatal("did not expect FunctionCall") + } + } + + // It is also possible to force a function call by using FunctionCallingAny + // instead of FunctionCallingNone. See the documentation for FunctionCallingMode + // for details. +} + func printResponse(resp *genai.GenerateContentResponse) { for _, cand := range resp.Candidates { for _, part := range cand.Content.Parts {