From 08caf19cfbdbc0babb21b8d56cdec373719abc9d Mon Sep 17 00:00:00 2001 From: Nikolay Matrosov Date: Mon, 30 Oct 2023 17:45:19 +0100 Subject: [PATCH] feat: make possible to provide additional auth headers needed for AI services --- src/session.ts | 28 ++++++++++++++++++++++------ src/types.ts | 1 + 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/session.ts b/src/session.ts index 4dd758e4..c88cf599 100644 --- a/src/session.ts +++ b/src/session.ts @@ -4,16 +4,18 @@ import { import { createChannel } from 'nice-grpc'; import { Required } from 'utility-types'; import { + ChannelSslOptions, GeneratedServiceClientCtor, IamTokenCredentialsConfig, OAuthCredentialsConfig, - ServiceAccountCredentialsConfig, WrappedServiceClientType, - SessionConfig, ChannelSslOptions, + ServiceAccountCredentialsConfig, + SessionConfig, + WrappedServiceClientType, } from './types'; import { IamTokenService } from './token-service/iam-token-service'; import { MetadataTokenService } from './token-service/metadata-token-service'; import { clientFactory } from './utils/client-factory'; -import { serviceClients, cloudApi } from '.'; +import { cloudApi, serviceClients } from '.'; import { getServiceClientEndpoint } from './service-endpoints'; const isOAuth = (config: SessionConfig): config is OAuthCredentialsConfig => 'oauthToken' in config; @@ -39,7 +41,8 @@ const newTokenCreator = (config: SessionConfig): () => Promise => { yandexPassportOauthToken: config.oauthToken, }); }; - } if (isIamToken(config)) { + } + if (isIamToken(config)) { const { iamToken } = config; return async () => iamToken; @@ -50,7 +53,11 @@ const newTokenCreator = (config: SessionConfig): () => Promise => { return async () => tokenService.getToken(); }; -const newChannelCredentials = (tokenCreator: TokenCreator, sslOptions?: ChannelSslOptions) => credentials.combineChannelCredentials( +const newChannelCredentials = ( + tokenCreator: TokenCreator, + sslOptions?: ChannelSslOptions, + headers?: Record, +) => credentials.combineChannelCredentials( credentials.createSsl(sslOptions?.rootCerts, sslOptions?.privateKey, sslOptions?.certChain), credentials.createFromMetadataGenerator( ( @@ -62,6 +69,15 @@ const newChannelCredentials = (tokenCreator: TokenCreator, sslOptions?: ChannelS const md = new Metadata(); md.set('authorization', `Bearer ${token}`); + if (headers) { + for (const [key, value] of Object.entries(headers)) { + const lowerCaseKey = key.toLowerCase(); + + if (lowerCaseKey !== 'authorization') { + md.set(lowerCaseKey, value); + } + } + } return callback(null, md); }) @@ -87,7 +103,7 @@ export class Session { ...config, }; this.tokenCreator = newTokenCreator(this.config); - this.channelCredentials = newChannelCredentials(this.tokenCreator, this.config.ssl); + this.channelCredentials = newChannelCredentials(this.tokenCreator, this.config.ssl, this.config.headers); } get pollInterval(): number { diff --git a/src/types.ts b/src/types.ts index 18bda45f..2a902367 100644 --- a/src/types.ts +++ b/src/types.ts @@ -44,6 +44,7 @@ export interface ChannelSslOptions { export interface GenericCredentialsConfig { pollInterval?: number; ssl?: ChannelSslOptions + headers?: Record; } export interface OAuthCredentialsConfig extends GenericCredentialsConfig {