22
22
23
23
import org .bson .Document ;
24
24
import org .jspecify .annotations .NullUnmarked ;
25
-
26
25
import org .springframework .core .ResolvableType ;
27
26
import org .springframework .core .annotation .MergedAnnotation ;
28
27
import org .springframework .data .domain .SliceImpl ;
29
28
import org .springframework .data .domain .Sort .Order ;
30
29
import org .springframework .data .mongodb .core .MongoOperations ;
31
30
import org .springframework .data .mongodb .core .aggregation .Aggregation ;
31
+ import org .springframework .data .mongodb .core .aggregation .AggregationOperation ;
32
32
import org .springframework .data .mongodb .core .aggregation .AggregationOptions ;
33
33
import org .springframework .data .mongodb .core .aggregation .AggregationPipeline ;
34
34
import org .springframework .data .mongodb .core .aggregation .AggregationResults ;
@@ -80,12 +80,7 @@ CodeBlock build() {
80
80
81
81
builder .add ("\n " );
82
82
83
- Class <?> outputType = queryMethod .getReturnedObjectType ();
84
- if (MongoSimpleTypes .HOLDER .isSimpleType (outputType )) {
85
- outputType = Document .class ;
86
- } else if (ClassUtils .isAssignable (AggregationResults .class , outputType )) {
87
- outputType = queryMethod .getReturnType ().getComponentType ().getType ();
88
- }
83
+ Class <?> outputType = getOutputType (queryMethod );
89
84
90
85
if (ReflectionUtils .isVoid (queryMethod .getReturnedObjectType ())) {
91
86
builder .addStatement ("$L.aggregate($L, $T.class)" , mongoOpsRef , aggregationVariableName , outputType );
@@ -146,7 +141,6 @@ CodeBlock build() {
146
141
builder .addStatement ("return $L.aggregateStream($L, $T.class)" , mongoOpsRef , aggregationVariableName ,
147
142
outputType );
148
143
} else {
149
-
150
144
builder .addStatement ("return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
151
145
aggregationVariableName , outputType );
152
146
}
@@ -155,6 +149,17 @@ CodeBlock build() {
155
149
156
150
return builder .build ();
157
151
}
152
+
153
+ }
154
+
155
+ private static Class <?> getOutputType (MongoQueryMethod queryMethod ) {
156
+ Class <?> outputType = queryMethod .getReturnedObjectType ();
157
+ if (MongoSimpleTypes .HOLDER .isSimpleType (outputType )) {
158
+ outputType = Document .class ;
159
+ } else if (ClassUtils .isAssignable (AggregationResults .class , outputType ) && queryMethod .getReturnType ().getComponentType () != null ) {
160
+ outputType = queryMethod .getReturnType ().getComponentType ().getType ();
161
+ }
162
+ return outputType ;
158
163
}
159
164
160
165
@ NullUnmarked
@@ -173,13 +178,7 @@ static class AggregationCodeBlockBuilder {
173
178
174
179
this .context = context ;
175
180
this .queryMethod = queryMethod ;
176
- String parameterNames = StringUtils .collectionToDelimitedString (context .getAllParameterNames (), ", " );
177
-
178
- if (StringUtils .hasText (parameterNames )) {
179
- this .parameterNames = ", " + parameterNames ;
180
- } else {
181
- this .parameterNames = "" ;
182
- }
181
+ this .parameterNames = StringUtils .collectionToDelimitedString (context .getAllParameterNames (), ", " );
183
182
}
184
183
185
184
AggregationCodeBlockBuilder stages (AggregationInteraction aggregation ) {
@@ -231,7 +230,8 @@ private CodeBlock pipeline(String pipelineVariableName) {
231
230
builder .add (aggregationStages (context .localVariable ("stages" ), source .stages ()));
232
231
233
232
if (StringUtils .hasText (sortParameter )) {
234
- builder .add (sortingStage (sortParameter ));
233
+ Class <?> outputType = getOutputType (queryMethod );
234
+ builder .add (sortingStage (sortParameter , outputType ));
235
235
}
236
236
237
237
if (StringUtils .hasText (limitParameter )) {
@@ -244,6 +244,7 @@ private CodeBlock pipeline(String pipelineVariableName) {
244
244
245
245
builder .addStatement ("$T $L = createPipeline($L)" , AggregationPipeline .class , pipelineVariableName ,
246
246
context .localVariable ("stages" ));
247
+
247
248
return builder .build ();
248
249
}
249
250
@@ -312,7 +313,7 @@ private CodeBlock aggregationStages(String stageListVariableName, Collection<Str
312
313
return builder .build ();
313
314
}
314
315
315
- private CodeBlock sortingStage (String sortProvider ) {
316
+ private CodeBlock sortingStage (String sortProvider , Class <?> outputType ) {
316
317
317
318
Builder builder = CodeBlock .builder ();
318
319
@@ -322,8 +323,17 @@ private CodeBlock sortingStage(String sortProvider) {
322
323
builder .addStatement ("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);" ,
323
324
context .localVariable ("sortDocument" ), context .localVariable ("order" ));
324
325
builder .endControlFlow ();
325
- builder .addStatement ("stages.add(new $T($S, $L))" , Document .class , "$sort" ,
326
- context .localVariable ("sortDocument" ));
326
+
327
+ if (outputType == Document .class || MongoSimpleTypes .HOLDER .isSimpleType (outputType )
328
+ || ClassUtils .isAssignable (context .getRepositoryInformation ().getDomainType (), outputType )) {
329
+ builder .addStatement ("$L.add(new $T($S, $L))" , context .localVariable ("stages" ), Document .class , "$sort" ,
330
+ context .localVariable ("sortDocument" ));
331
+ } else {
332
+ builder .addStatement ("$L.add(($T) _ctx -> new $T($S, _ctx.getMappedObject($L, $T.class)))" ,
333
+ context .localVariable ("stages" ), AggregationOperation .class , Document .class , "$sort" ,
334
+ context .localVariable ("sortDocument" ), outputType );
335
+ }
336
+
327
337
builder .endControlFlow ();
328
338
329
339
return builder .build ();
@@ -333,7 +343,7 @@ private CodeBlock pagingStage(String pageableProvider, boolean slice) {
333
343
334
344
Builder builder = CodeBlock .builder ();
335
345
336
- builder .add (sortingStage (pageableProvider + ".getSort()" ));
346
+ builder .add (sortingStage (pageableProvider + ".getSort()" , getOutputType ( queryMethod ) ));
337
347
338
348
builder .beginControlFlow ("if ($L.isPaged())" , pageableProvider );
339
349
builder .beginControlFlow ("if ($L.getOffset() > 0)" , pageableProvider );
0 commit comments