Skip to content

Commit bf33816

Browse files
author
Artemiy Degtyarev
committed
add: basic keyset pagination support (without directions)
Signed-off-by: Artemiy Degtyarev <[email protected]>
1 parent 6a8f392 commit bf33816

File tree

3 files changed

+145
-66
lines changed

3 files changed

+145
-66
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java

Lines changed: 101 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,6 @@
1515
*/
1616
package org.springframework.data.jdbc.repository.query;
1717

18-
import static org.springframework.data.jdbc.repository.query.JdbcQueryExecution.*;
19-
20-
import java.sql.ResultSet;
21-
import java.util.ArrayList;
22-
import java.util.Collection;
23-
import java.util.List;
24-
import java.util.function.Function;
25-
import java.util.function.IntFunction;
26-
import java.util.function.LongSupplier;
27-
import java.util.function.Supplier;
28-
2918
import org.jspecify.annotations.Nullable;
3019
import org.springframework.core.convert.converter.Converter;
3120
import org.springframework.data.domain.*;
@@ -50,6 +39,17 @@
5039
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
5140
import org.springframework.util.Assert;
5241

42+
import java.lang.reflect.Field;
43+
import java.sql.ResultSet;
44+
import java.util.*;
45+
import java.util.function.Function;
46+
import java.util.function.IntFunction;
47+
import java.util.function.LongSupplier;
48+
import java.util.function.Supplier;
49+
import java.util.stream.Collectors;
50+
51+
import static org.springframework.data.jdbc.repository.query.JdbcQueryExecution.ResultProcessingConverter;
52+
5353
/**
5454
* An {@link AbstractJdbcQuery} implementation based on a {@link PartTree}.
5555
*
@@ -73,61 +73,61 @@ public class PartTreeJdbcQuery extends AbstractJdbcQuery {
7373
/**
7474
* Creates a new {@link PartTreeJdbcQuery}.
7575
*
76-
* @param queryMethod must not be {@literal null}.
77-
* @param operations must not be {@literal null}.
76+
* @param queryMethod must not be {@literal null}.
77+
* @param operations must not be {@literal null}.
7878
* @param rowMapperFactory must not be {@literal null}.
7979
* @since 4.0
8080
*/
8181
public PartTreeJdbcQuery(JdbcQueryMethod queryMethod, JdbcAggregateOperations operations,
82-
org.springframework.data.jdbc.repository.query.RowMapperFactory rowMapperFactory) {
82+
org.springframework.data.jdbc.repository.query.RowMapperFactory rowMapperFactory) {
8383
this(operations.getConverter().getMappingContext(), queryMethod, operations.getDataAccessStrategy().getDialect(),
84-
operations.getConverter(), operations.getDataAccessStrategy().getJdbcOperations(), rowMapperFactory);
84+
operations.getConverter(), operations.getDataAccessStrategy().getJdbcOperations(), rowMapperFactory);
8585
}
8686

8787
/**
8888
* Creates a new {@link PartTreeJdbcQuery}.
8989
*
90-
* @param queryMethod must not be {@literal null}.
91-
* @param dialect must not be {@literal null}.
92-
* @param converter must not be {@literal null}.
93-
* @param operations must not be {@literal null}.
90+
* @param queryMethod must not be {@literal null}.
91+
* @param dialect must not be {@literal null}.
92+
* @param converter must not be {@literal null}.
93+
* @param operations must not be {@literal null}.
9494
* @param rowMapperFactory must not be {@literal null}.
9595
*/
9696
public PartTreeJdbcQuery(JdbcQueryMethod queryMethod, Dialect dialect, JdbcConverter converter,
97-
NamedParameterJdbcOperations operations,
98-
org.springframework.data.jdbc.repository.query.RowMapperFactory rowMapperFactory) {
97+
NamedParameterJdbcOperations operations,
98+
org.springframework.data.jdbc.repository.query.RowMapperFactory rowMapperFactory) {
9999
this(converter.getMappingContext(), queryMethod, dialect, converter, operations, rowMapperFactory);
100100
}
101101

