Skip to content

Add Roots support to the MCP client and server #2895

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions js/plugins/mcp/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
"exports": {
".": {
"require": "./lib/index.js",
"import": "./lib/index.mjs",
"types": "./lib/index.d.ts",
"default": "./lib/index.js"
"import": "./lib/index.mjs"
}
},
"type": "commonjs",
Expand All @@ -28,7 +26,7 @@
"build:clean": "rimraf ./lib",
"build": "npm-run-all build:clean check compile",
"build:watch": "tsup-node --watch",
"test": "tsx --test ./tests/*_test.ts"
"test": "jest"
},
"repository": {
"type": "git",
Expand All @@ -43,16 +41,23 @@
},
"devDependencies": {
"@jest/globals": "^29.7.0",
"@types/jest": "^29.5.12",
"@types/node": "^20.11.16",
"jest": "^29.7.0",
"npm-run-all": "^4.1.5",
"rimraf": "^6.0.1",
"ts-jest": "^29.1.2",
"ts-jest": "^29.2.5",
"tsup": "^8.3.5",
"tsx": "^4.19.2",
"typescript": "^5.3.0"
"typescript": "^5.6.3"
},
"dependencies": {
"@modelcontextprotocol/sdk": "^1.8.0"
"@modelcontextprotocol/sdk": "^1.11.0"
},
"jest": {
"preset": "ts-jest",
"testMatch": [
"<rootDir>/tests/**/*_test.ts"
]
}
}
136 changes: 136 additions & 0 deletions js/plugins/mcp/src/client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import type { Client } from '@modelcontextprotocol/sdk/client/index.js' with { 'resolution-mode': 'import' };
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' with { 'resolution-mode': 'import' };
import type {
ListRootsRequest,
ListRootsResult,
Root,
ServerCapabilities,
} from '@modelcontextprotocol/sdk/types.js' with { 'resolution-mode': 'import' };
import { ListRootsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
import { Genkit, GenkitError } from 'genkit';
import { registerAllPrompts } from './client/prompts';
import { registerResourceTools } from './client/resources';
import { registerAllTools } from './client/tools';
import type { McpClientOptions } from './index.js';

async function transportFrom(options: McpClientOptions): Promise<Transport> {
if (options.transport) return options.transport;
if (options.serverUrl) {
const { SSEClientTransport } = await import(
'@modelcontextprotocol/sdk/client/sse.js'
);
return new SSEClientTransport(new URL(options.serverUrl));
}
if (options.serverProcess) {
const { StdioClientTransport } = await import(
'@modelcontextprotocol/sdk/client/stdio.js'
);
return new StdioClientTransport(options.serverProcess);
}
if (options.serverWebsocketUrl) {
const { WebSocketClientTransport } = await import(
'@modelcontextprotocol/sdk/client/websocket.js'
);
let url = options.serverWebsocketUrl;
if (typeof url === 'string') url = new URL(url);
return new WebSocketClientTransport(url);
}

throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: `Unable to create a server connection with supplied options. Must provide transport, stdio, or sseUrl:\n${JSON.stringify(
options,
null,
2
)}`,
});
}

