Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

airframe-grpc: Support gRPC backend #1192

Merged
merged 16 commits into from
Jul 28, 2020
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.http.grpc

/**
*/
object Grpc {
def server: GrpcServerConfig = GrpcServerConfig()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.http.grpc
import io.grpc.{Server, ServerBuilder}
import wvlet.airframe.{Design, Session}
import wvlet.airframe.http.Router
import wvlet.log.LogSupport
import wvlet.log.io.IOUtil
import scala.language.existentials

/**
*/
case class GrpcServerConfig(
name: String = "default",
private val serverPort: Option[Int] = None,
router: Router = Router.empty
) extends LogSupport {
lazy val port = serverPort.getOrElse(IOUtil.unusedPort)

def withName(name: String): GrpcServerConfig = this.copy(name = name)
def withPort(port: Int): GrpcServerConfig = this.copy(serverPort = Some(port))
def withRouter(router: Router): GrpcServerConfig = this.copy(router = router)

def newServer(session: Session): GrpcServer = {
val services = GrpcServiceBuilder.buildService(router, session)
debug(s"service:\n${services.map(_.getServiceDescriptor).mkString("\n")}")
val serverBuilder = ServerBuilder.forPort(port)
for (service <- services) {
serverBuilder.addService(service)
}
new GrpcServer(this, serverBuilder.build())
}

def design: Design = {
Design.newDesign
.bind[GrpcServerConfig].toInstance(this)
.bind[GrpcServer].toProvider { (config: GrpcServerConfig, session: Session) => config.newServer(session) }
.onStart { _.start }
}
}

class GrpcServer(grpcServerConfig: GrpcServerConfig, server: Server) extends AutoCloseable with LogSupport {
def port: Int = grpcServerConfig.port
def localAddress: String = s"localhost:${grpcServerConfig.port}"

def start: Unit = {
info(s"Starting gRPC server ${grpcServerConfig.name} at ${localAddress}")
server.start()
}

def awaitTermination: Unit = {
server.awaitTermination()
}

override def close(): Unit = {
info(s"Closing gRPC server ${grpcServerConfig.name} at ${localAddress}")
server.shutdownNow()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.http.grpc
import java.io.{ByteArrayInputStream, InputStream}

import io.grpc.MethodDescriptor.Marshaller
import io.grpc.stub.ServerCalls
import io.grpc.{MethodDescriptor, ServerServiceDefinition}
import wvlet.airframe.Session
import wvlet.airframe.codec.{MessageCodec, MessageCodecFactory}
import wvlet.airframe.control.IO
import wvlet.airframe.http.Router
import wvlet.airframe.http.router.Route
import wvlet.airframe.msgpack.spi.MsgPack

/**
*/
object GrpcServiceBuilder {

def buildMethodDescriptor(r: Route, codecFactory: MessageCodecFactory): MethodDescriptor[MsgPack, Any] = {
val b = MethodDescriptor.newBuilder[MsgPack, Any]()
// TODO setIdempotent, setSafe, sampling, etc.
b.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName(s"${r.serviceName}/${r.methodSurface.name}")
.setRequestMarshaller(RPCRequestMarshaller)
.setResponseMarshaller(
new RPCResponseMarshaller[Any](
codecFactory.of(r.returnTypeSurface).asInstanceOf[MessageCodec[Any]]
)
)
.build()
}

def buildService(
router: Router,
session: Session,
codecFactory: MessageCodecFactory = MessageCodecFactory.defaultFactoryForJSON
): Seq[ServerServiceDefinition] = {
val services = for ((serviceName, routes) <- router.routes.groupBy(_.serviceName)) yield {
val routeAndMethods = for (route <- routes) yield {
(route, buildMethodDescriptor(route, codecFactory))
}

val serviceBuilder = ServerServiceDefinition.builder(serviceName)

for ((r, m) <- routeAndMethods) {
// TODO Support Client/Server Streams
val controller = session.getInstanceOf(r.controllerSurface)
serviceBuilder.addMethod(
m,
ServerCalls.asyncUnaryCall(new RPCRequestHandler[Any](controller, r.methodSurface, codecFactory))
)
}
val serviceDef = serviceBuilder.build()
serviceDef
}

services.toSeq
}

object RPCRequestMarshaller extends Marshaller[MsgPack] {
override def stream(value: MsgPack): InputStream = {
new ByteArrayInputStream(value)
}
override def parse(stream: InputStream): MsgPack = {
IO.readFully(stream)
}
}

class RPCResponseMarshaller[A](codec: MessageCodec[A]) extends Marshaller[A] {
override def stream(value: A): InputStream = {
new ByteArrayInputStream(codec.toMsgPack(value))
}
override def parse(stream: InputStream): A = {
codec.fromMsgPack(stream.readAllBytes())
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.http.grpc
import io.grpc.stub.ServerCalls.UnaryMethod
import io.grpc.stub.StreamObserver
import wvlet.airframe.codec.MessageCodecFactory
import wvlet.airframe.codec.PrimitiveCodec.ValueCodec
import wvlet.airframe.http.router.HttpRequestMapper
import wvlet.airframe.msgpack.spi.MsgPack
import wvlet.airframe.msgpack.spi.Value.MapValue
import wvlet.airframe.surface.{CName, MethodSurface}
import wvlet.log.LogSupport

import scala.util.{Failure, Success, Try}

/**
* Receives MessagePack Map value for the RPC request, and call the controller method
*/
class RPCRequestHandler[A](controller: Any, methodSurface: MethodSurface, codecFactory: MessageCodecFactory)
extends UnaryMethod[MsgPack, A]
with LogSupport {

private val argCodecs = methodSurface.args.map(a => codecFactory.of(a.surface))

override def invoke(request: MsgPack, responseObserver: StreamObserver[A]): Unit = {
// Build method arguments from MsgPack
val requestValue = ValueCodec.unpack(request)
trace(requestValue)

val result = Try {
requestValue match {
case m: MapValue =>
val mapValue = HttpRequestMapper.toCanonicalKeyNameMap(m)
val args = for ((arg, i) <- methodSurface.args.zipWithIndex) yield {
val argOpt = mapValue.get(CName.toCanonicalName(arg.name)) match {
case Some(paramValue) =>
Option(argCodecs(i).fromMsgPack(paramValue.toMsgpack)).orElse {
throw new IllegalArgumentException(s"Failed to parse ${paramValue} for ${arg}")
}
case None =>
// If no value is found, use the method parameter's default argument
arg.getMethodArgDefaultValue(controller)
}
argOpt.getOrElse {
throw new IllegalArgumentException(s"No key for ${arg.name} is found in ${m}")
}
}

trace(s"RPC call ${methodSurface.name}(${args.mkString(", ")})")
methodSurface.call(controller, args: _*)
case _ =>
throw new IllegalArgumentException(s"Invalid argument: ${requestValue}")
}
}
result match {
case Success(v) =>
responseObserver.onNext(v.asInstanceOf[A])
case Failure(e) =>
responseObserver.onError(e)
}
responseObserver.onCompleted()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.http.grpc
import java.io.{ByteArrayInputStream, InputStream}

import io.grpc.MethodDescriptor.Marshaller
import wvlet.airframe.codec.{INVALID_DATA, MessageCodecException, MessageContext}
import wvlet.airframe.codec.PrimitiveCodec.StringCodec
import wvlet.airframe.msgpack.spi.MessagePack
import wvlet.log.LogSupport

/**
* Marshalling String as MessagePack
*/
private[grpc] object StringMarshaller extends Marshaller[String] with LogSupport {
override def stream(value: String): InputStream = {
new ByteArrayInputStream(StringCodec.toMsgPack(value))
}
override def parse(stream: InputStream): String = {
val unpacker = MessagePack.newUnpacker(stream)
val v = MessageContext()

StringCodec.unpack(unpacker, v)
if (!v.isNull) {
val s = v.getString
s
} else {
v.getError match {
case Some(e) => throw new RuntimeException(e)
case None => throw new MessageCodecException(INVALID_DATA, StringCodec, "invalid input")
}
}
}
}
Loading