102102
/**
103103
* Creates a new {@link PartTreeJdbcQuery}.
104104
*
105-
* @param context must not be {@literal null}.
105+
* @param context must not be {@literal null}.
106106
* @param queryMethod must not be {@literal null}.
107-
* @param dialect must not be {@literal null}.
108-
* @param converter must not be {@literal null}.
109-
* @param operations must not be {@literal null}.
110-
* @param rowMapper must not be {@literal null}.
107+
* @param dialect must not be {@literal null}.
108+
* @param converter must not be {@literal null}.
109+
* @param operations must not be {@literal null}.
110+
* @param rowMapper must not be {@literal null}.
111111
*/
112112
public PartTreeJdbcQuery(RelationalMappingContext context, JdbcQueryMethod queryMethod, Dialect dialect,
113-
JdbcConverter converter, NamedParameterJdbcOperations operations, RowMapper<Object> rowMapper) {
113+
JdbcConverter converter, NamedParameterJdbcOperations operations, RowMapper<Object> rowMapper) {
114114
this(context, queryMethod, dialect, converter, operations, it -> rowMapper);
115115
}
116116

117117
/**
118118
* Creates a new {@link PartTreeJdbcQuery}.
119119
*
120-
* @param context must not be {@literal null}.
121-
* @param queryMethod must not be {@literal null}.
122-
* @param dialect must not be {@literal null}.
123-
* @param converter must not be {@literal null}.
124-
* @param operations must not be {@literal null}.
120+
* @param context must not be {@literal null}.
121+
* @param queryMethod must not be {@literal null}.
122+
* @param dialect must not be {@literal null}.
123+
* @param converter must not be {@literal null}.
124+
* @param operations must not be {@literal null}.
125125
* @param rowMapperFactory must not be {@literal null}.
126126
* @since 2.3
127127
*/
128128
public PartTreeJdbcQuery(RelationalMappingContext context, JdbcQueryMethod queryMethod, Dialect dialect,
129-
JdbcConverter converter, NamedParameterJdbcOperations operations,
130-
org.springframework.data.jdbc.repository.query.RowMapperFactory rowMapperFactory) {
129+
JdbcConverter converter, NamedParameterJdbcOperations operations,
130+
org.springframework.data.jdbc.repository.query.RowMapperFactory rowMapperFactory) {
131131

132132
super(queryMethod, operations);
133133

@@ -146,7 +146,7 @@ public PartTreeJdbcQuery(RelationalMappingContext context, JdbcQueryMethod query
146146
JdbcQueryCreator.validate(this.tree, this.parameters, this.converter.getMappingContext());
147147

148148
this.cachedRowMapperFactory = new CachedRowMapperFactory(tree, rowMapperFactory, converter,
149-
queryMethod.getResultProcessor());
149+
queryMethod.getResultProcessor());
150150
}
151151