export class GenkitMcpClient {
ai: Genkit;
options: McpClientOptions;
client?: Client;
serverCapabilities?: ServerCapabilities | undefined = {};
_isSetup: boolean = false;

constructor(ai: Genkit, options: McpClientOptions) {
this.ai = ai;
this.options = options;
}

async setup(): Promise<void> {
if (this._isSetup) return;
const { Client } = await import(
'@modelcontextprotocol/sdk/client/index.js'
);

const transport = await transportFrom(this.options);
this.client = new Client(
{
name: this.options.name,
version: this.options.version || '1.0.0',
roots: this.options.roots,
},
{
capabilities: {
// TODO: Allow actually changing the roots dynamically. This requires
// manipulating which tools, resources, etc. are registered, since
// they can change based on the roots.
roots: { listChanged: false },
},
}
);

this.client.setRequestHandler(
ListRootsRequestSchema,
this.listRoots.bind(this)
);

await this.client.connect(transport);
this.serverCapabilities = this.client.getServerCapabilities();

await this.registerCapabilities();
this._isSetup = true;
}

async registerCapabilities(): Promise<void> {
if (!this.client || !this.serverCapabilities) {
return;
}
const promises: Promise<any>[] = [];
if (this.serverCapabilities?.tools) {
promises.push(registerAllTools(this.ai, this.client, this.options));
}
if (this.serverCapabilities?.prompts) {
promises.push(registerAllPrompts(this.ai, this.client, this.options));
}
if (this.serverCapabilities?.resources) {
promises.push(registerResourceTools(this.ai, this.client, this.options));
}
await Promise.all(promises);
}

async listRoots(req: ListRootsRequest): Promise<ListRootsResult> {
if (!this.options.roots) {
return { roots: [] };
}
const mcpRoots: Root[] = this.options.roots.map<Root>((root) => root);
return { roots: mcpRoots };
}
}
1 change: 1 addition & 0 deletions js/plugins/mcp/src/client/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export function fromMcpPart(part: PromptMessage['content']): Part {
url: `data:${part.mimeType};base64,${part.data}`,
},
};
case 'audio':
case 'resource':
return {};
}
Expand Down
8 changes: 6 additions & 2 deletions js/plugins/mcp/src/client/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ function registerPrompt(
}

