Skip to content

Commit 9d207e3

Browse files
ilayaperumalgmarkpollack
authored andcommitted
Refactor SimpleVectorStore
- Remove SimpleVectorStore's dependency on deprecated embeddings from Document object - Create a custom Content object that represents the SimpleVectorStore's contents and embedding - Add tests
1 parent d173572 commit 9d207e3

File tree

4 files changed

+504
-14
lines changed

4 files changed

+504
-14
lines changed

spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,15 @@
6767
* @author Mark Pollack
6868
* @author Christian Tzolov
6969
* @author Sebastien Deleuze
70+
* @author Ilayaperumal Gopinathan
7071
*/
7172
public class SimpleVectorStore extends AbstractObservationVectorStore {
7273

7374
private static final Logger logger = LoggerFactory.getLogger(SimpleVectorStore.class);
7475

7576
private final ObjectMapper objectMapper;
7677

77-
protected Map<String, Document> store = new ConcurrentHashMap<>();
78+
protected Map<String, SimpleVectorStoreContent> store = new ConcurrentHashMap<>();
7879

7980
protected EmbeddingModel embeddingModel;
8081

@@ -94,11 +95,17 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse
9495

9596
@Override
9697
public void doAdd(List<Document> documents) {
98+
Objects.requireNonNull(documents, "Documents list cannot be null");
99+
if (documents.isEmpty()) {
100+
throw new IllegalArgumentException("Documents list cannot be empty");
101+
}
102+
97103
for (Document document : documents) {
98104
logger.info("Calling EmbeddingModel for document id = {}", document.getId());
99105
float[] embedding = this.embeddingModel.embed(document);
100-
document.setEmbedding(embedding);
101-
this.store.put(document.getId(), document);
106+
SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent(document.getId(),
107+
document.getContent(), document.getMetadata(), embedding);
108+
this.store.put(document.getId(), storeContent);
102109
}
103110
}
104111

@@ -120,12 +127,12 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
120127
float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
121128
return this.store.values()
122129
.stream()
123-
.map(entry -> new Similarity(entry.getId(),
130+
.map(entry -> new Similarity(entry,
124131
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
125132
.filter(s -> s.score >= request.getSimilarityThreshold())
126133
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
127134
.limit(request.getTopK())
128-
.map(s -> this.store.get(s.key))
135+
.map(s -> s.getDocument())
129136
.toList();
130137
}
131138

@@ -176,12 +183,11 @@ public void save(File file) {
176183
* @param file the file to load the vector store content
177184
*/
178185
public void load(File file) {
179-
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
186+
TypeReference<HashMap<String, SimpleVectorStoreContent>> typeRef = new TypeReference<>() {
180187

181188
};
182189
try {
183-
Map<String, Document> deserializedMap = this.objectMapper.readValue(file, typeRef);
184-
this.store = deserializedMap;
190+
this.store = this.objectMapper.readValue(file, typeRef);
185191
}
186192
catch (IOException ex) {
187193
throw new RuntimeException(ex);
@@ -193,12 +199,11 @@ public void load(File file) {
193199
* @param resource the resource to load the vector store content
194200
*/
195201
public void load(Resource resource) {
196-
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
202+
TypeReference<HashMap<String, SimpleVectorStoreContent>> typeRef = new TypeReference<>() {
197203

198204
};
199205
try {
200-
Map<String, Document> deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef);
201-
this.store = deserializedMap;
206+
this.store = this.objectMapper.readValue(resource.getInputStream(), typeRef);
202207
}
203208
catch (IOException ex) {
204209
throw new RuntimeException(ex);
@@ -232,15 +237,23 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
232237

233238
public static class Similarity {
234239

235-
private String key;
240+
private SimpleVectorStoreContent content;
236241

237242
private double score;
238243

239-
public Similarity(String key, double score) {
240-
this.key = key;
244+
public Similarity(SimpleVectorStoreContent content, double score) {
245+
this.content = content;
241246
this.score = score;
242247
}
243248

249+
Document getDocument() {
250+
return Document.builder()
251+
.withId(this.content.getId())
252+
.withContent(this.content.getContent())
253+
.withMetadata(this.content.getMetadata())
254+
.build();
255+
}
256+
244257
}
245258

246259
public final class EmbeddingMath {
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vectorstore;
18+
19+
import java.util.Arrays;
20+
import java.util.Collections;
21+
import java.util.HashMap;
22+
import java.util.Map;
23+
import java.util.Objects;
24+
25+
import com.fasterxml.jackson.annotation.JsonCreator;
26+
import com.fasterxml.jackson.annotation.JsonProperty;
27+
28+
import org.springframework.ai.document.id.IdGenerator;
29+
import org.springframework.ai.document.id.RandomIdGenerator;
30+
import org.springframework.ai.model.Content;
31+
import org.springframework.util.Assert;
32+
33+
/**
34+
* An immutable {@link Content} implementation representing content, metadata, and its
35+
* embeddings. This class is thread-safe and all its fields are final and deeply
36+
* immutable. The embedding vector is required for all instances of this class.
37+
*/
38+
public final class SimpleVectorStoreContent implements Content {
39+
40+
private final String id;
41+
42+
private final String content;
43+
44+
private final Map<String, Object> metadata;
45+
46+
private final float[] embedding;
47+
48+
/**
49+
* Creates a new instance with the given content, empty metadata, and embedding
50+
* vector.
51+
* @param content the content text, must not be null
52+
* @param embedding the embedding vector, must not be null
53+
*/
54+
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
55+
public SimpleVectorStoreContent(@JsonProperty("content") String content,
56+
@JsonProperty("embedding") float[] embedding) {
57+
this(content, new HashMap<>(), embedding);
58+
}
59+
60+
/**
61+
* Creates a new instance with the given content, metadata, and embedding vector.
62+
* @param content the content text, must not be null
63+
* @param metadata the metadata map, must not be null
64+
* @param embedding the embedding vector, must not be null
65+
*/
66+
public SimpleVectorStoreContent(String content, Map<String, Object> metadata, float[] embedding) {
67+
this(content, metadata, new RandomIdGenerator(), embedding);
68+
}
69+
70+
/**
71+
* Creates a new instance with the given content, metadata, custom ID generator, and
72+
* embedding vector.
73+
* @param content the content text, must not be null
74+
* @param metadata the metadata map, must not be null
75+
* @param idGenerator the ID generator to use, must not be null
76+
* @param embedding the embedding vector, must not be null
77+
*/
78+
public SimpleVectorStoreContent(String content, Map<String, Object> metadata, IdGenerator idGenerator,
79+
float[] embedding) {
80+
this(idGenerator.generateId(content, metadata), content, metadata, embedding);
81+
}
82+
83+
/**
84+
* Creates a new instance with all fields specified.
85+
* @param id the unique identifier, must not be empty
86+
* @param content the content text, must not be null
87+
* @param metadata the metadata map, must not be null
88+
* @param embedding the embedding vector, must not be null
89+
* @throws IllegalArgumentException if any parameter is null or if id is empty
90+
*/
91+
public SimpleVectorStoreContent(String id, String content, Map<String, Object> metadata, float[] embedding) {
92+
Assert.hasText(id, "id must not be null or empty");
93+
Assert.notNull(content, "content must not be null");
94+
Assert.notNull(metadata, "metadata must not be null");
95+
Assert.notNull(embedding, "embedding must not be null");
96+
Assert.isTrue(embedding.length > 0, "embedding vector must not be empty");
97+
98+
this.id = id;
99+
this.content = content;
100+
this.metadata = Collections.unmodifiableMap(new HashMap<>(metadata));
101+
this.embedding = Arrays.copyOf(embedding, embedding.length);
102+
}
103+
104+
/**
105+
* Creates a new instance with an updated embedding vector.
106+
* @param embedding the new embedding vector, must not be null
107+
* @return a new instance with the updated embedding
108+
* @throws IllegalArgumentException if embedding is null or empty
109+
*/
110+
public SimpleVectorStoreContent withEmbedding(float[] embedding) {
111+
Assert.notNull(embedding, "embedding must not be null");
112+
Assert.isTrue(embedding.length > 0, "embedding vector must not be empty");
113+
return new SimpleVectorStoreContent(this.id, this.content, this.metadata, embedding);
114+
}
115+
116+
public String getId() {
117+
return this.id;
118+
}
119+
120+
@Override
121+
public String getContent() {
122+
return this.content;
123+
}
124+
125+
@Override
126+
public Map<String, Object> getMetadata() {
127+
return this.metadata;
128+
}
129+
130+
/**
131+
* Returns a defensive copy of the embedding vector.
132+
* @return a new array containing the embedding vector
133+
*/
134+
public float[] getEmbedding() {
135+
return Arrays.copyOf(this.embedding, this.embedding.length);
136+
}
137+
138+
@Override
139+
public boolean equals(Object o) {
140+
if (this == o)
141+
return true;
142+
if (o == null || getClass() != o.getClass())
143+
return false;
144+
145+
SimpleVectorStoreContent that = (SimpleVectorStoreContent) o;
146+
return Objects.equals(this.id, that.id) && Objects.equals(this.content, that.content)
147+
&& Objects.equals(this.metadata, that.metadata) && Arrays.equals(this.embedding, that.embedding);
148+
}
149+
150+
@Override
151+
public int hashCode() {
152+
int result = Objects.hashCode(this.id);
153+
result = 31 * result + Objects.hashCode(this.content);
154+
result = 31 * result + Objects.hashCode(this.metadata);
155+
result = 31 * result + Arrays.hashCode(this.embedding);
156+
return result;
157+
}
158+
159+
@Override
160+
public String toString() {
161+
return "SimpleVectorStoreContent{" + "id='" + this.id + '\'' + ", content='" + this.content + '\''
162+
+ ", metadata=" + this.metadata + ", embedding=" + Arrays.toString(embedding) + '}';
163+
}
164+
165+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vectorstore;
18+
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
22+
import org.junit.Test;
23+
24+
import org.springframework.ai.document.Document;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
/**
29+
* @author Ilayaperumal Gopinathan
30+
*/
31+
public class SimpleVectorStoreSimilarityTests {
32+
33+
@Test
34+
public void testSimilarity() {
35+
Map<String, Object> metadata = new HashMap<>();
36+
metadata.put("foo", "bar");
37+
float[] testEmbedding = new float[] { 1.0f, 2.0f, 3.0f };
38+
39+
SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent("1", "hello, how are you?", metadata,
40+
testEmbedding);
41+
SimpleVectorStore.Similarity similarity = new SimpleVectorStore.Similarity(storeContent, 0.6d);
42+
Document document = similarity.getDocument();
43+
assertThat(document).isNotNull();
44+
assertThat(document.getId()).isEqualTo("1");
45+
assertThat(document.getContent()).isEqualTo("hello, how are you?");
46+
assertThat(document.getMetadata().get("foo")).isEqualTo("bar");
47+
}
48+
49+
}

0 commit comments

Comments
 (0)