Skip to content

Commit 9dff20d

Browse files
committed
Support for SimpleVectorStore with metdata filter expressions
Signed-off-by: Jemin Huh <[email protected]>
1 parent 3d3c20d commit 9dff20d

File tree

4 files changed

+614
-23
lines changed

4 files changed

+614
-23
lines changed

spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,33 @@
1616

1717
package org.springframework.ai.vectorstore;
1818

19-
import java.io.File;
20-
import java.io.FileOutputStream;
21-
import java.io.IOException;
22-
import java.io.OutputStream;
23-
import java.io.OutputStreamWriter;
24-
import java.io.Writer;
25-
import java.nio.charset.StandardCharsets;
26-
import java.nio.file.FileAlreadyExistsException;
27-
import java.nio.file.Files;
28-
import java.util.Comparator;
29-
import java.util.HashMap;
30-
import java.util.List;
31-
import java.util.Map;
32-
import java.util.Objects;
33-
import java.util.Optional;
34-
import java.util.concurrent.ConcurrentHashMap;
35-
3619
import com.fasterxml.jackson.core.JsonProcessingException;
3720
import com.fasterxml.jackson.core.type.TypeReference;
3821
import com.fasterxml.jackson.databind.ObjectMapper;
3922
import com.fasterxml.jackson.databind.ObjectWriter;
4023
import com.fasterxml.jackson.databind.json.JsonMapper;
41-
4224
import org.springframework.ai.document.Document;
4325
import org.springframework.ai.embedding.EmbeddingModel;
4426
import org.springframework.ai.observation.conventions.VectorStoreProvider;
4527
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
4628
import org.springframework.ai.util.JacksonUtils;
29+
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
30+
import org.springframework.ai.vectorstore.filter.converter.SimpleVectorStoreFilterExpressionConverter;
4731
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
4832
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
4933
import org.springframework.core.io.Resource;
5034
import org.springframework.core.log.LogAccessor;
35+
import org.springframework.expression.ExpressionParser;
36+
import org.springframework.expression.spel.standard.SpelExpressionParser;
37+
import org.springframework.expression.spel.support.StandardEvaluationContext;
38+
39+
import java.io.*;
40+
import java.nio.charset.StandardCharsets;
41+
import java.nio.file.FileAlreadyExistsException;
42+
import java.nio.file.Files;
43+
import java.util.*;
44+
import java.util.concurrent.ConcurrentHashMap;
45+
import java.util.function.Predicate;
5146