152152
private Sort getDynamicSort(RelationalParameterAccessor accessor) {
@@ -158,7 +158,7 @@ private Sort getDynamicSort(RelationalParameterAccessor accessor) {
158158
public Object execute(Object[] values) {
159159

160160
RelationalParametersParameterAccessor accessor = new RelationalParametersParameterAccessor(getQueryMethod(),
161-
values);
161+
values);
162162

163163
if (tree.isDelete()) {
164164
JdbcQueryExecution<?> execution = createModifyingQueryExecutor();
@@ -180,17 +180,18 @@ public Object execute(Object[] values) {
180180
}
181181

182182
private JdbcQueryExecution<?> getQueryExecution(ResultProcessor processor,
183-
RelationalParametersParameterAccessor accessor) {
183+
RelationalParametersParameterAccessor accessor) {
184184

185185
ResultSetExtractor<Boolean> extractor = tree.isExistsProjection() ? (ResultSet::next) : null;
186186
Supplier<RowMapper<?>> rowMapper = parameters.hasDynamicProjection()
187-
? () -> cachedRowMapperFactory.getRowMapper(processor)
188-
: cachedRowMapperFactory;
187+
? () -> cachedRowMapperFactory.getRowMapper(processor)
188+
: cachedRowMapperFactory;
189189

190190
JdbcQueryExecution<?> queryExecution = getJdbcQueryExecution(extractor, rowMapper);
191191

192192
if (getQueryMethod().isScrollQuery()) {
193-
return new ScrollQueryExecution<>((JdbcQueryExecution<Collection<Object>>) queryExecution, accessor.getScrollPosition(), this.tree.getMaxResults());
193+
//noinspection unchecked
194+
return new ScrollQueryExecution<>((JdbcQueryExecution<Collection<Object>>) queryExecution, accessor.getScrollPosition(), this.tree.getMaxResults(), tree.getSort());
194195
}
195196

196197
if (getQueryMethod().isSliceQuery()) {
@@ -202,23 +203,23 @@ private JdbcQueryExecution<?> getQueryExecution(ResultProcessor processor,
202203

203204
// noinspection unchecked
204205
return new PageQueryExecution<>((JdbcQueryExecution<Collection<Object>>) queryExecution, accessor.getPageable(),
205-
() -> {
206+
() -> {
206207

207-
RelationalEntityMetadata<?> entityMetadata = getQueryMethod().getEntityInformation();
208+
RelationalEntityMetadata<?> entityMetadata = getQueryMethod().getEntityInformation();
208209

209-
JdbcCountQueryCreator queryCreator = new JdbcCountQueryCreator(context, tree, converter, dialect,
210-
entityMetadata, accessor, false, processor.getReturnedType(), getQueryMethod().lookupLockAnnotation(), false);
210+
JdbcCountQueryCreator queryCreator = new JdbcCountQueryCreator(context, tree, converter, dialect,
211+
entityMetadata, accessor, false, processor.getReturnedType(), getQueryMethod().lookupLockAnnotation(), false);
211212

212-
ParametrizedQuery countQuery = queryCreator.createQuery(Sort.unsorted());
213-
Object count = singleObjectQuery(new SingleColumnRowMapper<>(Number.class)).execute(countQuery.getQuery(),
214-
countQuery.getParameterSource(dialect.getLikeEscaper()));
213+
ParametrizedQuery countQuery = queryCreator.createQuery(Sort.unsorted());
214+
Object count = singleObjectQuery(new SingleColumnRowMapper<>(Number.class)).execute(countQuery.getQuery(),
215+
countQuery.getParameterSource(dialect.getLikeEscaper()));
215216

216-
Long converted = converter.getConversionService().convert(count, Long.class);
217+
Long converted = converter.getConversionService().convert(count, Long.class);
217218

218-
Assert.state(converted != null, "Count must not be null");
219+
Assert.state(converted != null, "Count must not be null");
219220

220-
return converted;
221-
});
221+
return converted;
222+
});
222223
}
223224

224225
return queryExecution;
@@ -229,7 +230,7 @@ ParametrizedQuery createQuery(RelationalParametersParameterAccessor accessor, Re
229230
RelationalEntityMetadata<?> entityMetadata = getQueryMethod().getEntityInformation();
230231

231232
JdbcQueryCreator queryCreator = new JdbcQueryCreator(context, tree, converter, dialect, entityMetadata, accessor,
232-
getQueryMethod().isSliceQuery(), returnedType, this.getQueryMethod().lookupLockAnnotation(), getQueryMethod().isScrollQuery());
233+
getQueryMethod().isSliceQuery(), returnedType, this.getQueryMethod().lookupLockAnnotation(), getQueryMethod().isScrollQuery());
233234
return queryCreator.createQuery(getDynamicSort(accessor));
234235
}
235236

@@ -238,12 +239,12 @@ private List<ParametrizedQuery> createDeleteQueries(RelationalParametersParamete
238239
RelationalEntityMetadata<?> entityMetadata = getQueryMethod().getEntityInformation();
239240

240241
JdbcDeleteQueryCreator queryCreator = new JdbcDeleteQueryCreator(context, tree, converter, dialect, entityMetadata,
241-
accessor);
242+
accessor);
242243
return queryCreator.createQuery();
243244
}
244245

245246
private JdbcQueryExecution<?> getJdbcQueryExecution(@Nullable ResultSetExtractor<Boolean> extractor,
246-
Supplier<RowMapper<?>> rowMapper) {
247+
Supplier<RowMapper<?>> rowMapper) {
247248

248249
if (getQueryMethod().isPageQuery() || getQueryMethod().isSliceQuery() || getQueryMethod().isScrollQuery()) {
249250
return collectionQuery(rowMapper.get());
@@ -260,12 +261,14 @@ private JdbcQueryExecution<?> getJdbcQueryExecution(@Nullable ResultSetExtractor
260261
static class ScrollQueryExecution<T> implements JdbcQueryExecution<Window<T>> {
261262
private final JdbcQueryExecution<? extends Collection<T>> delegate;
262263
private final ScrollPosition position;
263-
@Nullable private final Integer maxResults;
264+
private final @Nullable Integer maxResults;
265+
private final Sort sort;
264266

265-
ScrollQueryExecution(JdbcQueryExecution<? extends Collection<T>> delegate, ScrollPosition position, @Nullable Integer maxResults) {
267+
ScrollQueryExecution(JdbcQueryExecution<? extends Collection<T>> delegate, ScrollPosition position, @Nullable Integer maxResults, Sort sort) {
266268
this.delegate = delegate;
267269
this.position = position;
268270
this.maxResults = maxResults;
271+
this.sort = sort;
269272
}
270273

271274
@Override
@@ -277,10 +280,46 @@ static class ScrollQueryExecution<T> implements JdbcQueryExecution<Window<T>> {
277280
if (position instanceof OffsetScrollPosition)
278281
positionFunction = ((OffsetScrollPosition) position).positionFunction();
279282

283+
if (position instanceof KeysetScrollPosition) {
284+
Map<String, Object> keys = ((KeysetScrollPosition) position).getKeys();
285+
286+
if (keys.isEmpty()) {
287+
List<String> orders = sort.get().map(Sort.Order::getProperty).toList();
288+
289+
keys = extractKeys(resultList, orders);
290+
}
291+
292+
Map<String, Object> finalKeys = keys;
293+
positionFunction = (ignoredI) -> ScrollPosition.of(finalKeys, ((KeysetScrollPosition) position).getDirection()) ;
294+
}
295+
280296
boolean hasNext = resultList.size() >= maxResults;
281297

282298
return Window.from(resultList, positionFunction, hasNext);
283299
}
300+
301+
private Map<String, Object> extractKeys(List<T> resultList, List<String> orders) {
302+
T last = resultList.get(resultList.size() - 1);
303+
304+
Field[] fields = last.getClass().getDeclaredFields();
305+
306+
return Arrays
307+
.stream(fields)
308+
.filter(it -> orders.contains(it.getName()))
309+
.peek(it -> it.setAccessible(true))
310+
.collect(
311+
Collectors.toMap(
312+
Field::getName,
313+
it -> {
314+
try {
315+
return it.get(last);
316+
} catch (Exception e) {
317+
throw new RuntimeException(e);
318+
}
319+
}
320+
)
321+
);
322+
}
284323
}
285324

286325
/**
@@ -329,7 +368,7 @@ static class PageQueryExecution<T> implements JdbcQueryExecution<Slice<T>> {
329368
private final LongSupplier countSupplier;
330369

331370
PageQueryExecution(JdbcQueryExecution<? extends Collection<T>> delegate, Pageable pageable,
332-
LongSupplier countSupplier) {
371+
LongSupplier countSupplier) {
333372
this.delegate = delegate;
334373
this.pageable = pageable;
335374
this.countSupplier = countSupplier;
@@ -341,7 +380,7 @@ public Slice<T> execute(String query, SqlParameterSource parameter) {
341380
Collection<T> result = delegate.execute(query, parameter);
342381

343382
return PageableExecutionUtils.getPage(result instanceof List ? (List<T>) result : new ArrayList<>(result),
344-
pageable, countSupplier);
383+
pageable, countSupplier);
345384
}
346385

347386
}
@@ -356,18 +395,18 @@ class CachedRowMapperFactory implements Supplier<RowMapper<?>> {
356395
private final Function<ResultProcessor, RowMapper<?>> rowMapperFunction;
357396

358397
public CachedRowMapperFactory(PartTree tree,
359-
RowMapperFactory rowMapperFactory, RelationalConverter converter,
360-
ResultProcessor defaultResultProcessor) {
398+
RowMapperFactory rowMapperFactory, RelationalConverter converter,
399+
ResultProcessor defaultResultProcessor) {
361400

362401
this.rowMapperFunction = processor -> {
363402

364403
if (tree.isCountProjection() || tree.isExistsProjection()) {
365404
return rowMapperFactory.create(resolveTypeToRead(processor));
366405
}
367406
Converter<Object, Object> resultProcessingConverter = new ResultProcessingConverter(processor,
368-
converter.getMappingContext(), converter.getEntityInstantiators());
407+
converter.getMappingContext(), converter.getEntityInstantiators());
369408
return new ConvertingRowMapper(
370-
rowMapperFactory.create(processor.getReturnedType().getDomainType()), resultProcessingConverter);
409+
rowMapperFactory.create(processor.getReturnedType().getDomainType()), resultProcessingConverter);
371410
};
372411

373412
this.rowMapper = Lazy.of(() -> this.rowMapperFunction.apply(defaultResultProcessor));

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StatementFactory.java

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
*/
1616
package org.springframework.data.jdbc.repository.query;
1717

18-
import java.util.ArrayList;
19-
import java.util.Arrays;
20-
import java.util.Collection;
21-
import java.util.List;
18+
import java.util.*;
2219
import java.util.function.Predicate;
2320

2421
import org.jspecify.annotations.Nullable;
@@ -208,6 +205,8 @@ public String build(MapSqlParameterSource parameterSource) {
208205

209206
SelectBuilder.SelectLimitOffset limitOffsetBuilder = createSelectClause(entity, table);
210207
SelectBuilder.SelectWhere whereBuilder = applyLimitAndOffset(limitOffsetBuilder);
208+
criteria = applyScrollCriteria(criteria, scrollPosition);
209+
211210
SelectBuilder.SelectOrdered selectOrderBuilder = applyCriteria(criteria, entity, table, parameterSource,
212211
whereBuilder);
213212
selectOrderBuilder = applyOrderBy(sort, entity, table, selectOrderBuilder);
@@ -222,6 +221,31 @@ public String build(MapSqlParameterSource parameterSource) {
222221
return SqlRenderer.create(renderContextFactory.createRenderContext()).render(select);
223222
}
224223

224+
@Nullable Criteria applyScrollCriteria(@Nullable Criteria criteria, @Nullable ScrollPosition position) {
225+
if (!(position instanceof KeysetScrollPosition) || position.isInitial() || ((KeysetScrollPosition) position).getKeys().isEmpty())
226+
return criteria;
227+
228+
criteria = criteria == null ? Criteria.empty() : criteria;
229+
230+
Map<String, Object> keys = ((KeysetScrollPosition) position).getKeys();
231+
List<String> columns = new ArrayList<>(keys.keySet());
232+
List<Object> values = new ArrayList<>(keys.values());
233+
234+
Criteria result = null;
235+
236+
for (int i = 0; i < keys.size(); i++) {
237+
Criteria orCriteria = Criteria.where(columns.get(i)).greaterThan(values.get(i));
238+
239+
for (int j = 0; j < i; j++) {
240+
orCriteria = Criteria.where(columns.get(j)).is(values.get(j)).and(orCriteria);
241+
}
242+
243+
result = (result == null) ? orCriteria : result.or(orCriteria);
244+
}
245+
246+
return criteria.and(result);
247+
}
248+
225249
SelectBuilder.SelectOrdered applyOrderBy(Sort sort, RelationalPersistentEntity<?> entity, Table table,
226250
SelectBuilder.SelectOrdered selectOrdered) {
227251

0 commit comments

Comments
 (0)