From 0ee8b7a8e279e232d9969e60d9d10989f753a96e Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Wed, 12 Feb 2025 09:54:17 -0800 Subject: [PATCH] rpc (feature): Support RPC method returning Rx[A] (#3829) --- .../http/codegen/client/RPCClientGenerator.scala | 15 ++++++++++++++- .../test/scala/example/rpc/RPCTestService.scala | 3 +++ .../wvlet/airframe/test/api/HelloRPC.scala | 2 ++ .../wvlet/airframe/test/api/HelloRPCImpl.scala | 4 ++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/airframe-http-codegen/src/main/scala/wvlet/airframe/http/codegen/client/RPCClientGenerator.scala b/airframe-http-codegen/src/main/scala/wvlet/airframe/http/codegen/client/RPCClientGenerator.scala index 0671aa0335..9a43eb8a4c 100644 --- a/airframe-http-codegen/src/main/scala/wvlet/airframe/http/codegen/client/RPCClientGenerator.scala +++ b/airframe-http-codegen/src/main/scala/wvlet/airframe/http/codegen/client/RPCClientGenerator.scala @@ -16,6 +16,7 @@ import wvlet.airframe.http.{Http, HttpMethod} import wvlet.airframe.http.codegen.HttpClientIR import wvlet.airframe.http.codegen.HttpClientIR.{ClientMethodDef, ClientServiceDef} import wvlet.airframe.http.codegen.client.HttpClientGenerator.RichSurface +import wvlet.airframe.rx.Rx /** * The default RPC client generator using Http.client.Sync/AsyncClient @@ -141,7 +142,19 @@ object RPCClientGenerator extends HttpClientGenerator { m.inputParameters .map(x => s"${x.name}: ${x.surface.fullTypeName}") - val returnType = if (isAsync) s"Rx[${m.returnType.fullTypeName}]" else m.returnType.fullTypeName + val isRxResponse = m.returnType.rawType.isAssignableFrom(classOf[Rx[_]]) && m.returnType.typeArgs.size == 1 + val returnElementType = if (isRxResponse) { + // for methods returning Rx[A], extract A + m.returnType.typeArgs(0).fullTypeName + } else { + m.returnType.fullTypeName + } + val returnType = + if (isAsync) + s"Rx[${returnElementType}]" + else + returnElementType + if (m.isRPC) { s"""def ${m.name}(${inputArgs.mkString(", ")}): ${returnType} = { | client.rpc[${m.typeArgString}](${sendRequestArgs(m)}) diff --git a/airframe-http-codegen/src/test/scala/example/rpc/RPCTestService.scala b/airframe-http-codegen/src/test/scala/example/rpc/RPCTestService.scala index 759ce8f2ce..065a7c27f0 100644 --- a/airframe-http-codegen/src/test/scala/example/rpc/RPCTestService.scala +++ b/airframe-http-codegen/src/test/scala/example/rpc/RPCTestService.scala @@ -13,6 +13,7 @@ */ package example.rpc import wvlet.airframe.http.{RPC, RxRouter, RxRouterProvider} +import wvlet.airframe.rx.Rx /** */ @@ -41,6 +42,8 @@ trait RPCExample { def rpcWithOption(p1: Option[String]): Unit def rpcWithPrimitiveAndOption(p1: String, p2: Option[String]): Unit def rpcWithOptionOfComplexType(p1: Option[RPCRequest]): Unit + + def rpcWithRxResponse(p1: Int): Rx[RPCResponse] } object RPCExample extends RxRouterProvider { diff --git a/airframe-integration-test-api/src/main/scala-3/wvlet/airframe/test/api/HelloRPC.scala b/airframe-integration-test-api/src/main/scala-3/wvlet/airframe/test/api/HelloRPC.scala index 9445571e67..9b80b16bf7 100644 --- a/airframe-integration-test-api/src/main/scala-3/wvlet/airframe/test/api/HelloRPC.scala +++ b/airframe-integration-test-api/src/main/scala-3/wvlet/airframe/test/api/HelloRPC.scala @@ -14,6 +14,7 @@ package wvlet.airframe.test.api import wvlet.airframe.http.* +import wvlet.airframe.rx.Rx @RPC trait HelloRPC: @@ -23,6 +24,7 @@ trait HelloRPC: def serverStatus: Status def ackStatus(status: Status): Status def variousParams(params: VariousParams): VariousParams + def ackStatusAsync(name: String): Rx[Status] object HelloRPC extends RxRouterProvider: override def router: RxRouter = RxRouter.of[HelloRPC] diff --git a/airframe-integration-test/src/main/scala-3/wvlet/airframe/test/api/HelloRPCImpl.scala b/airframe-integration-test/src/main/scala-3/wvlet/airframe/test/api/HelloRPCImpl.scala index a208c03644..b255efd5a3 100644 --- a/airframe-integration-test/src/main/scala-3/wvlet/airframe/test/api/HelloRPCImpl.scala +++ b/airframe-integration-test/src/main/scala-3/wvlet/airframe/test/api/HelloRPCImpl.scala @@ -13,6 +13,7 @@ */ package wvlet.airframe.test.api +import wvlet.airframe.rx.Rx import wvlet.airframe.test.api.HelloRPC.VariousParams import wvlet.log.LogSupport @@ -28,3 +29,6 @@ class HelloRPCImpl extends HelloRPC with LogSupport: override def variousParams(params: VariousParams): VariousParams = info(s"received: ${params}") params + + override def ackStatusAsync(name: String): Rx[Status] = + Rx.const(Status.OK)