/**
* Lookup all tools available in the server and register each as a Genkit tool.
* Lookup all prompts available in the server and register each as a Genkit
* prompt.
*/
export async function registerAllPrompts(
ai: Genkit,
Expand All @@ -72,7 +73,10 @@ export async function registerAllPrompts(
): Promise<void> {
let cursor: string | undefined;
while (true) {
const { nextCursor, prompts } = await client.listPrompts({ cursor });
const { nextCursor, prompts } = await client.listPrompts({
cursor,
roots: params.roots,
});
prompts.forEach((p) => registerPrompt(ai, client, p, params));
cursor = nextCursor;
if (!cursor) break;
Expand Down
25 changes: 22 additions & 3 deletions js/plugins/mcp/src/client/resources.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ export async function registerResourceTools(
client: Client,
params: McpClientOptions
) {
const rootsList = params.roots
? params.roots
.map<string>((root) => `${root.name} (${root.uri})`)
.join(', ')
: '';

ai.defineTool(
{
name: `${params.name}/list_resources`,
Expand All @@ -36,18 +42,31 @@ export async function registerResourceTools(
cursor: z.string().optional(),
/** When specified, automatically paginate and fetch all resources. */
all: z.boolean().optional(),
roots: z
.array(z.object({ name: z.string().optional(), uri: z.string() }))
.optional()
.describe(
`The list of roots to limit the results to. Available roots: ${rootsList}`
),
}),
},
async ({ cursor, all }) => {
async ({
cursor,
all,
roots,
}): Promise<{ nextCursor?: string | undefined; resources: Resource[] }> => {
if (!all) {
return client.listResources();
return client.listResources({ roots: roots || params.roots });
}

let currentCursor: string | undefined = cursor;
const resources: Resource[] = [];
while (true) {
const { nextCursor, resources: newResources } =
await client.listResources({ cursor: currentCursor });
await client.listResources({
cursor: currentCursor,
roots: roots || params.roots,
});
resources.push(...newResources);
currentCursor = nextCursor;
if (!currentCursor) break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,22 @@
* limitations under the License.
*/

// no tests... :(
import { Genkit, z } from 'genkit';
import type { McpClientOptions } from '../index.js';

export async function registerClientRoots(
ai: Genkit,
params: McpClientOptions
) {
ai.defineTool(
{
name: `${params.name}/list_roots`,
description: `Lists the available roots for this MCP client`,
inputSchema: z.object({}),
outputSchema: z.array(z.object({ name: z.string(), uri: z.string() })),
},
async () => {
return params.roots ?? [];
}
);
}
5 changes: 4 additions & 1 deletion js/plugins/mcp/src/client/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ export async function registerAllTools(
): Promise<void> {
let cursor: string | undefined;
while (true) {
const { nextCursor, tools } = await client.listTools({ cursor });
const { nextCursor, tools } = await client.listTools({
cursor,
roots: params.roots,
});
tools.forEach((t) => registerTool(ai, client, t, params));
cursor = nextCursor;
if (!cursor) break;
Expand Down
67 changes: 12 additions & 55 deletions js/plugins/mcp/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@

import type { StdioServerParameters } from '@modelcontextprotocol/sdk/client/stdio.js' with { 'resolution-mode': 'import' };
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' with { 'resolution-mode': 'import' };
import { Genkit, GenkitError } from 'genkit';
import { Genkit } from 'genkit';
import { genkitPlugin } from 'genkit/plugin';
import { registerAllPrompts } from './client/prompts.js';
import { registerResourceTools } from './client/resources.js';
import { registerAllTools } from './client/tools.js';
import { GenkitMcpServer } from './server.js';
import { GenkitMcpClient } from './client';
import { GenkitMcpServer } from './server';

export interface McpClientOptions {
/** Provide a name for this client which will be its namespace for all tools and prompts. */
Expand All @@ -38,59 +36,16 @@ export interface McpClientOptions {
serverWebsocketUrl?: string | URL;
/** Return tool responses in raw MCP form instead of processing them for Genkit compatibility. */
rawToolResponses?: boolean;
/** Specify the MCP roots this client would like the server to work within. */
roots?: { name: string; uri: string }[];
}

async function transportFrom(params: McpClientOptions): Promise<Transport> {
if (params.transport) return params.transport;
if (params.serverUrl) {
const { SSEClientTransport } = await import(
'@modelcontextprotocol/sdk/client/sse.js'
);
return new SSEClientTransport(new URL(params.serverUrl));
}
if (params.serverProcess) {
const { StdioClientTransport } = await import(
'@modelcontextprotocol/sdk/client/stdio.js'
);
return new StdioClientTransport(params.serverProcess);
}
if (params.serverWebsocketUrl) {
const { WebSocketClientTransport } = await import(
'@modelcontextprotocol/sdk/client/websocket.js'
);
let url = params.serverWebsocketUrl;
if (typeof url === 'string') url = new URL(url);
return new WebSocketClientTransport(url);
}
const mcpClients: Record<string, GenkitMcpClient> = {};

throw new GenkitError({
status: 'INVALID_ARGUMENT',
message:
'Unable to create a server connection with supplied options. Must provide transport, stdio, or sseUrl.',
});
}

export function mcpClient(params: McpClientOptions) {
return genkitPlugin(params.name, async (ai: Genkit) => {
const { Client } = await import(
'@modelcontextprotocol/sdk/client/index.js'
);

const transport = await transportFrom(params);
ai.options.model;
const client = new Client(
{ name: params.name, version: params.version || '1.0.0' },
{ capabilities: {} }
);
await client.connect(transport);
const capabilties = await client.getServerCapabilities();
const promises: Promise<any>[] = [];
if (capabilties?.tools) promises.push(registerAllTools(ai, client, params));
if (capabilties?.prompts)
promises.push(registerAllPrompts(ai, client, params));
if (capabilties?.resources)
promises.push(registerResourceTools(ai, client, params));
await Promise.all(promises);
export function mcpClient(options: McpClientOptions) {
return genkitPlugin(options.name, async (ai: Genkit) => {
mcpClients[options.name] = new GenkitMcpClient(ai, options);
await mcpClients[options.name].setup();
});
}

Expand All @@ -99,6 +54,8 @@ export interface McpServerOptions {
name: string;
/** The version you want the server to advertise to clients. Defaults to 1.0.0. */
version?: string;
/** The MCP roots this server is associated with or serves. */
roots?: { name: string; uri: string }[];
}

export function mcpServer(ai: Genkit, options: McpServerOptions) {
Expand Down
Loading