diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java index 3919994db48..02c18f6b26a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java @@ -114,7 +114,7 @@ public record Value(Object value) implements Operand { */ public enum ExpressionType { - AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN + AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN, NOT } @@ -131,6 +131,9 @@ public enum ExpressionType { * be another {@link Expression}. */ public record Expression(ExpressionType type, Operand left, Operand right) implements Operand { + public Expression(ExpressionType type, Operand operand) { + this(type, operand, null); + } } /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java index 7b6595e860e..b0c3e9c2505 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java @@ -122,4 +122,8 @@ public Op group(Op content) { return new Op(new Filter.Group(content.build())); } + public Op not(Op content) { + return new Op(new Filter.Expression(ExpressionType.NOT, content.expression, null)); + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java index dddbf50bafd..bbaff204491 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java @@ -35,6 +35,7 @@ import org.springframework.ai.vectorstore.filter.antlr4.FiltersBaseVisitor; import org.springframework.ai.vectorstore.filter.antlr4.FiltersLexer; import org.springframework.ai.vectorstore.filter.antlr4.FiltersParser; +import org.springframework.ai.vectorstore.filter.antlr4.FiltersParser.NotExpressionContext; import org.springframework.core.NestedExceptionUtils; import org.springframework.util.Assert; @@ -263,6 +264,11 @@ public Filter.Operand visitGroupExpression(FiltersParser.GroupExpressionContext return new Filter.Group(castToExpression(this.visit(ctx.booleanExpression()))); } + @Override + public Filter.Operand visitNotExpression(NotExpressionContext ctx) { + return new Filter.Expression(Filter.ExpressionType.NOT, this.visit(ctx.booleanExpression()), null); + } + public Filter.Expression castToExpression(Filter.Operand expression) { if (expression instanceof Filter.Group group) { // Remove the top-level grouping. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java new file mode 100644 index 00000000000..0bf221774f6 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java @@ -0,0 +1,205 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.filter; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; +import org.springframework.ai.vectorstore.filter.Filter.Operand; +import org.springframework.ai.vectorstore.filter.converter.FilterExpressionConverter; +import org.springframework.util.Assert; + +/** + * Helper class providing various boolean transformation. + * + * @author Christian Tzolov + */ +public class FilterHelper { + + private FilterHelper() { + } + + private final static Map TYPE_NEGATION_MAP = Map.of(ExpressionType.AND, + ExpressionType.OR, ExpressionType.OR, ExpressionType.AND, ExpressionType.EQ, ExpressionType.NE, + ExpressionType.NE, ExpressionType.EQ, ExpressionType.GT, ExpressionType.LTE, ExpressionType.GTE, + ExpressionType.LT, ExpressionType.LT, ExpressionType.GTE, ExpressionType.LTE, ExpressionType.GT, + ExpressionType.IN, ExpressionType.NIN, ExpressionType.NIN, ExpressionType.IN); + + /** + * Transforms the input expression into a semantically equivalent one with negation + * operators propagated thought the expression tree by following the negation rules: + * + *
+	 * 	NOT(NOT(a)) = a
+	 *
+	 * 	NOT(a AND b) = NOT(a) OR NOT(b)
+	 * 	NOT(a OR b) = NOT(a) AND NOT(b)
+	 *
+	 * 	NOT(a EQ b) = a NE b
+	 * 	NOT(a NE b) = a EQ b
+	 *
+	 * 	NOT(a GT b) = a LTE b
+	 * 	NOT(a GTE b) = a LT b
+	 *
+	 * 	NOT(a LT b) = a GTE b
+	 * 	NOT(a LTE b) = a GT b
+	 *
+	 * 	NOT(a IN [...]) = a NIN [...]
+	 * 	NOT(a NIN [...]) = a IN [...]
+	 * 
+ * @param operand Filter expression to negate. + * @return Returns an negation of the input expression. + */ + public static Filter.Operand negate(Filter.Operand operand) { + + if (operand instanceof Filter.Group group) { + Operand inEx = negate(group.content()); + if (inEx instanceof Filter.Group inEx2) { + inEx = inEx2.content(); + } + return new Filter.Group((Expression) inEx); + } + else if (operand instanceof Filter.Expression exp) { + switch (exp.type()) { + case NOT: // NOT(NOT(a)) = a + return negate(exp.left()); + case AND: // NOT(a AND b) = NOT(a) OR NOT(b) + case OR: // NOT(a OR b) = NOT(a) AND NOT(b) + return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), negate(exp.left()), + negate(exp.right())); + case EQ: // NOT(e EQ b) = e NE b + case NE: // NOT(e NE b) = e EQ b + case GT: // NOT(e GT b) = e LTE b + case GTE: // NOT(e GTE b) = e LT b + case LT: // NOT(e LT b) = e GTE b + case LTE: // NOT(e LTE b) = e GT b + return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), exp.left(), exp.right()); + case IN: // NOT(e IN [...]) = e NIN [...] + case NIN: // NOT(e NIN [...]) = e IN [...] + return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), exp.left(), exp.right()); + default: + throw new IllegalArgumentException("Unknown expression type: " + exp.type()); + } + } + else { + throw new IllegalArgumentException("Can not negate operand of type: " + operand.getClass()); + } + } + + /** + * Expands the IN into a semantically equivalent boolean expressions of ORs of EQs. + * Useful for providers that don't provide native IN support. + * + * For example the
+	 * foo IN ["bar1", "bar2", "bar3"]
+	 * 
+ * + * expression is equivalent to + * + *
+	 * {@code foo == "bar1" || foo == "bar2" || foo == "bar3" (e.g. OR(foo EQ "bar1" OR(foo EQ "bar2" OR(foo EQ "bar3")))}
+	 * 
+ * @param exp input IN expression. + * @param context Output native expression. + * @param filterExpressionConverter {@link FilterExpressionConverter} used to compose + * the OR and EQ expanded expressions. + */ + public static void expandIn(Expression exp, StringBuilder context, + FilterExpressionConverter filterExpressionConverter) { + Assert.isTrue(exp.type() == ExpressionType.IN, "Expected IN expressions but was: " + exp.type()); + expandInNinExpressions(ExpressionType.OR, ExpressionType.EQ, exp, context, filterExpressionConverter); + } + + /** + * + * Expands the NIN (e.g. NOT IN) into a semantically equivalent boolean expressions of + * ANDs of NEs. Useful for providers that don't provide native NIN support.
+ * + * For example the + * + *
+	 * foo NIN ["bar1", "bar2", "bar3"] (or foo NOT IN ["bar1", "bar2", "bar3"])
+	 * 
+ * + * express is equivalent to + * + *
+	 * {@code foo != "bar1" && foo != "bar2" && foo != "bar3" (e.g. AND(foo NE "bar1" AND( foo NE "bar2" OR(foo NE "bar3"))) )}
+	 * 
+ * @param exp input NIN expression. + * @param context Output native expression. + * @param filterExpressionConverter {@link FilterExpressionConverter} used to compose + * the AND and NE expanded expressions. + */ + public static void expandNin(Expression exp, StringBuilder context, + FilterExpressionConverter filterExpressionConverter) { + Assert.isTrue(exp.type() == ExpressionType.NIN, "Expected NIN expressions but was: " + exp.type()); + expandInNinExpressions(ExpressionType.AND, ExpressionType.NE, exp, context, filterExpressionConverter); + } + + private static void expandInNinExpressions(Filter.ExpressionType outerExpressionType, + Filter.ExpressionType innerExpressionType, Expression exp, StringBuilder context, + FilterExpressionConverter expressionConverter) { + if (exp.right() instanceof Filter.Value value) { + if (value.value() instanceof List list) { + // 1. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo == "bar1" || + // foo == "bar2" || foo == "bar3" + // or equivalent to OR(foo == "bar1" OR( foo == "bar2" OR(foo == "bar3"))) + // 2. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo != "bar1" && + // foo != "bar2" && foo != "bar3" + // or equivalent to AND(foo != "bar1" AND( foo != "bar2" OR(foo != + // "bar3"))) + List eqExprs = new ArrayList<>(); + for (Object o : list) { + eqExprs.add(new Filter.Expression(innerExpressionType, exp.left(), new Filter.Value(o))); + } + context.append(expressionConverter.convertExpression(aggregate(outerExpressionType, eqExprs))); + } + else { + // 1. foo IN ["bar"] is equivalent to foo == "BAR" + // 2. foo NIN ["bar"] is equivalent to foo != "BAR" + context.append(expressionConverter + .convertExpression(new Filter.Expression(innerExpressionType, exp.left(), exp.right()))); + } + } + else { + throw new IllegalStateException( + "Filter IN right expression should be of Filter.Value type but was " + exp.right().getClass()); + } + } + + /** + * Recursively aggregates a list of expression into a binary tree with 'aggregateType' + * join nodes. + * @param aggregateType type all tree splits. + * @param expressions list of expressions to aggregate. + * @return Returns a binary tree expression. + */ + private static Filter.Expression aggregate(Filter.ExpressionType aggregateType, + List expressions) { + + if (expressions.size() == 1) { + return expressions.get(0); + } + return new Filter.Expression(aggregateType, expressions.get(0), + aggregate(aggregateType, expressions.subList(1, expressions.size()))); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/Filters.interp b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/Filters.interp index 61589278601..51775a8a55c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/Filters.interp +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/Filters.interp @@ -66,4 +66,4 @@ constant atn: -[4, 1, 26, 87, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 30, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 38, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 46, 8, 1, 10, 1, 12, 1, 49, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 55, 8, 2, 10, 2, 12, 2, 58, 9, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 69, 8, 4, 1, 5, 3, 5, 72, 8, 5, 1, 5, 1, 5, 3, 5, 76, 8, 5, 1, 5, 1, 5, 4, 5, 80, 8, 5, 11, 5, 12, 5, 81, 1, 5, 3, 5, 85, 8, 5, 1, 5, 0, 1, 2, 6, 0, 2, 4, 6, 8, 10, 0, 2, 2, 0, 8, 8, 11, 15, 1, 0, 9, 10, 95, 0, 12, 1, 0, 0, 0, 2, 37, 1, 0, 0, 0, 4, 50, 1, 0, 0, 0, 6, 61, 1, 0, 0, 0, 8, 68, 1, 0, 0, 0, 10, 84, 1, 0, 0, 0, 12, 13, 5, 1, 0, 0, 13, 14, 3, 2, 1, 0, 14, 15, 5, 0, 0, 1, 15, 1, 1, 0, 0, 0, 16, 17, 6, 1, -1, 0, 17, 18, 3, 8, 4, 0, 18, 19, 3, 6, 3, 0, 19, 20, 3, 10, 5, 0, 20, 38, 1, 0, 0, 0, 21, 22, 3, 8, 4, 0, 22, 23, 5, 18, 0, 0, 23, 24, 3, 4, 2, 0, 24, 38, 1, 0, 0, 0, 25, 29, 3, 8, 4, 0, 26, 27, 5, 20, 0, 0, 27, 30, 5, 18, 0, 0, 28, 30, 5, 19, 0, 0, 29, 26, 1, 0, 0, 0, 29, 28, 1, 0, 0, 0, 30, 31, 1, 0, 0, 0, 31, 32, 3, 4, 2, 0, 32, 38, 1, 0, 0, 0, 33, 34, 5, 6, 0, 0, 34, 35, 3, 2, 1, 0, 35, 36, 5, 7, 0, 0, 36, 38, 1, 0, 0, 0, 37, 16, 1, 0, 0, 0, 37, 21, 1, 0, 0, 0, 37, 25, 1, 0, 0, 0, 37, 33, 1, 0, 0, 0, 38, 47, 1, 0, 0, 0, 39, 40, 10, 3, 0, 0, 40, 41, 5, 16, 0, 0, 41, 46, 3, 2, 1, 4, 42, 43, 10, 2, 0, 0, 43, 44, 5, 17, 0, 0, 44, 46, 3, 2, 1, 3, 45, 39, 1, 0, 0, 0, 45, 42, 1, 0, 0, 0, 46, 49, 1, 0, 0, 0, 47, 45, 1, 0, 0, 0, 47, 48, 1, 0, 0, 0, 48, 3, 1, 0, 0, 0, 49, 47, 1, 0, 0, 0, 50, 51, 5, 4, 0, 0, 51, 56, 3, 10, 5, 0, 52, 53, 5, 3, 0, 0, 53, 55, 3, 10, 5, 0, 54, 52, 1, 0, 0, 0, 55, 58, 1, 0, 0, 0, 56, 54, 1, 0, 0, 0, 56, 57, 1, 0, 0, 0, 57, 59, 1, 0, 0, 0, 58, 56, 1, 0, 0, 0, 59, 60, 5, 5, 0, 0, 60, 5, 1, 0, 0, 0, 61, 62, 7, 0, 0, 0, 62, 7, 1, 0, 0, 0, 63, 64, 5, 25, 0, 0, 64, 65, 5, 2, 0, 0, 65, 69, 5, 25, 0, 0, 66, 69, 5, 25, 0, 0, 67, 69, 5, 22, 0, 0, 68, 63, 1, 0, 0, 0, 68, 66, 1, 0, 0, 0, 68, 67, 1, 0, 0, 0, 69, 9, 1, 0, 0, 0, 70, 72, 7, 1, 0, 0, 71, 70, 1, 0, 0, 0, 71, 72, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 85, 5, 23, 0, 0, 74, 76, 7, 1, 0, 0, 75, 74, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 77, 1, 0, 0, 0, 77, 85, 5, 24, 0, 0, 78, 80, 5, 22, 0, 0, 79, 78, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 79, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 85, 1, 0, 0, 0, 83, 85, 5, 21, 0, 0, 84, 71, 1, 0, 0, 0, 84, 75, 1, 0, 0, 0, 84, 79, 1, 0, 0, 0, 84, 83, 1, 0, 0, 0, 85, 11, 1, 0, 0, 0, 10, 29, 37, 45, 47, 56, 68, 71, 75, 81, 84] \ No newline at end of file +[4, 1, 26, 89, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 30, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 40, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 48, 8, 1, 10, 1, 12, 1, 51, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 57, 8, 2, 10, 2, 12, 2, 60, 9, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 71, 8, 4, 1, 5, 3, 5, 74, 8, 5, 1, 5, 1, 5, 3, 5, 78, 8, 5, 1, 5, 1, 5, 4, 5, 82, 8, 5, 11, 5, 12, 5, 83, 1, 5, 3, 5, 87, 8, 5, 1, 5, 0, 1, 2, 6, 0, 2, 4, 6, 8, 10, 0, 2, 2, 0, 8, 8, 11, 15, 1, 0, 9, 10, 98, 0, 12, 1, 0, 0, 0, 2, 39, 1, 0, 0, 0, 4, 52, 1, 0, 0, 0, 6, 63, 1, 0, 0, 0, 8, 70, 1, 0, 0, 0, 10, 86, 1, 0, 0, 0, 12, 13, 5, 1, 0, 0, 13, 14, 3, 2, 1, 0, 14, 15, 5, 0, 0, 1, 15, 1, 1, 0, 0, 0, 16, 17, 6, 1, -1, 0, 17, 18, 3, 8, 4, 0, 18, 19, 3, 6, 3, 0, 19, 20, 3, 10, 5, 0, 20, 40, 1, 0, 0, 0, 21, 22, 3, 8, 4, 0, 22, 23, 5, 18, 0, 0, 23, 24, 3, 4, 2, 0, 24, 40, 1, 0, 0, 0, 25, 29, 3, 8, 4, 0, 26, 27, 5, 20, 0, 0, 27, 30, 5, 18, 0, 0, 28, 30, 5, 19, 0, 0, 29, 26, 1, 0, 0, 0, 29, 28, 1, 0, 0, 0, 30, 31, 1, 0, 0, 0, 31, 32, 3, 4, 2, 0, 32, 40, 1, 0, 0, 0, 33, 34, 5, 6, 0, 0, 34, 35, 3, 2, 1, 0, 35, 36, 5, 7, 0, 0, 36, 40, 1, 0, 0, 0, 37, 38, 5, 20, 0, 0, 38, 40, 3, 2, 1, 1, 39, 16, 1, 0, 0, 0, 39, 21, 1, 0, 0, 0, 39, 25, 1, 0, 0, 0, 39, 33, 1, 0, 0, 0, 39, 37, 1, 0, 0, 0, 40, 49, 1, 0, 0, 0, 41, 42, 10, 4, 0, 0, 42, 43, 5, 16, 0, 0, 43, 48, 3, 2, 1, 5, 44, 45, 10, 3, 0, 0, 45, 46, 5, 17, 0, 0, 46, 48, 3, 2, 1, 4, 47, 41, 1, 0, 0, 0, 47, 44, 1, 0, 0, 0, 48, 51, 1, 0, 0, 0, 49, 47, 1, 0, 0, 0, 49, 50, 1, 0, 0, 0, 50, 3, 1, 0, 0, 0, 51, 49, 1, 0, 0, 0, 52, 53, 5, 4, 0, 0, 53, 58, 3, 10, 5, 0, 54, 55, 5, 3, 0, 0, 55, 57, 3, 10, 5, 0, 56, 54, 1, 0, 0, 0, 57, 60, 1, 0, 0, 0, 58, 56, 1, 0, 0, 0, 58, 59, 1, 0, 0, 0, 59, 61, 1, 0, 0, 0, 60, 58, 1, 0, 0, 0, 61, 62, 5, 5, 0, 0, 62, 5, 1, 0, 0, 0, 63, 64, 7, 0, 0, 0, 64, 7, 1, 0, 0, 0, 65, 66, 5, 25, 0, 0, 66, 67, 5, 2, 0, 0, 67, 71, 5, 25, 0, 0, 68, 71, 5, 25, 0, 0, 69, 71, 5, 22, 0, 0, 70, 65, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 70, 69, 1, 0, 0, 0, 71, 9, 1, 0, 0, 0, 72, 74, 7, 1, 0, 0, 73, 72, 1, 0, 0, 0, 73, 74, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 87, 5, 23, 0, 0, 76, 78, 7, 1, 0, 0, 77, 76, 1, 0, 0, 0, 77, 78, 1, 0, 0, 0, 78, 79, 1, 0, 0, 0, 79, 87, 5, 24, 0, 0, 80, 82, 5, 22, 0, 0, 81, 80, 1, 0, 0, 0, 82, 83, 1, 0, 0, 0, 83, 81, 1, 0, 0, 0, 83, 84, 1, 0, 0, 0, 84, 87, 1, 0, 0, 0, 85, 87, 5, 21, 0, 0, 86, 73, 1, 0, 0, 0, 86, 77, 1, 0, 0, 0, 86, 81, 1, 0, 0, 0, 86, 85, 1, 0, 0, 0, 87, 11, 1, 0, 0, 0, 10, 29, 39, 47, 49, 58, 70, 73, 77, 83, 86] \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java index c8e98410aa5..962a36c6796 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java @@ -121,6 +121,28 @@ public void enterInExpression(FiltersParser.InExpressionContext ctx) { public void exitInExpression(FiltersParser.InExpressionContext ctx) { } + /** + * {@inheritDoc} + * + *

+ * The default implementation does nothing. + *

+ */ + @Override + public void enterNotExpression(FiltersParser.NotExpressionContext ctx) { + } + + /** + * {@inheritDoc} + * + *

+ * The default implementation does nothing. + *

+ */ + @Override + public void exitNotExpression(FiltersParser.NotExpressionContext ctx) { + } + /** * {@inheritDoc} * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java index 5c0c81193ee..555a6962913 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java @@ -86,6 +86,19 @@ public T visitInExpression(FiltersParser.InExpressionContext ctx) { return visitChildren(ctx); } + /** + * {@inheritDoc} + * + *

+ * The default implementation returns the result of calling {@link #visitChildren} on + * {@code ctx}. + *

+ */ + @Override + public T visitNotExpression(FiltersParser.NotExpressionContext ctx) { + return visitChildren(ctx); + } + /** * {@inheritDoc} * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java index f6b920479b3..77444e52747 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java @@ -83,6 +83,20 @@ public interface FiltersListener extends ParseTreeListener { */ void exitInExpression(FiltersParser.InExpressionContext ctx); + /** + * Enter a parse tree produced by the {@code NotExpression} labeled alternative in + * {@link FiltersParser#booleanExpression}. + * @param ctx the parse tree + */ + void enterNotExpression(FiltersParser.NotExpressionContext ctx); + + /** + * Exit a parse tree produced by the {@code NotExpression} labeled alternative in + * {@link FiltersParser#booleanExpression}. + * @param ctx the parse tree + */ + void exitNotExpression(FiltersParser.NotExpressionContext ctx); + /** * Enter a parse tree produced by the {@code CompareExpression} labeled alternative in * {@link FiltersParser#booleanExpression}. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java index 3171c199132..a17e355d7ee 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java @@ -358,6 +358,43 @@ public T accept(ParseTreeVisitor visitor) { } + @SuppressWarnings("CheckReturnValue") + public static class NotExpressionContext extends BooleanExpressionContext { + + public TerminalNode NOT() { + return getToken(FiltersParser.NOT, 0); + } + + public BooleanExpressionContext booleanExpression() { + return getRuleContext(BooleanExpressionContext.class, 0); + } + + public NotExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) + ((FiltersListener) listener).enterNotExpression(this); + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) + ((FiltersListener) listener).exitNotExpression(this); + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) + return ((FiltersVisitor) visitor).visitNotExpression(this); + else + return visitor.visitChildren(this); + } + + } + @SuppressWarnings("CheckReturnValue") public static class CompareExpressionContext extends BooleanExpressionContext { @@ -502,7 +539,7 @@ private BooleanExpressionContext booleanExpression(int _p) throws RecognitionExc int _alt; enterOuterAlt(_localctx, 1); { - setState(37); + setState(39); _errHandler.sync(this); switch (getInterpreter().adaptivePredict(_input, 1, _ctx)) { case 1: { @@ -570,9 +607,19 @@ private BooleanExpressionContext booleanExpression(int _p) throws RecognitionExc match(RIGHT_PARENTHESIS); } break; + case 5: { + _localctx = new NotExpressionContext(_localctx); + _ctx = _localctx; + _prevctx = _localctx; + setState(37); + match(NOT); + setState(38); + booleanExpression(1); + } + break; } _ctx.stop = _input.LT(-1); - setState(47); + setState(49); _errHandler.sync(this); _alt = getInterpreter().adaptivePredict(_input, 3, _ctx); while (_alt != 2 && _alt != org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER) { @@ -581,7 +628,7 @@ private BooleanExpressionContext booleanExpression(int _p) throws RecognitionExc triggerExitRuleEvent(); _prevctx = _localctx; { - setState(45); + setState(47); _errHandler.sync(this); switch (getInterpreter().adaptivePredict(_input, 2, _ctx)) { case 1: { @@ -589,13 +636,13 @@ private BooleanExpressionContext booleanExpression(int _p) throws RecognitionExc new BooleanExpressionContext(_parentctx, _parentState)); ((AndExpressionContext) _localctx).left = _prevctx; pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); - setState(39); - if (!(precpred(_ctx, 3))) - throw new FailedPredicateException(this, "precpred(_ctx, 3)"); - setState(40); - ((AndExpressionContext) _localctx).operator = match(AND); setState(41); - ((AndExpressionContext) _localctx).right = booleanExpression(4); + if (!(precpred(_ctx, 4))) + throw new FailedPredicateException(this, "precpred(_ctx, 4)"); + setState(42); + ((AndExpressionContext) _localctx).operator = match(AND); + setState(43); + ((AndExpressionContext) _localctx).right = booleanExpression(5); } break; case 2: { @@ -603,19 +650,19 @@ private BooleanExpressionContext booleanExpression(int _p) throws RecognitionExc new BooleanExpressionContext(_parentctx, _parentState)); ((OrExpressionContext) _localctx).left = _prevctx; pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); - setState(42); - if (!(precpred(_ctx, 2))) - throw new FailedPredicateException(this, "precpred(_ctx, 2)"); - setState(43); - ((OrExpressionContext) _localctx).operator = match(OR); setState(44); - ((OrExpressionContext) _localctx).right = booleanExpression(3); + if (!(precpred(_ctx, 3))) + throw new FailedPredicateException(this, "precpred(_ctx, 3)"); + setState(45); + ((OrExpressionContext) _localctx).operator = match(OR); + setState(46); + ((OrExpressionContext) _localctx).right = booleanExpression(4); } break; } } } - setState(49); + setState(51); _errHandler.sync(this); _alt = getInterpreter().adaptivePredict(_input, 3, _ctx); } @@ -697,27 +744,27 @@ public final ConstantArrayContext constantArray() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(50); + setState(52); match(LEFT_SQUARE_BRACKETS); - setState(51); + setState(53); constant(); - setState(56); + setState(58); _errHandler.sync(this); _la = _input.LA(1); while (_la == COMMA) { { { - setState(52); + setState(54); match(COMMA); - setState(53); + setState(55); constant(); } } - setState(58); + setState(60); _errHandler.sync(this); _la = _input.LA(1); } - setState(59); + setState(61); match(RIGHT_SQUARE_BRACKETS); } } @@ -797,7 +844,7 @@ public final CompareContext compare() throws RecognitionException { try { enterOuterAlt(_localctx, 1); { - setState(61); + setState(63); _la = _input.LA(1); if (!((((_la) & ~0x3f) == 0 && ((1L << _la) & 63744L) != 0))) { _errHandler.recoverInline(this); @@ -875,28 +922,28 @@ public final IdentifierContext identifier() throws RecognitionException { IdentifierContext _localctx = new IdentifierContext(_ctx, getState()); enterRule(_localctx, 8, RULE_identifier); try { - setState(68); + setState(70); _errHandler.sync(this); switch (getInterpreter().adaptivePredict(_input, 5, _ctx)) { case 1: enterOuterAlt(_localctx, 1); { - setState(63); + setState(65); match(IDENTIFIER); - setState(64); + setState(66); match(DOT); - setState(65); + setState(67); match(IDENTIFIER); } break; case 2: enterOuterAlt(_localctx, 2); { - setState(66); + setState(68); match(IDENTIFIER); } break; case 3: enterOuterAlt(_localctx, 3); { - setState(67); + setState(69); match(QUOTED_STRING); } break; @@ -1092,18 +1139,18 @@ public final ConstantContext constant() throws RecognitionException { int _la; try { int _alt; - setState(84); + setState(86); _errHandler.sync(this); switch (getInterpreter().adaptivePredict(_input, 9, _ctx)) { case 1: _localctx = new IntegerConstantContext(_localctx); enterOuterAlt(_localctx, 1); { - setState(71); + setState(73); _errHandler.sync(this); _la = _input.LA(1); if (_la == MINUS || _la == PLUS) { { - setState(70); + setState(72); _la = _input.LA(1); if (!(_la == MINUS || _la == PLUS)) { _errHandler.recoverInline(this); @@ -1117,19 +1164,19 @@ public final ConstantContext constant() throws RecognitionException { } } - setState(73); + setState(75); match(INTEGER_VALUE); } break; case 2: _localctx = new DecimalConstantContext(_localctx); enterOuterAlt(_localctx, 2); { - setState(75); + setState(77); _errHandler.sync(this); _la = _input.LA(1); if (_la == MINUS || _la == PLUS) { { - setState(74); + setState(76); _la = _input.LA(1); if (!(_la == MINUS || _la == PLUS)) { _errHandler.recoverInline(this); @@ -1143,21 +1190,21 @@ public final ConstantContext constant() throws RecognitionException { } } - setState(77); + setState(79); match(DECIMAL_VALUE); } break; case 3: _localctx = new TextConstantContext(_localctx); enterOuterAlt(_localctx, 3); { - setState(79); + setState(81); _errHandler.sync(this); _alt = 1; do { switch (_alt) { case 1: { { - setState(78); + setState(80); match(QUOTED_STRING); } } @@ -1165,7 +1212,7 @@ public final ConstantContext constant() throws RecognitionException { default: throw new NoViableAltException(this); } - setState(81); + setState(83); _errHandler.sync(this); _alt = getInterpreter().adaptivePredict(_input, 8, _ctx); } @@ -1175,7 +1222,7 @@ public final ConstantContext constant() throws RecognitionException { case 4: _localctx = new BooleanConstantContext(_localctx); enterOuterAlt(_localctx, 4); { - setState(83); + setState(85); match(BOOLEAN_VALUE); } break; @@ -1203,67 +1250,68 @@ public boolean sempred(RuleContext _localctx, int ruleIndex, int predIndex) { private boolean booleanExpression_sempred(BooleanExpressionContext _localctx, int predIndex) { switch (predIndex) { case 0: - return precpred(_ctx, 3); + return precpred(_ctx, 4); case 1: - return precpred(_ctx, 2); + return precpred(_ctx, 3); } return true; } - public static final String _serializedATN = "\u0004\u0001\u001aW\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002" + public static final String _serializedATN = "\u0004\u0001\u001aY\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002" + "\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004\u0007\u0004\u0002" + "\u0005\u0007\u0005\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001" + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + "\u0001\u0003\u0001\u001e\b\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0003\u0001&\b\u0001\u0001\u0001\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0005\u0001.\b" - + "\u0001\n\u0001\f\u00011\t\u0001\u0001\u0002\u0001\u0002\u0001\u0002\u0001" - + "\u0002\u0005\u00027\b\u0002\n\u0002\f\u0002:\t\u0002\u0001\u0002\u0001" - + "\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001\u0004\u0001\u0004\u0001" - + "\u0004\u0001\u0004\u0003\u0004E\b\u0004\u0001\u0005\u0003\u0005H\b\u0005" - + "\u0001\u0005\u0001\u0005\u0003\u0005L\b\u0005\u0001\u0005\u0001\u0005" - + "\u0004\u0005P\b\u0005\u000b\u0005\f\u0005Q\u0001\u0005\u0003\u0005U\b" - + "\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000\u0002\u0004\u0006\b\n" - + "\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000\t\n_\u0000\f\u0001" - + "\u0000\u0000\u0000\u0002%\u0001\u0000\u0000\u0000\u00042\u0001\u0000\u0000" - + "\u0000\u0006=\u0001\u0000\u0000\u0000\bD\u0001\u0000\u0000\u0000\nT\u0001" - + "\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000\r\u000e\u0003\u0002\u0001" - + "\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f\u0001\u0001\u0000\u0000" - + "\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000\u0011\u0012\u0003\b\u0004" - + "\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013\u0014\u0003\n\u0005\u0000" - + "\u0014&\u0001\u0000\u0000\u0000\u0015\u0016\u0003\b\u0004\u0000\u0016" - + "\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003\u0004\u0002\u0000\u0018" - + "&\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b\u0004\u0000\u001a\u001b" - + "\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012\u0000\u0000\u001c\u001e" - + "\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000\u0000\u0000\u001d\u001c" - + "\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000\u0000\u0000\u001f \u0003" - + "\u0004\u0002\u0000 &\u0001\u0000\u0000\u0000!\"\u0005\u0006\u0000\u0000" - + "\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000\u0000$&\u0001\u0000\u0000" - + "\u0000%\u0010\u0001\u0000\u0000\u0000%\u0015\u0001\u0000\u0000\u0000%" - + "\u0019\u0001\u0000\u0000\u0000%!\u0001\u0000\u0000\u0000&/\u0001\u0000" - + "\u0000\u0000\'(\n\u0003\u0000\u0000()\u0005\u0010\u0000\u0000).\u0003" - + "\u0002\u0001\u0004*+\n\u0002\u0000\u0000+,\u0005\u0011\u0000\u0000,.\u0003" - + "\u0002\u0001\u0003-\'\u0001\u0000\u0000\u0000-*\u0001\u0000\u0000\u0000" - + ".1\u0001\u0000\u0000\u0000/-\u0001\u0000\u0000\u0000/0\u0001\u0000\u0000" - + "\u00000\u0003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000\u000023\u0005" - + "\u0004\u0000\u000038\u0003\n\u0005\u000045\u0005\u0003\u0000\u000057\u0003" - + "\n\u0005\u000064\u0001\u0000\u0000\u00007:\u0001\u0000\u0000\u000086\u0001" - + "\u0000\u0000\u000089\u0001\u0000\u0000\u00009;\u0001\u0000\u0000\u0000" - + ":8\u0001\u0000\u0000\u0000;<\u0005\u0005\u0000\u0000<\u0005\u0001\u0000" - + "\u0000\u0000=>\u0007\u0000\u0000\u0000>\u0007\u0001\u0000\u0000\u0000" - + "?@\u0005\u0019\u0000\u0000@A\u0005\u0002\u0000\u0000AE\u0005\u0019\u0000" - + "\u0000BE\u0005\u0019\u0000\u0000CE\u0005\u0016\u0000\u0000D?\u0001\u0000" - + "\u0000\u0000DB\u0001\u0000\u0000\u0000DC\u0001\u0000\u0000\u0000E\t\u0001" - + "\u0000\u0000\u0000FH\u0007\u0001\u0000\u0000GF\u0001\u0000\u0000\u0000" - + "GH\u0001\u0000\u0000\u0000HI\u0001\u0000\u0000\u0000IU\u0005\u0017\u0000" - + "\u0000JL\u0007\u0001\u0000\u0000KJ\u0001\u0000\u0000\u0000KL\u0001\u0000" - + "\u0000\u0000LM\u0001\u0000\u0000\u0000MU\u0005\u0018\u0000\u0000NP\u0005" - + "\u0016\u0000\u0000ON\u0001\u0000\u0000\u0000PQ\u0001\u0000\u0000\u0000" - + "QO\u0001\u0000\u0000\u0000QR\u0001\u0000\u0000\u0000RU\u0001\u0000\u0000" - + "\u0000SU\u0005\u0015\u0000\u0000TG\u0001\u0000\u0000\u0000TK\u0001\u0000" - + "\u0000\u0000TO\u0001\u0000\u0000\u0000TS\u0001\u0000\u0000\u0000U\u000b" - + "\u0001\u0000\u0000\u0000\n\u001d%-/8DGKQT"; + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0003\u0001(\b" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0005\u00010\b\u0001\n\u0001\f\u00013\t\u0001\u0001\u0002\u0001" + + "\u0002\u0001\u0002\u0001\u0002\u0005\u00029\b\u0002\n\u0002\f\u0002<\t" + + "\u0002\u0001\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001" + + "\u0004\u0001\u0004\u0001\u0004\u0001\u0004\u0003\u0004G\b\u0004\u0001" + + "\u0005\u0003\u0005J\b\u0005\u0001\u0005\u0001\u0005\u0003\u0005N\b\u0005" + + "\u0001\u0005\u0001\u0005\u0004\u0005R\b\u0005\u000b\u0005\f\u0005S\u0001" + + "\u0005\u0003\u0005W\b\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000" + + "\u0002\u0004\u0006\b\n\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000" + + "\t\nb\u0000\f\u0001\u0000\u0000\u0000\u0002\'\u0001\u0000\u0000\u0000" + + "\u00044\u0001\u0000\u0000\u0000\u0006?\u0001\u0000\u0000\u0000\bF\u0001" + + "\u0000\u0000\u0000\nV\u0001\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000" + + "\r\u000e\u0003\u0002\u0001\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f" + + "\u0001\u0001\u0000\u0000\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000" + + "\u0011\u0012\u0003\b\u0004\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013" + + "\u0014\u0003\n\u0005\u0000\u0014(\u0001\u0000\u0000\u0000\u0015\u0016" + + "\u0003\b\u0004\u0000\u0016\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003" + + "\u0004\u0002\u0000\u0018(\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b" + + "\u0004\u0000\u001a\u001b\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012" + + "\u0000\u0000\u001c\u001e\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000" + + "\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000" + + "\u0000\u0000\u001f \u0003\u0004\u0002\u0000 (\u0001\u0000\u0000\u0000" + + "!\"\u0005\u0006\u0000\u0000\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000" + + "\u0000$(\u0001\u0000\u0000\u0000%&\u0005\u0014\u0000\u0000&(\u0003\u0002" + + "\u0001\u0001\'\u0010\u0001\u0000\u0000\u0000\'\u0015\u0001\u0000\u0000" + + "\u0000\'\u0019\u0001\u0000\u0000\u0000\'!\u0001\u0000\u0000\u0000\'%\u0001" + + "\u0000\u0000\u0000(1\u0001\u0000\u0000\u0000)*\n\u0004\u0000\u0000*+\u0005" + + "\u0010\u0000\u0000+0\u0003\u0002\u0001\u0005,-\n\u0003\u0000\u0000-.\u0005" + + "\u0011\u0000\u0000.0\u0003\u0002\u0001\u0004/)\u0001\u0000\u0000\u0000" + + "/,\u0001\u0000\u0000\u000003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000" + + "\u000012\u0001\u0000\u0000\u00002\u0003\u0001\u0000\u0000\u000031\u0001" + + "\u0000\u0000\u000045\u0005\u0004\u0000\u00005:\u0003\n\u0005\u000067\u0005" + + "\u0003\u0000\u000079\u0003\n\u0005\u000086\u0001\u0000\u0000\u00009<\u0001" + + "\u0000\u0000\u0000:8\u0001\u0000\u0000\u0000:;\u0001\u0000\u0000\u0000" + + ";=\u0001\u0000\u0000\u0000<:\u0001\u0000\u0000\u0000=>\u0005\u0005\u0000" + + "\u0000>\u0005\u0001\u0000\u0000\u0000?@\u0007\u0000\u0000\u0000@\u0007" + + "\u0001\u0000\u0000\u0000AB\u0005\u0019\u0000\u0000BC\u0005\u0002\u0000" + + "\u0000CG\u0005\u0019\u0000\u0000DG\u0005\u0019\u0000\u0000EG\u0005\u0016" + + "\u0000\u0000FA\u0001\u0000\u0000\u0000FD\u0001\u0000\u0000\u0000FE\u0001" + + "\u0000\u0000\u0000G\t\u0001\u0000\u0000\u0000HJ\u0007\u0001\u0000\u0000" + + "IH\u0001\u0000\u0000\u0000IJ\u0001\u0000\u0000\u0000JK\u0001\u0000\u0000" + + "\u0000KW\u0005\u0017\u0000\u0000LN\u0007\u0001\u0000\u0000ML\u0001\u0000" + + "\u0000\u0000MN\u0001\u0000\u0000\u0000NO\u0001\u0000\u0000\u0000OW\u0005" + + "\u0018\u0000\u0000PR\u0005\u0016\u0000\u0000QP\u0001\u0000\u0000\u0000" + + "RS\u0001\u0000\u0000\u0000SQ\u0001\u0000\u0000\u0000ST\u0001\u0000\u0000" + + "\u0000TW\u0001\u0000\u0000\u0000UW\u0005\u0015\u0000\u0000VI\u0001\u0000" + + "\u0000\u0000VM\u0001\u0000\u0000\u0000VQ\u0001\u0000\u0000\u0000VU\u0001" + + "\u0000\u0000\u0000W\u000b\u0001\u0000\u0000\u0000\n\u001d\'/1:FIMSV"; public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray()); static { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java index 8312712736b..27413bd6e3f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java @@ -63,6 +63,14 @@ public interface FiltersVisitor extends ParseTreeVisitor { */ T visitInExpression(FiltersParser.InExpressionContext ctx); + /** + * Visit a parse tree produced by the {@code NotExpression} labeled alternative in + * {@link FiltersParser#booleanExpression}. + * @param ctx the parse tree + * @return the visitor result + */ + T visitNotExpression(FiltersParser.NotExpressionContext ctx); + /** * Visit a parse tree produced by the {@code CompareExpression} labeled alternative in * {@link FiltersParser#booleanExpression}. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java index 46880565415..8769d97ecfc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java @@ -19,6 +19,7 @@ import java.util.List; import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterHelper; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Group; @@ -52,14 +53,27 @@ else if (operand instanceof Filter.Value value) { this.doValue(value, context); } else if (operand instanceof Filter.Expression expression) { - if ((expression.type() != ExpressionType.AND && expression.type() != ExpressionType.OR) - && !(expression.right() instanceof Filter.Value)) { + if ((expression.type() != ExpressionType.NOT && expression.type() != ExpressionType.AND + && expression.type() != ExpressionType.OR) && !(expression.right() instanceof Filter.Value)) { throw new RuntimeException("Non AND/OR expression must have Value right argument!"); } - this.doExpression(expression, context); + if (expression.type() == ExpressionType.NOT) { + this.doNot(expression, context); + } + else { + this.doExpression(expression, context); + } } } + protected void doNot(Filter.Expression expression, StringBuilder context) { + // Default behavior is to convert the NOT expression into its semantically + // equivalent negation expression. + // Effectively removing the NOT types form the boolean expression tree before + // passing it to the doExpression. + this.convertOperand(FilterHelper.negate(expression), context); + } + protected abstract void doExpression(Filter.Expression expression, StringBuilder context); protected abstract void doKey(Filter.Key filterKey, StringBuilder context); diff --git a/spring-ai-core/src/main/resources/antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 b/spring-ai-core/src/main/resources/antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 index 21bb45ad400..086766b6abf 100644 --- a/spring-ai-core/src/main/resources/antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 +++ b/spring-ai-core/src/main/resources/antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 @@ -33,6 +33,7 @@ booleanExpression | left=booleanExpression operator=AND right=booleanExpression # AndExpression | left=booleanExpression operator=OR right=booleanExpression # OrExpression | LEFT_PARENTHESIS booleanExpression RIGHT_PARENTHESIS # GroupExpression + | NOT booleanExpression # NotExpression ; constantArray diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java index c05cba98bad..d5f8e077a38 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java @@ -33,6 +33,7 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NOT; /** * @author Christian Tzolov @@ -97,4 +98,18 @@ public void tesIn2() { new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); } + @Test + public void tesNot() { + // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] + var exp = b.not(b.and(b.and(b.eq("isOpen", true), b.gte("year", 2020)), b.in("country", "BG", "NL", "US"))) + .build(); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Expression(AND, + new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), + new Expression(GTE, new Key("year"), new Value(2020))), + new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))), + null)); + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java index 057ae86cc4a..27e62f1d8fa 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java @@ -30,10 +30,11 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.IN; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.LTE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NOT; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; -import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.LTE; /** * @author Christian Tzolov @@ -111,6 +112,61 @@ public void tesBoolean() { .get("WHERE " + "isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]")).isEqualTo(exp); } + @Test + public void tesNot() { + // NOT(isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"]) + Expression exp = parser.parse("not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])"); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Group(new Expression(AND, + new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), + new Expression(GTE, new Key("year"), new Value(2020))), + new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))), + null)); + + assertThat(parser.getCache() + .get("WHERE " + "not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])")) + .isEqualTo(exp); + } + + @Test + public void tesNotNin() { + // NOT(country NOT IN ["BG", "NL", "US"]) + Expression exp = parser.parse("not(country NOT IN [\"BG\", \"NL\", \"US\"])"); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Group(new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US")))), null)); + } + + @Test + public void tesNotNin2() { + // NOT country NOT IN ["BG", "NL", "US"] + Expression exp = parser.parse("NOT country NOT IN [\"BG\", \"NL\", \"US\"]"); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US"))), null)); + } + + @Test + public void tesNestedNot() { + // NOT(isOpen == true AND year >= 2020 AND NOT(country IN ["BG", "NL", "US"])) + Expression exp = parser + .parse("not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))"); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Group(new Expression(AND, + new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), + new Expression(GTE, new Key("year"), new Value(2020))), + new Expression(NOT, + new Group(new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))), + null))), + null)); + + assertThat(parser.getCache() + .get("WHERE " + "not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))")) + .isEqualTo(exp); + } + @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java new file mode 100644 index 00000000000..154c94490a5 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java @@ -0,0 +1,171 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.filter; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Value; +import org.springframework.ai.vectorstore.filter.converter.PrintFilterExpressionConverter; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +public class FilterHelperTests { + + @Test + public void negateEQ() { + assertThat(Filter.parser().parse("NOT key == 'UK' ")).isEqualTo(new Filter.Expression(ExpressionType.NOT, + new Filter.Expression(ExpressionType.EQ, new Key("key"), new Value("UK")), null)); + + assertThat(FilterHelper.negate(Filter.parser().parse("NOT key == 'UK' "))) + .isEqualTo(new Filter.Expression(ExpressionType.NE, new Key("key"), new Value("UK"))); + + assertThat(FilterHelper.negate(Filter.parser().parse("NOT (key == 'UK') "))) + .isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.NE, new Key("key"), new Value("UK")))); + } + + @Test + public void negateNE() { + var exp = Filter.parser().parse("NOT key != 'UK' "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.EQ, new Key("key"), new Value("UK"))); + + } + + @Test + public void negateGT() { + var exp = Filter.parser().parse("NOT key > 13 "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.LTE, new Key("key"), new Value(13))); + + } + + @Test + public void negateGTE() { + var exp = Filter.parser().parse("NOT key >= 13 "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(13))); + } + + @Test + public void negateLT() { + var exp = Filter.parser().parse("NOT key < 13 "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(13))); + } + + @Test + public void negateLTE() { + var exp = Filter.parser().parse("NOT key <= 13 "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.GT, new Key("key"), new Value(13))); + } + + @Test + public void negateIN() { + var exp = Filter.parser().parse("NOT key IN [11, 12, 13] "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.NIN, new Key("key"), new Value(List.of(11, 12, 13)))); + } + + @Test + public void negateNIN() { + var exp = Filter.parser().parse("NOT key NIN [11, 12, 13] "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.IN, new Key("key"), new Value(List.of(11, 12, 13)))); + } + + @Test + public void negateNIN2() { + var exp = Filter.parser().parse("NOT key NOT IN [11, 12, 13] "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.IN, new Key("key"), new Value(List.of(11, 12, 13)))); + } + + @Test + public void negateAND() { + var exp = Filter.parser().parse("NOT(key >= 11 AND key < 13)"); + assertThat(FilterHelper.negate(exp)).isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.OR, + new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)), + new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(13))))); + } + + @Test + public void negateOR() { + var exp = Filter.parser().parse("NOT(key >= 11 OR key < 13)"); + assertThat(FilterHelper.negate(exp)).isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.AND, + new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)), + new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(13))))); + } + + @Test + public void negateNot() { + var exp = Filter.parser().parse("NOT NOT(key >= 11)"); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)))); + } + + @Test + public void negateNestedNot() { + var exp = Filter.parser().parse("NOT(NOT(key >= 11))"); + assertThat(exp).isEqualTo( + new Filter.Expression(ExpressionType.NOT, new Filter.Group(new Filter.Expression(ExpressionType.NOT, + new Filter.Group(new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(11))))))); + + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)))); + } + + @Test + public void expandIN() { + var exp = Filter.parser().parse("key IN [11, 12, 13] "); + assertThat(new InNinTestConverter().convertExpression(exp)).isEqualTo("key EQ 11 OR key EQ 12 OR key EQ 13"); + } + + @Test + public void expandNIN() { + var exp1 = Filter.parser().parse("key NIN [11, 12, 13] "); + var exp2 = Filter.parser().parse("key NOT IN [11, 12, 13] "); + assertThat(exp1).isEqualTo(exp2); + assertThat(new InNinTestConverter().convertExpression(exp1)).isEqualTo("key NE 11 AND key NE 12 AND key NE 13"); + } + + private static class InNinTestConverter extends PrintFilterExpressionConverter { + + @Override + public void doExpression(Expression expression, StringBuilder context) { + if (expression.type() == ExpressionType.IN) { + FilterHelper.expandIn(expression, context, this); + } + else if (expression.type() == ExpressionType.NIN) { + FilterHelper.expandNin(expression, context, this); + } + else { + super.doExpression(expression, context); + } + } + + }; + +} diff --git a/spring-ai-docs/concepts-staging.adoc b/spring-ai-docs/concepts-staging.adoc index 26873e2d373..37693931c5e 100644 --- a/spring-ai-docs/concepts-staging.adoc +++ b/spring-ai-docs/concepts-staging.adoc @@ -113,7 +113,7 @@ One approach involves presenting both the user's request and the AI model's resp Furthermore, leveraging the information stored in the Vector Database as supplementary data can enhance the evaluation process, aiding in the determination of response relevance. -The Spring AI project currenlty provides some very basic examples of how you can evaluate the responses in the form of prompts to include in a JUnit test. +The Spring AI project currently provides some very basic examples of how you can evaluate the responses in the form of prompts to include in a JUnit test. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc index b5f8be5138d..73dfb679f6f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc @@ -87,8 +87,7 @@ country == 'UK' && year >= 2020 && isActive == true. These are the available implementations of the `VectorStore` interface: -* `InMemoryVectorStore` -* `SimplePersistentVectorStore` +* `InMemoryVectorStore` and `SimplePersistentVectorStore`. * Pinecone: https://www.pinecone.io/[PineCone] vector store. * PgVector [`PgVectorStore`]: The https://github.com/pgvector/pgvector[PostgreSQL/PGVector] vector store. * Milvus [`MilvusVectorStore`]: The https://milvus.io/[Milvus] vector store @@ -117,7 +116,7 @@ The `VectorStore` implementation computes the embeddings and stores the JSON and @Autowired VectorStore vectorStore; - void load(String sourceFile) {} + void load(String sourceFile) { JsonReader jsonReader = new JsonReader(new FileSystemResource(sourceFile), "price", "name", "shortDescription", "description", "tags"); List documents = jsonReader.get(); diff --git a/vector-stores/spring-ai-azure/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java b/vector-stores/spring-ai-azure/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java index 84c59636e4e..a99c22af63b 100644 --- a/vector-stores/spring-ai-azure/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java +++ b/vector-stores/spring-ai-azure/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java @@ -166,11 +166,20 @@ public void searchWithFilters() throws InterruptedException { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) .withSimilarityThresholdAll() - .withFilterExpression("country nin ['BG']")); + .withFilterExpression("country not in ['BG']")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country not in ['BG'])")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + // List results = // vectorStore.similaritySearch(SearchRequest.query("The World") // .withTopK(5) diff --git a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/experimental/ai/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/experimental/ai/vectorstore/ChromaVectorStoreIT.java index f6e88c0f1e9..12f00cf170b 100644 --- a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/experimental/ai/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/experimental/ai/vectorstore/ChromaVectorStoreIT.java @@ -122,6 +122,11 @@ public void addAndSearchWithFilters() { assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + results = vectorStore.similaritySearch( + request.withSimilarityThresholdAll().withFilterExpression("NOT(country == 'Netherland')")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + // Remove all documents from the store vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); }); diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java index fd19b8fdf66..6d035d6455c 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java @@ -188,6 +188,16 @@ public void searchWithFilters(String metricType) throws InterruptedException { assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country == 'BG' && year == 2020)")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); + }); } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java index d6d154f09ce..adf88701116 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java @@ -182,6 +182,12 @@ public void searchWithFilters(String distanceType) { assertThat(results.get(0).getId()).isIn(bgDocument.getId(), nlDocument.getId()); assertThat(results.get(1).getId()).isIn(bgDocument.getId(), nlDocument.getId()); + results = vectorStore.similaritySearch(searchRequest + .withFilterExpression("NOT((country == 'BG' && year == 2020) || (country == 'NL'))")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); + results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) .withSimilarityThresholdAll() diff --git a/vector-stores/spring-ai-pinecone/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java b/vector-stores/spring-ai-pinecone/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java index 63baf358b52..2506aa19f9b 100644 --- a/vector-stores/spring-ai-pinecone/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java +++ b/vector-stores/spring-ai-pinecone/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java @@ -156,6 +156,13 @@ public void addAndSearchWithFilters() { assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + results = vectorStore.similaritySearch(searchRequest.withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country == 'Netherland')")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + // Remove all documents from the store vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java b/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java index 3083749c50b..1c66e3a1f36 100644 --- a/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java +++ b/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java @@ -16,7 +16,6 @@ package org.springframework.ai.vectorstore; -import java.util.ArrayList; import java.util.Date; import java.util.List; @@ -27,6 +26,7 @@ import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.FilterHelper; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; import org.springframework.util.Assert; @@ -62,10 +62,10 @@ public void setMapIntegerToNumberValue(boolean mapIntegerToNumberValue) { protected void doExpression(Expression exp, StringBuilder context) { if (exp.type() == ExpressionType.IN) { - rewriteInNinExpressions(Filter.ExpressionType.OR, Filter.ExpressionType.EQ, exp, context); + FilterHelper.expandIn(exp, context, this); } else if (exp.type() == ExpressionType.NIN) { - rewriteInNinExpressions(Filter.ExpressionType.AND, Filter.ExpressionType.NE, exp, context); + FilterHelper.expandNin(exp, context, this); } else if (exp.type() == ExpressionType.AND || exp.type() == ExpressionType.OR) { context.append(getOperationSymbol(exp)); @@ -82,51 +82,6 @@ else if (exp.type() == ExpressionType.AND || exp.type() == ExpressionType.OR) { } } - /** - * Recursively aggregates a list of expression into a binary tree with 'aggregateType' - * join nodes. - * @param aggregateType type all tree splits. - * @param expressions list of expressions to aggregate. - * @return Returns a binary tree expression. - */ - private Filter.Expression aggregate(Filter.ExpressionType aggregateType, List expressions) { - - if (expressions.size() == 1) { - return expressions.get(0); - } - return new Filter.Expression(aggregateType, expressions.get(0), - aggregate(aggregateType, expressions.subList(1, expressions.size()))); - } - - private void rewriteInNinExpressions(Filter.ExpressionType outerExpressionType, - Filter.ExpressionType innerExpressionType, Expression exp, StringBuilder context) { - if (exp.right() instanceof Filter.Value value) { - if (value.value() instanceof List list) { - // 1. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo == "bar1" || - // foo == "bar2" || foo == "bar3" - // or equivalent to OR(foo == "bar1" OR( foo == "bar2" OR(foo == "bar3"))) - // 2. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo != "bar1" && - // foo != "bar2" && foo != "bar3" - // or equivalent to AND(foo != "bar1" AND( foo != "bar2" OR(foo != - // "bar3"))) - List eqExprs = new ArrayList<>(); - for (Object o : list) { - eqExprs.add(new Filter.Expression(innerExpressionType, exp.left(), new Filter.Value(o))); - } - this.doExpression(aggregate(outerExpressionType, eqExprs), context); - } - else { - // 1. foo IN ["bar"] is equivalent to foo == "BAR" - // 2. foo NIN ["bar"] is equivalent to foo != "BAR" - this.doExpression(new Filter.Expression(innerExpressionType, exp.left(), exp.right()), context); - } - } - else { - throw new IllegalStateException( - "Filter IN right expression should be of Filter.Value type but was " + exp.right().getClass()); - } - } - private String getOperationSymbol(Expression exp) { switch (exp.type()) { case AND: diff --git a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java index 9f4b9264678..6b3084a2c11 100644 --- a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java @@ -20,7 +20,6 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Key; diff --git a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java index 336cecce140..ed0c6e8704c 100644 --- a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java +++ b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java @@ -152,6 +152,14 @@ public void searchWithFilters() throws InterruptedException { assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT((country == 'BG' && year == 2020) || (country == 'NL'))")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); + vectorStore.delete(List.of(bgDocument.getId(), nlDocument.getId(), bgDocument2.getId())); }); }