diff --git a/annotation-processor/build.gradle b/annotation-processor/build.gradle index 9189d5e..a6dfda8 100644 --- a/annotation-processor/build.gradle +++ b/annotation-processor/build.gradle @@ -11,8 +11,8 @@ dependencies { api project(":runtime") - testCompile'org.junit.jupiter:junit-jupiter:5.5.2' - testCompile 'org.assertj:assertj-core:3.14.0' + testCompile 'org.junit.jupiter:junit-jupiter:5.8.2' + testCompile 'org.assertj:assertj-core:3.22.0' testRuntime 'ch.qos.logback:logback-classic:1.2.10' testCompile("io.github.jbock-java:compile-testing:0.19.11") } diff --git a/annotation-processor/src/main/kotlin/pl/touk/krush/model/EntityDefinition.kt b/annotation-processor/src/main/kotlin/pl/touk/krush/model/EntityDefinition.kt index db73771..383d171 100644 --- a/annotation-processor/src/main/kotlin/pl/touk/krush/model/EntityDefinition.kt +++ b/annotation-processor/src/main/kotlin/pl/touk/krush/model/EntityDefinition.kt @@ -160,6 +160,20 @@ fun EntityGraph.traverse(function: (TypeElement, EntityDefinition) -> Unit) { this.entries.forEach { (key, value) -> function.invoke(key, value) } } +class DFS(val graphs: EntityGraphs) { + private val result = mutableSetOf() + private val visited = mutableSetOf() + + fun visit(elem: TypeElement): List { + val current = graphs.entity(elem.packageName, elem) ?: throw EntityNotMappedException(elem) + result.add(current) + visited.add(elem) + val remaining = current.associations.map { it.target }.filterNot { visited.contains(it) } + remaining.forEach { visit(it) } + return result.toList() + } +} + fun EntityGraph.allAssociations() = this.values.flatMap { entityDef -> entityDef.associations.map { it.target } }.toSet() diff --git a/annotation-processor/src/main/kotlin/pl/touk/krush/source/MappingsGenerator.kt b/annotation-processor/src/main/kotlin/pl/touk/krush/source/MappingsGenerator.kt index b376b14..f2cae8f 100644 --- a/annotation-processor/src/main/kotlin/pl/touk/krush/source/MappingsGenerator.kt +++ b/annotation-processor/src/main/kotlin/pl/touk/krush/source/MappingsGenerator.kt @@ -55,7 +55,7 @@ class MappingsGenerator : SourceGenerator { fileSpec.addFunction(buildSelfReferencesToEntityListFunc(entityType, entity)) } fileSpec.addFunction(buildAddSubEntitiesToEntityFunc(entityClass, entity)) - fileSpec.addFunction(buildToEntityMapFunc(hasSelfRef, entityClass, entity, graph)) + fileSpec.addFunction(buildToEntityMapFunc(hasSelfRef, entityClass, entity, graphs)) // Functions for inserting objects into the DB buildFromEntityFunc(entityType, entity)?.let { funSpec -> @@ -122,7 +122,11 @@ class MappingsGenerator : SourceGenerator { } else if (!assoc.nullable) { "\t${assoc.name} = this.to${assoc.target.simpleName}()" } else { - "\t${assoc.name} = row[${entity.tableName}.${assoc.defaultIdPropName()}]?.let { this.to${assoc.target.simpleName}() }" + if (assoc.isSelfReferential) { + "\t${assoc.name} = row[${entity.tableName}.${assoc.defaultIdPropName()}]?.let { nextAlias?.let { this.to${assoc.target.simpleName}(nextAlias) } }" + } else { + "\t${assoc.name} = row[${entity.tableName}.${assoc.defaultIdPropName()}]?.let { this.to${assoc.target.simpleName}() }" + } } } @@ -274,7 +278,6 @@ class MappingsGenerator : SourceGenerator { // Allowing a null id here allows users to not include a join with the other table if they don't // need the relation-lists to be populated addStatement("val $assocVarId = ${idReadingBlock(setAssoc.targetId, setAssoc.targetTable, nullable = true, rowReference = "row")}") -// beginControlFlow("if ($assocVarId != null && !containsEntity(%T::class, $assocVarId)) {", targetClass) beginControlFlow("if ($assocVarId != null) {") addStatement("val $attrValName = $entityParamName.$assocVar as MutableList<$targetTypeName>") @@ -350,7 +353,9 @@ class MappingsGenerator : SourceGenerator { return func.build() } - private fun buildToEntityMapFunc(hasSelfRef: Boolean, entityClass: ClassName, entity: EntityDefinition, graph: EntityGraph): FunSpec { + private fun buildToEntityMapFunc( + hasSelfRef: Boolean, entityClass: ClassName, entity: EntityDefinition, graphs: EntityGraphs + ): FunSpec { val rootKey = entity.id?.asUnderlyingTypeName() ?: throw MissingIdException(entity) val func = if (hasSelfRef) { @@ -382,17 +387,20 @@ class MappingsGenerator : SourceGenerator { addStatement("}") - val selfRefAssociations = graph.values - .flatMap { entityDef -> - entityDef.associations.filter { it.isSelfReferential } - } + val selfRefAssociations = DFS(graphs).visit(entity.type) + .flatMap { it.associations } + .filter { it.isSelfReferential } + + val selfRefAssociationsFiltered = selfRefAssociations + // filter out bidirectional associations processed twice + .filterNot { selfRefAssoc -> selfRefAssociations.any { it != selfRefAssoc && it.target == selfRefAssoc.source && it.isBidirectional } } - if(selfRefAssociations.isNotEmpty()) { + if (selfRefAssociationsFiltered.isNotEmpty()) { // Go through all self references requested and add them to the respective list. addStatement("selfReferenceRequests.forEach { (clazz, unsatisfiedMap) -> ") addStatement("\twhen(clazz) {") - selfRefAssociations + selfRefAssociationsFiltered .forEach { selfRefAssoc -> val entityName = selfRefAssoc.source.simpleName val subjectIdName = "subject${entityName}Id" diff --git a/annotation-processor/src/test/kotlin/pl/touk/krush/model/DFSTest.kt b/annotation-processor/src/test/kotlin/pl/touk/krush/model/DFSTest.kt new file mode 100644 index 0000000..f6466b3 --- /dev/null +++ b/annotation-processor/src/test/kotlin/pl/touk/krush/model/DFSTest.kt @@ -0,0 +1,31 @@ +package pl.touk.krush.model + +import com.squareup.kotlinpoet.metadata.KotlinPoetMetadataPreview +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import pl.touk.krush.AnnotationProcessorTest +import javax.lang.model.util.Elements +import javax.lang.model.util.Types + +@KotlinPoetMetadataPreview +class DFSTest(types: Types, elements: Elements) : AnnotationProcessorTest(types, elements), EntityGraphSampleData { + + @Test + fun shouldVisitAllNodesUsingDFS() { + //given + val graphBuilder = oneToOneGraphBuilder(getTypeEnv()) + + //when + val graphs = graphBuilder.build() + val typeElement = oneToOneSourceEntity(getTypeEnv()) + + val elements = DFS(graphs).visit(typeElement) + + //then + assertThat(elements) + .hasSize(2) + .extracting("type") + .containsOnly(typeElement, oneToOneTargetEntity(getTypeEnv())) + } + +} diff --git a/build.gradle b/build.gradle index 9b0992c..960bb88 100644 --- a/build.gradle +++ b/build.gradle @@ -9,7 +9,7 @@ buildscript { } plugins { - id 'org.jetbrains.kotlin.jvm' version '1.4.31' + id 'org.jetbrains.kotlin.jvm' version '1.6.10' id 'pl.allegro.tech.build.axion-release' version '1.13.6' id 'maven-publish' } diff --git a/example/src/main/kotlin/pl/touk/krush/realreferences/Category.kt b/example/src/main/kotlin/pl/touk/krush/realreferences/Category.kt index c376046..6bd2be8 100644 --- a/example/src/main/kotlin/pl/touk/krush/realreferences/Category.kt +++ b/example/src/main/kotlin/pl/touk/krush/realreferences/Category.kt @@ -12,7 +12,7 @@ data class Category( @Column val name: String, - @OneToOne + @ManyToOne @JoinColumn(name = "parent_id") val parent: Category?,