diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java index e22d53b9950..ae97736905a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfiguration.java @@ -42,6 +42,7 @@ * @author Christian Tzolov * @author Eddú Meléndez * @author Soby Chacko + * @author Ilayaperumal Gopinathan */ @AutoConfiguration @ConditionalOnClass({ MilvusVectorStore.class, EmbeddingModel.class }) @@ -75,6 +76,11 @@ public MilvusVectorStore vectorStore(MilvusServiceClient milvusClient, Embedding .withMetricType(MetricType.valueOf(properties.getMetricType().name())) .withIndexParameters(properties.getIndexParameters()) .withEmbeddingDimension(properties.getEmbeddingDimension()) + .withIDFieldName(properties.getIdFieldName()) + .withAutoId(properties.isAutoId()) + .withContentFieldName(properties.getContentFieldName()) + .withMetadataFieldName(properties.getMetadataFieldName()) + .withEmbeddingFieldName(properties.getEmbeddingFieldName()) .build(); return new MilvusVectorStore(milvusClient, embeddingModel, config, properties.isInitializeSchema(), diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java index 9a17543b5db..13308307442 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java @@ -23,6 +23,7 @@ /** * @author Christian Tzolov + * @author Ilayaperumal Gopinathan */ @ConfigurationProperties(MilvusVectorStoreProperties.CONFIG_PREFIX) public class MilvusVectorStoreProperties extends CommonVectorStoreProperties { @@ -59,6 +60,31 @@ public class MilvusVectorStoreProperties extends CommonVectorStoreProperties { */ private String indexParameters = "{\"nlist\":1024}"; + /** + * The ID field name for the collection. + */ + private String idFieldName = MilvusVectorStore.DOC_ID_FIELD_NAME; + + /** + * Boolean flag to indicate if the auto-id is used. + */ + private boolean isAutoId = false; + + /** + * The content field name for the collection. + */ + private String contentFieldName = MilvusVectorStore.CONTENT_FIELD_NAME; + + /** + * The metadata field name for the collection. + */ + private String metadataFieldName = MilvusVectorStore.METADATA_FIELD_NAME; + + /** + * The embedding field name for the collection. + */ + private String embeddingFieldName = MilvusVectorStore.EMBEDDING_FIELD_NAME; + public String getDatabaseName() { return this.databaseName; } @@ -113,6 +139,50 @@ public void setIndexParameters(String indexParameters) { this.indexParameters = indexParameters; } + public String getIdFieldName() { + return this.idFieldName; + } + + public void setIdFieldName(String idFieldName) { + Assert.notNull(idFieldName, "idFieldName can not be null"); + this.idFieldName = idFieldName; + } + + public boolean isAutoId() { + return this.isAutoId; + } + + public void setAutoId(boolean autoId) { + this.isAutoId = autoId; + } + + public String getContentFieldName() { + return this.contentFieldName; + } + + public void setContentFieldName(String contentFieldName) { + Assert.notNull(contentFieldName, "contentFieldName can not be null"); + this.contentFieldName = contentFieldName; + } + + public String getMetadataFieldName() { + return this.metadataFieldName; + } + + public void setMetadataFieldName(String metadataFieldName) { + Assert.notNull(metadataFieldName, "metadataFieldName can not be null"); + this.metadataFieldName = metadataFieldName; + } + + public String getEmbeddingFieldName() { + return this.embeddingFieldName; + } + + public void setEmbeddingFieldName(String embeddingFieldName) { + Assert.notNull(embeddingFieldName, "embeddingFieldName can not be null"); + this.embeddingFieldName = embeddingFieldName; + } + public enum MilvusMetricType { /** diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java index 15723b9b0d2..e621a26dfd9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreAutoConfigurationIT.java @@ -46,6 +46,7 @@ * @author Eddú Meléndez * @author Soby Chacko * @author Thomas Vitale + * @author Ilayaperumal Gopinathan */ @Testcontainers public class MilvusVectorStoreAutoConfigurationIT { @@ -109,6 +110,57 @@ public void addAndSearch() { }); } + @Test + public void searchWithCustomFields() { + contextRunner + .withPropertyValues("spring.ai.vectorstore.milvus.metricType=COSINE", + "spring.ai.vectorstore.milvus.indexType=IVF_FLAT", + "spring.ai.vectorstore.milvus.embeddingDimension=384", + "spring.ai.vectorstore.milvus.collectionName=myCustomCollection", + "spring.ai.vectorstore.milvus.idFieldName=identity", + "spring.ai.vectorstore.milvus.contentFieldName=text", + "spring.ai.vectorstore.milvus.embeddingFieldName=vectors", + "spring.ai.vectorstore.milvus.metadataFieldName=meta", + "spring.ai.vectorstore.milvus.initializeSchema=true", + "spring.ai.vectorstore.milvus.client.host=" + milvus.getHost(), + "spring.ai.vectorstore.milvus.client.port=" + milvus.getMappedPort(19530)) + .run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); + + vectorStore.add(documents); + + assertObservationRegistry(observationRegistry, VectorStoreProvider.MILVUS, + VectorStoreObservationContext.Operation.ADD); + observationRegistry.clear(); + + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance"); + + assertObservationRegistry(observationRegistry, VectorStoreProvider.MILVUS, + VectorStoreObservationContext.Operation.QUERY); + observationRegistry.clear(); + + // Remove all documents from the store + vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList()); + + results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); + assertThat(results).hasSize(0); + + assertObservationRegistry(observationRegistry, VectorStoreProvider.MILVUS, + VectorStoreObservationContext.Operation.DELETE); + observationRegistry.clear(); + + }); + } + @Configuration(proxyBeanMethods = false) static class Config { diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java index fe91142344c..7aa809204a3 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java @@ -26,6 +26,7 @@ import io.micrometer.observation.ObservationRegistry; import io.milvus.client.MilvusServiceClient; import io.milvus.common.clientenum.ConsistencyLevelEnum; +import io.milvus.exception.ParamException; import io.milvus.grpc.DataType; import io.milvus.grpc.DescribeIndexResponse; import io.milvus.grpc.MutationResult; @@ -72,6 +73,7 @@ * @author Christian Tzolov * @author Soby Chacko * @author Thomas Vitale + * @author Ilayaperumal Gopinathan */ public class MilvusVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -94,9 +96,6 @@ public class MilvusVectorStore extends AbstractObservationVectorStore implements // Metadata, automatically assigned by Milvus. public static final String DISTANCE_FIELD_NAME = "distance"; - public static final List SEARCH_OUTPUT_FIELDS = List.of(DOC_ID_FIELD_NAME, CONTENT_FIELD_NAME, - METADATA_FIELD_NAME); - private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class); private static Map SIMILARITY_TYPE_MAPPING = Map.of(MetricType.COSINE, @@ -170,10 +169,13 @@ public void doAdd(List documents) { } List fields = new ArrayList<>(); - fields.add(new InsertParam.Field(DOC_ID_FIELD_NAME, docIdArray)); - fields.add(new InsertParam.Field(CONTENT_FIELD_NAME, contentArray)); - fields.add(new InsertParam.Field(METADATA_FIELD_NAME, metadataArray)); - fields.add(new InsertParam.Field(EMBEDDING_FIELD_NAME, embeddingArray)); + // Insert ID field only if it is not auto ID + if (!this.config.isAutoId) { + fields.add(new InsertParam.Field(this.config.idFieldName, docIdArray)); + } + fields.add(new InsertParam.Field(this.config.contentFieldName, contentArray)); + fields.add(new InsertParam.Field(this.config.metadataFieldName, metadataArray)); + fields.add(new InsertParam.Field(this.config.embeddingFieldName, embeddingArray)); InsertParam insertParam = InsertParam.newBuilder() .withDatabaseName(this.config.databaseName) @@ -191,7 +193,7 @@ public void doAdd(List documents) { public Optional doDelete(List idList) { Assert.notNull(idList, "Document id list must not be null"); - String deleteExpression = String.format("%s in [%s]", DOC_ID_FIELD_NAME, + String deleteExpression = String.format("%s in [%s]", this.config.idFieldName, idList.stream().map(id -> "'" + id + "'").collect(Collectors.joining(","))); R status = this.milvusClient.delete(DeleteParam.newBuilder() @@ -214,17 +216,20 @@ public List doSimilaritySearch(SearchRequest request) { ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : ""; Assert.notNull(request.getQuery(), "Query string must not be null"); - + List outFieldNames = new ArrayList<>(); + outFieldNames.add(this.config.idFieldName); + outFieldNames.add(this.config.contentFieldName); + outFieldNames.add(this.config.metadataFieldName); float[] embedding = this.embeddingModel.embed(request.getQuery()); var searchParamBuilder = SearchParam.newBuilder() .withCollectionName(this.config.collectionName) .withConsistencyLevel(ConsistencyLevelEnum.STRONG) .withMetricType(this.config.metricType) - .withOutFields(SEARCH_OUTPUT_FIELDS) + .withOutFields(outFieldNames) .withTopK(request.getTopK()) .withVectors(List.of(EmbeddingUtils.toList(embedding))) - .withVectorFieldName(EMBEDDING_FIELD_NAME); + .withVectorFieldName(this.config.embeddingFieldName); if (StringUtils.hasText(nativeFilterExpressions)) { searchParamBuilder.withExpr(nativeFilterExpressions); @@ -242,12 +247,19 @@ public List doSimilaritySearch(SearchRequest request) { .stream() .filter(rowRecord -> getResultSimilarity(rowRecord) >= request.getSimilarityThreshold()) .map(rowRecord -> { - String docId = (String) rowRecord.get(DOC_ID_FIELD_NAME); - String content = (String) rowRecord.get(CONTENT_FIELD_NAME); - JSONObject metadata = (JSONObject) rowRecord.get(METADATA_FIELD_NAME); - // inject the distance into the metadata. - metadata.put(DISTANCE_FIELD_NAME, 1 - getResultSimilarity(rowRecord)); - return new Document(docId, content, metadata.getInnerMap()); + String docId = String.valueOf(rowRecord.get(this.config.idFieldName)); + String content = (String) rowRecord.get(this.config.contentFieldName); + JSONObject metadata = null; + try { + metadata = (JSONObject) rowRecord.get(this.config.metadataFieldName); + // inject the distance into the metadata. + metadata.put(DISTANCE_FIELD_NAME, 1 - getResultSimilarity(rowRecord)); + } + catch (ParamException e) { + // skip the ParamException if metadata doesn't exist for the custom + // collection + } + return new Document(docId, content, (metadata != null) ? metadata.getInnerMap() : Map.of()); }) .toList(); } @@ -291,45 +303,9 @@ private boolean isDatabaseCollectionExists() { void createCollection() { if (!isDatabaseCollectionExists()) { - - FieldType docIdFieldType = FieldType.newBuilder() - .withName(DOC_ID_FIELD_NAME) - .withDataType(DataType.VarChar) - .withMaxLength(36) - .withPrimaryKey(true) - .withAutoID(false) - .build(); - FieldType contentFieldType = FieldType.newBuilder() - .withName(CONTENT_FIELD_NAME) - .withDataType(DataType.VarChar) - .withMaxLength(65535) - .build(); - FieldType metadataFieldType = FieldType.newBuilder() - .withName(METADATA_FIELD_NAME) - .withDataType(DataType.JSON) - .build(); - FieldType embeddingFieldType = FieldType.newBuilder() - .withName(EMBEDDING_FIELD_NAME) - .withDataType(DataType.FloatVector) - .withDimension(this.embeddingDimensions()) - .build(); - - CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder() - .withDatabaseName(this.config.databaseName) - .withCollectionName(this.config.collectionName) - .withDescription("Spring AI Vector Store") - .withConsistencyLevel(ConsistencyLevelEnum.STRONG) - .withShardsNum(2) - .addFieldType(docIdFieldType) - .addFieldType(contentFieldType) - .addFieldType(metadataFieldType) - .addFieldType(embeddingFieldType) - .build(); - - R collectionStatus = this.milvusClient.createCollection(createCollectionReq); - if (collectionStatus.getException() != null) { - throw new RuntimeException("Failed to create collection", collectionStatus.getException()); - } + createCollection(this.config.databaseName, this.config.collectionName, this.config.idFieldName, + this.config.isAutoId, this.config.contentFieldName, this.config.metadataFieldName, + this.config.embeddingFieldName); } R indexDescriptionResponse = this.milvusClient @@ -342,7 +318,7 @@ void createCollection() { R indexStatus = this.milvusClient.createIndex(CreateIndexParam.newBuilder() .withDatabaseName(this.config.databaseName) .withCollectionName(this.config.collectionName) - .withFieldName(EMBEDDING_FIELD_NAME) + .withFieldName(this.config.embeddingFieldName) .withIndexType(this.config.indexType) .withMetricType(this.config.metricType) .withExtraParam(this.config.indexParameters) @@ -364,6 +340,49 @@ void createCollection() { } } + void createCollection(String databaseName, String collectionName, String idFieldName, boolean isAutoId, + String contentFieldName, String metadataFieldName, String embeddingFieldName) { + FieldType docIdFieldType = FieldType.newBuilder() + .withName(idFieldName) + .withDataType(DataType.VarChar) + .withMaxLength(36) + .withPrimaryKey(true) + .withAutoID(isAutoId) + .build(); + FieldType contentFieldType = FieldType.newBuilder() + .withName(contentFieldName) + .withDataType(DataType.VarChar) + .withMaxLength(65535) + .build(); + FieldType metadataFieldType = FieldType.newBuilder() + .withName(metadataFieldName) + .withDataType(DataType.JSON) + .build(); + FieldType embeddingFieldType = FieldType.newBuilder() + .withName(embeddingFieldName) + .withDataType(DataType.FloatVector) + .withDimension(this.embeddingDimensions()) + .build(); + + CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder() + .withDatabaseName(databaseName) + .withCollectionName(collectionName) + .withDescription("Spring AI Vector Store") + .withConsistencyLevel(ConsistencyLevelEnum.STRONG) + .withShardsNum(2) + .addFieldType(docIdFieldType) + .addFieldType(contentFieldType) + .addFieldType(metadataFieldType) + .addFieldType(embeddingFieldType) + .build(); + + R collectionStatus = this.milvusClient.createCollection(createCollectionReq); + if (collectionStatus.getException() != null) { + throw new RuntimeException("Failed to create collection", collectionStatus.getException()); + } + + } + int embeddingDimensions() { if (this.config.embeddingDimension != INVALID_EMBEDDING_DIMENSION) { return this.config.embeddingDimension; @@ -443,6 +462,16 @@ public static class MilvusVectorStoreConfig { private final String indexParameters; + private final String idFieldName; + + private final boolean isAutoId; + + private final String contentFieldName; + + private final String metadataFieldName; + + private final String embeddingFieldName; + private MilvusVectorStoreConfig(Builder builder) { this.databaseName = builder.databaseName; this.collectionName = builder.collectionName; @@ -450,6 +479,11 @@ private MilvusVectorStoreConfig(Builder builder) { this.indexType = builder.indexType; this.metricType = builder.metricType; this.indexParameters = builder.indexParameters; + this.idFieldName = builder.idFieldName; + this.isAutoId = builder.isAutoId; + this.contentFieldName = builder.contentFieldName; + this.metadataFieldName = builder.metadataFieldName; + this.embeddingFieldName = builder.embeddingFieldName; } /** @@ -482,6 +516,16 @@ public static class Builder { private String indexParameters = "{\"nlist\":1024}"; + private String idFieldName = DOC_ID_FIELD_NAME; + + private boolean isAutoId = false; + + private String contentFieldName = CONTENT_FIELD_NAME; + + private String metadataFieldName = METADATA_FIELD_NAME; + + private String embeddingFieldName = EMBEDDING_FIELD_NAME; + private Builder() { } @@ -560,6 +604,58 @@ public Builder withEmbeddingDimension(int newEmbeddingDimension) { return this; } + /** + * Configures the ID field name. Default is {@value #DOC_ID_FIELD_NAME}. + * @param idFieldName The name for the ID field + * @return this builder + */ + public Builder withIDFieldName(String idFieldName) { + this.idFieldName = idFieldName; + return this; + } + + /** + * Configures the boolean flag if the auto-id is used. Default is false. + * @param isAutoId boolean flag to indicate if the auto-id is enabled + * @return this builder + */ + public Builder withAutoId(boolean isAutoId) { + this.isAutoId = isAutoId; + return this; + } + + /** + * Configures the content field name. Default is {@value #CONTENT_FIELD_NAME}. + * @param contentFieldName The name for the content field + * @return this builder + */ + public Builder withContentFieldName(String contentFieldName) { + this.contentFieldName = contentFieldName; + return this; + } + + /** + * Configures the metadata field name. Default is + * {@value #METADATA_FIELD_NAME}. + * @param metadataFieldName The name for the metadata field + * @return this builder + */ + public Builder withMetadataFieldName(String metadataFieldName) { + this.metadataFieldName = metadataFieldName; + return this; + } + + /** + * Configures the embedding field name. Default is + * {@value #EMBEDDING_FIELD_NAME}. + * @param embeddingFieldName The name for the embedding field + * @return this builder + */ + public Builder withEmbeddingFieldName(String embeddingFieldName) { + this.embeddingFieldName = embeddingFieldName; + return this; + } + /** * {@return the immutable configuration} */ diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreCustomFieldNamesIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreCustomFieldNamesIT.java new file mode 100644 index 00000000000..f8e0e9fd8e4 --- /dev/null +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreCustomFieldNamesIT.java @@ -0,0 +1,259 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore; + +import io.milvus.client.MilvusServiceClient; +import io.milvus.param.ConnectParam; +import io.milvus.param.IndexType; +import io.milvus.param.MetricType; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.vectorstore.MilvusVectorStore.MilvusVectorStoreConfig; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.core.io.DefaultResourceLoader; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ilayaperumal Gopinathan + */ +@Testcontainers +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +class MilvusVectorStoreCustomFieldNamesIT { + + @Container + private static MilvusContainer milvusContainer = new MilvusContainer(MilvusImage.DEFAULT_IMAGE); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + private void resetCollection(VectorStore vectorStore) { + ((MilvusVectorStore) vectorStore).dropCollection(); + ((MilvusVectorStore) vectorStore).createCollection(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "COSINE" }) + void searchWithCustomFieldNames(String metricType) { + + contextRunner + .withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType, + "test.spring.ai.vectorstore.milvus.idFieldName=document_id", + "test.spring.ai.vectorstore.milvus.contentFieldName=text", + "test.spring.ai.vectorstore.milvus.embeddingFieldName=vector", + "test.spring.ai.vectorstore.milvus.metadataFieldName=meta") + .run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + resetCollection(vectorStore); + + vectorStore.add(documents); + + List fullResult = vectorStore.similaritySearch(SearchRequest.query("Spring")); + + List distances = fullResult.stream() + .map(doc -> (Float) doc.getMetadata().get("distance")) + .toList(); + + assertThat(distances).hasSize(3); + + float threshold = (distances.get(0) + distances.get(1)) / 2; + + List results = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(String.valueOf(resultDoc.getId())).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "COSINE" }) + void searchWithoutMetadataFieldOverride(String metricType) { + + contextRunner + .withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType, + "test.spring.ai.vectorstore.milvus.idFieldName=identity", + "test.spring.ai.vectorstore.milvus.contentFieldName=text", + "test.spring.ai.vectorstore.milvus.embeddingFieldName=embed") + .run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + resetCollection(vectorStore); + + vectorStore.add(documents); + + List fullResult = vectorStore.similaritySearch(SearchRequest.query("Spring")); + + List distances = fullResult.stream() + .map(doc -> (Float) doc.getMetadata().get("distance")) + .toList(); + + assertThat(distances).hasSize(3); + + float threshold = (distances.get(0) + distances.get(1)) / 2; + + List results = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(String.valueOf(resultDoc.getId())).isEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "COSINE" }) + void searchWithAutoIdEnabled(String metricType) { + + contextRunner + .withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=" + metricType, + "test.spring.ai.vectorstore.milvus.isAutoId=true", + "test.spring.ai.vectorstore.milvus.idFieldName=identity", + "test.spring.ai.vectorstore.milvus.contentFieldName=media", + "test.spring.ai.vectorstore.milvus.metadataFieldName=meta", + "test.spring.ai.vectorstore.milvus.embeddingFieldName=embed") + .run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + resetCollection(vectorStore); + + vectorStore.add(documents); + + List fullResult = vectorStore.similaritySearch(SearchRequest.query("Spring")); + + List distances = fullResult.stream() + .map(doc -> (Float) doc.getMetadata().get("distance")) + .toList(); + + assertThat(distances).hasSize(3); + + float threshold = (distances.get(0) + distances.get(1)) / 2; + + List results = vectorStore + .similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThreshold(1 - threshold)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + // Verify that the auto ID is used + assertThat(String.valueOf(resultDoc.getId())).isNotEqualTo(documents.get(0).getId()); + assertThat(resultDoc.getContent()).contains( + "Spring AI provides abstractions that serve as the foundation for developing AI applications."); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Value("${test.spring.ai.vectorstore.milvus.metricType}") + private MetricType metricType; + + @Value("${test.spring.ai.vectorstore.milvus.idFieldName}") + private String idFieldName; + + @Value("${test.spring.ai.vectorstore.milvus.isAutoId:false}") + private Boolean isAutoId; + + @Value("${test.spring.ai.vectorstore.milvus.contentFieldName}") + private String contentFieldName; + + @Value("${test.spring.ai.vectorstore.milvus.embeddingFieldName}") + private String embeddingFieldName; + + @Value("${test.spring.ai.vectorstore.milvus.metadataFieldName:metadata}") + private String metadataFieldName; + + @Bean + VectorStore vectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel) { + MilvusVectorStoreConfig config = MilvusVectorStoreConfig.builder() + .withCollectionName("test_vector_store_custom_fields") + .withDatabaseName("default") + .withIndexType(IndexType.IVF_FLAT) + .withMetricType(metricType) + .withIDFieldName(idFieldName) + .withAutoId(isAutoId) + .withContentFieldName(contentFieldName) + .withEmbeddingFieldName(embeddingFieldName) + .withMetadataFieldName(metadataFieldName) + .build(); + return new MilvusVectorStore(milvusClient, embeddingModel, config, true, new TokenCountBatchingStrategy()); + } + + @Bean + MilvusServiceClient milvusClient() { + return new MilvusServiceClient(ConnectParam.newBuilder() + .withAuthorization("minioadmin", "minioadmin") + .withUri(milvusContainer.getEndpoint()) + .build()); + } + + @Bean + EmbeddingModel embeddingModel() { + return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + } + + } + +}