diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index b653d1e7e02..87b721c5367 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -446,6 +446,12 @@ ${project.version} + + org.springframework.ai + spring-ai-gemfire-store-spring-boot-starter + ${project.version} + + org.springframework.ai spring-ai-zhipuai-spring-boot-starter diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc index d67d9f67da9..fd52c1dd064 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc @@ -1,122 +1,136 @@ = GemFire Vector Store -This section walks you through setting up the GemFire VectorStore to store document embeddings and perform similarity searches. +This section walks you through setting up the `GemFireVectorStore` to store document embeddings and perform similarity searches. -link:https://tanzu.vmware.com/gemfire[GemFire] is an ultra high speed in-memory data and compute grid, with vector extensions to store and search vectors efficiently. +link:https://tanzu.vmware.com/gemfire[GemFire] is a distributed, in-memory, key-value store that performs read and write operations at blazingly fast speeds. It offers highly available parallel message queues, continuous availability, and an event-driven architecture you can scale dynamically with no downtime. As your data size requirements increase to support high-performance, real-time apps, GemFire can scale linearly with ease. -link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/overview.html[GemFire VectorDB] extends GemFire's capabilities, serving as a versatile vector database that efficiently stores, retrieves, and performs vector searches through a distributed and resilient infrastructure: - -Capabilities: -- Create Indexes -- Store vectors and the associated metadata -- Perform vector searches based on similarity +link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/overview.html[GemFire VectorDB] extends GemFire's capabilities, serving as a versatile vector database that efficiently stores, retrieves, and performs vector similarity searches. == Prerequisites -Access to a GemFire cluster with the link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/install.html[GemFire Vector Database] extension installed. -You can download the GemFire VectorDB extension from the link:https://network.pivotal.io/products/gemfire-vectordb/[VMware Tanzu Network] after signing in. +1. A GemFire cluster with the GemFire VectorDB extension enabled +- link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/install.html[Install GemFire VectorDB extension] + +2. An `EmbeddingModel` bean to compute the document embeddings. Refer to the xref:api/embeddings.adoc#available-implementations[EmbeddingModel] section for more information. == Dependencies Add these dependencies to your project: -- Embedding Model boot starter, required for calculating embeddings. -- Transformers Embedding (Local) and follow the ONNX Transformers Embedding instructions. +- Embedding Model boot starter, required for calculating embeddings +- Transformers Embedding (Local) and follow the ONNX Transformers Embedding instructions + +=== Maven -[source,xml] +To use just the `GemFireVectorStore`, add the following dependency to your project’s Maven `pom.xml`: + +[source, xml] ---- - org.springframework.ai - spring-ai-transformers + org.springframework.ai + spring-ai-gemfire-store ---- -- Add the GemFire VectorDB dependencies +To enable Spring Boot’s Auto-Configuration for the `GemFireVectorStore`, also add the following to your project’s Maven `pom.xml`: -[source,xml] +[source, xml] ---- org.springframework.ai - spring-ai-gemfire-store + spring-ai-gemfire-store-spring-boot-starter + ---- -TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. +=== Gradle +For Gradle users, add the following to your `build.gradle` file under the dependencies block to use just the `GemFireVectorStore`: -== Sample Code +[souce, xml] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-gemfire-store' +} +---- -- To configure GemFire in your application, use the following setup: +To enable Spring Boot’s Auto-Configuration for the `GemFireVectorStore`, also include this dependency: -[source,java] +[source, xml] ---- -@Bean -public GemFireVectorStoreConfig gemFireVectorStoreConfig() { - return GemFireVectorStoreConfig.builder() - .withUrl("http://localhost:8080") - .withIndexName("spring-ai-test-index") - .build(); +dependencies { + implementation 'org.springframework.ai:spring-ai-gemfire-store-spring-boot-starter' } ---- -- Create a GemFireVectorStore instance connected to your GemFire VectorDB: +== Usage + +- Create a `GemFireVectorStore` instance connected to the GemFire cluster: [source,java] ---- @Bean -public VectorStore vectorStore(GemFireVectorStoreConfig config, EmbeddingModel embeddingModel) { - return new GemFireVectorStore(config, embeddingModel); +public VectorStore vectorStore(EmbeddingModel embeddingModel) { + return new GemFireVectorStore(new GemFireVectorStoreConfig() + .setIndexName("my-vector-index") + .setPort(7071), embeddingClient); } ---- -- Create a Vector Index which will configure GemFire region. + +[NOTE] +==== +The default configuration connects to a GemFire cluster at `localhost:8080` +==== + +- In your application, create a few documents: [source,java] ---- - public void createIndex() { - try { - CreateRequest createRequest = new CreateRequest(); - createRequest.setName(INDEX_NAME); - createRequest.setBeamWidth(20); - createRequest.setMaxConnections(16); - ObjectMapper objectMapper = new ObjectMapper(); - String index = objectMapper.writeValueAsString(createRequest); - client.post() - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(index) - .retrieve() - .bodyToMono(Void.class) - .block(); - } - catch (Exception e) { - logger.warn("An unexpected error occurred while creating the index"); - } - } ----- - -- Create some documents: +List documents = List.of( + new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Map.of("country", "UK", "year", 2020)), + new Document("The World is Big and Salvation Lurks Around the Corner", Map.of()), + new Document("You walk forward facing the past and you turn back toward the future.", Map.of("country", "NL", "year", 2023))); +---- + +- Add the documents to the vector store: [source,java] ---- - List documents = List.of( - new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), - new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), - new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); +vectorStore.add(documents); ---- -- Add the documents to GemFire VectorDB: +- And to retrieve documents using similarity search: [source,java] ---- -vectorStore.add(List.of(document)); +List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5)); ---- -- And finally, retrieve documents similar to a query: +You should retrieve the document containing the text "Spring AI rocks!!". +You can also limit the number of results using a similarity threshold: [source,java] ---- - List results = vectorStore.similaritySearch("Spring", 5); +List results = vectorStore.similaritySearch( + SearchRequest.query("Spring").withTopK(5) + .withSimilarityThreshold(0.5d)); ---- -If all goes well, you should retrieve the document containing the text "Spring AI rocks!!". +== GemFireVectorStore properties + +You can use the following properties in your Spring Boot configuration to further configure the `GemFireVectorStore`. + +|=== +|Property|Default value +|`spring.ai.vectorstore.gemfire.host`|localhost +|`spring.ai.vectorstore.gemfire.port`|8080 +|`spring.ai.vectorstore.gemfire.index-name`|spring-ai-gemfire-store +|`spring.ai.vectorstore.gemfire.beam-width`|100 +|`spring.ai.vectorstore.gemfire.max-connections`|16 +|`spring.ai.vectorstore.gemfire.vector-similarity-function`|COSINE +|`spring.ai.vectorstore.gemfire.fields`|[] +|`spring.ai.vectorstore.gemfire.buckets`|0 +|=== diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 291129335cb..97f3e56da41 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -297,6 +297,13 @@ true + + org.springframework.ai + spring-ai-gemfire-store + ${project.parent.version} + true + + org.springframework.ai spring-ai-minimax @@ -458,6 +465,13 @@ test + + dev.gemfire + gemfire-testcontainers + 2.3.0 + test + + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java new file mode 100644 index 00000000000..7792c045131 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vectorstore.gemfire; + +import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails; + +/** + * @author Geet Rawat + */ +public interface GemFireConnectionDetails extends ConnectionDetails { + + String getHost(); + + int getPort(); + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java new file mode 100644 index 00000000000..098f2172fa4 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java @@ -0,0 +1,84 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.gemfire; + +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.GemFireVectorStore; +import org.springframework.ai.vectorstore.GemFireVectorStoreConfig; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; + +/** + * @author Geet Rawat + */ +@AutoConfiguration +@ConditionalOnClass({ GemFireVectorStore.class, EmbeddingModel.class }) +@EnableConfigurationProperties(GemFireVectorStoreProperties.class) +@ConditionalOnProperty(prefix = "spring.ai.vectorstore.gemfire", value = { "index-name" }) +public class GemFireVectorStoreAutoConfiguration { + + @Bean + @ConditionalOnMissingBean(GemFireConnectionDetails.class) + GemFireVectorStoreAutoConfiguration.PropertiesGemFireConnectionDetails gemfireConnectionDetails( + GemFireVectorStoreProperties properties) { + return new GemFireVectorStoreAutoConfiguration.PropertiesGemFireConnectionDetails(properties); + } + + @Bean + @ConditionalOnMissingBean + public GemFireVectorStore vectorStore(EmbeddingModel embeddingModel, GemFireVectorStoreProperties properties, + GemFireConnectionDetails gemFireConnectionDetails) { + var config = new GemFireVectorStoreConfig(); + + config.setHost(gemFireConnectionDetails.getHost()) + .setPort(gemFireConnectionDetails.getPort()) + .setIndexName(properties.getIndexName()) + .setBeamWidth(properties.getBeamWidth()) + .setMaxConnections(properties.getMaxConnections()) + .setBuckets(properties.getBuckets()) + .setVectorSimilarityFunction(properties.getVectorSimilarityFunction()) + .setFields(properties.getFields()) + .setSslEnabled(properties.isSslEnabled()); + return new GemFireVectorStore(config, embeddingModel); + } + + private static class PropertiesGemFireConnectionDetails implements GemFireConnectionDetails { + + private final GemFireVectorStoreProperties properties; + + PropertiesGemFireConnectionDetails(GemFireVectorStoreProperties properties) { + + this.properties = properties; + } + + @Override + public String getHost() { + return this.properties.getHost(); + } + + @Override + public int getPort() { + return this.properties.getPort(); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java new file mode 100644 index 00000000000..cdcc27c04cd --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java @@ -0,0 +1,166 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.gemfire; + +import org.springframework.ai.vectorstore.GemFireVectorStoreConfig; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author Geet Rawat + */ +@ConfigurationProperties(GemFireVectorStoreProperties.CONFIG_PREFIX) +public class GemFireVectorStoreProperties { + + /** + * Configuration prefix for Spring AI VectorStore GemFire. + */ + public static final String CONFIG_PREFIX = "spring.ai.vectorstore.gemfire"; + + /** + * The host of the GemFire to connect to. To specify a custom host, use + * "spring.ai.vectorstore.gemfire.host"; + * + */ + private String host = GemFireVectorStoreConfig.DEFAULT_HOST; + + /** + * The port of the GemFire to connect to. To specify a custom port, use + * "spring.ai.vectorstore.gemfire.port"; + */ + private int port = GemFireVectorStoreConfig.DEFAULT_PORT; + + /** + * The name of the index in the GemFire. To specify a custom index, use + * "spring.ai.vectorstore.gemfire.index-name"; + */ + private String indexName = GemFireVectorStoreConfig.DEFAULT_INDEX_NAME; + + /** + * The beam width for similarity queries. Default value is {@code 100}. To specify a + * custom beam width, use "spring.ai.vectorstore.gemfire.beam-width"; + */ + private int beamWidth = GemFireVectorStoreConfig.DEFAULT_BEAM_WIDTH; + + /** + * The maximum number of connections allowed. Default value is {@code 16}. To specify + * custom number of connections, use "spring.ai.vectorstore.gemfire.max-connections"; + */ + private int maxConnections = GemFireVectorStoreConfig.DEFAULT_MAX_CONNECTIONS; + + /** + * The similarity function to be used for vector comparisons. Default value is + * {@code "COSINE"}. To specify custom vectorSimilarityFunction, use + * "spring.ai.vectorstore.gemfire.vector-similarity-function"; + * + */ + private String vectorSimilarityFunction = GemFireVectorStoreConfig.DEFAULT_SIMILARITY_FUNCTION; + + /** + * The fields to be used for queries. Default value is an array containing + * {@code "vector"}. To specify custom fields, use + * "spring.ai.vectorstore.gemfire.fields" + */ + private String[] fields = GemFireVectorStoreConfig.DEFAULT_FIELDS; + + /** + * The number of buckets to use for partitioning the data. Default value is {@code 0}. + * + * To specify custom buckets, use "spring.ai.vectorstore.gemfire.buckets"; + * + */ + private int buckets = GemFireVectorStoreConfig.DEFAULT_BUCKETS; + + /** + * Set to true if GemFire cluster is ssl enabled + * + * To specify sslEnabled, use "spring.ai.vectorstore.gemfire.ssl-enabled"; + */ + private boolean sslEnabled = GemFireVectorStoreConfig.DEFAULT_SSL_ENABLED; + + public int getBeamWidth() { + return beamWidth; + } + + public void setBeamWidth(int beamWidth) { + this.beamWidth = beamWidth; + } + + public int getPort() { + return port; + } + + public void setPort(int port) { + this.port = port; + } + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public String getIndexName() { + return indexName; + } + + public void setIndexName(String indexName) { + this.indexName = indexName; + } + + public int getMaxConnections() { + return maxConnections; + } + + public void setMaxConnections(int maxConnections) { + this.maxConnections = maxConnections; + } + + public String getVectorSimilarityFunction() { + return vectorSimilarityFunction; + } + + public void setVectorSimilarityFunction(String vectorSimilarityFunction) { + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + public String[] getFields() { + return fields; + } + + public void setFields(String[] fields) { + this.fields = fields; + } + + public int getBuckets() { + return buckets; + } + + public void setBuckets(int buckets) { + this.buckets = buckets; + } + + public boolean isSslEnabled() { + return sslEnabled; + } + + public void setSslEnabled(boolean sslEnabled) { + this.sslEnabled = sslEnabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 5117fd476e9..804d5c8fa0f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -32,6 +32,7 @@ org.springframework.ai.autoconfigure.vectorstore.mongo.MongoDBAtlasVectorStoreAu org.springframework.ai.autoconfigure.anthropic.AnthropicAutoConfiguration org.springframework.ai.autoconfigure.watsonxai.WatsonxAiAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.elasticsearch.ElasticsearchVectorStoreAutoConfiguration +org.springframework.ai.autoconfigure.vectorstore.gemfire.GemFireVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.cassandra.CassandraVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration org.springframework.ai.autoconfigure.chat.client.ChatClientAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java new file mode 100644 index 00000000000..9c6f9ca607d --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java @@ -0,0 +1,200 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.gemfire; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.hasSize; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.dockerjava.api.model.ExposedPort; +import com.github.dockerjava.api.model.PortBinding; +import com.github.dockerjava.api.model.Ports; +import com.vmware.gemfire.testcontainers.GemFireCluster; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.springframework.ai.ResourceUtils; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.transformers.TransformersEmbeddingModel; +import org.springframework.ai.vectorstore.GemFireVectorStore; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * @author Geet Rawat + */ + +class GemFireVectorStoreAutoConfigurationIT { + + private static GemFireCluster gemFireCluster; + + private static final String INDEX_NAME = "spring-ai-index"; + + private static final int BEAM_WIDTH = 50; + + private static final int MAX_CONNECTIONS = 8; + + private static final String SIMILARITY_FUNCTION = "DOT_PRODUCT"; + + private static final String[] FIELDS = { "someField1", "someField2" }; + + private static final int BUCKET_COUNT = 2; + + private static final int HTTP_SERVICE_PORT = 9090; + + private static final int LOCATOR_COUNT = 1; + + private static final int SERVER_COUNT = 1; + + @AfterAll + public static void stopGemFireCluster() { + gemFireCluster.close(); + } + + List documents = List.of( + new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document( + ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(GemFireVectorStoreAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.gemfire.index-name=" + INDEX_NAME) + .withPropertyValues("spring.ai.vectorstore.gemfire.beam-width=" + BEAM_WIDTH) + .withPropertyValues("spring.ai.vectorstore.gemfire.max-connections=" + MAX_CONNECTIONS) + .withPropertyValues("spring.ai.vectorstore.gemfire.vector-similarity-function=" + SIMILARITY_FUNCTION) + .withPropertyValues("spring.ai.vectorstore.gemfire.buckets=" + BUCKET_COUNT) + .withPropertyValues("spring.ai.vectorstore.gemfire.fields=someField1,someField2") + .withPropertyValues("spring.ai.vectorstore.gemfire.host=localhost") + .withPropertyValues("spring.ai.vectorstore.gemfire.port=" + HTTP_SERVICE_PORT); + + @BeforeAll + public static void startGemFireCluster() { + Ports.Binding hostPort = Ports.Binding.bindPort(HTTP_SERVICE_PORT); + ExposedPort exposedPort = new ExposedPort(HTTP_SERVICE_PORT); + PortBinding mappedPort = new PortBinding(hostPort, exposedPort); + gemFireCluster = new GemFireCluster("gemfire/gemfire-all:10.1-jdk17", LOCATOR_COUNT, SERVER_COUNT); + gemFireCluster.withConfiguration(GemFireCluster.SERVER_GLOB, + container -> container.withExposedPorts(HTTP_SERVICE_PORT) + .withCreateContainerCmdModifier(cmd -> cmd.getHostConfig().withPortBindings(mappedPort))); + gemFireCluster.withGemFireProperty(GemFireCluster.SERVER_GLOB, "http-service-port", + Integer.toString(HTTP_SERVICE_PORT)); + gemFireCluster.acceptLicense().start(); + + System.setProperty("spring.data.gemfire.pool.locators", + String.format("localhost[%d]", gemFireCluster.getLocatorPort())); + } + + @Test + void ensureGemFireVectorStoreCustomConfiguration() { + this.contextRunner.run(context -> { + GemFireVectorStore store = context.getBean(GemFireVectorStore.class); + Assertions.assertNotNull(store); + assertThat(store.getIndexName()).isEqualTo(INDEX_NAME); + assertThat(store.getBeamWidth()).isEqualTo(BEAM_WIDTH); + assertThat(store.getMaxConnections()).isEqualTo(MAX_CONNECTIONS); + assertThat(store.getVectorSimilarityFunction()).isEqualTo(SIMILARITY_FUNCTION); + assertThat(store.getFields()).isEqualTo(FIELDS); + + String indexJson = store.getIndex(); + Map index = parseIndex(indexJson); + assertThat(index.get("name")).isEqualTo(INDEX_NAME); + assertThat(index.get("beam-width")).isEqualTo(BEAM_WIDTH); + assertThat(index.get("max-connections")).isEqualTo(MAX_CONNECTIONS); + assertThat(index.get("vector-similarity-function")).isEqualTo(SIMILARITY_FUNCTION); + assertThat(index.get("buckets")).isEqualTo(BUCKET_COUNT); + + }); + } + + @Test + public void addAndSearchTest() { + contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + vectorStore.add(documents); + + Awaitility.await().until(() -> { + return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + }, hasSize(1)); + + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance"); + + // Remove all documents from the store + vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + + Awaitility.await().until(() -> { + return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + }, hasSize(0)); + }); + } + + private Map parseIndex(String json) { + try { + JsonNode rootNode = new ObjectMapper().readTree(json); + Map indexDetails = new HashMap<>(); + if (rootNode.isObject()) { + if (rootNode.has("name")) + indexDetails.put("name", rootNode.get("name").asText()); + if (rootNode.has("beam-width")) + indexDetails.put("beam-width", rootNode.get("beam-width").asInt()); + if (rootNode.has("max-connections")) + indexDetails.put("max-connections", rootNode.get("max-connections").asInt()); + if (rootNode.has("vector-similarity-function")) + indexDetails.put("vector-similarity-function", rootNode.get("vector-similarity-function").asText()); + if (rootNode.has("buckets")) + indexDetails.put("buckets", rootNode.get("buckets").asInt()); + if (rootNode.has("number-of-embeddings")) + indexDetails.put("number-of-embeddings", rootNode.get("number-of-embeddings").asInt()); + } + return indexDetails; + } + catch (Exception e) { + return new HashMap<>(); + } + } + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Bean + public EmbeddingModel embeddingModel() { + return new TransformersEmbeddingModel(); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java new file mode 100644 index 00000000000..559e07778b2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.gemfire; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.vectorstore.GemFireVectorStoreConfig; + +/** + * @author Geet Rawat + */ +class GemFireVectorStorePropertiesTests { + + @Test + void defaultValues() { + var props = new GemFireVectorStoreProperties(); + assertThat(props.getIndexName()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_INDEX_NAME); + assertThat(props.getHost()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_HOST); + assertThat(props.getPort()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_PORT); + assertThat(props.getBeamWidth()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_BEAM_WIDTH); + assertThat(props.getMaxConnections()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_MAX_CONNECTIONS); + assertThat(props.getFields()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_FIELDS); + assertThat(props.getBuckets()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_BUCKETS); + } + + @Test + void customValues() { + var props = new GemFireVectorStoreProperties(); + props.setIndexName("spring-ai-index"); + props.setHost("localhost"); + props.setPort(9090); + props.setBeamWidth(10); + props.setMaxConnections(10); + props.setFields(new String[] { "test" }); + props.setBuckets(10); + + assertThat(props.getIndexName()).isEqualTo("spring-ai-index"); + assertThat(props.getHost()).isEqualTo("localhost"); + assertThat(props.getPort()).isEqualTo(9090); + assertThat(props.getBeamWidth()).isEqualTo(10); + assertThat(props.getMaxConnections()).isEqualTo(10); + assertThat(props.getFields()).isEqualTo(new String[] { "test" }); + assertThat(props.getBuckets()).isEqualTo(10); + + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml new file mode 100644 index 00000000000..de48a933f30 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml @@ -0,0 +1,42 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-gemfire-store-spring-boot-starter + jar + Spring AI Starter - GemFire Vector Store + Spring AI GemFire Vector Store Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-gemfire-store + ${project.parent.version} + + + + diff --git a/vector-stores/spring-ai-gemfire-store/pom.xml b/vector-stores/spring-ai-gemfire-store/pom.xml index 45f7e71637e..01ca725646b 100644 --- a/vector-stores/spring-ai-gemfire-store/pom.xml +++ b/vector-stores/spring-ai-gemfire-store/pom.xml @@ -38,6 +38,13 @@ + + dev.gemfire + gemfire-testcontainers + 2.3.0 + test + + org.springframework.ai spring-ai-openai @@ -71,10 +78,6 @@ 3.0.0 test - - org.apache.logging.log4j - log4j-core - diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java index c193b21b667..0393de9c520 100644 --- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.beans.factory.InitializingBean; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.util.Assert; @@ -48,13 +49,11 @@ * * @author Geet Rawat */ -public class GemFireVectorStore implements VectorStore { - - public static final String QUERY = "/query"; +public class GemFireVectorStore implements VectorStore, InitializingBean { private static final Logger logger = LoggerFactory.getLogger(GemFireVectorStore.class); - private static final String DISTANCE_METADATA_FIELD_NAME = "distance"; + private static final String DEFAULT_URI = "http{ssl}://{host}:{port}/gemfire-vectordb/v1/indexes"; private static final String EMBEDDINGS = "/embeddings"; @@ -62,179 +61,135 @@ public class GemFireVectorStore implements VectorStore { private final EmbeddingModel embeddingModel; - private final int topKPerBucket; - - private final int topK; - - private final String documentField; - - public static final class GemFireVectorStoreConfig { - - private final WebClient client; - - private final String index; - - private final int topKPerBucket; - - public final int topK; - - private final String documentField; - - public static Builder builder() { - return new Builder(); - } - - private GemFireVectorStoreConfig(Builder builder) { - String base = UriComponentsBuilder.fromUriString(DEFAULT_URI) - .build(builder.sslEnabled ? "s" : "", builder.host, builder.port) - .toString(); - this.index = builder.index; - this.client = WebClient.create(base); - this.topKPerBucket = builder.topKPerBucket; - this.topK = builder.topK; - this.documentField = builder.documentField; - } - - public static class Builder { - - private String host; - - private int port = DEFAULT_PORT; - - private boolean sslEnabled; + private static final String DOCUMENT_FIELD = "document"; - private long connectionTimeout; + // Create Index Parameters - private long requestTimeout; + private String indexName; - private String index; + public String getIndexName() { + return indexName; + } - private int topKPerBucket = DEFAULT_TOP_K_PER_BUCKET; + private int beamWidth; - private int topK = DEFAULT_TOP_K; + public int getBeamWidth() { + return beamWidth; + } - private String documentField = DEFAULT_DOCUMENT_FIELD; + private int maxConnections; - public Builder withHost(String host) { - Assert.hasText(host, "host must have a value"); - this.host = host; - return this; - } + public int getMaxConnections() { + return maxConnections; + } - public Builder withPort(int port) { - Assert.isTrue(port > 0, "port must be positive"); - this.port = port; - return this; - } + private int buckets; - public Builder withSslEnabled(boolean sslEnabled) { - this.sslEnabled = sslEnabled; - return this; - } - - public Builder withConnectionTimeout(long timeout) { - Assert.isTrue(timeout >= 0, "timeout must be >= 0"); - this.connectionTimeout = timeout; - return this; - } + public int getBuckets() { + return buckets; + } - public Builder withRequestTimeout(long timeout) { - Assert.isTrue(timeout >= 0, "timeout must be >= 0"); - this.requestTimeout = timeout; - return this; - } + private String vectorSimilarityFunction; - public Builder withIndex(String index) { - Assert.hasText(index, "index must have a value"); - this.index = index; - return this; - } + public String getVectorSimilarityFunction() { + return vectorSimilarityFunction; + } - public Builder withTopKPerBucket(int topKPerBucket) { - Assert.isTrue(topKPerBucket > 0, "topKPerBucket must be positive"); - this.topKPerBucket = topKPerBucket; - return this; - } + private String[] fields; - public Builder withTopK(int topK) { - Assert.isTrue(topK > 0, "topK must be positive"); - this.topK = topK; - return this; - } + public String[] getFields() { + return fields; + } - public Builder withDocumentField(String documentField) { - Assert.hasText(documentField, "documentField must have a value"); - this.documentField = documentField; - return this; - } + // Query Defaults + private static final String QUERY = "/query"; - public GemFireVectorStoreConfig build() { - return new GemFireVectorStoreConfig(this); - } + private static final String DISTANCE_METADATA_FIELD_NAME = "distance"; + /** + * Initializes the GemFireVectorStore after properties are set. This method is called + * after all bean properties have been set and allows the bean to perform any + * initialization it requires. + */ + @Override + public void afterPropertiesSet() throws Exception { + if (indexExists()) { + deleteIndex(); } + createIndex(); } - private static final int DEFAULT_PORT = 9090; - - public static final String DEFAULT_URI = "http{ssl}://{host}:{port}/gemfire-vectordb/v1/indexes"; - - private static final int DEFAULT_TOP_K_PER_BUCKET = 10; - - private static final int DEFAULT_TOP_K = 10; - - private static final String DEFAULT_DOCUMENT_FIELD = "document"; - - public String indexName; + /** + * Checks if the index exists in the GemFireVectorStore. + * @return {@code true} if the index exists, {@code false} otherwise + */ + public boolean indexExists() { + String indexResponse = getIndex(); + return !indexResponse.isEmpty(); + } - public void setIndexName(String indexName) { - this.indexName = indexName; + public String getIndex() { + return client.get().uri("/" + indexName).retrieve().bodyToMono(String.class).onErrorReturn("").block(); } - public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingModel embedding) { + /** + * Configures and initializes a GemFireVectorStore instance based on the provided + * configuration. + * @param config the configuration for the GemFireVectorStore + * @param embeddingModel the embedding client used for generating embeddings + */ + + public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingModel embeddingModel) { Assert.notNull(config, "GemFireVectorStoreConfig must not be null"); - Assert.notNull(embedding, "EmbeddingModel must not be null"); - this.client = config.client; - this.embeddingModel = embedding; - this.topKPerBucket = config.topKPerBucket; - this.topK = config.topK; - this.documentField = config.documentField; + Assert.notNull(embeddingModel, "EmbeddingModel must not be null"); + this.indexName = config.indexName; + this.embeddingModel = embeddingModel; + this.beamWidth = config.beamWidth; + this.maxConnections = config.maxConnections; + this.buckets = config.buckets; + this.vectorSimilarityFunction = config.vectorSimilarityFunction; + this.fields = config.fields; + + String base = UriComponentsBuilder.fromUriString(DEFAULT_URI) + .build(config.sslEnabled ? "s" : "", config.host, config.port) + .toString(); + this.client = WebClient.create(base); } - private static final class CreateRequest { + public static class CreateRequest { @JsonProperty("name") - private String name; + private String indexName; @JsonProperty("beam-width") - private int beamWidth = 100; + private int beamWidth; @JsonProperty("max-connections") - private int maxConnections = 16; + private int maxConnections; @JsonProperty("vector-similarity-function") - private String vectorSimilarityFunction = "COSINE"; + private String vectorSimilarityFunction; @JsonProperty("fields") - private String[] fields = new String[] { "vector" }; + private String[] fields; @JsonProperty("buckets") - private int buckets = 0; + private int buckets; public CreateRequest() { } - public CreateRequest(String name) { - this.name = name; + public CreateRequest(String indexName) { + this.indexName = indexName; } - public String getName() { - return name; + public String getIndexName() { + return indexName; } - public void setName(String name) { - this.name = name; + public void setIndexName(String indexName) { + this.indexName = indexName; } public int getBeamWidth() { @@ -419,7 +374,7 @@ public void add(List documents) { // Compute and assign an embedding to the document. document.setEmbedding(this.embeddingModel.embed(document)); List floatVector = document.getEmbedding().stream().map(Double::floatValue).toList(); - return new UploadRequest.Embedding(document.getId(), floatVector, documentField, document.getContent(), + return new UploadRequest.Embedding(document.getId(), floatVector, DOCUMENT_FIELD, document.getContent(), document.getMetadata()); }).toList()); @@ -463,22 +418,26 @@ public Optional delete(List idList) { @Override public List similaritySearch(SearchRequest request) { if (request.hasFilterExpression()) { - throw new UnsupportedOperationException("Gemfire does not support metadata filter expressions yet."); + throw new UnsupportedOperationException("GemFire currently does not support metadata filter expressions."); } List vector = this.embeddingModel.embed(request.getQuery()); List floatVector = vector.stream().map(Double::floatValue).toList(); - return client.post() .uri("/" + indexName + QUERY) .contentType(MediaType.APPLICATION_JSON) - .bodyValue(new QueryRequest(floatVector, request.getTopK(), topKPerBucket, true)) + .bodyValue(new QueryRequest(floatVector, request.getTopK(), request.getTopK(), // TopKPerBucket + true)) .retrieve() .bodyToFlux(QueryResponse.class) .filter(r -> r.score >= request.getSimilarityThreshold()) .map(r -> { Map metadata = r.metadata; + if (r.metadata == null) { + metadata = new HashMap<>(); + metadata.put(DOCUMENT_FIELD, "--Deleted--"); + } metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score); - String content = (String) metadata.remove(documentField); + String content = (String) metadata.remove(DOCUMENT_FIELD); return new Document(r.key, content, metadata); }) .collectList() @@ -486,10 +445,22 @@ public List similaritySearch(SearchRequest request) { .block(); } - public void createIndex(String indexName) throws JsonProcessingException { + /** + * Creates a new index in the GemFireVectorStore using specified parameters. This + * method is invoked during initialization. + * @throws JsonProcessingException if an error occurs during JSON processing + */ + public void createIndex() throws JsonProcessingException { CreateRequest createRequest = new CreateRequest(indexName); + createRequest.setBeamWidth(beamWidth); + createRequest.setMaxConnections(maxConnections); + createRequest.setBuckets(buckets); + createRequest.setVectorSimilarityFunction(vectorSimilarityFunction); + createRequest.setFields(fields); + ObjectMapper objectMapper = new ObjectMapper(); String index = objectMapper.writeValueAsString(createRequest); + client.post() .contentType(MediaType.APPLICATION_JSON) .bodyValue(index) @@ -499,9 +470,8 @@ public void createIndex(String indexName) throws JsonProcessingException { .block(); } - public void deleteIndex(String indexName) { + public void deleteIndex() { DeleteRequest deleteRequest = new DeleteRequest(); - deleteRequest.setDeleteData(true); client.method(HttpMethod.DELETE) .uri("/" + indexName) .body(BodyInserters.fromValue(deleteRequest)) @@ -511,6 +481,12 @@ public void deleteIndex(String indexName) { .block(); } + /** + * Handles exceptions that occur during HTTP client operations and maps them to + * appropriate runtime exceptions. + * @param ex the exception that occurred during HTTP client operation + * @return a mapped runtime exception corresponding to the HTTP client exception + */ private Throwable handleHttpClientException(Throwable ex) { if (!(ex instanceof WebClientResponseException clientException)) { throw new RuntimeException(String.format("Got an unexpected error: %s", ex)); diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStoreConfig.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStoreConfig.java new file mode 100644 index 00000000000..58dd6049633 --- /dev/null +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStoreConfig.java @@ -0,0 +1,108 @@ +package org.springframework.ai.vectorstore; + +import org.springframework.util.Assert; + +public final class GemFireVectorStoreConfig { + + // Create Index DEFAULT Values + public static final String DEFAULT_HOST = "localhost"; + + public static final int DEFAULT_PORT = 8080; + + public static final String DEFAULT_INDEX_NAME = "spring-ai-gemfire-index"; + + public static final int UPPER_BOUND_BEAM_WIDTH = 3200; + + public static final int DEFAULT_BEAM_WIDTH = 100; + + private static final int UPPER_BOUND_MAX_CONNECTIONS = 512; + + public static final int DEFAULT_MAX_CONNECTIONS = 16; + + public static final String DEFAULT_SIMILARITY_FUNCTION = "COSINE"; + + public static final String[] DEFAULT_FIELDS = new String[] {}; + + public static final int DEFAULT_BUCKETS = 0; + + public static final boolean DEFAULT_SSL_ENABLED = false; + + String host = GemFireVectorStoreConfig.DEFAULT_HOST; + + int port = DEFAULT_PORT; + + String indexName = DEFAULT_INDEX_NAME; + + int beamWidth = DEFAULT_BEAM_WIDTH; + + int maxConnections = DEFAULT_MAX_CONNECTIONS; + + String vectorSimilarityFunction = DEFAULT_SIMILARITY_FUNCTION; + + String[] fields = DEFAULT_FIELDS; + + int buckets = DEFAULT_BUCKETS; + + boolean sslEnabled = DEFAULT_SSL_ENABLED; + + public GemFireVectorStoreConfig setHost(String host) { + Assert.hasText(host, "host must have a value"); + this.host = host; + return this; + } + + public GemFireVectorStoreConfig setPort(int port) { + Assert.isTrue(port > 0, "port must be positive"); + this.port = port; + return this; + } + + public GemFireVectorStoreConfig setSslEnabled(boolean sslEnabled) { + this.sslEnabled = sslEnabled; + return this; + } + + public GemFireVectorStoreConfig setIndexName(String indexName) { + Assert.hasText(indexName, "indexName must have a value"); + this.indexName = indexName; + return this; + } + + public GemFireVectorStoreConfig setBeamWidth(int beamWidth) { + Assert.isTrue(beamWidth > 0, "beamWidth must be positive"); + Assert.isTrue(beamWidth <= GemFireVectorStoreConfig.UPPER_BOUND_BEAM_WIDTH, + "beamWidth must be less than or equal to " + GemFireVectorStoreConfig.UPPER_BOUND_BEAM_WIDTH); + this.beamWidth = beamWidth; + return this; + } + + public GemFireVectorStoreConfig setMaxConnections(int maxConnections) { + Assert.isTrue(maxConnections > 0, "maxConnections must be positive"); + Assert.isTrue(maxConnections <= GemFireVectorStoreConfig.UPPER_BOUND_MAX_CONNECTIONS, + "maxConnections must be less than or equal to " + GemFireVectorStoreConfig.UPPER_BOUND_MAX_CONNECTIONS); + this.maxConnections = maxConnections; + return this; + } + + public GemFireVectorStoreConfig setBuckets(int buckets) { + Assert.isTrue(buckets >= 0, "bucket must be 1 or more"); + this.buckets = buckets; + return this; + } + + public GemFireVectorStoreConfig setVectorSimilarityFunction(String vectorSimilarityFunction) { + Assert.hasText(vectorSimilarityFunction, "vectorSimilarityFunction must have a value"); + this.vectorSimilarityFunction = vectorSimilarityFunction; + return this; + } + + public GemFireVectorStoreConfig setFields(String[] fields) { + this.fields = fields; + return this; + } + + public GemFireVectorStoreConfig() { + + } + +} diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java index 21de25e509c..72ac0760ed2 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java @@ -26,16 +26,17 @@ import java.util.Map; import java.util.UUID; +import com.github.dockerjava.api.model.ExposedPort; +import com.github.dockerjava.api.model.PortBinding; +import com.github.dockerjava.api.model.Ports; +import com.vmware.gemfire.testcontainers.GemFireCluster; import org.awaitility.Awaitility; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.transformers.TransformersEmbeddingModel; -import org.springframework.ai.vectorstore.GemFireVectorStore.GemFireVectorStoreConfig; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -46,11 +47,40 @@ * @author Geet Rawat * @since 1.0.0 */ -@EnabledIfEnvironmentVariable(named = "GEMFIRE_HOST", matches = ".+") public class GemFireVectorStoreIT { public static final String INDEX_NAME = "spring-ai-index1"; + private static GemFireCluster gemFireCluster; + + private static final int HTTP_SERVICE_PORT = 9090; + + private static final int LOCATOR_COUNT = 1; + + private static final int SERVER_COUNT = 1; + + @AfterAll + public static void stopGemFireCluster() { + gemFireCluster.close(); + } + + @BeforeAll + public static void startGemFireCluster() { + Ports.Binding hostPort = Ports.Binding.bindPort(HTTP_SERVICE_PORT); + ExposedPort exposedPort = new ExposedPort(HTTP_SERVICE_PORT); + PortBinding mappedPort = new PortBinding(hostPort, exposedPort); + gemFireCluster = new GemFireCluster("gemfire/gemfire-all:10.1-jdk17", LOCATOR_COUNT, SERVER_COUNT); + gemFireCluster.withConfiguration(GemFireCluster.SERVER_GLOB, + container -> container.withExposedPorts(HTTP_SERVICE_PORT) + .withCreateContainerCmdModifier(cmd -> cmd.getHostConfig().withPortBindings(mappedPort))); + gemFireCluster.withGemFireProperty(GemFireCluster.SERVER_GLOB, "http-service-port", + Integer.toString(HTTP_SERVICE_PORT)); + gemFireCluster.acceptLicense().start(); + + System.setProperty("spring.data.gemfire.pool.locators", + String.format("localhost[%d]", gemFireCluster.getLocatorPort())); + } + List documents = List.of( new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), @@ -69,25 +99,16 @@ public static String getText(String uri) { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withUserConfiguration(TestApplication.class); - @BeforeEach - public void createIndex() { - contextRunner.run(c -> c.getBean(GemFireVectorStore.class).createIndex(INDEX_NAME)); - } - - @AfterEach - public void deleteIndex() { - contextRunner.run(c -> c.getBean(GemFireVectorStore.class).deleteIndex(INDEX_NAME)); - } - @Test public void addAndDeleteEmbeddingTest() { contextRunner.run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); vectorStore.add(documents); vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); - Awaitility.await().atMost(1, MINUTES).until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(3)); - }, hasSize(0)); + Awaitility.await() + .atMost(1, MINUTES) + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(3)), + hasSize(0)); }); } @@ -97,14 +118,15 @@ public void addAndSearchTest() { VectorStore vectorStore = context.getBean(VectorStore.class); vectorStore.add(documents); - Awaitility.await().atMost(1, MINUTES).until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); - }, hasSize(1)); + Awaitility.await() + .atMost(1, MINUTES) + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)), + hasSize(1)); List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(5)); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); - assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); + assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939)" + " was an economic shock"); assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); @@ -120,9 +142,10 @@ public void documentUpdateTest() { Collections.singletonMap("meta1", "meta1")); vectorStore.add(List.of(document)); SearchRequest springSearchRequest = SearchRequest.query("Spring").withTopK(5); - Awaitility.await().atMost(1, MINUTES).until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); - }, hasSize(1)); + Awaitility.await() + .atMost(1, MINUTES) + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)), + hasSize(1)); List results = vectorStore.similaritySearch(springSearchRequest); Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); @@ -131,7 +154,7 @@ public void documentUpdateTest() { assertThat(resultDoc.getMetadata()).containsKey("distance"); Document sameIdDocument = new Document(document.getId(), - "The World is Big and Salvation Lurks Around the Corner", + "The World is Big and Salvation Lurks " + "Around the Corner", Collections.singletonMap("meta2", "meta2")); vectorStore.add(List.of(sameIdDocument)); @@ -141,7 +164,7 @@ public void documentUpdateTest() { assertThat(results).hasSize(1); resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(document.getId()); - assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); + assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation" + " Lurks Around the Corner"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); }); @@ -154,10 +177,11 @@ public void searchThresholdTest() { VectorStore vectorStore = context.getBean(VectorStore.class); vectorStore.add(documents); - Awaitility.await().atMost(1, MINUTES).until(() -> { - return vectorStore - .similaritySearch(SearchRequest.query("Great Depression").withTopK(5).withSimilarityThresholdAll()); - }, hasSize(3)); + Awaitility.await() + .atMost(1, MINUTES) + .until(() -> vectorStore + .similaritySearch(SearchRequest.query("Great Depression").withTopK(5).withSimilarityThresholdAll()), + hasSize(3)); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); @@ -173,7 +197,7 @@ public void searchThresholdTest() { Document resultDoc = results.get(0); assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); - assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); + assertThat(resultDoc.getContent()).contains("The Great Depression " + "(1929–1939) was an economic shock"); assertThat(resultDoc.getMetadata()).containsKey("meta2"); assertThat(resultDoc.getMetadata()).containsKey("distance"); }); @@ -185,14 +209,14 @@ public static class TestApplication { @Bean public GemFireVectorStoreConfig gemfireVectorStoreConfig() { - return GemFireVectorStoreConfig.builder().withHost("localhost").build(); + return new GemFireVectorStoreConfig().setHost("localhost") + .setPort(HTTP_SERVICE_PORT) + .setIndexName(INDEX_NAME); } @Bean public GemFireVectorStore vectorStore(GemFireVectorStoreConfig config, EmbeddingModel embeddingModel) { - GemFireVectorStore gemFireVectorStore = new GemFireVectorStore(config, embeddingModel); - gemFireVectorStore.setIndexName(INDEX_NAME); - return gemFireVectorStore; + return new GemFireVectorStore(config, embeddingModel); } @Bean @@ -202,4 +226,4 @@ public EmbeddingModel embeddingModel() { } -} \ No newline at end of file +}