diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java index 0b2ccf2fd51..18d8e23bc16 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java @@ -15,24 +15,64 @@ */ package org.springframework.ai.vectorstore; +import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Key; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; +import java.util.List; /** * Converts {@link Expression} into PgVector metadata filter expression format. * (https://www.postgresql.org/docs/current/functions-json.html) * + * @author Muthukumaran Navaneethakrishnan * @author Christian Tzolov */ public class PgVectorFilterExpressionConverter extends AbstractFilterExpressionConverter { @Override protected void doExpression(Expression expression, StringBuilder context) { - this.convertOperand(expression.left(), context); - context.append(getOperationSymbol(expression)); - this.convertOperand(expression.right(), context); + if (expression.type() == Filter.ExpressionType.IN) { + handleIn(expression, context); + } + else if (expression.type() == Filter.ExpressionType.NIN) { + handleNotIn(expression, context); + } + else { + this.convertOperand(expression.left(), context); + context.append(getOperationSymbol(expression)); + this.convertOperand(expression.right(), context); + } + } + + private void handleIn(Expression expression, StringBuilder context) { + context.append("("); + convertToConditions(expression, context); + context.append(")"); + } + + private void convertToConditions(Expression expression, StringBuilder context) { + Filter.Value right = (Filter.Value) expression.right(); + Object value = right.value(); + if (!(value instanceof List)) { + throw new IllegalArgumentException("Expected a List, but got: " + value.getClass().getSimpleName()); + } + List values = (List) value; + for (int i = 0; i < values.size(); i++) { + this.convertOperand(expression.left(), context); + context.append(" == "); + this.doSingleValue(values.get(i), context); + if (i < values.size() - 1) { + context.append(" || "); + } + } + } + + private void handleNotIn(Expression expression, StringBuilder context) { + context.append("!("); + convertToConditions(expression, context); + context.append(")"); } private String getOperationSymbol(Expression exp) { @@ -53,10 +93,6 @@ private String getOperationSymbol(Expression exp) { return " > "; case GTE: return " >= "; - case IN: - return " in "; - case NIN: - return " nin "; default: throw new RuntimeException("Not supported expression type: " + exp.type()); } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java index 9cdc6b7bdbb..ca662ab8e09 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java @@ -34,6 +34,7 @@ import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; /** + * @author Muthukumaran Navaneethakrishnan * @author Christian Tzolov */ public class PgVectorFilterExpressionConverterTests { @@ -61,7 +62,8 @@ public void tesIn() { // genre in ["comedy", "documentary", "drama"] String vectorExpr = converter.convertExpression( new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); - assertThat(vectorExpr).isEqualTo("$.genre in [\"comedy\",\"documentary\",\"drama\"]"); + assertThat(vectorExpr) + .isEqualTo("($.genre == \"comedy\" || $.genre == \"documentary\" || $.genre == \"drama\")"); } @Test @@ -82,7 +84,7 @@ public void testGroup() { new Expression(EQ, new Key("country"), new Value("BG")))), new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); assertThat(vectorExpr) - .isEqualTo("($.year >= 2020 || $.country == \"BG\") && $.city nin [\"Sofia\",\"Plovdiv\"]"); + .isEqualTo("($.year >= 2020 || $.country == \"BG\") && !($.city == \"Sofia\" || $.city == \"Plovdiv\")"); } @Test @@ -93,7 +95,8 @@ public void tesBoolean() { new Expression(GTE, new Key("year"), new Value(2020))), new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); - assertThat(vectorExpr).isEqualTo("$.isOpen == true && $.year >= 2020 && $.country in [\"BG\",\"NL\",\"US\"]"); + assertThat(vectorExpr).isEqualTo( + "$.isOpen == true && $.year >= 2020 && ($.country == \"BG\" || $.country == \"NL\" || $.country == \"US\")"); } @Test diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java index 19586503d31..cfdcf86d973 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java @@ -24,12 +24,15 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.stream.Stream; import javax.sql.DataSource; import org.junit.Assert; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; @@ -57,6 +60,7 @@ import com.zaxxer.hikari.HikariDataSource; /** + * @author Muthukumaran Navaneethakrishnan * @author Christian Tzolov */ @Testcontainers @@ -128,6 +132,46 @@ public void addAndSearch(String distanceType) { }); } + static Stream provideFilters() { + return Stream.of(Arguments.of("country in ['BG','NL']", 3), // String Filters In + Arguments.of("year in [2020]", 1), // Numeric Filters In + Arguments.of("country not in ['BG']", 1), // String Filter Not In + Arguments.of("year not in [2020]", 2) // Numeric Filter Not In + ); + } + + @ParameterizedTest(name = "Filter expression {0} should return {1} records ") + @MethodSource("provideFilters") + public void searchWithInFilter(String expression, Integer expectedRecords) { + + contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE") + .run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2020, "foo bar 1", "bar.foo")); + var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "NL")); + var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2023)); + + vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); + + SearchRequest searchRequest = SearchRequest.query("The World") + .withFilterExpression(expression) + .withTopK(5) + .withSimilarityThresholdAll(); + + List results = vectorStore.similaritySearch(searchRequest); + + assertThat(results).hasSize(expectedRecords); + + // Remove all documents from the store + dropTable(context); + }); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE", "NEGATIVE_INNER_PRODUCT" }) public void searchWithFilters(String distanceType) {