Skip to content

Commit 05c6ba9

Browse files
committed
Polishing.
Refactor evaluation handling, align with JPA using MongoParameters and the originating Method. Add support for reactive GeoNear count.
1 parent 9b4d2c9 commit 05c6ba9

29 files changed

+481
-449
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveFindOperation.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ interface TerminatingFindNear<T> {
194194
* @return never {@literal null}.
195195
*/
196196
Flux<GeoResult<T>> all();
197+
198+
/**
199+
* Count matching elements.
200+
*
201+
* @return number of elements matching the query.
202+
* @since 5.0
203+
*/
204+
Mono<Long> count();
197205
}
198206

199207
/**

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveFindOperationSupport.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ public <R> TerminatingFindNear<R> map(QueryResultConverter<? super G, ? extends
239239
public Flux<GeoResult<G>> all() {
240240
return template.doGeoNear(nearQuery, domainType, getCollectionName(), returnType, resultConverter);
241241
}
242+
243+
@Override
244+
public Mono<Long> count() {
245+
return template.doGeoNearCount(nearQuery, domainType, getCollectionName());
246+
}
242247
}
243248

244249
/**

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import org.jspecify.annotations.Nullable;
4848
import org.reactivestreams.Publisher;
4949
import org.reactivestreams.Subscriber;
50+
5051
import org.springframework.beans.BeansException;
5152
import org.springframework.context.ApplicationContext;
5253
import org.springframework.context.ApplicationContextAware;
@@ -1015,30 +1016,38 @@ public <O> Flux<O> aggregate(Aggregation aggregation, String collectionName, Cla
10151016

10161017
protected <O> Flux<O> doAggregate(Aggregation aggregation, String collectionName, @Nullable Class<?> inputType,
10171018
Class<O> outputType) {
1018-
return doAggregate(aggregation, collectionName, inputType, outputType, QueryResultConverter.entity());
1019+
1020+
AggregationDefinition context = queryOperations.createAggregation(aggregation, inputType);
1021+
return doAggregate(aggregation, collectionName, outputType, QueryResultConverter.entity(), context);
10191022
}
10201023

10211024
<T, O> Flux<O> doAggregate(Aggregation aggregation, String collectionName, @Nullable Class<?> inputType,
10221025
Class<T> outputType, QueryResultConverter<? super T, ? extends O> resultConverter) {
10231026

1027+
AggregationDefinition context = queryOperations.createAggregation(aggregation, inputType);
1028+
return doAggregate(aggregation, collectionName, outputType, resultConverter, context);
1029+
}
1030+
1031+
<T, O> Flux<O> doAggregate(Aggregation aggregation, String collectionName, Class<T> outputType,
1032+
QueryResultConverter<? super T, ? extends O> resultConverter, AggregationDefinition definition) {
1033+
10241034
Assert.notNull(aggregation, "Aggregation pipeline must not be null");
10251035
Assert.hasText(collectionName, "Collection name must not be null or empty");
10261036
Assert.notNull(outputType, "Output type must not be null");
10271037

10281038
AggregationOptions options = aggregation.getOptions();
10291039
Assert.isTrue(!options.isExplain(), "Cannot use explain option with streaming");
10301040

1031-
AggregationDefinition ctx = queryOperations.createAggregation(aggregation, inputType);
10321041

10331042
if (LOGGER.isDebugEnabled()) {
10341043
LOGGER.debug(String.format("Streaming aggregation: %s in collection %s",
1035-
serializeToJsonSafely(ctx.getAggregationPipeline()), collectionName));
1044+
serializeToJsonSafely(definition.getAggregationPipeline()), collectionName));
10361045
}
10371046

10381047
DocumentCallback<O> readCallback = new QueryResultConverterCallback<>(resultConverter,
10391048
new ReadDocumentCallback<>(mongoConverter, outputType, collectionName));
1040-
return execute(collectionName, collection -> aggregateAndMap(collection, ctx.getAggregationPipeline(),
1041-
ctx.isOutOrMerge(), options, readCallback, ctx.getInputType()));
1049+
return execute(collectionName, collection -> aggregateAndMap(collection, definition.getAggregationPipeline(),
1050+
definition.isOutOrMerge(), options, readCallback, definition.getInputType()));
10421051
}
10431052

10441053
private <O> Flux<O> aggregateAndMap(MongoCollection<Document> collection, List<Document> pipeline,
@@ -1093,6 +1102,33 @@ protected <T> Flux<GeoResult<T>> geoNear(NearQuery near, Class<?> entityClass, S
10931102
return doGeoNear(near, entityClass, collectionName, returnType, QueryResultConverter.entity());
10941103
}
10951104

1105+
Mono<Long> doGeoNearCount(NearQuery near, Class<?> domainType, String collectionName) {
1106+
1107+
Builder optionsBuilder = AggregationOptions.builder().collation(near.getCollation());
1108+
1109+
if (near.hasReadPreference()) {
1110+
optionsBuilder.readPreference(near.getReadPreference());
1111+
}
1112+
1113+
if (near.hasReadConcern()) {
1114+
optionsBuilder.readConcern(near.getReadConcern());
1115+
}
1116+
1117+
String distanceField = operations.nearQueryDistanceFieldName(domainType);
1118+
Aggregation $geoNear = TypedAggregation.newAggregation(domainType,
1119+
Aggregation.geoNear(near, distanceField).skip(-1).limit(-1), Aggregation.count().as("_totalCount"))
1120+
.withOptions(optionsBuilder.build());
1121+
1122+
AggregationDefinition definition = queryOperations.createAggregation($geoNear, (AggregationOperationContext) null);
1123+
1124+
Flux<Document> results = doAggregate($geoNear, collectionName, Document.class, QueryResultConverter.entity(),
1125+
definition);
1126+
1127+
return results.last()
1128+
.map(doc -> NumberUtils.convertNumberToTargetClass(doc.get("_totalCount", Integer.class), Long.class))
1129+
.defaultIfEmpty(0L);
1130+
}
1131+
10961132
@SuppressWarnings("unchecked")
10971133
<T, R> Flux<GeoResult<R>> doGeoNear(NearQuery near, Class<?> entityClass, String collectionName, Class<T> returnType,
10981134
QueryResultConverter<? super T, ? extends R> resultConverter) {

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/CriteriaDefinition.java

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ interface Placeholder {
5050

5151
/**
5252
* Create a new placeholder for index bindable parameter.
53-
*
53+
*
5454
* @param position the index of the parameter to bind.
5555
* @return new instance of {@link Placeholder}.
5656
*/
@@ -65,20 +65,4 @@ static Placeholder placeholder(String expression) {
6565
Object getValue();
6666
}
6767

