Skip to content

Commit

Permalink
Merge pull request #88 from mkouba/issue-85-encoders
Browse files Browse the repository at this point in the history
core: introduce API to encode return values as response objects
  • Loading branch information
mkouba authored Jan 28, 2025
2 parents 3bf58ff + b41885e commit 0095a17
Show file tree
Hide file tree
Showing 29 changed files with 1,262 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static io.quarkiverse.mcp.server.runtime.FeatureMetadata.Feature.TOOL;
import static io.quarkus.deployment.annotations.ExecutionTime.RUNTIME_INIT;

import java.lang.annotation.Annotation;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.ArrayList;
Expand Down Expand Up @@ -43,22 +44,31 @@
import io.quarkiverse.mcp.server.TextContent;
import io.quarkiverse.mcp.server.TextResourceContents;
import io.quarkiverse.mcp.server.ToolResponse;
import io.quarkiverse.mcp.server.runtime.EncoderMapper;
import io.quarkiverse.mcp.server.runtime.ExecutionModel;
import io.quarkiverse.mcp.server.runtime.FeatureArgument;
import io.quarkiverse.mcp.server.runtime.FeatureArgument.Provider;
import io.quarkiverse.mcp.server.runtime.FeatureMetadata;
import io.quarkiverse.mcp.server.runtime.FeatureMetadata.Feature;
import io.quarkiverse.mcp.server.runtime.FeatureMethodInfo;
import io.quarkiverse.mcp.server.runtime.JsonTextContentEncoder;
import io.quarkiverse.mcp.server.runtime.JsonTextResourceContentsEncoder;
import io.quarkiverse.mcp.server.runtime.McpMetadata;
import io.quarkiverse.mcp.server.runtime.McpServerRecorder;
import io.quarkiverse.mcp.server.runtime.PromptCompleteManager;
import io.quarkiverse.mcp.server.runtime.PromptEncoderResultMapper;
import io.quarkiverse.mcp.server.runtime.PromptManager;
import io.quarkiverse.mcp.server.runtime.ResourceContentsEncoderResultMapper;
import io.quarkiverse.mcp.server.runtime.ResourceManager;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateCompleteManager;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateManager;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateManager.VariableMatcher;
import io.quarkiverse.mcp.server.runtime.ResultMappers;
import io.quarkiverse.mcp.server.runtime.ToolEncoderResultMapper;
import io.quarkiverse.mcp.server.runtime.ToolManager;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.AutoAddScopeBuildItem;
import io.quarkus.arc.deployment.BeanDiscoveryFinishedBuildItem;
Expand Down Expand Up @@ -103,7 +113,9 @@ void addBeans(BuildProducer<AdditionalBeanBuildItem> additionalBeans) {
additionalBeans.produce(AdditionalBeanBuildItem.unremovableOf("io.quarkiverse.mcp.server.runtime.ConnectionManager"));
additionalBeans.produce(AdditionalBeanBuildItem.builder().setUnremovable()
.addBeanClasses(PromptManager.class, ToolManager.class, ResourceManager.class, PromptCompleteManager.class,
ResourceTemplateManager.class, ResourceTemplateCompleteManager.class)
ResourceTemplateManager.class, ResourceTemplateCompleteManager.class, JsonTextContentEncoder.class,
JsonTextResourceContentsEncoder.class, ToolEncoderResultMapper.class,
ResourceContentsEncoderResultMapper.class, PromptEncoderResultMapper.class)
.build());
}

