Skip to content

Add AI integration tests #7038

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 3 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion firebase-ai/firebase-ai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ android {
targetSdk = targetSdkVersion
baseline = file("lint-baseline.xml")
}
sourceSets { getByName("test").java.srcDirs("src/testUtil") }
sourceSets {
// getByName("test").java.srcDirs("src/testUtil")
getByName("androidTest") { kotlin.srcDirs("src/testUtil") }
}
}

// Enable Kotlin "Explicit API Mode". This causes the Kotlin compiler to fail if any
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.firebase.ai

import androidx.test.platform.app.InstrumentationRegistry
import com.google.firebase.FirebaseApp
import com.google.firebase.FirebaseOptions
import com.google.firebase.ai.type.GenerativeBackend

class AIModels {

companion object {
private val API_KEY: String = ""
private val APP_ID: String = ""
private val PROJECT_ID: String = "fireescape-integ-tests"
// General purpose models
var app: FirebaseApp? = null
var flash2Model: GenerativeModel? = null
var flash2LiteModel: GenerativeModel? = null

/** Returns a list of general purpose models to test */
fun getModels(): List<GenerativeModel> {
if (flash2Model == null) {
setup()
}
return listOf(flash2Model!!, flash2LiteModel!!)
}

fun app(): FirebaseApp {
if (app == null) {
setup()
}
return app!!
}

fun setup() {
val context = InstrumentationRegistry.getInstrumentation().context
app =
FirebaseApp.initializeApp(
context,
FirebaseOptions.Builder()
.setApiKey(API_KEY)
.setApplicationId(APP_ID)
.setProjectId(PROJECT_ID)
.build()
)
flash2Model =
FirebaseAI.getInstance(app!!, GenerativeBackend.vertexAI())
.generativeModel(
modelName = "gemini-2.0-flash",
)
flash2LiteModel =
FirebaseAI.getInstance(app!!, GenerativeBackend.vertexAI())
.generativeModel(
modelName = "gemini-2.0-flash-lite",
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.firebase.ai

import android.graphics.Bitmap
import com.google.firebase.ai.AIModels.Companion.getModels
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.ContentModality
import com.google.firebase.ai.type.CountTokensResponse
import java.io.ByteArrayOutputStream
import kotlinx.coroutines.runBlocking
import org.junit.Test

class CountTokensTests {

/** Ensures that the token count is expected for simple words. */
@Test
fun testCountTokensAmount() {
for (model in getModels()) {
runBlocking {
val response = model.countTokens("this is five different words")
assert(response.totalTokens == 5)
assert(response.promptTokensDetails.size == 1)
assert(response.promptTokensDetails[0].modality == ContentModality.TEXT)
assert(response.promptTokensDetails[0].tokenCount == 5)
}
}
}

/** Ensures that the model returns token counts in the correct modality for text. */
@Test
fun testCountTokensTextModality() {
for (model in getModels()) {
runBlocking {
val response = model.countTokens("this is a text prompt")
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.TEXT))
}
}
}

/** Ensures that the model returns token counts in the correct modality for bitmap images. */
@Test
fun testCountTokensImageModality() {
for (model in getModels()) {
runBlocking {
val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
val response = model.countTokens(bitmap)
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.IMAGE))
}
}
}

/**
* Ensures the model can count tokens for multiple modalities at once, and return the
* corresponding token modalities correctly.
*/
@Test
fun testCountTokensTextAndImageModality() {
for (model in getModels()) {
runBlocking {
val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
val response =
model.countTokens(
Content.Builder().text("this is text").build(),
Content.Builder().image(bitmap).build()
)
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.size == 2)
assert(containsModality(response, ContentModality.TEXT))
assert(containsModality(response, ContentModality.IMAGE))
}
}
}

/**
* Ensures the model can count the tokens for a sent file. Additionally, ensures that the model
* treats this sent file as the modality of the mime type, in this case, a plaintext file has its
* tokens counted as `ContentModality.TEXT`.
*/
@Test
fun testCountTokensTextFileModality() {
for (model in getModels()) {
runBlocking {
val response =
model.countTokens(
Content.Builder().inlineData("this is text".toByteArray(), "text/plain").build()
)
checkTokenCountsMatch(response)
assert(response.totalTokens == 3)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.TEXT))
}
}
}

