2626import java .nio .file .FileAlreadyExistsException ;
2727import java .nio .file .Files ;
2828import java .util .Comparator ;
29- import java .util .HashMap ;
3029import java .util .List ;
3130import java .util .Map ;
32- import java .util .Objects ;
3331import java .util .Optional ;
3432import java .util .concurrent .ConcurrentHashMap ;
3533
3634import com .fasterxml .jackson .core .JsonProcessingException ;
3735import com .fasterxml .jackson .core .type .TypeReference ;
3836import com .fasterxml .jackson .databind .ObjectMapper ;
39- import com .fasterxml .jackson .databind .ObjectWriter ;
4037import com .fasterxml .jackson .databind .json .JsonMapper ;
4138import io .micrometer .observation .ObservationRegistry ;
4239import org .slf4j .Logger ;
5047import org .springframework .ai .vectorstore .observation .AbstractObservationVectorStore ;
5148import org .springframework .ai .vectorstore .observation .VectorStoreObservationContext ;
5249import org .springframework .ai .vectorstore .observation .VectorStoreObservationConvention ;
50+ import org .springframework .core .io .FileSystemResource ;
5351import org .springframework .core .io .Resource ;
52+ import org .springframework .util .Assert ;
5453
5554/**
56- * SimpleVectorStore is a simple implementation of the VectorStore interface.
57- *
55+ * Simple, in-memory implementation of the {@link VectorStore} interface.
56+ * <p/>
5857 * It also provides methods to save the current state of the vectors to a file, and to
5958 * load vectors from a file.
60- *
59+ * <p/>
6160 * For a deeper understanding of the mathematical concepts and computations involved in
6261 * calculating similarity scores among vectors, refer to this
6362 * [resource](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_understanding_vectors).
6766 * @author Mark Pollack
6867 * @author Christian Tzolov
6968 * @author Sebastien Deleuze
69+ * @author John Blum
70+ * @see VectorStore
7071 */
7172public class SimpleVectorStore extends AbstractObservationVectorStore {
7273
@@ -87,54 +88,72 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse
8788
8889 super (observationRegistry , customObservationConvention );
8990
90- Objects .requireNonNull (embeddingModel , "EmbeddingModel must not be null" );
91+ Assert .notNull (embeddingModel , "EmbeddingModel must not be null" );
92+
9193 this .embeddingModel = embeddingModel ;
9294 this .objectMapper = JsonMapper .builder ().addModules (JacksonUtils .instantiateAvailableModules ()).build ();
9395 }
9496
9597 @ Override
9698 public void doAdd (List <Document > documents ) {
9799 for (Document document : documents ) {
98- logger .info ("Calling EmbeddingModel for document id = {}" , document .getId ());
99- float [] embedding = this .embeddingModel .embed (document );
100- document .setEmbedding (embedding );
100+ logger .info ("Calling EmbeddingModel for Document id = {}" , document .getId ());
101+ document = embed (document );
101102 this .store .put (document .getId (), document );
102103 }
103104 }
104105
106+ protected Document embed (Document document ) {
107+ float [] documentEmbedding = this .embeddingModel .embed (document );
108+ document .setEmbedding (documentEmbedding );
109+ return document ;
110+ }
111+
105112 @ Override
106113 public Optional <Boolean > doDelete (List <String > idList ) {
107- for (String id : idList ) {
108- this .store .remove (id );
109- }
114+ idList .forEach (this .store ::remove );
110115 return Optional .of (true );
111116 }
112117
113118 @ Override
114119 public List <Document > doSimilaritySearch (SearchRequest request ) {
120+
115121 if (request .getFilterExpression () != null ) {
116122 throw new UnsupportedOperationException (
117- "The [" + this . getClass () + " ] doesn't support metadata filtering!" );
123+ "[%s ] doesn't support metadata filtering" . formatted ( getClass (). getName ()) );
118124 }
119125
120- float [] userQueryEmbedding = getUserQueryEmbedding (request .getQuery ());
121- return this .store .values ()
122- .stream ()
123- .map (entry -> new Similarity (entry .getId (),
124- EmbeddingMath .cosineSimilarity (userQueryEmbedding , entry .getEmbedding ())))
125- .filter (s -> s .score >= request .getSimilarityThreshold ())
126- .sorted (Comparator .<Similarity >comparingDouble (s -> s .score ).reversed ())
126+ // @formatter:off
127+ return this .store .values ().stream ()
128+ .map (document -> computeSimilarity (request , document ))
129+ .filter (similarity -> similarity .score >= request .getSimilarityThreshold ())
130+ .sorted (Comparator .<Similarity >comparingDouble (similarity -> similarity .score ).reversed ())
127131 .limit (request .getTopK ())
128- .map (s -> this .store .get (s .key ))
132+ .map (similarity -> this .store .get (similarity .key ))
129133 .toList ();
134+ // @formatter:on
135+ }
136+
137+ protected Similarity computeSimilarity (SearchRequest request , Document document ) {
138+
139+ float [] userQueryEmbedding = getUserQueryEmbedding (request );
140+ float [] documentEmbedding = document .getEmbedding ();
141+
142+ double score = computeCosineSimilarity (userQueryEmbedding , documentEmbedding );
143+
144+ return new Similarity (document .getId (), score );
145+ }
146+
147+ protected double computeCosineSimilarity (float [] userQueryEmbedding , float [] storedDocumentEmbedding ) {
148+ return EmbeddingMath .cosineSimilarity (userQueryEmbedding , storedDocumentEmbedding );
130149 }
131150
132151 /**
133152 * Serialize the vector store content into a file in JSON format.
134153 * @param file the file to save the vector store content
135154 */
136155 public void save (File file ) {
137- String json = getVectorDbAsJson ();
156+
138157 try {
139158 if (!file .exists ()) {
140159 logger .info ("Creating new vector store file: {}" , file );
@@ -145,28 +164,30 @@ public void save(File file) {
145164 throw new RuntimeException ("File already exists: " + file , e );
146165 }
147166 catch (IOException e ) {
148- throw new RuntimeException ("Failed to create new file: " + file + ". Reason: " + e .getMessage (), e );
167+ throw new RuntimeException ("Failed to create new file: " + file + "; Reason: " + e .getMessage (), e );
149168 }
150169 }
151170 else {
152171 logger .info ("Overwriting existing vector store file: {}" , file );
153172 }
173+
154174 try (OutputStream stream = new FileOutputStream (file );
155175 Writer writer = new OutputStreamWriter (stream , StandardCharsets .UTF_8 )) {
176+ String json = getVectorDbAsJson ();
156177 writer .write (json );
157178 writer .flush ();
158179 }
159180 }
160181 catch (IOException ex ) {
161- logger .error ("IOException occurred while saving vector store file. " , ex );
182+ logger .error ("IOException occurred while saving vector store file" , ex );
162183 throw new RuntimeException (ex );
163184 }
164185 catch (SecurityException ex ) {
165- logger .error ("SecurityException occurred while saving vector store file. " , ex );
186+ logger .error ("SecurityException occurred while saving vector store file" , ex );
166187 throw new RuntimeException (ex );
167188 }
168189 catch (NullPointerException ex ) {
169- logger .error ("NullPointerException occurred while saving vector store file. " , ex );
190+ logger .error ("NullPointerException occurred while saving vector store file" , ex );
170191 throw new RuntimeException (ex );
171192 }
172193 }
@@ -176,45 +197,40 @@ public void save(File file) {
176197 * @param file the file to load the vector store content
177198 */
178199 public void load (File file ) {
179- TypeReference <HashMap <String , Document >> typeRef = new TypeReference <>() {
180-
181- };
182- try {
183- Map <String , Document > deserializedMap = this .objectMapper .readValue (file , typeRef );
184- this .store = deserializedMap ;
185- }
186- catch (IOException ex ) {
187- throw new RuntimeException (ex );
188- }
200+ load (new FileSystemResource (file ));
189201 }
190202
191203 /**
192204 * Deserialize the vector store content from a resource in JSON format into memory.
193205 * @param resource the resource to load the vector store content
194206 */
195207 public void load (Resource resource ) {
196- TypeReference <HashMap <String , Document >> typeRef = new TypeReference <>() {
197208
198- };
199209 try {
200- Map <String , Document > deserializedMap = this .objectMapper .readValue (resource .getInputStream (), typeRef );
201- this .store = deserializedMap ;
210+ this .store = this .objectMapper .readValue (resource .getInputStream (), documentMapTypeRef ());
202211 }
203212 catch (IOException ex ) {
204213 throw new RuntimeException (ex );
205214 }
206215 }
207216
217+ private TypeReference <Map <String , Document >> documentMapTypeRef () {
218+ return new TypeReference <>() {
219+ };
220+ }
221+
208222 private String getVectorDbAsJson () {
209- ObjectWriter objectWriter = this .objectMapper .writerWithDefaultPrettyPrinter ();
210- String json ;
223+
211224 try {
212- json = objectWriter .writeValueAsString (this .store );
225+ return this . objectMapper . writerWithDefaultPrettyPrinter () .writeValueAsString (this .store );
213226 }
214227 catch (JsonProcessingException e ) {
215- throw new RuntimeException ("Error serializing documentMap to JSON. " , e );
228+ throw new RuntimeException ("Error serializing Map of Documents to JSON" , e );
216229 }
217- return json ;
230+ }
231+
232+ private float [] getUserQueryEmbedding (SearchRequest request ) {
233+ return getUserQueryEmbedding (request .getQuery ());
218234 }
219235
220236 private float [] getUserQueryEmbedding (String query ) {
@@ -232,9 +248,9 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str
232248
233249 public static class Similarity {
234250
235- private String key ;
251+ private final String key ;
236252
237- private double score ;
253+ private final double score ;
238254
239255 public Similarity (String key , double score ) {
240256 this .key = key ;
@@ -243,16 +259,18 @@ public Similarity(String key, double score) {
243259
244260 }
245261
246- public final class EmbeddingMath {
262+ public static final class EmbeddingMath {
247263
248264 private EmbeddingMath () {
249265 throw new UnsupportedOperationException ("This is a utility class and cannot be instantiated" );
250266 }
251267
252268 public static double cosineSimilarity (float [] vectorX , float [] vectorY ) {
269+
253270 if (vectorX == null || vectorY == null ) {
254- throw new RuntimeException ("Vectors must not be null" );
271+ throw new IllegalArgumentException ("Vectors must not be null" );
255272 }
273+
256274 if (vectorX .length != vectorY .length ) {
257275 throw new IllegalArgumentException ("Vectors lengths must be equal" );
258276 }
@@ -268,20 +286,22 @@ public static double cosineSimilarity(float[] vectorX, float[] vectorY) {
268286 return dotProduct / (Math .sqrt (normX ) * Math .sqrt (normY ));
269287 }
270288
271- public static float dotProduct (float [] vectorX , float [] vectorY ) {
289+ private static float dotProduct (float [] vectorX , float [] vectorY ) {
290+
272291 if (vectorX .length != vectorY .length ) {
273292 throw new IllegalArgumentException ("Vectors lengths must be equal" );
274293 }
275294
276295 float result = 0 ;
277- for (int i = 0 ; i < vectorX .length ; ++i ) {
278- result += vectorX [i ] * vectorY [i ];
296+
297+ for (int index = 0 ; index < vectorX .length ; ++index ) {
298+ result += vectorX [index ] * vectorY [index ];
279299 }
280300
281301 return result ;
282302 }
283303
284- public static float norm (float [] vector ) {
304+ private static float norm (float [] vector ) {
285305 return dotProduct (vector , vector );
286306 }
287307
0 commit comments