diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml
index b653d1e7e02..87b721c5367 100644
--- a/spring-ai-bom/pom.xml
+++ b/spring-ai-bom/pom.xml
@@ -446,6 +446,12 @@
${project.version}
+
+ org.springframework.ai
+ spring-ai-gemfire-store-spring-boot-starter
+ ${project.version}
+
+
org.springframework.ai
spring-ai-zhipuai-spring-boot-starter
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc
index d67d9f67da9..fd52c1dd064 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/gemfire.adoc
@@ -1,122 +1,136 @@
= GemFire Vector Store
-This section walks you through setting up the GemFire VectorStore to store document embeddings and perform similarity searches.
+This section walks you through setting up the `GemFireVectorStore` to store document embeddings and perform similarity searches.
-link:https://tanzu.vmware.com/gemfire[GemFire] is an ultra high speed in-memory data and compute grid, with vector extensions to store and search vectors efficiently.
+link:https://tanzu.vmware.com/gemfire[GemFire] is a distributed, in-memory, key-value store that performs read and write operations at blazingly fast speeds. It offers highly available parallel message queues, continuous availability, and an event-driven architecture you can scale dynamically with no downtime. As your data size requirements increase to support high-performance, real-time apps, GemFire can scale linearly with ease.
-link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/overview.html[GemFire VectorDB] extends GemFire's capabilities, serving as a versatile vector database that efficiently stores, retrieves, and performs vector searches through a distributed and resilient infrastructure:
-
-Capabilities:
-- Create Indexes
-- Store vectors and the associated metadata
-- Perform vector searches based on similarity
+link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/overview.html[GemFire VectorDB] extends GemFire's capabilities, serving as a versatile vector database that efficiently stores, retrieves, and performs vector similarity searches.
== Prerequisites
-Access to a GemFire cluster with the link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/install.html[GemFire Vector Database] extension installed.
-You can download the GemFire VectorDB extension from the link:https://network.pivotal.io/products/gemfire-vectordb/[VMware Tanzu Network] after signing in.
+1. A GemFire cluster with the GemFire VectorDB extension enabled
+- link:https://docs.vmware.com/en/VMware-GemFire-VectorDB/1.0/gemfire-vectordb/install.html[Install GemFire VectorDB extension]
+
+2. An `EmbeddingModel` bean to compute the document embeddings. Refer to the xref:api/embeddings.adoc#available-implementations[EmbeddingModel] section for more information.
== Dependencies
Add these dependencies to your project:
-- Embedding Model boot starter, required for calculating embeddings.
-- Transformers Embedding (Local) and follow the ONNX Transformers Embedding instructions.
+- Embedding Model boot starter, required for calculating embeddings
+- Transformers Embedding (Local) and follow the ONNX Transformers Embedding instructions
+
+=== Maven
-[source,xml]
+To use just the `GemFireVectorStore`, add the following dependency to your project’s Maven `pom.xml`:
+
+[source, xml]
----
- org.springframework.ai
- spring-ai-transformers
+ org.springframework.ai
+ spring-ai-gemfire-store
----
-- Add the GemFire VectorDB dependencies
+To enable Spring Boot’s Auto-Configuration for the `GemFireVectorStore`, also add the following to your project’s Maven `pom.xml`:
-[source,xml]
+[source, xml]
----
org.springframework.ai
- spring-ai-gemfire-store
+ spring-ai-gemfire-store-spring-boot-starter
+
----
-TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file.
+=== Gradle
+For Gradle users, add the following to your `build.gradle` file under the dependencies block to use just the `GemFireVectorStore`:
-== Sample Code
+[souce, xml]
+----
+dependencies {
+ implementation 'org.springframework.ai:spring-ai-gemfire-store'
+}
+----
-- To configure GemFire in your application, use the following setup:
+To enable Spring Boot’s Auto-Configuration for the `GemFireVectorStore`, also include this dependency:
-[source,java]
+[source, xml]
----
-@Bean
-public GemFireVectorStoreConfig gemFireVectorStoreConfig() {
- return GemFireVectorStoreConfig.builder()
- .withUrl("http://localhost:8080")
- .withIndexName("spring-ai-test-index")
- .build();
+dependencies {
+ implementation 'org.springframework.ai:spring-ai-gemfire-store-spring-boot-starter'
}
----
-- Create a GemFireVectorStore instance connected to your GemFire VectorDB:
+== Usage
+
+- Create a `GemFireVectorStore` instance connected to the GemFire cluster:
[source,java]
----
@Bean
-public VectorStore vectorStore(GemFireVectorStoreConfig config, EmbeddingModel embeddingModel) {
- return new GemFireVectorStore(config, embeddingModel);
+public VectorStore vectorStore(EmbeddingModel embeddingModel) {
+ return new GemFireVectorStore(new GemFireVectorStoreConfig()
+ .setIndexName("my-vector-index")
+ .setPort(7071), embeddingClient);
}
----
-- Create a Vector Index which will configure GemFire region.
+
+[NOTE]
+====
+The default configuration connects to a GemFire cluster at `localhost:8080`
+====
+
+- In your application, create a few documents:
[source,java]
----
- public void createIndex() {
- try {
- CreateRequest createRequest = new CreateRequest();
- createRequest.setName(INDEX_NAME);
- createRequest.setBeamWidth(20);
- createRequest.setMaxConnections(16);
- ObjectMapper objectMapper = new ObjectMapper();
- String index = objectMapper.writeValueAsString(createRequest);
- client.post()
- .contentType(MediaType.APPLICATION_JSON)
- .bodyValue(index)
- .retrieve()
- .bodyToMono(Void.class)
- .block();
- }
- catch (Exception e) {
- logger.warn("An unexpected error occurred while creating the index");
- }
- }
-----
-
-- Create some documents:
+List documents = List.of(
+ new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Map.of("country", "UK", "year", 2020)),
+ new Document("The World is Big and Salvation Lurks Around the Corner", Map.of()),
+ new Document("You walk forward facing the past and you turn back toward the future.", Map.of("country", "NL", "year", 2023)));
+----
+
+- Add the documents to the vector store:
[source,java]
----
- List documents = List.of(
- new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
- new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
- new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2")));
+vectorStore.add(documents);
----
-- Add the documents to GemFire VectorDB:
+- And to retrieve documents using similarity search:
[source,java]
----
-vectorStore.add(List.of(document));
+List results = vectorStore.similaritySearch(
+ SearchRequest.query("Spring").withTopK(5));
----
-- And finally, retrieve documents similar to a query:
+You should retrieve the document containing the text "Spring AI rocks!!".
+You can also limit the number of results using a similarity threshold:
[source,java]
----
- List results = vectorStore.similaritySearch("Spring", 5);
+List results = vectorStore.similaritySearch(
+ SearchRequest.query("Spring").withTopK(5)
+ .withSimilarityThreshold(0.5d));
----
-If all goes well, you should retrieve the document containing the text "Spring AI rocks!!".
+== GemFireVectorStore properties
+
+You can use the following properties in your Spring Boot configuration to further configure the `GemFireVectorStore`.
+
+|===
+|Property|Default value
+|`spring.ai.vectorstore.gemfire.host`|localhost
+|`spring.ai.vectorstore.gemfire.port`|8080
+|`spring.ai.vectorstore.gemfire.index-name`|spring-ai-gemfire-store
+|`spring.ai.vectorstore.gemfire.beam-width`|100
+|`spring.ai.vectorstore.gemfire.max-connections`|16
+|`spring.ai.vectorstore.gemfire.vector-similarity-function`|COSINE
+|`spring.ai.vectorstore.gemfire.fields`|[]
+|`spring.ai.vectorstore.gemfire.buckets`|0
+|===
diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml
index 291129335cb..97f3e56da41 100644
--- a/spring-ai-spring-boot-autoconfigure/pom.xml
+++ b/spring-ai-spring-boot-autoconfigure/pom.xml
@@ -297,6 +297,13 @@
true
+
+ org.springframework.ai
+ spring-ai-gemfire-store
+ ${project.parent.version}
+ true
+
+
org.springframework.ai
spring-ai-minimax
@@ -458,6 +465,13 @@
test
+
+ dev.gemfire
+ gemfire-testcontainers
+ 2.3.0
+ test
+
+
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java
new file mode 100644
index 00000000000..7792c045131
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireConnectionDetails.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.vectorstore.gemfire;
+
+import org.springframework.boot.autoconfigure.service.connection.ConnectionDetails;
+
+/**
+ * @author Geet Rawat
+ */
+public interface GemFireConnectionDetails extends ConnectionDetails {
+
+ String getHost();
+
+ int getPort();
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java
new file mode 100644
index 00000000000..098f2172fa4
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfiguration.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.autoconfigure.vectorstore.gemfire;
+
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.vectorstore.GemFireVectorStore;
+import org.springframework.ai.vectorstore.GemFireVectorStoreConfig;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.context.annotation.Bean;
+
+/**
+ * @author Geet Rawat
+ */
+@AutoConfiguration
+@ConditionalOnClass({ GemFireVectorStore.class, EmbeddingModel.class })
+@EnableConfigurationProperties(GemFireVectorStoreProperties.class)
+@ConditionalOnProperty(prefix = "spring.ai.vectorstore.gemfire", value = { "index-name" })
+public class GemFireVectorStoreAutoConfiguration {
+
+ @Bean
+ @ConditionalOnMissingBean(GemFireConnectionDetails.class)
+ GemFireVectorStoreAutoConfiguration.PropertiesGemFireConnectionDetails gemfireConnectionDetails(
+ GemFireVectorStoreProperties properties) {
+ return new GemFireVectorStoreAutoConfiguration.PropertiesGemFireConnectionDetails(properties);
+ }
+
+ @Bean
+ @ConditionalOnMissingBean
+ public GemFireVectorStore vectorStore(EmbeddingModel embeddingModel, GemFireVectorStoreProperties properties,
+ GemFireConnectionDetails gemFireConnectionDetails) {
+ var config = new GemFireVectorStoreConfig();
+
+ config.setHost(gemFireConnectionDetails.getHost())
+ .setPort(gemFireConnectionDetails.getPort())
+ .setIndexName(properties.getIndexName())
+ .setBeamWidth(properties.getBeamWidth())
+ .setMaxConnections(properties.getMaxConnections())
+ .setBuckets(properties.getBuckets())
+ .setVectorSimilarityFunction(properties.getVectorSimilarityFunction())
+ .setFields(properties.getFields())
+ .setSslEnabled(properties.isSslEnabled());
+ return new GemFireVectorStore(config, embeddingModel);
+ }
+
+ private static class PropertiesGemFireConnectionDetails implements GemFireConnectionDetails {
+
+ private final GemFireVectorStoreProperties properties;
+
+ PropertiesGemFireConnectionDetails(GemFireVectorStoreProperties properties) {
+
+ this.properties = properties;
+ }
+
+ @Override
+ public String getHost() {
+ return this.properties.getHost();
+ }
+
+ @Override
+ public int getPort() {
+ return this.properties.getPort();
+ }
+
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java
new file mode 100644
index 00000000000..cdcc27c04cd
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreProperties.java
@@ -0,0 +1,166 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.autoconfigure.vectorstore.gemfire;
+
+import org.springframework.ai.vectorstore.GemFireVectorStoreConfig;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+
+/**
+ * @author Geet Rawat
+ */
+@ConfigurationProperties(GemFireVectorStoreProperties.CONFIG_PREFIX)
+public class GemFireVectorStoreProperties {
+
+ /**
+ * Configuration prefix for Spring AI VectorStore GemFire.
+ */
+ public static final String CONFIG_PREFIX = "spring.ai.vectorstore.gemfire";
+
+ /**
+ * The host of the GemFire to connect to. To specify a custom host, use
+ * "spring.ai.vectorstore.gemfire.host";
+ *
+ */
+ private String host = GemFireVectorStoreConfig.DEFAULT_HOST;
+
+ /**
+ * The port of the GemFire to connect to. To specify a custom port, use
+ * "spring.ai.vectorstore.gemfire.port";
+ */
+ private int port = GemFireVectorStoreConfig.DEFAULT_PORT;
+
+ /**
+ * The name of the index in the GemFire. To specify a custom index, use
+ * "spring.ai.vectorstore.gemfire.index-name";
+ */
+ private String indexName = GemFireVectorStoreConfig.DEFAULT_INDEX_NAME;
+
+ /**
+ * The beam width for similarity queries. Default value is {@code 100}. To specify a
+ * custom beam width, use "spring.ai.vectorstore.gemfire.beam-width";
+ */
+ private int beamWidth = GemFireVectorStoreConfig.DEFAULT_BEAM_WIDTH;
+
+ /**
+ * The maximum number of connections allowed. Default value is {@code 16}. To specify
+ * custom number of connections, use "spring.ai.vectorstore.gemfire.max-connections";
+ */
+ private int maxConnections = GemFireVectorStoreConfig.DEFAULT_MAX_CONNECTIONS;
+
+ /**
+ * The similarity function to be used for vector comparisons. Default value is
+ * {@code "COSINE"}. To specify custom vectorSimilarityFunction, use
+ * "spring.ai.vectorstore.gemfire.vector-similarity-function";
+ *
+ */
+ private String vectorSimilarityFunction = GemFireVectorStoreConfig.DEFAULT_SIMILARITY_FUNCTION;
+
+ /**
+ * The fields to be used for queries. Default value is an array containing
+ * {@code "vector"}. To specify custom fields, use
+ * "spring.ai.vectorstore.gemfire.fields"
+ */
+ private String[] fields = GemFireVectorStoreConfig.DEFAULT_FIELDS;
+
+ /**
+ * The number of buckets to use for partitioning the data. Default value is {@code 0}.
+ *
+ * To specify custom buckets, use "spring.ai.vectorstore.gemfire.buckets";
+ *
+ */
+ private int buckets = GemFireVectorStoreConfig.DEFAULT_BUCKETS;
+
+ /**
+ * Set to true if GemFire cluster is ssl enabled
+ *
+ * To specify sslEnabled, use "spring.ai.vectorstore.gemfire.ssl-enabled";
+ */
+ private boolean sslEnabled = GemFireVectorStoreConfig.DEFAULT_SSL_ENABLED;
+
+ public int getBeamWidth() {
+ return beamWidth;
+ }
+
+ public void setBeamWidth(int beamWidth) {
+ this.beamWidth = beamWidth;
+ }
+
+ public int getPort() {
+ return port;
+ }
+
+ public void setPort(int port) {
+ this.port = port;
+ }
+
+ public String getHost() {
+ return host;
+ }
+
+ public void setHost(String host) {
+ this.host = host;
+ }
+
+ public String getIndexName() {
+ return indexName;
+ }
+
+ public void setIndexName(String indexName) {
+ this.indexName = indexName;
+ }
+
+ public int getMaxConnections() {
+ return maxConnections;
+ }
+
+ public void setMaxConnections(int maxConnections) {
+ this.maxConnections = maxConnections;
+ }
+
+ public String getVectorSimilarityFunction() {
+ return vectorSimilarityFunction;
+ }
+
+ public void setVectorSimilarityFunction(String vectorSimilarityFunction) {
+ this.vectorSimilarityFunction = vectorSimilarityFunction;
+ }
+
+ public String[] getFields() {
+ return fields;
+ }
+
+ public void setFields(String[] fields) {
+ this.fields = fields;
+ }
+
+ public int getBuckets() {
+ return buckets;
+ }
+
+ public void setBuckets(int buckets) {
+ this.buckets = buckets;
+ }
+
+ public boolean isSslEnabled() {
+ return sslEnabled;
+ }
+
+ public void setSslEnabled(boolean sslEnabled) {
+ this.sslEnabled = sslEnabled;
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
index 5117fd476e9..804d5c8fa0f 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
+++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
@@ -32,6 +32,7 @@ org.springframework.ai.autoconfigure.vectorstore.mongo.MongoDBAtlasVectorStoreAu
org.springframework.ai.autoconfigure.anthropic.AnthropicAutoConfiguration
org.springframework.ai.autoconfigure.watsonxai.WatsonxAiAutoConfiguration
org.springframework.ai.autoconfigure.vectorstore.elasticsearch.ElasticsearchVectorStoreAutoConfiguration
+org.springframework.ai.autoconfigure.vectorstore.gemfire.GemFireVectorStoreAutoConfiguration
org.springframework.ai.autoconfigure.vectorstore.cassandra.CassandraVectorStoreAutoConfiguration
org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration
org.springframework.ai.autoconfigure.chat.client.ChatClientAutoConfiguration
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java
new file mode 100644
index 00000000000..9c6f9ca607d
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java
@@ -0,0 +1,200 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.autoconfigure.vectorstore.gemfire;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.Matchers.hasSize;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.github.dockerjava.api.model.ExposedPort;
+import com.github.dockerjava.api.model.PortBinding;
+import com.github.dockerjava.api.model.Ports;
+import com.vmware.gemfire.testcontainers.GemFireCluster;
+import org.awaitility.Awaitility;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.ResourceUtils;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.transformers.TransformersEmbeddingModel;
+import org.springframework.ai.vectorstore.GemFireVectorStore;
+import org.springframework.ai.vectorstore.SearchRequest;
+import org.springframework.ai.vectorstore.VectorStore;
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+/**
+ * @author Geet Rawat
+ */
+
+class GemFireVectorStoreAutoConfigurationIT {
+
+ private static GemFireCluster gemFireCluster;
+
+ private static final String INDEX_NAME = "spring-ai-index";
+
+ private static final int BEAM_WIDTH = 50;
+
+ private static final int MAX_CONNECTIONS = 8;
+
+ private static final String SIMILARITY_FUNCTION = "DOT_PRODUCT";
+
+ private static final String[] FIELDS = { "someField1", "someField2" };
+
+ private static final int BUCKET_COUNT = 2;
+
+ private static final int HTTP_SERVICE_PORT = 9090;
+
+ private static final int LOCATOR_COUNT = 1;
+
+ private static final int SERVER_COUNT = 1;
+
+ @AfterAll
+ public static void stopGemFireCluster() {
+ gemFireCluster.close();
+ }
+
+ List documents = List.of(
+ new Document(ResourceUtils.getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")),
+ new Document(ResourceUtils.getText("classpath:/test/data/time.shelter.txt")), new Document(
+ ResourceUtils.getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad")));
+
+ private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
+ .withConfiguration(AutoConfigurations.of(GemFireVectorStoreAutoConfiguration.class))
+ .withUserConfiguration(Config.class)
+ .withPropertyValues("spring.ai.vectorstore.gemfire.index-name=" + INDEX_NAME)
+ .withPropertyValues("spring.ai.vectorstore.gemfire.beam-width=" + BEAM_WIDTH)
+ .withPropertyValues("spring.ai.vectorstore.gemfire.max-connections=" + MAX_CONNECTIONS)
+ .withPropertyValues("spring.ai.vectorstore.gemfire.vector-similarity-function=" + SIMILARITY_FUNCTION)
+ .withPropertyValues("spring.ai.vectorstore.gemfire.buckets=" + BUCKET_COUNT)
+ .withPropertyValues("spring.ai.vectorstore.gemfire.fields=someField1,someField2")
+ .withPropertyValues("spring.ai.vectorstore.gemfire.host=localhost")
+ .withPropertyValues("spring.ai.vectorstore.gemfire.port=" + HTTP_SERVICE_PORT);
+
+ @BeforeAll
+ public static void startGemFireCluster() {
+ Ports.Binding hostPort = Ports.Binding.bindPort(HTTP_SERVICE_PORT);
+ ExposedPort exposedPort = new ExposedPort(HTTP_SERVICE_PORT);
+ PortBinding mappedPort = new PortBinding(hostPort, exposedPort);
+ gemFireCluster = new GemFireCluster("gemfire/gemfire-all:10.1-jdk17", LOCATOR_COUNT, SERVER_COUNT);
+ gemFireCluster.withConfiguration(GemFireCluster.SERVER_GLOB,
+ container -> container.withExposedPorts(HTTP_SERVICE_PORT)
+ .withCreateContainerCmdModifier(cmd -> cmd.getHostConfig().withPortBindings(mappedPort)));
+ gemFireCluster.withGemFireProperty(GemFireCluster.SERVER_GLOB, "http-service-port",
+ Integer.toString(HTTP_SERVICE_PORT));
+ gemFireCluster.acceptLicense().start();
+
+ System.setProperty("spring.data.gemfire.pool.locators",
+ String.format("localhost[%d]", gemFireCluster.getLocatorPort()));
+ }
+
+ @Test
+ void ensureGemFireVectorStoreCustomConfiguration() {
+ this.contextRunner.run(context -> {
+ GemFireVectorStore store = context.getBean(GemFireVectorStore.class);
+ Assertions.assertNotNull(store);
+ assertThat(store.getIndexName()).isEqualTo(INDEX_NAME);
+ assertThat(store.getBeamWidth()).isEqualTo(BEAM_WIDTH);
+ assertThat(store.getMaxConnections()).isEqualTo(MAX_CONNECTIONS);
+ assertThat(store.getVectorSimilarityFunction()).isEqualTo(SIMILARITY_FUNCTION);
+ assertThat(store.getFields()).isEqualTo(FIELDS);
+
+ String indexJson = store.getIndex();
+ Map index = parseIndex(indexJson);
+ assertThat(index.get("name")).isEqualTo(INDEX_NAME);
+ assertThat(index.get("beam-width")).isEqualTo(BEAM_WIDTH);
+ assertThat(index.get("max-connections")).isEqualTo(MAX_CONNECTIONS);
+ assertThat(index.get("vector-similarity-function")).isEqualTo(SIMILARITY_FUNCTION);
+ assertThat(index.get("buckets")).isEqualTo(BUCKET_COUNT);
+
+ });
+ }
+
+ @Test
+ public void addAndSearchTest() {
+ contextRunner.run(context -> {
+ VectorStore vectorStore = context.getBean(VectorStore.class);
+ vectorStore.add(documents);
+
+ Awaitility.await().until(() -> {
+ return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1));
+ }, hasSize(1));
+
+ List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1));
+
+ assertThat(results).hasSize(1);
+ Document resultDoc = results.get(0);
+ assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId());
+ assertThat(resultDoc.getContent()).contains(
+ "Spring AI provides abstractions that serve as the foundation for developing AI applications.");
+ assertThat(resultDoc.getMetadata()).hasSize(2);
+ assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance");
+
+ // Remove all documents from the store
+ vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());
+
+ Awaitility.await().until(() -> {
+ return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1));
+ }, hasSize(0));
+ });
+ }
+
+ private Map parseIndex(String json) {
+ try {
+ JsonNode rootNode = new ObjectMapper().readTree(json);
+ Map indexDetails = new HashMap<>();
+ if (rootNode.isObject()) {
+ if (rootNode.has("name"))
+ indexDetails.put("name", rootNode.get("name").asText());
+ if (rootNode.has("beam-width"))
+ indexDetails.put("beam-width", rootNode.get("beam-width").asInt());
+ if (rootNode.has("max-connections"))
+ indexDetails.put("max-connections", rootNode.get("max-connections").asInt());
+ if (rootNode.has("vector-similarity-function"))
+ indexDetails.put("vector-similarity-function", rootNode.get("vector-similarity-function").asText());
+ if (rootNode.has("buckets"))
+ indexDetails.put("buckets", rootNode.get("buckets").asInt());
+ if (rootNode.has("number-of-embeddings"))
+ indexDetails.put("number-of-embeddings", rootNode.get("number-of-embeddings").asInt());
+ }
+ return indexDetails;
+ }
+ catch (Exception e) {
+ return new HashMap<>();
+ }
+ }
+
+ @Configuration(proxyBeanMethods = false)
+ static class Config {
+
+ @Bean
+ public EmbeddingModel embeddingModel() {
+ return new TransformersEmbeddingModel();
+ }
+
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java
new file mode 100644
index 00000000000..559e07778b2
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStorePropertiesTests.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.autoconfigure.vectorstore.gemfire;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.vectorstore.GemFireVectorStoreConfig;
+
+/**
+ * @author Geet Rawat
+ */
+class GemFireVectorStorePropertiesTests {
+
+ @Test
+ void defaultValues() {
+ var props = new GemFireVectorStoreProperties();
+ assertThat(props.getIndexName()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_INDEX_NAME);
+ assertThat(props.getHost()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_HOST);
+ assertThat(props.getPort()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_PORT);
+ assertThat(props.getBeamWidth()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_BEAM_WIDTH);
+ assertThat(props.getMaxConnections()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_MAX_CONNECTIONS);
+ assertThat(props.getFields()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_FIELDS);
+ assertThat(props.getBuckets()).isEqualTo(GemFireVectorStoreConfig.DEFAULT_BUCKETS);
+ }
+
+ @Test
+ void customValues() {
+ var props = new GemFireVectorStoreProperties();
+ props.setIndexName("spring-ai-index");
+ props.setHost("localhost");
+ props.setPort(9090);
+ props.setBeamWidth(10);
+ props.setMaxConnections(10);
+ props.setFields(new String[] { "test" });
+ props.setBuckets(10);
+
+ assertThat(props.getIndexName()).isEqualTo("spring-ai-index");
+ assertThat(props.getHost()).isEqualTo("localhost");
+ assertThat(props.getPort()).isEqualTo(9090);
+ assertThat(props.getBeamWidth()).isEqualTo(10);
+ assertThat(props.getMaxConnections()).isEqualTo(10);
+ assertThat(props.getFields()).isEqualTo(new String[] { "test" });
+ assertThat(props.getBuckets()).isEqualTo(10);
+
+ }
+
+}
diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml
new file mode 100644
index 00000000000..de48a933f30
--- /dev/null
+++ b/spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store/pom.xml
@@ -0,0 +1,42 @@
+
+
+ 4.0.0
+
+ org.springframework.ai
+ spring-ai
+ 1.0.0-SNAPSHOT
+ ../../pom.xml
+
+ spring-ai-gemfire-store-spring-boot-starter
+ jar
+ Spring AI Starter - GemFire Vector Store
+ Spring AI GemFire Vector Store Auto Configuration
+ https://github.com/spring-projects/spring-ai
+
+
+ https://github.com/spring-projects/spring-ai
+ git://github.com/spring-projects/spring-ai.git
+ git@github.com:spring-projects/spring-ai.git
+
+
+
+
+
+ org.springframework.boot
+ spring-boot-starter
+
+
+
+ org.springframework.ai
+ spring-ai-spring-boot-autoconfigure
+ ${project.parent.version}
+
+
+
+ org.springframework.ai
+ spring-ai-gemfire-store
+ ${project.parent.version}
+
+
+
+
diff --git a/vector-stores/spring-ai-gemfire-store/pom.xml b/vector-stores/spring-ai-gemfire-store/pom.xml
index 45f7e71637e..01ca725646b 100644
--- a/vector-stores/spring-ai-gemfire-store/pom.xml
+++ b/vector-stores/spring-ai-gemfire-store/pom.xml
@@ -38,6 +38,13 @@
+
+ dev.gemfire
+ gemfire-testcontainers
+ 2.3.0
+ test
+
+
org.springframework.ai
spring-ai-openai
@@ -71,10 +78,6 @@
3.0.0
test
-
- org.apache.logging.log4j
- log4j-core
-
diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java
index c193b21b667..0393de9c520 100644
--- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java
+++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2023 - 2024 the original author or authors.
+ * Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.beans.factory.InitializingBean;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
@@ -48,13 +49,11 @@
*
* @author Geet Rawat
*/
-public class GemFireVectorStore implements VectorStore {
-
- public static final String QUERY = "/query";
+public class GemFireVectorStore implements VectorStore, InitializingBean {
private static final Logger logger = LoggerFactory.getLogger(GemFireVectorStore.class);
- private static final String DISTANCE_METADATA_FIELD_NAME = "distance";
+ private static final String DEFAULT_URI = "http{ssl}://{host}:{port}/gemfire-vectordb/v1/indexes";
private static final String EMBEDDINGS = "/embeddings";
@@ -62,179 +61,135 @@ public class GemFireVectorStore implements VectorStore {
private final EmbeddingModel embeddingModel;
- private final int topKPerBucket;
-
- private final int topK;
-
- private final String documentField;
-
- public static final class GemFireVectorStoreConfig {
-
- private final WebClient client;
-
- private final String index;
-
- private final int topKPerBucket;
-
- public final int topK;
-
- private final String documentField;
-
- public static Builder builder() {
- return new Builder();
- }
-
- private GemFireVectorStoreConfig(Builder builder) {
- String base = UriComponentsBuilder.fromUriString(DEFAULT_URI)
- .build(builder.sslEnabled ? "s" : "", builder.host, builder.port)
- .toString();
- this.index = builder.index;
- this.client = WebClient.create(base);
- this.topKPerBucket = builder.topKPerBucket;
- this.topK = builder.topK;
- this.documentField = builder.documentField;
- }
-
- public static class Builder {
-
- private String host;
-
- private int port = DEFAULT_PORT;
-
- private boolean sslEnabled;
+ private static final String DOCUMENT_FIELD = "document";
- private long connectionTimeout;
+ // Create Index Parameters
- private long requestTimeout;
+ private String indexName;
- private String index;
+ public String getIndexName() {
+ return indexName;
+ }
- private int topKPerBucket = DEFAULT_TOP_K_PER_BUCKET;
+ private int beamWidth;
- private int topK = DEFAULT_TOP_K;
+ public int getBeamWidth() {
+ return beamWidth;
+ }
- private String documentField = DEFAULT_DOCUMENT_FIELD;
+ private int maxConnections;
- public Builder withHost(String host) {
- Assert.hasText(host, "host must have a value");
- this.host = host;
- return this;
- }
+ public int getMaxConnections() {
+ return maxConnections;
+ }
- public Builder withPort(int port) {
- Assert.isTrue(port > 0, "port must be positive");
- this.port = port;
- return this;
- }
+ private int buckets;
- public Builder withSslEnabled(boolean sslEnabled) {
- this.sslEnabled = sslEnabled;
- return this;
- }
-
- public Builder withConnectionTimeout(long timeout) {
- Assert.isTrue(timeout >= 0, "timeout must be >= 0");
- this.connectionTimeout = timeout;
- return this;
- }
+ public int getBuckets() {
+ return buckets;
+ }
- public Builder withRequestTimeout(long timeout) {
- Assert.isTrue(timeout >= 0, "timeout must be >= 0");
- this.requestTimeout = timeout;
- return this;
- }
+ private String vectorSimilarityFunction;
- public Builder withIndex(String index) {
- Assert.hasText(index, "index must have a value");
- this.index = index;
- return this;
- }
+ public String getVectorSimilarityFunction() {
+ return vectorSimilarityFunction;
+ }
- public Builder withTopKPerBucket(int topKPerBucket) {
- Assert.isTrue(topKPerBucket > 0, "topKPerBucket must be positive");
- this.topKPerBucket = topKPerBucket;
- return this;
- }
+ private String[] fields;
- public Builder withTopK(int topK) {
- Assert.isTrue(topK > 0, "topK must be positive");
- this.topK = topK;
- return this;
- }
+ public String[] getFields() {
+ return fields;
+ }
- public Builder withDocumentField(String documentField) {
- Assert.hasText(documentField, "documentField must have a value");
- this.documentField = documentField;
- return this;
- }
+ // Query Defaults
+ private static final String QUERY = "/query";
- public GemFireVectorStoreConfig build() {
- return new GemFireVectorStoreConfig(this);
- }
+ private static final String DISTANCE_METADATA_FIELD_NAME = "distance";
+ /**
+ * Initializes the GemFireVectorStore after properties are set. This method is called
+ * after all bean properties have been set and allows the bean to perform any
+ * initialization it requires.
+ */
+ @Override
+ public void afterPropertiesSet() throws Exception {
+ if (indexExists()) {
+ deleteIndex();
}
+ createIndex();
}
- private static final int DEFAULT_PORT = 9090;
-
- public static final String DEFAULT_URI = "http{ssl}://{host}:{port}/gemfire-vectordb/v1/indexes";
-
- private static final int DEFAULT_TOP_K_PER_BUCKET = 10;
-
- private static final int DEFAULT_TOP_K = 10;
-
- private static final String DEFAULT_DOCUMENT_FIELD = "document";
-
- public String indexName;
+ /**
+ * Checks if the index exists in the GemFireVectorStore.
+ * @return {@code true} if the index exists, {@code false} otherwise
+ */
+ public boolean indexExists() {
+ String indexResponse = getIndex();
+ return !indexResponse.isEmpty();
+ }
- public void setIndexName(String indexName) {
- this.indexName = indexName;
+ public String getIndex() {
+ return client.get().uri("/" + indexName).retrieve().bodyToMono(String.class).onErrorReturn("").block();
}
- public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingModel embedding) {
+ /**
+ * Configures and initializes a GemFireVectorStore instance based on the provided
+ * configuration.
+ * @param config the configuration for the GemFireVectorStore
+ * @param embeddingModel the embedding client used for generating embeddings
+ */
+
+ public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingModel embeddingModel) {
Assert.notNull(config, "GemFireVectorStoreConfig must not be null");
- Assert.notNull(embedding, "EmbeddingModel must not be null");
- this.client = config.client;
- this.embeddingModel = embedding;
- this.topKPerBucket = config.topKPerBucket;
- this.topK = config.topK;
- this.documentField = config.documentField;
+ Assert.notNull(embeddingModel, "EmbeddingModel must not be null");
+ this.indexName = config.indexName;
+ this.embeddingModel = embeddingModel;
+ this.beamWidth = config.beamWidth;
+ this.maxConnections = config.maxConnections;
+ this.buckets = config.buckets;
+ this.vectorSimilarityFunction = config.vectorSimilarityFunction;
+ this.fields = config.fields;
+
+ String base = UriComponentsBuilder.fromUriString(DEFAULT_URI)
+ .build(config.sslEnabled ? "s" : "", config.host, config.port)
+ .toString();
+ this.client = WebClient.create(base);
}
- private static final class CreateRequest {
+ public static class CreateRequest {
@JsonProperty("name")
- private String name;
+ private String indexName;
@JsonProperty("beam-width")
- private int beamWidth = 100;
+ private int beamWidth;
@JsonProperty("max-connections")
- private int maxConnections = 16;
+ private int maxConnections;
@JsonProperty("vector-similarity-function")
- private String vectorSimilarityFunction = "COSINE";
+ private String vectorSimilarityFunction;
@JsonProperty("fields")
- private String[] fields = new String[] { "vector" };
+ private String[] fields;
@JsonProperty("buckets")
- private int buckets = 0;
+ private int buckets;
public CreateRequest() {
}
- public CreateRequest(String name) {
- this.name = name;
+ public CreateRequest(String indexName) {
+ this.indexName = indexName;
}
- public String getName() {
- return name;
+ public String getIndexName() {
+ return indexName;
}
- public void setName(String name) {
- this.name = name;
+ public void setIndexName(String indexName) {
+ this.indexName = indexName;
}
public int getBeamWidth() {
@@ -419,7 +374,7 @@ public void add(List documents) {
// Compute and assign an embedding to the document.
document.setEmbedding(this.embeddingModel.embed(document));
List floatVector = document.getEmbedding().stream().map(Double::floatValue).toList();
- return new UploadRequest.Embedding(document.getId(), floatVector, documentField, document.getContent(),
+ return new UploadRequest.Embedding(document.getId(), floatVector, DOCUMENT_FIELD, document.getContent(),
document.getMetadata());
}).toList());
@@ -463,22 +418,26 @@ public Optional delete(List idList) {
@Override
public List similaritySearch(SearchRequest request) {
if (request.hasFilterExpression()) {
- throw new UnsupportedOperationException("Gemfire does not support metadata filter expressions yet.");
+ throw new UnsupportedOperationException("GemFire currently does not support metadata filter expressions.");
}
List vector = this.embeddingModel.embed(request.getQuery());
List floatVector = vector.stream().map(Double::floatValue).toList();
-
return client.post()
.uri("/" + indexName + QUERY)
.contentType(MediaType.APPLICATION_JSON)
- .bodyValue(new QueryRequest(floatVector, request.getTopK(), topKPerBucket, true))
+ .bodyValue(new QueryRequest(floatVector, request.getTopK(), request.getTopK(), // TopKPerBucket
+ true))
.retrieve()
.bodyToFlux(QueryResponse.class)
.filter(r -> r.score >= request.getSimilarityThreshold())
.map(r -> {
Map metadata = r.metadata;
+ if (r.metadata == null) {
+ metadata = new HashMap<>();
+ metadata.put(DOCUMENT_FIELD, "--Deleted--");
+ }
metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score);
- String content = (String) metadata.remove(documentField);
+ String content = (String) metadata.remove(DOCUMENT_FIELD);
return new Document(r.key, content, metadata);
})
.collectList()
@@ -486,10 +445,22 @@ public List similaritySearch(SearchRequest request) {
.block();
}
- public void createIndex(String indexName) throws JsonProcessingException {
+ /**
+ * Creates a new index in the GemFireVectorStore using specified parameters. This
+ * method is invoked during initialization.
+ * @throws JsonProcessingException if an error occurs during JSON processing
+ */
+ public void createIndex() throws JsonProcessingException {
CreateRequest createRequest = new CreateRequest(indexName);
+ createRequest.setBeamWidth(beamWidth);
+ createRequest.setMaxConnections(maxConnections);
+ createRequest.setBuckets(buckets);
+ createRequest.setVectorSimilarityFunction(vectorSimilarityFunction);
+ createRequest.setFields(fields);
+
ObjectMapper objectMapper = new ObjectMapper();
String index = objectMapper.writeValueAsString(createRequest);
+
client.post()
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(index)
@@ -499,9 +470,8 @@ public void createIndex(String indexName) throws JsonProcessingException {
.block();
}
- public void deleteIndex(String indexName) {
+ public void deleteIndex() {
DeleteRequest deleteRequest = new DeleteRequest();
- deleteRequest.setDeleteData(true);
client.method(HttpMethod.DELETE)
.uri("/" + indexName)
.body(BodyInserters.fromValue(deleteRequest))
@@ -511,6 +481,12 @@ public void deleteIndex(String indexName) {
.block();
}
+ /**
+ * Handles exceptions that occur during HTTP client operations and maps them to
+ * appropriate runtime exceptions.
+ * @param ex the exception that occurred during HTTP client operation
+ * @return a mapped runtime exception corresponding to the HTTP client exception
+ */
private Throwable handleHttpClientException(Throwable ex) {
if (!(ex instanceof WebClientResponseException clientException)) {
throw new RuntimeException(String.format("Got an unexpected error: %s", ex));
diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStoreConfig.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStoreConfig.java
new file mode 100644
index 00000000000..58dd6049633
--- /dev/null
+++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStoreConfig.java
@@ -0,0 +1,108 @@
+package org.springframework.ai.vectorstore;
+
+import org.springframework.util.Assert;
+
+public final class GemFireVectorStoreConfig {
+
+ // Create Index DEFAULT Values
+ public static final String DEFAULT_HOST = "localhost";
+
+ public static final int DEFAULT_PORT = 8080;
+
+ public static final String DEFAULT_INDEX_NAME = "spring-ai-gemfire-index";
+
+ public static final int UPPER_BOUND_BEAM_WIDTH = 3200;
+
+ public static final int DEFAULT_BEAM_WIDTH = 100;
+
+ private static final int UPPER_BOUND_MAX_CONNECTIONS = 512;
+
+ public static final int DEFAULT_MAX_CONNECTIONS = 16;
+
+ public static final String DEFAULT_SIMILARITY_FUNCTION = "COSINE";
+
+ public static final String[] DEFAULT_FIELDS = new String[] {};
+
+ public static final int DEFAULT_BUCKETS = 0;
+
+ public static final boolean DEFAULT_SSL_ENABLED = false;
+
+ String host = GemFireVectorStoreConfig.DEFAULT_HOST;
+
+ int port = DEFAULT_PORT;
+
+ String indexName = DEFAULT_INDEX_NAME;
+
+ int beamWidth = DEFAULT_BEAM_WIDTH;
+
+ int maxConnections = DEFAULT_MAX_CONNECTIONS;
+
+ String vectorSimilarityFunction = DEFAULT_SIMILARITY_FUNCTION;
+
+ String[] fields = DEFAULT_FIELDS;
+
+ int buckets = DEFAULT_BUCKETS;
+
+ boolean sslEnabled = DEFAULT_SSL_ENABLED;
+
+ public GemFireVectorStoreConfig setHost(String host) {
+ Assert.hasText(host, "host must have a value");
+ this.host = host;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig setPort(int port) {
+ Assert.isTrue(port > 0, "port must be positive");
+ this.port = port;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig setSslEnabled(boolean sslEnabled) {
+ this.sslEnabled = sslEnabled;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig setIndexName(String indexName) {
+ Assert.hasText(indexName, "indexName must have a value");
+ this.indexName = indexName;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig setBeamWidth(int beamWidth) {
+ Assert.isTrue(beamWidth > 0, "beamWidth must be positive");
+ Assert.isTrue(beamWidth <= GemFireVectorStoreConfig.UPPER_BOUND_BEAM_WIDTH,
+ "beamWidth must be less than or equal to " + GemFireVectorStoreConfig.UPPER_BOUND_BEAM_WIDTH);
+ this.beamWidth = beamWidth;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig setMaxConnections(int maxConnections) {
+ Assert.isTrue(maxConnections > 0, "maxConnections must be positive");
+ Assert.isTrue(maxConnections <= GemFireVectorStoreConfig.UPPER_BOUND_MAX_CONNECTIONS,
+ "maxConnections must be less than or equal to " + GemFireVectorStoreConfig.UPPER_BOUND_MAX_CONNECTIONS);
+ this.maxConnections = maxConnections;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig setBuckets(int buckets) {
+ Assert.isTrue(buckets >= 0, "bucket must be 1 or more");
+ this.buckets = buckets;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig setVectorSimilarityFunction(String vectorSimilarityFunction) {
+ Assert.hasText(vectorSimilarityFunction, "vectorSimilarityFunction must have a value");
+ this.vectorSimilarityFunction = vectorSimilarityFunction;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig setFields(String[] fields) {
+ this.fields = fields;
+ return this;
+ }
+
+ public GemFireVectorStoreConfig() {
+
+ }
+
+}
diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java
index 21de25e509c..72ac0760ed2 100644
--- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java
+++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java
@@ -26,16 +26,17 @@
import java.util.Map;
import java.util.UUID;
+import com.github.dockerjava.api.model.ExposedPort;
+import com.github.dockerjava.api.model.PortBinding;
+import com.github.dockerjava.api.model.Ports;
+import com.vmware.gemfire.testcontainers.GemFireCluster;
import org.awaitility.Awaitility;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
-
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
-import org.springframework.ai.vectorstore.GemFireVectorStore.GemFireVectorStoreConfig;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -46,11 +47,40 @@
* @author Geet Rawat
* @since 1.0.0
*/
-@EnabledIfEnvironmentVariable(named = "GEMFIRE_HOST", matches = ".+")
public class GemFireVectorStoreIT {
public static final String INDEX_NAME = "spring-ai-index1";
+ private static GemFireCluster gemFireCluster;
+
+ private static final int HTTP_SERVICE_PORT = 9090;
+
+ private static final int LOCATOR_COUNT = 1;
+
+ private static final int SERVER_COUNT = 1;
+
+ @AfterAll
+ public static void stopGemFireCluster() {
+ gemFireCluster.close();
+ }
+
+ @BeforeAll
+ public static void startGemFireCluster() {
+ Ports.Binding hostPort = Ports.Binding.bindPort(HTTP_SERVICE_PORT);
+ ExposedPort exposedPort = new ExposedPort(HTTP_SERVICE_PORT);
+ PortBinding mappedPort = new PortBinding(hostPort, exposedPort);
+ gemFireCluster = new GemFireCluster("gemfire/gemfire-all:10.1-jdk17", LOCATOR_COUNT, SERVER_COUNT);
+ gemFireCluster.withConfiguration(GemFireCluster.SERVER_GLOB,
+ container -> container.withExposedPorts(HTTP_SERVICE_PORT)
+ .withCreateContainerCmdModifier(cmd -> cmd.getHostConfig().withPortBindings(mappedPort)));
+ gemFireCluster.withGemFireProperty(GemFireCluster.SERVER_GLOB, "http-service-port",
+ Integer.toString(HTTP_SERVICE_PORT));
+ gemFireCluster.acceptLicense().start();
+
+ System.setProperty("spring.data.gemfire.pool.locators",
+ String.format("localhost[%d]", gemFireCluster.getLocatorPort()));
+ }
+
List documents = List.of(
new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
@@ -69,25 +99,16 @@ public static String getText(String uri) {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withUserConfiguration(TestApplication.class);
- @BeforeEach
- public void createIndex() {
- contextRunner.run(c -> c.getBean(GemFireVectorStore.class).createIndex(INDEX_NAME));
- }
-
- @AfterEach
- public void deleteIndex() {
- contextRunner.run(c -> c.getBean(GemFireVectorStore.class).deleteIndex(INDEX_NAME));
- }
-
@Test
public void addAndDeleteEmbeddingTest() {
contextRunner.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);
vectorStore.add(documents);
vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());
- Awaitility.await().atMost(1, MINUTES).until(() -> {
- return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(3));
- }, hasSize(0));
+ Awaitility.await()
+ .atMost(1, MINUTES)
+ .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(3)),
+ hasSize(0));
});
}
@@ -97,14 +118,15 @@ public void addAndSearchTest() {
VectorStore vectorStore = context.getBean(VectorStore.class);
vectorStore.add(documents);
- Awaitility.await().atMost(1, MINUTES).until(() -> {
- return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1));
- }, hasSize(1));
+ Awaitility.await()
+ .atMost(1, MINUTES)
+ .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)),
+ hasSize(1));
List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(5));
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());
- assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
+ assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939)" + " was an economic shock");
assertThat(resultDoc.getMetadata()).hasSize(2);
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
@@ -120,9 +142,10 @@ public void documentUpdateTest() {
Collections.singletonMap("meta1", "meta1"));
vectorStore.add(List.of(document));
SearchRequest springSearchRequest = SearchRequest.query("Spring").withTopK(5);
- Awaitility.await().atMost(1, MINUTES).until(() -> {
- return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1));
- }, hasSize(1));
+ Awaitility.await()
+ .atMost(1, MINUTES)
+ .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)),
+ hasSize(1));
List results = vectorStore.similaritySearch(springSearchRequest);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
@@ -131,7 +154,7 @@ public void documentUpdateTest() {
assertThat(resultDoc.getMetadata()).containsKey("distance");
Document sameIdDocument = new Document(document.getId(),
- "The World is Big and Salvation Lurks Around the Corner",
+ "The World is Big and Salvation Lurks " + "Around the Corner",
Collections.singletonMap("meta2", "meta2"));
vectorStore.add(List.of(sameIdDocument));
@@ -141,7 +164,7 @@ public void documentUpdateTest() {
assertThat(results).hasSize(1);
resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
- assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
+ assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation" + " Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
});
@@ -154,10 +177,11 @@ public void searchThresholdTest() {
VectorStore vectorStore = context.getBean(VectorStore.class);
vectorStore.add(documents);
- Awaitility.await().atMost(1, MINUTES).until(() -> {
- return vectorStore
- .similaritySearch(SearchRequest.query("Great Depression").withTopK(5).withSimilarityThresholdAll());
- }, hasSize(3));
+ Awaitility.await()
+ .atMost(1, MINUTES)
+ .until(() -> vectorStore
+ .similaritySearch(SearchRequest.query("Great Depression").withTopK(5).withSimilarityThresholdAll()),
+ hasSize(3));
List fullResult = vectorStore
.similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll());
@@ -173,7 +197,7 @@ public void searchThresholdTest() {
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());
- assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock");
+ assertThat(resultDoc.getContent()).contains("The Great Depression " + "(1929–1939) was an economic shock");
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");
});
@@ -185,14 +209,14 @@ public static class TestApplication {
@Bean
public GemFireVectorStoreConfig gemfireVectorStoreConfig() {
- return GemFireVectorStoreConfig.builder().withHost("localhost").build();
+ return new GemFireVectorStoreConfig().setHost("localhost")
+ .setPort(HTTP_SERVICE_PORT)
+ .setIndexName(INDEX_NAME);
}
@Bean
public GemFireVectorStore vectorStore(GemFireVectorStoreConfig config, EmbeddingModel embeddingModel) {
- GemFireVectorStore gemFireVectorStore = new GemFireVectorStore(config, embeddingModel);
- gemFireVectorStore.setIndexName(INDEX_NAME);
- return gemFireVectorStore;
+ return new GemFireVectorStore(config, embeddingModel);
}
@Bean
@@ -202,4 +226,4 @@ public EmbeddingModel embeddingModel() {
}
-}
\ No newline at end of file
+}