/**
* Ensures the model can count the tokens for a sent file. Additionally, ensures that the model
* treats this sent file as the modality of the mime type, in this case, a PNG encoded bitmap has
* its tokens counted as `ContentModality.IMAGE`.
*/
@Test
fun testCountTokensImageFileModality() {
for (model in getModels()) {
runBlocking {
val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
val stream = ByteArrayOutputStream()
bitmap.compress(Bitmap.CompressFormat.PNG, 1, stream)
val array = stream.toByteArray()
val response = model.countTokens(Content.Builder().inlineData(array, "image/png").build())
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.IMAGE))
}
}
}

/**
* Ensures that nothing is free, that is, empty content contains no tokens. For some reason, this
* is treated as `ContentModality.TEXT`.
*/
@Test
fun testCountTokensNothingIsFree() {
for (model in getModels()) {
runBlocking {
val response = model.countTokens(Content.Builder().build())
checkTokenCountsMatch(response)
assert(response.totalTokens == 0)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.TEXT))
}
}
}

/**
* Checks if the model can count the tokens for a sent file. Additionally, ensures that the model
* treats this sent file as the modality of the mime type, in this case, a JSON file is not
* recognized, and no tokens are counted. This ensures if/when the model can handle JSON, our
* testing makes us aware.
*/
@Test
fun testCountTokensJsonFileModality() {
for (model in getModels()) {
runBlocking {
val json =
"""
{
"foo": "bar",
"baz": 3,
"qux": [
{
"quux": [
1,
2
]
}
]
}
"""
.trimIndent()
val response =
model.countTokens(
Content.Builder().inlineData(json.toByteArray(), "application/json").build()
)
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.isEmpty())
assert(response.totalTokens == 0)
}
}
}

fun checkTokenCountsMatch(response: CountTokensResponse) {
assert(sumTokenCount(response) == response.totalTokens)
}

fun sumTokenCount(response: CountTokensResponse): Int {
return response.promptTokensDetails.sumOf { it.tokenCount }
}

fun containsModality(response: CountTokensResponse, modality: ContentModality): Boolean {
for (token in response.promptTokensDetails) {
if (token.modality == modality) {
return true
}
}
return false
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.firebase.ai

import android.graphics.Bitmap
import com.google.firebase.ai.AIModels.Companion.getModels
import com.google.firebase.ai.type.Content
import kotlinx.coroutines.runBlocking
import org.junit.Test

class GenerateContentTests {
private val validator = TypesValidator()

/**
* Ensures the model can response to prompts and that the structure of this response is expected.
*/
@Test
fun testGenerateContent_BasicRequest() {
for (model in getModels()) {
runBlocking {
val response = model.generateContent("pick a random color")
validator.validateResponse(response)
}
}
}

/**
* Ensures that the model can answer very simple questions. Further testing the "logic" of the
* model and the content of the responses is prone to flaking, this test is also prone to that.
* This is probably the furthest we can consistently test for reasonable response structure, past
* sending the request and response back to the model and asking it if it fits our expectations.
*/
@Test
fun testGenerateContent_ColorMixing() {
for (model in getModels()) {
runBlocking {
val response = model.generateContent("what color is created when red and yellow are mixed?")
validator.validateResponse(response)
assert(response.text!!.contains("orange", true))
}
}
}

/**
* Ensures that the model can answer very simple questions. Further testing the "logic" of the
* model and the content of the responses is prone to flaking, this test is also prone to that.
* This is probably the furthest we can consistently test for reasonable response structure, past
* sending the request and response back to the model and asking it if it fits our expectations.
*/
@Test
fun testGenerateContent_CanSendImage() {
for (model in getModels()) {
runBlocking {
val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
val yellow = Integer.parseUnsignedInt("FFFFFF00", 16)
bitmap.setPixel(3, 3, yellow)
bitmap.setPixel(6, 3, yellow)
bitmap.setPixel(3, 6, yellow)
bitmap.setPixel(4, 7, yellow)
bitmap.setPixel(5, 7, yellow)
bitmap.setPixel(6, 6, yellow)
val response =
model.generateContent(
Content.Builder().text("here is a tiny smile").image(bitmap).build()
)
validator.validateResponse(response)
}
}
}
}
Loading
Loading