diff --git a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java index 9223bec60d6..6908a7fd26f 100644 --- a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java +++ b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java @@ -485,9 +485,13 @@ public Document mapRow(ResultSet rs, int rowNum) throws SQLException { Map metadata = toMap(rs.getString(3)); float distance = rs.getFloat(4); - metadata.put("distance", distance); - - return new Document(id, content, metadata); + // @formatter:off + return Document.builder() + .id(id) + .text(content) + .metadata(metadata) + .score(1.0 - distance) + .build(); // @formatter:on } private Map toMap(String source) { diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java index d4d1e8edb92..41764924b2b 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java @@ -122,20 +122,20 @@ static Stream provideFilters() { ); } - private static boolean isSortedByDistance(List docs) { + private static boolean isSortedByScore(List docs) { - List distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + List scores = docs.stream().map(Document::getScore).toList(); - if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + if (CollectionUtils.isEmpty(scores) || scores.size() == 1) { return true; } - Iterator iter = distances.iterator(); - Float current; - Float previous = iter.next(); + Iterator iter = scores.iterator(); + Double current; + Double previous = iter.next(); while (iter.hasNext()) { current = iter.next(); - if (previous > current) { + if (previous < current) { return false; } previous = current; @@ -166,7 +166,8 @@ public void addAndSearch(String distanceType) { assertThat(results).hasSize(1); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2"); + assertThat(resultDoc.getScore()).isBetween(0.0, 1.0); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); @@ -315,7 +316,8 @@ public void documentUpdate(String distanceType) { Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getText()).isEqualTo("Spring AI rocks!!"); - assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1"); + assertThat(resultDoc.getScore()).isBetween(0.0, 1.0); Document sameIdDocument = new Document(document.getId(), "The World is Big and Salvation Lurks Around the Corner", @@ -329,7 +331,8 @@ public void documentUpdate(String distanceType) { resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); assertThat(resultDoc.getText()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); - assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2"); + assertThat(resultDoc.getScore()).isBetween(0.0, 1.0); dropTable(context); }); @@ -350,19 +353,14 @@ public void searchWithThreshold(String distanceType) { assertThat(fullResult).hasSize(3); - assertThat(isSortedByDistance(fullResult)).isTrue(); + assertThat(isSortedByScore(fullResult)).isTrue(); - List distances = fullResult.stream() - .map(doc -> (Float) doc.getMetadata().get("distance")) - .toList(); + List scores = fullResult.stream().map(Document::getScore).toList(); - float threshold = (distances.get(0) + distances.get(1)) / 2; + double threshold = (scores.get(0) + scores.get(1)) / 2; - List results = vectorStore.similaritySearch(SearchRequest.builder() - .query("Time Shelter") - .topK(5) - .similarityThreshold(1 - threshold) - .build()); + List results = vectorStore.similaritySearch( + SearchRequest.builder().query("Time Shelter").topK(5).similarityThreshold(threshold).build()); assertThat(results).hasSize(1); Document resultDoc = results.get(0);