Skip to content

Commit

Permalink
Filter selfReferenceRequests to minimal needed
Browse files Browse the repository at this point in the history
  • Loading branch information
pjagielski committed Jan 15, 2022
1 parent 1776b59 commit 770dcb0
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 14 deletions.
4 changes: 2 additions & 2 deletions annotation-processor/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ dependencies {

api project(":runtime")

testCompile'org.junit.jupiter:junit-jupiter:5.5.2'
testCompile 'org.assertj:assertj-core:3.14.0'
testCompile 'org.junit.jupiter:junit-jupiter:5.8.2'
testCompile 'org.assertj:assertj-core:3.22.0'
testRuntime 'ch.qos.logback:logback-classic:1.2.10'
testCompile("io.github.jbock-java:compile-testing:0.19.11")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,20 @@ fun EntityGraph.traverse(function: (TypeElement, EntityDefinition) -> Unit) {
this.entries.forEach { (key, value) -> function.invoke(key, value) }
}

class DFS(val graphs: EntityGraphs) {
private val result = mutableSetOf<EntityDefinition>()
private val visited = mutableSetOf<TypeElement>()

fun visit(elem: TypeElement): List<EntityDefinition> {
val current = graphs.entity(elem.packageName, elem) ?: throw EntityNotMappedException(elem)
result.add(current)
visited.add(elem)
val remaining = current.associations.map { it.target }.filterNot { visited.contains(it) }
remaining.forEach { visit(it) }
return result.toList()
}
}

fun EntityGraph.allAssociations() =
this.values.flatMap { entityDef -> entityDef.associations.map { it.target } }.toSet()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class MappingsGenerator : SourceGenerator {
fileSpec.addFunction(buildSelfReferencesToEntityListFunc(entityType, entity))
}
fileSpec.addFunction(buildAddSubEntitiesToEntityFunc(entityClass, entity))
fileSpec.addFunction(buildToEntityMapFunc(hasSelfRef, entityClass, entity, graph))
fileSpec.addFunction(buildToEntityMapFunc(hasSelfRef, entityClass, entity, graphs))

// Functions for inserting objects into the DB
buildFromEntityFunc(entityType, entity)?.let { funSpec ->
Expand Down Expand Up @@ -122,7 +122,11 @@ class MappingsGenerator : SourceGenerator {
} else if (!assoc.nullable) {
"\t${assoc.name} = this.to${assoc.target.simpleName}()"
} else {
"\t${assoc.name} = row[${entity.tableName}.${assoc.defaultIdPropName()}]?.let { this.to${assoc.target.simpleName}() }"
if (assoc.isSelfReferential) {
"\t${assoc.name} = row[${entity.tableName}.${assoc.defaultIdPropName()}]?.let { nextAlias?.let { this.to${assoc.target.simpleName}(nextAlias) } }"
} else {
"\t${assoc.name} = row[${entity.tableName}.${assoc.defaultIdPropName()}]?.let { this.to${assoc.target.simpleName}() }"
}
}
}

Expand Down Expand Up @@ -274,7 +278,6 @@ class MappingsGenerator : SourceGenerator {
// Allowing a null id here allows users to not include a join with the other table if they don't
// need the relation-lists to be populated
addStatement("val $assocVarId = ${idReadingBlock(setAssoc.targetId, setAssoc.targetTable, nullable = true, rowReference = "row")}")
// beginControlFlow("if ($assocVarId != null && !containsEntity(%T::class, $assocVarId)) {", targetClass)
beginControlFlow("if ($assocVarId != null) {")

addStatement("val $attrValName = $entityParamName.$assocVar as MutableList<$targetTypeName>")
Expand Down Expand Up @@ -350,7 +353,9 @@ class MappingsGenerator : SourceGenerator {
return func.build()
}

private fun buildToEntityMapFunc(hasSelfRef: Boolean, entityClass: ClassName, entity: EntityDefinition, graph: EntityGraph): FunSpec {
private fun buildToEntityMapFunc(
hasSelfRef: Boolean, entityClass: ClassName, entity: EntityDefinition, graphs: EntityGraphs
): FunSpec {
val rootKey = entity.id?.asUnderlyingTypeName() ?: throw MissingIdException(entity)

val func = if (hasSelfRef) {
Expand Down Expand Up @@ -382,17 +387,20 @@ class MappingsGenerator : SourceGenerator {

addStatement("}")

val selfRefAssociations = graph.values
.flatMap { entityDef ->
entityDef.associations.filter { it.isSelfReferential }
}
val selfRefAssociations = DFS(graphs).visit(entity.type)
.flatMap { it.associations }
.filter { it.isSelfReferential }

val selfRefAssociationsFiltered = selfRefAssociations
// filter out bidirectional associations processed twice
.filterNot { selfRefAssoc -> selfRefAssociations.any { it != selfRefAssoc && it.target == selfRefAssoc.source && it.isBidirectional } }

if(selfRefAssociations.isNotEmpty()) {
if (selfRefAssociationsFiltered.isNotEmpty()) {
// Go through all self references requested and add them to the respective list.
addStatement("selfReferenceRequests.forEach { (clazz, unsatisfiedMap) -> ")
addStatement("\twhen(clazz) {")

selfRefAssociations
selfRefAssociationsFiltered
.forEach { selfRefAssoc ->
val entityName = selfRefAssoc.source.simpleName
val subjectIdName = "subject${entityName}Id"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package pl.touk.krush.model

import com.squareup.kotlinpoet.metadata.KotlinPoetMetadataPreview
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import pl.touk.krush.AnnotationProcessorTest
import javax.lang.model.util.Elements
import javax.lang.model.util.Types

@KotlinPoetMetadataPreview
class DFSTest(types: Types, elements: Elements) : AnnotationProcessorTest(types, elements), EntityGraphSampleData {

@Test
fun shouldVisitAllNodesUsingDFS() {
//given
val graphBuilder = oneToOneGraphBuilder(getTypeEnv())

//when
val graphs = graphBuilder.build()
val typeElement = oneToOneSourceEntity(getTypeEnv())

val elements = DFS(graphs).visit(typeElement)

//then
assertThat(elements)
.hasSize(2)
.extracting("type")
.containsOnly(typeElement, oneToOneTargetEntity(getTypeEnv()))
}

}
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ buildscript {
}

plugins {
id 'org.jetbrains.kotlin.jvm' version '1.4.31'
id 'org.jetbrains.kotlin.jvm' version '1.6.10'
id 'pl.allegro.tech.build.axion-release' version '1.13.6'
id 'maven-publish'
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ data class Category(
@Column
val name: String,

@OneToOne
@ManyToOne
@JoinColumn(name = "parent_id")
val parent: Category?,

Expand Down

0 comments on commit 770dcb0

Please # to comment.