Skip to content

Commit 25d6169

Browse files
committed
Refactor SimpleVectorStore
- Remove SimpleVectorStore's dependency on deprecated embeddings from Document object - Create a custom Content object that represents the SimpleVectorStore's contents and embedding
1 parent 979de19 commit 25d6169

File tree

3 files changed

+228
-14
lines changed

3 files changed

+228
-14
lines changed

spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,15 @@
6767
* @author Mark Pollack
6868
* @author Christian Tzolov
6969
* @author Sebastien Deleuze
70+
* @author Ilayaperumal Gopinathan
7071
*/
7172
public class SimpleVectorStore extends AbstractObservationVectorStore {
7273

7374
private static final Logger logger = LoggerFactory.getLogger(SimpleVectorStore.class);
7475

7576
private final ObjectMapper objectMapper;
7677

77-
protected Map<String, Document> store = new ConcurrentHashMap<>();
78+
protected Map<String, SimpleVectorStoreContent> store = new ConcurrentHashMap<>();
7879

7980
protected EmbeddingModel embeddingModel;
8081

@@ -97,8 +98,10 @@ public void doAdd(List<Document> documents) {
9798
for (Document document : documents) {
9899
logger.info("Calling EmbeddingModel for document id = {}", document.getId());
99100
float[] embedding = this.embeddingModel.embed(document);
100-
document.setEmbedding(embedding);
101-
this.store.put(document.getId(), document);
101+
SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent(document.getId(),
102+
document.getContent(), document.getMetadata());
103+
storeContent.setEmbedding(embedding);
104+
this.store.put(document.getId(), storeContent);
102105
}
103106
}
104107

@@ -120,12 +123,12 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
120123
float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
121124
return this.store.values()
122125
.stream()
123-
.map(entry -> new Similarity(entry.getId(),
126+
.map(entry -> new Similarity(entry,
124127
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
125128
.filter(s -> s.score >= request.getSimilarityThreshold())
126129
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
127130
.limit(request.getTopK())
128-
.map(s -> this.store.get(s.key))
131+
.map(s -> s.getDocument())
129132
.toList();
130133
}
131134

@@ -176,12 +179,11 @@ public void save(File file) {
176179
* @param file the file to load the vector store content
177180
*/
178181
public void load(File file) {
179-
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
182+
TypeReference<HashMap<String, SimpleVectorStoreContent>> typeRef = new TypeReference<>() {
180183

181184
};
182185
try {
183-
Map<String, Document> deserializedMap = this.objectMapper.readValue(file, typeRef);
184-
this.store = deserializedMap;
186+
this.store = this.objectMapper.readValue(file, typeRef);
185187
}
186188
catch (IOException ex) {
187189
throw new RuntimeException(ex);
@@ -193,12 +195,11 @@ public void load(File file) {
193195
* @param resource the resource to load the vector store content
194196
*/
195197
public void load(Resource resource) {
196-
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
198+
TypeReference<HashMap<String, SimpleVectorStoreContent>> typeRef = new TypeReference<>() {
197199

198200
};
199201
try {
200-
Map<String, Document> deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef);
201-
this.store = deserializedMap;
202+
this.store = this.objectMapper.readValue(resource.getInputStream(), typeRef);
202203
}
203204
catch (IOException ex) {
204205
throw new RuntimeException(ex);
@@ -232,15 +233,23 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
232233

233234
public static class Similarity {
234235

235-
private String key;
236+
private SimpleVectorStoreContent content;
236237

237238
private double score;
238239

239-
public Similarity(String key, double score) {
240-
this.key = key;
240+
public Similarity(SimpleVectorStoreContent content, double score) {
241+
this.content = content;
241242
this.score = score;
242243
}
243244

245+
Document getDocument() {
246+
return Document.builder()
247+
.withId(this.content.getId())
248+
.withContent(this.content.getContent())
249+
.withMetadata(this.content.getMetadata())
250+
.build();
251+
}
252+
244253
}
245254

246255
public final class EmbeddingMath {
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vectorstore;
18+
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
22+
import com.fasterxml.jackson.annotation.JsonCreator;
23+
import com.fasterxml.jackson.annotation.JsonProperty;
24+
25+
import org.springframework.ai.document.id.IdGenerator;
26+
import org.springframework.ai.document.id.RandomIdGenerator;
27+
import org.springframework.ai.model.Content;
28+
import org.springframework.util.Assert;
29+
30+
/**
31+
* A simple {@link Content} object which represents the content, metadata along its
32+
* embeddings.
33+
*/
34+
public class SimpleVectorStoreContent implements Content {
35+
36+
/**
37+
* Unique ID
38+
*/
39+
private final String id;
40+
41+
/**
42+
* Document content.
43+
*/
44+
private final String content;
45+
46+
/**
47+
* Metadata for the document. It should not be nested and values should be restricted
48+
* to string, int, float, boolean for simple use with Vector Dbs.
49+
*/
50+
private Map<String, Object> metadata;
51+
52+
/**
53+
* Embedding of the document. Note: ephemeral field.
54+
*/
55+
@JsonProperty(index = 100)
56+
private float[] embedding = new float[0];
57+
58+
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
59+
public SimpleVectorStoreContent(@JsonProperty("content") String content) {
60+
this(content, new HashMap<>());
61+
}
62+
63+
public SimpleVectorStoreContent(String content, Map<String, Object> metadata) {
64+
this(content, metadata, new RandomIdGenerator());
65+
}
66+
67+
public SimpleVectorStoreContent(String content, Map<String, Object> metadata, IdGenerator idGenerator) {
68+
this(idGenerator.generateId(content, metadata), content, metadata);
69+
}
70+
71+
public SimpleVectorStoreContent(String id, String content, Map<String, Object> metadata) {
72+
Assert.hasText(id, "id must not be null or empty");
73+
Assert.notNull(content, "content must not be null");
74+
Assert.notNull(metadata, "metadata must not be null");
75+
76+
this.id = id;
77+
this.content = content;
78+
this.metadata = metadata;
79+
}
80+
81+
public String getId() {
82+
return this.id;
83+
}
84+
85+
@Override
86+
public String getContent() {
87+
return this.content;
88+
}
89+
90+
@Override
91+
public Map<String, Object> getMetadata() {
92+
return this.metadata;
93+
}
94+
95+
public float[] getEmbedding() {
96+
return this.embedding;
97+
}
98+
99+
public void setEmbedding(float[] embedding) {
100+
Assert.notNull(embedding, "embedding must not be null");
101+
this.embedding = embedding;
102+
}
103+
104+
@Override
105+
public int hashCode() {
106+
final int prime = 31;
107+
int result = 1;
108+
result = prime * result + ((this.id == null) ? 0 : this.id.hashCode());
109+
result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode());
110+
result = prime * result + ((this.content == null) ? 0 : this.content.hashCode());
111+
return result;
112+
}
113+
114+
@Override
115+
public boolean equals(Object obj) {
116+
if (this == obj) {
117+
return true;
118+
}
119+
if (obj == null) {
120+
return false;
121+
}
122+
if (getClass() != obj.getClass()) {
123+
return false;
124+
}
125+
SimpleVectorStoreContent other = (SimpleVectorStoreContent) obj;
126+
if (this.id == null) {
127+
if (other.id != null) {
128+
return false;
129+
}
130+
}
131+
else if (!this.id.equals(other.id)) {
132+
return false;
133+
}
134+
if (this.metadata == null) {
135+
if (other.metadata != null) {
136+
return false;
137+
}
138+
}
139+
else if (!this.metadata.equals(other.metadata)) {
140+
return false;
141+
}
142+
if (this.content == null) {
143+
if (other.content != null) {
144+
return false;
145+
}
146+
}
147+
else if (!this.content.equals(other.content)) {
148+
return false;
149+
}
150+
return true;
151+
}
152+
153+
@Override
154+
public String toString() {
155+
return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content
156+
+ '}';
157+
}
158+
159+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vectorstore;
18+
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
22+
import org.junit.Test;
23+
24+
import org.springframework.ai.document.Document;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
/**
29+
* @author Ilayaperumal Gopinathan
30+
*/
31+
public class SimpleVectorStoreSimilarityTests {
32+
33+
@Test
34+
public void testSimilarity() {
35+
Map<String, Object> metadata = new HashMap<>();
36+
metadata.put("foo", "bar");
37+
SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent("1", "hello, how are you?", metadata);
38+
SimpleVectorStore.Similarity similarity = new SimpleVectorStore.Similarity(storeContent, 0.6d);
39+
Document document = similarity.getDocument();
40+
assertThat(document).isNotNull();
41+
assertThat(document.getId()).isEqualTo("1");
42+
assertThat(document.getContent()).isEqualTo("hello, how are you?");
43+
assertThat(document.getMetadata().get("foo")).isEqualTo("bar");
44+
}
45+
46+
}

0 commit comments

Comments
 (0)