Skip to content

Commit

Permalink
Merge pull request #6 from mkouba/server-ping
Browse files Browse the repository at this point in the history
Implement server ping utility
  • Loading branch information
mkouba authored Dec 12, 2024
2 parents 0532bc3 + defd705 commit a2fa9ca
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jakarta.enterprise.context.ApplicationScoped;

import io.quarkiverse.mcp.server.test.MyPrompts.Options;
import io.quarkiverse.mcp.server.test.prompts.MyPrompts.Options;

@ApplicationScoped
public class FooService {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,35 @@
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

import org.awaitility.Awaitility;
import org.jboss.resteasy.reactive.client.SseEvent;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.restassured.http.ContentType;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;

public class McpServerTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(McpClient.class, FooService.class, MyPrompts.class));
public abstract class McpServerTest {

@TestHTTPResource
URI testUri;

@Test
public void testInit() throws URISyntaxException {
List<SseEvent<String>> sseMessages;

AtomicInteger idGenerator = new AtomicInteger();

protected URI initClient() throws URISyntaxException {
return initClient(null);
}

protected URI initClient(Consumer<JsonObject> initResultAssert) throws URISyntaxException {
McpClient mcpClient = QuarkusRestClientBuilder.newBuilder()
.baseUri(testUri)
.build(McpClient.class);

List<SseEvent<String>> sseMessages = new CopyOnWriteArrayList<>();
sseMessages = new CopyOnWriteArrayList<>();
mcpClient.init().subscribe().with(s -> sseMessages.add(s), e -> {
});
Awaitility.await().until(() -> !sseMessages.isEmpty());
Expand All @@ -47,7 +47,7 @@ public void testInit() throws URISyntaxException {
.put("params",
new JsonObject()
.put("clientInfo", new JsonObject()
.put("name", "FooClient")
.put("name", "test-client")
.put("version", "1.0"))
.put("protocolVersion", "2024-11-05"));

Expand All @@ -60,12 +60,14 @@ public void testInit() throws URISyntaxException {
.statusCode(200)
.extract().body().asString());

assertEquals(initMessage.getInteger("id"), initResponse.getInteger("id"));
assertEquals("2.0", initResponse.getString("jsonrpc"));
JsonObject initResult = initResponse.getJsonObject("result");
JsonObject initResult = assertResponseMessage(initMessage, initResponse);
assertNotNull(initResult);
assertEquals("2024-11-05", initResult.getString("protocolVersion"));

if (initResultAssert != null) {
initResultAssert.accept(initResult);
}

// Send "notifications/initialized"
given()
.contentType(ContentType.JSON)
Expand All @@ -77,69 +79,20 @@ public void testInit() throws URISyntaxException {
.then()
.statusCode(200);

JsonObject promptListMessage = newMessage("prompts/list");

JsonObject promptListResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(promptListMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

assertEquals(promptListMessage.getInteger("id"), promptListResponse.getInteger("id"));
assertEquals("2.0", promptListResponse.getString("jsonrpc"));
JsonObject promptListResult = promptListResponse.getJsonObject("result");
assertNotNull(promptListResult);
JsonArray prompts = promptListResult.getJsonArray("prompts");
assertEquals(4, prompts.size());

assertPromptMessage("Hello Lu!", endpoint, "foo", new JsonObject()
.put("name", "Lu")
.put("repeat", 1)
.put("options", new JsonObject()
.put("enabled", true)));
assertPromptMessage("LU", endpoint, "BAR", new JsonObject()
.put("val", "Lu"));
assertPromptMessage("LU", endpoint, "uni_bar", new JsonObject()
.put("val", "Lu"));
assertPromptMessage("LU", endpoint, "uni_list_bar", new JsonObject()
.put("val", "Lu"));
return endpoint;
}

private void assertPromptMessage(String expectedText, URI endpoint, String name, JsonObject arguments) {
JsonObject promptGetMessage = newMessage("prompts/get")
.put("params", new JsonObject()
.put("name", name)
.put("arguments", arguments));

JsonObject promptGetResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(promptGetMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

assertEquals(promptGetMessage.getInteger("id"), promptGetResponse.getInteger("id"));
assertEquals("2.0", promptGetResponse.getString("jsonrpc"));
JsonObject promptGetResult = promptGetResponse.getJsonObject("result");
assertNotNull(promptGetResult);
JsonArray messages = promptGetResult.getJsonArray("messages");
assertEquals(1, messages.size());
JsonObject message = messages.getJsonObject(0);
assertEquals("user", message.getString("role"));
JsonObject content = message.getJsonObject("content");
assertEquals("text", content.getString("type"));
assertEquals(expectedText, content.getString("text"));
protected JsonObject assertResponseMessage(JsonObject message, JsonObject response) {
assertEquals(message.getInteger("id"), response.getInteger("id"));
assertEquals("2.0", response.getString("jsonrpc"));
return response.getJsonObject("result");
}

private static AtomicInteger ID = new AtomicInteger();

private JsonObject newMessage(String method) {
return new JsonObject().put("jsonrpc", "2.0").put("method", method).put("id", ID.incrementAndGet());
protected JsonObject newMessage(String method) {
return new JsonObject()
.put("jsonrpc", "2.0")
.put("method", method)
.put("id", idGenerator.incrementAndGet());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package io.quarkiverse.mcp.server.test.ping;

import static io.restassured.RestAssured.given;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.net.URISyntaxException;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.test.McpClient;
import io.quarkiverse.mcp.server.test.McpServerTest;
import io.quarkus.test.QuarkusUnitTest;
import io.restassured.http.ContentType;
import io.vertx.core.json.JsonObject;

public class PingTest extends McpServerTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(McpClient.class));

@Test
public void testPing() throws URISyntaxException {
URI endpoint = initClient();

JsonObject pingMessage = newMessage("ping");

JsonObject pingResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(pingMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

JsonObject pingResult = assertResponseMessage(pingMessage, pingResponse);
assertNotNull(pingResult);
assertTrue(pingResult.isEmpty());
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package io.quarkiverse.mcp.server.test;
package io.quarkiverse.mcp.server.test.prompts;

import java.util.List;

Expand All @@ -8,6 +8,7 @@
import io.quarkiverse.mcp.server.PromptArg;
import io.quarkiverse.mcp.server.PromptMessage;
import io.quarkiverse.mcp.server.TextContent;
import io.quarkiverse.mcp.server.test.FooService;
import io.quarkus.arc.Arc;
import io.quarkus.runtime.BlockingOperationControl;
import io.smallrye.common.vertx.VertxContext;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package io.quarkiverse.mcp.server.test.prompts;

import static io.restassured.RestAssured.given;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.net.URI;
import java.net.URISyntaxException;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.test.FooService;
import io.quarkiverse.mcp.server.test.McpClient;
import io.quarkiverse.mcp.server.test.McpServerTest;
import io.quarkus.test.QuarkusUnitTest;
import io.restassured.http.ContentType;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;

public class PromptsTest extends McpServerTest {

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(McpClient.class, FooService.class, MyPrompts.class));

@Test
public void testPrompts() throws URISyntaxException {
URI endpoint = initClient();

JsonObject promptListMessage = newMessage("prompts/list");

JsonObject promptListResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(promptListMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

JsonObject promptListResult = assertResponseMessage(promptListMessage, promptListResponse);
assertNotNull(promptListResult);
JsonArray prompts = promptListResult.getJsonArray("prompts");
assertEquals(4, prompts.size());

assertPromptMessage("Hello Lu!", endpoint, "foo", new JsonObject()
.put("name", "Lu")
.put("repeat", 1)
.put("options", new JsonObject()
.put("enabled", true)));
assertPromptMessage("LU", endpoint, "BAR", new JsonObject()
.put("val", "Lu"));
assertPromptMessage("LU", endpoint, "uni_bar", new JsonObject()
.put("val", "Lu"));
assertPromptMessage("LU", endpoint, "uni_list_bar", new JsonObject()
.put("val", "Lu"));
}

private void assertPromptMessage(String expectedText, URI endpoint, String name, JsonObject arguments) {
JsonObject promptGetMessage = newMessage("prompts/get")
.put("params", new JsonObject()
.put("name", name)
.put("arguments", arguments));

JsonObject promptGetResponse = new JsonObject(given()
.contentType(ContentType.JSON)
.when()
.body(promptGetMessage.encode())
.post(endpoint)
.then()
.statusCode(200)
.extract().body().asString());

JsonObject promptGetResult = assertResponseMessage(promptGetMessage, promptGetResponse);
assertNotNull(promptGetResult);
JsonArray messages = promptGetResult.getJsonArray("messages");
assertEquals(1, messages.size());
JsonObject message = messages.getJsonObject(0);
assertEquals("user", message.getString("role"));
JsonObject content = message.getJsonObject("content");
assertEquals("text", content.getString("type"));
assertEquals(expectedText, content.getString("text"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void handle(RoutingContext ctx) {
}
String id = ctx.request().getParam("id");
if (id == null) {
LOG.error("Connection id is required");
LOG.error("Connection id is missing");
ctx.fail(400);
return;
}
Expand All @@ -56,7 +56,7 @@ public void handle(RoutingContext ctx) {
JsonObject message = ctx.body().asJsonObject();
String jsonrpc = message.getString("jsonrpc");
if (!"2.0".equals(jsonrpc)) {
LOG.errorf("Invalid jsonrpc [%s]", message);
LOG.errorf("Invalid jsonrpc version [%s]", message);
ctx.fail(400);
return;
}
Expand Down Expand Up @@ -99,11 +99,11 @@ private void initializing(JsonObject message, RoutingContext ctx, McpConnectionI
String method = message.getString("method");
if ("notifications/initialized".equals(method)) {
if (connection.initialized()) {
LOG.infof("Client initialized [id: %s]", connection.id());
LOG.infof("Client successfully initialized [%s]", connection.id());
ctx.end();
}
} else {
LOG.infof("Client not initialized yet [id: %s]", connection.id());
LOG.infof("Client not initialized yet [%s]", connection.id());
ctx.fail(400);
}
// TODO ping
Expand All @@ -114,20 +114,31 @@ private void operation(JsonObject message, RoutingContext ctx, McpConnectionImpl
switch (method) {
case "prompts/list" -> promptsList(message, ctx);
case "prompts/get" -> promptsGet(message, ctx);
case "ping" -> ping(message, ctx);
default -> throw new IllegalArgumentException("Unsupported method: " + method);
}
}

private void ping(JsonObject message, RoutingContext ctx) {
// https://spec.modelcontextprotocol.io/specification/basic/utilities/ping/
Object id = message.getValue("id");
LOG.infof("Ping [id: %s]", id);
ctx.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json");
ctx.end(result(id, new JsonObject()).encode());
}

private void promptsList(JsonObject message, RoutingContext ctx) {
LOG.infof("List prompts");
Object id = message.getValue("id");
LOG.infof("List prompts [id: %s]", id);
PromptManager promptManager = Arc.container().instance(PromptManager.class).get();
ctx.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json");
ctx.end(result(message.getValue("id"), new JsonObject()
ctx.end(result(id, new JsonObject()
.put("prompts", new JsonArray(promptManager.list()))).encode());
}

private void promptsGet(JsonObject message, RoutingContext ctx) {
LOG.infof("Get prompts");
Object id = message.getValue("id");
LOG.infof("Get prompts [id: %s]", id);
PromptManager promptManager = Arc.container().instance(PromptManager.class).get();
ctx.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json");
JsonObject params = message.getJsonObject("params");
Expand All @@ -137,7 +148,7 @@ private void promptsGet(JsonObject message, RoutingContext ctx) {
@Override
public void handle(AsyncResult<List<PromptMessage>> ar) {
if (ar.succeeded()) {
ctx.end(result(message.getValue("id"), new JsonObject()
ctx.end(result(id, new JsonObject()
.put("messages", new JsonArray(ar.result())))
.encode());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void handle(RoutingContext ctx) {

String id = Base64.getUrlEncoder().encodeToString(UUID.randomUUID().toString().getBytes());

LOG.infof("Connection initialized: %s]", id);
LOG.infof("Client connection initialized [%s]", id);

McpConnectionImpl connection = new McpConnectionImpl(id, response);
connectionManager.add(connection);
Expand Down

0 comments on commit a2fa9ca

Please # to comment.