Skip to content

Commit 7cc48a0

Browse files
committed
Add raw Similarity, improve SearchResult detection.
1 parent 0ccd4a3 commit 7cc48a0

File tree

10 files changed

+187
-51
lines changed

10 files changed

+187
-51
lines changed

src/main/java/org/springframework/data/domain/Score.java

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,30 +68,6 @@ public static Range<Score> between(Score min, Score max) {
6868
return Range.from(Range.Bound.inclusive(min)).to(Range.Bound.inclusive(max));
6969
}
7070

71-
/**
72-
* Creates a new {@link Range} by creating minimum and maximum {@link Score} from the given values without
73-
* {@link ScoringFunction#UNSPECIFIED specifying a scoring function}.
74-
*
75-
* @param minValue minimum value.
76-
* @param maxValue maximum value.
77-
* @return the {@link Range} between the given values.
78-
*/
79-
public static Range<Score> between(double minValue, double maxValue) {
80-
return between(minValue, maxValue, ScoringFunction.UNSPECIFIED);
81-
}
82-
83-
/**
84-
* Creates a new {@link Range} by creating minimum and maximum {@link Score} from the given values.
85-
*
86-
* @param minValue minimum value.
87-
* @param maxValue maximum value.
88-
* @param function the scoring function to use.
89-
* @return the {@link Range} between the given values.
90-
*/
91-
public static Range<Score> between(double minValue, double maxValue, ScoringFunction function) {
92-
return between(Score.of(minValue, function), Score.of(maxValue, function));
93-
}
94-
9571
public double getValue() {
9672
return value;
9773
}

src/main/java/org/springframework/data/domain/SearchResults.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public List<SearchResult<T>> getContent() {
5050
}
5151

5252
@Override
53+
@SuppressWarnings("unchecked")
5354
public Iterator<SearchResult<T>> iterator() {
5455
return (Iterator<SearchResult<T>>) results.iterator();
5556
}
@@ -64,11 +65,7 @@ public <U> SearchResults<U> map(Function<? super T, ? extends U> converter) {
6465

6566
Assert.notNull(converter, "Function must not be null");
6667

67-
List<SearchResult<U>> result = results.stream().map(it -> {
68-
69-
SearchResult<U> mapped = it.map(converter);
70-
return mapped;
71-
}).collect(Collectors.toList());
68+
List<SearchResult<U>> result = results.stream().map(it -> it.<U> map(converter)).collect(Collectors.toList());
7269

7370
return new SearchResults<>(result);
7471
}
@@ -93,7 +90,8 @@ public int hashCode() {
9390

9491
@Override
9592
public String toString() {
96-
return String.format("SearchResults: [results: %s]", StringUtils.collectionToCommaDelimitedString(results));
93+
return results.isEmpty() ? "SearchResults: [empty]"
94+
: String.format("SearchResults: [results: %s]", StringUtils.collectionToCommaDelimitedString(results));
9795
}
9896

9997
}

src/main/java/org/springframework/data/domain/Similarity.java

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import org.springframework.util.Assert;
1919

