Skip to content

Commit 40b7637

Browse files
authored
Refactor live bidi (#6870)
Per [b/410070347](https://b.corp.google.com/issues/410070347), This refactors our bidi model to be more thread safe, and take better advantage of immutability and kotlin coroutines. This solves various edge case issues, as well as makes it easier to maintain. This PR also fixes the following: - [b/410063693](https://b.corp.google.com/issues/410063693) -> Add serialization for bidi responses - [b/410064609](https://b.corp.google.com/issues/410064609) -> Retain audio data if present when turn complete or interrupted - [b/410069806](https://b.corp.google.com/issues/410069806) -> Use blocking instead of background dispatcher for bidi - [b/410841715](https://b.corp.google.com/issues/410841715) -> Catch Bidi exceptions - [b/410067576](https://b.corp.google.com/issues/410067576) -> Send zeroed out audio data to server when not speaking
1 parent 47e37b5 commit 40b7637

17 files changed

+806
-388
lines changed

firebase-vertexai/CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
* [feature] Added support for `HarmBlockThreshold.OFF`. See the
1010
[model documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters#how_to_configure_content_filters){: .external}
1111
for more information.
12+
* [fixed] Improved thread usage when using a `LiveGenerativeModel`. (#6870)
13+
* [fixed] Fixed an issue with `LiveContentResponse` audio data not being present when the model was
14+
interrupted or the turn completed. (#6870)
15+
* [fixed] Fixed an issue with `LiveSession` not converting exceptions to `FirebaseVertexAIException`. (#6870)
16+
1217

1318
# 16.3.0
1419
* [feature] Emits a warning when attempting to use an incompatible model with

firebase-vertexai/api.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ package com.google.firebase.vertexai.type {
629629
method public suspend Object? send(String text, kotlin.coroutines.Continuation<? super kotlin.Unit>);
630630
method public suspend Object? sendFunctionResponse(java.util.List<com.google.firebase.vertexai.type.FunctionResponsePart> functionList, kotlin.coroutines.Continuation<? super kotlin.Unit>);
631631
method public suspend Object? sendMediaStream(java.util.List<com.google.firebase.vertexai.type.MediaData> mediaChunks, kotlin.coroutines.Continuation<? super kotlin.Unit>);
632-
method public suspend Object? startAudioConversation(kotlin.jvm.functions.Function1<? super com.google.firebase.vertexai.type.FunctionCallPart,com.google.firebase.vertexai.type.FunctionResponsePart>? functionCallHandler = null, kotlin.coroutines.Continuation<? super kotlin.Unit>);
632+
method @RequiresPermission(android.Manifest.permission.RECORD_AUDIO) public suspend Object? startAudioConversation(kotlin.jvm.functions.Function1<? super com.google.firebase.vertexai.type.FunctionCallPart,com.google.firebase.vertexai.type.FunctionResponsePart>? functionCallHandler = null, kotlin.coroutines.Continuation<? super kotlin.Unit>);
633633
method public void stopAudioConversation();
634634
method public void stopReceiving();
635635
}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package com.google.firebase.vertexai
1919
import android.util.Log
2020
import com.google.firebase.Firebase
2121
import com.google.firebase.FirebaseApp
22-
import com.google.firebase.annotations.concurrent.Background
22+
import com.google.firebase.annotations.concurrent.Blocking
2323
import com.google.firebase.app
2424
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
2525
import com.google.firebase.auth.internal.InternalAuthProvider
@@ -41,7 +41,7 @@ import kotlin.coroutines.CoroutineContext
4141
public class FirebaseVertexAI
4242
internal constructor(
4343
private val firebaseApp: FirebaseApp,
44-
@Background private val backgroundDispatcher: CoroutineContext,
44+
@Blocking private val blockingDispatcher: CoroutineContext,
4545
private val location: String,
4646
private val appCheckProvider: Provider<InteropAppCheckTokenProvider>,
4747
private val internalAuthProvider: Provider<InternalAuthProvider>,
@@ -133,7 +133,7 @@ internal constructor(
133133
"projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}",
134134
firebaseApp.options.apiKey,
135135
firebaseApp,
136-
backgroundDispatcher,
136+
blockingDispatcher,
137137
generationConfig,
138138
tools,
139139
systemInstruction,

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIMultiResourceComponent.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package com.google.firebase.vertexai
1818

1919
import androidx.annotation.GuardedBy
2020
import com.google.firebase.FirebaseApp
21-
import com.google.firebase.annotations.concurrent.Background
21+
import com.google.firebase.annotations.concurrent.Blocking
2222
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
2323
import com.google.firebase.auth.internal.InternalAuthProvider
2424
import com.google.firebase.inject.Provider
@@ -31,7 +31,7 @@ import kotlin.coroutines.CoroutineContext
3131
*/
3232
internal class FirebaseVertexAIMultiResourceComponent(
3333
private val app: FirebaseApp,
34-
@Background val backgroundDispatcher: CoroutineContext,
34+
@Blocking val blockingDispatcher: CoroutineContext,
3535
private val appCheckProvider: Provider<InteropAppCheckTokenProvider>,
3636
private val internalAuthProvider: Provider<InternalAuthProvider>,
3737
) {
@@ -43,7 +43,7 @@ internal class FirebaseVertexAIMultiResourceComponent(
4343
instances[location]
4444
?: FirebaseVertexAI(
4545
app,
46-
backgroundDispatcher,
46+
blockingDispatcher,
4747
location,
4848
appCheckProvider,
4949
internalAuthProvider

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAIRegistrar.kt

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package com.google.firebase.vertexai
1818

1919
import androidx.annotation.Keep
2020
import com.google.firebase.FirebaseApp
21-
import com.google.firebase.annotations.concurrent.Background
21+
import com.google.firebase.annotations.concurrent.Blocking
2222
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
2323
import com.google.firebase.auth.internal.InternalAuthProvider
2424
import com.google.firebase.components.Component
@@ -41,13 +41,13 @@ internal class FirebaseVertexAIRegistrar : ComponentRegistrar {
4141
Component.builder(FirebaseVertexAIMultiResourceComponent::class.java)
4242
.name(LIBRARY_NAME)
4343
.add(Dependency.required(firebaseApp))
44-
.add(Dependency.required(backgroundDispatcher))
44+
.add(Dependency.required(blockingDispatcher))
4545
.add(Dependency.optionalProvider(appCheckInterop))
4646
.add(Dependency.optionalProvider(internalAuthProvider))
4747
.factory { container ->
4848
FirebaseVertexAIMultiResourceComponent(
4949
container[firebaseApp],
50-
container.get(backgroundDispatcher),
50+
container.get(blockingDispatcher),
5151
container.getProvider(appCheckInterop),
5252
container.getProvider(internalAuthProvider)
5353
)
@@ -62,7 +62,7 @@ internal class FirebaseVertexAIRegistrar : ComponentRegistrar {
6262
private val firebaseApp = unqualified(FirebaseApp::class.java)
6363
private val appCheckInterop = unqualified(InteropAppCheckTokenProvider::class.java)
6464
private val internalAuthProvider = unqualified(InternalAuthProvider::class.java)
65-
private val backgroundDispatcher =
66-
Qualified.qualified(Background::class.java, CoroutineDispatcher::class.java)
65+
private val blockingDispatcher =
66+
Qualified.qualified(Blocking::class.java, CoroutineDispatcher::class.java)
6767
}
6868
}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/LiveGenerativeModel.kt

+13-10
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
package com.google.firebase.vertexai
1818

1919
import com.google.firebase.FirebaseApp
20-
import com.google.firebase.annotations.concurrent.Background
20+
import com.google.firebase.annotations.concurrent.Blocking
2121
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
2222
import com.google.firebase.auth.internal.InternalAuthProvider
2323
import com.google.firebase.vertexai.common.APIController
2424
import com.google.firebase.vertexai.common.AppCheckHeaderProvider
25-
import com.google.firebase.vertexai.type.BidiGenerateContentClientMessage
25+
import com.google.firebase.vertexai.common.JSON
2626
import com.google.firebase.vertexai.type.Content
27+
import com.google.firebase.vertexai.type.LiveClientSetupMessage
2728
import com.google.firebase.vertexai.type.LiveGenerationConfig
2829
import com.google.firebase.vertexai.type.LiveSession
2930
import com.google.firebase.vertexai.type.PublicPreviewAPI
@@ -38,6 +39,7 @@ import kotlinx.coroutines.channels.ClosedReceiveChannelException
3839
import kotlinx.serialization.ExperimentalSerializationApi
3940
import kotlinx.serialization.encodeToString
4041
import kotlinx.serialization.json.Json
42+
import kotlinx.serialization.json.JsonObject
4143

4244
/**
4345
* Represents a multimodal model (like Gemini) capable of real-time content generation based on
@@ -47,7 +49,7 @@ import kotlinx.serialization.json.Json
4749
public class LiveGenerativeModel
4850
internal constructor(
4951
private val modelName: String,
50-
@Background private val backgroundDispatcher: CoroutineContext,
52+
@Blocking private val blockingDispatcher: CoroutineContext,
5153
private val config: LiveGenerationConfig? = null,
5254
private val tools: List<Tool>? = null,
5355
private val systemInstruction: Content? = null,
@@ -58,7 +60,7 @@ internal constructor(
5860
modelName: String,
5961
apiKey: String,
6062
firebaseApp: FirebaseApp,
61-
backgroundDispatcher: CoroutineContext,
63+
blockingDispatcher: CoroutineContext,
6264
config: LiveGenerationConfig? = null,
6365
tools: List<Tool>? = null,
6466
systemInstruction: Content? = null,
@@ -68,7 +70,7 @@ internal constructor(
6870
internalAuthProvider: InternalAuthProvider? = null,
6971
) : this(
7072
modelName,
71-
backgroundDispatcher,
73+
blockingDispatcher,
7274
config,
7375
tools,
7476
systemInstruction,
@@ -93,7 +95,7 @@ internal constructor(
9395
@OptIn(ExperimentalSerializationApi::class)
9496
public suspend fun connect(): LiveSession {
9597
val clientMessage =
96-
BidiGenerateContentClientMessage(
98+
LiveClientSetupMessage(
9799
modelName,
98100
config?.toInternal(),
99101
tools?.map { it.toInternal() },
@@ -104,10 +106,11 @@ internal constructor(
104106
try {
105107
val webSession = controller.getWebSocketSession(location)
106108
webSession.send(Frame.Text(data))
107-
val receivedJson = webSession.incoming.receive().readBytes().toString(Charsets.UTF_8)
108-
// TODO: Try to decode the json instead of string matching.
109-
return if (receivedJson.contains("setupComplete")) {
110-
LiveSession(session = webSession, backgroundDispatcher = backgroundDispatcher)
109+
val receivedJsonStr = webSession.incoming.receive().readBytes().toString(Charsets.UTF_8)
110+
val receivedJson = JSON.parseToJsonElement(receivedJsonStr)
111+
112+
return if (receivedJson is JsonObject && "setupComplete" in receivedJson) {
113+
LiveSession(session = webSession, blockingDispatcher = blockingDispatcher)
111114
} else {
112115
webSession.close()
113116
throw ServiceConnectionHandshakeFailedException("Unable to connect to the server")

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt

+1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ internal constructor(
165165

166166
suspend fun getWebSocketSession(location: String): ClientWebSocketSession =
167167
client.webSocketSession(getBidiEndpoint(location))
168+
168169
fun generateContentStream(
169170
request: GenerateContentRequest
170171
): Flow<GenerateContentResponse.Internal> =
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.vertexai.common.util
18+
19+
import android.media.AudioRecord
20+
import kotlinx.coroutines.flow.flow
21+
import kotlinx.coroutines.yield
22+
23+
/**
24+
* The minimum buffer size for this instance.
25+
*
26+
* The same as calling [AudioRecord.getMinBufferSize], except the params are pre-populated.
27+
*/
28+
internal val AudioRecord.minBufferSize: Int
29+
get() = AudioRecord.getMinBufferSize(sampleRate, channelConfiguration, audioFormat)
30+
31+
/**
32+
* Reads from this [AudioRecord] and returns the data in a flow.
33+
*
34+
* Will yield when this instance is not recording.
35+
*/
36+
internal fun AudioRecord.readAsFlow() = flow {
37+
val buffer = ByteArray(minBufferSize)
38+
39+
while (true) {
40+
if (recordingState != AudioRecord.RECORDSTATE_RECORDING) {
41+
yield()
42+
continue
43+
}
44+
45+
val bytesRead = read(buffer, 0, buffer.size)
46+
if (bytesRead > 0) {
47+
emit(buffer.copyOf(bytesRead))
48+
}
49+
}
50+
}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/kotlin.kt

+62
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@
1616

1717
package com.google.firebase.vertexai.common.util
1818

19+
import java.io.ByteArrayOutputStream
1920
import java.lang.reflect.Field
21+
import kotlin.coroutines.EmptyCoroutineContext
22+
import kotlinx.coroutines.CoroutineScope
23+
import kotlinx.coroutines.Job
24+
import kotlinx.coroutines.cancel
25+
import kotlinx.coroutines.currentCoroutineContext
26+
import kotlinx.coroutines.flow.Flow
27+
import kotlinx.coroutines.flow.flow
28+
import kotlinx.coroutines.flow.fold
2029

2130
/**
2231
* Removes the last character from the [StringBuilder].
@@ -39,3 +48,56 @@ internal fun StringBuilder.removeLast(): StringBuilder =
3948
* ```
4049
*/
4150
internal inline fun <reified T : Annotation> Field.getAnnotation() = getAnnotation(T::class.java)
51+
52+
/**
53+
* Collects bytes from this flow and doesn't emit them back until [minSize] is reached.
54+
*
55+
* For example:
56+
* ```
57+
* val byteArr = flowOf(byteArrayOf(1), byteArrayOf(2, 3, 4), byteArrayOf(5, 6, 7, 8))
58+
* val expectedResult = listOf(byteArrayOf(1, 2, 3, 4), byteArrayOf( 5, 6, 7, 8))
59+
*
60+
* byteArr.accumulateUntil(4).toList() shouldContainExactly expectedResult
61+
* ```
62+
*
63+
* @param minSize The minimum about of bytes the array should have before being sent down-stream
64+
* @param emitLeftOvers If the flow completes and there are bytes left over that don't meet the
65+
* [minSize], send them anyways.
66+
*/
67+
internal fun Flow<ByteArray>.accumulateUntil(
68+
minSize: Int,
69+
emitLeftOvers: Boolean = false
70+
): Flow<ByteArray> = flow {
71+
val remaining =
72+
fold(ByteArrayOutputStream()) { buffer, it ->
73+
buffer.apply {
74+
write(it, 0, it.size)
75+
if (size() >= minSize) {
76+
emit(toByteArray())
77+
reset()
78+
}
79+
}
80+
}
81+
82+
if (emitLeftOvers && remaining.size() > 0) {
83+
emit(remaining.toByteArray())
84+
}
85+
}
86+
87+
/**
88+
* Create a [Job] that is a child of the [currentCoroutineContext], if any.
89+
*
90+
* This is useful when you want a coroutine scope to be canceled when its parent scope is canceled,
91+
* and you don't have full control over the parent scope, but you don't want the cancellation of the
92+
* child to impact the parent.
93+
*
94+
* If the parent coroutine context does not have a job, an empty one will be created.
95+
*/
96+
internal suspend inline fun childJob() = Job(currentCoroutineContext()[Job] ?: Job())
97+
98+
/**
99+
* A constant value pointing to a cancelled [CoroutineScope].
100+
*
101+
* Useful when you want to initialize a mutable [CoroutineScope] in a canceled state.
102+
*/
103+
internal val CancelledCoroutineScope = CoroutineScope(EmptyCoroutineContext).apply { cancel() }

0 commit comments

Comments
 (0)