diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart index 4101364057a9..12a119f92864 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart @@ -551,7 +551,9 @@ final class GenerationConfig { this.topP, this.topK, this.responseMimeType, - this.responseSchema}); + this.responseSchema, + this.presencePenalty, + this.frequencyPenalty}); /// Number of generated responses to return. /// @@ -598,6 +600,12 @@ final class GenerationConfig { /// Note: The default value varies by model. final int? topK; + /// Controls the likelihood of repeating the same words or phrases already generated in the text. + final double? presencePenalty; + + /// Controls the likelihood of repeating words, with the penalty increasing for each repetition. + final double? frequencyPenalty; + /// Output response mimetype of the generated candidate text. /// /// Supported mimetype: @@ -627,6 +635,10 @@ final class GenerationConfig { 'responseMimeType': responseMimeType, if (responseSchema case final responseSchema?) 'responseSchema': responseSchema, + if (presencePenalty case final presencePenalty?) + 'presencePenalty': presencePenalty, + if (frequencyPenalty case final frequencyPenalty?) + 'frequencyPenalty': frequencyPenalty, }; } diff --git a/packages/firebase_vertexai/firebase_vertexai/test/generation_config_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/generation_config_test.dart new file mode 100644 index 000000000000..d056634c1d58 --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/test/generation_config_test.dart @@ -0,0 +1,83 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import 'package:firebase_core/firebase_core.dart'; +import 'package:firebase_vertexai/firebase_vertexai.dart'; +import 'package:flutter_test/flutter_test.dart'; + +import 'mock.dart'; +import 'utils/stub_client.dart'; + +void main() { + setupFirebaseVertexAIMocks(); + // ignore: unused_local_variable + late FirebaseApp app; + setUpAll(() async { + // Initialize Firebase + app = await Firebase.initializeApp(); + }); + group('GenerationConfig Tests', () { + const defaultModelName = 'some-model'; + + (ClientController, GenerativeModel) createModel({ + String modelName = defaultModelName, + }) { + final client = ClientController(); + final model = createModelWithClient( + app: app, + model: modelName, + client: client.client, + location: 'us-central1'); + return (client, model); + } + test('toJson method', (){ + final generationConfig = GenerationConfig(candidateCount: 2, stopSequences: ['a'], maxOutputTokens: 10, + topK: 1, topP: 0.5, frequencyPenalty: 0.5, presencePenalty: 0.5); + final result = { + if (candidateCount case final candidateCount?) + 'candidateCount': candidateCount, + if (stopSequences case final stopSequences?) + 'stopSequences': stopSequences, + if (maxOutputTokens case final maxOutputTokens?) + 'maxOutputTokens': maxOutputTokens, + if (temperature case final temperature?) 'temperature': temperature, + if (topP case final topP?) 'topP': topP, + if (topK case final topK?) 'topK': topK, + 'presencePenalty': presencePenalty, + 'frequencyPenalty': frequencyPenalty, + }; + expect(generationConfig.toJson(), result); + }); + test('Test to check if presencePenalty and frequencyPenalty is include in generateContent', () async { + final (client, model) = createModel(); + final response = await client.checkRequest( + () => model.generateContent([Content.text('Some prompt')], + generationConfig: GenerationConfig(presencePenalty: 0.7, frequencyPenalty: 0.3)), + response: arbitraryGenerateContentResponse, + verifyRequest: (_,request){ + expect(request['generationConfig'], { + 'candidateCount': 1, + 'stopSequences': ['a'], + 'maxOutputTokens': 10, + 'temperature': 0.5, + 'topP': 0.5, + 'topK': 1, + 'presencePenalty' : 0.7, + 'frequencyPenalty' : 0.3, + }); + }, + ); + + }); + }); +}