diff --git a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/FilterStringEscapeUtils.java b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/FilterStringEscapeUtils.java new file mode 100644 index 00000000000..3e81ed868d9 --- /dev/null +++ b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/FilterStringEscapeUtils.java @@ -0,0 +1,227 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.filter; + +/** + * Utility class for safely escaping strings in filter expressions to prevent injection + * attacks. This class provides methods to escape special characters that could be used to + * break filter expression syntax or cause security vulnerabilities. + * + * @author Spring AI Team + * @since 1.0.0 + */ +public final class FilterStringEscapeUtils { + + private FilterStringEscapeUtils() { + // Utility class - prevent instantiation + throw new IllegalStateException("Utility class"); + } + + /** + * Escape characters for double-quoted strings (used in GraphQL, JSON, etc.). Escapes: + * ", \, \n, \r, \t, \b, \f + * @param input the string to escape + * @return the escaped string safe for use in double-quoted contexts + */ + public static String escapeForDoubleQuotes(String input) { + if (input == null) { + return null; + } + + StringBuilder result = new StringBuilder(); + for (char c : input.toCharArray()) { + switch (c) { + case '"' -> result.append("\\\""); + case '\\' -> result.append("\\\\"); + case '\n' -> result.append("\\n"); + case '\r' -> result.append("\\r"); + case '\t' -> result.append("\\t"); + case '\b' -> result.append("\\b"); + case '\f' -> result.append("\\f"); + default -> result.append(c); + } + } + return result.toString(); + } + + /** + * Escape characters for single-quoted strings (used in SQL, etc.). Escapes: ', \, \n, + * \r, \t, \b, \f + * @param input the string to escape + * @return the escaped string safe for use in single-quoted contexts + */ + public static String escapeForSingleQuotes(String input) { + if (input == null) { + return null; + } + + StringBuilder result = new StringBuilder(); + for (char c : input.toCharArray()) { + switch (c) { + case '\'' -> result.append("\\'"); + case '\\' -> result.append("\\\\"); + case '\n' -> result.append("\\n"); + case '\r' -> result.append("\\r"); + case '\t' -> result.append("\\t"); + case '\b' -> result.append("\\b"); + case '\f' -> result.append("\\f"); + default -> result.append(c); + } + } + return result.toString(); + } + + /** + * Escape characters for SQL identifiers and values. This method handles both single + * quotes and backslashes for SQL contexts. + * @param input the string to escape + * @return the escaped string safe for use in SQL contexts + */ + public static String escapeForSql(String input) { + if (input == null) { + return null; + } + // SQL standard: escape single quotes by doubling them + return input.replace("'", "''"); + } + + /** + * Escape characters for GraphQL string values. This method handles double quotes and + * escape sequences for GraphQL contexts. + * @param input the string to escape + * @return the escaped string safe for use in GraphQL contexts + */ + public static String escapeForGraphQL(String input) { + return escapeForDoubleQuotes(input); + } + + /** + * Escape characters for JSON string values. This method handles double quotes and + * escape sequences for JSON contexts. + * @param input the string to escape + * @return the escaped string safe for use in JSON contexts + */ + public static String escapeForJson(String input) { + return escapeForDoubleQuotes(input); + } + + /** + * Escape characters for Redis search queries. Redis has specific escaping rules for + * special characters. + * @param input the string to escape + * @return the escaped string safe for use in Redis search contexts + */ + public static String escapeForRedis(String input) { + if (input == null) { + return null; + } + + StringBuilder result = new StringBuilder(); + for (char c : input.toCharArray()) { + switch (c) { + case '\\' -> result.append("\\\\"); + case '"' -> result.append("\\\""); + case '\'' -> result.append("\\'"); + case ' ' -> result.append("\\ "); + case '\n' -> result.append("\\n"); + case '\r' -> result.append("\\r"); + case '\t' -> result.append("\\t"); + case '\b' -> result.append("\\b"); + case '\f' -> result.append("\\f"); + default -> result.append(c); + } + } + return result.toString(); + } + + /** + * Escape characters for Elasticsearch queries. Elasticsearch requires specific + * escaping for special characters in query strings. + * @param input the string to escape + * @return the escaped string safe for use in Elasticsearch query contexts + */ + public static String escapeForElasticsearch(String input) { + if (input == null) { + return null; + } + + StringBuilder result = new StringBuilder(); + for (char c : input.toCharArray()) { + switch (c) { + case '\\' -> result.append("\\\\"); + case '"' -> result.append("\\\""); + case '\'' -> result.append("\\'"); + case '+' -> result.append("\\+"); + case '-' -> result.append("\\-"); + case '=' -> result.append("\\="); + case '&' -> result.append("\\&"); + case '|' -> result.append("\\|"); + case '!' -> result.append("\\!"); + case '(' -> result.append("\\("); + case ')' -> result.append("\\)"); + case '{' -> result.append("\\{"); + case '}' -> result.append("\\}"); + case '[' -> result.append("\\["); + case ']' -> result.append("\\]"); + case '^' -> result.append("\\^"); + case '~' -> result.append("\\~"); + case '*' -> result.append("\\*"); + case '?' -> result.append("\\?"); + case ':' -> result.append("\\:"); + case '/' -> result.append("\\/"); + case ' ' -> result.append("\\ "); + case '\n' -> result.append("\\n"); + case '\r' -> result.append("\\r"); + case '\t' -> result.append("\\t"); + default -> result.append(c); + } + } + return result.toString(); + } + + /** + * Generic escape method that takes an escape type parameter. + * @param input the string to escape + * @param escapeType the type of escaping to apply + * @return the escaped string + */ + public static String escape(String input, EscapeType escapeType) { + if (input == null) { + return null; + } + + return switch (escapeType) { + case DOUBLE_QUOTES -> escapeForDoubleQuotes(input); + case SINGLE_QUOTES -> escapeForSingleQuotes(input); + case SQL -> escapeForSql(input); + case GRAPHQL -> escapeForGraphQL(input); + case JSON -> escapeForJson(input); + case REDIS -> escapeForRedis(input); + case ELASTICSEARCH -> escapeForElasticsearch(input); + }; + } + + /** + * Enumeration of different escape types supported by this utility. + */ + public enum EscapeType { + + DOUBLE_QUOTES, SINGLE_QUOTES, SQL, GRAPHQL, JSON, REDIS, ELASTICSEARCH + + } + +} diff --git a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java index 3d63d121713..7f1eca4fc79 100644 --- a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java +++ b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java @@ -24,7 +24,6 @@ import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Operand; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; -import org.springframework.ai.vectorstore.filter.FilterHelper; /** * AbstractFilterExpressionConverter is an abstract class that implements the @@ -99,7 +98,7 @@ protected void doNot(Filter.Expression expression, StringBuilder context) { // equivalent negation expression. // Effectively removing the NOT types form the boolean expression tree before // passing it to the doExpression. - this.convertOperand(FilterHelper.negate(expression), context); + this.convertOperand(org.springframework.ai.vectorstore.filter.FilterHelper.negate(expression), context); } /** @@ -145,7 +144,8 @@ protected void doValue(Filter.Value filterValue, StringBuilder context) { */ protected void doSingleValue(Object value, StringBuilder context) { if (value instanceof String) { - context.append(String.format("\"%s\"", value)); + context.append(String.format("\"%s\"", org.springframework.ai.vectorstore.filter.FilterStringEscapeUtils + .escapeForDoubleQuotes((String) value))); } else { context.append(value); diff --git a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/converter/SimpleVectorStoreFilterExpressionConverter.java b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/converter/SimpleVectorStoreFilterExpressionConverter.java index 838e26a2d42..f31a83bf732 100644 --- a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/converter/SimpleVectorStoreFilterExpressionConverter.java +++ b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/filter/converter/SimpleVectorStoreFilterExpressionConverter.java @@ -26,6 +26,7 @@ import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.FilterStringEscapeUtils; /** * Converts {@link Expression} into SpEL metadata filter expression format. @@ -128,7 +129,7 @@ else if (value instanceof String text) { } } else { - context.append(text); + context.append(FilterStringEscapeUtils.escapeForSingleQuotes(text)); } context.append("'"); } diff --git a/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/filter/FilterStringEscapeUtilsTests.java b/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/filter/FilterStringEscapeUtilsTests.java new file mode 100644 index 00000000000..5223b575b3d --- /dev/null +++ b/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/filter/FilterStringEscapeUtilsTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.filter; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link FilterStringEscapeUtils} to ensure proper escaping of special + * characters and prevention of injection attacks in filter expressions. + * + * @author Spring AI Team + * @since 1.0.0 + */ +class FilterStringEscapeUtilsTests { + + @Test + void testEscapeForGraphQL() { + assertThat(FilterStringEscapeUtils.escapeForGraphQL("hello")).isEqualTo("hello"); + assertThat(FilterStringEscapeUtils.escapeForGraphQL("hello\"world")).isEqualTo("hello\\\"world"); + assertThat(FilterStringEscapeUtils.escapeForGraphQL("hello\\world")).isEqualTo("hello\\\\world"); + assertThat(FilterStringEscapeUtils.escapeForGraphQL("hello\"\\world")).isEqualTo("hello\\\"\\\\world"); + } + + @Test + void testEscapeForSql() { + assertThat(FilterStringEscapeUtils.escapeForSql("hello")).isEqualTo("hello"); + assertThat(FilterStringEscapeUtils.escapeForSql("hello'world")).isEqualTo("hello''world"); + assertThat(FilterStringEscapeUtils.escapeForSql("hello''world")).isEqualTo("hello''''world"); + } + + @Test + void testEscapeForDoubleQuotes() { + assertThat(FilterStringEscapeUtils.escapeForDoubleQuotes("hello")).isEqualTo("hello"); + assertThat(FilterStringEscapeUtils.escapeForDoubleQuotes("hello\"world")).isEqualTo("hello\\\"world"); + assertThat(FilterStringEscapeUtils.escapeForDoubleQuotes("hello\\world")).isEqualTo("hello\\\\world"); + assertThat(FilterStringEscapeUtils.escapeForDoubleQuotes("hello\"\\world")).isEqualTo("hello\\\"\\\\world"); + } + + @Test + void testEscapeForSingleQuotes() { + assertThat(FilterStringEscapeUtils.escapeForSingleQuotes("hello")).isEqualTo("hello"); + assertThat(FilterStringEscapeUtils.escapeForSingleQuotes("hello'world")).isEqualTo("hello\\'world"); + assertThat(FilterStringEscapeUtils.escapeForSingleQuotes("hello\\'world")).isEqualTo("hello\\\\\\'world"); + } + + @Test + void testEscapeForJson() { + assertThat(FilterStringEscapeUtils.escapeForJson("hello")).isEqualTo("hello"); + assertThat(FilterStringEscapeUtils.escapeForJson("hello\"world")).isEqualTo("hello\\\"world"); + assertThat(FilterStringEscapeUtils.escapeForJson("hello\\world")).isEqualTo("hello\\\\world"); + assertThat(FilterStringEscapeUtils.escapeForJson("hello\"\\world")).isEqualTo("hello\\\"\\\\world"); + } + + @Test + void testEscapeForRedis() { + assertThat(FilterStringEscapeUtils.escapeForRedis("hello")).isEqualTo("hello"); + assertThat(FilterStringEscapeUtils.escapeForRedis("hello\"world")).isEqualTo("hello\\\"world"); + assertThat(FilterStringEscapeUtils.escapeForRedis("hello\\world")).isEqualTo("hello\\\\world"); + assertThat(FilterStringEscapeUtils.escapeForRedis("hello\"\\world")).isEqualTo("hello\\\"\\\\world"); + } + + @Test + void testEscapeForElasticsearch() { + assertThat(FilterStringEscapeUtils.escapeForElasticsearch("hello")).isEqualTo("hello"); + assertThat(FilterStringEscapeUtils.escapeForElasticsearch("hello\"world")).isEqualTo("hello\\\"world"); + assertThat(FilterStringEscapeUtils.escapeForElasticsearch("hello\\world")).isEqualTo("hello\\\\world"); + assertThat(FilterStringEscapeUtils.escapeForElasticsearch("hello\"\\world")).isEqualTo("hello\\\"\\\\world"); + } + + @Test + void testGenericEscapeMethod() { + assertThat(FilterStringEscapeUtils.escape("hello\"world", FilterStringEscapeUtils.EscapeType.DOUBLE_QUOTES)) + .isEqualTo("hello\\\"world"); + assertThat(FilterStringEscapeUtils.escape("hello'world", FilterStringEscapeUtils.EscapeType.SINGLE_QUOTES)) + .isEqualTo("hello\\'world"); + assertThat(FilterStringEscapeUtils.escape("hello'world", FilterStringEscapeUtils.EscapeType.SQL)) + .isEqualTo("hello''world"); + assertThat(FilterStringEscapeUtils.escape("hello\"world", FilterStringEscapeUtils.EscapeType.GRAPHQL)) + .isEqualTo("hello\\\"world"); + assertThat(FilterStringEscapeUtils.escape("hello\"world", FilterStringEscapeUtils.EscapeType.JSON)) + .isEqualTo("hello\\\"world"); + assertThat(FilterStringEscapeUtils.escape("hello\"world", FilterStringEscapeUtils.EscapeType.REDIS)) + .isEqualTo("hello\\\"world"); + assertThat(FilterStringEscapeUtils.escape("hello\"world", FilterStringEscapeUtils.EscapeType.ELASTICSEARCH)) + .isEqualTo("hello\\\"world"); + } + + @Test + void testInjectionAttackPrevention() { + // Test strings that could be used for injection attacks + String maliciousString = "\"; DROP TABLE users; --"; + String escaped = FilterStringEscapeUtils.escapeForDoubleQuotes(maliciousString); + assertThat(escaped).isEqualTo("\\\"; DROP TABLE users; --"); + // The assertion should check for the presence of escaped characters, not the + // absence of original ones + assertThat(escaped).contains("\\\""); + + String maliciousString2 = "'; DROP TABLE users; --"; + String escaped2 = FilterStringEscapeUtils.escapeForSingleQuotes(maliciousString2); + assertThat(escaped2).isEqualTo("\\'; DROP TABLE users; --"); + assertThat(escaped2).contains("\\'"); + + // Test GraphQL injection attempt + String graphqlInjection = "\"value\": \"injected\", \"malicious\": true"; + String escapedGraphQL = FilterStringEscapeUtils.escapeForGraphQL(graphqlInjection); + assertThat(escapedGraphQL).isEqualTo("\\\"value\\\": \\\"injected\\\", \\\"malicious\\\": true"); + assertThat(escapedGraphQL).contains("\\\""); + } + + @Test + void testUnicodeAndSpecialCharacters() { + // Test Unicode characters + String unicodeString = "Hello δΈ–η•Œ 🌍"; + assertThat(FilterStringEscapeUtils.escapeForDoubleQuotes(unicodeString)).isEqualTo(unicodeString); + + // Test mixed special characters + String mixedString = "Hello\"World'Test\\New\nLine\tTab"; + String escaped = FilterStringEscapeUtils.escapeForDoubleQuotes(mixedString); + assertThat(escaped).isEqualTo("Hello\\\"World'Test\\\\New\\nLine\\tTab"); + assertThat(escaped).contains("\\\""); + assertThat(escaped).contains("\\n"); + assertThat(escaped).contains("\\t"); + } + +} diff --git a/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/filter/converter/FilterExpressionConverterSecurityTests.java b/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/filter/converter/FilterExpressionConverterSecurityTests.java new file mode 100644 index 00000000000..fce47364473 --- /dev/null +++ b/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/filter/converter/FilterExpressionConverterSecurityTests.java @@ -0,0 +1,160 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.filter.converter; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; + +/** + * Tests to verify that FilterExpressionConverter implementations properly escape special + * characters to prevent injection attacks. + * + * @author Spring AI Team + * @since 1.0.0 + */ +class FilterExpressionConverterSecurityTests { + + @Test + void testSimpleVectorStoreFilterExpressionConverterEscaping() { + SimpleVectorStoreFilterExpressionConverter converter = new SimpleVectorStoreFilterExpressionConverter(); + + // Test with malicious string containing quotes and escape sequences + String maliciousValue = "'; DROP TABLE users; --"; + FilterExpressionBuilder builder = new FilterExpressionBuilder(); + Filter.Expression filter = builder.eq("testField", maliciousValue).build(); + + String result = converter.convertExpression(filter); + + // Verify that the malicious string is properly escaped + assertThat(result).contains("\\'"); + assertThat(result).contains("DROP TABLE users"); + } + + @Test + void testSimpleVectorStoreFilterExpressionConverterWithNewlines() { + SimpleVectorStoreFilterExpressionConverter converter = new SimpleVectorStoreFilterExpressionConverter(); + + // Test with string containing newlines and tabs + String valueWithNewlines = "line1\nline2\tline3"; + FilterExpressionBuilder builder = new FilterExpressionBuilder(); + Filter.Expression filter = builder.eq("testField", valueWithNewlines).build(); + + String result = converter.convertExpression(filter); + + // Verify that newlines and tabs are properly escaped + assertThat(result).contains("\\n"); + assertThat(result).contains("\\t"); + } + + @Test + void testSimpleVectorStoreFilterExpressionConverterWithBackslashes() { + SimpleVectorStoreFilterExpressionConverter converter = new SimpleVectorStoreFilterExpressionConverter(); + + // Test with string containing backslashes + String valueWithBackslashes = "path\\to\\file"; + FilterExpressionBuilder builder = new FilterExpressionBuilder(); + Filter.Expression filter = builder.eq("testField", valueWithBackslashes).build(); + + String result = converter.convertExpression(filter); + + // Verify that backslashes are properly escaped + assertThat(result).contains("\\\\"); + } + + @Test + void testAbstractFilterExpressionConverterEscaping() { + // Create a test converter that extends AbstractFilterExpressionConverter + TestFilterExpressionConverter converter = new TestFilterExpressionConverter(); + + // Test with malicious string + String maliciousValue = "\"; DROP TABLE users; --"; + FilterExpressionBuilder builder = new FilterExpressionBuilder(); + Filter.Expression filter = builder.eq("testField", maliciousValue).build(); + + String result = converter.convertExpression(filter); + + // Verify that the malicious string is properly escaped + assertThat(result).contains("\\\""); + assertThat(result).contains("DROP TABLE users"); + } + + @Test + void testComplexInjectionAttempts() { + SimpleVectorStoreFilterExpressionConverter converter = new SimpleVectorStoreFilterExpressionConverter(); + + // Test various injection patterns + String[] injectionPatterns = { "'; DROP TABLE users; --", "' OR 1=1 --", + "'; INSERT INTO users VALUES ('hacker', 'password'); --", "' UNION SELECT * FROM users --", + "'; UPDATE users SET password='hacked'; --" }; + + for (String pattern : injectionPatterns) { + FilterExpressionBuilder builder = new FilterExpressionBuilder(); + Filter.Expression filter = builder.eq("testField", pattern).build(); + + String result = converter.convertExpression(filter); + + // Verify that quotes are properly escaped + assertThat(result).contains("\\'"); + } + } + + @Test + void testUnicodeAndSpecialCharacters() { + SimpleVectorStoreFilterExpressionConverter converter = new SimpleVectorStoreFilterExpressionConverter(); + + // Test with Unicode characters + String unicodeValue = "Hello δΈ–η•Œ 🌍"; + FilterExpressionBuilder builder = new FilterExpressionBuilder(); + Filter.Expression filter = builder.eq("testField", unicodeValue).build(); + + String result = converter.convertExpression(filter); + + // Unicode characters should be preserved + assertThat(result).contains("Hello δΈ–η•Œ 🌍"); + } + + /** + * Test implementation of AbstractFilterExpressionConverter for testing purposes. + */ + private static class TestFilterExpressionConverter extends AbstractFilterExpressionConverter { + + @Override + protected void doExpression(Filter.Expression expression, StringBuilder context) { + this.convertOperand(expression.left(), context); + context.append(" == "); + this.convertOperand(expression.right(), context); + } + + @Override + protected void doKey(Filter.Key key, StringBuilder context) { + context.append(key.key()); + } + + @Override + protected void doGroup(Filter.Group group, StringBuilder context) { + context.append("("); + super.doGroup(group, context); + context.append(")"); + } + + } + +} diff --git a/src/checkstyle/checkstyle-suppressions.xml b/src/checkstyle/checkstyle-suppressions.xml index a6ba36d1ee5..4647eba4a27 100644 --- a/src/checkstyle/checkstyle-suppressions.xml +++ b/src/checkstyle/checkstyle-suppressions.xml @@ -62,5 +62,7 @@ + + diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java index 7de6a51c7f5..57ad6934b62 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java @@ -29,7 +29,7 @@ import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Group; -import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.FilterStringEscapeUtils; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; import org.springframework.util.Assert; @@ -96,7 +96,7 @@ private String getOperationSymbol(Expression exp) { } @Override - public void doKey(Key key, StringBuilder context) { + protected void doKey(Filter.Key key, StringBuilder context) { var hasOuterQuotes = hasOuterQuotes(key.key()); var identifier = (hasOuterQuotes) ? removeOuterQuotes(key.key()) : key.key(); var prefixedIdentifier = withMetaPrefix(identifier); @@ -150,7 +150,7 @@ else if (value instanceof String text) { } } else { - context.append(String.format("'%s'", text)); + context.append(String.format("'%s'", FilterStringEscapeUtils.escapeForSql(text))); } } else { diff --git a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBFilterExpressionConverter.java b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBFilterExpressionConverter.java index a5812bf4a1d..50e0e428250 100644 --- a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBFilterExpressionConverter.java +++ b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBFilterExpressionConverter.java @@ -19,7 +19,7 @@ import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.Group; -import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.FilterStringEscapeUtils; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; /** @@ -43,10 +43,15 @@ protected void doExpression(Expression expression, StringBuilder context) { this.convertOperand(expression.right(), context); } + @Override + protected void doKey(Filter.Key key, StringBuilder context) { + context.append(String.format("JSON_EXTRACT(%s, '$.%s')", this.metadataFieldName, key.key())); + } + @Override protected void doSingleValue(Object value, StringBuilder context) { if (value instanceof String) { - context.append(String.format("\'%s\'", value)); + context.append(String.format("\'%s\'", FilterStringEscapeUtils.escapeForSql((String) value))); } else { context.append(value); @@ -70,11 +75,6 @@ private String getOperationSymbol(Expression exp) { }; } - @Override - protected void doKey(Key key, StringBuilder context) { - context.append("JSON_VALUE(" + this.metadataFieldName + ", '$." + key.key() + "')"); - } - protected void doStartValueRange(Filter.Value listValue, StringBuilder context) { context.append("("); } diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateFilterExpressionConverter.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateFilterExpressionConverter.java index 3321cd179f2..8c9123cb1c7 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateFilterExpressionConverter.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateFilterExpressionConverter.java @@ -27,6 +27,7 @@ import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Key; import org.springframework.ai.vectorstore.filter.FilterHelper; +import org.springframework.ai.vectorstore.filter.FilterStringEscapeUtils; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; import org.springframework.util.Assert; @@ -182,7 +183,7 @@ else if (value instanceof Boolean b) { context.append(String.format("valueBoolean:%s ", b)); } else if (value instanceof String s) { - context.append(String.format("valueText:\"%s\" ", s)); + context.append(String.format("valueText:\"%s\" ", FilterStringEscapeUtils.escapeForGraphQL(s))); } else if (value instanceof Date date) { String dateString = DateFormatUtils.format(date, "yyyy-MM-dd\'T\'HH:mm:ssZZZZZ");