Skip to content

Commit 467dbc3

Browse files
committed
Support similarity scores in Document API
Document * Introduced “score” attribute in Document API. It stores the similarity score. * Consolidate “distance” metadata for Documents. It stores the distance measurement. * Adopted prefix-less naming convention in Document.Builder and deprecated old methods. * Deprecated the many overloaded Document constructors in favour of Document.Builder. Vector Stores * Every vector store implementation now configures a “score” attribute with the similarity score of the Document embedding. It also includes the “distance” metadata with the distance measurement. * Fixed error in Elasticsearch where distance and similarity were mixed up. * Added missing integration tests for SimpleVectorStore. * The Azure Vector Store and HanaDB Vector Store do not include those measurements because the product documentation do not include information about how the similarity score is returned, and without access to the cloud products I could not verify that via debugging. * Improved tests to actually assert the result of the similarity search based on the returned score. Signed-off-by: Thomas Vitale <[email protected]>
1 parent 67a8896 commit 467dbc3

File tree

51 files changed

+859
-387
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+859
-387
lines changed

spring-ai-core/src/main/java/org/springframework/ai/document/Document.java

Lines changed: 127 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.HashMap;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.Objects;
2425

2526
import com.fasterxml.jackson.annotation.JsonCreator;
2627
import com.fasterxml.jackson.annotation.JsonIgnore;
@@ -31,6 +32,7 @@
3132
import org.springframework.ai.document.id.RandomIdGenerator;
3233
import org.springframework.ai.model.Media;
3334
import org.springframework.ai.model.MediaContent;
35+
import org.springframework.lang.Nullable;
3436
import org.springframework.util.Assert;
3537
import org.springframework.util.StringUtils;
3638

