Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public record Value(Object value) implements Operand {
*/
public enum ExpressionType {

AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN
AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN, NOT

}

Expand All @@ -131,6 +131,9 @@ public enum ExpressionType {
* be another {@link Expression}.
*/
public record Expression(ExpressionType type, Operand left, Operand right) implements Operand {
public Expression(ExpressionType type, Operand operand) {
this(type, operand, null);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,8 @@ public Op group(Op content) {
return new Op(new Filter.Group(content.build()));
}

public Op not(Op content) {
return new Op(new Filter.Expression(ExpressionType.NOT, content.expression, null));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.springframework.ai.vectorstore.filter.antlr4.FiltersBaseVisitor;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersLexer;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersParser;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersParser.NotExpressionContext;
import org.springframework.core.NestedExceptionUtils;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -263,6 +264,11 @@ public Filter.Operand visitGroupExpression(FiltersParser.GroupExpressionContext
return new Filter.Group(castToExpression(this.visit(ctx.booleanExpression())));
}

@Override
public Filter.Operand visitNotExpression(NotExpressionContext ctx) {
return new Filter.Expression(Filter.ExpressionType.NOT, this.visit(ctx.booleanExpression()), null);
}

public Filter.Expression castToExpression(Filter.Operand expression) {
if (expression instanceof Filter.Group group) {
// Remove the top-level grouping.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Copyright 2023-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.vectorstore.filter;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.Filter.ExpressionType;
import org.springframework.ai.vectorstore.filter.Filter.Operand;
import org.springframework.ai.vectorstore.filter.converter.FilterExpressionConverter;
import org.springframework.util.Assert;

/**
* Helper class providing various boolean transformation.
*
* @author Christian Tzolov
*/
public class FilterHelper {

private FilterHelper() {
}

private final static Map<ExpressionType, ExpressionType> TYPE_NEGATION_MAP = Map.of(ExpressionType.AND,
ExpressionType.OR, ExpressionType.OR, ExpressionType.AND, ExpressionType.EQ, ExpressionType.NE,
ExpressionType.NE, ExpressionType.EQ, ExpressionType.GT, ExpressionType.LTE, ExpressionType.GTE,
ExpressionType.LT, ExpressionType.LT, ExpressionType.GTE, ExpressionType.LTE, ExpressionType.GT,
ExpressionType.IN, ExpressionType.NIN, ExpressionType.NIN, ExpressionType.IN);

/**
* Transforms the input expression into a semantically equivalent one with negation
* operators propagated thought the expression tree by following the negation rules:
*
* <pre>
* NOT(NOT(a)) = a
*
* NOT(a AND b) = NOT(a) OR NOT(b)
* NOT(a OR b) = NOT(a) AND NOT(b)
*
* NOT(a EQ b) = a NE b
* NOT(a NE b) = a EQ b
*
* NOT(a GT b) = a LTE b
* NOT(a GTE b) = a LT b
*
* NOT(a LT b) = a GTE b
* NOT(a LTE b) = a GT b
*
* NOT(a IN [...]) = a NIN [...]
* NOT(a NIN [...]) = a IN [...]
* </pre>
* @param operand Filter expression to negate.
* @return Returns an negation of the input expression.
*/
public static Filter.Operand negate(Filter.Operand operand) {

if (operand instanceof Filter.Group group) {
Operand inEx = negate(group.content());
if (inEx instanceof Filter.Group inEx2) {
inEx = inEx2.content();
}
return new Filter.Group((Expression) inEx);
}
else if (operand instanceof Filter.Expression exp) {
switch (exp.type()) {
case NOT: // NOT(NOT(a)) = a
return negate(exp.left());
case AND: // NOT(a AND b) = NOT(a) OR NOT(b)
case OR: // NOT(a OR b) = NOT(a) AND NOT(b)
return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), negate(exp.left()),
negate(exp.right()));
case EQ: // NOT(e EQ b) = e NE b
case NE: // NOT(e NE b) = e EQ b
case GT: // NOT(e GT b) = e LTE b
case GTE: // NOT(e GTE b) = e LT b
case LT: // NOT(e LT b) = e GTE b
case LTE: // NOT(e LTE b) = e GT b
return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), exp.left(), exp.right());
case IN: // NOT(e IN [...]) = e NIN [...]
case NIN: // NOT(e NIN [...]) = e IN [...]
return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), exp.left(), exp.right());
default:
throw new IllegalArgumentException("Unknown expression type: " + exp.type());
}
}
else {
throw new IllegalArgumentException("Can not negate operand of type: " + operand.getClass());
}
}

/**
* Expands the IN into a semantically equivalent boolean expressions of ORs of EQs.
* Useful for providers that don't provide native IN support.
*
* For example the <pre>
* foo IN ["bar1", "bar2", "bar3"]
* </pre>
*
* expression is equivalent to
*
* <pre>
* {@code foo == "bar1" || foo == "bar2" || foo == "bar3" (e.g. OR(foo EQ "bar1" OR(foo EQ "bar2" OR(foo EQ "bar3")))}
* </pre>
* @param exp input IN expression.
* @param context Output native expression.
* @param filterExpressionConverter {@link FilterExpressionConverter} used to compose
* the OR and EQ expanded expressions.
*/
public static void expandIn(Expression exp, StringBuilder context,
FilterExpressionConverter filterExpressionConverter) {
Assert.isTrue(exp.type() == ExpressionType.IN, "Expected IN expressions but was: " + exp.type());
expandInNinExpressions(ExpressionType.OR, ExpressionType.EQ, exp, context, filterExpressionConverter);
}

/**
*
* Expands the NIN (e.g. NOT IN) into a semantically equivalent boolean expressions of
* ANDs of NEs. Useful for providers that don't provide native NIN support.<br/>
*
* For example the
*
* <pre>
* foo NIN ["bar1", "bar2", "bar3"] (or foo NOT IN ["bar1", "bar2", "bar3"])
* </pre>
*
* express is equivalent to
*
* <pre>
* {@code foo != "bar1" && foo != "bar2" && foo != "bar3" (e.g. AND(foo NE "bar1" AND( foo NE "bar2" OR(foo NE "bar3"))) )}
* </pre>
* @param exp input NIN expression.
* @param context Output native expression.
* @param filterExpressionConverter {@link FilterExpressionConverter} used to compose
* the AND and NE expanded expressions.
*/
public static void expandNin(Expression exp, StringBuilder context,
FilterExpressionConverter filterExpressionConverter) {
Assert.isTrue(exp.type() == ExpressionType.NIN, "Expected NIN expressions but was: " + exp.type());
expandInNinExpressions(ExpressionType.AND, ExpressionType.NE, exp, context, filterExpressionConverter);
}

private static void expandInNinExpressions(Filter.ExpressionType outerExpressionType,
Filter.ExpressionType innerExpressionType, Expression exp, StringBuilder context,
FilterExpressionConverter expressionConverter) {
if (exp.right() instanceof Filter.Value value) {
if (value.value() instanceof List list) {
// 1. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo == "bar1" ||
// foo == "bar2" || foo == "bar3"
// or equivalent to OR(foo == "bar1" OR( foo == "bar2" OR(foo == "bar3")))
// 2. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo != "bar1" &&
// foo != "bar2" && foo != "bar3"
// or equivalent to AND(foo != "bar1" AND( foo != "bar2" OR(foo !=
// "bar3")))
List<Filter.Expression> eqExprs = new ArrayList<>();
for (Object o : list) {
eqExprs.add(new Filter.Expression(innerExpressionType, exp.left(), new Filter.Value(o)));
}
context.append(expressionConverter.convertExpression(aggregate(outerExpressionType, eqExprs)));
}
else {
// 1. foo IN ["bar"] is equivalent to foo == "BAR"
// 2. foo NIN ["bar"] is equivalent to foo != "BAR"
context.append(expressionConverter
.convertExpression(new Filter.Expression(innerExpressionType, exp.left(), exp.right())));
}
}
else {
throw new IllegalStateException(
"Filter IN right expression should be of Filter.Value type but was " + exp.right().getClass());
}
}

/**
* Recursively aggregates a list of expression into a binary tree with 'aggregateType'
* join nodes.
* @param aggregateType type all tree splits.
* @param expressions list of expressions to aggregate.
* @return Returns a binary tree expression.
*/
private static Filter.Expression aggregate(Filter.ExpressionType aggregateType,
List<Filter.Expression> expressions) {

if (expressions.size() == 1) {
return expressions.get(0);
}
return new Filter.Expression(aggregateType, expressions.get(0),
aggregate(aggregateType, expressions.subList(1, expressions.size())));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ constant


atn:
[4, 1, 26, 87, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 30, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 38, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 46, 8, 1, 10, 1, 12, 1, 49, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 55, 8, 2, 10, 2, 12, 2, 58, 9, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 69, 8, 4, 1, 5, 3, 5, 72, 8, 5, 1, 5, 1, 5, 3, 5, 76, 8, 5, 1, 5, 1, 5, 4, 5, 80, 8, 5, 11, 5, 12, 5, 81, 1, 5, 3, 5, 85, 8, 5, 1, 5, 0, 1, 2, 6, 0, 2, 4, 6, 8, 10, 0, 2, 2, 0, 8, 8, 11, 15, 1, 0, 9, 10, 95, 0, 12, 1, 0, 0, 0, 2, 37, 1, 0, 0, 0, 4, 50, 1, 0, 0, 0, 6, 61, 1, 0, 0, 0, 8, 68, 1, 0, 0, 0, 10, 84, 1, 0, 0, 0, 12, 13, 5, 1, 0, 0, 13, 14, 3, 2, 1, 0, 14, 15, 5, 0, 0, 1, 15, 1, 1, 0, 0, 0, 16, 17, 6, 1, -1, 0, 17, 18, 3, 8, 4, 0, 18, 19, 3, 6, 3, 0, 19, 20, 3, 10, 5, 0, 20, 38, 1, 0, 0, 0, 21, 22, 3, 8, 4, 0, 22, 23, 5, 18, 0, 0, 23, 24, 3, 4, 2, 0, 24, 38, 1, 0, 0, 0, 25, 29, 3, 8, 4, 0, 26, 27, 5, 20, 0, 0, 27, 30, 5, 18, 0, 0, 28, 30, 5, 19, 0, 0, 29, 26, 1, 0, 0, 0, 29, 28, 1, 0, 0, 0, 30, 31, 1, 0, 0, 0, 31, 32, 3, 4, 2, 0, 32, 38, 1, 0, 0, 0, 33, 34, 5, 6, 0, 0, 34, 35, 3, 2, 1, 0, 35, 36, 5, 7, 0, 0, 36, 38, 1, 0, 0, 0, 37, 16, 1, 0, 0, 0, 37, 21, 1, 0, 0, 0, 37, 25, 1, 0, 0, 0, 37, 33, 1, 0, 0, 0, 38, 47, 1, 0, 0, 0, 39, 40, 10, 3, 0, 0, 40, 41, 5, 16, 0, 0, 41, 46, 3, 2, 1, 4, 42, 43, 10, 2, 0, 0, 43, 44, 5, 17, 0, 0, 44, 46, 3, 2, 1, 3, 45, 39, 1, 0, 0, 0, 45, 42, 1, 0, 0, 0, 46, 49, 1, 0, 0, 0, 47, 45, 1, 0, 0, 0, 47, 48, 1, 0, 0, 0, 48, 3, 1, 0, 0, 0, 49, 47, 1, 0, 0, 0, 50, 51, 5, 4, 0, 0, 51, 56, 3, 10, 5, 0, 52, 53, 5, 3, 0, 0, 53, 55, 3, 10, 5, 0, 54, 52, 1, 0, 0, 0, 55, 58, 1, 0, 0, 0, 56, 54, 1, 0, 0, 0, 56, 57, 1, 0, 0, 0, 57, 59, 1, 0, 0, 0, 58, 56, 1, 0, 0, 0, 59, 60, 5, 5, 0, 0, 60, 5, 1, 0, 0, 0, 61, 62, 7, 0, 0, 0, 62, 7, 1, 0, 0, 0, 63, 64, 5, 25, 0, 0, 64, 65, 5, 2, 0, 0, 65, 69, 5, 25, 0, 0, 66, 69, 5, 25, 0, 0, 67, 69, 5, 22, 0, 0, 68, 63, 1, 0, 0, 0, 68, 66, 1, 0, 0, 0, 68, 67, 1, 0, 0, 0, 69, 9, 1, 0, 0, 0, 70, 72, 7, 1, 0, 0, 71, 70, 1, 0, 0, 0, 71, 72, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 85, 5, 23, 0, 0, 74, 76, 7, 1, 0, 0, 75, 74, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 77, 1, 0, 0, 0, 77, 85, 5, 24, 0, 0, 78, 80, 5, 22, 0, 0, 79, 78, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 79, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 85, 1, 0, 0, 0, 83, 85, 5, 21, 0, 0, 84, 71, 1, 0, 0, 0, 84, 75, 1, 0, 0, 0, 84, 79, 1, 0, 0, 0, 84, 83, 1, 0, 0, 0, 85, 11, 1, 0, 0, 0, 10, 29, 37, 45, 47, 56, 68, 71, 75, 81, 84]
[4, 1, 26, 89, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 30, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 40, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 48, 8, 1, 10, 1, 12, 1, 51, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 57, 8, 2, 10, 2, 12, 2, 60, 9, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 71, 8, 4, 1, 5, 3, 5, 74, 8, 5, 1, 5, 1, 5, 3, 5, 78, 8, 5, 1, 5, 1, 5, 4, 5, 82, 8, 5, 11, 5, 12, 5, 83, 1, 5, 3, 5, 87, 8, 5, 1, 5, 0, 1, 2, 6, 0, 2, 4, 6, 8, 10, 0, 2, 2, 0, 8, 8, 11, 15, 1, 0, 9, 10, 98, 0, 12, 1, 0, 0, 0, 2, 39, 1, 0, 0, 0, 4, 52, 1, 0, 0, 0, 6, 63, 1, 0, 0, 0, 8, 70, 1, 0, 0, 0, 10, 86, 1, 0, 0, 0, 12, 13, 5, 1, 0, 0, 13, 14, 3, 2, 1, 0, 14, 15, 5, 0, 0, 1, 15, 1, 1, 0, 0, 0, 16, 17, 6, 1, -1, 0, 17, 18, 3, 8, 4, 0, 18, 19, 3, 6, 3, 0, 19, 20, 3, 10, 5, 0, 20, 40, 1, 0, 0, 0, 21, 22, 3, 8, 4, 0, 22, 23, 5, 18, 0, 0, 23, 24, 3, 4, 2, 0, 24, 40, 1, 0, 0, 0, 25, 29, 3, 8, 4, 0, 26, 27, 5, 20, 0, 0, 27, 30, 5, 18, 0, 0, 28, 30, 5, 19, 0, 0, 29, 26, 1, 0, 0, 0, 29, 28, 1, 0, 0, 0, 30, 31, 1, 0, 0, 0, 31, 32, 3, 4, 2, 0, 32, 40, 1, 0, 0, 0, 33, 34, 5, 6, 0, 0, 34, 35, 3, 2, 1, 0, 35, 36, 5, 7, 0, 0, 36, 40, 1, 0, 0, 0, 37, 38, 5, 20, 0, 0, 38, 40, 3, 2, 1, 1, 39, 16, 1, 0, 0, 0, 39, 21, 1, 0, 0, 0, 39, 25, 1, 0, 0, 0, 39, 33, 1, 0, 0, 0, 39, 37, 1, 0, 0, 0, 40, 49, 1, 0, 0, 0, 41, 42, 10, 4, 0, 0, 42, 43, 5, 16, 0, 0, 43, 48, 3, 2, 1, 5, 44, 45, 10, 3, 0, 0, 45, 46, 5, 17, 0, 0, 46, 48, 3, 2, 1, 4, 47, 41, 1, 0, 0, 0, 47, 44, 1, 0, 0, 0, 48, 51, 1, 0, 0, 0, 49, 47, 1, 0, 0, 0, 49, 50, 1, 0, 0, 0, 50, 3, 1, 0, 0, 0, 51, 49, 1, 0, 0, 0, 52, 53, 5, 4, 0, 0, 53, 58, 3, 10, 5, 0, 54, 55, 5, 3, 0, 0, 55, 57, 3, 10, 5, 0, 56, 54, 1, 0, 0, 0, 57, 60, 1, 0, 0, 0, 58, 56, 1, 0, 0, 0, 58, 59, 1, 0, 0, 0, 59, 61, 1, 0, 0, 0, 60, 58, 1, 0, 0, 0, 61, 62, 5, 5, 0, 0, 62, 5, 1, 0, 0, 0, 63, 64, 7, 0, 0, 0, 64, 7, 1, 0, 0, 0, 65, 66, 5, 25, 0, 0, 66, 67, 5, 2, 0, 0, 67, 71, 5, 25, 0, 0, 68, 71, 5, 25, 0, 0, 69, 71, 5, 22, 0, 0, 70, 65, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 70, 69, 1, 0, 0, 0, 71, 9, 1, 0, 0, 0, 72, 74, 7, 1, 0, 0, 73, 72, 1, 0, 0, 0, 73, 74, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 87, 5, 23, 0, 0, 76, 78, 7, 1, 0, 0, 77, 76, 1, 0, 0, 0, 77, 78, 1, 0, 0, 0, 78, 79, 1, 0, 0, 0, 79, 87, 5, 24, 0, 0, 80, 82, 5, 22, 0, 0, 81, 80, 1, 0, 0, 0, 82, 83, 1, 0, 0, 0, 83, 81, 1, 0, 0, 0, 83, 84, 1, 0, 0, 0, 84, 87, 1, 0, 0, 0, 85, 87, 5, 21, 0, 0, 86, 73, 1, 0, 0, 0, 86, 77, 1, 0, 0, 0, 86, 81, 1, 0, 0, 0, 86, 85, 1, 0, 0, 0, 87, 11, 1, 0, 0, 0, 10, 29, 39, 47, 49, 58, 70, 73, 77, 83, 86]
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,28 @@ public void enterInExpression(FiltersParser.InExpressionContext ctx) {
public void exitInExpression(FiltersParser.InExpressionContext ctx) {
}

/**
* {@inheritDoc}
*
* <p>
* The default implementation does nothing.
* </p>
*/
@Override
public void enterNotExpression(FiltersParser.NotExpressionContext ctx) {
}

/**
* {@inheritDoc}
*
* <p>
* The default implementation does nothing.
* </p>
*/
@Override
public void exitNotExpression(FiltersParser.NotExpressionContext ctx) {
}

/**
* {@inheritDoc}
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ public T visitInExpression(FiltersParser.InExpressionContext ctx) {
return visitChildren(ctx);
}

/**
* {@inheritDoc}
*
* <p>
* The default implementation returns the result of calling {@link #visitChildren} on
* {@code ctx}.
* </p>
*/
@Override
public T visitNotExpression(FiltersParser.NotExpressionContext ctx) {
return visitChildren(ctx);
}

/**
* {@inheritDoc}
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ public interface FiltersListener extends ParseTreeListener {
*/
void exitInExpression(FiltersParser.InExpressionContext ctx);

/**
* Enter a parse tree produced by the {@code NotExpression} labeled alternative in
* {@link FiltersParser#booleanExpression}.
* @param ctx the parse tree
*/
void enterNotExpression(FiltersParser.NotExpressionContext ctx);

/**
* Exit a parse tree produced by the {@code NotExpression} labeled alternative in
* {@link FiltersParser#booleanExpression}.
* @param ctx the parse tree
*/
void exitNotExpression(FiltersParser.NotExpressionContext ctx);

/**
* Enter a parse tree produced by the {@code CompareExpression} labeled alternative in
* {@link FiltersParser#booleanExpression}.
Expand Down
Loading