2323import java .util .function .Consumer ;
2424import java .util .stream .Collectors ;
2525
26+ import org .bson .BinaryVector ;
2627import org .bson .Document ;
2728
2829import org .springframework .data .domain .Limit ;
30+ import org .springframework .data .domain .Vector ;
31+ import org .springframework .data .mongodb .core .mapping .MongoVector ;
2932import org .springframework .data .mongodb .core .query .Criteria ;
3033import org .springframework .data .mongodb .core .query .CriteriaDefinition ;
3134import org .springframework .lang .Contract ;
@@ -54,13 +57,13 @@ public class VectorSearchOperation implements AggregationOperation {
5457 private final Limit limit ;
5558 private final @ Nullable Integer numCandidates ;
5659 private final QueryPaths path ;
57- private final List <? extends Number > vector ;
60+ private final Vector vector ;
5861 private final String score ;
5962 private final Consumer <Criteria > scoreCriteria ;
6063
6164 private VectorSearchOperation (SearchType searchType , @ Nullable CriteriaDefinition filter , String indexName ,
62- Limit limit , @ Nullable Integer numCandidates , QueryPaths path , List <? extends Number > vector ,
63- @ Nullable String searchScore , Consumer <Criteria > scoreCriteria ) {
65+ Limit limit , @ Nullable Integer numCandidates , QueryPaths path , Vector vector , @ Nullable String searchScore ,
66+ Consumer <Criteria > scoreCriteria ) {
6467
6568 this .searchType = searchType ;
6669 this .filter = filter ;
@@ -73,7 +76,7 @@ private VectorSearchOperation(SearchType searchType, @Nullable CriteriaDefinitio
7376 this .scoreCriteria = scoreCriteria ;
7477 }
7578
76- VectorSearchOperation (String indexName , QueryPaths path , Limit limit , List <? extends Number > vector ) {
79+ VectorSearchOperation (String indexName , QueryPaths path , Limit limit , Vector vector ) {
7780 this (SearchType .DEFAULT , null , indexName , limit , null , path , vector , null , null );
7881 }
7982
@@ -249,8 +252,18 @@ public Document toDocument(AggregationOperationContext context) {
249252 path = mappedObject .keySet ().iterator ().next ();
250253 }
251254
255+ Object source = vector .getSource ();
256+
257+ if (source instanceof float []) {
258+ source = vector .toDoubleArray ();
259+ }
260+
261+ if (source instanceof double [] ds ) {
262+ source = Arrays .stream (ds ).boxed ().collect (Collectors .toList ());
263+ }
264+
252265 $vectorSearch .append ("path" , path );
253- $vectorSearch .append ("queryVector" , vector );
266+ $vectorSearch .append ("queryVector" , source );
254267
255268 return new Document (getOperator (), $vectorSearch );
256269 }
@@ -288,7 +301,7 @@ private static class VectorSearchBuilder implements PathContributor, VectorContr
288301
289302 String index ;
290303 QueryPath <String > paths ;
291- private List <? extends Number > vector ;
304+ Vector vector ;
292305
293306 PathContributor index (String index ) {
294307 this .index = index ;
@@ -308,8 +321,8 @@ public VectorSearchOperation limit(Limit limit) {
308321 }
309322
310323 @ Override
311- public LimitContributor vector (List <? extends Number > vectors ) {
312- this .vector = vectors ;
324+ public LimitContributor vector (Vector vector ) {
325+ this .vector = vector ;
313326 return this ;
314327 }
315328 }
@@ -428,28 +441,63 @@ public interface PathContributor {
428441 public interface VectorContributor {
429442
430443 /**
431- * Array of numbers of the BSON double, BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or
432- * int8 type that represent the query vector. The number type must match the indexed field value type. Otherwise,
433- * Atlas Vector Search doesn't return any results or errors.
444+ * Array of float numbers that represent the query vector. The number type must match the indexed field value type.
445+ * Otherwise, Atlas Vector Search doesn't return any results or errors.
446+ *
447+ * @param vector the query vector.
448+ * @return
449+ */
450+ @ Contract ("_ -> this" )
451+ default LimitContributor vector (float ... vector ) {
452+ return vector (Vector .of (vector ));
453+ }
454+
455+ /**
456+ * Array of double numbers that represent the query vector. The number type must match the indexed field value type.
457+ * Otherwise, Atlas Vector Search doesn't return any results or errors.
458+ *
459+ * @param vector the query vector.
460+ * @return
461+ */
462+ @ Contract ("_ -> this" )
463+ default LimitContributor vector (double ... vector ) {
464+ return vector (Vector .of (vector ));
465+ }
466+
467+ /**
468+ * Array of numbers that represent the query vector. The number type must match the indexed field value type.
469+ * Otherwise, Atlas Vector Search doesn't return any results or errors.
470+ *
471+ * @param vector the query vector.
472+ * @return
473+ */
474+ @ Contract ("_ -> this" )
475+ default LimitContributor vector (List <? extends Number > vector ) {
476+ return vector (Vector .of (vector ));
477+ }
478+
479+ /**
480+ * Binary vector (BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or int8 type) that
481+ * represent the query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector
482+ * Search doesn't return any results or errors.
434483 *
435- * @param vectors
484+ * @param vector the query vector.
436485 * @return
437486 */
438487 @ Contract ("_ -> this" )
439- default LimitContributor vector (Double ... vectors ) {
440- return vector (Arrays . asList ( vectors ));
488+ default LimitContributor vector (BinaryVector vector ) {
489+ return vector (MongoVector . of ( vector ));
441490 }
442491
443492 /**
444- * Array of numbers of the BSON double, BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or
445- * int8 type that represent the query vector. The number type must match the indexed field value type. Otherwise,
446- * Atlas Vector Search doesn't return any results or errors.
493+ * The query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector Search doesn't
494+ * return any results or errors.
447495 *
448- * @param vectors
496+ * @param vector the query vector.
449497 * @return
450498 */
451499 @ Contract ("_ -> this" )
452- LimitContributor vector (List <? extends Number > vectors );
500+ LimitContributor vector (Vector vector );
453501 }
454502
455503 public interface LimitContributor {
0 commit comments