Expand Down Expand Up @@ -385,10 +397,13 @@ void registerForReflection(List<FeatureMethodBuildItem> featureMethods,

private void validateFeatureMethod(MethodInfo method, Feature feature, AnnotationInstance featureAnnotation) {
if (Modifier.isStatic(method.flags())) {
throw new IllegalStateException("MCP feature method must not be static: " + method);
throw new IllegalStateException(feature + " method must not be static: " + method);
}
if (Modifier.isPrivate(method.flags())) {
throw new IllegalStateException("MCP feature method must not be private: " + method);
throw new IllegalStateException(feature + " method must not be private: " + method);
}
if (method.returnType().kind() == Kind.VOID) {
throw new IllegalStateException(feature + " method may not return void: " + method);
}
switch (feature) {
case PROMPT -> validatePromptMethod(method);
Expand All @@ -405,16 +420,7 @@ private void validateFeatureMethod(MethodInfo method, Feature feature, Annotatio
ClassType.create(DotNames.PROMPT_MESSAGE));

private void validatePromptMethod(MethodInfo method) {
org.jboss.jandex.Type type = method.returnType();
if (DotNames.UNI.equals(type.name()) && type.kind() == Kind.PARAMETERIZED_TYPE) {
type = type.asParameterizedType().arguments().get(0);
}
if (DotNames.LIST.equals(type.name()) && type.kind() == Kind.PARAMETERIZED_TYPE) {
type = type.asParameterizedType().arguments().get(0);
}
if (!PROMPT_TYPES.contains(type)) {
throw new IllegalStateException("Unsupported Prompt method return type: " + method);
}
// No need to validate return type

List<MethodParameterInfo> parameters = parameters(method);
for (MethodParameterInfo param : parameters) {
Expand Down Expand Up @@ -469,33 +475,25 @@ private void validateResourceTemplateCompleteMethod(MethodInfo method) {
ClassType.create(DotNames.STRING));

private void validateToolMethod(MethodInfo method) {
org.jboss.jandex.Type type = method.returnType();
// No need to validate return type
}

private boolean useEncoder(org.jboss.jandex.Type type, Set<org.jboss.jandex.Type> types) {
if (DotNames.UNI.equals(type.name()) && type.kind() == Kind.PARAMETERIZED_TYPE) {
type = type.asParameterizedType().arguments().get(0);
}
if (DotNames.LIST.equals(type.name()) && type.kind() == Kind.PARAMETERIZED_TYPE) {
type = type.asParameterizedType().arguments().get(0);
}
if (!TOOL_TYPES.contains(type)) {
throw new IllegalStateException("Unsupported Tool method return type: " + method);
}
return !types.contains(type);
}

static final Set<org.jboss.jandex.Type> RESOURCE_TYPES = Set.of(ClassType.create(DotNames.RESOURCE_RESPONSE),
ClassType.create(DotNames.RESOURCE_CONTENTS), ClassType.create(DotNames.TEXT_RESOURCE_CONTENTS),
ClassType.create(DotNames.BLOB_RESOURCE_CONTENTS));

private void validateResourceMethod(MethodInfo method) {
org.jboss.jandex.Type type = method.returnType();
if (DotNames.UNI.equals(type.name()) && type.kind() == Kind.PARAMETERIZED_TYPE) {
type = type.asParameterizedType().arguments().get(0);
}
if (DotNames.LIST.equals(type.name()) && type.kind() == Kind.PARAMETERIZED_TYPE) {
type = type.asParameterizedType().arguments().get(0);
}
if (!RESOURCE_TYPES.contains(type)) {
throw new IllegalStateException("Unsupported Resource method return type: " + method);
}
// No need to validate return type

List<MethodParameterInfo> parameters = parameters(method);
if (!parameters.isEmpty()) {
Expand All @@ -505,16 +503,7 @@ private void validateResourceMethod(MethodInfo method) {
}

private void validateResourceTemplateMethod(MethodInfo method, AnnotationInstance featureAnnotation) {
org.jboss.jandex.Type type = method.returnType();
if (DotNames.UNI.equals(type.name()) && type.kind() == Kind.PARAMETERIZED_TYPE) {
type = type.asParameterizedType().arguments().get(0);
}
if (DotNames.LIST.equals(type.name()) && type.kind() == Kind.PARAMETERIZED_TYPE) {
type = type.asParameterizedType().arguments().get(0);
}
if (!RESOURCE_TYPES.contains(type)) {
throw new IllegalStateException("Unsupported Resource template method return type: " + method);
}
// No need to validate return type

AnnotationValue uriTemplateValue = featureAnnotation.value("uriTemplate");
if (uriTemplateValue == null) {
Expand Down Expand Up @@ -623,46 +612,99 @@ private FeatureArgument.Provider providerFrom(org.jboss.jandex.Type type) {

private ResultHandle getMapper(BytecodeCreator bytecode, org.jboss.jandex.Type returnType,
Feature feature) {
// Returns a function that converts the returned object to Uni<RESPONSE>
// where the RESPONSE is one of ToolResponse, PromptResponse, ResourceResponse, CompleteResponse
// IMPL NOTE: at this point the method return type is already validated
return switch (feature) {
case PROMPT -> readResultMapper(bytecode,
createMapperField(PROMPT, returnType, DotNames.PROMPT_RESPONSE, c -> "MESSAGE"));
case PROMPT -> propmpResultMapper(bytecode, returnType);
case PROMPT_COMPLETE -> readResultMapper(bytecode,
createMapperField(PROMPT_COMPLETE, returnType, DotNames.COMPLETE_RESPONSE, c -> "STRING"));
case TOOL -> readResultMapper(bytecode, createMapperField(TOOL, returnType, DotNames.TOOL_RESPONSE, c -> {
return isContent(c) ? "CONTENT" : "STRING";
}));
case RESOURCE, RESOURCE_TEMPLATE -> readResultMapper(bytecode,
createMapperField(RESOURCE, returnType, DotNames.RESOURCE_RESPONSE, c -> "CONTENT"));
createMapperClassSimpleName(PROMPT_COMPLETE, returnType, DotNames.COMPLETE_RESPONSE, c -> "String"));
case TOOL -> toolResultMapper(bytecode, returnType);
case RESOURCE, RESOURCE_TEMPLATE -> resourceResultMapper(bytecode, returnType);
case RESOURCE_TEMPLATE_COMPLETE -> readResultMapper(bytecode,
createMapperField(RESOURCE_TEMPLATE_COMPLETE, returnType, DotNames.COMPLETE_RESPONSE, c -> "STRING"));
createMapperClassSimpleName(RESOURCE_TEMPLATE_COMPLETE, returnType, DotNames.COMPLETE_RESPONSE,
c -> "String"));
default -> throw new IllegalArgumentException("Unsupported feature: " + feature);
};
}

static String createMapperField(FeatureMetadata.Feature feature, org.jboss.jandex.Type returnType,
ResultHandle resourceResultMapper(BytecodeCreator bytecode, org.jboss.jandex.Type returnType) {
if (useEncoder(returnType, RESOURCE_TYPES)) {
return encoderResultMapper(bytecode, returnType, ResourceContentsEncoderResultMapper.class);
} else {
return readResultMapper(bytecode,
createMapperClassSimpleName(RESOURCE, returnType, DotNames.RESOURCE_RESPONSE, c -> "Content"));
}
}

ResultHandle encoderResultMapper(BytecodeCreator bytecode, org.jboss.jandex.Type returnType, Class<?> mapperClazz) {
// Arc.container().instance(mapperClazz).get();
ResultHandle container = bytecode
.invokeStaticMethod(MethodDescriptor.ofMethod(Arc.class, "container", ArcContainer.class));
ResultHandle instance = bytecode.invokeInterfaceMethod(MethodDescriptor.ofMethod(ArcContainer.class, "instance",
InstanceHandle.class, Class.class, Annotation[].class), container,
bytecode.loadClass(mapperClazz), bytecode.newArray(Annotation.class, 0));
ResultHandle mapper = bytecode.invokeInterfaceMethod(
MethodDescriptor.ofMethod(InstanceHandle.class, "get", Object.class),
instance);
if (DotNames.UNI.equals(returnType.name())) {
if (DotNames.LIST.equals(returnType.asParameterizedType().arguments().get(0).name())) {
mapper = bytecode.invokeVirtualMethod(
MethodDescriptor.ofMethod(mapperClazz, "uniList", EncoderMapper.class), mapper);
} else {
mapper = bytecode.invokeVirtualMethod(
MethodDescriptor.ofMethod(mapperClazz, "uni", EncoderMapper.class), mapper);
}
} else if (DotNames.LIST.equals(returnType.name())) {
mapper = bytecode.invokeVirtualMethod(
MethodDescriptor.ofMethod(mapperClazz, "list", EncoderMapper.class), mapper);
}
return mapper;
}

ResultHandle toolResultMapper(BytecodeCreator bytecode, org.jboss.jandex.Type returnType) {
if (useEncoder(returnType, TOOL_TYPES)) {
return encoderResultMapper(bytecode, returnType, ToolEncoderResultMapper.class);
} else {
return readResultMapper(bytecode, createMapperClassSimpleName(TOOL, returnType, DotNames.TOOL_RESPONSE, c -> {
return isContent(c) ? "Content" : "String";
}));
}
}

ResultHandle propmpResultMapper(BytecodeCreator bytecode, org.jboss.jandex.Type returnType) {
if (useEncoder(returnType, PROMPT_TYPES)) {
return encoderResultMapper(bytecode, returnType, PromptEncoderResultMapper.class);
} else {
return readResultMapper(bytecode,
createMapperClassSimpleName(PROMPT, returnType, DotNames.PROMPT_RESPONSE, c -> "OfMessage"));
}
}

static String createMapperClassSimpleName(FeatureMetadata.Feature feature, org.jboss.jandex.Type returnType,
DotName baseType, Function<DotName, String> componentMapper) {
if (returnType.name().equals(baseType)) {
return "TO_UNI";
return "ToUni";
}
org.jboss.jandex.Type type = returnType;
StringBuilder ret;
if (feature == PROMPT_COMPLETE || feature == RESOURCE_TEMPLATE_COMPLETE) {
ret = new StringBuilder("COMPLETE_");
ret = new StringBuilder("Complete");
} else {
ret = new StringBuilder(feature.toString())
.append("_");
String f = feature.toString();
// TOOL -> Tool
ret = new StringBuilder().append(f.charAt(0)).append(f.substring(1).toLowerCase());
}
if (DotNames.UNI.equals(type.name())) {
type = type.asParameterizedType().arguments().get(0);
if (type.name().equals(baseType)) {
return "IDENTITY";
return "Identity";
}
ret.append("UNI_");
ret.append("Uni");
}
if (DotNames.LIST.equals(type.name())) {
type = type.asParameterizedType().arguments().get(0);
ret.append("LIST_");
ret.append("List");
}
ret.append(componentMapper.apply(type.name()));

Expand All @@ -674,8 +716,9 @@ private boolean isContent(DotName typeName) {
|| DotNames.IMAGE_CONTENT.equals(typeName) || DotNames.RESOURCE_CONTENT.equals(typeName);
}

private ResultHandle readResultMapper(BytecodeCreator bytecode, String contantName) {
return bytecode.readStaticField(FieldDescriptor.of(ResultMappers.class, contantName, Function.class));
private ResultHandle readResultMapper(BytecodeCreator bytecode, String mapperClassSimpleName) {
String mapperClassName = ResultMappers.class.getName() + "$" + mapperClassSimpleName;
return bytecode.readStaticField(FieldDescriptor.of(mapperClassName, "INSTANCE", mapperClassName));
}

private static ExecutionModel executionModel(MethodInfo method, TransformedAnnotationsBuildItem transformedAnnotations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ public class McpServerProcessorTest {

@Test
public void testCreateMapperField() {
assertEquals("RESOURCE_CONTENT",
McpServerProcessor.createMapperField(Feature.RESOURCE, ClassType.create(DotNames.TEXT_RESOURCE_CONTENTS),
assertEquals("ResourceContent",
McpServerProcessor.createMapperClassSimpleName(Feature.RESOURCE,
ClassType.create(DotNames.TEXT_RESOURCE_CONTENTS),
DotNames.RESOURCE_RESPONSE,
c -> "CONTENT"));
assertEquals("IDENTITY",
McpServerProcessor.createMapperField(Feature.RESOURCE,
c -> "Content"));
assertEquals("Identity",
McpServerProcessor.createMapperClassSimpleName(Feature.RESOURCE,
ParameterizedType.create(DotNames.UNI, ClassType.create(DotNames.RESOURCE_RESPONSE)),
DotNames.RESOURCE_RESPONSE,
c -> "CONTENT"));
c -> "Content"));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.quarkiverse.mcp.server;

import jakarta.annotation.Priority;

/**
* Encodes an object as {@link Content}.
* <p>
* Implementation classes must be CDI beans. Qualifiers are ignored. {@link jakarta.enterprise.context.Dependent} beans are
* reused during encoding.
* <p>
* Encoders may define the priority with {@link Priority}. An encoder with higher priority takes precedence.
*
* @param <TYPE>
* @see Content
* @see Tool
* @see Prompt
*/
public interface ContentEncoder<TYPE> extends Encoder<TYPE, Content> {

}
24 changes: 24 additions & 0 deletions core/runtime/src/main/java/io/quarkiverse/mcp/server/Encoder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.quarkiverse.mcp.server;

/**
*
* @param <TYPE> The type to be encoded
* @param <ENCODED> The resulting type of encoding
*/
public interface Encoder<TYPE, ENCODED> {

/**
*
* @param runtimeType The runtime class of an object that should be encoded, must not be {@code null}
* @return {@code true} if this encoder can encode the provided type, {@code false} otherwise
*/
boolean supports(Class<?> runtimeType);

/**
*
* @param value
* @return the encoded value
*/
ENCODED encode(TYPE value);

}
18 changes: 6 additions & 12 deletions core/runtime/src/main/java/io/quarkiverse/mcp/server/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,19 @@
* <p>
* The result of a "prompt get" operation is always represented as a {@link PromptResponse}. However, the annotated method can
* also return other types that are converted according to the following rules.
* <p>
* <ul>
* <li>If the method returns a {@link PromptMessage} then the reponse has no description and contains the single
* <li>If it returns a {@link PromptMessage} then the reponse has no description and contains the single
* message object.</li>
* <li>If the method returns a {@link List} of {@link PromptMessage}s then the reponse has no description and contains the
* <li>If it returns a {@link List} of {@link PromptMessage}s then the reponse has no description and contains the
* list of messages.</li>
* <li>The method may return a {@link Uni} that wraps any of the type mentioned above.</li>
* </ul>
* In other words, the return type must be one of the following list:
* <ul>
* <li>{@code PromptResponse}</li>
* <li>{@code PromptMessage}</li>
* <li>{@code List<PromptMessage>}</li>
* <li>{@code Uni<PromptResponse>}</li>
* <li>{@code Uni<PromptMessage>}</li>
* <li>{@code Uni<List<PromptMessage>>}</li>
* <li>If it returns any other type {@code X} then {@code X} is encoded using the {@link PromptResponseEncoder} API.</li>
* <li>It may also return a {@link Uni} that wraps any of the type mentioned above.</li>
* </ul>
*
* @see PromptResponse
* @see PromptArg
* @see PromptResponseEncoder
*/
@Retention(RUNTIME)
@Target(METHOD)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.quarkiverse.mcp.server;

import jakarta.annotation.Priority;

/**
* Encodes an object as {@link PromptResponse}.
* <p>
* If a propmpt response encoder exists and matches a specific return type then it always takes precedence over matching
* {@link ContentEncoder}.
* <p>
* Implementation classes must be CDI beans. Qualifiers are ignored. {@link jakarta.enterprise.context.Dependent} beans are
* reused during encoding.
* <p>
* Encoders may define the priority with {@link Priority}. An encoder with higher priority takes precedence.
*
* @param <TYPE>
* @see PromptResponse
* @see Prompt
*/
public interface PromptResponseEncoder<TYPE> extends Encoder<TYPE, PromptResponse> {

}
Loading

0 comments on commit 0095a17

Please # to comment.