Skip to content

Commit 0b4f9db

Browse files
committed
Support for SimpleVectorStore with metdata filter expressions
Signed-off-by: Jemin Huh <[email protected]>
1 parent b525309 commit 0b4f9db

File tree

4 files changed

+610
-5
lines changed

4 files changed

+610
-5
lines changed

spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.util.Objects;
3333
import java.util.Optional;
3434
import java.util.concurrent.ConcurrentHashMap;
35+
import java.util.function.Predicate;
3536

3637
import com.fasterxml.jackson.core.JsonProcessingException;
3738
import com.fasterxml.jackson.core.type.TypeReference;
@@ -46,9 +47,14 @@
4647
import org.springframework.ai.observation.conventions.VectorStoreProvider;
4748
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
4849
import org.springframework.ai.util.JacksonUtils;
50+
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
51+
import org.springframework.ai.vectorstore.filter.converter.SimpleVectorStoreFilterExpressionConverter;
4952
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
5053
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
5154
import org.springframework.core.io.Resource;
55+
import org.springframework.expression.ExpressionParser;
56+
import org.springframework.expression.spel.standard.SpelExpressionParser;
57+
import org.springframework.expression.spel.support.StandardEvaluationContext;
5258

5359
/**
5460
* SimpleVectorStore is a simple implementation of the VectorStore interface.
@@ -67,18 +73,25 @@
6773
* @author Sebastien Deleuze
6874
* @author Ilayaperumal Gopinathan
6975
* @author Thomas Vitale
76+
* @author Jemin Huh
7077
*/
7178
public class SimpleVectorStore extends AbstractObservationVectorStore {
7279

7380
private static final Logger logger = LoggerFactory.getLogger(SimpleVectorStore.class);
7481

7582
private final ObjectMapper objectMapper;
7683

84+
private final ExpressionParser expressionParser;
85+
86+
private final FilterExpressionConverter filterExpressionConverter;
87+
7788
protected Map<String, SimpleVectorStoreContent> store = new ConcurrentHashMap<>();
7889

7990
protected SimpleVectorStore(SimpleVectorStoreBuilder builder) {
8091
super(builder);
8192
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
93+
this.expressionParser = new SpelExpressionParser();
94+
this.filterExpressionConverter = new SimpleVectorStoreFilterExpressionConverter();
8295
}
8396

8497
/**
@@ -115,14 +128,11 @@ public Optional<Boolean> doDelete(List<String> idList) {
115128

116129
@Override
117130
public List<Document> doSimilaritySearch(SearchRequest request) {
118-
if (request.getFilterExpression() != null) {
119-
throw new UnsupportedOperationException(
120-
"The [" + this.getClass() + "] doesn't support metadata filtering!");
121-
}
122-
131+
Predicate<SimpleVectorStoreContent> documentFilterPredicate = doFilterPredicate(request);
123132
float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
124133
return this.store.values()
125134
.stream()
135+
.filter(documentFilterPredicate)
126136
.map(content -> content
127137
.toDocument(EmbeddingMath.cosineSimilarity(userQueryEmbedding, content.getEmbedding())))
128138
.filter(document -> document.getScore() >= request.getSimilarityThreshold())
@@ -131,6 +141,16 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
131141
.toList();
132142
}
133143

144+
private Predicate<SimpleVectorStoreContent> doFilterPredicate(SearchRequest request) {
145+
return request.hasFilterExpression() ? document -> {
146+
StandardEvaluationContext context = new StandardEvaluationContext();
147+
context.setVariable("metadata", document.getMetadata());
148+
return this.expressionParser
149+
.parseExpression(this.filterExpressionConverter.convertExpression(request.getFilterExpression()))
150+
.getValue(context, Boolean.class);
151+
} : document -> true;
152+
}
153+
134154
/**
135155
* Serialize the vector store content into a file in JSON format.
136156
* @param file the file to save the vector store content
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vectorstore.filter.converter;
18+
19+
import org.springframework.ai.vectorstore.filter.Filter;
20+
import org.springframework.ai.vectorstore.filter.Filter.Expression;
21+
22+
import java.text.ParseException;
23+
import java.text.SimpleDateFormat;
24+
import java.util.Date;
25+
import java.util.List;
26+
import java.util.TimeZone;
27+
import java.util.regex.Pattern;
28+
29+
/**
30+
* Converts {@link Expression} into SpEL metadata filter expression format.
31+
* (https://docs.spring.io/spring-framework/reference/core/expressions.html)
32+
*
33+
* @author Jemin Huh
34+
*/
35+
public class SimpleVectorStoreFilterExpressionConverter extends AbstractFilterExpressionConverter {
36+
37+
private static final Pattern DATE_FORMAT_PATTERN = Pattern.compile("\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}Z");
38+
39+
private final SimpleDateFormat dateFormat;
40+
41+
public SimpleVectorStoreFilterExpressionConverter() {
42+
this.dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'");
43+
this.dateFormat.setTimeZone(TimeZone.getTimeZone("UTC"));
44+
}
45+
46+
@Override
47+
protected void doExpression(Filter.Expression expression, StringBuilder context) {
48+
this.convertOperand(expression.left(), context);
49+
context.append(getOperationSymbol(expression));
50+
this.convertOperand(expression.right(), context);
51+
}
52+
53+
private String getOperationSymbol(Filter.Expression exp) {
54+
return switch (exp.type()) {
55+
case AND -> " and ";
56+
case OR -> " or ";
57+
case EQ -> " == ";
58+
case LT -> " < ";
59+
case LTE -> " <= ";
60+
case GT -> " > ";
61+
case GTE -> " >= ";
62+
case NE -> " != ";
63+
case IN -> " in ";
64+
case NIN -> " not in ";
65+
default -> throw new RuntimeException("Not supported expression type: " + exp.type());
66+
};
67+
}
68+
69+
@Override
70+
protected void doKey(Filter.Key key, StringBuilder context) {
71+
var identifier = hasOuterQuotes(key.key()) ? removeOuterQuotes(key.key()) : key.key();
72+
context.append("#metadata['").append(identifier).append("']");
73+
}
74+
75+
@Override
76+
protected void doValue(Filter.Value filterValue, StringBuilder context) {
77+
if (filterValue.value() instanceof List<?> list) {
78+
var formattedList = new StringBuilder("{");
79+
int c = 0;
80+
for (Object v : list) {
81+
this.doSingleValue(v, formattedList);
82+
if (c++ < list.size() - 1) {
83+
this.doAddValueRangeSpitter(filterValue, formattedList);
84+
}
85+
}
86+
formattedList.append("}");
87+
88+
if (context.lastIndexOf("in ") == -1) {
89+
context.append(formattedList);
90+
}
91+
else {
92+
appendSpELContains(formattedList, context);
93+
}
94+
}
95+
else {
96+
this.doSingleValue(filterValue.value(), context);
97+
}
98+
}
99+
100+
private void appendSpELContains(StringBuilder formattedList, StringBuilder context) {
101+
int metadataStart = context.lastIndexOf("#metadata");
102+
if (metadataStart == -1)
103+
throw new RuntimeException("Wrong SpEL expression: " + context);
104+
105+
int metadataEnd = context.indexOf(" ", metadataStart);
106+
String metadata = context.substring(metadataStart, metadataEnd);
107+
context.setLength(context.lastIndexOf("in "));
108+
context.delete(metadataStart, metadataEnd + 1);
109+
context.append(formattedList).append(".contains(").append(metadata).append(")");
110+
}
111+
112+
@Override
113+
protected void doSingleValue(Object value, StringBuilder context) {
114+
if (value instanceof Date date) {
115+
context.append("'");
116+
context.append(this.dateFormat.format(date));
117+
context.append("'");
118+
}
119+
else if (value instanceof String text) {
120+
context.append("'");
121+
if (DATE_FORMAT_PATTERN.matcher(text).matches()) {
122+
try {
123+
Date date = this.dateFormat.parse(text);
124+
context.append(this.dateFormat.format(date));
125+
}
126+
catch (ParseException e) {
127+
throw new IllegalArgumentException("Invalid date type:" + text, e);
128+
}
129+
}
130+
else {
131+
context.append(text);
132+
}
133+
context.append("'");
134+
}
135+
else {
136+
context.append(value);
137+
}
138+
}
139+
140+
@Override
141+
protected void doGroup(Filter.Group group, StringBuilder context) {
142+
context.append("(");
143+
super.doGroup(group, context);
144+
context.append(")");
145+
}
146+
147+
}

0 commit comments

Comments
 (0)