35
35
import org .springframework .data .jpa .repository .Modifying ;
36
36
import org .springframework .data .jpa .repository .NativeQuery ;
37
37
import org .springframework .data .jpa .repository .QueryHints ;
38
+ import org .springframework .data .jpa .repository .QueryRewriter ;
38
39
import org .springframework .data .jpa .repository .query .DeclaredQuery ;
39
40
import org .springframework .data .jpa .repository .query .JpaQueryMethod ;
40
41
import org .springframework .data .jpa .repository .query .ParameterBinding ;
@@ -85,6 +86,7 @@ static class QueryBlockBuilder {
85
86
private @ Nullable AotEntityGraph entityGraph ;
86
87
private @ Nullable String sqlResultSetMapping ;
87
88
private @ Nullable Class <?> queryReturnType ;
89
+ private @ Nullable Class <?> queryRewriter = QueryRewriter .IdentityQueryRewriter .class ;
88
90
89
91
private QueryBlockBuilder (AotQueryMethodGenerationContext context , JpaQueryMethod queryMethod ) {
90
92
this .context = context ;
@@ -126,6 +128,11 @@ public QueryBlockBuilder queryReturnType(@Nullable Class<?> queryReturnType) {
126
128
return this ;
127
129
}
128
130
131
+ public QueryBlockBuilder queryRewriter (@ Nullable Class <?> queryRewriter ) {
132
+ this .queryRewriter = queryRewriter == null ? QueryRewriter .IdentityQueryRewriter .class : queryRewriter ;
133
+ return this ;
134
+ }
135
+
129
136
/**
130
137
* Build the query block.
131
138
*
@@ -145,12 +152,20 @@ public CodeBlock build() {
145
152
CodeBlock .Builder builder = CodeBlock .builder ();
146
153
builder .add ("\n " );
147
154
148
- String queryStringNameVariableName = null ;
155
+ String queryStringVariableName = null ;
156
+
157
+ String queryRewriterName = null ;
158
+
159
+ if (queries .result () instanceof StringAotQuery && queryRewriter != QueryRewriter .IdentityQueryRewriter .class ) {
160
+
161
+ queryRewriterName = "queryRewriter" ;
162
+ builder .addStatement ("$T $L = new $T()" , queryRewriter , queryRewriterName , queryRewriter );
163
+ }
149
164
150
165
if (queries != null && queries .result () instanceof StringAotQuery sq ) {
151
166
152
- queryStringNameVariableName = "%sString" .formatted (queryVariableName );
153
- builder .addStatement ( "$T $L = $S" , String . class , queryStringNameVariableName , sq . getQueryString ( ));
167
+ queryStringVariableName = "%sString" .formatted (queryVariableName );
168
+ builder .add ( buildQueryString ( sq , queryStringVariableName ));
154
169
}
155
170
156
171
String countQueryStringNameVariableName = null ;
@@ -159,7 +174,7 @@ public CodeBlock build() {
159
174
if (queryMethod .isPageQuery () && queries .count () instanceof StringAotQuery sq ) {
160
175
161
176
countQueryStringNameVariableName = "count%sString" .formatted (StringUtils .capitalize (queryVariableName ));
162
- builder .addStatement ( "$T $L = $S" , String . class , countQueryStringNameVariableName , sq . getQueryString ( ));
177
+ builder .add ( buildQueryString ( sq , countQueryStringNameVariableName ));
163
178
}
164
179
165
180
String sortParameterName = context .getSortParameterName ();
@@ -169,14 +184,14 @@ public CodeBlock build() {
169
184
170
185
if ((StringUtils .hasText (sortParameterName ) || StringUtils .hasText (dynamicReturnType ))
171
186
&& queries .result () instanceof StringAotQuery ) {
172
- builder .add (applyRewrite (sortParameterName , dynamicReturnType , queryStringNameVariableName , actualReturnType ));
187
+ builder .add (applyRewrite (sortParameterName , dynamicReturnType , queryStringVariableName , actualReturnType ));
173
188
}
174
189
175
190
if (queries .result ().hasExpression () || queries .count ().hasExpression ()) {
176
191
builder .addStatement ("class ExpressionMarker{}" );
177
192
}
178
193
179
- builder .add (createQuery (false , queryVariableName , queryStringNameVariableName , queries .result (),
194
+ builder .add (createQuery (false , queryVariableName , queryStringVariableName , queryRewriterName , queries .result (),
180
195
this .sqlResultSetMapping , this .queryHints , this .entityGraph , this .queryReturnType ));
181
196
182
197
builder .add (applyLimits (queries .result ().isExists ()));
@@ -187,7 +202,8 @@ public CodeBlock build() {
187
202
188
203
boolean queryHints = this .queryHints .isPresent () && this .queryHints .getBoolean ("forCounting" );
189
204
190
- builder .add (createQuery (true , countQueryVariableName , countQueryStringNameVariableName , queries .count (), null ,
205
+ builder .add (createQuery (true , countQueryVariableName , countQueryStringNameVariableName , queryRewriterName ,
206
+ queries .count (), null ,
191
207
queryHints ? this .queryHints : MergedAnnotation .missing (), null , Long .class ));
192
208
builder .addStatement ("return ($T) $L.getSingleResult()" , Long .class , countQueryVariableName );
193
209
@@ -199,6 +215,13 @@ public CodeBlock build() {
199
215
return builder .build ();
200
216
}
201
217
218
+ private CodeBlock buildQueryString (StringAotQuery sq , String queryStringVariableName ) {
219
+
220
+ CodeBlock .Builder builder = CodeBlock .builder ();
221
+ builder .addStatement ("$T $L = $S" , String .class , queryStringVariableName , sq .getQueryString ());
222
+ return builder .build ();
223
+ }
224
+
202
225
private CodeBlock applyRewrite (@ Nullable String sort , @ Nullable String dynamicReturnType , String queryString ,
203
226
Class <?> actualReturnType ) {
204
227
@@ -268,12 +291,14 @@ private CodeBlock applyLimits(boolean exists) {
268
291
}
269
292
270
293
private CodeBlock createQuery (boolean count , String queryVariableName , @ Nullable String queryStringNameVariableName ,
271
- AotQuery query , @ Nullable String sqlResultSetMapping , MergedAnnotation <QueryHints > queryHints ,
294
+ @ Nullable String queryRewriterName , AotQuery query , @ Nullable String sqlResultSetMapping ,
295
+ MergedAnnotation <QueryHints > queryHints ,
272
296
@ Nullable AotEntityGraph entityGraph , @ Nullable Class <?> queryReturnType ) {
273
297
274
298
Builder builder = CodeBlock .builder ();
275
299
276
- builder .add (doCreateQuery (count , queryVariableName , queryStringNameVariableName , query , sqlResultSetMapping ,
300
+ builder .add (doCreateQuery (count , queryVariableName , queryStringNameVariableName , queryRewriterName , query ,
301
+ sqlResultSetMapping ,
277
302
queryReturnType ));
278
303
279
304
if (entityGraph != null ) {
@@ -306,18 +331,36 @@ private CodeBlock createQuery(boolean count, String queryVariableName, @Nullable
306
331
}
307
332
308
333
private CodeBlock doCreateQuery (boolean count , String queryVariableName ,
309
- @ Nullable String queryStringNameVariableName , AotQuery query , @ Nullable String sqlResultSetMapping ,
334
+ @ Nullable String queryStringName , @ Nullable String queryRewriterName , AotQuery query ,
335
+ @ Nullable String sqlResultSetMapping ,
310
336
@ Nullable Class <?> queryReturnType ) {
311
337
312
338
ReturnedType returnedType = context .getReturnedType ();
313
339
Builder builder = CodeBlock .builder ();
340
+ String queryStringNameToUse = queryStringName ;
314
341
315
342
if (query instanceof StringAotQuery sq ) {
316
343
344
+ if (StringUtils .hasText (queryRewriterName )) {
345
+
346
+ queryStringNameToUse = queryStringName + "Rewritten" ;
347
+
348
+ if (StringUtils .hasText (context .getPageableParameterName ())) {
349
+ builder .addStatement ("$T $L = $L.rewrite($L, $L)" , String .class , queryStringNameToUse , queryRewriterName ,
350
+ queryStringName , context .getPageableParameterName ());
351
+ } else if (StringUtils .hasText (context .getSortParameterName ())) {
352
+ builder .addStatement ("$T $L = $L.rewrite($L, $L)" , String .class , queryStringNameToUse , queryRewriterName ,
353
+ queryStringName , context .getSortParameterName ());
354
+ } else {
355
+ builder .addStatement ("$T $L = $L.rewrite($L, $T.unsorted())" , String .class , queryStringNameToUse ,
356
+ queryRewriterName , queryStringName , Sort .class );
357
+ }
358
+ }
359
+
317
360
if (StringUtils .hasText (sqlResultSetMapping )) {
318
361
319
362
builder .addStatement ("$T $L = this.$L.createNativeQuery($L, $S)" , Query .class , queryVariableName ,
320
- context .fieldNameOf (EntityManager .class ), queryStringNameVariableName , sqlResultSetMapping );
363
+ context .fieldNameOf (EntityManager .class ), queryStringNameToUse , sqlResultSetMapping );
321
364
322
365
return builder .build ();
323
366
}
@@ -327,10 +370,10 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
327
370
if (queryReturnType != null ) {
328
371
329
372
builder .addStatement ("$T $L = this.$L.createNativeQuery($L, $T.class)" , Query .class , queryVariableName ,
330
- context .fieldNameOf (EntityManager .class ), queryStringNameVariableName , queryReturnType );
373
+ context .fieldNameOf (EntityManager .class ), queryStringNameToUse , queryReturnType );
331
374
} else {
332
375
builder .addStatement ("$T $L = this.$L.createNativeQuery($L)" , Query .class , queryVariableName ,
333
- context .fieldNameOf (EntityManager .class ), queryStringNameVariableName );
376
+ context .fieldNameOf (EntityManager .class ), queryStringNameToUse );
334
377
}
335
378
336
379
return builder .build ();
@@ -339,18 +382,18 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
339
382
if (sq .hasConstructorExpressionOrDefaultProjection () && !count && returnedType .isProjecting ()
340
383
&& returnedType .getReturnedType ().isInterface ()) {
341
384
builder .addStatement ("$T $L = this.$L.createQuery($L)" , Query .class , queryVariableName ,
342
- context .fieldNameOf (EntityManager .class ), queryStringNameVariableName );
385
+ context .fieldNameOf (EntityManager .class ), queryStringNameToUse );
343
386
} else {
344
387
345
388
String createQueryMethod = query .isNative () ? "createNativeQuery" : "createQuery" ;
346
389
347
390
if (!sq .hasConstructorExpressionOrDefaultProjection () && !count && returnedType .isProjecting ()
348
391
&& returnedType .getReturnedType ().isInterface ()) {
349
392
builder .addStatement ("$T $L = this.$L.$L($L, $T.class)" , Query .class , queryVariableName ,
350
- context .fieldNameOf (EntityManager .class ), createQueryMethod , queryStringNameVariableName , Tuple .class );
393
+ context .fieldNameOf (EntityManager .class ), createQueryMethod , queryStringNameToUse , Tuple .class );
351
394
} else {
352
395
builder .addStatement ("$T $L = this.$L.$L($L)" , Query .class , queryVariableName ,
353
- context .fieldNameOf (EntityManager .class ), createQueryMethod , queryStringNameVariableName );
396
+ context .fieldNameOf (EntityManager .class ), createQueryMethod , queryStringNameToUse );
354
397
}
355
398
}
356
399
0 commit comments