68-
static class PlaceholderImpl implements Placeholder {
69-
private final Object expression;
70-
71-
public PlaceholderImpl(Object expression) {
72-
this.expression = expression;
73-
}
74-
75-
@Override
76-
public Object getValue() {
77-
return expression;
78-
}
79-
80-
public String toString() {
81-
return getValue().toString();
82-
}
83-
}
8468
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/GeoCommand.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,9 @@ private String getCommand(Shape shape) {
7676

7777
Assert.notNull(shape, "Shape must not be null");
7878

79-
if(shape instanceof GeoJson<?>) {
79+
if (shape instanceof GeoJson<?>) {
8080
return "$geometry";
81-
}
82-
if (shape instanceof Box) {
81+
} else if (shape instanceof Box) {
8382
return "$box";
8483
} else if (shape instanceof Circle) {
8584
return "$center";
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.mongodb.core.query;
17+
18+
/**
19+
* @author Christoph Strobl
20+
* @since 5.0
21+
*/
22+
record PlaceholderImpl(Object expression) implements CriteriaDefinition.Placeholder {
23+
24+
@Override
25+
public Object getValue() {
26+
return expression;
27+
}
28+
29+
public String toString() {
30+
return getValue().toString();
31+
}
32+
33+
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
package org.springframework.data.mongodb.repository.aot;
1717

1818
import java.util.ArrayList;
19-
import java.util.LinkedHashMap;
19+
import java.util.Collection;
2020
import java.util.List;
21-
import java.util.Map;
2221
import java.util.stream.Stream;
2322

2423
import org.bson.Document;
2524
import org.jspecify.annotations.NullUnmarked;
25+
2626
import org.springframework.core.ResolvableType;
2727
import org.springframework.core.annotation.MergedAnnotation;
2828
import org.springframework.data.domain.SliceImpl;
@@ -47,6 +47,8 @@
4747
import org.springframework.util.StringUtils;
4848

4949
/**
50+
* Code blocks for building aggregation pipelines and execution statements for MongoDB repositories.
51+
*
5052
* @author Christoph Strobl
5153
* @since 5.0
5254
*/
@@ -160,7 +162,7 @@ static class AggregationCodeBlockBuilder {
160162

161163
private final AotQueryMethodGenerationContext context;
162164
private final MongoQueryMethod queryMethod;
163-
private final Map<String, CodeBlock> arguments;
165+
private final String parameterNames;
164166

165167
private AggregationInteraction source;
166168

@@ -170,9 +172,8 @@ static class AggregationCodeBlockBuilder {
170172
AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
171173

172174
this.context = context;
173-
this.arguments = new LinkedHashMap<>();
174-
context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it)));
175175
this.queryMethod = queryMethod;
176+
this.parameterNames = StringUtils.collectionToDelimitedString(context.getAllParameterNames(), ", ");
176177
}
177178

178179
AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) {
@@ -220,33 +221,18 @@ private CodeBlock pipeline(String pipelineVariableName) {
220221
String limitParameter = context.getLimitParameterName();
221222
String pageableParameter = context.getPageableParameterName();
222223

223-
boolean mightBeSorted = StringUtils.hasText(sortParameter);
224-
boolean mightBeLimited = StringUtils.hasText(limitParameter);
225-
boolean mightBePaged = StringUtils.hasText(pageableParameter);
226-
227-
int stageCount = source.stages().size();
228-
if (mightBeSorted) {
229-
stageCount++;
230-
}
231-
if (mightBeLimited) {
232-
stageCount++;
233-
}
234-
if (mightBePaged) {
235-
stageCount += 3;
236-
}
237-
238224
Builder builder = CodeBlock.builder();
239-
builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments));
225+
builder.add(aggregationStages(context.localVariable("stages"), source.stages()));
240226

241-
if (mightBeSorted) {
227+
if (StringUtils.hasText(sortParameter)) {
242228
builder.add(sortingStage(sortParameter));
243229
}
244230

245-
if (mightBeLimited) {
231+
if (StringUtils.hasText(limitParameter)) {
246232
builder.add(limitingStage(limitParameter));
247233
}
248234

249-
if (mightBePaged) {
235+
if (StringUtils.hasText(pageableParameter)) {
250236
builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery()));
251237
}
252238

@@ -259,6 +245,7 @@ private CodeBlock aggregationOptions(String aggregationVariableName) {
259245

260246
Builder builder = CodeBlock.builder();
261247
List<CodeBlock> options = new ArrayList<>(5);
248+
262249
if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) {
263250
options.add(CodeBlock.of(".skipOutput()"));
264251
}
@@ -299,20 +286,21 @@ private CodeBlock aggregationOptions(String aggregationVariableName) {
299286
return builder.build();
300287
}
301288

302-
private CodeBlock aggregationStages(String stageListVariableName, Iterable<String> stages, int stageCount,
303-
Map<String, CodeBlock> arguments) {
289+
private CodeBlock aggregationStages(String stageListVariableName, Collection<String> stages) {
304290

305291
Builder builder = CodeBlock.builder();
306292
builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class,
307-
stageCount);
293+
stages.size());
308294
int stageCounter = 0;
309295

310296
for (String stage : stages) {
311297

312298
VariableSnippet stageSnippet = Snippet.declare(builder)
313-
.variable(Document.class, context.localVariable("stage_%s".formatted(stageCounter++)))
314-
.of(MongoCodeBlocks.asDocument(stage, arguments));
299+
.variable(Document.class, context.localVariable("stage_%s".formatted(stageCounter)))
300+
.of(MongoCodeBlocks.asDocument(stage, parameterNames));
315301
builder.addStatement("$L.add($L)", stageListVariableName, stageSnippet.getVariableName());
302+
303+
stageCounter++;
316304
}
317305

318306
return builder.build();

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
package org.springframework.data.mongodb.repository.aot;
1717

18+
import org.jspecify.annotations.NullUnmarked;
19+
1820
import org.springframework.data.geo.Distance;
1921
import org.springframework.data.geo.GeoPage;
2022
import org.springframework.data.geo.GeoResults;
@@ -27,11 +29,14 @@
2729
import org.springframework.util.ClassUtils;
2830

2931
/**
32+
* Code blocks for generating code related to geo-near queries in MongoDB repositories.
33+
*
3034
* @author Christoph Strobl
3135
* @since 5.0
3236
*/
3337
class GeoBlocks {
3438

39+
@NullUnmarked
3540
static class GeoNearCodeBlockBuilder {
3641

3742
private final AotQueryMethodGenerationContext context;
@@ -91,16 +96,15 @@ public GeoNearCodeBlockBuilder usingQueryVariableName(String variableName) {
9196
}
9297
}
9398

99+
@NullUnmarked
94100
static class GeoNearExecutionCodeBlockBuilder {
95101

96102
private final AotQueryMethodGenerationContext context;
97-
private final MongoQueryMethod queryMethod;
98103
private String queryVariableName;
99104

100-
GeoNearExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) {
105+
GeoNearExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context) {
101106

102107
this.context = context;
103-
this.queryMethod = queryMethod;
104108
}
105109

106110
GeoNearExecutionCodeBlockBuilder referencing(String queryVariableName) {

0 commit comments

Comments
 (0)