diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParserOptions.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParserOptions.kt index f1bb8009..259b11f8 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParserOptions.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParserOptions.kt @@ -163,7 +163,7 @@ data class SchemaParserOptions internal constructor( GenericWrapper(CompletableFuture::class, 0), GenericWrapper(CompletionStage::class, 0), GenericWrapper(Publisher::class, 0), - GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel, _ -> + GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel -> publish(coroutineContextProvider.provide()) { try { for (item in receiveChannel) { diff --git a/src/main/kotlin/graphql/kickstart/tools/resolver/FieldResolverScanner.kt b/src/main/kotlin/graphql/kickstart/tools/resolver/FieldResolverScanner.kt index 4cd44985..555b96c7 100644 --- a/src/main/kotlin/graphql/kickstart/tools/resolver/FieldResolverScanner.kt +++ b/src/main/kotlin/graphql/kickstart/tools/resolver/FieldResolverScanner.kt @@ -2,6 +2,7 @@ package graphql.kickstart.tools.resolver import graphql.GraphQLContext import graphql.Scalars +import graphql.kickstart.tools.GraphQLSubscriptionResolver import graphql.kickstart.tools.ResolverInfo import graphql.kickstart.tools.RootResolverInfo import graphql.kickstart.tools.SchemaParserOptions @@ -9,8 +10,10 @@ import graphql.kickstart.tools.util.* import graphql.language.FieldDefinition import graphql.language.TypeName import graphql.schema.DataFetchingEnvironment +import kotlinx.coroutines.channels.ReceiveChannel import org.apache.commons.lang3.ClassUtils import org.apache.commons.lang3.reflect.FieldUtils +import org.reactivestreams.Publisher import org.slf4j.LoggerFactory import java.lang.reflect.AccessibleObject import java.lang.reflect.Method @@ -86,7 +89,7 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) { } private fun findResolverMethod(field: FieldDefinition, search: Search): Method? { - val methods = getAllMethods(search.type) + val methods = getAllMethods(search) val argumentCount = field.inputValueDefinitions.size + if (search.requiredFirstParameterType != null) 1 else 0 val name = field.name @@ -109,10 +112,11 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) { } } - private fun getAllMethods(type: JavaType): List { - val declaredMethods = type.unwrap().declaredNonProxyMethods - val superClassesMethods = ClassUtils.getAllSuperclasses(type.unwrap()).flatMap { it.methods.toList() } - val interfacesMethods = ClassUtils.getAllInterfaces(type.unwrap()).flatMap { it.methods.toList() } + private fun getAllMethods(search: Search): List { + val type = search.type.unwrap() + val declaredMethods = type.declaredNonProxyMethods + val superClassesMethods = ClassUtils.getAllSuperclasses(type).flatMap { it.methods.toList() } + val interfacesMethods = ClassUtils.getAllInterfaces(type).flatMap { it.methods.toList() } return (declaredMethods + superClassesMethods + interfacesMethods) .asSequence() @@ -121,9 +125,26 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) { // discard any methods that are coming off the root of the class hierarchy // to avoid issues with duplicate method declarations .filter { it.declaringClass != Object::class.java } + // subscription resolvers must return a publisher + .filter { search.source !is GraphQLSubscriptionResolver || resolverMethodReturnsPublisher(it) } .toList() } + private fun resolverMethodReturnsPublisher(method: Method) = + method.returnType.isAssignableFrom(Publisher::class.java) || receiveChannelToPublisherWrapper(method) + + private fun receiveChannelToPublisherWrapper(method: Method) = + method.returnType.isAssignableFrom(ReceiveChannel::class.java) + && options.genericWrappers.any { wrapper -> + val isReceiveChannelWrapper = wrapper.type == method.returnType + val hasPublisherTransformer = wrapper + .transformer.javaClass + .declaredMethods + .filter { it.name == "invoke" } + .any { it.returnType.isAssignableFrom(Publisher::class.java) } + isReceiveChannelWrapper && hasPublisherTransformer + } + private fun isBoolean(type: GraphQLLangType) = type.unwrap().let { it is TypeName && it.name == Scalars.GraphQLBoolean.name } private fun verifyMethodArguments(method: Method, requiredCount: Int, search: Search): Boolean { @@ -166,14 +187,18 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) { private fun getMissingFieldMessage(field: FieldDefinition, searches: List, scannedProperties: Boolean): String { val signatures = mutableListOf("") val isBoolean = isBoolean(field.type) + var isSubscription = false searches.forEach { search -> signatures.addAll(getMissingMethodSignatures(field, search, isBoolean, scannedProperties)) + isSubscription = isSubscription || search.source is GraphQLSubscriptionResolver } val sourceName = if (field.sourceLocation != null && field.sourceLocation.sourceName != null) field.sourceLocation.sourceName else "" val sourceLocation = if (field.sourceLocation != null) "$sourceName:${field.sourceLocation.line}" else "" - return "No method${if (scannedProperties) " or field" else ""} found as defined in schema $sourceLocation with any of the following signatures (with or without one of $allowedLastArgumentTypes as the last argument), in priority order:\n${signatures.joinToString("\n ")}" + return "No method${if (scannedProperties) " or field" else ""} found as defined in schema $sourceLocation with any of the following signatures " + + "(with or without one of $allowedLastArgumentTypes as the last argument), in priority order:\n${signatures.joinToString("\n ")}" + + if (isSubscription) "\n\nNote that a Subscription data fetcher must return a Publisher of events" else "" } private fun getMissingMethodSignatures(field: FieldDefinition, search: Search, isBoolean: Boolean, scannedProperties: Boolean): List { diff --git a/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt b/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt index 50700fb7..87fb9daf 100644 --- a/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt @@ -662,4 +662,44 @@ class SchemaParserTest { } } } + + @Test + fun `parser should verify subscription resolver return type`() { + val error = assertThrows(FieldResolverError::class.java) { + SchemaParser.newParser() + .schemaString( + """ + type Subscription { + onItemCreated: Int! + } + + type Query { + test: String + } + """ + ) + .resolvers( + Subscription(), + object : GraphQLQueryResolver { fun test() = "test" } + ) + .build() + .makeExecutableSchema() + } + + val expected = """ + No method or field found as defined in schema :3 with any of the following signatures (with or without one of [interface graphql.schema.DataFetchingEnvironment, class graphql.GraphQLContext] as the last argument), in priority order: + + graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.onItemCreated() + graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.getOnItemCreated() + graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.onItemCreated + + Note that a Subscription data fetcher must return a Publisher of events + """.trimIndent() + + assertEquals(error.message, expected) + } + + class Subscription : GraphQLSubscriptionResolver { + fun onItemCreated(env: DataFetchingEnvironment) = env.hashCode() + } }