diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStore.java index 13a3f59cf6c..14222c0864b 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStore.java @@ -102,6 +102,16 @@ * } * *
+ * AWS OpenSearch Serverless usage example: + *
+ *{@code + * OpenSearchVectorStore vectorStore = OpenSearchVectorStore.builder(openSearchClient, embeddingModel) + * .initializeSchema(true) + * .manageDocumentIds(false) // Required for AWS OpenSearch Serverless + * .build(); + * }+ * + *
* Advanced configuration example: *
*{@code @@ -137,6 +147,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author inpink + * @author Sanghun Lee * @since 1.0.0 */ public class OpenSearchVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -170,6 +181,8 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem private String similarityFunction; + private final boolean manageDocumentIds; + /** * Creates a new OpenSearchVectorStore using the builder pattern. * @param builder The configured builder instance @@ -187,6 +200,7 @@ protected OpenSearchVectorStore(Builder builder) { // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces this.similarityFunction = builder.similarityFunction; this.initializeSchema = builder.initializeSchema; + this.manageDocumentIds = builder.manageDocumentIds; } /** @@ -210,14 +224,27 @@ public void doAdd(Listdocuments) { for (Document document : documents) { OpenSearchDocument openSearchDocument = new OpenSearchDocument(document.getId(), document.getText(), document.getMetadata(), embedding.get(documents.indexOf(document))); - bulkRequestBuilder.operations(op -> op - .index(idx -> idx.index(this.index).id(openSearchDocument.id()).document(openSearchDocument))); + + // Conditionally set document ID based on manageDocumentIds flag + if (this.manageDocumentIds) { + bulkRequestBuilder.operations(op -> op + .index(idx -> idx.index(this.index).id(openSearchDocument.id()).document(openSearchDocument))); + } + else { + bulkRequestBuilder + .operations(op -> op.index(idx -> idx.index(this.index).document(openSearchDocument))); + } } bulkRequest(bulkRequestBuilder.build()); } @Override public void doDelete(List idList) { + if (!this.manageDocumentIds) { + logger.warn("Document ID management is disabled. Delete operations may not work as expected " + + "since document IDs are auto-generated by OpenSearch. Consider using filter-based deletion instead."); + } + BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); for (String id : idList) { bulkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.index).id(id))); @@ -417,6 +444,8 @@ public static class Builder extends AbstractVectorStoreBuilder { private String similarityFunction = COSINE_SIMILARITY_FUNCTION; + private boolean manageDocumentIds = true; + /** * Sets the OpenSearch client. * @param openSearchClient The OpenSearch client to use @@ -488,6 +517,28 @@ public Builder similarityFunction(String similarityFunction) { return this; } + /** + * Sets whether to manage document IDs during indexing operations. + * + * When set to {@code true} (default), document IDs will be explicitly set during + * indexing operations. When set to {@code false}, OpenSearch will auto-generate + * document IDs, which is required for AWS OpenSearch Serverless vector search + * collections. + *
+ *+ * Note: When document ID management is disabled, the {@link #doDelete(List)} + * method may not work as expected since document IDs are auto-generated by + * OpenSearch. + *
+ * @param manageDocumentIds true to manage document IDs (default), false to let + * OpenSearch auto-generate IDs + * @return The builder instance + */ + public Builder manageDocumentIds(boolean manageDocumentIds) { + this.manageDocumentIds = manageDocumentIds; + return this; + } + /** * Builds a new OpenSearchVectorStore instance with the configured properties. * @return A new OpenSearchVectorStore instance diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreIT.java index 380d434c63b..47d04d4be0d 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreIT.java @@ -564,6 +564,161 @@ void getNativeClientTest() { }); } + @ParameterizedTest(name = "manageDocumentIds={0}") + @ValueSource(booleans = { true, false }) + void testManageDocumentIdsSetting(boolean manageDocumentIds) { + getContextRunner().run(context -> { + OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); + + // Create a new vector store with specific manageDocumentIds setting + OpenSearchVectorStore testVectorStore = OpenSearchVectorStore + .builder((OpenSearchClient) vectorStore.getNativeClient().orElseThrow(), + context.getBean(EmbeddingModel.class)) + .manageDocumentIds(manageDocumentIds) + .index("test_manage_document_ids_" + manageDocumentIds) + .initializeSchema(true) + .build(); + + // Test documents + ListtestDocuments = List.of(new Document("doc1", "Test content 1", Map.of("key1", "value1")), + new Document("doc2", "Test content 2", Map.of("key2", "value2"))); + + // Add documents + testVectorStore.add(testDocuments); + + // Wait for indexing + Awaitility.await() + .until(() -> testVectorStore + .similaritySearch(SearchRequest.builder().query("Test content").topK(2).build()), hasSize(2)); + + // Search and verify results + List results = testVectorStore + .similaritySearch(SearchRequest.builder().query("Test content").topK(2).build()); + + assertThat(results).hasSize(2); + + // Verify document content and metadata are preserved + assertThat(results.stream().map(Document::getText).toList()).containsExactlyInAnyOrder("Test content 1", + "Test content 2"); + + assertThat(results.stream().map(doc -> doc.getMetadata().get("key1")).toList()).contains("value1"); + assertThat(results.stream().map(doc -> doc.getMetadata().get("key2")).toList()).contains("value2"); + + // Clean up + testVectorStore.delete(testDocuments.stream().map(Document::getId).toList()); + }); + } + + @Test + void testManageDocumentIdsFalseForAWSOpenSearchServerless() { + getContextRunner().run(context -> { + OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); + + // Create vector store with manageDocumentIds=false (AWS OpenSearch Serverless + // mode) + OpenSearchVectorStore awsCompatibleVectorStore = OpenSearchVectorStore + .builder((OpenSearchClient) vectorStore.getNativeClient().orElseThrow(), + context.getBean(EmbeddingModel.class)) + .manageDocumentIds(false) + .index("test_aws_serverless_compatible") + .initializeSchema(true) + .build(); + + // Test documents with IDs (these should be ignored when + // manageDocumentIds=false) + List testDocuments = List.of( + new Document("custom-id-1", "AWS Serverless content 1", Map.of("env", "aws-serverless")), + new Document("custom-id-2", "AWS Serverless content 2", Map.of("env", "aws-serverless"))); + + // Add documents - should work without explicit document ID errors + awsCompatibleVectorStore.add(testDocuments); + + // Wait for indexing + Awaitility.await() + .until(() -> awsCompatibleVectorStore + .similaritySearch(SearchRequest.builder().query("AWS Serverless").topK(2).build()), hasSize(2)); + + // Search and verify results + List results = awsCompatibleVectorStore + .similaritySearch(SearchRequest.builder().query("AWS Serverless").topK(2).build()); + + assertThat(results).hasSize(2); + + // Verify content is preserved + assertThat(results.stream().map(Document::getText).toList()) + .containsExactlyInAnyOrder("AWS Serverless content 1", "AWS Serverless content 2"); + + // Verify metadata is preserved + assertThat(results.stream().map(doc -> doc.getMetadata().get("env")).toList()) + .containsOnly("aws-serverless"); + + // Clean up + awsCompatibleVectorStore.delete(List.of("_all")); + }); + } + + @Test + void testManageDocumentIdsTrueWithExplicitIds() { + getContextRunner().run(context -> { + OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); + + // Create vector store with manageDocumentIds=true (default behavior) + OpenSearchVectorStore explicitIdVectorStore = OpenSearchVectorStore + .builder((OpenSearchClient) vectorStore.getNativeClient().orElseThrow(), + context.getBean(EmbeddingModel.class)) + .manageDocumentIds(true) + .index("test_explicit_ids") + .initializeSchema(true) + .build(); + + // Test documents with specific IDs + List testDocuments = List.of( + new Document("explicit-id-1", "Explicit ID content 1", Map.of("type", "explicit")), + new Document("explicit-id-2", "Explicit ID content 2", Map.of("type", "explicit"))); + + // Add documents + explicitIdVectorStore.add(testDocuments); + + // Wait for indexing + Awaitility.await() + .until(() -> explicitIdVectorStore + .similaritySearch(SearchRequest.builder().query("Explicit ID").topK(2).build()), hasSize(2)); + + // Search and verify results + List results = explicitIdVectorStore + .similaritySearch(SearchRequest.builder().query("Explicit ID").topK(2).build()); + + assertThat(results).hasSize(2); + + // Verify document IDs are preserved + assertThat(results.stream().map(Document::getId).toList()).containsExactlyInAnyOrder("explicit-id-1", + "explicit-id-2"); + + // Verify content and metadata + assertThat(results.stream().map(Document::getText).toList()) + .containsExactlyInAnyOrder("Explicit ID content 1", "Explicit ID content 2"); + + assertThat(results.stream().map(doc -> doc.getMetadata().get("type")).toList()).containsOnly("explicit"); + + // Test deletion by specific IDs + explicitIdVectorStore.delete(List.of("explicit-id-1")); + + Awaitility.await() + .until(() -> explicitIdVectorStore + .similaritySearch(SearchRequest.builder().query("Explicit ID").topK(2).build()), hasSize(1)); + + // Verify only one document remains + results = explicitIdVectorStore + .similaritySearch(SearchRequest.builder().query("Explicit ID").topK(2).build()); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo("explicit-id-2"); + + // Clean up + explicitIdVectorStore.delete(List.of("explicit-id-2")); + }); + } + @SpringBootConfiguration @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication { diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreTest.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreTest.java new file mode 100644 index 00000000000..e39b6c3205e --- /dev/null +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreTest.java @@ -0,0 +1,206 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.opensearch; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.opensearch.core.BulkRequest; +import org.opensearch.client.opensearch.core.BulkResponse; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; + +/** + * Unit tests for OpenSearchVectorStore.doAdd() method. + * + * Focuses on testing the manageDocumentIds functionality and document ID handling. + */ +@ExtendWith(MockitoExtension.class) +@DisplayName("OpenSearchVectorStore.doAdd() Tests") +class OpenSearchVectorStoreTest { + + @Mock + private OpenSearchClient mockOpenSearchClient; + + @Mock + private EmbeddingModel mockEmbeddingModel; + + @Mock + private BulkResponse mockBulkResponse; + + @BeforeEach + void setUp() throws IOException { + // Use lenient to avoid UnnecessaryStubbingException + lenient().when(mockEmbeddingModel.dimensions()).thenReturn(3); + lenient().when(mockOpenSearchClient.bulk(any(BulkRequest.class))).thenReturn(mockBulkResponse); + lenient().when(mockBulkResponse.errors()).thenReturn(false); + } + + @ParameterizedTest(name = "manageDocumentIds={0}") + @ValueSource(booleans = { true, false }) + @DisplayName("Should handle document ID management setting correctly") + void shouldHandleDocumentIdManagementSetting(boolean manageDocumentIds) throws IOException { + // Given + when(mockEmbeddingModel.embed(any(), any(), any())) + .thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }, new float[] { 0.4f, 0.5f, 0.6f })); + + OpenSearchVectorStore vectorStore = createVectorStore(manageDocumentIds); + List documents = List.of(new Document("doc1", "content1", Map.of()), + new Document("doc2", "content2", Map.of())); + + // When + vectorStore.add(documents); + + // Then + BulkRequest capturedRequest = captureBulkRequest(); + assertThat(capturedRequest.operations()).hasSize(2); + + verifyDocumentIdHandling(capturedRequest, manageDocumentIds); + } + + @Test + @DisplayName("Should handle single document correctly") + void shouldHandleSingleDocumentCorrectly() throws IOException { + // Given + when(mockEmbeddingModel.embed(any(), any(), any())).thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f })); + + OpenSearchVectorStore vectorStore = createVectorStore(true); + Document document = new Document("test-id", "test content", Map.of("key", "value")); + + // When + vectorStore.add(List.of(document)); + + // Then + BulkRequest capturedRequest = captureBulkRequest(); + var operation = capturedRequest.operations().get(0); + + assertThat(operation.isIndex()).isTrue(); + assertThat(operation.index().id()).isEqualTo("test-id"); + assertThat(operation.index().document()).isNotNull(); + } + + @Test + @DisplayName("Should handle multiple documents with explicit IDs") + void shouldHandleMultipleDocumentsWithExplicitIds() throws IOException { + // Given + when(mockEmbeddingModel.embed(any(), any(), any())).thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }, + new float[] { 0.4f, 0.5f, 0.6f }, new float[] { 0.7f, 0.8f, 0.9f })); + + OpenSearchVectorStore vectorStore = createVectorStore(true); + List documents = List.of(new Document("doc1", "content1", Map.of()), + new Document("doc2", "content2", Map.of()), new Document("doc3", "content3", Map.of())); + + // When + vectorStore.add(documents); + + // Then + BulkRequest capturedRequest = captureBulkRequest(); + assertThat(capturedRequest.operations()).hasSize(3); + + for (int i = 0; i < 3; i++) { + var operation = capturedRequest.operations().get(i); + assertThat(operation.isIndex()).isTrue(); + assertThat(operation.index().id()).isEqualTo("doc" + (i + 1)); + } + } + + @Test + @DisplayName("Should handle multiple documents without explicit IDs") + void shouldHandleMultipleDocumentsWithoutExplicitIds() throws IOException { + // Given + when(mockEmbeddingModel.embed(any(), any(), any())) + .thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }, new float[] { 0.4f, 0.5f, 0.6f })); + + OpenSearchVectorStore vectorStore = createVectorStore(false); + List documents = List.of(new Document("doc1", "content1", Map.of()), + new Document("doc2", "content2", Map.of())); + + // When + vectorStore.add(documents); + + // Then + BulkRequest capturedRequest = captureBulkRequest(); + assertThat(capturedRequest.operations()).hasSize(2); + + for (var operation : capturedRequest.operations()) { + assertThat(operation.isIndex()).isTrue(); + assertThat(operation.index().id()).isNull(); + } + } + + @Test + @DisplayName("Should handle embedding model error") + void shouldHandleEmbeddingModelError() { + // Given + when(mockEmbeddingModel.embed(any(), any(), any())).thenThrow(new RuntimeException("Embedding failed")); + + OpenSearchVectorStore vectorStore = createVectorStore(true); + List documents = List.of(new Document("doc1", "content", Map.of())); + + // When & Then + assertThatThrownBy(() -> vectorStore.add(documents)).isInstanceOf(RuntimeException.class) + .hasMessageContaining("Embedding failed"); + } + + // Helper methods + + private OpenSearchVectorStore createVectorStore(boolean manageDocumentIds) { + return OpenSearchVectorStore.builder(mockOpenSearchClient, mockEmbeddingModel) + .manageDocumentIds(manageDocumentIds) + .build(); + } + + private BulkRequest captureBulkRequest() throws IOException { + ArgumentCaptor captor = ArgumentCaptor.forClass(BulkRequest.class); + verify(mockOpenSearchClient).bulk(captor.capture()); + return captor.getValue(); + } + + private void verifyDocumentIdHandling(BulkRequest request, boolean shouldHaveExplicitIds) { + for (int i = 0; i < request.operations().size(); i++) { + var operation = request.operations().get(i); + assertThat(operation.isIndex()).isTrue(); + + if (shouldHaveExplicitIds) { + assertThat(operation.index().id()).isEqualTo("doc" + (i + 1)); + } + else { + assertThat(operation.index().id()).isNull(); + } + } + } + +} \ No newline at end of file