Skip to content

Commit 20cda03

Browse files
committed
Add support for QueryRewriter.
See #3830
1 parent ff8754d commit 20cda03

File tree

5 files changed

+100
-19
lines changed

5 files changed

+100
-19
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/JpaCodeBlocks.java

+59-16
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.data.jpa.repository.Modifying;
3636
import org.springframework.data.jpa.repository.NativeQuery;
3737
import org.springframework.data.jpa.repository.QueryHints;
38+
import org.springframework.data.jpa.repository.QueryRewriter;
3839
import org.springframework.data.jpa.repository.query.DeclaredQuery;
3940
import org.springframework.data.jpa.repository.query.JpaQueryMethod;
4041
import org.springframework.data.jpa.repository.query.ParameterBinding;
@@ -85,6 +86,7 @@ static class QueryBlockBuilder {
8586
private @Nullable AotEntityGraph entityGraph;
8687
private @Nullable String sqlResultSetMapping;
8788
private @Nullable Class<?> queryReturnType;
89+
private @Nullable Class<?> queryRewriter = QueryRewriter.IdentityQueryRewriter.class;
8890

8991
private QueryBlockBuilder(AotQueryMethodGenerationContext context, JpaQueryMethod queryMethod) {
9092
this.context = context;
@@ -126,6 +128,11 @@ public QueryBlockBuilder queryReturnType(@Nullable Class<?> queryReturnType) {
126128
return this;
127129
}
128130

131+
public QueryBlockBuilder queryRewriter(@Nullable Class<?> queryRewriter) {
132+
this.queryRewriter = queryRewriter == null ? QueryRewriter.IdentityQueryRewriter.class : queryRewriter;
133+
return this;
134+
}
135+
129136
/**
130137
* Build the query block.
131138
*
@@ -145,12 +152,20 @@ public CodeBlock build() {
145152
CodeBlock.Builder builder = CodeBlock.builder();
146153
builder.add("\n");
147154

148-
String queryStringNameVariableName = null;
155+
String queryStringVariableName = null;
156+
157+
String queryRewriterName = null;
158+
159+
if (queries.result() instanceof StringAotQuery && queryRewriter != QueryRewriter.IdentityQueryRewriter.class) {
160+
161+
queryRewriterName = "queryRewriter";
162+
builder.addStatement("$T $L = new $T()", queryRewriter, queryRewriterName, queryRewriter);
163+
}
149164

150165
if (queries != null && queries.result() instanceof StringAotQuery sq) {
151166

152-
queryStringNameVariableName = "%sString".formatted(queryVariableName);
153-
builder.addStatement("$T $L = $S", String.class, queryStringNameVariableName, sq.getQueryString());
167+
queryStringVariableName = "%sString".formatted(queryVariableName);
168+
builder.add(buildQueryString(sq, queryStringVariableName));
154169
}
155170

156171
String countQueryStringNameVariableName = null;
@@ -159,7 +174,7 @@ public CodeBlock build() {
159174
if (queryMethod.isPageQuery() && queries.count() instanceof StringAotQuery sq) {
160175

161176
countQueryStringNameVariableName = "count%sString".formatted(StringUtils.capitalize(queryVariableName));
162-
builder.addStatement("$T $L = $S", String.class, countQueryStringNameVariableName, sq.getQueryString());
177+
builder.add(buildQueryString(sq, countQueryStringNameVariableName));
163178
}
164179

165180
String sortParameterName = context.getSortParameterName();
@@ -169,14 +184,14 @@ public CodeBlock build() {
169184

170185
if ((StringUtils.hasText(sortParameterName) || StringUtils.hasText(dynamicReturnType))
171186
&& queries.result() instanceof StringAotQuery) {
172-
builder.add(applyRewrite(sortParameterName, dynamicReturnType, queryStringNameVariableName, actualReturnType));
187+
builder.add(applyRewrite(sortParameterName, dynamicReturnType, queryStringVariableName, actualReturnType));
173188
}
174189

175190
if (queries.result().hasExpression() || queries.count().hasExpression()) {
176191
builder.addStatement("class ExpressionMarker{}");
177192
}
178193

179-
builder.add(createQuery(false, queryVariableName, queryStringNameVariableName, queries.result(),
194+
builder.add(createQuery(false, queryVariableName, queryStringVariableName, queryRewriterName, queries.result(),
180195
this.sqlResultSetMapping, this.queryHints, this.entityGraph, this.queryReturnType));
181196

182197
builder.add(applyLimits(queries.result().isExists()));
@@ -187,7 +202,8 @@ public CodeBlock build() {
187202

188203
boolean queryHints = this.queryHints.isPresent() && this.queryHints.getBoolean("forCounting");
189204

190-
builder.add(createQuery(true, countQueryVariableName, countQueryStringNameVariableName, queries.count(), null,
205+
builder.add(createQuery(true, countQueryVariableName, countQueryStringNameVariableName, queryRewriterName,
206+
queries.count(), null,
191207
queryHints ? this.queryHints : MergedAnnotation.missing(), null, Long.class));
192208
builder.addStatement("return ($T) $L.getSingleResult()", Long.class, countQueryVariableName);
193209

@@ -199,6 +215,13 @@ public CodeBlock build() {
199215
return builder.build();
200216
}
201217

218+
private CodeBlock buildQueryString(StringAotQuery sq, String queryStringVariableName) {
219+
220+
CodeBlock.Builder builder = CodeBlock.builder();
221+
builder.addStatement("$T $L = $S", String.class, queryStringVariableName, sq.getQueryString());
222+
return builder.build();
223+
}
224+
202225
private CodeBlock applyRewrite(@Nullable String sort, @Nullable String dynamicReturnType, String queryString,
203226
Class<?> actualReturnType) {
204227

@@ -268,12 +291,14 @@ private CodeBlock applyLimits(boolean exists) {
268291
}
269292

270293
private CodeBlock createQuery(boolean count, String queryVariableName, @Nullable String queryStringNameVariableName,
271-
AotQuery query, @Nullable String sqlResultSetMapping, MergedAnnotation<QueryHints> queryHints,
294+
@Nullable String queryRewriterName, AotQuery query, @Nullable String sqlResultSetMapping,
295+
MergedAnnotation<QueryHints> queryHints,
272296
@Nullable AotEntityGraph entityGraph, @Nullable Class<?> queryReturnType) {
273297

274298
Builder builder = CodeBlock.builder();
275299

276-
builder.add(doCreateQuery(count, queryVariableName, queryStringNameVariableName, query, sqlResultSetMapping,
300+
builder.add(doCreateQuery(count, queryVariableName, queryStringNameVariableName, queryRewriterName, query,
301+
sqlResultSetMapping,
277302
queryReturnType));
278303

279304
if (entityGraph != null) {
@@ -306,18 +331,36 @@ private CodeBlock createQuery(boolean count, String queryVariableName, @Nullable
306331
}
307332

308333
private CodeBlock doCreateQuery(boolean count, String queryVariableName,
309-
@Nullable String queryStringNameVariableName, AotQuery query, @Nullable String sqlResultSetMapping,
334+
@Nullable String queryStringName, @Nullable String queryRewriterName, AotQuery query,
335+
@Nullable String sqlResultSetMapping,
310336
@Nullable Class<?> queryReturnType) {
311337

312338
ReturnedType returnedType = context.getReturnedType();
313339
Builder builder = CodeBlock.builder();
340+
String queryStringNameToUse = queryStringName;
314341

315342
if (query instanceof StringAotQuery sq) {
316343

344+
if (StringUtils.hasText(queryRewriterName)) {
345+
346+
queryStringNameToUse = queryStringName + "Rewritten";
347+
348+
if (StringUtils.hasText(context.getPageableParameterName())) {
349+
builder.addStatement("$T $L = $L.rewrite($L, $L)", String.class, queryStringNameToUse, queryRewriterName,
350+
queryStringName, context.getPageableParameterName());
351+
} else if (StringUtils.hasText(context.getSortParameterName())) {
352+
builder.addStatement("$T $L = $L.rewrite($L, $L)", String.class, queryStringNameToUse, queryRewriterName,
353+
queryStringName, context.getSortParameterName());
354+
} else {
355+
builder.addStatement("$T $L = $L.rewrite($L, $T.unsorted())", String.class, queryStringNameToUse,
356+
queryRewriterName, queryStringName, Sort.class);
357+
}
358+
}
359+
317360
if (StringUtils.hasText(sqlResultSetMapping)) {
318361

319362
builder.addStatement("$T $L = this.$L.createNativeQuery($L, $S)", Query.class, queryVariableName,
320-
context.fieldNameOf(EntityManager.class), queryStringNameVariableName, sqlResultSetMapping);
363+
context.fieldNameOf(EntityManager.class), queryStringNameToUse, sqlResultSetMapping);
321364

322365
return builder.build();
323366
}
@@ -327,10 +370,10 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
327370
if (queryReturnType != null) {
328371

329372
builder.addStatement("$T $L = this.$L.createNativeQuery($L, $T.class)", Query.class, queryVariableName,
330-
context.fieldNameOf(EntityManager.class), queryStringNameVariableName, queryReturnType);
373+
context.fieldNameOf(EntityManager.class), queryStringNameToUse, queryReturnType);
331374
} else {
332375
builder.addStatement("$T $L = this.$L.createNativeQuery($L)", Query.class, queryVariableName,
333-
context.fieldNameOf(EntityManager.class), queryStringNameVariableName);
376+
context.fieldNameOf(EntityManager.class), queryStringNameToUse);
334377
}
335378

336379
return builder.build();
@@ -339,18 +382,18 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
339382
if (sq.hasConstructorExpressionOrDefaultProjection() && !count && returnedType.isProjecting()
340383
&& returnedType.getReturnedType().isInterface()) {
341384
builder.addStatement("$T $L = this.$L.createQuery($L)", Query.class, queryVariableName,
342-
context.fieldNameOf(EntityManager.class), queryStringNameVariableName);
385+
context.fieldNameOf(EntityManager.class), queryStringNameToUse);
343386
} else {
344387

345388
String createQueryMethod = query.isNative() ? "createNativeQuery" : "createQuery";
346389

347390
if (!sq.hasConstructorExpressionOrDefaultProjection() && !count && returnedType.isProjecting()
348391
&& returnedType.getReturnedType().isInterface()) {
349392
builder.addStatement("$T $L = this.$L.$L($L, $T.class)", Query.class, queryVariableName,
350-
context.fieldNameOf(EntityManager.class), createQueryMethod, queryStringNameVariableName, Tuple.class);
393+
context.fieldNameOf(EntityManager.class), createQueryMethod, queryStringNameToUse, Tuple.class);
351394
} else {
352395
builder.addStatement("$T $L = this.$L.$L($L)", Query.class, queryVariableName,
353-
context.fieldNameOf(EntityManager.class), createQueryMethod, queryStringNameVariableName);
396+
context.fieldNameOf(EntityManager.class), createQueryMethod, queryStringNameToUse);
354397
}
355398
}
356399

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/JpaRepositoryContributor.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ protected void customizeClass(RepositoryInformation information, AotRepositoryFr
8181
@Override
8282
protected void customizeConstructor(AotRepositoryConstructorBuilder constructorBuilder) {
8383

84+
// TODO: BeanFactoryQueryRewriterProvider if there is a method using QueryRewriters.
85+
8486
constructorBuilder.addParameter("entityManager", EntityManager.class);
8587
constructorBuilder.addParameter("context", RepositoryFactoryBeanSupport.FragmentCreationContext.class);
8688

@@ -149,7 +151,8 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB
149151

150152
body.add(JpaCodeBlocks.queryBuilder(context, queryMethod).filter(aotQueries)
151153
.queryReturnType(QueriesFactory.getQueryReturnType(aotQueries.result(), returnedType, context))
152-
.nativeQuery(nativeQuery).queryHints(queryHints).entityGraph(aotEntityGraph).build());
154+
.nativeQuery(nativeQuery).queryHints(queryHints).entityGraph(aotEntityGraph)
155+
.queryRewriter(query.isPresent() ? query.getClass("queryRewriter") : null).build());
153156

154157
body.add(
155158
JpaCodeBlocks.executionBuilder(context, queryMethod).modifying(modifying).query(aotQueries.result()).build());

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/aot/JpaRepositoryContributorIntegrationTests.java

+13
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.data.domain.Limit;
3434
import org.springframework.data.domain.Page;
3535
import org.springframework.data.domain.PageRequest;
36+
import org.springframework.data.domain.Pageable;
3637
import org.springframework.data.domain.Slice;
3738
import org.springframework.data.domain.Sort;
3839
import org.springframework.data.jpa.domain.sample.Role;
@@ -624,6 +625,18 @@ void shouldQuerySubtype() {
624625
assertThat(result).isInstanceOf(SpecialUser.class);
625626
}
626627

628+
@Test
629+
void shouldApplyQueryRewriter() {
630+
631+
User result = fragment.findAndApplyQueryRewriter(kylo.getEmailAddress());
632+
633+
assertThat(result).isNotNull();
634+
635+
Page<User> page = fragment.findAndApplyQueryRewriter(kylo.getEmailAddress(), Pageable.unpaged());
636+
637+
assertThat(page).isNotEmpty();
638+
}
639+
627640
void todo() {
628641

629642
// dynamic projections: Not implemented

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/aot/UserRepository.java

+19
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.data.jpa.repository.NativeQuery;
3333
import org.springframework.data.jpa.repository.Query;
3434
import org.springframework.data.jpa.repository.QueryHints;
35+
import org.springframework.data.jpa.repository.QueryRewriter;
3536
import org.springframework.data.repository.CrudRepository;
3637

3738
/**
@@ -229,11 +230,29 @@ interface UserRepository extends CrudRepository<User, Integer> {
229230
@Query("select u from User u where u.emailAddress = ?1 AND TYPE(u) = ?2")
230231
<T extends User> T findByEmailAddress(String emailAddress, Class<T> type);
231232

233+
@Query(value = "select u from PLACEHOLDER u where u.emailAddress = ?1", queryRewriter = MyQueryRewriter.class)
234+
User findAndApplyQueryRewriter(String emailAddress);
235+
236+
@Query(value = "select u from OTHER u where u.emailAddress = ?1", queryRewriter = MyQueryRewriter.class)
237+
Page<User> findAndApplyQueryRewriter(String emailAddress, Pageable pageable);
238+
232239
interface EmailOnly {
233240
String getEmailAddress();
234241
}
235242

236243
record Names(String firstname, String lastname) {
237244
}
238245

246+
static class MyQueryRewriter implements QueryRewriter {
247+
248+
@Override
249+
public String rewrite(String query, Sort sort) {
250+
return query.replaceAll("PLACEHOLDER", "User");
251+
}
252+
253+
@Override
254+
public String rewrite(String query, Pageable pageRequest) {
255+
return query.replaceAll("OTHER", "User");
256+
}
257+
}
239258
}

src/main/antora/modules/ROOT/pages/jpa/aot.adoc

+5-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ This optimization moves query method processing from runtime to build-time, whic
3838
The resulting AOT repository fragment follows the naming scheme of `<Repository FQCN>Impl__Aot` and is placed in the same package as the repository interface.
3939
You can find all queries in their String form for generated repository query methods.
4040

41+
NOTE: Consider AOT repository classes an internal optimization.
42+
Do not use them directly in your code as generation and implementation details may change in future releases.
43+
4144
=== Running with AOT Repositories
4245

4346
AOT is a mandatory step to transform a Spring application to a native executable, so it is automatically enabled when running in this mode.
@@ -79,9 +82,9 @@ Mind that using Value Expressions requires expression parsing and contextual inf
7982
* `CrudRepository` and other base interface methods
8083
* Querydsl and Query by Example methods
8184
* Methods whose implementation would be overly complex
82-
** Methods accepting `ScrollPosition (e.g. `Keyset` pagination)
85+
** Methods accepting `ScrollPosition` (e.g. `Keyset` pagination)
8386
** Stored procedure query methods annotated with `@Procedure`
84-
** For now: Dynamic and interface projections
87+
** Dynamic projections
8588

8689
[[aot.repositories.json]]
8790
== Repository Metadata

0 commit comments

Comments
 (0)