Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Implement server ping utility #6

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jakarta.enterprise.context.ApplicationScoped;

import io.quarkiverse.mcp.server.test.MyPrompts.Options;
import io.quarkiverse.mcp.server.test.prompts.MyPrompts.Options;

@ApplicationScoped
public class FooService {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,35 @@
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

import org.awaitility.Awaitility;
import org.jboss.resteasy.reactive.client.SseEvent;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.restassured.http.ContentType;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;

public class McpServerTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(McpClient.class, FooService.class, MyPrompts.class));
public abstract class McpServerTest {

@TestHTTPResource
URI testUri;

@Test
public void testInit() throws URISyntaxException {
List<SseEvent<String>> sseMessages;

AtomicInteger idGenerator = new AtomicInteger();

protected URI initClient() throws URISyntaxException {
return initClient(null);
}

protected URI initClient(Consumer<JsonObject> initResultAssert) throws URISyntaxException {
McpClient mcpClient = QuarkusRestClientBuilder.newBuilder()
.baseUri(testUri)
.build(McpClient.class);

List<SseEvent<String>> sseMessages = new CopyOnWriteArrayList<>();
sseMessages = new CopyOnWriteArrayList<>();
mcpClient.init().subscribe().with(s -> sseMessages.add(s), e -> {
});
Awaitility.await().until(() -> !sseMessages.isEmpty());
Expand All @@ -47,7 +47,7 @@ public void testInit() throws URISyntaxException {
.put("params",
new JsonObject()
.put("clientInfo", new JsonObject()
.put("name", "FooClient")
.put("name", "test-client")
.put("version", "1.0"))
.put("protocolVersion", "2024-11-05"));

Expand All @@ -60,12 +60,14 @@ public void testInit() throws URISyntaxException {
.statusCode(200)
.extract().body().asString());

assertEquals(initMessage.getInteger("id"), initResponse.getInteger("id"));
assertEquals("2.0", initResponse.getString("jsonrpc"));
JsonObject initResult = initResponse.getJsonObject("result");
JsonObject initResult = assertResponseMessage(initMessage, initResponse);
assertNotNull(initResult);
assertEquals("2024-11-05", initResult.getString("protocolVersion"));

if (initResultAssert != null) {
initResultAssert.accept(initResult);
}

// Send "notifications/initialized"
given()
.contentType(ContentType.JSON)
Expand All @@ -77,69 +79,20 @@ public void testInit() throws URISyntaxException {
.then()
.statusCode(200);

JsonObject promptListMessage = newMessage("prompts/list");

JsonObject promptListResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(promptListMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

assertEquals(promptListMessage.getInteger("id"), promptListResponse.getInteger("id"));
assertEquals("2.0", promptListResponse.getString("jsonrpc"));
JsonObject promptListResult = promptListResponse.getJsonObject("result");
assertNotNull(promptListResult);
JsonArray prompts = promptListResult.getJsonArray("prompts");
assertEquals(4, prompts.size());

assertPromptMessage("Hello Lu!", endpoint, "foo", new JsonObject()
.put("name", "Lu")
.put("repeat", 1)
.put("options", new JsonObject()
.put("enabled", true)));
assertPromptMessage("LU", endpoint, "BAR", new JsonObject()
.put("val", "Lu"));
assertPromptMessage("LU", endpoint, "uni_bar", new JsonObject()
.put("val", "Lu"));
assertPromptMessage("LU", endpoint, "uni_list_bar", new JsonObject()
.put("val", "Lu"));
return endpoint;
}

private void assertPromptMessage(String expectedText, URI endpoint, String name, JsonObject arguments) {
JsonObject promptGetMessage = newMessage("prompts/get")
.put("params", new JsonObject()
.put("name", name)
.put("arguments", arguments));

JsonObject promptGetResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(promptGetMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

assertEquals(promptGetMessage.getInteger("id"), promptGetResponse.getInteger("id"));
assertEquals("2.0", promptGetResponse.getString("jsonrpc"));
JsonObject promptGetResult = promptGetResponse.getJsonObject("result");
assertNotNull(promptGetResult);
JsonArray messages = promptGetResult.getJsonArray("messages");
assertEquals(1, messages.size());
JsonObject message = messages.getJsonObject(0);
assertEquals("user", message.getString("role"));
JsonObject content = message.getJsonObject("content");
assertEquals("text", content.getString("type"));
assertEquals(expectedText, content.getString("text"));
protected JsonObject assertResponseMessage(JsonObject message, JsonObject response) {
assertEquals(message.getInteger("id"), response.getInteger("id"));
assertEquals("2.0", response.getString("jsonrpc"));
return response.getJsonObject("result");
}

private static AtomicInteger ID = new AtomicInteger();

private JsonObject newMessage(String method) {
return new JsonObject().put("jsonrpc", "2.0").put("method", method).put("id", ID.incrementAndGet());
protected JsonObject newMessage(String method) {
return new JsonObject()
.put("jsonrpc", "2.0")
.put("method", method)
.put("id", idGenerator.incrementAndGet());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package io.quarkiverse.mcp.server.test.ping;

import static io.restassured.RestAssured.given;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.net.URISyntaxException;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.test.McpClient;
import io.quarkiverse.mcp.server.test.McpServerTest;
import io.quarkus.test.QuarkusUnitTest;
import io.restassured.http.ContentType;
import io.vertx.core.json.JsonObject;

public class PingTest extends McpServerTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(McpClient.class));

@Test
public void testPing() throws URISyntaxException {
URI endpoint = initClient();

JsonObject pingMessage = newMessage("ping");

JsonObject pingResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(pingMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

JsonObject pingResult = assertResponseMessage(pingMessage, pingResponse);
assertNotNull(pingResult);
assertTrue(pingResult.isEmpty());
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkiverse.mcp.server.test;
package io.quarkiverse.mcp.server.test.prompts;

import java.util.List;

Expand All @@ -8,6 +8,7 @@
import io.quarkiverse.mcp.server.PromptArg;
import io.quarkiverse.mcp.server.PromptMessage;
import io.quarkiverse.mcp.server.TextContent;
import io.quarkiverse.mcp.server.test.FooService;
import io.quarkus.arc.Arc;
import io.quarkus.runtime.BlockingOperationControl;
import io.smallrye.common.vertx.VertxContext;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package io.quarkiverse.mcp.server.test.prompts;

import static io.restassured.RestAssured.given;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.net.URI;
import java.net.URISyntaxException;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.test.FooService;
import io.quarkiverse.mcp.server.test.McpClient;
import io.quarkiverse.mcp.server.test.McpServerTest;
import io.quarkus.test.QuarkusUnitTest;
import io.restassured.http.ContentType;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;

public class PromptsTest extends McpServerTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(McpClient.class, FooService.class, MyPrompts.class));

@Test
public void testPrompts() throws URISyntaxException {
URI endpoint = initClient();

JsonObject promptListMessage = newMessage("prompts/list");

JsonObject promptListResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(promptListMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

JsonObject promptListResult = assertResponseMessage(promptListMessage, promptListResponse);
assertNotNull(promptListResult);
JsonArray prompts = promptListResult.getJsonArray("prompts");
assertEquals(4, prompts.size());

assertPromptMessage("Hello Lu!", endpoint, "foo", new JsonObject()
.put("name", "Lu")
.put("repeat", 1)
.put("options", new JsonObject()
.put("enabled", true)));
assertPromptMessage("LU", endpoint, "BAR", new JsonObject()
.put("val", "Lu"));
assertPromptMessage("LU", endpoint, "uni_bar", new JsonObject()
.put("val", "Lu"));
assertPromptMessage("LU", endpoint, "uni_list_bar", new JsonObject()
.put("val", "Lu"));
}

private void assertPromptMessage(String expectedText, URI endpoint, String name, JsonObject arguments) {
JsonObject promptGetMessage = newMessage("prompts/get")
.put("params", new JsonObject()
.put("name", name)
.put("arguments", arguments));

JsonObject promptGetResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(promptGetMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

JsonObject promptGetResult = assertResponseMessage(promptGetMessage, promptGetResponse);
assertNotNull(promptGetResult);
JsonArray messages = promptGetResult.getJsonArray("messages");
assertEquals(1, messages.size());
JsonObject message = messages.getJsonObject(0);
assertEquals("user", message.getString("role"));
JsonObject content = message.getJsonObject("content");
assertEquals("text", content.getString("type"));
assertEquals(expectedText, content.getString("text"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void handle(RoutingContext ctx) {
}
String id = ctx.request().getParam("id");
if (id == null) {
LOG.error("Connection id is required");
LOG.error("Connection id is missing");
ctx.fail(400);
return;
}
Expand All @@ -56,7 +56,7 @@ public void handle(RoutingContext ctx) {
JsonObject message = ctx.body().asJsonObject();
String jsonrpc = message.getString("jsonrpc");
if (!"2.0".equals(jsonrpc)) {
LOG.errorf("Invalid jsonrpc [%s]", message);
LOG.errorf("Invalid jsonrpc version [%s]", message);
ctx.fail(400);
return;
}
Expand Down Expand Up @@ -99,11 +99,11 @@ private void initializing(JsonObject message, RoutingContext ctx, McpConnectionI
String method = message.getString("method");
if ("notifications/initialized".equals(method)) {
if (connection.initialized()) {
LOG.infof("Client initialized [id: %s]", connection.id());
LOG.infof("Client successfully initialized [%s]", connection.id());
ctx.end();
}
} else {
LOG.infof("Client not initialized yet [id: %s]", connection.id());
LOG.infof("Client not initialized yet [%s]", connection.id());
ctx.fail(400);
}
// TODO ping
Expand All @@ -114,20 +114,31 @@ private void operation(JsonObject message, RoutingContext ctx, McpConnectionImpl
switch (method) {
case "prompts/list" -> promptsList(message, ctx);
case "prompts/get" -> promptsGet(message, ctx);
case "ping" -> ping(message, ctx);
default -> throw new IllegalArgumentException("Unsupported method: " + method);
}
}

private void ping(JsonObject message, RoutingContext ctx) {
// https://spec.modelcontextprotocol.io/specification/basic/utilities/ping/
Object id = message.getValue("id");
LOG.infof("Ping [id: %s]", id);
ctx.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json");
ctx.end(result(id, new JsonObject()).encode());
}

private void promptsList(JsonObject message, RoutingContext ctx) {
LOG.infof("List prompts");
Object id = message.getValue("id");
LOG.infof("List prompts [id: %s]", id);
PromptManager promptManager = Arc.container().instance(PromptManager.class).get();
ctx.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json");
ctx.end(result(message.getValue("id"), new JsonObject()
ctx.end(result(id, new JsonObject()
.put("prompts", new JsonArray(promptManager.list()))).encode());
}

private void promptsGet(JsonObject message, RoutingContext ctx) {
LOG.infof("Get prompts");
Object id = message.getValue("id");
LOG.infof("Get prompts [id: %s]", id);
PromptManager promptManager = Arc.container().instance(PromptManager.class).get();
ctx.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json");
JsonObject params = message.getJsonObject("params");
Expand All @@ -137,7 +148,7 @@ private void promptsGet(JsonObject message, RoutingContext ctx) {
@Override
public void handle(AsyncResult<List<PromptMessage>> ar) {
if (ar.succeeded()) {
ctx.end(result(message.getValue("id"), new JsonObject()
ctx.end(result(id, new JsonObject()
.put("messages", new JsonArray(ar.result())))
.encode());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void handle(RoutingContext ctx) {

String id = Base64.getUrlEncoder().encodeToString(UUID.randomUUID().toString().getBytes());

LOG.infof("Connection initialized: %s]", id);
LOG.infof("Client connection initialized [%s]", id);

McpConnectionImpl connection = new McpConnectionImpl(id, response);
connectionManager.add(connection);
Expand Down
Loading