5247
/**
5348
* SimpleVectorStore is a simple implementation of the VectorStore interface.
@@ -66,18 +61,25 @@
6661
* @author Sebastien Deleuze
6762
* @author Ilayaperumal Gopinathan
6863
* @author Thomas Vitale
64+
* @author Jemin Huh
6965
*/
7066
public class SimpleVectorStore extends AbstractObservationVectorStore {
7167

7268
private static final LogAccessor logger = new LogAccessor(SimpleVectorStore.class);
7369

7470
private final ObjectMapper objectMapper;
7571

72+
private final ExpressionParser expressionParser;
73+
74+
private final FilterExpressionConverter filterExpressionConverter;
75+
7676
protected Map<String, SimpleVectorStoreContent> store = new ConcurrentHashMap<>();
7777

7878
protected SimpleVectorStore(SimpleVectorStoreBuilder builder) {
7979
super(builder);
8080
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
81+
this.expressionParser = new SpelExpressionParser();
82+
this.filterExpressionConverter = new SimpleVectorStoreFilterExpressionConverter();
8183
}
8284

8385
/**
@@ -114,14 +116,11 @@ public Optional<Boolean> doDelete(List<String> idList) {
114116

115117
@Override
116118
public List<Document> doSimilaritySearch(SearchRequest request) {
117-
if (request.getFilterExpression() != null) {
118-
throw new UnsupportedOperationException(
119-
"The [" + this.getClass() + "] doesn't support metadata filtering!");
120-
}
121-
119+
Predicate<SimpleVectorStoreContent> documentFilterPredicate = doFilterPredicate(request);
122120
float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
123121
return this.store.values()
124122
.stream()
123+
.filter(documentFilterPredicate)
125124
.map(content -> content
126125
.toDocument(EmbeddingMath.cosineSimilarity(userQueryEmbedding, content.getEmbedding())))
127126
.filter(document -> document.getScore() >= request.getSimilarityThreshold())
@@ -130,6 +129,16 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
130129
.toList();
131130
}
132131

132+
private Predicate<SimpleVectorStoreContent> doFilterPredicate(SearchRequest request) {
133+
return request.hasFilterExpression() ? document -> {
134+
StandardEvaluationContext context = new StandardEvaluationContext();
135+
context.setVariable("metadata", document.getMetadata());
136+
return this.expressionParser
137+
.parseExpression(this.filterExpressionConverter.convertExpression(request.getFilterExpression()))
138+
.getValue(context, Boolean.class);
139+
} : document -> true;
140+
}
141+
133142
/**
134143
* Serialize the vector store content into a file in JSON format.
135144
* @param file the file to save the vector store content
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vectorstore.filter.converter;
18+
19+
import org.springframework.ai.vectorstore.filter.Filter;
20+
import org.springframework.ai.vectorstore.filter.Filter.Expression;
21+
22+
import java.text.ParseException;
23+
import java.text.SimpleDateFormat;
24+
import java.util.Date;
25+
import java.util.List;
26+
import java.util.TimeZone;
27+
import java.util.regex.Pattern;
28+
29+
/**
30+
* Converts {@link Expression} into SpEL metadata filter expression format.
31+
* (https://docs.spring.io/spring-framework/reference/core/expressions.html)
32+
*
33+
* @author Jemin Huh
34+
*/
35+
public class SimpleVectorStoreFilterExpressionConverter extends AbstractFilterExpressionConverter {
36+
37+
private static final Pattern DATE_FORMAT_PATTERN = Pattern.compile("\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}Z");
38+
39+
private final SimpleDateFormat dateFormat;
40+
41+
public SimpleVectorStoreFilterExpressionConverter() {
42+
this.dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'");
43+
this.dateFormat.setTimeZone(TimeZone.getTimeZone("UTC"));
44+
}
45+
46+
@Override
47+
protected void doExpression(Filter.Expression expression, StringBuilder context) {
48+
this.convertOperand(expression.left(), context);
49+
context.append(getOperationSymbol(expression));
50+
this.convertOperand(expression.right(), context);
51+
}
52+
53+
private String getOperationSymbol(Filter.Expression exp) {
54+
return switch (exp.type()) {
55+
case AND -> " and ";
56+
case OR -> " or ";
57+
case EQ -> " == ";
58+
case LT -> " < ";
59+
case LTE -> " <= ";
60+
case GT -> " > ";
61+
case GTE -> " >= ";
62+
case NE -> " != ";
63+
case IN -> " in ";
64+
case NIN -> " not in ";
65+
default -> throw new RuntimeException("Not supported expression type: " + exp.type());
66+
};
67+
}
68+
69+
@Override
70+
protected void doKey(Filter.Key key, StringBuilder context) {
71+
var identifier = hasOuterQuotes(key.key()) ? removeOuterQuotes(key.key()) : key.key();
72+
context.append("#metadata['").append(identifier).append("']");
73+
}
74+
75+
@Override
76+
protected void doValue(Filter.Value filterValue, StringBuilder context) {
77+
if (filterValue.value() instanceof List<?> list) {
78+
var formattedList = new StringBuilder("{");
79+
int c = 0;
80+
for (Object v : list) {
81+
this.doSingleValue(v, formattedList);
82+
if (c++ < list.size() - 1) {
83+
this.doAddValueRangeSpitter(filterValue, formattedList);
84+
}
85+
}
86+
formattedList.append("}");
87+
88+
if (context.lastIndexOf("in ") == -1) {
89+
context.append(formattedList);
90+
}
91+
else {
92+
appendSpELContains(formattedList, context);
93+
}
94+
}
95+
else {
96+
this.doSingleValue(filterValue.value(), context);
97+
}
98+
}
99+
100+
private void appendSpELContains(StringBuilder formattedList, StringBuilder context) {
101+
int metadataStart = context.lastIndexOf("#metadata");
102+
if (metadataStart == -1)
103+
throw new RuntimeException("Wrong SpEL expression: " + context);
104+
105+
int metadataEnd = context.indexOf(" ", metadataStart);
106+
String metadata = context.substring(metadataStart, metadataEnd);
107+
context.setLength(context.lastIndexOf("in "));
108+
context.delete(metadataStart, metadataEnd + 1);
109+
context.append(formattedList).append(".contains(").append(metadata).append(")");
110+
}
111+
112+
@Override
113+
protected void doSingleValue(Object value, StringBuilder context) {
114+
if (value instanceof Date date) {
115+
context.append("'");
116+
context.append(this.dateFormat.format(date));
117+
context.append("'");
118+
}
119+
else if (value instanceof String text) {
120+
context.append("'");
121+
if (DATE_FORMAT_PATTERN.matcher(text).matches()) {
122+
try {
123+
Date date = this.dateFormat.parse(text);
124+
context.append(this.dateFormat.format(date));
125+
}
126+
catch (ParseException e) {
127+
throw new IllegalArgumentException("Invalid date type:" + text, e);
128+
}
129+
}
130+
else {
131+
context.append(text);
132+
}
133+
context.append("'");
134+
}
135+
else {
136+
context.append(value);
137+
}
138+
}
139+
140+
@Override
141+
protected void doGroup(Filter.Group group, StringBuilder context) {
142+
context.append("(");
143+
super.doGroup(group, context);
144+
context.append(")");
145+
}
146+
147+
}

0 commit comments

Comments
 (0)