Skip to content

Commit 883a81d

Browse files
committed
Consider multiple selections for the same property.
1 parent 3e5da60 commit 883a81d

File tree

4 files changed

+64
-27
lines changed

4 files changed

+64
-27
lines changed

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/convert/QueryMapper.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ public List<Selector> getMappedSelectors(Columns columns, CassandraPersistentEnt
199199

200200
Field field = createPropertyField(entity, column);
201201

202-
columns.getSelector(column).ifPresent(selector -> {
202+
columns.getSelector(column).forEach(selector -> {
203203

204204
List<CqlIdentifier> mappedColumnNames = getCqlIdentifier(column, field);
205205

@@ -301,8 +301,12 @@ public List<CqlIdentifier> getMappedColumnNames(Columns columns, CassandraPersis
301301
Field field = createPropertyField(entity, column);
302302
field.getProperty().ifPresent(seen::add);
303303

304-
columns.getSelector(column).filter(selector -> selector instanceof ColumnSelector)
305-
.ifPresent(columnSelector -> columnNames.addAll(getCqlIdentifier(column, field)));
304+
columns.getSelector(column).forEach(columnSelector -> {
305+
306+
if (columnSelector instanceof ColumnSelector) {
307+
columnNames.addAll(getCqlIdentifier(column, field));
308+
}
309+
});
306310
}
307311

308312
if (columns.isEmpty()) {

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/query/Columns.java

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.springframework.data.cassandra.core.query;
1717

18+
import java.util.ArrayList;
1819
import java.util.Arrays;
1920
import java.util.Collections;
2021
import java.util.Iterator;
@@ -52,9 +53,9 @@
5253
*/
5354
public class Columns implements Iterable<ColumnName> {
5455

55-
private final Map<ColumnName, Selector> columns;
56+
private final Map<ColumnName, List<Selector>> columns;
5657

57-
private Columns(Map<ColumnName, Selector> columns) {
58+
private Columns(Map<ColumnName, List<Selector>> columns) {
5859
this.columns = Collections.unmodifiableMap(columns);
5960
}
6061

@@ -77,10 +78,10 @@ public static Columns from(String... columnNames) {
7778

7879
Assert.notNull(columnNames, "Column names must not be null");
7980

80-
Map<ColumnName, Selector> columns = new LinkedHashMap<>(columnNames.length, 1);
81+
Map<ColumnName, List<Selector>> columns = new LinkedHashMap<>(columnNames.length, 1);
8182

8283
for (String columnName : columnNames) {
83-
columns.put(ColumnName.from(columnName), ColumnSelector.from(columnName));
84+
columns.put(ColumnName.from(columnName), new ArrayList<>(List.of(ColumnSelector.from(columnName))));
8485
}
8586

8687
return new Columns(columns);
@@ -96,10 +97,10 @@ public static Columns from(CqlIdentifier... columnNames) {
9697

9798
Assert.notNull(columnNames, "Column names must not be null");
9899

99-
Map<ColumnName, Selector> columns = new LinkedHashMap<>(columnNames.length, 1);
100+
Map<ColumnName, List<Selector>> columns = new LinkedHashMap<>(columnNames.length, 1);
100101

101102
for (CqlIdentifier cqlId : columnNames) {
102-
columns.put(ColumnName.from(cqlId), ColumnSelector.from(cqlId));
103+
columns.put(ColumnName.from(cqlId), new ArrayList<>(List.of(ColumnSelector.from(cqlId))));
103104
}
104105

105106
return new Columns(columns);
@@ -210,8 +211,8 @@ public Columns select(CqlIdentifier columnName, Selector selector) {
210211
*/
211212
private Columns select(ColumnName columnName, Selector selector) {
212213

213-
Map<ColumnName, Selector> result = new LinkedHashMap<>(this.columns);
214-
result.put(columnName, selector);
214+
Map<ColumnName, List<Selector>> result = new LinkedHashMap<>(this.columns);
215+
result.computeIfAbsent(columnName, it -> new ArrayList<>()).add(selector);
215216

216217
return new Columns(result);
217218
}
@@ -232,9 +233,10 @@ public boolean isEmpty() {
232233
*/
233234
public Columns and(Columns columns) {
234235

235-
Map<ColumnName, Selector> result = new LinkedHashMap<>(this.columns);
236+
Map<ColumnName, List<Selector>> result = new LinkedHashMap<>(this.columns);
236237

237-
result.putAll(columns.columns);
238+
columns.columns
239+
.forEach((col, selectors) -> result.computeIfAbsent(col, columnName -> new ArrayList<>()).addAll(selectors));
238240

239241
return new Columns(result);
240242
}
@@ -248,11 +250,12 @@ public Iterator<ColumnName> iterator() {
248250
* @param columnName must not be {@literal null}.
249251
* @return the {@link Optional} {@link Selector} for {@link ColumnName}.
250252
*/
251-
public Optional<Selector> getSelector(ColumnName columnName) {
253+
public List<Selector> getSelector(ColumnName columnName) {
252254

253255
Assert.notNull(columnName, "ColumnName must not be null");
254256

255-
return Optional.ofNullable(this.columns.get(columnName));
257+
List<Selector> selectors = this.columns.get(columnName);
258+
return selectors == null ? List.of() : selectors;
256259
}
257260

258261
@Override
@@ -281,7 +284,7 @@ public int hashCode() {
281284
@Override
282285
public String toString() {
283286

284-
Iterator<Entry<ColumnName, Selector>> iterator = this.columns.entrySet().iterator();
287+
Iterator<Entry<ColumnName, List<Selector>>> iterator = this.columns.entrySet().iterator();
285288
StringBuilder builder = toString(iterator);
286289

287290
if (builder.isEmpty()) {
@@ -291,24 +294,26 @@ public String toString() {
291294
return builder.toString();
292295
}
293296

294-
private StringBuilder toString(Iterator<Entry<ColumnName, Selector>> iterator) {
297+
private StringBuilder toString(Iterator<Entry<ColumnName, List<Selector>>> iterator) {
295298

296299
StringBuilder builder = new StringBuilder();
297300
boolean first = true;
298301

299302
while (iterator.hasNext()) {
300303

301-
Entry<ColumnName, Selector> entry = iterator.next();
304+
Entry<ColumnName, List<Selector>> entry = iterator.next();
302305

303-
Selector expression = entry.getValue();
306+
for (Selector selector : entry.getValue()) {
304307

305-
if (first) {
306-
first = false;
307-
} else {
308-
builder.append(", ");
308+
if (first) {
309+
first = false;
310+
} else {
311+
builder.append(", ");
312+
}
313+
314+
builder.append(selector);
309315
}
310316

311-
builder.append(expression.toString());
312317
}
313318

314319
return builder;
@@ -670,8 +675,9 @@ public Selector dotProduct() {
670675

671676
@Override
672677
public Selector using(SimilarityFunction similarityFunction) {
673-
return FunctionCall.from("similarity_" + similarityFunction.name().toLowerCase(Locale.ROOT), columnName,
674-
vector).as(columnName);
678+
return FunctionCall
679+
.from("similarity_" + similarityFunction.name().toLowerCase(Locale.ROOT), columnName, vector)
680+
.as(columnName);
675681
}
676682
};
677683
}

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/query/QueryStatementCreator.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ class QueryStatementCreator {
6262

6363
private static final Map<ScoringFunction, SimilarityFunction> SIMILARITY_FUNCTIONS = Map.of(
6464
VectorScoringFunctions.COSINE, SimilarityFunction.COSINE, VectorScoringFunctions.EUCLIDEAN,
65-
SimilarityFunction.EUCLIDEAN, VectorScoringFunctions.DOT, SimilarityFunction.DOT_PRODUCT);
65+
SimilarityFunction.EUCLIDEAN, VectorScoringFunctions.DOT, SimilarityFunction.DOT_PRODUCT,
66+
VectorScoringFunctions.INNER_PRODUCT, SimilarityFunction.DOT_PRODUCT);
6667

6768
private static final Log LOG = LogFactory.getLog(QueryStatementCreator.class);
6869

spring-data-cassandra/src/test/java/org/springframework/data/cassandra/core/convert/QueryMapperUnitTests.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
import org.springframework.data.domain.Sort;
6666
import org.springframework.data.domain.Sort.Direction;
6767
import org.springframework.data.domain.Sort.Order;
68+
import org.springframework.data.domain.Vector;
6869
import org.springframework.lang.Nullable;
6970

7071
import com.datastax.oss.driver.api.core.CqlIdentifier;
@@ -368,6 +369,31 @@ void shouldMapColumnWithCompositePrimaryKeyClass() {
368369
assertThat(mappedObject).contains(CqlIdentifier.fromCql("first_name"));
369370
}
370371

372+
@Test //
373+
void shouldMapMultipleColumnNames() {
374+
375+
Columns columnNames = Columns.from("array").select("array",
376+
selectorBuilder -> selectorBuilder.similarity(Vector.of(1, 2)).cosine().as("score"));
377+
378+
List<CqlIdentifier> mappedObject = queryMapper.getMappedColumnNames(columnNames,
379+
mappingContext.getRequiredPersistentEntity(WithVector.class));
380+
381+
assertThat(mappedObject).contains(CqlIdentifier.fromCql("array"));
382+
}
383+
384+
@Test //
385+
void shouldMapMultipleSelectorsNames() {
386+
387+
Columns columnNames = Columns.from("array").select("array",
388+
selectorBuilder -> selectorBuilder.similarity(Vector.of(1, 2)).cosine().as("score"));
389+
390+
List<Selector> mappedObject = queryMapper.getMappedSelectors(columnNames,
391+
mappingContext.getRequiredPersistentEntity(WithVector.class));
392+
393+
assertThat(mappedObject).extracting(Selector::toString).contains("array",
394+
"similarity_cosine(array, [1.0, 2.0]) AS score");
395+
}
396+
371397
@Test // DATACASS-523
372398
@SuppressWarnings("all")
373399
void shouldMapTuple() {

0 commit comments

Comments
 (0)