diff --git a/packages/grpc/package.json b/packages/grpc/package.json index 4725d6c7..4d0ad6f6 100644 --- a/packages/grpc/package.json +++ b/packages/grpc/package.json @@ -1,6 +1,6 @@ { "name": "@walmartlabs/cookie-cutter-grpc", - "version": "1.6.0-beta.2", + "version": "1.6.0-beta.3", "license": "Apache-2.0", "main": "dist/index.js", "types": "dist/index.d.ts", diff --git a/packages/grpc/src/__test__/grpc.test.ts b/packages/grpc/src/__test__/grpc.test.ts index d1cd3aff..0dc000dc 100644 --- a/packages/grpc/src/__test__/grpc.test.ts +++ b/packages/grpc/src/__test__/grpc.test.ts @@ -22,11 +22,14 @@ import { GrpcMetadata, grpcSource, IGrpcClientConfiguration, + IGrpcClientOptions, IGrpcConfiguration, + IGrpcServerOptions, IResponseStream, } from ".."; import { sample } from "./Sample"; +const apiKey = "token"; let nextPort = 56011; export interface ISampleService { @@ -78,16 +81,23 @@ export const SampleServiceDefinition = { }, }; -function testApp(handler: any, host?: string): CancelablePromise { +function testApp( + handler: any, + host?: string, + options?: IGrpcServerOptions +): CancelablePromise { return Application.create() .input() .add( - grpcSource({ - port: nextPort, - host, - definitions: [SampleServiceDefinition], - skipNoStreamingValidation: true, - }) + grpcSource( + { + port: nextPort, + host, + definitions: [SampleServiceDefinition], + skipNoStreamingValidation: true, + }, + options + ) ) .done() .dispatch(handler) @@ -96,13 +106,17 @@ function testApp(handler: any, host?: string): CancelablePromise { async function createClient( host?: string, - config?: Partial + config?: Partial, + options?: string | IGrpcClientOptions ): Promise { - const client = grpcClient({ - endpoint: `${host || "localhost"}:${nextPort++}`, - definition: SampleServiceDefinition, - ...config, - }); + const client = grpcClient( + { + endpoint: `${host || "localhost"}:${nextPort++}`, + definition: SampleServiceDefinition, + ...config, + }, + options + ); return client; } @@ -126,6 +140,29 @@ describe("gRPC source", () => { } }); + it("serves requests with api key validation", async () => { + const app = testApp( + { + onNoStreaming: async ( + request: sample.ISampleRequest, + _: IDispatchContext + ): Promise => { + return { name: request.id.toString() }; + }, + }, + undefined, + { apiKey } + ); + try { + const client = await createClient(undefined, undefined, { apiKey }); + const response = await client.NoStreaming({ id: 15 }); + expect(response).toMatchObject({ name: "15" }); + } finally { + app.cancel(); + await app; + } + }); + it("serves response streams", async () => { const app = testApp({ onStreamingOut: async ( @@ -235,6 +272,29 @@ describe("gRPC source", () => { } }); + it("throws error for missing/invalid api key", async () => { + const app = testApp( + { + onNoStreaming: async ( + request: sample.ISampleRequest, + _: IDispatchContext + ): Promise => { + return { name: request.id.toString() }; + }, + }, + undefined, + { apiKey } + ); + try { + const client = await createClient(); + const response = client.NoStreaming({ id: 15 }); + await expect(response).rejects.toThrowError(/Invalid API Key/); + } finally { + app.cancel(); + await app; + } + }); + it("validates that no streaming operations are exposed", () => { const a = () => grpcSource({ diff --git a/packages/grpc/src/index.ts b/packages/grpc/src/index.ts index eccf046e..708a17d8 100644 --- a/packages/grpc/src/index.ts +++ b/packages/grpc/src/index.ts @@ -58,6 +58,10 @@ export interface IGrpcServerConfiguration { readonly skipNoStreamingValidation?: boolean; } +export interface IGrpcServerOptions { + readonly apiKey?: string; +} + export interface IGrpcClientConfiguration { readonly endpoint: string; readonly definition: IGrpcServiceDefinition; @@ -66,6 +70,11 @@ export interface IGrpcClientConfiguration { readonly behavior?: Required; } +export interface IGrpcClientOptions { + readonly certPath?: string; + readonly apiKey?: string; +} + export enum GrpcMetadata { OperationPath = "grpc.OperationPath", ResponseStream = "grpc.ResponseStream", @@ -80,7 +89,8 @@ export interface IResponseStream { } export function grpcSource( - configuration: IGrpcServerConfiguration & IGrpcConfiguration + configuration: IGrpcServerConfiguration & IGrpcConfiguration, + options?: IGrpcServerOptions ): IInputSource & IRequireInitialization { configuration = config.parse( GrpcSourceConfiguration, @@ -90,7 +100,7 @@ export function grpcSource( allocator: Buffer, } ); - return new GrpcInputSource(configuration); + return new GrpcInputSource(configuration, options); } export function grpcMsg(operation: IGrpcServiceMethod, request: any): IMessage { @@ -102,7 +112,7 @@ export function grpcMsg(operation: IGrpcServiceMethod, request: any): IMessage { export function grpcClient( configuration: IGrpcClientConfiguration & IGrpcConfiguration, - certPath?: string + certPathOrOptions?: string | IGrpcClientOptions ): T & IRequireInitialization & IDisposable { configuration = config.parse( GrpcClientConfiguration, @@ -122,5 +132,8 @@ export function grpcClient( }, } ); - return createGrpcClient(configuration, certPath); + if (typeof certPathOrOptions === "string") { + certPathOrOptions = { certPath: certPathOrOptions }; + } + return createGrpcClient(configuration, certPathOrOptions); } diff --git a/packages/grpc/src/internal/GrpcInputSource.ts b/packages/grpc/src/internal/GrpcInputSource.ts index 01a9c69f..e62590b3 100644 --- a/packages/grpc/src/internal/GrpcInputSource.ts +++ b/packages/grpc/src/internal/GrpcInputSource.ts @@ -19,6 +19,7 @@ import { OpenTracingTagKeys, } from "@walmartlabs/cookie-cutter-core"; import { + Metadata, sendUnaryData, Server, ServerCredentials, @@ -38,7 +39,7 @@ import { GrpcResponseStream, GrpcStreamHandler, } from "."; -import { GrpcMetadata, IGrpcConfiguration, IGrpcServerConfiguration } from ".."; +import { GrpcMetadata, IGrpcConfiguration, IGrpcServerConfiguration, IGrpcServerOptions } from ".."; import { GrpcOpenTracingTagKeys } from "./helper"; enum GrpcMetrics { @@ -59,7 +60,10 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization { private tracer: Tracer; private metrics: IMetrics; - constructor(private readonly config: IGrpcServerConfiguration & IGrpcConfiguration) { + constructor( + private readonly config: IGrpcServerConfiguration & IGrpcConfiguration, + private readonly options?: IGrpcServerOptions + ) { if (!config.skipNoStreamingValidation) { for (const def of config.definitions) { for (const key of Object.keys(def)) { @@ -180,7 +184,11 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization { if (value !== undefined) { callback(undefined, value); } else if (error !== undefined) { - callback(this.createError(error), null); + if ((error as ServerErrorResponse).code !== undefined) { + callback(error, null); + } else { + callback(this.createError(error), null); + } } else { callback( this.createError("not implemented", status.UNIMPLEMENTED), @@ -198,7 +206,15 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization { path: method.path, }); }); - + if (this.options?.apiKey) { + if (!this.isApiKeyValid(call.metadata)) { + await msgRef.release( + undefined, + this.createError("Invalid API Key", status.UNAUTHENTICATED) + ); + return; + } + } if (!(await this.queue.enqueue(msgRef))) { await msgRef.release(undefined, new Error("service unavailable")); } @@ -239,4 +255,9 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization { message: error.toString(), }; } + + private isApiKeyValid(meta: Metadata) { + const headerValue = meta.get("authorization"); + return headerValue?.[0]?.toString() === this.options.apiKey; + } } diff --git a/packages/grpc/src/internal/client.ts b/packages/grpc/src/internal/client.ts index 92d922c1..2799d4c1 100644 --- a/packages/grpc/src/internal/client.ts +++ b/packages/grpc/src/internal/client.ts @@ -30,7 +30,7 @@ import { import { FORMAT_HTTP_HEADERS, Span, SpanContext, Tags, Tracer } from "opentracing"; import { performance } from "perf_hooks"; import { createGrpcConfiguration, createServiceDefinition } from "."; -import { IGrpcClientConfiguration, IGrpcConfiguration } from ".."; +import { IGrpcClientConfiguration, IGrpcClientOptions, IGrpcConfiguration } from ".."; enum GrpcMetrics { RequestSent = "cookie_cutter.grpc_client.request_sent", @@ -75,23 +75,27 @@ class ClientBase implements IRequireInitialization, IDisposable { export function createGrpcClient( config: IGrpcClientConfiguration & IGrpcConfiguration, - certPath?: string + options?: IGrpcClientOptions ): T & IDisposable & IRequireInitialization { const serviceDef = createServiceDefinition(config.definition); let client: Client; const ClientType = makeGenericClientConstructor(serviceDef, undefined, undefined); + const certPath = options?.certPath; + const apiKey = options?.apiKey; if (certPath) { const rootCert = readFileSync(certPath); const channelCreds = credentials.createSsl(rootCert); - const metaCallback = (_params: any, callback: (arg0: null, arg1: Metadata) => void) => { - const meta = new Metadata(); - meta.add("custom-auth-header", "token"); - callback(null, meta); - }; - - const callCreds = credentials.createFromMetadataGenerator(metaCallback); - const combCreds = credentials.combineChannelCredentials(channelCreds, callCreds); + let combCreds = channelCreds; + if (apiKey) { + const metaCallback = (_params: any, callback: (arg0: null, arg1: Metadata) => void) => { + const meta = new Metadata(); + meta.add("authorization", apiKey); + callback(null, meta); + }; + const callCreds = credentials.createFromMetadataGenerator(metaCallback); + combCreds = credentials.combineChannelCredentials(channelCreds, callCreds); + } client = new ClientType(config.endpoint, combCreds, createGrpcConfiguration(config)); } else { client = new ClientType( @@ -165,12 +169,16 @@ export function createGrpcClient( const stream = await retrier.retry((bail) => { try { + const meta = createTracingMetadata(wrapper.tracer, span); + if (!certPath && apiKey) { + meta.set("authorization", apiKey); + } return client.makeServerStreamRequest( method.path, method.requestSerialize, method.responseDeserialize, request, - createTracingMetadata(wrapper.tracer, span), + meta, callOptions() ); } catch (e) { @@ -239,12 +247,16 @@ export function createGrpcClient( return await retrier.retry(async (bail) => { try { return await new Promise((resolve, reject) => { + const meta = createTracingMetadata(wrapper.tracer, span); + if (!certPath && apiKey) { + meta.set("authorization", apiKey); + } client.makeUnaryRequest( method.path, method.requestSerialize, method.responseDeserialize, request, - createTracingMetadata(wrapper.tracer, span), + meta, callOptions(), (error, value) => { this.metrics.increment(GrpcMetrics.RequestProcessed, {