1515 */
1616package org .springframework .data .jdbc .repository .query ;
1717
18- import static org .springframework .data .jdbc .repository .query .JdbcQueryExecution .*;
19-
20- import java .sql .ResultSet ;
21- import java .util .ArrayList ;
22- import java .util .Collection ;
23- import java .util .List ;
24- import java .util .function .Function ;
25- import java .util .function .IntFunction ;
26- import java .util .function .LongSupplier ;
27- import java .util .function .Supplier ;
28-
2918import org .jspecify .annotations .Nullable ;
3019import org .springframework .core .convert .converter .Converter ;
3120import org .springframework .data .domain .*;
5039import org .springframework .jdbc .core .namedparam .SqlParameterSource ;
5140import org .springframework .util .Assert ;
5241
42+ import java .lang .reflect .Field ;
43+ import java .sql .ResultSet ;
44+ import java .util .*;
45+ import java .util .function .Function ;
46+ import java .util .function .IntFunction ;
47+ import java .util .function .LongSupplier ;
48+ import java .util .function .Supplier ;
49+ import java .util .stream .Collectors ;
50+
51+ import static org .springframework .data .jdbc .repository .query .JdbcQueryExecution .ResultProcessingConverter ;
52+
5353/**
5454 * An {@link AbstractJdbcQuery} implementation based on a {@link PartTree}.
5555 *
@@ -73,61 +73,61 @@ public class PartTreeJdbcQuery extends AbstractJdbcQuery {
7373 /**
7474 * Creates a new {@link PartTreeJdbcQuery}.
7575 *
76- * @param queryMethod must not be {@literal null}.
77- * @param operations must not be {@literal null}.
76+ * @param queryMethod must not be {@literal null}.
77+ * @param operations must not be {@literal null}.
7878 * @param rowMapperFactory must not be {@literal null}.
7979 * @since 4.0
8080 */
8181 public PartTreeJdbcQuery (JdbcQueryMethod queryMethod , JdbcAggregateOperations operations ,
82- org .springframework .data .jdbc .repository .query .RowMapperFactory rowMapperFactory ) {
82+ org .springframework .data .jdbc .repository .query .RowMapperFactory rowMapperFactory ) {
8383 this (operations .getConverter ().getMappingContext (), queryMethod , operations .getDataAccessStrategy ().getDialect (),
84- operations .getConverter (), operations .getDataAccessStrategy ().getJdbcOperations (), rowMapperFactory );
84+ operations .getConverter (), operations .getDataAccessStrategy ().getJdbcOperations (), rowMapperFactory );
8585 }
8686
8787 /**
8888 * Creates a new {@link PartTreeJdbcQuery}.
8989 *
90- * @param queryMethod must not be {@literal null}.
91- * @param dialect must not be {@literal null}.
92- * @param converter must not be {@literal null}.
93- * @param operations must not be {@literal null}.
90+ * @param queryMethod must not be {@literal null}.
91+ * @param dialect must not be {@literal null}.
92+ * @param converter must not be {@literal null}.
93+ * @param operations must not be {@literal null}.
9494 * @param rowMapperFactory must not be {@literal null}.
9595 */
9696 public PartTreeJdbcQuery (JdbcQueryMethod queryMethod , Dialect dialect , JdbcConverter converter ,
97- NamedParameterJdbcOperations operations ,
98- org .springframework .data .jdbc .repository .query .RowMapperFactory rowMapperFactory ) {
97+ NamedParameterJdbcOperations operations ,
98+ org .springframework .data .jdbc .repository .query .RowMapperFactory rowMapperFactory ) {
9999 this (converter .getMappingContext (), queryMethod , dialect , converter , operations , rowMapperFactory );
100100 }
101101
102102 /**
103103 * Creates a new {@link PartTreeJdbcQuery}.
104104 *
105- * @param context must not be {@literal null}.
105+ * @param context must not be {@literal null}.
106106 * @param queryMethod must not be {@literal null}.
107- * @param dialect must not be {@literal null}.
108- * @param converter must not be {@literal null}.
109- * @param operations must not be {@literal null}.
110- * @param rowMapper must not be {@literal null}.
107+ * @param dialect must not be {@literal null}.
108+ * @param converter must not be {@literal null}.
109+ * @param operations must not be {@literal null}.
110+ * @param rowMapper must not be {@literal null}.
111111 */
112112 public PartTreeJdbcQuery (RelationalMappingContext context , JdbcQueryMethod queryMethod , Dialect dialect ,
113- JdbcConverter converter , NamedParameterJdbcOperations operations , RowMapper <Object > rowMapper ) {
113+ JdbcConverter converter , NamedParameterJdbcOperations operations , RowMapper <Object > rowMapper ) {
114114 this (context , queryMethod , dialect , converter , operations , it -> rowMapper );
115115 }
116116
117117 /**
118118 * Creates a new {@link PartTreeJdbcQuery}.
119119 *
120- * @param context must not be {@literal null}.
121- * @param queryMethod must not be {@literal null}.
122- * @param dialect must not be {@literal null}.
123- * @param converter must not be {@literal null}.
124- * @param operations must not be {@literal null}.
120+ * @param context must not be {@literal null}.
121+ * @param queryMethod must not be {@literal null}.
122+ * @param dialect must not be {@literal null}.
123+ * @param converter must not be {@literal null}.
124+ * @param operations must not be {@literal null}.
125125 * @param rowMapperFactory must not be {@literal null}.
126126 * @since 2.3
127127 */
128128 public PartTreeJdbcQuery (RelationalMappingContext context , JdbcQueryMethod queryMethod , Dialect dialect ,
129- JdbcConverter converter , NamedParameterJdbcOperations operations ,
130- org .springframework .data .jdbc .repository .query .RowMapperFactory rowMapperFactory ) {
129+ JdbcConverter converter , NamedParameterJdbcOperations operations ,
130+ org .springframework .data .jdbc .repository .query .RowMapperFactory rowMapperFactory ) {
131131
132132 super (queryMethod , operations );
133133
@@ -146,7 +146,7 @@ public PartTreeJdbcQuery(RelationalMappingContext context, JdbcQueryMethod query
146146 JdbcQueryCreator .validate (this .tree , this .parameters , this .converter .getMappingContext ());
147147
148148 this .cachedRowMapperFactory = new CachedRowMapperFactory (tree , rowMapperFactory , converter ,
149- queryMethod .getResultProcessor ());
149+ queryMethod .getResultProcessor ());
150150 }
151151
152152 private Sort getDynamicSort (RelationalParameterAccessor accessor ) {
@@ -158,7 +158,7 @@ private Sort getDynamicSort(RelationalParameterAccessor accessor) {
158158 public Object execute (Object [] values ) {
159159
160160 RelationalParametersParameterAccessor accessor = new RelationalParametersParameterAccessor (getQueryMethod (),
161- values );
161+ values );
162162
163163 if (tree .isDelete ()) {
164164 JdbcQueryExecution <?> execution = createModifyingQueryExecutor ();
@@ -180,17 +180,18 @@ public Object execute(Object[] values) {
180180 }
181181
182182 private JdbcQueryExecution <?> getQueryExecution (ResultProcessor processor ,
183- RelationalParametersParameterAccessor accessor ) {
183+ RelationalParametersParameterAccessor accessor ) {
184184
185185 ResultSetExtractor <Boolean > extractor = tree .isExistsProjection () ? (ResultSet ::next ) : null ;
186186 Supplier <RowMapper <?>> rowMapper = parameters .hasDynamicProjection ()
187- ? () -> cachedRowMapperFactory .getRowMapper (processor )
188- : cachedRowMapperFactory ;
187+ ? () -> cachedRowMapperFactory .getRowMapper (processor )
188+ : cachedRowMapperFactory ;
189189
190190 JdbcQueryExecution <?> queryExecution = getJdbcQueryExecution (extractor , rowMapper );
191191
192192 if (getQueryMethod ().isScrollQuery ()) {
193- return new ScrollQueryExecution <>((JdbcQueryExecution <Collection <Object >>) queryExecution , accessor .getScrollPosition (), this .tree .getMaxResults ());
193+ //noinspection unchecked
194+ return new ScrollQueryExecution <>((JdbcQueryExecution <Collection <Object >>) queryExecution , accessor .getScrollPosition (), this .tree .getMaxResults (), tree .getSort ());
194195 }
195196
196197 if (getQueryMethod ().isSliceQuery ()) {
@@ -202,23 +203,23 @@ private JdbcQueryExecution<?> getQueryExecution(ResultProcessor processor,
202203
203204 // noinspection unchecked
204205 return new PageQueryExecution <>((JdbcQueryExecution <Collection <Object >>) queryExecution , accessor .getPageable (),
205- () -> {
206+ () -> {
206207
207- RelationalEntityMetadata <?> entityMetadata = getQueryMethod ().getEntityInformation ();
208+ RelationalEntityMetadata <?> entityMetadata = getQueryMethod ().getEntityInformation ();
208209
209- JdbcCountQueryCreator queryCreator = new JdbcCountQueryCreator (context , tree , converter , dialect ,
210- entityMetadata , accessor , false , processor .getReturnedType (), getQueryMethod ().lookupLockAnnotation (), false );
210+ JdbcCountQueryCreator queryCreator = new JdbcCountQueryCreator (context , tree , converter , dialect ,
211+ entityMetadata , accessor , false , processor .getReturnedType (), getQueryMethod ().lookupLockAnnotation (), false );
211212
212- ParametrizedQuery countQuery = queryCreator .createQuery (Sort .unsorted ());
213- Object count = singleObjectQuery (new SingleColumnRowMapper <>(Number .class )).execute (countQuery .getQuery (),
214- countQuery .getParameterSource (dialect .getLikeEscaper ()));
213+ ParametrizedQuery countQuery = queryCreator .createQuery (Sort .unsorted ());
214+ Object count = singleObjectQuery (new SingleColumnRowMapper <>(Number .class )).execute (countQuery .getQuery (),
215+ countQuery .getParameterSource (dialect .getLikeEscaper ()));
215216
216- Long converted = converter .getConversionService ().convert (count , Long .class );
217+ Long converted = converter .getConversionService ().convert (count , Long .class );
217218
218- Assert .state (converted != null , "Count must not be null" );
219+ Assert .state (converted != null , "Count must not be null" );
219220
220- return converted ;
221- });
221+ return converted ;
222+ });
222223 }
223224
224225 return queryExecution ;
@@ -229,7 +230,7 @@ ParametrizedQuery createQuery(RelationalParametersParameterAccessor accessor, Re
229230 RelationalEntityMetadata <?> entityMetadata = getQueryMethod ().getEntityInformation ();
230231
231232 JdbcQueryCreator queryCreator = new JdbcQueryCreator (context , tree , converter , dialect , entityMetadata , accessor ,
232- getQueryMethod ().isSliceQuery (), returnedType , this .getQueryMethod ().lookupLockAnnotation (), getQueryMethod ().isScrollQuery ());
233+ getQueryMethod ().isSliceQuery (), returnedType , this .getQueryMethod ().lookupLockAnnotation (), getQueryMethod ().isScrollQuery ());
233234 return queryCreator .createQuery (getDynamicSort (accessor ));
234235 }
235236
@@ -238,12 +239,12 @@ private List<ParametrizedQuery> createDeleteQueries(RelationalParametersParamete
238239 RelationalEntityMetadata <?> entityMetadata = getQueryMethod ().getEntityInformation ();
239240
240241 JdbcDeleteQueryCreator queryCreator = new JdbcDeleteQueryCreator (context , tree , converter , dialect , entityMetadata ,
241- accessor );
242+ accessor );
242243 return queryCreator .createQuery ();
243244 }
244245
245246 private JdbcQueryExecution <?> getJdbcQueryExecution (@ Nullable ResultSetExtractor <Boolean > extractor ,
246- Supplier <RowMapper <?>> rowMapper ) {
247+ Supplier <RowMapper <?>> rowMapper ) {
247248
248249 if (getQueryMethod ().isPageQuery () || getQueryMethod ().isSliceQuery () || getQueryMethod ().isScrollQuery ()) {
249250 return collectionQuery (rowMapper .get ());
@@ -260,12 +261,14 @@ private JdbcQueryExecution<?> getJdbcQueryExecution(@Nullable ResultSetExtractor
260261 static class ScrollQueryExecution <T > implements JdbcQueryExecution <Window <T >> {
261262 private final JdbcQueryExecution <? extends Collection <T >> delegate ;
262263 private final ScrollPosition position ;
263- @ Nullable private final Integer maxResults ;
264+ private final @ Nullable Integer maxResults ;
265+ private final Sort sort ;
264266
265- ScrollQueryExecution (JdbcQueryExecution <? extends Collection <T >> delegate , ScrollPosition position , @ Nullable Integer maxResults ) {
267+ ScrollQueryExecution (JdbcQueryExecution <? extends Collection <T >> delegate , ScrollPosition position , @ Nullable Integer maxResults , Sort sort ) {
266268 this .delegate = delegate ;
267269 this .position = position ;
268270 this .maxResults = maxResults ;
271+ this .sort = sort ;
269272 }
270273
271274 @ Override
@@ -277,10 +280,46 @@ static class ScrollQueryExecution<T> implements JdbcQueryExecution<Window<T>> {
277280 if (position instanceof OffsetScrollPosition )
278281 positionFunction = ((OffsetScrollPosition ) position ).positionFunction ();
279282
283+ if (position instanceof KeysetScrollPosition ) {
284+ Map <String , Object > keys = ((KeysetScrollPosition ) position ).getKeys ();
285+
286+ if (keys .isEmpty ()) {
287+ List <String > orders = sort .get ().map (Sort .Order ::getProperty ).toList ();
288+
289+ keys = extractKeys (resultList , orders );
290+ }
291+
292+ Map <String , Object > finalKeys = keys ;
293+ positionFunction = (ignoredI ) -> ScrollPosition .of (finalKeys , ((KeysetScrollPosition ) position ).getDirection ()) ;
294+ }
295+
280296 boolean hasNext = resultList .size () >= maxResults ;
281297
282298 return Window .from (resultList , positionFunction , hasNext );
283299 }
300+
301+ private Map <String , Object > extractKeys (List <T > resultList , List <String > orders ) {
302+ T last = resultList .get (resultList .size () - 1 );
303+
304+ Field [] fields = last .getClass ().getDeclaredFields ();
305+
306+ return Arrays
307+ .stream (fields )
308+ .filter (it -> orders .contains (it .getName ()))
309+ .peek (it -> it .setAccessible (true ))
310+ .collect (
311+ Collectors .toMap (
312+ Field ::getName ,
313+ it -> {
314+ try {
315+ return it .get (last );
316+ } catch (Exception e ) {
317+ throw new RuntimeException (e );
318+ }
319+ }
320+ )
321+ );
322+ }
284323 }
285324
286325 /**
@@ -329,7 +368,7 @@ static class PageQueryExecution<T> implements JdbcQueryExecution<Slice<T>> {
329368 private final LongSupplier countSupplier ;
330369
331370 PageQueryExecution (JdbcQueryExecution <? extends Collection <T >> delegate , Pageable pageable ,
332- LongSupplier countSupplier ) {
371+ LongSupplier countSupplier ) {
333372 this .delegate = delegate ;
334373 this .pageable = pageable ;
335374 this .countSupplier = countSupplier ;
@@ -341,7 +380,7 @@ public Slice<T> execute(String query, SqlParameterSource parameter) {
341380 Collection <T > result = delegate .execute (query , parameter );
342381
343382 return PageableExecutionUtils .getPage (result instanceof List ? (List <T >) result : new ArrayList <>(result ),
344- pageable , countSupplier );
383+ pageable , countSupplier );
345384 }
346385
347386 }
@@ -356,18 +395,18 @@ class CachedRowMapperFactory implements Supplier<RowMapper<?>> {
356395 private final Function <ResultProcessor , RowMapper <?>> rowMapperFunction ;
357396
358397 public CachedRowMapperFactory (PartTree tree ,
359- RowMapperFactory rowMapperFactory , RelationalConverter converter ,
360- ResultProcessor defaultResultProcessor ) {
398+ RowMapperFactory rowMapperFactory , RelationalConverter converter ,
399+ ResultProcessor defaultResultProcessor ) {
361400
362401 this .rowMapperFunction = processor -> {
363402
364403 if (tree .isCountProjection () || tree .isExistsProjection ()) {
365404 return rowMapperFactory .create (resolveTypeToRead (processor ));
366405 }
367406 Converter <Object , Object > resultProcessingConverter = new ResultProcessingConverter (processor ,
368- converter .getMappingContext (), converter .getEntityInstantiators ());
407+ converter .getMappingContext (), converter .getEntityInstantiators ());
369408 return new ConvertingRowMapper (
370- rowMapperFactory .create (processor .getReturnedType ().getDomainType ()), resultProcessingConverter );
409+ rowMapperFactory .create (processor .getReturnedType ().getDomainType ()), resultProcessingConverter );
371410 };
372411
373412 this .rowMapper = Lazy .of (() -> this .rowMapperFunction .apply (defaultResultProcessor ));
0 commit comments