Skip to content

feat(vertexai): Add presencePenalty and frequencyPenalty to GenerationConfig #17062

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,9 @@ final class GenerationConfig {
this.topP,
this.topK,
this.responseMimeType,
this.responseSchema});
this.responseSchema,
this.presencePenalty,
this.frequencyPenalty});

/// Number of generated responses to return.
///
Expand Down Expand Up @@ -598,6 +600,12 @@ final class GenerationConfig {
/// Note: The default value varies by model.
final int? topK;

/// Controls the likelihood of repeating the same words or phrases already generated in the text.
final double? presencePenalty;

/// Controls the likelihood of repeating words, with the penalty increasing for each repetition.
final double? frequencyPenalty;

/// Output response mimetype of the generated candidate text.
///
/// Supported mimetype:
Expand Down Expand Up @@ -627,6 +635,10 @@ final class GenerationConfig {
'responseMimeType': responseMimeType,
if (responseSchema case final responseSchema?)
'responseSchema': responseSchema,
if (presencePenalty case final presencePenalty?)
'presencePenalty': presencePenalty,
if (frequencyPenalty case final frequencyPenalty?)
'frequencyPenalty': frequencyPenalty,
};
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import 'package:firebase_core/firebase_core.dart';
import 'package:firebase_vertexai/firebase_vertexai.dart';
import 'package:flutter_test/flutter_test.dart';

import 'mock.dart';
import 'utils/stub_client.dart';

void main() {
setupFirebaseVertexAIMocks();
// ignore: unused_local_variable
late FirebaseApp app;
setUpAll(() async {
// Initialize Firebase
app = await Firebase.initializeApp();
});
group('GenerationConfig Tests', () {
const defaultModelName = 'some-model';

(ClientController, GenerativeModel) createModel({
String modelName = defaultModelName,
}) {
final client = ClientController();
final model = createModelWithClient(
app: app,
model: modelName,
client: client.client,
location: 'us-central1');
return (client, model);
}
test('toJson method', (){
final generationConfig = GenerationConfig(candidateCount: 2, stopSequences: ['a'], maxOutputTokens: 10,
topK: 1, topP: 0.5, frequencyPenalty: 0.5, presencePenalty: 0.5);
final result = {
if (candidateCount case final candidateCount?)
'candidateCount': candidateCount,
if (stopSequences case final stopSequences?)
'stopSequences': stopSequences,
if (maxOutputTokens case final maxOutputTokens?)
'maxOutputTokens': maxOutputTokens,
if (temperature case final temperature?) 'temperature': temperature,
if (topP case final topP?) 'topP': topP,
if (topK case final topK?) 'topK': topK,
'presencePenalty': presencePenalty,
'frequencyPenalty': frequencyPenalty,
};
expect(generationConfig.toJson(), result);
});
test('Test to check if presencePenalty and frequencyPenalty is include in generateContent', () async {
final (client, model) = createModel();
final response = await client.checkRequest(
() => model.generateContent([Content.text('Some prompt')],
generationConfig: GenerationConfig(presencePenalty: 0.7, frequencyPenalty: 0.3)),
response: arbitraryGenerateContentResponse,
verifyRequest: (_,request){
expect(request['generationConfig'], {
'candidateCount': 1,
'stopSequences': ['a'],
'maxOutputTokens': 10,
'temperature': 0.5,
'topP': 0.5,
'topK': 1,
'presencePenalty' : 0.7,
'frequencyPenalty' : 0.3,
});
},
);

});
});
}
Loading