diff --git a/pom.xml b/pom.xml index cc5c8607948..a9418372b0c 100644 --- a/pom.xml +++ b/pom.xml @@ -35,6 +35,7 @@ vector-stores/spring-ai-pinecone vector-stores/spring-ai-chroma vector-stores/spring-ai-azure + vector-stores/spring-ai-weaviate @@ -96,6 +97,7 @@ 3.24.4 2.0.42 11.6.0 + 4.4.1 1.19.0 diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 8bdbfa1b629..9c45599f55b 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -117,6 +117,14 @@ true + + + org.springframework.experimental.ai + spring-ai-weaviate-store + ${project.parent.version} + true + + org.springframework.boot diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java new file mode 100644 index 00000000000..95ffbd48d7e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfiguration.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.weaviate; + +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.WeaviateVectorStore; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; + +/** + * @author Christian Tzolov + */ +@AutoConfiguration +@ConditionalOnClass({ EmbeddingClient.class, WeaviateVectorStore.class }) +@EnableConfigurationProperties({ WeaviateVectorStoreProperties.class }) +public class WeaviateVectorStoreAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public VectorStore vectorStore(EmbeddingClient embeddingClient, WeaviateVectorStoreProperties properties) { + + WeaviateVectorStoreConfig.Builder configBuilder = WeaviateVectorStore.WeaviateVectorStoreConfig.builder() + .withScheme(properties.getScheme()) + .withApiKey(properties.getApiKey()) + .withHost(properties.getHost()) + .withHeaders(properties.getHeaders()) + .withObjectClass(properties.getObjectClass()) + .withFilterableMetadataFields(properties.getFilterField() + .entrySet() + .stream() + .map(e -> new MetadataField(e.getKey(), e.getValue())) + .toList()) + .withConsistencyLevel(properties.getConsistencyLevel()); + + return new WeaviateVectorStore(configBuilder.build(), embeddingClient); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java new file mode 100644 index 00000000000..6a28b999ac5 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreProperties.java @@ -0,0 +1,107 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.weaviate; + +import java.util.Map; + +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.ConsistentLevel; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author Christian Tzolov + */ +@ConfigurationProperties(WeaviateVectorStoreProperties.CONFIG_PREFIX) +public class WeaviateVectorStoreProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vectorstore.weaviate"; + + private String scheme = "http"; + + private String host = "localhost:8080"; + + private String apiKey = ""; + + private String objectClass = "SpringAiWeaviate"; + + private ConsistentLevel consistencyLevel = WeaviateVectorStoreConfig.ConsistentLevel.ONE; + + /** + * spring.ai.vectorstore.weaviate.filter-field.= + */ + private Map filterField = Map.of(); + + private Map headers = Map.of(); + + public void setScheme(String scheme) { + this.scheme = scheme; + } + + public String getScheme() { + return scheme; + } + + public void setHost(String host) { + this.host = host; + } + + public String getHost() { + return host; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getObjectClass() { + return objectClass; + } + + public void setObjectClass(String indexName) { + this.objectClass = indexName; + } + + public ConsistentLevel getConsistencyLevel() { + return consistencyLevel; + } + + public void setConsistencyLevel(ConsistentLevel consistencyLevel) { + this.consistencyLevel = consistencyLevel; + } + + public Map getHeaders() { + return headers; + } + + public void setHeaders(Map headers) { + this.headers = headers; + } + + public Map getFilterField() { + return filterField; + } + + public void setFilterField(Map filterMetadataFields) { + this.filterField = filterMetadataFields; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index e69af94f13f..80adaad6d47 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -8,3 +8,4 @@ org.springframework.ai.autoconfigure.embedding.transformer.TransformersEmbedding org.springframework.ai.autoconfigure.huggingface.HuggingfaceAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.azure.AzureVectorStoreAutoConfiguration +org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateVectorStoreAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationTests.java new file mode 100644 index 00000000000..9b3f413f55e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/weaviate/WeaviateVectorStoreAutoConfigurationTests.java @@ -0,0 +1,132 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.weaviate; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.embedding.TransformersEmbeddingClient; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +@Testcontainers +public class WeaviateVectorStoreAutoConfigurationTests { + + @Container + static GenericContainer weaviateContainer = new GenericContainer<>("semitechnologies/weaviate:1.22.4") + .withEnv("AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED", "true") + .withEnv("PERSISTENCE_DATA_PATH", "/var/lib/weaviate") + .withEnv("QUERY_DEFAULTS_LIMIT", "25") + .withEnv("DEFAULT_VECTORIZER_MODULE", "none") + .withEnv("CLUSTER_HOSTNAME", "node1") + .withExposedPorts(8080); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(WeaviateVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.weaviate.scheme=http", + "spring.ai.vectorstore.weaviate.host=localhost:" + weaviateContainer.getMappedPort(8080), + "spring.ai.vectorstore.weaviate.filter-field.country=TEXT", + "spring.ai.vectorstore.weaviate.filter-field.year=NUMBER", + "spring.ai.vectorstore.weaviate.filter-field.active=BOOLEAN", + "spring.ai.vectorstore.weaviate.filter-field.price=NUMBER"); + + @Test + public void addAndSearchWithFilters() { + + contextRunner.run(context -> { + + WeaviateVectorStoreProperties properties = context.getBean(WeaviateVectorStoreProperties.class); + + assertThat(properties.getFilterField()).hasSize(4); + + assertThat(properties.getFilterField().get("country")).isEqualTo(MetadataField.Type.TEXT); + assertThat(properties.getFilterField().get("year")).isEqualTo(MetadataField.Type.NUMBER); + assertThat(properties.getFilterField().get("active")).isEqualTo(MetadataField.Type.BOOLEAN); + assertThat(properties.getFilterField().get("price")).isEqualTo(MetadataField.Type.NUMBER); + + VectorStore vectorStore = context.getBean(VectorStore.class); + + var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "Bulgaria", "price", 3.14, "active", true, "year", 2020)); + var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "Netherland", "price", 1.57, "active", false, "year", 2023)); + + vectorStore.add(List.of(bgDocument, nlDocument)); + + var request = SearchRequest.query("The World").withTopK(5); + + List results = vectorStore.similaritySearch(request); + assertThat(results).hasSize(2); + + results = vectorStore + .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Bulgaria'")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + + results = vectorStore + .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Netherland'")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + + results = vectorStore.similaritySearch( + request.withSimilarityThresholdAll().withFilterExpression("price > 1.57 && active == true")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + + results = vectorStore + .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("year in [2020, 2023]")); + assertThat(results).hasSize(2); + + results = vectorStore.similaritySearch( + request.withSimilarityThresholdAll().withFilterExpression("year > 2020 && year <= 2023")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + + // Remove all documents from the store + vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); + }); + } + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Bean + public EmbeddingClient embeddingClient() { + return new TransformersEmbeddingClient(); + } + + } + +} diff --git a/vector-stores/spring-ai-weaviate/README.md b/vector-stores/spring-ai-weaviate/README.md new file mode 100644 index 00000000000..01b22beb640 --- /dev/null +++ b/vector-stores/spring-ai-weaviate/README.md @@ -0,0 +1,196 @@ +# Weaviate VectorStore + +This readme will walk you through setting up the Weaviate VectorStore to store document embeddings and perform similarity searches. + +## What is Weaviate? + +[Weaviate](https://weaviate.io/) is an open-source vector database. +It allows you to store data objects and vector embeddings from your favorite ML-models, and scale seamlessly into billions of data objects. +It gives you the tools to store document embeddings, content and metadata and to search through those embeddings including metadata filtering. + +## Prerequisites + +1. `EmbeddingClient` instance to compute the document embeddings. Several options are available: + + - `Transformers Embedding` - computes the embedding in your, local environment. Follow the [Transformers Embedding](../../embedding-clients/transformers-embedding/) instructions. + - `OpenAI Embedding` - uses the OpenAI embedding endpoint. You need to create an account at [OpenAI Signup](https://platform.openai.com/signup) and generate the api-key token at [API Keys](https://platform.openai.com/account/api-keys). + - You can also use the `Azure OpenAI Embedding` or the `PostgresML Embedding Client`. + +2. `Weaviate cluster`. You can a cluster, locally, in a Docker container ([Local Weaviate](#appendix_a)) or create a [Weaviate Cloud Service](https://console.weaviate.cloud/). For later you need to create an weaviate account spin a cluster and get your access api-key from the [dashboard details](https://console.weaviate.cloud/dashboard). + +On startup the `WeaviateVectorStore` creates the required `SpringAiWeaviate` object schema (if such is not already provisioned). + +## Dependencies + +Add these dependencies to your project: + +1. Embedding Client boot starter, required for calculating embeddings. + + - Transformers Embedding (Local) + + ```xml + + org.springframework.experimental.ai + spring-ai-transformers-embedding-spring-boot-starter + 0.7.1-SNAPSHOT + + ``` + + follow the [transformers-embedding](../../embedding-clients/transformers-embedding/README.md) instructions. + + - or OpenAI (Cloud) + + ```xml + + org.springframework.experimental.ai + spring-ai-openai-spring-boot-starter + 0.7.1-SNAPSHOT + + ``` + + you'll need to provide your OpenAI API Key. Set it as an environment variable like so: + + ```bash + export SPRING_AI_OPENAI_API_KEY='Your_OpenAI_API_Key' + ``` + +2. Weaviate VectorStore. + + ```xml + + org.springframework.experimental.ai + spring-ai-weaviate-store + 0.7.1-SNAPSHOT + + ``` + +## Usage + +Create a WeaviateVectorStore instance connected to local Weaviate cluster: + +```java + @Bean + public VectorStore vectorStore(EmbeddingClient embeddingClient) { + WeaviateVectorStoreConfig config = WeaviateVectorStoreConfig.builder() + .withScheme("http") + .withHost("localhost:8080") + // Define the metadata fields to be used + // in the similarity search filters. + .withFilterableMetadataFields(List.of( + MetadataField.text("country"), + MetadataField.number("year"), + MetadataField.bool("active"))) + // Consistency level can be: ONE, QUORUM or ALL. + .withConsistencyLevel(ConsistentLevel.ONE) + .build(); + + return new WeaviateVectorStore(config, embeddingClient); + } +``` + +> [!NOTE] +> You must list explicitly all metadata field names and types (`BOOLEAN`, `TEXT` or `NUMBER`) for any metadata key used in filter expression. +>The `withFilterableMetadataKeys` above registers filterable metadata fields: `country` of type `TEXT`, `year` of type `NUMBER` and `active` of type `BOOLEAN`. +> +> If the filterable metadata fields is expanded with new entires, you have to (re)upload/update the documents with this metadata. +> +> You can use the following, Weaviate [system metadata](https://weaviate.io/developers/weaviate/api/graphql/filters#special-cases) fields without explicit definition: `id`, `_creationTimeUnix` and `_lastUpdateTimeUnix`. + +Then yn your main code, create some documents + +```java +List documents = List.of( + new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Map.of("country", "UK", "active", true, "year", 2020)), + new Document("The World is Big and Salvation Lurks Around the Corner", Map.of()), + new Document("You walk forward facing the past and you turn back toward the future.", Map.of("country", "NL", "active", false, "year", 2023))); +``` + +Add the documents to your vector store: + +```java +vectorStore.add(List.of(document)); +``` + +And finally, retrieve documents similar to a query: + +```java +List results = vectorStore.similaritySearch( + SearchRequest + .query("Spring") + .withTopK(5)); +``` + +If all goes well, you should retrieve the document containing the text "Spring AI rocks!!". + +### Metadata filtering + +You can leverage the generic, portable [metadata filters](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_metadata_filters) with WeaviateVectorStore as well. + +For example you can use either the text expression language: + +```java +vectorStore.similaritySearch( + SearchRequest + .query("The World") + .withTopK(TOP_K) + .withSimilarityThreshold(SIMILARITY_THRESHOLD) + .withFilterExpression("country in ['UK', 'NL'] && year >= 2020")); +``` + +or programmatically using the expression DSL: + +```java +FilterExpressionBuilder b = Filter.builder(); + +vectorStore.similaritySearch( + SearchRequest + .query("The World") + .withTopK(TOP_K) + .withSimilarityThreshold(SIMILARITY_THRESHOLD) + .withFilterExpression(b.and( + b.in("country", "UK", "NL"), + b.gte("year", 2020)).build())); +``` + +The, portable, filter expressions get automatically converted into the proprietary Weaviate [where filters](https://weaviate.io/developers/weaviate/api/graphql/filters). +For example the following, portable, filter expression + +```sql +country in ['UK', 'NL'] && year >= 2020 +``` + +is converted into Weaviate, GraphQL, [where filter expression](https://weaviate.io/developers/weaviate/api/graphql/filters): + +```graphQL +operator:And + operands: + [{ + operator:Or + operands: + [{ + path:["meta_country"] + operator:Equal + valueText:"UK" + }, + { + path:["meta_country"] + operator:Equal + valueText:"NL" + }] + }, + { + path:["meta_year"] + operator:GreaterThanEqual + valueNumber:2020 + }] +``` + +## Appendix A: Run Weaviate cluster in docker container + +Start Weaviate in a docker container: + +```bash +docker run -it --rm --name weaviate -e AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true -e PERSISTENCE_DATA_PATH=/var/lib/weaviate -e QUERY_DEFAULTS_LIMIT=25 -e DEFAULT_VECTORIZER_MODULE=none -e CLUSTER_HOSTNAME=node1 -p 8080:8080 semitechnologies/weaviate:1.22.4 +``` + +Starts a Weaviate cluster at http://localhost:8080/v1 with scheme=`http`, host=`localhost:8080` and apiKey=`""`. Then follow the [usage instructions](#usage). diff --git a/vector-stores/spring-ai-weaviate/pom.xml b/vector-stores/spring-ai-weaviate/pom.xml new file mode 100644 index 00000000000..536edb61b4b --- /dev/null +++ b/vector-stores/spring-ai-weaviate/pom.xml @@ -0,0 +1,84 @@ + + + 4.0.0 + + org.springframework.experimental.ai + spring-ai + 0.7.1-SNAPSHOT + ../../pom.xml + + spring-ai-weaviate-store + jar + spring-ai-weaviate-store + spring-ai-weaviate + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + 17 + 17 + + + + + org.springframework.experimental.ai + spring-ai-core + ${project.parent.version} + + + + io.weaviate + client + ${weaviate-client.version} + + + commons-logging + commons-logging + + + + + + + org.springframework.experimental.ai + transformers-embedding + ${parent.version} + test + + + + org.springframework.experimental.ai + spring-ai-test + ${parent.version} + test + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.testcontainers + testcontainers + ${testcontainers.version} + test + + + + org.testcontainers + junit-jupiter + ${testcontainers.version} + test + + + + + diff --git a/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java b/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java new file mode 100644 index 00000000000..b713eb182c1 --- /dev/null +++ b/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java @@ -0,0 +1,237 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import java.util.ArrayList; +import java.util.Date; +import java.util.List; + +import org.apache.commons.lang3.time.DateFormatUtils; + +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; +import org.springframework.ai.vectorstore.filter.Filter.Group; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; +import org.springframework.util.Assert; + +/** + * Converts {@link Expression} into Weaviate metadata filter expression format. + * (https://weaviate.io/developers/weaviate/api/graphql/filters) + * + * @author Christian Tzolov + */ +public class WeaviateFilterExpressionConverter extends AbstractFilterExpressionConverter { + + private boolean mapIntegerToNumberValue = true; + + // https://weaviate.io/developers/weaviate/api/graphql/filters#special-cases + private static final List SYSTEM_IDENTIFIERS = List.of("id", "_creationTimeUnix", "_lastUpdateTimeUnix"); + + private List allowedIdentifierNames; + + public WeaviateFilterExpressionConverter(List allowedIdentifierNames) { + Assert.notNull(allowedIdentifierNames, "List can be empty but not null."); + this.allowedIdentifierNames = allowedIdentifierNames; + } + + public void setAllowedIdentifierNames(List allowedIdentifierNames) { + this.allowedIdentifierNames = allowedIdentifierNames; + } + + public void setMapIntegerToNumberValue(boolean mapIntegerToNumberValue) { + this.mapIntegerToNumberValue = mapIntegerToNumberValue; + } + + @Override + protected void doExpression(Expression exp, StringBuilder context) { + + if (exp.type() == ExpressionType.IN) { + rewriteInNinExpressions(Filter.ExpressionType.OR, Filter.ExpressionType.EQ, exp, context); + } + else if (exp.type() == ExpressionType.NIN) { + rewriteInNinExpressions(Filter.ExpressionType.AND, Filter.ExpressionType.NE, exp, context); + } + else if (exp.type() == ExpressionType.AND || exp.type() == ExpressionType.OR) { + context.append(getOperationSymbol(exp)); + context.append("operands:[{"); + this.convertOperand(exp.left(), context); + context.append("},\n{"); + this.convertOperand(exp.right(), context); + context.append("}]"); + } + else { + this.convertOperand(exp.left(), context); + context.append(getOperationSymbol(exp)); + this.convertOperand(exp.right(), context); + } + } + + /** + * Recursively aggregates a list of expression into a binary tree with 'aggregateType' + * join nodes. + * @param aggregateType type all tree splits. + * @param expressions list of expressions to aggregate. + * @return Returns a binary tree expression. + */ + private Filter.Expression aggregate(Filter.ExpressionType aggregateType, List expressions) { + + if (expressions.size() == 1) { + return expressions.get(0); + } + return new Filter.Expression(aggregateType, expressions.get(0), + aggregate(aggregateType, expressions.subList(1, expressions.size()))); + } + + private void rewriteInNinExpressions(Filter.ExpressionType outerExpressionType, + Filter.ExpressionType innerExpressionType, Expression exp, StringBuilder context) { + if (exp.right() instanceof Filter.Value value) { + if (value.value() instanceof List list) { + // 1. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo == "bar1" || + // foo == "bar2" || foo == "bar3" + // or equivalent to OR(foo == "bar1" OR( foo == "bar2" OR(foo == "bar3"))) + // 2. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo != "bar1" && + // foo != "bar2" && foo != "bar3" + // or equivalent to AND(foo != "bar1" AND( foo != "bar2" OR(foo != + // "bar3"))) + List eqExprs = new ArrayList<>(); + for (Object o : list) { + eqExprs.add(new Filter.Expression(innerExpressionType, exp.left(), new Filter.Value(o))); + } + this.doExpression(aggregate(outerExpressionType, eqExprs), context); + } + else { + // 1. foo IN ["bar"] is equivalent to foo == "BAR" + // 2. foo NIN ["bar"] is equivalent to foo != "BAR" + this.doExpression(new Filter.Expression(innerExpressionType, exp.left(), exp.right()), context); + } + } + else { + throw new IllegalStateException( + "Filter IN right expression should be of Filter.Value type but was " + exp.right().getClass()); + } + } + + private String getOperationSymbol(Expression exp) { + switch (exp.type()) { + case AND: + return "operator:And \n"; + case OR: + return "operator:Or \n"; + case EQ: + return "operator:Equal \n"; + case NE: + return "operator:NotEqual \n"; + case LT: + return "operator:LessThan \n"; + case LTE: + return "operator:LessThanEqual \n"; + case GT: + return "operator:GreaterThan \n"; + case GTE: + return "operator:GreaterThanEqual \n"; + case IN: + throw new IllegalStateException( + "The 'IN' operator should have been transformed into chain of OR/EQ expressions."); + case NIN: + throw new IllegalStateException( + "The 'NIN' operator should have been transformed into chain of AND/NEQ expressions."); + default: + throw new UnsupportedOperationException("Not supported expression type:" + exp.type()); + } + } + + @Override + protected void doKey(Key key, StringBuilder context) { + var identifier = (hasOuterQuotes(key.key())) ? removeOuterQuotes(key.key()) : key.key(); + context.append("path:[\"" + withMetaPrefix(identifier) + "\"] \n"); + } + + public String withMetaPrefix(String identifier) { + if (SYSTEM_IDENTIFIERS.contains(identifier)) { + return identifier; + } + + if (this.allowedIdentifierNames.contains(identifier)) { + return "meta_" + identifier; + } + + throw new IllegalArgumentException("Not allowed filter identifier name: " + identifier + + ". Consider adding it to WeaviateVectorStore#filterMetadataKeys."); + } + + @Override + protected void doValue(Filter.Value filterValue, StringBuilder context) { + if (filterValue.value() instanceof List list) { + // nothing + throw new IllegalStateException(""); + } + else { + this.doSingleValue(filterValue.value(), context); + } + } + + @Override + protected void doSingleValue(Object value, StringBuilder context) { + if (value instanceof Integer i) { + if (this.mapIntegerToNumberValue) { + context.append(String.format("valueNumber:%s ", i)); + } + else { + context.append(String.format("valueInt:%s ", i)); + } + } + else if (value instanceof Long l) { + if (this.mapIntegerToNumberValue) { + context.append(String.format("valueNumber:%s ", l)); + } + else { + context.append(String.format("valueInt:%s ", l)); + } + } + else if (value instanceof Double d) { + context.append(String.format("valueNumber:%s ", d)); + } + else if (value instanceof Float f) { + context.append(String.format("valueNumber:%s ", f)); + } + else if (value instanceof Boolean b) { + context.append(String.format("valueBoolean:%s ", b)); + } + else if (value instanceof String s) { + context.append(String.format("valueText:\"%s\" ", s)); + } + else if (value instanceof Date date) { + String dateString = DateFormatUtils.format(date, "yyyy-MM-dd\'T\'HH:mm:ssZZZZZ"); + context.append(String.format("valueDate:\"%s\" ", dateString)); + } + else { + throw new RuntimeException("Unsupported value type: " + value); + } + } + + @Override + protected void doGroup(Group group, StringBuilder context) { + // Replaces the group: AND((foo == "bar" OR bar == "foo"), "boza" == "koza") into + // AND(AND(id != -1, (foo == "bar" OR bar == "foo")), "boza" == "koza") into + this.convertOperand(new Expression(ExpressionType.AND, + new Expression(ExpressionType.NE, new Filter.Key("id"), new Filter.Value("-1")), group.content()), + context); + } + +} \ No newline at end of file diff --git a/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java new file mode 100644 index 00000000000..662fabe2863 --- /dev/null +++ b/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java @@ -0,0 +1,615 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateAuthClient; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.base.WeaviateErrorMessage; +import io.weaviate.client.v1.auth.exception.AuthException; +import io.weaviate.client.v1.batch.model.BatchDeleteResponse; +import io.weaviate.client.v1.batch.model.ObjectGetResponse; +import io.weaviate.client.v1.data.model.WeaviateObject; +import io.weaviate.client.v1.filters.Operator; +import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.GraphQLError; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument; +import io.weaviate.client.v1.graphql.query.argument.WhereArgument; +import io.weaviate.client.v1.graphql.query.builder.GetBuilder; +import io.weaviate.client.v1.graphql.query.builder.GetBuilder.GetBuilderBuilder; +import io.weaviate.client.v1.graphql.query.fields.Field; +import io.weaviate.client.v1.graphql.query.fields.Fields; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.ConsistentLevel; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * A VectorStore implementation backed by Weaviate vector database. + * + * Note: You can assign arbitrary metadata fields with your Documents. Later will be + * persisted and managed as Document fields. But only the metadata keys listed in + * {@link WeaviateVectorStore#filterMetadataFields} can be used for similarity search + * expression filters. + * + * @author Christian Tzolov + */ +public class WeaviateVectorStore implements VectorStore, InitializingBean { + + public static final String DOCUMENT_METADATA_DISTANCE_KEY_NAME = "distance"; + + private static final String METADATA_FIELD_PREFIX = "meta_"; + + private static final String CONTENT_FIELD_NAME = "content"; + + private static final String METADATA_FIELD_NAME = "metadata"; + + private static final String ADDITIONAL_FIELD_NAME = "_additional"; + + private static final String ADDITIONAL_ID_FIELD_NAME = "id"; + + private static final String ADDITIONAL_CERTAINTY_FIELD_NAME = "certainty"; + + private static final String ADDITIONAL_VECTOR_FIELD_NAME = "vector"; + + private final EmbeddingClient embeddingClient; + + private final WeaviateClient weaviateClient; + + private final ConsistentLevel consistencyLevel; + + private final String weaviateObjectClass; + + /** + * List of metadata fields (as field name and type) that can be used in similarity + * search query filter expressions. The {@link Document#getMetadata()} can contain + * arbitrary number of metadata entries, but only the fields listed here can be used + * in the search filter expressions. + * + * If new entries are added ot the filterMetadataFields the affected documents must be + * (re)updated. + */ + private final List filterMetadataFields; + + /** + * List of weaviate field to retrieve whey performing similarity search. + */ + private final Field[] weaviateSimilaritySearchFields; + + /** + * Converts the generic {@link Filter.Expression} into, native, Weaviate filter + * expressions. + */ + private final WeaviateFilterExpressionConverter filterExpressionConverter; + + /** + * Used to serialize/deserialize the document metadata when stored/retrieved from the + * weaviate vector store. + */ + private final ObjectMapper objetMapper = new ObjectMapper(); + + /** + * Configuration class for the WeaviateVectorStore. + */ + public static final class WeaviateVectorStoreConfig { + + public record MetadataField(String name, Type type) { + public enum Type { + + TEXT, NUMBER, BOOLEAN + + } + + public static MetadataField text(String name) { + return new MetadataField(name, Type.TEXT); + } + + public static MetadataField number(String name) { + return new MetadataField(name, Type.NUMBER); + } + + public static MetadataField bool(String name) { + return new MetadataField(name, Type.BOOLEAN); + } + + } + + /** + * https://weaviate.io/developers/weaviate/concepts/replication-architecture/consistency#tunable-consistency-strategies + */ + public enum ConsistentLevel { + + /** + * Write must receive an acknowledgement from at least one replica node. This + * is the fastest (most available), but least consistent option. + */ + ONE, + + /** + * Write must receive an acknowledgement from at least QUORUM replica nodes. + * QUORUM is calculated as n / 2 + 1, where n is the number of replicas. + */ + QUORUM, + + /** + * Write must receive an acknowledgement from all replica nodes. This is the + * most consistent, but 'slowest'. + */ + ALL + + } + + /** + * The server api key. + */ + private final String apiKey; + + /** + * The URL scheme, such as 'http' or 'https'. + */ + private final String scheme; + + private final String host; + + private final String weaviateObjectClass; + + private final ConsistentLevel consistencyLevel; + + /** + * Known metadata fields to add as a fields to the Weaviate schema. You can add + * arbitrary metadata with your documents but only the metadata fields listed here + * can be used in the expression filters. + */ + private final List filterMetadataFields; + + private final Map headers; + + /** + * Constructor using the builder. + * @param builder The configuration builder. + */ + public WeaviateVectorStoreConfig(Builder builder) { + this.apiKey = builder.apiKey; + this.scheme = builder.scheme; + this.host = builder.host; + this.weaviateObjectClass = builder.objectClass; + this.consistencyLevel = builder.consistencyLevel; + this.filterMetadataFields = builder.filterMetadataFields; + this.headers = builder.headers; + } + + /** + * Start building a new configuration. + * @return The entry point for creating a new configuration. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * {@return the default config} + */ + public static WeaviateVectorStoreConfig defaultConfig() { + return builder().build(); + } + + public static class Builder { + + private String apiKey = ""; + + private String scheme = "http"; + + private String host = "localhost:8080"; + + private String objectClass = "SpringAiWeaviate"; + + private ConsistentLevel consistencyLevel = WeaviateVectorStoreConfig.ConsistentLevel.ONE; + + private List filterMetadataFields = List.of(); + + private Map headers = Map.of(); + + private Builder() { + } + + /** + * Pinecone api key. + * @param apiKey key to use. + * @return this builder. + */ + public Builder withApiKey(String apiKey) { + Assert.notNull(apiKey, "The apiKey can not be null."); + this.apiKey = apiKey; + return this; + } + + /** + * Weaviate scheme. + * @param scheme scheme to use. + * @return this builder. + */ + public Builder withScheme(String scheme) { + Assert.hasText(scheme, "The scheme can not be empty."); + this.scheme = scheme; + return this; + } + + /** + * Weaviate host. + * @param host host to use. + * @return this builder. + */ + public Builder withHost(String host) { + Assert.hasText(host, "The host can not be empty."); + this.host = host; + return this; + } + + /** + * Weaviate known, filterable metadata fields. + * @param filterMetadataFields known metadata fields to use. + * @return this builder. + */ + public Builder withFilterableMetadataFields(List filterMetadataFields) { + Assert.notNull(filterMetadataFields, "The filterMetadataFields can not be null."); + this.filterMetadataFields = filterMetadataFields; + return this; + } + + /** + * Weaviate config headers. + * @param headers config headers to use. + * @return this builder. + */ + public Builder withHeaders(Map headers) { + Assert.notNull(headers, "The headers can not be null."); + this.headers = headers; + return this; + } + + /** + * Weaviate objectClass. + * @param objectClass objectClass to use. + * @return this builder. + */ + public Builder withObjectClass(String objectClass) { + Assert.hasText(objectClass, "The objectClass can not be empty."); + this.objectClass = objectClass; + return this; + } + + /** + * Weaviate consistencyLevel. + * @param consistencyLevel consistencyLevel to use. + * @return this builder. + */ + public Builder withConsistencyLevel(ConsistentLevel consistencyLevel) { + Assert.notNull(consistencyLevel, "The consistencyLevel can not be null."); + this.consistencyLevel = consistencyLevel; + return this; + } + + /** + * {@return the immutable configuration} + */ + public WeaviateVectorStoreConfig build() { + return new WeaviateVectorStoreConfig(this); + } + + } + + } + + /** + * Constructs a new WeaviateVectorStore. + * @param vectorStoreConfig The configuration for the store. + * @param embeddingClient The client for embedding operations. + */ + public WeaviateVectorStore(WeaviateVectorStoreConfig vectorStoreConfig, EmbeddingClient embeddingClient) { + Assert.notNull(vectorStoreConfig, "WeaviateVectorStoreConfig must not be null"); + Assert.notNull(embeddingClient, "EmbeddingClient must not be null"); + + this.embeddingClient = embeddingClient; + this.consistencyLevel = vectorStoreConfig.consistencyLevel; + this.weaviateObjectClass = vectorStoreConfig.weaviateObjectClass; + this.filterMetadataFields = vectorStoreConfig.filterMetadataFields; + this.filterExpressionConverter = new WeaviateFilterExpressionConverter( + this.filterMetadataFields.stream().map(MetadataField::name).toList()); + + try { + this.weaviateClient = WeaviateAuthClient.apiKey( + new Config(vectorStoreConfig.scheme, vectorStoreConfig.host, vectorStoreConfig.headers), + vectorStoreConfig.apiKey); + } + catch (AuthException e) { + throw new IllegalArgumentException(e); + } + + this.weaviateSimilaritySearchFields = buildWeaviateSimilaritySearchFields(); + } + + private Field[] buildWeaviateSimilaritySearchFields() { + + List searchWeaviateFieldList = new ArrayList<>(); + + searchWeaviateFieldList.add(Field.builder().name(CONTENT_FIELD_NAME).build()); + searchWeaviateFieldList.add(Field.builder().name(METADATA_FIELD_NAME).build()); + searchWeaviateFieldList.addAll(this.filterMetadataFields.stream() + .map(mf -> Field.builder().name(METADATA_FIELD_PREFIX + mf.name()).build()) + .toList()); + searchWeaviateFieldList.add(Field.builder() + .name(ADDITIONAL_FIELD_NAME) + // https://weaviate.io/developers/weaviate/api/graphql/get#additional-properties--metadata + .fields(Field.builder().name(ADDITIONAL_ID_FIELD_NAME).build(), + Field.builder().name(ADDITIONAL_CERTAINTY_FIELD_NAME).build(), + Field.builder().name(ADDITIONAL_VECTOR_FIELD_NAME).build()) + .build()); + + return searchWeaviateFieldList.toArray(new Field[0]); + } + + @Override + public void add(List documents) { + + if (CollectionUtils.isEmpty(documents)) { + return; + } + + List weaviateObjects = documents.stream().map(this::toWeaviateObject).toList(); + + Result response = this.weaviateClient.batch() + .objectsBatcher() + .withObjects(weaviateObjects.toArray(new WeaviateObject[0])) + .withConsistencyLevel(this.consistencyLevel.name()) + .run(); + + List errorMessages = new ArrayList<>(); + + if (response.hasErrors()) { + errorMessages.add(response.getError() + .getMessages() + .stream() + .map(wm -> wm.getMessage()) + .collect(Collectors.joining("\n"))); + throw new RuntimeException("Failed to add documents because: \n" + errorMessages); + } + + if (response.getResult() != null) { + for (var r : response.getResult()) { + if (r.getResult() != null && r.getResult().getErrors() != null) { + var error = r.getResult().getErrors(); + errorMessages + .add(error.getError().stream().map(e -> e.getMessage()).collect(Collectors.joining("\n"))); + } + } + } + + if (!CollectionUtils.isEmpty(errorMessages)) { + throw new RuntimeException("Failed to add documents because: \n" + errorMessages); + } + } + + private WeaviateObject toWeaviateObject(Document document) { + + if (CollectionUtils.isEmpty(document.getEmbedding())) { + List embedding = this.embeddingClient.embed(document); + document.setEmbedding(embedding); + } + + // https://weaviate.io/developers/weaviate/config-refs/datatypes + Map fields = new HashMap<>(); + fields.put(CONTENT_FIELD_NAME, document.getContent()); + try { + String metadataString = this.objetMapper.writeValueAsString(document.getMetadata()); + fields.put(METADATA_FIELD_NAME, metadataString); + } + catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize the Document metadata: " + document.getContent()); + } + + // Add the filterable metadata fields as top level fields, allowing filler + // expressions on them. + for (MetadataField mf : this.filterMetadataFields) { + if (document.getMetadata().containsKey(mf.name())) { + fields.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name())); + } + } + + return WeaviateObject.builder() + .className(this.weaviateObjectClass) + .id(document.getId()) + .vector(toFloatArray(document.getEmbedding())) + .properties(fields) + .build(); + } + + @Override + public Optional delete(List documentIds) { + + Result result = this.weaviateClient.batch() + .objectsBatchDeleter() + .withClassName(this.weaviateObjectClass) + .withConsistencyLevel(this.consistencyLevel.name()) + .withWhere(WhereFilter.builder() + .path("id") + .operator(Operator.ContainsAny) + .valueString(documentIds.toArray(new String[0])) + .build()) + .run(); + + if (result.hasErrors()) { + String errorMessages = result.getError() + .getMessages() + .stream() + .map(wm -> wm.getMessage()) + .collect(Collectors.joining(",")); + throw new RuntimeException("Failed to delete documents because: \n" + errorMessages); + } + + return Optional.of(!result.hasErrors()); + } + + @Override + public List similaritySearch(SearchRequest request) { + + Float[] embedding = toFloatArray(this.embeddingClient.embed(request.getQuery())); + + GetBuilder.GetBuilderBuilder builder = GetBuilder.builder(); + + GetBuilderBuilder queryBuilder = builder.className(this.weaviateObjectClass) + .withNearVectorFilter(NearVectorArgument.builder() + .vector(embedding) + .certainty((float) request.getSimilarityThreshold()) + .build()) + .limit(request.getTopK()) + .withWhereFilter(WhereArgument.builder().build()) // adds an empty 'where:{}' + // placeholder. + .fields(Fields.builder().fields(this.weaviateSimilaritySearchFields).build()); + + String graphQLQuery = queryBuilder.build().buildQuery(); + + if (request.hasFilterExpression()) { + // replace the empty 'where:{}' placeholder with real filter. + String filter = this.filterExpressionConverter.convertExpression(request.getFilterExpression()); + graphQLQuery = graphQLQuery.replace("where:{}", String.format("where:{%s}", filter)); + } + else { + // remove the empty 'where:{}' placeholder. + graphQLQuery = graphQLQuery.replace("where:{}", ""); + } + + Result result = this.weaviateClient.graphQL().raw().withQuery(graphQLQuery).run(); + + if (result.hasErrors()) { + throw new IllegalArgumentException(result.getError() + .getMessages() + .stream() + .map(WeaviateErrorMessage::getMessage) + .collect(Collectors.joining("\n"))); + } + + GraphQLError[] errors = result.getResult().getErrors(); + if (errors != null && errors.length > 0) { + throw new IllegalArgumentException( + Arrays.stream(errors).map(GraphQLError::getMessage).collect(Collectors.joining("\n"))); + } + + @SuppressWarnings("unchecked") + Optional>> resGetPart = ((Map>) result.getResult().getData()) + .entrySet() + .stream() + .findFirst(); + if (!resGetPart.isPresent()) { + return List.of(); + } + + Optional resItemsPart = resGetPart.get().getValue().entrySet().stream().findFirst(); + if (!resItemsPart.isPresent()) { + return List.of(); + } + + @SuppressWarnings("unchecked") + List> resItems = ((Map.Entry>>) resItemsPart.get()).getValue(); + + return resItems.stream().map(this::toDocument).toList(); + } + + @SuppressWarnings("unchecked") + private Document toDocument(Map item) { + + // Additional (System) + Map additional = (Map) item.get(ADDITIONAL_FIELD_NAME); + double certainty = (Double) additional.get(ADDITIONAL_CERTAINTY_FIELD_NAME); + String id = (String) additional.get(ADDITIONAL_ID_FIELD_NAME); + List embedding = ((List) additional.get(ADDITIONAL_VECTOR_FIELD_NAME)).stream().toList(); + + // Metadata + Map metadata = new HashMap<>(); + metadata.put(DOCUMENT_METADATA_DISTANCE_KEY_NAME, 1 - certainty); + + try { + String metadataJson = (String) item.get(METADATA_FIELD_NAME); + if (StringUtils.hasText(metadataJson)) { + metadata.putAll(this.objetMapper.readValue(metadataJson, Map.class)); + } + } + catch (Exception e) { + throw new RuntimeException(e); + } + + // Content + String content = (String) item.get(CONTENT_FIELD_NAME); + + var document = new Document(id, content, metadata); + document.setEmbedding(embedding); + + return document; + } + + /** + * Converts a list of doubles to a array of floats. + * @param doubleList The list of doubles. + * @return The converted array of floats. + */ + private Float[] toFloatArray(List doubleList) { + return doubleList.stream().map(d -> d.floatValue()).toList().toArray(new Float[0]); + } + + @Override + public void afterPropertiesSet() throws Exception { + + Map metadata = new HashMap<>(); + if (!CollectionUtils.isEmpty(this.filterMetadataFields)) { + for (MetadataField mf : this.filterMetadataFields) { + switch (mf.type()) { + case TEXT: + metadata.put(mf.name(), "Hello"); + break; + case NUMBER: + metadata.put(mf.name(), 3.14); + break; + case BOOLEAN: + metadata.put(mf.name(), true); + break; + default: + break; + } + } + } + + var document = new Document("Hello world", metadata); + this.add(List.of(document)); + this.delete(List.of(document.getId())); + } + +} diff --git a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java new file mode 100644 index 00000000000..9f4b9264678 --- /dev/null +++ b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java @@ -0,0 +1,274 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Group; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Value; +import org.springframework.ai.vectorstore.filter.converter.FilterExpressionConverter; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.IN; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.LTE; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; + +/** + * @author Christian Tzolov + */ +public class WeaviateFilterExpressionConverterTests { + + private static String format(String text) { + return text.trim().replace(" " + System.lineSeparator(), System.lineSeparator()) + "\n"; + } + + @Test + public void testMissingFilterName() { + + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of()); + + assertThatThrownBy(() -> { + converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + }).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Not allowed filter identifier name: country. Consider adding it to WeaviateVectorStore#filterMetadataKeys."); + } + + @Test + public void testSystemIdentifiers() { + + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of()); + + // id == "1" && _creationTimeUnix >= "36" && _lastUpdateTimeUnix <= "100" + + String vectorExpr = converter.convertExpression(new Expression(AND, + new Expression(AND, new Expression(EQ, new Key("id"), new Value("1")), + new Expression(GTE, new Key("_creationTimeUnix"), new Value("36"))), + new Expression(LTE, new Key("_lastUpdateTimeUnix"), new Value("100")))); + + assertThat(format(vectorExpr)).isEqualTo(""" + operator:And + operands:[{operator:And + operands:[{path:["id"] + operator:Equal + valueText:"1" }, + {path:["_creationTimeUnix"] + operator:GreaterThanEqual + valueText:"36" }]}, + {path:["_lastUpdateTimeUnix"] + operator:LessThanEqual + valueText:"100" }] + """); + } + + @Test + public void testEQ() { + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of("country")); + + // country == "BG" + String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + assertThat(format(vectorExpr)).isEqualTo(""" + path:["meta_country"] + operator:Equal + valueText:"BG" + """); + } + + @Test + public void tesEqAndGte() { + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of("genre", "year")); + + // genre == "drama" AND year >= 2020 + String vectorExpr = converter + .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), + new Expression(GTE, new Key("year"), new Value(2020)))); + assertThat(format(vectorExpr)).isEqualTo(""" + operator:And + operands:[{path:["meta_genre"] + operator:Equal + valueText:"drama" }, + {path:["meta_year"] + operator:GreaterThanEqual + valueNumber:2020 }] + """); + } + + @Test + public void tesIn() { + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of("genre")); + + // genre in ["comedy", "documentary", "drama"] + String vectorExpr = converter.convertExpression( + new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); + assertThat(format(vectorExpr)).isEqualTo(""" + operator:Or + operands:[{path:["meta_genre"] + operator:Equal + valueText:"comedy" }, + {operator:Or + operands:[{path:["meta_genre"] + operator:Equal + valueText:"documentary" }, + {path:["meta_genre"] + operator:Equal + valueText:"drama" }]}] + """); + } + + @Test + public void testNe() { + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of("city", "year", "country")); + + // year >= 2020 OR country == "BG" AND city != "Sofia" + String vectorExpr = converter + .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), + new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), + new Expression(NE, new Key("city"), new Value("Sofia"))))); + assertThat(format(vectorExpr)).isEqualTo(""" + operator:Or + operands:[{path:["meta_year"] + operator:GreaterThanEqual + valueNumber:2020 }, + {operator:And + operands:[{path:["meta_country"] + operator:Equal + valueText:"BG" }, + {path:["meta_city"] + operator:NotEqual + valueText:"Sofia" }]}] + """); + } + + @Test + public void testGroup() { + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of("city", "year", "country")); + + // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] + String vectorExpr = converter.convertExpression(new Expression(AND, + new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), + new Expression(EQ, new Key("country"), new Value("BG")))), + new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); + + assertThat(format(vectorExpr)).isEqualTo(""" + operator:And + operands:[{operator:And + operands:[{path:["id"] + operator:NotEqual + valueText:"-1" }, + {operator:Or + operands:[{path:["meta_year"] + operator:GreaterThanEqual + valueNumber:2020 }, + {path:["meta_country"] + operator:Equal + valueText:"BG" }]}]}, + {operator:And + operands:[{path:["meta_city"] + operator:NotEqual + valueText:"Sofia" }, + {path:["meta_city"] + operator:NotEqual + valueText:"Plovdiv" }]}] + """); + } + + @Test + public void tesBoolean() { + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter( + List.of("isOpen", "year", "country")); + + // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] + String vectorExpr = converter.convertExpression(new Expression(AND, + new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), + new Expression(GTE, new Key("year"), new Value(2020))), + new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); + + assertThat(format(vectorExpr)).isEqualTo(""" + operator:And + operands:[{operator:And + operands:[{path:["meta_isOpen"] + operator:Equal + valueBoolean:true }, + {path:["meta_year"] + operator:GreaterThanEqual + valueNumber:2020 }]}, + {operator:Or + operands:[{path:["meta_country"] + operator:Equal + valueText:"BG" }, + {operator:Or + operands:[{path:["meta_country"] + operator:Equal + valueText:"NL" }, + {path:["meta_country"] + operator:Equal + valueText:"US" }]}]}] + """); + } + + @Test + public void testDecimal() { + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of("temperature")); + + // temperature >= -15.6 && temperature <= +20.13 + String vectorExpr = converter + .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), + new Expression(LTE, new Key("temperature"), new Value(20.13)))); + + assertThat(format(vectorExpr)).isEqualTo(""" + operator:And + operands:[{path:["meta_temperature"] + operator:GreaterThanEqual + valueNumber:-15.6 }, + {path:["meta_temperature"] + operator:LessThanEqual + valueNumber:20.13 }] + """); + } + + @Test + public void testComplexIdentifiers() { + FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of("country 1 2 3")); + + String vectorExpr = converter + .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); + assertThat(format(vectorExpr)).isEqualTo(""" + path:["meta_country 1 2 3"] + operator:Equal + valueText:"BG" + """); + + vectorExpr = converter.convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); + assertThat(format(vectorExpr)).isEqualTo(""" + path:["meta_country 1 2 3"] + operator:Equal + valueText:"BG" + """); + } + +} diff --git a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java new file mode 100644 index 00000000000..336cecce140 --- /dev/null +++ b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java @@ -0,0 +1,262 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.embedding.TransformersEmbeddingClient; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig; +import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.core.io.DefaultResourceLoader; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +@Testcontainers +public class WeaviateVectorStoreIT { + + @Container + static GenericContainer weaviateContainer = new GenericContainer<>("semitechnologies/weaviate:1.22.4") + .withEnv("AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED", "true") + .withEnv("PERSISTENCE_DATA_PATH", "/var/lib/weaviate") + .withEnv("QUERY_DEFAULTS_LIMIT", "25") + .withEnv("DEFAULT_VECTORIZER_MODULE", "none") + .withEnv("CLUSTER_HOSTNAME", "node1") + .withExposedPorts(8080); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + List documents = List.of( + new Document("471a8c78-549a-4b2c-bce5-ef3ae6579be3", getText("classpath:/test/data/spring.ai.txt"), + Map.of("meta1", "meta1")), + new Document("bc51d7f7-627b-4ba6-adf4-f0bcd1998f8f", getText("classpath:/test/data/time.shelter.txt"), + Map.of()), + new Document("d0237682-1150-44ff-b4d2-1be9b1731ee5", getText("classpath:/test/data/great.depression.txt"), + Map.of("meta2", "meta2"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + private void resetCollection(VectorStore vectorStore) { + vectorStore.delete(documents.stream().map(Document::getId).toList()); + } + + @Test + public void addAndSearch() { + + contextRunner.run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + resetCollection(vectorStore); + + vectorStore.add(documents); + + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + + // Remove all documents from the store + vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + + results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + assertThat(results).hasSize(0); + }); + } + + @Test + public void searchWithFilters() throws InterruptedException { + + contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2020)); + var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "NL")); + var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2023)); + + vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); + + List results = vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)); + assertThat(results).hasSize(3); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'NL'")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'BG'")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'BG' && year == 2020")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + + vectorStore.delete(List.of(bgDocument.getId(), nlDocument.getId(), bgDocument2.getId())); + }); + } + + @Test + public void documentUpdate() { + + contextRunner.run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + resetCollection(vectorStore); + + Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", + Collections.singletonMap("meta1", "meta1")); + + vectorStore.add(List.of(document)); + + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); + assertThat(resultDoc.getMetadata()).containsKey("meta1"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + Document sameIdDocument = new Document(document.getId(), + "The World is Big and Salvation Lurks Around the Corner", + Collections.singletonMap("meta2", "meta2")); + + vectorStore.add(List.of(sameIdDocument)); + + results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5)); + + assertThat(results).hasSize(1); + resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + vectorStore.delete(List.of(document.getId())); + + }); + } + + @Test + public void searchWithThreshold() { + + contextRunner.run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + resetCollection(vectorStore); + + vectorStore.add(documents); + + List fullResult = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll()); + + List distances = fullResult.stream() + .map(doc -> (Double) doc.getMetadata().get("distance")) + .toList(); + + assertThat(distances).hasSize(3); + + double threshold = (distances.get(0) + distances.get(1)) / 2; + + List results = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration + public static class TestApplication { + + @Bean + public VectorStore vectorStore(EmbeddingClient embeddingClient) { + WeaviateVectorStoreConfig config = WeaviateVectorStore.WeaviateVectorStoreConfig.builder() + .withScheme("http") + .withHost(String.format("%s:%s", weaviateContainer.getHost(), weaviateContainer.getMappedPort(8080))) + .withFilterableMetadataFields(List.of(MetadataField.text("country"), MetadataField.number("year"))) + .withConsistencyLevel(WeaviateVectorStoreConfig.ConsistentLevel.ONE) + .build(); + + WeaviateVectorStore vectorStore = new WeaviateVectorStore(config, embeddingClient); + + return vectorStore; + } + + @Bean + public EmbeddingClient embeddingClient() { + return new TransformersEmbeddingClient(); + } + + } + +} \ No newline at end of file diff --git a/vector-stores/spring-ai-weaviate/src/test/resources/docker-compose.yml b/vector-stores/spring-ai-weaviate/src/test/resources/docker-compose.yml new file mode 100644 index 00000000000..17bbe004e75 --- /dev/null +++ b/vector-stores/spring-ai-weaviate/src/test/resources/docker-compose.yml @@ -0,0 +1,20 @@ +version: '3.4' +services: + weaviate: + command: + - --host + - 0.0.0.0 + - --port + - '8080' + - --scheme + - http + image: semitechnologies/weaviate:1.22.4 + ports: + - 8080:8080 + restart: on-failure:0 + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + CLUSTER_HOSTNAME: 'node1'