Skip to content

Validate subscription data resolver during schema parsing #756

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ data class SchemaParserOptions internal constructor(
GenericWrapper(CompletableFuture::class, 0),
GenericWrapper(CompletionStage::class, 0),
GenericWrapper(Publisher::class, 0),
GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel, _ ->
GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel ->
publish(coroutineContextProvider.provide()) {
try {
for (item in receiveChannel) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ package graphql.kickstart.tools.resolver

import graphql.GraphQLContext
import graphql.Scalars
import graphql.kickstart.tools.GraphQLSubscriptionResolver
import graphql.kickstart.tools.ResolverInfo
import graphql.kickstart.tools.RootResolverInfo
import graphql.kickstart.tools.SchemaParserOptions
import graphql.kickstart.tools.util.*
import graphql.language.FieldDefinition
import graphql.language.TypeName
import graphql.schema.DataFetchingEnvironment
import kotlinx.coroutines.channels.ReceiveChannel
import org.apache.commons.lang3.ClassUtils
import org.apache.commons.lang3.reflect.FieldUtils
import org.reactivestreams.Publisher
import org.slf4j.LoggerFactory
import java.lang.reflect.AccessibleObject
import java.lang.reflect.Method
Expand Down Expand Up @@ -86,7 +89,7 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
}

private fun findResolverMethod(field: FieldDefinition, search: Search): Method? {
val methods = getAllMethods(search.type)
val methods = getAllMethods(search)
val argumentCount = field.inputValueDefinitions.size + if (search.requiredFirstParameterType != null) 1 else 0
val name = field.name

Expand All @@ -109,10 +112,11 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
}
}

private fun getAllMethods(type: JavaType): List<Method> {
val declaredMethods = type.unwrap().declaredNonProxyMethods
val superClassesMethods = ClassUtils.getAllSuperclasses(type.unwrap()).flatMap { it.methods.toList() }
val interfacesMethods = ClassUtils.getAllInterfaces(type.unwrap()).flatMap { it.methods.toList() }
private fun getAllMethods(search: Search): List<Method> {
val type = search.type.unwrap()
val declaredMethods = type.declaredNonProxyMethods
val superClassesMethods = ClassUtils.getAllSuperclasses(type).flatMap { it.methods.toList() }
val interfacesMethods = ClassUtils.getAllInterfaces(type).flatMap { it.methods.toList() }

return (declaredMethods + superClassesMethods + interfacesMethods)
.asSequence()
Expand All @@ -121,9 +125,26 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
// discard any methods that are coming off the root of the class hierarchy
// to avoid issues with duplicate method declarations
.filter { it.declaringClass != Object::class.java }
// subscription resolvers must return a publisher
.filter { search.source !is GraphQLSubscriptionResolver || resolverMethodReturnsPublisher(it) }
.toList()
}

private fun resolverMethodReturnsPublisher(method: Method) =
method.returnType.isAssignableFrom(Publisher::class.java) || receiveChannelToPublisherWrapper(method)

private fun receiveChannelToPublisherWrapper(method: Method) =
method.returnType.isAssignableFrom(ReceiveChannel::class.java)
&& options.genericWrappers.any { wrapper ->
val isReceiveChannelWrapper = wrapper.type == method.returnType
val hasPublisherTransformer = wrapper
.transformer.javaClass
.declaredMethods
.filter { it.name == "invoke" }
.any { it.returnType.isAssignableFrom(Publisher::class.java) }
isReceiveChannelWrapper && hasPublisherTransformer
}

private fun isBoolean(type: GraphQLLangType) = type.unwrap().let { it is TypeName && it.name == Scalars.GraphQLBoolean.name }

private fun verifyMethodArguments(method: Method, requiredCount: Int, search: Search): Boolean {
Expand Down Expand Up @@ -166,14 +187,18 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
private fun getMissingFieldMessage(field: FieldDefinition, searches: List<Search>, scannedProperties: Boolean): String {
val signatures = mutableListOf("")
val isBoolean = isBoolean(field.type)
var isSubscription = false

searches.forEach { search ->
signatures.addAll(getMissingMethodSignatures(field, search, isBoolean, scannedProperties))
isSubscription = isSubscription || search.source is GraphQLSubscriptionResolver
}

val sourceName = if (field.sourceLocation != null && field.sourceLocation.sourceName != null) field.sourceLocation.sourceName else "<unknown>"
val sourceLocation = if (field.sourceLocation != null) "$sourceName:${field.sourceLocation.line}" else "<unknown>"
return "No method${if (scannedProperties) " or field" else ""} found as defined in schema $sourceLocation with any of the following signatures (with or without one of $allowedLastArgumentTypes as the last argument), in priority order:\n${signatures.joinToString("\n ")}"
return "No method${if (scannedProperties) " or field" else ""} found as defined in schema $sourceLocation with any of the following signatures " +
"(with or without one of $allowedLastArgumentTypes as the last argument), in priority order:\n${signatures.joinToString("\n ")}" +
if (isSubscription) "\n\nNote that a Subscription data fetcher must return a Publisher of events" else ""
}

private fun getMissingMethodSignatures(field: FieldDefinition, search: Search, isBoolean: Boolean, scannedProperties: Boolean): List<String> {
Expand Down
40 changes: 40 additions & 0 deletions src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -662,4 +662,44 @@ class SchemaParserTest {
}
}
}

@Test
fun `parser should verify subscription resolver return type`() {
val error = assertThrows(FieldResolverError::class.java) {
SchemaParser.newParser()
.schemaString(
"""
type Subscription {
onItemCreated: Int!
}

type Query {
test: String
}
"""
)
.resolvers(
Subscription(),
object : GraphQLQueryResolver { fun test() = "test" }
)
.build()
.makeExecutableSchema()
}

val expected = """
No method or field found as defined in schema <unknown>:3 with any of the following signatures (with or without one of [interface graphql.schema.DataFetchingEnvironment, class graphql.GraphQLContext] as the last argument), in priority order:

graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.onItemCreated()
graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.getOnItemCreated()
graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.onItemCreated

Note that a Subscription data fetcher must return a Publisher of events
""".trimIndent()

assertEquals(error.message, expected)
}

class Subscription : GraphQLSubscriptionResolver {
fun onItemCreated(env: DataFetchingEnvironment) = env.hashCode()
}
}