@@ -61,7 +63,15 @@ public class Document implements MediaContent {
6163
* Metadata for the document. It should not be nested and values should be restricted
6264
* to string, int, float, boolean for simple use with Vector Dbs.
6365
*/
64-
private Map<String, Object> metadata;
66+
private final Map<String, Object> metadata;
67+
68+
/**
69+
* Measure of similarity between the document embedding and the query vector. The
70+
* higher the score, the more they are similar. It's the opposite of the distance
71+
* measure.
72+
*/
73+
@Nullable
74+
private Double score;
6575

6676
/**
6777
* Embedding of the document. Note: ephemeral field.
@@ -80,31 +90,61 @@ public Document(@JsonProperty("content") String content) {
8090
this(content, new HashMap<>());
8191
}
8292

93+
/**
94+
* @deprecated Use builder instead: {@link Document#builder()}.
95+
*/
96+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
8397
public Document(String content, Map<String, Object> metadata) {
8498
this(content, metadata, new RandomIdGenerator());
8599
}
86100

101+
/**
102+
* @deprecated Use builder instead: {@link Document#builder()}.
103+
*/
104+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
87105
public Document(String content, Collection<Media> media, Map<String, Object> metadata) {
88106
this(new RandomIdGenerator().generateId(content, metadata), content, media, metadata);
89107
}
90108

109+
/**
110+
* @deprecated Use builder instead: {@link Document#builder()}.
111+
*/
112+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
91113
public Document(String content, Map<String, Object> metadata, IdGenerator idGenerator) {
92114
this(idGenerator.generateId(content, metadata), content, metadata);
93115
}
94116

117+
/**
118+
* @deprecated Use builder instead: {@link Document#builder()}.
119+
*/
120+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
95121
public Document(String id, String content, Map<String, Object> metadata) {
96122
this(id, content, List.of(), metadata);
97123
}
98124

125+
/**
126+
* @deprecated Use builder instead: {@link Document#builder()}.
127+
*/
128+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
99129
public Document(String id, String content, Collection<Media> media, Map<String, Object> metadata) {
100-
Assert.hasText(id, "id must not be null or empty");
101-
Assert.notNull(content, "content must not be null");
102-
Assert.notNull(metadata, "metadata must not be null");
130+
this(id, content, media, metadata, null);
131+
}
132+
133+
public Document(String id, String content, Collection<Media> media, Map<String, Object> metadata,
134+
@Nullable Double score) {
135+
Assert.hasText(id, "id cannot be null or empty");
136+
Assert.notNull(content, "content cannot be null");
137+
Assert.notNull(media, "media cannot be null");
138+
Assert.noNullElements(media, "media cannot have null elements");
139+
Assert.notNull(metadata, "metadata cannot be null");
140+
Assert.noNullElements(metadata.keySet(), "metadata cannot have null keys");
141+
Assert.noNullElements(metadata.values(), "metadata cannot have null values");
103142

104143
this.id = id;
105144
this.content = content;
106-
this.media = media;
107-
this.metadata = metadata;
145+
this.media = media != null ? media : List.of();
146+
this.metadata = metadata != null ? metadata : new HashMap<>();
147+
this.score = score;
108148
}
109149

110150
public static Builder builder() {
@@ -149,6 +189,15 @@ public Map<String, Object> getMetadata() {
149189
return this.metadata;
150190
}
151191

192+
@Nullable
193+
public Double getScore() {
194+
return this.score;
195+
}
196+
197+
public void setScore(@Nullable Double score) {
198+
this.score = score;
199+
}
200+
152201
/**
153202
* Return the embedding that were calculated.
154203
* @deprecated We are considering getting rid of this, please comment on
@@ -186,57 +235,24 @@ public void setContentFormatter(ContentFormatter contentFormatter) {
186235

187236
@Override
188237
public int hashCode() {
189-
final int prime = 31;
190-
int result = 1;
191-
result = prime * result + ((this.id == null) ? 0 : this.id.hashCode());
192-
result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode());
193-
result = prime * result + ((this.content == null) ? 0 : this.content.hashCode());
194-
return result;
238+
return Objects.hash(id, content, media, metadata);
195239
}
196240

197241
@Override
198-
public boolean equals(Object obj) {
199-
if (this == obj) {
242+
public boolean equals(Object o) {
243+
if (this == o)
200244
return true;
201-
}
202-
if (obj == null) {
203-
return false;
204-
}
205-
if (getClass() != obj.getClass()) {
206-
return false;
207-
}
208-
Document other = (Document) obj;
209-
if (this.id == null) {
210-
if (other.id != null) {
211-
return false;
212-
}
213-
}
214-
else if (!this.id.equals(other.id)) {
215-
return false;
216-
}
217-
if (this.metadata == null) {
218-
if (other.metadata != null) {
219-
return false;
220-
}
221-
}
222-
else if (!this.metadata.equals(other.metadata)) {
245+
if (o == null || getClass() != o.getClass())
223246
return false;
224-
}
225-
if (this.content == null) {
226-
if (other.content != null) {
227-
return false;
228-
}
229-
}
230-
else if (!this.content.equals(other.content)) {
231-
return false;
232-
}
233-
return true;
247+
Document document = (Document) o;
248+
return Objects.equals(id, document.id) && Objects.equals(content, document.content)
249+
&& Objects.equals(media, document.media) && Objects.equals(metadata, document.metadata);
234250
}
235251

236252
@Override
237253
public String toString() {
238-
return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content
239-
+ '\'' + ", media=" + this.media + '}';
254+
return "Document{" + "id='" + id + '\'' + ", content='" + content + '\'' + ", media=" + media + ", metadata="
255+
+ metadata + ", score=" + score + '}';
240256
}
241257

242258
public static class Builder {
@@ -249,56 +265,102 @@ public static class Builder {
249265

250266
private Map<String, Object> metadata = new HashMap<>();
251267

268+
private float[] embedding = new float[0];
269+
270+
private Double score;
271+
252272
private IdGenerator idGenerator = new RandomIdGenerator();
253273

254-
public Builder withIdGenerator(IdGenerator idGenerator) {
255-
Assert.notNull(idGenerator, "idGenerator must not be null");
274+
public Builder idGenerator(IdGenerator idGenerator) {
275+
Assert.notNull(idGenerator, "idGenerator cannot be null");
256276
this.idGenerator = idGenerator;
257277
return this;
258278
}
259279

260-
public Builder withId(String id) {
261-
Assert.hasText(id, "id must not be null or empty");
280+
public Builder id(String id) {
281+
Assert.hasText(id, "id cannot be null or empty");
262282
this.id = id;
263283
return this;
264284
}
265285

266-
public Builder withContent(String content) {
267-
Assert.notNull(content, "content must not be null");
286+
public Builder content(String content) {
268287
this.content = content;
269288
return this;
270289
}
271290

272-
public Builder withMedia(List<Media> media) {
273-
Assert.notNull(media, "media must not be null");
291+
public Builder media(List<Media> media) {
274292
this.media = media;
275293
return this;
276294
}
277295

278-
public Builder withMedia(Media media) {
279-
Assert.notNull(media, "media must not be null");
280-
this.media.add(media);
296+
public Builder media(Media... media) {
297+
Assert.noNullElements(media, "media cannot contain null elements");
298+
this.media.addAll(List.of(media));
281299
return this;
282300
}
283301

284-
public Builder withMetadata(Map<String, Object> metadata) {
285-
Assert.notNull(metadata, "metadata must not be null");
302+
public Builder metadata(Map<String, Object> metadata) {
286303
this.metadata = metadata;
287304
return this;
288305
}
289306

290-
public Builder withMetadata(String key, Object value) {
291-
Assert.notNull(key, "key must not be null");
292-
Assert.notNull(value, "value must not be null");
307+
public Builder metadata(String key, Object value) {
293308
this.metadata.put(key, value);
294309
return this;
295310
}
296311

312+
public Builder embedding(float[] embedding) {
313+
this.embedding = embedding;
314+
return this;
315+
}
316+
317+
public Builder score(Double score) {
318+
this.score = score;
319+
return this;
320+
}
321+
322+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
323+
public Builder withIdGenerator(IdGenerator idGenerator) {
324+
return idGenerator(idGenerator);
325+
}
326+
327+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
328+
public Builder withId(String id) {
329+
return id(id);
330+
}
331+
332+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
333+
public Builder withContent(String content) {
334+
return content(content);
335+
}
336+
337+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
338+
public Builder withMedia(List<Media> media) {
339+
return media(media);
340+
}
341+
342+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
343+
public Builder withMedia(Media media) {
344+
return media(media);
345+
}
346+
347+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
348+
public Builder withMetadata(Map<String, Object> metadata) {
349+
return metadata(metadata);
350+
}
351+
352+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
353+
public Builder withMetadata(String key, Object value) {
354+
return metadata(key, value);
355+
}
356+
297357
public Document build() {
298358
if (!StringUtils.hasText(this.id)) {
299359
this.id = this.idGenerator.generateId(this.content, this.metadata);
300360
}
301-
return new Document(this.id, this.content, this.media, this.metadata);
361+
var document = new Document(this.id, this.content, this.media, this.metadata, this.score);
362+
document.setEmbedding(this.embedding);
363+
return document;
302364
}
303365

304366
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.document;
18+
19+
import org.springframework.ai.vectorstore.VectorStore;
20+
21+
/**
22+
* Common set of metadata keys used in {@link Document}s by {@link DocumentReader}s and
23+
* {@link VectorStore}s.
24+
*
25+
* @author Thomas Vitale
26+
* @since 1.0.0
27+
*/
28+
public enum DocumentMetadata {
29+
30+
// @formatter:off
31+
32+
/**
33+
* Measure of distance between the document embedding and the query vector.
34+
* The lower the distance, the more they are similar.
35+
* It's the opposite of the similarity score.
36+
*/
37+
DISTANCE("distance");
38+
39+
private final String value;
40+
41+
DocumentMetadata(String value) {
42+
this.value = value;
43+
}
44+
public String value() {
45+
return this.value;
46+
}
47+
48+
// @formatter:on
49+
50+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
@NonNullApi
18+
@NonNullFields
19+
package org.springframework.ai.document;
20+
21+
import org.springframework.lang.NonNullApi;
22+
import org.springframework.lang.NonNullFields;

0 commit comments

Comments
 (0)