2020
/**
21-
* Value object to represent a similarity score determined by a {@link ScoringFunction}. Similarity is expressed through
21+
* Value object to represent a similarity value determined by a {@link ScoringFunction}. Similarity is expressed through
2222
* a numerical value ranging between {@code 0} and {@code 1} where zero represents the lowest similarity and one the
2323
* highest similarity.
2424
* <p>
@@ -35,27 +35,40 @@ private Similarity(double value, ScoringFunction function) {
3535
}
3636

3737
/**
38-
* Creates a new {@link Similarity} from a plain {@code score} value using {@link ScoringFunction#UNSPECIFIED}.
38+
* Creates a new {@link Similarity} from a plain {@code similarity} value using {@link ScoringFunction#UNSPECIFIED}.
3939
*
40-
* @param score the score value without a specific {@link ScoringFunction}, ranging between {@code 0} and {@code 1}.
40+
* @param similarity the similarity value without a specific {@link ScoringFunction}, ranging between {@code 0} and
41+
* {@code 1}.
4142
* @return the new {@link Similarity}.
4243
*/
43-
public static Similarity of(double score) {
44-
return of(score, ScoringFunction.UNSPECIFIED);
44+
public static Similarity of(double similarity) {
45+
return of(similarity, ScoringFunction.UNSPECIFIED);
4546
}
4647

4748
/**
48-
* Creates a new {@link Similarity} from a {@code score} value using the given {@link ScoringFunction}.
49+
* Creates a new {@link Similarity} from a {@code similarity} value using the given {@link ScoringFunction}.
4950
*
50-
* @param score the score value, ranging between {@code 0} and {@code 1}.
51-
* @param function the scoring function that has computed the {@code score}.
51+
* @param similarity the similarity value, ranging between {@code 0} and {@code 1}.
52+
* @param function the scoring function that has computed the {@code similarity}.
5253
* @return the new {@link Similarity}.
5354
*/
54-
public static Similarity of(double score, ScoringFunction function) {
55+
public static Similarity of(double similarity, ScoringFunction function) {
5556

56-
Assert.isTrue(score >= (double) 0.0F && score <= (double) 1.0F, "Similarity must be in [0,1] range.");
57+
Assert.isTrue(similarity >= 0.0 && similarity <= 1.0, "Similarity must be in [0,1] range.");
5758

58-
return new Similarity(score, function);
59+
return new Similarity(similarity, function);
60+
}
61+
62+
/**
63+
* Creates a new raw {@link Similarity} from a {@code similarity} value using the given {@link ScoringFunction}.
64+
* Typically, this method is used when accepting external similarity values coming from a database search result.
65+
*
66+
* @param similarity the similarity value, ranging between {@code 0} and {@code 1}.
67+
* @param function the scoring function that has computed the {@code similarity}.
68+
* @return the new {@link Similarity}.
69+
*/
70+
public static Similarity raw(double similarity, ScoringFunction function) {
71+
return new Similarity(similarity, function);
5972
}
6073

6174
/**
@@ -77,7 +90,7 @@ public static Range<Similarity> between(Similarity min, Similarity max) {
7790
* @param maxValue maximum value, ranging between {@code 0} and {@code 1}.
7891
* @return the {@link Range} between the given values.
7992
*/
80-
public static Range<Score> between(double minValue, double maxValue) {
93+
public static Range<Similarity> between(double minValue, double maxValue) {
8194
return between(minValue, maxValue, ScoringFunction.UNSPECIFIED);
8295
}
8396

@@ -89,8 +102,8 @@ public static Range<Score> between(double minValue, double maxValue) {
89102
* @param function the scoring function to use.
90103
* @return the {@link Range} between the given values.
91104
*/
92-
public static Range<Score> between(double minValue, double maxValue, ScoringFunction function) {
93-
return (Range) between(Similarity.of(minValue, function), Similarity.of(maxValue, function));
105+
public static Range<Similarity> between(double minValue, double maxValue, ScoringFunction function) {
106+
return between(Similarity.of(minValue, function), Similarity.of(maxValue, function));
94107
}
95108

96109
@Override
@@ -100,4 +113,5 @@ public boolean equals(Object o) {
100113
}
101114
return super.equals(other);
102115
}
116+
103117
}

src/main/java/org/springframework/data/repository/query/Parameters.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ protected Parameters(ParametersSource parametersSource, Function<MethodParameter
124124
}
125125

126126
if (Range.class.isAssignableFrom(parameter.getType())
127-
&& ResolvableType.forMethodParameter(methodParameter).getGeneric(0)
128-
.isAssignableFrom(Score.class)) {
127+
&& Score.class.isAssignableFrom(ResolvableType.forMethodParameter(methodParameter).getGeneric(0).toClass())) {
129128
scoreRangeIndex = i;
130129
}
131130

src/main/java/org/springframework/data/repository/query/QueryMethod.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.springframework.data.domain.Page;
2727
import org.springframework.data.domain.Pageable;
2828
import org.springframework.data.domain.ScrollPosition;
29+
import org.springframework.data.domain.SearchResult;
2930
import org.springframework.data.domain.SearchResults;
3031
import org.springframework.data.domain.Slice;
3132
import org.springframework.data.domain.Sort;
@@ -42,6 +43,7 @@
4243
import org.springframework.data.util.TypeInformation;
4344
import org.springframework.lang.Nullable;
4445
import org.springframework.util.Assert;
46+
import org.springframework.util.ClassUtils;
4547

4648
/**
4749
* Abstraction of a method that is designated to execute a finder query. Enriches the standard {@link Method} interface
@@ -282,13 +284,21 @@ public final boolean isPageQuery() {
282284
}
283285

284286
/**
285-
* Returns whether the finder will return a {@link SearchResults} of results.
287+
* Returns whether the finder will return a {@link SearchResults} (or collection of {@link SearchResult}) of results.
286288
*
287289
* @return
288290
* @since 4.0
289291
*/
290-
public final boolean isSearchQuery() {
291-
return org.springframework.util.ClassUtils.isAssignable(SearchResults.class, unwrappedReturnType);
292+
public boolean isSearchQuery() {
293+
294+
if (ClassUtils.isAssignable(SearchResults.class, unwrappedReturnType)) {
295+
return true;
296+
}
297+
298+
TypeInformation<?> returnType = metadata.getReturnType(method);
299+
TypeInformation<?> componentType = returnType.getComponentType();
300+
301+
return componentType != null && SearchResult.class.isAssignableFrom(componentType.getType());
292302
}
293303

294304
/**

src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.springframework.core.convert.support.ConfigurableConversionService;
3939
import org.springframework.core.convert.support.DefaultConversionService;
4040
import org.springframework.data.domain.Page;
41+
import org.springframework.data.domain.SearchResult;
4142
import org.springframework.data.domain.SearchResults;
4243
import org.springframework.data.domain.Slice;
4344
import org.springframework.data.domain.Window;
@@ -255,6 +256,7 @@ public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type, Typ
255256
boolean needToUnwrap = type.isCollectionLike() //
256257
|| Slice.class.isAssignableFrom(rawType) //
257258
|| GeoResults.class.isAssignableFrom(rawType) //
259+
|| SearchResult.class.isAssignableFrom(rawType) //
258260
|| SearchResults.class.isAssignableFrom(rawType) //
259261
|| rawType.isArray() //
260262
|| supports(rawType) //
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.domain;
17+
18+
import static org.assertj.core.api.Assertions.*;
19+
20+
import org.junit.jupiter.api.Test;
21+
22+
/**
23+
* Unit tests for {@link Similarity}.
24+
*
25+
* @author Mark Paluch
26+
*/
27+
class SimilarityUnitTests {
28+
29+
@Test
30+
void shouldBeBounded() {
31+
32+
assertThatIllegalArgumentException().isThrownBy(() -> Similarity.of(-1));
33+
assertThatIllegalArgumentException().isThrownBy(() -> Similarity.of(1.01));
34+
}
35+
36+
@Test
37+
void shouldConstructRawSimilarity() {
38+
39+
Similarity similarity = Similarity.raw(2, ScoringFunction.UNSPECIFIED);
40+
41+
assertThat(similarity.getValue()).isEqualTo(2);
42+
}
43+
44+
@Test
45+
void shouldConstructGenericSimilarity() {
46+
47+
Similarity similarity = Similarity.of(1);
48+
49+
assertThat(similarity).isEqualTo(Similarity.of(1)).isNotEqualTo(Score.of(1)).isNotEqualTo(Similarity.of(0.5));
50+
assertThat(similarity).hasToString("1.0");
51+
assertThat(similarity.getFunction()).isEqualTo(ScoringFunction.UNSPECIFIED);
52+
}
53+
54+
@Test
55+
void shouldConstructMeteredSimilarity() {
56+
57+
Similarity similarity = Similarity.of(1, VectorScoringFunctions.COSINE);
58+
59+
assertThat(similarity).isEqualTo(Similarity.of(1, VectorScoringFunctions.COSINE))
60+
.isNotEqualTo(Score.of(1, VectorScoringFunctions.COSINE)).isNotEqualTo(Similarity.of(1));
61+
assertThat(similarity).hasToString("1.0 (COSINE)");
62+
assertThat(similarity.getFunction()).isEqualTo(VectorScoringFunctions.COSINE);
63+
}
64+
65+
@Test
66+
void shouldConstructRange() {
67+
68+
Range<Similarity> range = Similarity.between(0.5, 1);
69+
70+
assertThat(range.getLowerBound().getValue()).contains(Similarity.of(0.5));
71+
assertThat(range.getLowerBound().isInclusive()).isTrue();
72+
73+
assertThat(range.getUpperBound().getValue()).contains(Similarity.of(1));
74+
assertThat(range.getUpperBound().isInclusive()).isTrue();
75+
}
76+
77+
@Test
78+
void shouldConstructRangeWithFunction() {
79+
80+
Range<Similarity> range = Similarity.between(0.5, 1, VectorScoringFunctions.COSINE);
81+
82+
assertThat(range.getLowerBound().getValue()).contains(Similarity.of(0.5, VectorScoringFunctions.COSINE));
83+
assertThat(range.getLowerBound().isInclusive()).isTrue();
84+
85+
assertThat(range.getUpperBound().getValue()).contains(Similarity.of(1, VectorScoringFunctions.COSINE));
86+
assertThat(range.getUpperBound().isInclusive()).isTrue();
87+
}
88+
89+
}

src/test/java/org/springframework/data/repository/query/ParametersUnitTests.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
import org.springframework.data.domain.OffsetScrollPosition;
3232
import org.springframework.data.domain.Page;
3333
import org.springframework.data.domain.Pageable;
34+
import org.springframework.data.domain.Range;
35+
import org.springframework.data.domain.Score;
36+
import org.springframework.data.domain.Similarity;
3437
import org.springframework.data.domain.Sort;
3538
import org.springframework.data.domain.Window;
3639
import org.springframework.data.repository.Repository;
@@ -230,6 +233,22 @@ void considersGenericType() throws Exception {
230233
assertThat(parameters.getParameter(0).getType()).isEqualTo(Long.class);
231234
}
232235

236+
@Test // GH-
237+
void considersScoreRange() throws Exception {
238+
239+
var parameters = getParametersFor("methodWithScoreRange", Range.class);
240+
241+
assertThat(parameters.hasScoreRangeParameter()).isTrue();
242+
}
243+
244+
@Test // GH-
245+
void considersSimilarityRange() throws Exception {
246+
247+
var parameters = getParametersFor("methodWithSimilarityRange", Range.class);
248+
249+
assertThat(parameters.hasScoreRangeParameter()).isTrue();
250+
}
251+
233252
private Parameters<?, Parameter> getParametersFor(String methodName, Class<?>... parameterTypes)
234253
throws SecurityException, NoSuchMethodException {
235254

@@ -268,6 +287,10 @@ interface SampleDao extends Repository<User, String> {
268287

269288
void methodWithSingle(Single<String> single);
270289

290+
void methodWithScoreRange(Range<Score> single);
291+
292+
void methodWithSimilarityRange(Range<Similarity> single);
293+
271294
Page<Object> customPageable(SomePageable pageable);
272295

273296
Window<Object> customScrollPosition(OffsetScrollPosition request);

0 commit comments

Comments
 (0)