From 535fb9099d1679fc45f461e855449f83dd85ab68 Mon Sep 17 00:00:00 2001 From: mustafajadid Date: Mon, 16 Dec 2024 13:57:49 -0800 Subject: [PATCH] Extend Firebase SDK with new APIs to consume streaming callable function response. - Handling the server-sent event (SSE) parsing internally - Providing proper error handling and connection management - Maintaining memory efficiency for long-running streams --- .../google/firebase/functions/StramTests.kt | 128 ++++++++++ .../firebase/functions/FirebaseFunctions.kt | 226 ++++++++++++++++++ .../functions/HttpsCallableReference.kt | 83 +++++++ .../firebase/functions/SSETaskListener.kt | 14 ++ 4 files changed, 451 insertions(+) create mode 100644 firebase-functions/src/androidTest/java/com/google/firebase/functions/StramTests.kt create mode 100644 firebase-functions/src/main/java/com/google/firebase/functions/SSETaskListener.kt diff --git a/firebase-functions/src/androidTest/java/com/google/firebase/functions/StramTests.kt b/firebase-functions/src/androidTest/java/com/google/firebase/functions/StramTests.kt new file mode 100644 index 00000000000..6f2a2693ec2 --- /dev/null +++ b/firebase-functions/src/androidTest/java/com/google/firebase/functions/StramTests.kt @@ -0,0 +1,128 @@ +package com.google.firebase.functions.ktx + +import androidx.test.InstrumentationRegistry +import androidx.test.runner.AndroidJUnit4 +import com.google.android.gms.tasks.Tasks +import com.google.common.truth.Truth.assertThat +import com.google.firebase.FirebaseApp +import com.google.firebase.functions.FirebaseFunctions +import com.google.firebase.functions.FirebaseFunctionsException +import com.google.firebase.functions.SSETaskListener +import com.google.firebase.ktx.Firebase +import com.google.firebase.ktx.initialize +import java.util.concurrent.ExecutionException +import java.util.concurrent.TimeUnit +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith + +@RunWith(AndroidJUnit4::class) +class StreamTests { + + private lateinit var app: FirebaseApp + private lateinit var listener: SSETaskListener + + private lateinit var functions: FirebaseFunctions + var onNext = mutableListOf() + var onError: Any? = null + var onComplete: Any? = null + + @Before + fun setup() { + app = Firebase.initialize(InstrumentationRegistry.getContext())!! + functions = FirebaseFunctions.getInstance() + functions.useEmulator("10.0.2.2", 5001) + listener = + object : SSETaskListener { + override fun onNext(event: Any) { + onNext.add(event) + } + + override fun onError(event: Any) { + onError = event + } + + override fun onComplete(event: Any) { + onComplete = event + } + } + } + + @After + fun clear() { + onNext.clear() + onError = null + onComplete = null + } + + @Test + fun testGenStream() { + val input = hashMapOf("data" to "Why is the sky blue") + + val function = functions.getHttpsCallable("genStream") + val httpsCallableResult = Tasks.await(function.stream(input, listener)) + + val onNextStringList = onNext.map { it.toString() } + assertThat(onNextStringList) + .containsExactly( + "{chunk=hello}", + "{chunk=world}", + "{chunk=this}", + "{chunk=is}", + "{chunk=cool}" + ) + assertThat(onError).isNull() + assertThat(onComplete).isEqualTo("hello world this is cool") + assertThat(httpsCallableResult.data).isEqualTo("hello world this is cool") + } + + @Test + fun testGenStreamError() { + val input = hashMapOf("data" to "Why is the sky blue") + val function = functions.getHttpsCallable("genStreamError").withTimeout(7, TimeUnit.SECONDS) + + try { + Tasks.await(function.stream(input, listener)) + } catch (exception: Exception) { + onError = exception + } + + val onNextStringList = onNext.map { it.toString() } + assertThat(onNextStringList) + .containsExactly( + "{chunk=hello}", + "{chunk=world}", + "{chunk=this}", + "{chunk=is}", + "{chunk=cool}" + ) + assertThat(onError).isInstanceOf(ExecutionException::class.java) + val cause = (onError as ExecutionException).cause + assertThat(cause).isInstanceOf(FirebaseFunctionsException::class.java) + assertThat((cause as FirebaseFunctionsException).message).contains("Socket closed") + assertThat(onComplete).isNull() + } + + @Test + fun testGenStreamNoReturn() { + val input = hashMapOf("data" to "Why is the sky blue") + + val function = functions.getHttpsCallable("genStreamNoReturn") + try { + Tasks.await(function.stream(input, listener), 7, TimeUnit.SECONDS) + } catch (_: Exception) {} + + val onNextStringList = onNext.map { it.toString() } + assertThat(onNextStringList) + .containsExactly( + "{chunk=hello}", + "{chunk=world}", + "{chunk=this}", + "{chunk=is}", + "{chunk=cool}" + ) + assertThat(onError).isNull() + assertThat(onComplete).isNull() + } +} diff --git a/firebase-functions/src/main/java/com/google/firebase/functions/FirebaseFunctions.kt b/firebase-functions/src/main/java/com/google/firebase/functions/FirebaseFunctions.kt index 3c0e7d6553e..2858c009ce5 100644 --- a/firebase-functions/src/main/java/com/google/firebase/functions/FirebaseFunctions.kt +++ b/firebase-functions/src/main/java/com/google/firebase/functions/FirebaseFunctions.kt @@ -30,7 +30,10 @@ import com.google.firebase.functions.FirebaseFunctionsException.Code.Companion.f import com.google.firebase.functions.FirebaseFunctionsException.Companion.fromResponse import dagger.assisted.Assisted import dagger.assisted.AssistedInject +import java.io.BufferedReader import java.io.IOException +import java.io.InputStream +import java.io.InputStreamReader import java.io.InterruptedIOException import java.net.MalformedURLException import java.net.URL @@ -311,6 +314,229 @@ internal constructor( return tcs.task } + internal fun stream( + name: String, + data: Any?, + options: HttpsCallOptions, + listener: SSETaskListener + ): Task { + return providerInstalled.task + .continueWithTask(executor) { contextProvider.getContext(options.limitedUseAppCheckTokens) } + .continueWithTask(executor) { task: Task -> + if (!task.isSuccessful) { + return@continueWithTask Tasks.forException(task.exception!!) + } + val context = task.result + val url = getURL(name) + stream(url, data, options, context, listener) + } + } + + internal fun stream( + url: URL, + data: Any?, + options: HttpsCallOptions, + listener: SSETaskListener + ): Task { + return providerInstalled.task + .continueWithTask(executor) { contextProvider.getContext(options.limitedUseAppCheckTokens) } + .continueWithTask(executor) { task: Task -> + if (!task.isSuccessful) { + return@continueWithTask Tasks.forException(task.exception!!) + } + val context = task.result + stream(url, data, options, context, listener) + } + } + + private fun stream( + url: URL, + data: Any?, + options: HttpsCallOptions, + context: HttpsCallableContext?, + listener: SSETaskListener + ): Task { + Preconditions.checkNotNull(url, "url cannot be null") + val tcs = TaskCompletionSource() + val callClient = options.apply(client) + callClient.postStream(url, tcs, listener) { applyCommonConfiguration(data, context) } + + return tcs.task + } + + private inline fun OkHttpClient.postStream( + url: URL, + tcs: TaskCompletionSource, + listener: SSETaskListener, + crossinline config: Request.Builder.() -> Unit = {} + ) { + val requestBuilder = Request.Builder().url(url) + requestBuilder.config() + val request = requestBuilder.build() + + val call = newCall(request) + call.enqueue( + object : Callback { + override fun onFailure(ignored: Call, e: IOException) { + val exception: Exception = + if (e is InterruptedIOException) { + FirebaseFunctionsException( + FirebaseFunctionsException.Code.DEADLINE_EXCEEDED.name, + FirebaseFunctionsException.Code.DEADLINE_EXCEEDED, + null, + e + ) + } else { + FirebaseFunctionsException( + FirebaseFunctionsException.Code.INTERNAL.name, + FirebaseFunctionsException.Code.INTERNAL, + null, + e + ) + } + listener.onError(exception) + tcs.setException(exception) + } + + @Throws(IOException::class) + override fun onResponse(ignored: Call, response: Response) { + try { + validateResponse(response) + val bodyStream = response.body()?.byteStream() + if (bodyStream != null) { + processSSEStream(bodyStream, serializer, listener, tcs) + } else { + val error = + FirebaseFunctionsException( + "Response body is null", + FirebaseFunctionsException.Code.INTERNAL, + null + ) + listener.onError(error) + tcs.setException(error) + } + } catch (e: FirebaseFunctionsException) { + listener.onError(e) + tcs.setException(e) + } + } + } + ) + } + + private fun validateResponse(response: Response) { + if (response.isSuccessful) return + + val htmlContentType = "text/html; charset=utf-8" + val trimMargin: String + if (response.code() == 404 && response.header("Content-Type") == htmlContentType) { + trimMargin = """URL not found. Raw response: ${response.body()?.string()}""".trimMargin() + throw FirebaseFunctionsException( + trimMargin, + FirebaseFunctionsException.Code.fromHttpStatus(response.code()), + null + ) + } + + val text = response.body()?.string() ?: "" + val error: Any? + try { + val json = JSONObject(text) + error = serializer.decode(json.opt("error")) + } catch (e: Throwable) { + throw FirebaseFunctionsException( + "${e.message} Unexpected Response:\n$text ", + FirebaseFunctionsException.Code.INTERNAL, + e + ) + } + throw FirebaseFunctionsException( + error.toString(), + FirebaseFunctionsException.Code.INTERNAL, + error + ) + } + + private fun Request.Builder.applyCommonConfiguration(data: Any?, context: HttpsCallableContext?) { + val body: MutableMap = HashMap() + val encoded = serializer.encode(data) + body["data"] = encoded + if (context!!.authToken != null) { + header("Authorization", "Bearer " + context.authToken) + } + if (context.instanceIdToken != null) { + header("Firebase-Instance-ID-Token", context.instanceIdToken) + } + if (context.appCheckToken != null) { + header("X-Firebase-AppCheck", context.appCheckToken) + } + header("Accept", "text/event-stream") + val bodyJSON = JSONObject(body) + val contentType = MediaType.parse("application/json") + val requestBody = RequestBody.create(contentType, bodyJSON.toString()) + post(requestBody) + } + + private fun processSSEStream( + inputStream: InputStream, + serializer: Serializer, + listener: SSETaskListener, + tcs: TaskCompletionSource + ) { + BufferedReader(InputStreamReader(inputStream)).use { reader -> + try { + reader.lineSequence().forEach { line -> + val dataChunk = + when { + line.startsWith("data:") -> line.removePrefix("data:") + line.startsWith("result:") -> line.removePrefix("result:") + else -> return@forEach + } + try { + val json = JSONObject(dataChunk) + when { + json.has("message") -> + serializer.decode(json.opt("message"))?.let { listener.onNext(it) } + json.has("error") -> { + serializer.decode(json.opt("error"))?.let { + throw FirebaseFunctionsException( + it.toString(), + FirebaseFunctionsException.Code.INTERNAL, + it + ) + } + } + json.has("result") -> { + serializer.decode(json.opt("result"))?.let { + listener.onComplete(it) + tcs.setResult(HttpsCallableResult(it)) + } + return + } + } + } catch (e: Throwable) { + throw FirebaseFunctionsException( + "${e.message} Invalid JSON: $dataChunk", + FirebaseFunctionsException.Code.INTERNAL, + e + ) + } + } + throw FirebaseFunctionsException( + "Stream ended unexpectedly without completion.", + FirebaseFunctionsException.Code.INTERNAL, + null + ) + } catch (e: Exception) { + throw FirebaseFunctionsException( + e.message ?: "Error reading stream", + FirebaseFunctionsException.Code.INTERNAL, + e + ) + } + } + } + public companion object { /** A task that will be resolved once ProviderInstaller has installed what it needs to. */ private val providerInstalled = TaskCompletionSource() diff --git a/firebase-functions/src/main/java/com/google/firebase/functions/HttpsCallableReference.kt b/firebase-functions/src/main/java/com/google/firebase/functions/HttpsCallableReference.kt index 90bdb63221b..da8734757d5 100644 --- a/firebase-functions/src/main/java/com/google/firebase/functions/HttpsCallableReference.kt +++ b/firebase-functions/src/main/java/com/google/firebase/functions/HttpsCallableReference.kt @@ -125,6 +125,89 @@ public class HttpsCallableReference { } } + /** + * Streams data to the specified HTTPS endpoint asynchronously. + * + * The data passed into the endpoint can be any of the following types: + * + * * Any primitive type, including `null`, `int`, `long`, `float`, and `boolean`. + * * [String] + * * [List<?>][java.util.List], where the contained objects are also one of these types. + * * [Map<String, ?>][java.util.Map], where the values are also one of these types. + * * [org.json.JSONArray] + * * [org.json.JSONObject] + * * [org.json.JSONObject.NULL] + * + * If the returned task fails, the exception will be one of the following types: + * + * * [java.io.IOException] + * - if the HTTPS request failed to connect. + * * [FirebaseFunctionsException] + * - if the request connected, but the function returned an error. + * + * The request to the Cloud Functions backend made by this method automatically includes a + * Firebase Instance ID token to identify the app instance. If a user is logged in with Firebase + * Auth, an auth token for the user will also be automatically included. + * + * Firebase Instance ID sends data to the Firebase backend periodically to collect information + * regarding the app instance. To stop this, see + * [com.google.firebase.iid.FirebaseInstanceId.deleteInstanceId]. It will resume with a new + * Instance ID the next time you call this method. + * + * Streaming events are handled by the provided [SSETaskListener], which will receive events and + * handle errors and completion notifications. + * + * @param data Parameters to pass to the endpoint. + * @param listener A listener to handle streaming events, errors, and completion notifications. + * @return A Task that will be completed when the streaming operation has finished. + * @see org.json.JSONArray + * @see org.json.JSONObject + * @see java.io.IOException + * @see FirebaseFunctionsException + */ + public fun stream(data: Any?, listener: SSETaskListener): Task { + return if (name != null) { + functionsClient.stream(name, data, options, listener) + } else { + functionsClient.stream(url!!, data, options, listener) + } + } + + /** + * Streams data to the specified HTTPS endpoint asynchronously without arguments. + * + * The request to the Cloud Functions backend made by this method automatically includes a + * Firebase Instance ID token to identify the app instance. If a user is logged in with Firebase + * Auth, an auth token for the user will also be automatically included. + * + * Firebase Instance ID sends data to the Firebase backend periodically to collect information + * regarding the app instance. To stop this, see + * [com.google.firebase.iid.FirebaseInstanceId.deleteInstanceId]. It will resume with a new + * Instance ID the next time you call this method. + * + * Streaming events are handled by the provided [SSETaskListener], which will receive events and + * handle errors and completion notifications. + * + * If the returned task fails, the exception will be one of the following types: + * + * * [java.io.IOException] + * - if the HTTPS request failed to connect. + * * [FirebaseFunctionsException] + * - if the request connected, but the function returned an error. + * + * @param listener A listener to handle streaming events, errors, and completion notifications. + * @return A Task that will be completed when the streaming operation has finished. + * @see java.io.IOException + * @see FirebaseFunctionsException + */ + public fun stream(listener: SSETaskListener): Task { + return if (name != null) { + functionsClient.stream(name, null, options, listener) + } else { + functionsClient.stream(url!!, null, options, listener) + } + } + /** * Changes the timeout for calls from this instance of Functions. The default is 60 seconds. * diff --git a/firebase-functions/src/main/java/com/google/firebase/functions/SSETaskListener.kt b/firebase-functions/src/main/java/com/google/firebase/functions/SSETaskListener.kt new file mode 100644 index 00000000000..dffeddfeec2 --- /dev/null +++ b/firebase-functions/src/main/java/com/google/firebase/functions/SSETaskListener.kt @@ -0,0 +1,14 @@ +package com.google.firebase.functions + +/** Listener for events from a Server-Sent Events stream. */ +public interface SSETaskListener { + + /** Called when a new event is received. */ + public fun onNext(event: Any) + + /** Called when an error occurs. */ + public fun onError(event: Any) + + /** Called when the stream is closed. */ + public fun onComplete(event: Any) +}