From 438576cc9ebc33da21801ba36df02f664eb57cc6 Mon Sep 17 00:00:00 2001 From: Greg Spencer Date: Wed, 7 May 2025 12:35:17 -0700 Subject: [PATCH 1/3] Add Roots support to the MCP client and server --- js/plugins/mcp/package.json | 19 +- js/plugins/mcp/src/client/message.ts | 1 + js/plugins/mcp/src/client/prompts.ts | 5 +- js/plugins/mcp/src/client/resources.ts | 37 +- js/plugins/mcp/src/client/tools.ts | 5 +- js/plugins/mcp/src/index.ts | 43 +- js/plugins/mcp/src/server.ts | 29 +- js/plugins/mcp/tests/client_test.ts | 464 +++++++++++++++++ js/plugins/mcp/tests/mcp_test.ts | 17 - js/plugins/mcp/tests/server_test.ts | 673 +++++++++++++++++++++++++ js/pnpm-lock.yaml | 67 +-- 11 files changed, 1286 insertions(+), 74 deletions(-) create mode 100644 js/plugins/mcp/tests/client_test.ts delete mode 100644 js/plugins/mcp/tests/mcp_test.ts create mode 100644 js/plugins/mcp/tests/server_test.ts diff --git a/js/plugins/mcp/package.json b/js/plugins/mcp/package.json index 0ec2b1e11d..f143b888b1 100644 --- a/js/plugins/mcp/package.json +++ b/js/plugins/mcp/package.json @@ -16,9 +16,7 @@ "exports": { ".": { "require": "./lib/index.js", - "import": "./lib/index.mjs", - "types": "./lib/index.d.ts", - "default": "./lib/index.js" + "import": "./lib/index.mjs" } }, "type": "commonjs", @@ -28,7 +26,7 @@ "build:clean": "rimraf ./lib", "build": "npm-run-all build:clean check compile", "build:watch": "tsup-node --watch", - "test": "tsx --test ./tests/*_test.ts" + "test": "jest" }, "repository": { "type": "git", @@ -43,16 +41,23 @@ }, "devDependencies": { "@jest/globals": "^29.7.0", + "@types/jest": "^29.5.12", "@types/node": "^20.11.16", "jest": "^29.7.0", "npm-run-all": "^4.1.5", "rimraf": "^6.0.1", - "ts-jest": "^29.1.2", + "ts-jest": "^29.2.5", "tsup": "^8.3.5", "tsx": "^4.19.2", - "typescript": "^5.3.0" + "typescript": "^5.6.3" }, "dependencies": { - "@modelcontextprotocol/sdk": "^1.8.0" + "@modelcontextprotocol/sdk": "^1.11.0" + }, + "jest": { + "preset": "ts-jest", + "testMatch": [ + "/tests/**/*_test.ts" + ] } } diff --git a/js/plugins/mcp/src/client/message.ts b/js/plugins/mcp/src/client/message.ts index ab5f306267..9d4acc4ca0 100644 --- a/js/plugins/mcp/src/client/message.ts +++ b/js/plugins/mcp/src/client/message.ts @@ -40,6 +40,7 @@ export function fromMcpPart(part: PromptMessage['content']): Part { url: `data:${part.mimeType};base64,${part.data}`, }, }; + case 'audio': case 'resource': return {}; } diff --git a/js/plugins/mcp/src/client/prompts.ts b/js/plugins/mcp/src/client/prompts.ts index fb7ef601c6..ee4b7c26bd 100644 --- a/js/plugins/mcp/src/client/prompts.ts +++ b/js/plugins/mcp/src/client/prompts.ts @@ -72,7 +72,10 @@ export async function registerAllPrompts( ): Promise { let cursor: string | undefined; while (true) { - const { nextCursor, prompts } = await client.listPrompts({ cursor }); + const { nextCursor, prompts } = await client.listPrompts({ + cursor, + roots: params.roots, + }); prompts.forEach((p) => registerPrompt(ai, client, p, params)); cursor = nextCursor; if (!cursor) break; diff --git a/js/plugins/mcp/src/client/resources.ts b/js/plugins/mcp/src/client/resources.ts index 00c45c381d..69a0625f11 100644 --- a/js/plugins/mcp/src/client/resources.ts +++ b/js/plugins/mcp/src/client/resources.ts @@ -27,27 +27,56 @@ export async function registerResourceTools( client: Client, params: McpClientOptions ) { + const rootsList = params.roots + ? params.roots + .map((root) => `${root.name} (${root.uri})`) + .join(', ') + : ''; + ai.defineTool( { name: `${params.name}/list_resources`, - description: `list all available resources for '${params.name}'`, + description: `list all available resources for '${params.name}'${params.roots ? `, within roots ${rootsList}` : ''}`, inputSchema: z.object({ /** Provide a cursor for accessing additional paginated results. */ cursor: z.string().optional(), /** When specified, automatically paginate and fetch all resources. */ all: z.boolean().optional(), + /** The list of roots to limit the results to. Must be a subset of params.roots. */ + roots: z + .array(z.object({ name: z.string().optional(), uri: z.string() })) + .optional() + .describe( + `The list of roots to limit the results to. Must be a subset of ${rootsList}` + ), }), }, - async ({ cursor, all }) => { + async ({ + cursor, + all, + roots, + }): Promise<{ nextCursor?: string | undefined; resources: Resource[] }> => { + // Filter the roots so that they only contain roots in the params.roots list. + if (roots) { + roots = roots.filter((root) => + params.roots?.some( + (pRoot) => pRoot.name === root.name && pRoot.uri === root.uri + ) + ); + } + if (!all) { - return client.listResources(); + return client.listResources({ roots: roots || params.roots }); } let currentCursor: string | undefined = cursor; const resources: Resource[] = []; while (true) { const { nextCursor, resources: newResources } = - await client.listResources({ cursor: currentCursor }); + await client.listResources({ + cursor: currentCursor, + roots: roots || params.roots, + }); resources.push(...newResources); currentCursor = nextCursor; if (!currentCursor) break; diff --git a/js/plugins/mcp/src/client/tools.ts b/js/plugins/mcp/src/client/tools.ts index 04778bbdc8..49e164842e 100644 --- a/js/plugins/mcp/src/client/tools.ts +++ b/js/plugins/mcp/src/client/tools.ts @@ -88,7 +88,10 @@ export async function registerAllTools( ): Promise { let cursor: string | undefined; while (true) { - const { nextCursor, tools } = await client.listTools({ cursor }); + const { nextCursor, tools } = await client.listTools({ + cursor, + roots: params.roots, + }); tools.forEach((t) => registerTool(ai, client, t, params)); cursor = nextCursor; if (!cursor) break; diff --git a/js/plugins/mcp/src/index.ts b/js/plugins/mcp/src/index.ts index c950cf037d..5516927784 100644 --- a/js/plugins/mcp/src/index.ts +++ b/js/plugins/mcp/src/index.ts @@ -18,10 +18,10 @@ import type { StdioServerParameters } from '@modelcontextprotocol/sdk/client/std import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' with { 'resolution-mode': 'import' }; import { Genkit, GenkitError } from 'genkit'; import { genkitPlugin } from 'genkit/plugin'; -import { registerAllPrompts } from './client/prompts.js'; -import { registerResourceTools } from './client/resources.js'; -import { registerAllTools } from './client/tools.js'; -import { GenkitMcpServer } from './server.js'; +import { registerAllPrompts } from './client/prompts'; +import { registerResourceTools } from './client/resources'; +import { registerAllTools } from './client/tools'; +import { GenkitMcpServer } from './server'; export interface McpClientOptions { /** Provide a name for this client which will be its namespace for all tools and prompts. */ @@ -38,6 +38,8 @@ export interface McpClientOptions { serverWebsocketUrl?: string | URL; /** Return tool responses in raw MCP form instead of processing them for Genkit compatibility. */ rawToolResponses?: boolean; + /** Specify the MCP roots this client would like the server to work within. */ + roots?: { name: string; uri: string }[]; } async function transportFrom(params: McpClientOptions): Promise { @@ -77,19 +79,36 @@ export function mcpClient(params: McpClientOptions) { ); const transport = await transportFrom(params); - ai.options.model; const client = new Client( - { name: params.name, version: params.version || '1.0.0' }, - { capabilities: {} } + { + name: params.name, + version: params.version || '1.0.0', + roots: params.roots, + }, + { + capabilities: { + // TODO: Support sending root list change notifications to the server. + // Would require some way to update the list of roots outside of + // client params. Also requires that our registrations of tools and + // resources are able to handle updates to the list of roots this + // client defines, since tool and resource lists are affected by the + // list of roots. + roots: { listChanged: false }, + }, + } ); await client.connect(transport); - const capabilties = await client.getServerCapabilities(); + const capabilities = client.getServerCapabilities(); const promises: Promise[] = []; - if (capabilties?.tools) promises.push(registerAllTools(ai, client, params)); - if (capabilties?.prompts) + if (capabilities?.tools) { + promises.push(registerAllTools(ai, client, params)); + } + if (capabilities?.prompts) { promises.push(registerAllPrompts(ai, client, params)); - if (capabilties?.resources) + } + if (capabilities?.resources) { promises.push(registerResourceTools(ai, client, params)); + } await Promise.all(promises); }); } @@ -99,6 +118,8 @@ export interface McpServerOptions { name: string; /** The version you want the server to advertise to clients. Defaults to 1.0.0. */ version?: string; + /** The MCP roots this server is associated with or serves. */ + roots?: { name: string; uri: string }[]; } export function mcpServer(ai: Genkit, options: McpServerOptions) { diff --git a/js/plugins/mcp/src/server.ts b/js/plugins/mcp/src/server.ts index 55599f6170..ee8305ea07 100644 --- a/js/plugins/mcp/src/server.ts +++ b/js/plugins/mcp/src/server.ts @@ -33,10 +33,13 @@ import type { GetPromptResult, ListPromptsRequest, ListPromptsResult, + ListRootsRequest, + ListRootsResult, ListToolsRequest, ListToolsResult, Prompt, PromptMessage, + Root, Tool, } from '@modelcontextprotocol/sdk/types.js' with { 'resolution-mode': 'import' }; import { logger } from 'genkit/logging'; @@ -52,7 +55,6 @@ export class GenkitMcpServer { constructor(ai: Genkit, options: McpServerOptions) { this.ai = ai; this.options = options; - this.setup(); } async setup(): Promise { @@ -62,11 +64,19 @@ export class GenkitMcpServer { ); this.server = new Server( - { name: this.options.name, version: this.options.version || '1.0.0' }, + { + name: this.options.name, + version: this.options.version || '1.0.0', + roots: this.options.roots, + }, { capabilities: { prompts: {}, tools: {}, + // TODO: Support sending root list change notifications to the client. + // Would require some way to update the list of roots outside of the + // server params. + roots: { listChanged: false }, }, } ); @@ -76,6 +86,7 @@ export class GenkitMcpServer { GetPromptRequestSchema, ListPromptsRequestSchema, ListToolsRequestSchema, + ListRootsRequestSchema, } = await import('@modelcontextprotocol/sdk/types.js'); this.server.setRequestHandler( @@ -94,6 +105,10 @@ export class GenkitMcpServer { GetPromptRequestSchema, this.getPrompt.bind(this) ); + this.server.setRequestHandler( + ListRootsRequestSchema, + this.listRoots.bind(this) + ); const allActions = await this.ai.registry.listActions(); const toolList: ToolAction[] = []; @@ -124,6 +139,16 @@ export class GenkitMcpServer { }; } + async listRoots(req: ListRootsRequest): Promise { + await this.setup(); + if (!this.options.roots) { + return { roots: [] }; + } + + const mcpRoots: Root[] = this.options.roots.map((root) => root); + return { roots: mcpRoots }; + } + async callTool(req: CallToolRequest): Promise { await this.setup(); const tool = this.toolActions.find( diff --git a/js/plugins/mcp/tests/client_test.ts b/js/plugins/mcp/tests/client_test.ts new file mode 100644 index 0000000000..d94b32696a --- /dev/null +++ b/js/plugins/mcp/tests/client_test.ts @@ -0,0 +1,464 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { beforeEach, describe, expect, it, jest } from '@jest/globals'; +import type { StdioServerParameters } from '@modelcontextprotocol/sdk/client/stdio.js'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; +import type { Genkit } from 'genkit'; +import { McpClientOptions, mcpClient } from '../src/index'; + +type MockedFunction any> = jest.MockedFunction; +type MockedGenkitPluginSignature = ( + name: string, + onInit: (ai: Genkit) => Promise +) => { name: string }; + +const mockSSEClientTransport = jest.fn(); +const mockStdioClientTransport = jest.fn(); +const mockWebSocketClientTransport = jest.fn(); + +type McpServerCapabilities = { + tools?: boolean; + prompts?: boolean; + resources?: boolean; + roots?: { listChanged: boolean }; +}; + +const mockMcpClientInstance = { + connect: jest + .fn<(transport: Transport) => Promise>() + .mockResolvedValue(undefined), + getServerCapabilities: jest + .fn<() => McpServerCapabilities | undefined>() + .mockReturnValue({ + tools: true, + prompts: true, + resources: true, + }), +}; +const MockMcpClient = jest.fn(() => mockMcpClientInstance); + +type RegistrationFunction = ( + ai: Genkit, + client: typeof mockMcpClientInstance, + params: McpClientOptions +) => Promise; + +const mockGenkitInstance = { + options: {}, +} as Genkit; + +jest.mock('@modelcontextprotocol/sdk/client/sse.js', () => ({ + SSEClientTransport: mockSSEClientTransport, +})); +jest.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({ + StdioClientTransport: mockStdioClientTransport, +})); +jest.mock('@modelcontextprotocol/sdk/client/websocket.js', () => ({ + WebSocketClientTransport: mockWebSocketClientTransport, +})); +jest.mock('@modelcontextprotocol/sdk/client/index.js', () => ({ + Client: MockMcpClient, +})); + +jest.mock('genkit/plugin', () => { + return { + genkitPlugin: jest.fn(), + }; +}); + +jest.mock('genkit', () => { + class MockGenkitError extends Error { + constructor(public details: { status?: string; message: string }) { + super(details.message); + this.name = 'GenkitError'; + if (details.status) { + (this as any).status = details.status; + } + } + } + + const originalGenkitModule = + jest.requireActual('genkit'); + return { + __esModule: true, + ...originalGenkitModule, + GenkitError: MockGenkitError, + }; +}); + +jest.mock('../src/client/tools.ts', () => ({ + registerAllTools: jest.fn(), +})); +jest.mock('../src/client/prompts.ts', () => ({ + registerAllPrompts: jest.fn(), +})); +jest.mock('../src/client/resources.ts', () => ({ + registerResourceTools: jest.fn(), +})); + +const getMockedGenkitErrorConstructor = () => { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const genkitModule = require('genkit'); + return genkitModule.GenkitError as unknown as new (details: { + status?: string; + message: string; + }) => Error & { details: { status?: string; message: string } }; +}; + +describe('mcpClient', () => { + let pluginSetupFunction: (ai: Genkit) => Promise; + + let localMockRegisterAllTools: MockedFunction; + let localMockRegisterAllPrompts: MockedFunction; + let localMockRegisterResourceTools: MockedFunction; + + let localMockGenkitPlugin: MockedFunction; + + beforeEach(() => { + jest.clearAllMocks(); + + const mockedPluginModule = jest.requireMock('genkit/plugin') as { + genkitPlugin: MockedFunction; + }; + localMockGenkitPlugin = mockedPluginModule.genkitPlugin; + + const mockedToolsModule = jest.requireMock('../src/client/tools.ts') as { + registerAllTools: MockedFunction; + }; + localMockRegisterAllTools = mockedToolsModule.registerAllTools; + + const mockedPromptsModule = jest.requireMock( + '../src/client/prompts.ts' + ) as { + registerAllPrompts: MockedFunction; + }; + localMockRegisterAllPrompts = mockedPromptsModule.registerAllPrompts; + + const mockedResourcesModule = jest.requireMock( + '../src/client/resources.ts' + ) as { + registerResourceTools: MockedFunction; + }; + localMockRegisterResourceTools = + mockedResourcesModule.registerResourceTools; + + localMockGenkitPlugin.mockImplementation((name, func) => { + pluginSetupFunction = func; + return { name: `plugin-${name}` }; + }); + + mockMcpClientInstance.connect.mockResolvedValue(undefined); + MockMcpClient.mockClear().mockReturnValue(mockMcpClientInstance); + mockSSEClientTransport.mockClear(); + mockStdioClientTransport.mockClear(); + mockWebSocketClientTransport.mockClear(); + + mockMcpClientInstance.getServerCapabilities.mockReturnValue({ + tools: true, + prompts: true, + resources: true, + roots: { listChanged: false }, + }); + + localMockRegisterAllTools.mockClear().mockResolvedValue(undefined); + localMockRegisterAllPrompts.mockClear().mockResolvedValue(undefined); + localMockRegisterResourceTools.mockClear().mockResolvedValue(undefined); + }); + + it('should initialize with SSE transport and register all capabilities by default', async () => { + const options: McpClientOptions = { + name: 'test-sse-client', + serverUrl: 'http://localhost:1234/sse', + }; + mcpClient(options); + const capabilitiesToReturn = { + tools: true, + prompts: true, + resources: true, + roots: { listChanged: false }, + }; + mockMcpClientInstance.getServerCapabilities.mockReturnValue( + capabilitiesToReturn + ); + + await pluginSetupFunction(mockGenkitInstance); + + expect(MockMcpClient).toHaveBeenCalledTimes(1); + expect(MockMcpClient).toHaveBeenCalledWith(expect.any(Object), { + capabilities: { + roots: { listChanged: false }, + }, + }); + expect(mockMcpClientInstance.connect).toHaveBeenCalled(); + expect(mockMcpClientInstance.getServerCapabilities).toHaveBeenCalled(); + expect(localMockRegisterAllTools).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + options + ); + expect(localMockRegisterAllPrompts).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + options + ); + expect(localMockRegisterResourceTools).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + options + ); + }); + + it('should initialize with Stdio transport', async () => { + const serverProcessParams: StdioServerParameters = { + command: 'my-server-cmd', + }; + const options: McpClientOptions = { + name: 'test-stdio-client', + serverProcess: serverProcessParams, + }; + + mcpClient(options); + await pluginSetupFunction(mockGenkitInstance); + + expect(mockStdioClientTransport).toHaveBeenCalledWith(serverProcessParams); + expect(mockMcpClientInstance.connect).toHaveBeenCalledWith( + expect.any(mockStdioClientTransport) + ); + }); + + it('should initialize with WebSocket transport (string URL)', async () => { + const options: McpClientOptions = { + name: 'test-ws-client-string', + serverWebsocketUrl: 'ws://localhost:5678', + }; + + mcpClient(options); + await pluginSetupFunction(mockGenkitInstance); + + expect(mockWebSocketClientTransport).toHaveBeenCalledWith( + new URL(options.serverWebsocketUrl as string) + ); + expect(mockMcpClientInstance.connect).toHaveBeenCalledWith( + expect.any(mockWebSocketClientTransport) + ); + }); + + it('should initialize with WebSocket transport (URL object)', async () => { + const url = new URL('wss://secure.example.com:8080'); + const options: McpClientOptions = { + name: 'test-ws-client-url', + serverWebsocketUrl: url, + }; + + mcpClient(options); + await pluginSetupFunction(mockGenkitInstance); + + expect(mockWebSocketClientTransport).toHaveBeenCalledWith(url); + expect(mockMcpClientInstance.connect).toHaveBeenCalledWith( + expect.any(mockWebSocketClientTransport) + ); + }); + + it('should initialize with an existing transport', async () => { + const mockTransportInstance = {} as Transport; + const options: McpClientOptions = { + name: 'test-existing-transport-client', + transport: mockTransportInstance, + }; + + mcpClient(options); + await pluginSetupFunction(mockGenkitInstance); + + expect(mockMcpClientInstance.connect).toHaveBeenCalledWith( + mockTransportInstance + ); + expect(mockSSEClientTransport).not.toHaveBeenCalled(); + expect(mockStdioClientTransport).not.toHaveBeenCalled(); + expect(mockWebSocketClientTransport).not.toHaveBeenCalled(); + }); + + it('should throw GenkitError if no valid transport options are provided', async () => { + const options: McpClientOptions = { + name: 'test-error-client', + }; + + mcpClient(options); + const MockedError = getMockedGenkitErrorConstructor(); + + await expect(pluginSetupFunction(mockGenkitInstance)).rejects.toThrow( + MockedError + ); + try { + await pluginSetupFunction(mockGenkitInstance); + } catch (e: any) { + expect(e.details.status).toBe('INVALID_ARGUMENT'); + expect(e.message).toBe( + 'Unable to create a server connection with supplied options. Must provide transport, stdio, or sseUrl.' + ); + } + + expect(mockSSEClientTransport).not.toHaveBeenCalled(); + expect(mockStdioClientTransport).not.toHaveBeenCalled(); + expect(mockWebSocketClientTransport).not.toHaveBeenCalled(); + }); + + it('should use provided version, roots, and pass options to registration functions', async () => { + const options: McpClientOptions = { + name: 'test-options-client', + serverUrl: 'http://localhost:1234/sse', + version: '2.0.0-beta', + roots: [{ name: 'root1', uri: 'file:///project1' }], + rawToolResponses: true, + }; + + mcpClient(options); + await pluginSetupFunction(mockGenkitInstance); + + expect(MockMcpClient).toHaveBeenCalledWith( + { + name: options.name, + version: options.version, + roots: options.roots, + }, + { capabilities: { roots: { listChanged: false } } } + ); + expect(localMockRegisterAllTools).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + options + ); + expect(localMockRegisterAllPrompts).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + options + ); + expect(localMockRegisterResourceTools).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + options + ); + }); + + describe('capability-based registration', () => { + const baseOptions: McpClientOptions = { + name: 'test-caps-client', + serverUrl: 'http://localhost:9000/sse', + }; + + it('should only register tools if only tools capability is present', async () => { + mockMcpClientInstance.getServerCapabilities.mockReturnValue({ + tools: true, + prompts: false, + resources: false, + }); + mcpClient(baseOptions); + await pluginSetupFunction(mockGenkitInstance); + + expect(localMockRegisterAllTools).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + baseOptions + ); + expect(localMockRegisterAllPrompts).not.toHaveBeenCalled(); + expect(localMockRegisterResourceTools).not.toHaveBeenCalled(); + }); + + it('should only register prompts if only prompts capability is present', async () => { + mockMcpClientInstance.getServerCapabilities.mockReturnValue({ + tools: false, + prompts: true, + resources: false, + }); + mcpClient(baseOptions); + await pluginSetupFunction(mockGenkitInstance); + + expect(localMockRegisterAllTools).not.toHaveBeenCalled(); + expect(localMockRegisterAllPrompts).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + baseOptions + ); + expect(localMockRegisterResourceTools).not.toHaveBeenCalled(); + }); + + it('should only register resource tools if only resources capability is present', async () => { + mockMcpClientInstance.getServerCapabilities.mockReturnValue({ + tools: false, + prompts: false, + resources: true, + }); + mcpClient(baseOptions); + await pluginSetupFunction(mockGenkitInstance); + + expect(localMockRegisterAllTools).not.toHaveBeenCalled(); + expect(localMockRegisterAllPrompts).not.toHaveBeenCalled(); + expect(localMockRegisterResourceTools).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + baseOptions + ); + }); + + it('should register nothing if no capabilities are present (all false)', async () => { + mockMcpClientInstance.getServerCapabilities.mockReturnValue({ + tools: false, + prompts: false, + resources: false, + }); + mcpClient(baseOptions); + await pluginSetupFunction(mockGenkitInstance); + + expect(localMockRegisterAllTools).not.toHaveBeenCalled(); + expect(localMockRegisterAllPrompts).not.toHaveBeenCalled(); + expect(localMockRegisterResourceTools).not.toHaveBeenCalled(); + }); + + it('should register nothing if capabilities object is undefined', async () => { + ( + mockMcpClientInstance.getServerCapabilities as MockedFunction + ).mockResolvedValue(undefined); + mcpClient(baseOptions); + await pluginSetupFunction(mockGenkitInstance); + + expect(localMockRegisterAllTools).not.toHaveBeenCalled(); + expect(localMockRegisterAllPrompts).not.toHaveBeenCalled(); + expect(localMockRegisterResourceTools).not.toHaveBeenCalled(); + }); + + it('should register tools and prompts if those capabilities are present', async () => { + mockMcpClientInstance.getServerCapabilities.mockReturnValue({ + tools: true, + prompts: true, + resources: false, + }); + mcpClient(baseOptions); + await pluginSetupFunction(mockGenkitInstance); + + expect(localMockRegisterAllTools).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + baseOptions + ); + expect(localMockRegisterAllPrompts).toHaveBeenCalledWith( + mockGenkitInstance, + mockMcpClientInstance, + baseOptions + ); + expect(localMockRegisterResourceTools).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/js/plugins/mcp/tests/mcp_test.ts b/js/plugins/mcp/tests/mcp_test.ts deleted file mode 100644 index 0849c3e82b..0000000000 --- a/js/plugins/mcp/tests/mcp_test.ts +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// no tests... :( diff --git a/js/plugins/mcp/tests/server_test.ts b/js/plugins/mcp/tests/server_test.ts new file mode 100644 index 0000000000..bc409b6472 --- /dev/null +++ b/js/plugins/mcp/tests/server_test.ts @@ -0,0 +1,673 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { beforeEach, describe, expect, it, jest } from '@jest/globals'; +import { + Genkit, + GenkitError, + MessageData, + PromptAction, + ToolAction, +} from 'genkit'; +import type { McpServerOptions } from '../src/index'; +import { GenkitMcpServer } from '../src/server'; + +const mockToJsonSchema = jest.fn(); +const mockToToolDefinition = jest.fn(); +const mockLoggerInfo = jest.fn(); +const mockListActions = + jest.fn<() => Promise>>(); + +const mockMcpServerInstance = { + setRequestHandler: jest.fn(), + connect: jest.fn<() => Promise>().mockResolvedValue(undefined), +}; +const MockMcpServer = jest.fn(() => mockMcpServerInstance); +const mockStdioServerTransportInstance = {}; +const MockStdioServerTransport = jest.fn( + () => mockStdioServerTransportInstance +); + +jest.mock('@modelcontextprotocol/sdk/server/index.js', () => ({ + Server: MockMcpServer, +})); + +jest.mock('@modelcontextprotocol/sdk/types.js', () => ({ + CallToolRequestSchema: { _id: 'CallToolRequestSchema' }, + GetPromptRequestSchema: { _id: 'GetPromptRequestSchema' }, + ListPromptsRequestSchema: { _id: 'ListPromptsRequestSchema' }, + ListToolsRequestSchema: { _id: 'ListToolsRequestSchema' }, + ListRootsRequestSchema: { _id: 'ListRootsRequestSchema' }, +})); + +jest.mock('@modelcontextprotocol/sdk/server/stdio.js', () => ({ + StdioServerTransport: MockStdioServerTransport, +})); + +jest.mock('@genkit-ai/core/schema', () => ({ + get toJsonSchema() { + return mockToJsonSchema; + }, +})); + +jest.mock('genkit/tool', () => ({ + get toToolDefinition() { + return mockToToolDefinition; + }, +})); + +jest.mock('genkit/logging', () => ({ + logger: { + get info() { + return mockLoggerInfo; + }, + }, +})); + +interface MockGenkitInstance { + registry: { + listActions: typeof mockListActions; + }; +} + +const createMockGenkit = (): MockGenkitInstance => ({ + registry: { listActions: mockListActions }, +}); + +const createMockGenkitAction = ( + type: 'tool' | 'prompt', + name: string, + inputSchema?: any, + description?: string, + metadata?: Record +): { [key: string]: ToolAction | PromptAction } => { + const actionFunction = jest.fn() as unknown as ToolAction | PromptAction; + (actionFunction as any).__action = { + name, + inputSchema, + + inputJsonSchema: inputSchema + ? { type: 'object', properties: {} } + : undefined, + description, + ...metadata, + }; + return { [`/${type}/${name}`]: actionFunction }; +}; + +const createDefaultOptions = ( + overrides: Partial = {} +): McpServerOptions => ({ + name: 'test-mcp-server', + version: '1.0.0-test', + ...overrides, +}); + +describe('GenkitMcpServer', () => { + let mockGenkit: MockGenkitInstance; + + beforeEach(() => { + jest.clearAllMocks(); + mockGenkit = createMockGenkit(); + mockToToolDefinition.mockReturnValue({ + name: 'defaultMockedTool', + inputSchema: { type: 'object' }, + description: 'Default mock', + }); + }); + + describe('constructor and setup', () => { + it('should initialize MCP Server with options and capabilities', async () => { + const options = createDefaultOptions({ + name: 'my-server', + version: '2.0', + roots: [{ name: 'root1', uri: 'genkit:/tool/t1' }], + }); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + options + ); + await server.setup(); + expect(MockMcpServer).toHaveBeenCalledWith( + { + name: 'my-server', + version: '2.0', + roots: [{ name: 'root1', uri: 'genkit:/tool/t1' }], + }, + { + capabilities: { + prompts: {}, + tools: {}, + roots: { listChanged: false }, + }, + } + ); + }); + + it('should register request handlers for MCP SDK schemas', async () => { + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + await server.setup(); + + expect(mockMcpServerInstance.setRequestHandler).toHaveBeenCalledTimes(5); + const { + ListToolsRequestSchema, + CallToolRequestSchema, + ListPromptsRequestSchema, + GetPromptRequestSchema, + ListRootsRequestSchema, + } = await import('@modelcontextprotocol/sdk/types.js'); + + expect(mockMcpServerInstance.setRequestHandler).toHaveBeenCalledWith( + ListToolsRequestSchema, + expect.any(Function) + ); + expect(mockMcpServerInstance.setRequestHandler).toHaveBeenCalledWith( + CallToolRequestSchema, + expect.any(Function) + ); + expect(mockMcpServerInstance.setRequestHandler).toHaveBeenCalledWith( + ListPromptsRequestSchema, + expect.any(Function) + ); + expect(mockMcpServerInstance.setRequestHandler).toHaveBeenCalledWith( + GetPromptRequestSchema, + expect.any(Function) + ); + expect(mockMcpServerInstance.setRequestHandler).toHaveBeenCalledWith( + ListRootsRequestSchema, + expect.any(Function) + ); + }); + + it('should fetch and categorize actions from Genkit registry', async () => { + const toolAction1 = createMockGenkitAction('tool', 'tool1'); + const promptAction1 = createMockGenkitAction('prompt', 'prompt1'); + const otherAction = { + '/other/action': jest.fn() as unknown as ToolAction, + }; + mockListActions.mockResolvedValue({ + ...toolAction1, + ...promptAction1, + ...otherAction, + }); + + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + await server.setup(); + + expect(mockListActions).toHaveBeenCalled(); + expect(server.toolActions).toEqual( + expect.arrayContaining([Object.values(toolAction1)[0]]) + ); + expect(server.promptActions).toEqual( + expect.arrayContaining([Object.values(promptAction1)[0]]) + ); + }); + + it('should call setup only once even if methods are called multiple times', async () => { + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + await server.listTools({} as any); + await server.listPrompts({} as any); + expect(mockListActions).toHaveBeenCalledTimes(1); + expect(MockMcpServer).toHaveBeenCalledTimes(1); + }); + }); + + describe('listTools', () => { + it('should return tools based on registered tool actions', async () => { + const toolAction = createMockGenkitAction( + 'tool', + 'toolA', + { type: 'string' }, + 'Tool A desc' + ); + mockListActions.mockResolvedValue(toolAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + const mockToolFn = Object.values(toolAction)[0]; + + mockToToolDefinition.mockReturnValue({ + name: 'toolA', + inputSchema: { type: 'string' }, + description: 'Tool A desc', + }); + + const result = await server.listTools({} as any); + expect(mockToToolDefinition).toHaveBeenCalledWith(mockToolFn); + expect(result.tools).toEqual([ + { + name: 'toolA', + inputSchema: { type: 'string' }, + description: 'Tool A desc', + }, + ]); + }); + + it('should use default inputSchema if not defined', async () => { + const toolAction = createMockGenkitAction( + 'tool', + 'toolB', + undefined, // inputSchema + 'Tool B desc' + ); + mockListActions.mockResolvedValue(toolAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + + mockToToolDefinition.mockReturnValue({ + name: 'toolB', + // inputSchema is undefined here + description: 'Tool B desc', + }); + + const result = await server.listTools({} as any); + expect(result.tools[0].inputSchema).toEqual({ type: 'object' }); + }); + }); + + describe('listRoots', () => { + it('should return roots from options', async () => { + const roots = [{ name: 'root1', uri: 'genkit:/tool/t1' }]; + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions({ roots } as McpServerOptions) + ); + await server.setup(); + const result = await server.listRoots({} as any); + expect(result.roots).toEqual(roots); + }); + + it('should return empty array if no roots in options', async () => { + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions({ roots: undefined } as McpServerOptions) + ); + await server.setup(); + + const result = await server.listRoots({} as any); + expect(result.roots).toEqual([]); + }); + }); + + describe('callTool', () => { + it('should find and call the specified tool action', async () => { + const toolAction = createMockGenkitAction('tool', 'myTool'); + const toolFn = Object.values(toolAction)[0] as jest.MockedFunction; + toolFn.mockResolvedValue('tool_result_data' as any); + mockListActions.mockResolvedValue(toolAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + + const result = await server.callTool({ + params: { name: 'myTool', arguments: { arg1: 'val1' } }, + } as any); + + expect(toolFn).toHaveBeenCalledWith({ arg1: 'val1' }); + expect(result.content).toEqual([ + { type: 'text', text: JSON.stringify('tool_result_data') }, + ]); + }); + + it('should throw GenkitError if tool is not found', async () => { + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + mockListActions.mockResolvedValue({}); + await expect( + server.callTool({ params: { name: 'nonExistentTool' } } as any) + ).rejects.toThrow( + new GenkitError({ + status: 'NOT_FOUND', + message: + "Tried to call tool 'nonExistentTool' but it could not be found.", + }) + ); + }); + }); + + describe('listPrompts', () => { + it('should return prompts with arguments derived from inputSchema', async () => { + const promptInputSchema = { + properties: { query: { type: 'string', description: 'User query' } }, + required: ['query'], + }; + const promptAction = createMockGenkitAction( + 'prompt', + 'myPrompt', + promptInputSchema, + 'A prompt' + ); + mockListActions.mockResolvedValue(promptAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + const promptFn = Object.values(promptAction)[0]; + mockToJsonSchema.mockReturnValue({ + properties: { query: { type: 'string', description: 'User query' } }, + required: ['query'], + }); + + const result = await server.listPrompts({} as any); + expect(mockToJsonSchema).toHaveBeenCalledWith({ + schema: (promptFn as any).__action.inputSchema, + jsonSchema: (promptFn as any).__action.inputJsonSchema, + }); + expect(result.prompts).toEqual([ + { + name: 'myPrompt', + description: 'A prompt', + arguments: [ + { name: 'query', description: 'User query', required: true }, + ], + }, + ]); + }); + + it('should throw if prompt inputSchema is not an object (no properties)', async () => { + const promptAction = createMockGenkitAction('prompt', 'badPrompt', { + type: 'string', + }); + mockListActions.mockResolvedValue(promptAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + + mockToJsonSchema.mockReturnValue({ type: 'string' }); + await expect(server.listPrompts({} as any)).rejects.toThrow( + new GenkitError({ + status: 'FAILED_PRECONDITION', + message: + '[@genkit-ai/mcp] MCP prompts must take objects as input schema.', + }) + ); + }); + + it('should throw if prompt argument is not a string', async () => { + const promptAction = createMockGenkitAction( + 'prompt', + 'promptWithNonString', + { properties: { numArg: { type: 'number' } } } + ); + mockListActions.mockResolvedValue(promptAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + + mockToJsonSchema.mockReturnValue({ + properties: { numArg: { type: 'number', description: 'A number' } }, + }); + + await expect(server.listPrompts({} as any)).rejects.toThrow( + new GenkitError({ + status: 'FAILED_PRECONDITION', + message: `[@genkit-ai/mcp] MCP prompts may only take string arguments, but promptWithNonString has property 'numArg' of type 'number'.`, + }) + ); + }); + + it('should handle undefined inputSchema for a prompt gracefully', async () => { + const promptAction = createMockGenkitAction( + 'prompt', + 'noSchemaPrompt', + undefined, + 'Prompt without schema' + ); + mockListActions.mockResolvedValue(promptAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + mockToJsonSchema.mockReturnValue(undefined); // Simulate no schema or unparsable + + const result = await server.listPrompts({} as any); + expect(result.prompts[0].arguments).toBeUndefined(); + }); + }); + + describe('getPrompt', () => { + const createMessageData = ( + role: 'user' | 'model' | 'system' | 'tool', + text?: string, + mediaUrl?: string, + mediaContentType?: string + ): MessageData => ({ + role, + content: [ + ...(text ? [{ text, custom: {} }] : []), + ...(mediaUrl + ? [ + { + media: { url: mediaUrl, contentType: mediaContentType }, + custom: {}, + }, + ] + : []), + ].filter(Boolean) as MessageData['content'], // Filter out undefined if neither text nor media + }); + + it('should find, call prompt, and format messages (text only)', async () => { + const promptAction = createMockGenkitAction( + 'prompt', + 'chatPrompt', + undefined, + 'Chatty prompt' + ); + const promptFn = Object.values( + promptAction + )[0] as jest.MockedFunction; + promptFn.mockResolvedValue( + Promise.resolve({ + messages: [ + createMessageData('user', 'Hello'), + createMessageData('model', 'Hi there'), + ], + } as any) + ); + mockListActions.mockResolvedValue(promptAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + + const result = await server.getPrompt({ + params: { name: 'chatPrompt', arguments: { query: 'Hi' } }, + } as any); + + expect(promptFn).toHaveBeenCalledWith({ query: 'Hi' }); + expect(result.description).toBe('Chatty prompt'); + expect(result.messages).toEqual([ + { role: 'user', content: { type: 'text', text: 'Hello' } }, + { role: 'assistant', content: { type: 'text', text: 'Hi there' } }, + ]); + }); + + it('should format media messages (data URL)', async () => { + const promptAction = createMockGenkitAction('prompt', 'imagePrompt'); + const promptFn = Object.values( + promptAction + )[0] as jest.MockedFunction; + promptFn.mockResolvedValue( + Promise.resolve({ + messages: [ + createMessageData( + 'user', + undefined, + '', + 'image/png' + ), + ], + } as any) + ); + mockListActions.mockResolvedValue(promptAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + + const result = await server.getPrompt({ + params: { name: 'imagePrompt' }, + } as any); + expect(result.messages).toEqual([ + { + role: 'user', + content: { type: 'image', mimeType: 'image/png', data: 'abc' }, + }, + ]); + }); + + it('should throw GenkitError if prompt is not found', async () => { + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + mockListActions.mockResolvedValue({}); + await expect( + server.getPrompt({ params: { name: 'ghostPrompt' } } as any) + ).rejects.toThrow( + new GenkitError({ + status: 'NOT_FOUND', + message: + "[@genkit-ai/mcp] Tried to call prompt 'ghostPrompt' but it could not be found.", + }) + ); + }); + + it('should throw if message role is unsupported', async () => { + const promptAction = createMockGenkitAction('prompt', 'systemRolePrompt'); + const promptFn = Object.values( + promptAction + )[0] as jest.MockedFunction; + promptFn.mockResolvedValue( + Promise.resolve({ + messages: [createMessageData('system', 'System message')], + } as any) + ); + + mockListActions.mockResolvedValue(promptAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + + await expect( + server.getPrompt({ params: { name: 'systemRolePrompt' } } as any) + ).rejects.toThrow( + new GenkitError({ + status: 'UNIMPLEMENTED', + message: + "[@genkit-ai/mcp] MCP prompt messages do not support role 'system'. Only 'user' and 'model' messages are supported.", + }) + ); + }); + + it('should throw if media URL is not a data URL', async () => { + const promptAction = createMockGenkitAction('prompt', 'httpImagePrompt'); + const promptFn = Object.values( + promptAction + )[0] as jest.MockedFunction; + promptFn.mockResolvedValue( + Promise.resolve({ + messages: [ + createMessageData( + 'user', + undefined, + 'http://example.com/image.png' + ), + ], + } as any) + ); + + mockListActions.mockResolvedValue(promptAction); + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + await expect( + server.getPrompt({ params: { name: 'httpImagePrompt' } } as any) + ).rejects.toThrow( + new GenkitError({ + status: 'UNIMPLEMENTED', + message: + '[@genkit-ai/mcp] MCP prompt messages only support base64 data images.', + }) + ); + }); + }); + + describe('start', () => { + it('should use StdioServerTransport by default if no transport is provided', async () => { + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + await server.start(); + + expect(MockStdioServerTransport).toHaveBeenCalled(); + expect(mockMcpServerInstance.connect).toHaveBeenCalledWith( + mockStdioServerTransportInstance + ); + expect(mockLoggerInfo).toHaveBeenCalledWith( + "[@genkit-ai/mcp] MCP server 'test-mcp-server' started successfully." + ); + }); + + it('should use provided transport', async () => { + const customTransport = { custom: 'transport' }; + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + await server.start(customTransport as any); + + expect(MockStdioServerTransport).not.toHaveBeenCalled(); + expect(mockMcpServerInstance.connect).toHaveBeenCalledWith( + customTransport + ); + expect(mockLoggerInfo).toHaveBeenCalledWith( + "[@genkit-ai/mcp] MCP server 'test-mcp-server' started successfully." + ); + }); + + it('should call setup before connecting', async () => { + const server = new GenkitMcpServer( + mockGenkit as unknown as Genkit, + createDefaultOptions() + ); + // Spy on setup to ensure it's called by start + const setupSpy = jest.spyOn(server, 'setup'); + await server.start(); + expect(setupSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index a19d3e0108..3b13486538 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -634,8 +634,8 @@ importers: specifier: workspace:^ version: link:../../core '@modelcontextprotocol/sdk': - specifier: ^1.8.0 - version: 1.8.0 + specifier: ^1.11.0 + version: 1.11.0 genkit: specifier: workspace:^ version: link:../../genkit @@ -643,6 +643,9 @@ importers: '@jest/globals': specifier: ^29.7.0 version: 29.7.0 + '@types/jest': + specifier: ^29.5.12 + version: 29.5.13 '@types/node': specifier: ^20.11.16 version: 20.16.9 @@ -656,8 +659,8 @@ importers: specifier: ^6.0.1 version: 6.0.1 ts-jest: - specifier: ^29.1.2 - version: 29.2.5(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest@29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@5.6.3)))(typescript@5.6.3) + specifier: ^29.2.5 + version: 29.2.5(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(esbuild@0.24.0)(jest@29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@5.6.3)))(typescript@5.6.3) tsup: specifier: ^8.3.5 version: 8.3.5(postcss@8.4.47)(tsx@4.19.2)(typescript@5.6.3)(yaml@2.7.0) @@ -665,7 +668,7 @@ importers: specifier: ^4.19.2 version: 4.19.2 typescript: - specifier: ^5.3.0 + specifier: ^5.6.3 version: 5.6.3 plugins/next: @@ -1388,7 +1391,7 @@ importers: version: link:../../plugins/ollama genkitx-openai: specifier: ^0.10.1 - version: 0.10.1(@genkit-ai/ai@1.4.0)(@genkit-ai/core@1.4.0) + version: 0.10.1(@genkit-ai/ai@1.8.0)(@genkit-ai/core@1.8.0) devDependencies: rimraf: specifier: ^6.0.1 @@ -2520,11 +2523,11 @@ packages: '@firebase/webchannel-wrapper@1.0.3': resolution: {integrity: sha512-2xCRM9q9FlzGZCdgDMJwc0gyUkWFtkosy7Xxr6sFgQwn+wMNIWd7xIvYNauU1r64B5L5rsGKy/n9TKJ0aAFeqQ==} - '@genkit-ai/ai@1.4.0': - resolution: {integrity: sha512-s0YZ7quoYF4LYFFVnJz/3GvBmXPl8Ty9a5ZMOCB8k0xmAopiFwKEpaCMFbpIyF04EmB2U8x5/k3bjliD32eZXQ==} + '@genkit-ai/ai@1.8.0': + resolution: {integrity: sha512-TIhFgQCThdVOyrk6qiVF8dPfz4XmL3RxE9OCidhqcpGrGE5YeRvle+nWIbkNIojdLURQf3/dBxNdaqZZ7B3msQ==} - '@genkit-ai/core@1.4.0': - resolution: {integrity: sha512-Y85RsvXfejH7vQOH/O8/GgaKqDeqiDDMnWNKa2Cy2ugwlsy1P5jSHkQ5wUPgCCTSwQG4eOfdmmwGpFVvNi0QXw==} + '@genkit-ai/core@1.8.0': + resolution: {integrity: sha512-XvK/Gq7fi8pFCJftzby/6EWoVBj5EUSai/9/104Y699wbub3qreoIGH2DN/CKOUzt0gD2i7fHeYlyTAnJ+TPbw==} '@gerrit0/mini-shiki@1.24.4': resolution: {integrity: sha512-YEHW1QeAg6UmxEmswiQbOVEg1CW22b1XUD/lNTliOsu0LD0wqoyleFMnmbTp697QE0pcadQiR5cVtbbAPncvpw==} @@ -3156,8 +3159,8 @@ packages: react-dom: ^18 || ^19 zod: '>= 3' - '@modelcontextprotocol/sdk@1.8.0': - resolution: {integrity: sha512-e06W7SwrontJDHwCawNO5SGxG+nU9AAx+jpHHZqGl/WrDBdWOpvirC+s58VpJTB5QemI4jTRcjWT4Pt3Q1NPQQ==} + '@modelcontextprotocol/sdk@1.11.0': + resolution: {integrity: sha512-k/1pb70eD638anoi0e8wUGAlbMJXyvdV4p62Ko+EZ7eBe1xMx8Uhak1R5DgfoofsK5IBBnRwsYGTaLZl+6/+RQ==} engines: {node: '>=18'} '@next/env@15.2.4': @@ -6111,6 +6114,7 @@ packages: node-domexception@1.0.0: resolution: {integrity: sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==} engines: {node: '>=10.5.0'} + deprecated: Use your platform's native DOMException instead node-ensure@0.0.0: resolution: {integrity: sha512-DRI60hzo2oKN1ma0ckc6nQWlHU69RH6xN0sjQTjMpChPfTYvKZdcQFfdYK2RWbJcKyUizSIy/l8OTGxMAM1QDw==} @@ -6362,8 +6366,8 @@ packages: resolution: {integrity: sha512-saLsH7WeYYPiD25LDuLRRY/i+6HaPYr6G1OUlN39otzkSTxKnubR9RTxS3/Kk50s1g2JTgFwWQDQyplC5/SHZg==} engines: {node: '>= 6'} - pkce-challenge@4.1.0: - resolution: {integrity: sha512-ZBmhE1C9LcPoH9XZSdwiPtbPHZROwAnMy+kIFQVrnMCxY4Cudlz3gBOpzilgc0jOgRaiT3sIWfpMomW2ar2orQ==} + pkce-challenge@5.0.0: + resolution: {integrity: sha512-ueGLflrrnvwB3xuo/uGob5pd5FN7l0MsLf0Z87o/UQmRtwjvfylfc9MurIxRAWywCYTgrvpXBcqjV4OfCYGCIQ==} engines: {node: '>=16.20.0'} pkg-dir@4.2.0: @@ -8088,9 +8092,9 @@ snapshots: '@firebase/webchannel-wrapper@1.0.3': {} - '@genkit-ai/ai@1.4.0': + '@genkit-ai/ai@1.8.0': dependencies: - '@genkit-ai/core': 1.4.0 + '@genkit-ai/core': 1.8.0 '@opentelemetry/api': 1.9.0 '@types/node': 20.17.17 colorette: 2.0.20 @@ -8102,7 +8106,7 @@ snapshots: transitivePeerDependencies: - supports-color - '@genkit-ai/core@1.4.0': + '@genkit-ai/core@1.8.0': dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/context-async-hooks': 1.30.1(@opentelemetry/api@1.9.0) @@ -8761,7 +8765,7 @@ snapshots: - encoding - supports-color - '@modelcontextprotocol/sdk@1.8.0': + '@modelcontextprotocol/sdk@1.11.0': dependencies: content-type: 1.0.5 cors: 2.8.5 @@ -8769,7 +8773,7 @@ snapshots: eventsource: 3.0.5 express: 5.0.1 express-rate-limit: 7.5.0(express@5.0.1) - pkce-challenge: 4.1.0 + pkce-challenge: 5.0.0 raw-body: 3.0.0 zod: 3.24.1 zod-to-json-schema: 3.24.1(zod@3.24.1) @@ -11096,10 +11100,10 @@ snapshots: - encoding - supports-color - genkitx-openai@0.10.1(@genkit-ai/ai@1.4.0)(@genkit-ai/core@1.4.0): + genkitx-openai@0.10.1(@genkit-ai/ai@1.8.0)(@genkit-ai/core@1.8.0): dependencies: - '@genkit-ai/ai': 1.4.0 - '@genkit-ai/core': 1.4.0 + '@genkit-ai/ai': 1.8.0 + '@genkit-ai/core': 1.8.0 openai: 4.53.0(encoding@0.1.13) zod: 3.24.1 transitivePeerDependencies: @@ -12819,7 +12823,7 @@ snapshots: pirates@4.0.6: {} - pkce-challenge@4.1.0: {} + pkce-challenge@5.0.0: {} pkg-dir@4.2.0: dependencies: @@ -13521,31 +13525,32 @@ snapshots: ts-interface-checker@0.1.13: {} - ts-jest@29.2.5(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest@29.7.0(@types/node@20.11.30)(ts-node@10.9.2(@types/node@20.11.30)(typescript@4.9.5)))(typescript@4.9.5): + ts-jest@29.2.5(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(esbuild@0.24.0)(jest@29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@5.6.3)))(typescript@5.6.3): dependencies: bs-logger: 0.2.6 ejs: 3.1.10 fast-json-stable-stringify: 2.1.0 - jest: 29.7.0(@types/node@20.11.30)(ts-node@10.9.2(@types/node@20.11.30)(typescript@4.9.5)) + jest: 29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@5.6.3)) jest-util: 29.7.0 json5: 2.2.3 lodash.memoize: 4.1.2 make-error: 1.3.6 semver: 7.6.3 - typescript: 4.9.5 + typescript: 5.6.3 yargs-parser: 21.1.1 optionalDependencies: '@babel/core': 7.25.7 '@jest/transform': 29.7.0 '@jest/types': 29.6.3 babel-jest: 29.7.0(@babel/core@7.25.7) + esbuild: 0.24.0 - ts-jest@29.2.5(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest@29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@4.9.5)))(typescript@4.9.5): + ts-jest@29.2.5(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest@29.7.0(@types/node@20.11.30)(ts-node@10.9.2(@types/node@20.11.30)(typescript@4.9.5)))(typescript@4.9.5): dependencies: bs-logger: 0.2.6 ejs: 3.1.10 fast-json-stable-stringify: 2.1.0 - jest: 29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@4.9.5)) + jest: 29.7.0(@types/node@20.11.30)(ts-node@10.9.2(@types/node@20.11.30)(typescript@4.9.5)) jest-util: 29.7.0 json5: 2.2.3 lodash.memoize: 4.1.2 @@ -13559,18 +13564,18 @@ snapshots: '@jest/types': 29.6.3 babel-jest: 29.7.0(@babel/core@7.25.7) - ts-jest@29.2.5(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest@29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@5.6.3)))(typescript@5.6.3): + ts-jest@29.2.5(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest@29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@4.9.5)))(typescript@4.9.5): dependencies: bs-logger: 0.2.6 ejs: 3.1.10 fast-json-stable-stringify: 2.1.0 - jest: 29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@5.6.3)) + jest: 29.7.0(@types/node@20.16.9)(ts-node@10.9.2(@types/node@20.16.9)(typescript@4.9.5)) jest-util: 29.7.0 json5: 2.2.3 lodash.memoize: 4.1.2 make-error: 1.3.6 semver: 7.6.3 - typescript: 5.6.3 + typescript: 4.9.5 yargs-parser: 21.1.1 optionalDependencies: '@babel/core': 7.25.7 From 3a4d4cbf8df1caa8236547bb1f9893315f1d8b1d Mon Sep 17 00:00:00 2001 From: Greg Spencer Date: Thu, 8 May 2025 18:02:51 -0700 Subject: [PATCH 2/3] Allow updating roots --- js/plugins/mcp/src/client.ts | 138 +++++++++++++++++++++++++++ js/plugins/mcp/src/client/prompts.ts | 3 +- js/plugins/mcp/src/client/roots.ts | 37 +++++++ js/plugins/mcp/src/index.ts | 87 ++++------------- 4 files changed, 195 insertions(+), 70 deletions(-) create mode 100644 js/plugins/mcp/src/client.ts create mode 100644 js/plugins/mcp/src/client/roots.ts diff --git a/js/plugins/mcp/src/client.ts b/js/plugins/mcp/src/client.ts new file mode 100644 index 0000000000..f4aee2252f --- /dev/null +++ b/js/plugins/mcp/src/client.ts @@ -0,0 +1,138 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import type { Client } from '@modelcontextprotocol/sdk/client/index.js' with { 'resolution-mode': 'import' }; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' with { 'resolution-mode': 'import' }; +import type { + ListRootsRequest, + ListRootsResult, + Root, + ServerCapabilities, +} from '@modelcontextprotocol/sdk/types.js' with { 'resolution-mode': 'import' }; +import { ListRootsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; +import { Genkit, GenkitError } from 'genkit'; +import { registerAllPrompts } from './client/prompts'; +import { registerResourceTools } from './client/resources'; +import { registerAllTools } from './client/tools'; +import type { McpClientOptions } from './index.js'; + +async function transportFrom(params: McpClientOptions): Promise { + if (params.transport) return params.transport; + if (params.serverUrl) { + const { SSEClientTransport } = await import( + '@modelcontextprotocol/sdk/client/sse.js' + ); + return new SSEClientTransport(new URL(params.serverUrl)); + } + if (params.serverProcess) { + const { StdioClientTransport } = await import( + '@modelcontextprotocol/sdk/client/stdio.js' + ); + return new StdioClientTransport(params.serverProcess); + } + if (params.serverWebsocketUrl) { + const { WebSocketClientTransport } = await import( + '@modelcontextprotocol/sdk/client/websocket.js' + ); + let url = params.serverWebsocketUrl; + if (typeof url === 'string') url = new URL(url); + return new WebSocketClientTransport(url); + } + + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: + 'Unable to create a server connection with supplied options. Must provide transport, stdio, or sseUrl.', + }); +} + +export class GenkitMcpClient { + ai: Genkit; + options: McpClientOptions; + client?: Client; + serverCapabilities?: ServerCapabilities | undefined = {}; + + constructor(ai: Genkit, options: McpClientOptions) { + this.ai = ai; + this.options = options; + } + + async setup(): Promise { + const { Client } = await import( + '@modelcontextprotocol/sdk/client/index.js' + ); + + const transport = await transportFrom(this.options); + this.client = new Client( + { + name: this.options.name, + version: this.options.version || '1.0.0', + roots: this.options.roots, + }, + { + capabilities: { + roots: { listChanged: true }, + }, + } + ); + await this.client.connect(transport); + this.serverCapabilities = this.client.getServerCapabilities(); + this.client.setRequestHandler( + ListRootsRequestSchema, + this.listRoots.bind(this) + ); + + await this.registerCapabilites(); + } + + async registerCapabilites(): Promise { + if (!this.client || !this.serverCapabilities) { + return; + } + const promises: Promise[] = []; + if (this.serverCapabilities?.tools) { + promises.push(registerAllTools(this.ai, this.client, this.options)); + } + if (this.serverCapabilities?.prompts) { + promises.push(registerAllPrompts(this.ai, this.client, this.options)); + } + if (this.serverCapabilities?.resources) { + promises.push(registerResourceTools(this.ai, this.client, this.options)); + } + await Promise.all(promises); + } + + async notifyRootsChanged() { + this.client?.sendRootsListChanged(); + } + + set roots(roots: { name: string; uri: string }[]) { + this.options.roots = roots; + // Have to re-register the tools, resources, etc., since they can change + // based on the roots. + this.registerCapabilites(); + this.notifyRootsChanged(); + } + + async listRoots(req: ListRootsRequest): Promise { + await this.setup(); + if (!this.options.roots) { + return { roots: [] }; + } + const mcpRoots: Root[] = this.options.roots.map((root) => root); + return { roots: mcpRoots }; + } +} diff --git a/js/plugins/mcp/src/client/prompts.ts b/js/plugins/mcp/src/client/prompts.ts index ee4b7c26bd..3985709caf 100644 --- a/js/plugins/mcp/src/client/prompts.ts +++ b/js/plugins/mcp/src/client/prompts.ts @@ -63,7 +63,8 @@ function registerPrompt( } /** - * Lookup all tools available in the server and register each as a Genkit tool. + * Lookup all tools available in the server and register each as a Genkit + * prompt. */ export async function registerAllPrompts( ai: Genkit, diff --git a/js/plugins/mcp/src/client/roots.ts b/js/plugins/mcp/src/client/roots.ts new file mode 100644 index 0000000000..bf6ab61070 --- /dev/null +++ b/js/plugins/mcp/src/client/roots.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import type { Client } from '@modelcontextprotocol/sdk/client/index.js' with { 'resolution-mode': 'import' }; +import { Genkit, z } from 'genkit'; +import type { McpClientOptions } from '../index.js'; + +export async function registerClientRoots( + ai: Genkit, + client: Client, + params: McpClientOptions +) { + ai.defineTool( + { + name: `${params.name}/list_roots`, + description: `Lists the available roots for this MCP client`, + inputSchema: z.object({}), + outputSchema: z.array(z.object({ name: z.string(), uri: z.string() })), + }, + async () => { + return params.roots ?? []; + } + ); +} diff --git a/js/plugins/mcp/src/index.ts b/js/plugins/mcp/src/index.ts index 5516927784..4c6cf8d0ad 100644 --- a/js/plugins/mcp/src/index.ts +++ b/js/plugins/mcp/src/index.ts @@ -16,11 +16,9 @@ import type { StdioServerParameters } from '@modelcontextprotocol/sdk/client/stdio.js' with { 'resolution-mode': 'import' }; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' with { 'resolution-mode': 'import' }; -import { Genkit, GenkitError } from 'genkit'; +import { Genkit } from 'genkit'; import { genkitPlugin } from 'genkit/plugin'; -import { registerAllPrompts } from './client/prompts'; -import { registerResourceTools } from './client/resources'; -import { registerAllTools } from './client/tools'; +import { GenkitMcpClient } from './client'; import { GenkitMcpServer } from './server'; export interface McpClientOptions { @@ -42,77 +40,28 @@ export interface McpClientOptions { roots?: { name: string; uri: string }[]; } -async function transportFrom(params: McpClientOptions): Promise { - if (params.transport) return params.transport; - if (params.serverUrl) { - const { SSEClientTransport } = await import( - '@modelcontextprotocol/sdk/client/sse.js' - ); - return new SSEClientTransport(new URL(params.serverUrl)); - } - if (params.serverProcess) { - const { StdioClientTransport } = await import( - '@modelcontextprotocol/sdk/client/stdio.js' - ); - return new StdioClientTransport(params.serverProcess); - } - if (params.serverWebsocketUrl) { - const { WebSocketClientTransport } = await import( - '@modelcontextprotocol/sdk/client/websocket.js' - ); - let url = params.serverWebsocketUrl; - if (typeof url === 'string') url = new URL(url); - return new WebSocketClientTransport(url); - } - - throw new GenkitError({ - status: 'INVALID_ARGUMENT', - message: - 'Unable to create a server connection with supplied options. Must provide transport, stdio, or sseUrl.', - }); -} +const mcpClients: Record = {}; export function mcpClient(params: McpClientOptions) { return genkitPlugin(params.name, async (ai: Genkit) => { - const { Client } = await import( - '@modelcontextprotocol/sdk/client/index.js' - ); - - const transport = await transportFrom(params); - const client = new Client( - { - name: params.name, - version: params.version || '1.0.0', - roots: params.roots, - }, - { - capabilities: { - // TODO: Support sending root list change notifications to the server. - // Would require some way to update the list of roots outside of - // client params. Also requires that our registrations of tools and - // resources are able to handle updates to the list of roots this - // client defines, since tool and resource lists are affected by the - // list of roots. - roots: { listChanged: false }, - }, - } - ); - await client.connect(transport); - const capabilities = client.getServerCapabilities(); - const promises: Promise[] = []; - if (capabilities?.tools) { - promises.push(registerAllTools(ai, client, params)); - } - if (capabilities?.prompts) { - promises.push(registerAllPrompts(ai, client, params)); - } - if (capabilities?.resources) { - promises.push(registerResourceTools(ai, client, params)); - } - await Promise.all(promises); + mcpClients[params.name] = new GenkitMcpClient(ai, { + name: params.name, + version: params.version || '1.0.0', + roots: params.roots, + }); }); } +export function setMcpClientRoots( + name: string, + roots: { name: string; uri: string }[] +) { + if (!mcpClients[name]) { + throw new Error(`MCP client plugin ${name} doesn't exist.`); + } + mcpClients[name].roots = roots; +} + export interface McpServerOptions { /** The name you want to give your server for MCP inspection. */ name: string; From d6cdd1e7d9e46a0ba26fd471bca244184eaeee8f Mon Sep 17 00:00:00 2001 From: Greg Spencer Date: Fri, 9 May 2025 11:00:09 -0700 Subject: [PATCH 3/3] Remove abililty to change roots --- js/plugins/mcp/src/client.ts | 54 +++++++++++++------------- js/plugins/mcp/src/client/prompts.ts | 2 +- js/plugins/mcp/src/client/resources.ts | 14 +------ js/plugins/mcp/src/client/roots.ts | 2 - js/plugins/mcp/src/index.ts | 21 ++-------- js/plugins/mcp/tests/client_test.ts | 4 +- 6 files changed, 36 insertions(+), 61 deletions(-) diff --git a/js/plugins/mcp/src/client.ts b/js/plugins/mcp/src/client.ts index f4aee2252f..8123aef9c0 100644 --- a/js/plugins/mcp/src/client.ts +++ b/js/plugins/mcp/src/client.ts @@ -29,33 +29,36 @@ import { registerResourceTools } from './client/resources'; import { registerAllTools } from './client/tools'; import type { McpClientOptions } from './index.js'; -async function transportFrom(params: McpClientOptions): Promise { - if (params.transport) return params.transport; - if (params.serverUrl) { +async function transportFrom(options: McpClientOptions): Promise { + if (options.transport) return options.transport; + if (options.serverUrl) { const { SSEClientTransport } = await import( '@modelcontextprotocol/sdk/client/sse.js' ); - return new SSEClientTransport(new URL(params.serverUrl)); + return new SSEClientTransport(new URL(options.serverUrl)); } - if (params.serverProcess) { + if (options.serverProcess) { const { StdioClientTransport } = await import( '@modelcontextprotocol/sdk/client/stdio.js' ); - return new StdioClientTransport(params.serverProcess); + return new StdioClientTransport(options.serverProcess); } - if (params.serverWebsocketUrl) { + if (options.serverWebsocketUrl) { const { WebSocketClientTransport } = await import( '@modelcontextprotocol/sdk/client/websocket.js' ); - let url = params.serverWebsocketUrl; + let url = options.serverWebsocketUrl; if (typeof url === 'string') url = new URL(url); return new WebSocketClientTransport(url); } throw new GenkitError({ status: 'INVALID_ARGUMENT', - message: - 'Unable to create a server connection with supplied options. Must provide transport, stdio, or sseUrl.', + message: `Unable to create a server connection with supplied options. Must provide transport, stdio, or sseUrl:\n${JSON.stringify( + options, + null, + 2 + )}`, }); } @@ -64,6 +67,7 @@ export class GenkitMcpClient { options: McpClientOptions; client?: Client; serverCapabilities?: ServerCapabilities | undefined = {}; + _isSetup: boolean = false; constructor(ai: Genkit, options: McpClientOptions) { this.ai = ai; @@ -71,6 +75,7 @@ export class GenkitMcpClient { } async setup(): Promise { + if (this._isSetup) return; const { Client } = await import( '@modelcontextprotocol/sdk/client/index.js' ); @@ -84,21 +89,27 @@ export class GenkitMcpClient { }, { capabilities: { - roots: { listChanged: true }, + // TODO: Allow actually changing the roots dynamically. This requires + // manipulating which tools, resources, etc. are registered, since + // they can change based on the roots. + roots: { listChanged: false }, }, } ); - await this.client.connect(transport); - this.serverCapabilities = this.client.getServerCapabilities(); + this.client.setRequestHandler( ListRootsRequestSchema, this.listRoots.bind(this) ); - await this.registerCapabilites(); + await this.client.connect(transport); + this.serverCapabilities = this.client.getServerCapabilities(); + + await this.registerCapabilities(); + this._isSetup = true; } - async registerCapabilites(): Promise { + async registerCapabilities(): Promise { if (!this.client || !this.serverCapabilities) { return; } @@ -115,20 +126,7 @@ export class GenkitMcpClient { await Promise.all(promises); } - async notifyRootsChanged() { - this.client?.sendRootsListChanged(); - } - - set roots(roots: { name: string; uri: string }[]) { - this.options.roots = roots; - // Have to re-register the tools, resources, etc., since they can change - // based on the roots. - this.registerCapabilites(); - this.notifyRootsChanged(); - } - async listRoots(req: ListRootsRequest): Promise { - await this.setup(); if (!this.options.roots) { return { roots: [] }; } diff --git a/js/plugins/mcp/src/client/prompts.ts b/js/plugins/mcp/src/client/prompts.ts index 3985709caf..8e60081ed0 100644 --- a/js/plugins/mcp/src/client/prompts.ts +++ b/js/plugins/mcp/src/client/prompts.ts @@ -63,7 +63,7 @@ function registerPrompt( } /** - * Lookup all tools available in the server and register each as a Genkit + * Lookup all prompts available in the server and register each as a Genkit * prompt. */ export async function registerAllPrompts( diff --git a/js/plugins/mcp/src/client/resources.ts b/js/plugins/mcp/src/client/resources.ts index 69a0625f11..84885b40a1 100644 --- a/js/plugins/mcp/src/client/resources.ts +++ b/js/plugins/mcp/src/client/resources.ts @@ -36,18 +36,17 @@ export async function registerResourceTools( ai.defineTool( { name: `${params.name}/list_resources`, - description: `list all available resources for '${params.name}'${params.roots ? `, within roots ${rootsList}` : ''}`, + description: `list all available resources for '${params.name}'`, inputSchema: z.object({ /** Provide a cursor for accessing additional paginated results. */ cursor: z.string().optional(), /** When specified, automatically paginate and fetch all resources. */ all: z.boolean().optional(), - /** The list of roots to limit the results to. Must be a subset of params.roots. */ roots: z .array(z.object({ name: z.string().optional(), uri: z.string() })) .optional() .describe( - `The list of roots to limit the results to. Must be a subset of ${rootsList}` + `The list of roots to limit the results to. Available roots: ${rootsList}` ), }), }, @@ -56,15 +55,6 @@ export async function registerResourceTools( all, roots, }): Promise<{ nextCursor?: string | undefined; resources: Resource[] }> => { - // Filter the roots so that they only contain roots in the params.roots list. - if (roots) { - roots = roots.filter((root) => - params.roots?.some( - (pRoot) => pRoot.name === root.name && pRoot.uri === root.uri - ) - ); - } - if (!all) { return client.listResources({ roots: roots || params.roots }); } diff --git a/js/plugins/mcp/src/client/roots.ts b/js/plugins/mcp/src/client/roots.ts index bf6ab61070..9ece627749 100644 --- a/js/plugins/mcp/src/client/roots.ts +++ b/js/plugins/mcp/src/client/roots.ts @@ -14,13 +14,11 @@ * limitations under the License. */ -import type { Client } from '@modelcontextprotocol/sdk/client/index.js' with { 'resolution-mode': 'import' }; import { Genkit, z } from 'genkit'; import type { McpClientOptions } from '../index.js'; export async function registerClientRoots( ai: Genkit, - client: Client, params: McpClientOptions ) { ai.defineTool( diff --git a/js/plugins/mcp/src/index.ts b/js/plugins/mcp/src/index.ts index 4c6cf8d0ad..f8126f1651 100644 --- a/js/plugins/mcp/src/index.ts +++ b/js/plugins/mcp/src/index.ts @@ -42,26 +42,13 @@ export interface McpClientOptions { const mcpClients: Record = {}; -export function mcpClient(params: McpClientOptions) { - return genkitPlugin(params.name, async (ai: Genkit) => { - mcpClients[params.name] = new GenkitMcpClient(ai, { - name: params.name, - version: params.version || '1.0.0', - roots: params.roots, - }); +export function mcpClient(options: McpClientOptions) { + return genkitPlugin(options.name, async (ai: Genkit) => { + mcpClients[options.name] = new GenkitMcpClient(ai, options); + await mcpClients[options.name].setup(); }); } -export function setMcpClientRoots( - name: string, - roots: { name: string; uri: string }[] -) { - if (!mcpClients[name]) { - throw new Error(`MCP client plugin ${name} doesn't exist.`); - } - mcpClients[name].roots = roots; -} - export interface McpServerOptions { /** The name you want to give your server for MCP inspection. */ name: string; diff --git a/js/plugins/mcp/tests/client_test.ts b/js/plugins/mcp/tests/client_test.ts index d94b32696a..85062f3d7f 100644 --- a/js/plugins/mcp/tests/client_test.ts +++ b/js/plugins/mcp/tests/client_test.ts @@ -47,7 +47,9 @@ const mockMcpClientInstance = { tools: true, prompts: true, resources: true, + roots: { listChanged: false }, }), + setRequestHandler: jest.fn(), }; const MockMcpClient = jest.fn(() => mockMcpClientInstance); @@ -307,7 +309,7 @@ describe('mcpClient', () => { } catch (e: any) { expect(e.details.status).toBe('INVALID_ARGUMENT'); expect(e.message).toBe( - 'Unable to create a server connection with supplied options. Must provide transport, stdio, or sseUrl.' + 'Unable to create a server connection with supplied options. Must provide transport, stdio, or sseUrl:\n{\n "name": "test-error-client"\n}' ); }