diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java index 5a91fbecc06..4eae298a86d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java @@ -33,7 +33,7 @@ * @author Thomas Vitale * @author Ilayaperumal Gopinathan */ -public final class SearchRequest { +public class SearchRequest { /** * Similarity threshold that accepts all search scores. A threshold value of 0.0 means @@ -71,6 +71,16 @@ public static Builder from(SearchRequest originalSearchRequest) { .filterExpression(originalSearchRequest.getFilterExpression()); } + public SearchRequest() { + } + + protected SearchRequest(SearchRequest original) { + this.query = original.query; + this.topK = original.topK; + this.similarityThreshold = original.similarityThreshold; + this.filterExpression = original.filterExpression; + } + public String getQuery() { return this.query; } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc index bac3c45832e..c58f1a8df4c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/milvus.adoc @@ -161,6 +161,58 @@ vectorStore.similaritySearch(SearchRequest.builder() NOTE: These filter expressions are converted into the equivalent Milvus filters. +== Using MilvusSearchRequest + +MilvusSearchRequest extends SearchRequest, allowing you to use Milvus-specific search parameters such as native expressions and search parameter JSON. + +[source,java] +---- +MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder() + .query("sample query") + .topK(5) + .similarityThreshold(0.7) + .nativeExpression("metadata[\"age\"] > 30") // Overrides filterExpression if both are set + .filterExpression("age <= 30") // Ignored if nativeExpression is set + .searchParamsJson("{\"nprobe\":128}") + .build(); +List results = vectorStore.similaritySearch(request); +---- +This allows greater flexibility when using Milvus-specific search features. + +== Importance of `nativeExpression` and `searchParamsJson` in `MilvusSearchRequest` + +These two parameters enhance Milvus search precision and ensure optimal query performance: + +*nativeExpression*: Enables additional filtering capabilities using Milvus' native filtering expressions. +https://milvus.io/docs/boolean.md[Milvus Filtering] + +Example: +[source,java] +---- +MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder() + .query("sample query") + .topK(5) + .nativeExpression("metadata['category'] == 'science'") + .build(); +---- + +*searchParamsJson*: Essential for tuning search behavior when using IVF_FLAT, Milvus' default index. +https://milvus.io/docs/index.md?tab=floating[Milvus Vector Index] + +By default, `IVF_FLAT` requires `nprobe` to be set for accurate results. If not specified, `nprobe` defaults to `1`, which can lead to poor recall or even zero search results. + +Example: +[source,java] +---- +MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder() + .query("sample query") + .topK(5) + .searchParamsJson("{\"nprobe\":128}") + .build(); +---- + +Using `nativeExpression` ensures advanced filtering, while `searchParamsJson` prevents ineffective searches caused by a low default `nprobe` value. + [[milvus-properties]] == Milvus VectorStore properties diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusSearchRequest.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusSearchRequest.java new file mode 100755 index 00000000000..25e95e6876c --- /dev/null +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusSearchRequest.java @@ -0,0 +1,164 @@ +package org.springframework.ai.vectorstore.milvus; + +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.lang.Nullable; + +/** + * A specialized {@link SearchRequest} for Milvus vector search, extending the base + * request with Milvus-specific parameters. + *

+ * This class introduces two additional fields: + *

+ *

+ * Use the {@link MilvusBuilder} to construct instances of this class. + * + * @author waileong + */ +public final class MilvusSearchRequest extends SearchRequest { + + @Nullable + private final String nativeExpression; + + @Nullable + private final String searchParamsJson; + + /** + * Private constructor to initialize a MilvusSearchRequest using the base request and + * builder. + * @param baseRequest The base {@link SearchRequest} containing standard search + * fields. + * @param builder The {@link MilvusBuilder} containing Milvus-specific parameters. + */ + private MilvusSearchRequest(SearchRequest baseRequest, MilvusBuilder builder) { + super(baseRequest); // Copy all standard fields + this.nativeExpression = builder.nativeExpression; + this.searchParamsJson = builder.searchParamsJson; + } + + /** + * Retrieves the native Milvus filter expression. + * @return A string representing the native Milvus expression, or {@code null} if not + * set. + */ + @Nullable + public String getNativeExpression() { + return this.nativeExpression; + } + + /** + * Retrieves the JSON-encoded search parameters. + * @return A JSON string containing search parameters, or {@code null} if not set. + */ + @Nullable + public String getSearchParamsJson() { + return this.searchParamsJson; + } + + /** + * Creates a new {@link MilvusBuilder} for constructing a {@link MilvusSearchRequest}. + * @return A new {@link MilvusBuilder} instance. + */ + public static MilvusBuilder milvusBuilder() { + return new MilvusBuilder(); + } + + /** + * Builder class for constructing instances of {@link MilvusSearchRequest}. + */ + public static class MilvusBuilder { + + private final SearchRequest.Builder baseBuilder = SearchRequest.builder(); + + @Nullable + private String nativeExpression; + + @Nullable + private String searchParamsJson; + + /** + * {@link Builder#query(java.lang.String)} + */ + public MilvusBuilder query(String query) { + this.baseBuilder.query(query); + return this; + } + + /** + * {@link Builder#topK(int)} + */ + public MilvusBuilder topK(int topK) { + this.baseBuilder.topK(topK); + return this; + } + + /** + * {@link Builder#similarityThreshold(double)} + */ + public MilvusBuilder similarityThreshold(double threshold) { + this.baseBuilder.similarityThreshold(threshold); + return this; + } + + /** + * {@link Builder#similarityThresholdAll()} + */ + public MilvusBuilder similarityThresholdAll() { + this.baseBuilder.similarityThresholdAll(); + return this; + } + + /** + * {@link Builder#filterExpression(String)} + */ + public MilvusBuilder filterExpression(String textExpression) { + this.baseBuilder.filterExpression(textExpression); + return this; + } + + /** + * {@link Builder#filterExpression(Filter.Expression)} + */ + public MilvusBuilder filterExpression(Filter.Expression expression) { + this.baseBuilder.filterExpression(expression); + return this; + } + + /** + * Sets the native Milvus filter expression. + * @param nativeExpression The native Milvus expression string. + * @return This builder instance. + */ + public MilvusBuilder nativeExpression(String nativeExpression) { + this.nativeExpression = nativeExpression; + return this; + } + + /** + * Sets the JSON-encoded search parameters. + * @param searchParamsJson A JSON string containing search parameters. + * @return This builder instance. + */ + public MilvusBuilder searchParamsJson(String searchParamsJson) { + this.searchParamsJson = searchParamsJson; + return this; + } + + /** + * Builds and returns a {@link MilvusSearchRequest} instance. + * @return A new {@link MilvusSearchRequest} object with the specified parameters. + */ + public MilvusSearchRequest build() { + SearchRequest parentRequest = this.baseBuilder.build(); + return new MilvusSearchRequest(parentRequest, this); + } + + } + +} diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java index 5f683b46926..0b8a938f430 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java @@ -323,9 +323,18 @@ protected void doDelete(Filter.Expression filterExpression) { @Override public List doSimilaritySearch(SearchRequest request) { + String nativeFilterExpressions = ""; + String searchParamsJson = null; + if (request instanceof MilvusSearchRequest milvusReq) { + nativeFilterExpressions = StringUtils.hasText(milvusReq.getNativeExpression()) + ? milvusReq.getNativeExpression() : getConvertedFilterExpression(request); - String nativeFilterExpressions = (request.getFilterExpression() != null) - ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : ""; + searchParamsJson = StringUtils.hasText(milvusReq.getSearchParamsJson()) ? milvusReq.getSearchParamsJson() + : null; + } + else { + nativeFilterExpressions = getConvertedFilterExpression(request); + } Assert.notNull(request.getQuery(), "Query string must not be null"); List outFieldNames = new ArrayList<>(); @@ -348,6 +357,10 @@ public List doSimilaritySearch(SearchRequest request) { searchParamBuilder.withExpr(nativeFilterExpressions); } + if (StringUtils.hasText(searchParamsJson)) { + searchParamBuilder.withParams(searchParamsJson); + } + R respSearch = this.milvusClient.search(searchParamBuilder.build()); if (respSearch.getException() != null) { @@ -385,6 +398,11 @@ public List doSimilaritySearch(SearchRequest request) { .toList(); } + private String getConvertedFilterExpression(SearchRequest request) { + return (request.getFilterExpression() != null) + ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : ""; + } + private float getResultSimilarity(RowRecord rowRecord) { Float score = (Float) rowRecord.get(SIMILARITY_FIELD_NAME); return (this.metricType == MetricType.IP || this.metricType == MetricType.COSINE) ? score : (1 - score); diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusSearchRequestTest.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusSearchRequestTest.java new file mode 100644 index 00000000000..a3c9555c5a9 --- /dev/null +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusSearchRequestTest.java @@ -0,0 +1,66 @@ +package org.springframework.ai.vectorstore.milvus; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.vectorstore.SearchRequest.DEFAULT_TOP_K; +import static org.springframework.ai.vectorstore.SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL; + +/** + * Test class for verifying the functionality of the {@link MilvusSearchRequest} class. + * + * @author waileong + */ +class MilvusSearchRequestTest { + + @Test + void shouldBuildMilvusSearchRequestWithNativeExpression() { + String query = "sample query"; + int topK = 10; + double similarityThreshold = 0.8; + String nativeExpression = "city LIKE 'New%'"; + String searchParamsJson = "{\"nprobe\":128}"; + + MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder() + .query(query) + .topK(topK) + .similarityThreshold(similarityThreshold) + .nativeExpression(nativeExpression) + .searchParamsJson(searchParamsJson) + .build(); + + assertThat(request.getQuery()).isEqualTo(query); + assertThat(request.getTopK()).isEqualTo(topK); + assertThat(request.getSimilarityThreshold()).isEqualTo(similarityThreshold); + assertThat(request.getNativeExpression()).isEqualTo(nativeExpression); + assertThat(request.getSearchParamsJson()).isEqualTo(searchParamsJson); + } + + @Test + void shouldBuildMilvusSearchRequestWithDefaults() { + MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder().build(); + + assertThat(request.getQuery()).isEmpty(); + assertThat(request.getTopK()).isEqualTo(DEFAULT_TOP_K); + assertThat(request.getSimilarityThreshold()).isEqualTo(SIMILARITY_THRESHOLD_ACCEPT_ALL); + assertThat(request.getNativeExpression()).isNull(); + assertThat(request.getSearchParamsJson()).isNull(); + } + + @Test + void shouldAllowSettingNativeExpressionIndependently() { + String nativeExpression = "age > 30"; + MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder().nativeExpression(nativeExpression).build(); + + assertThat(request.getNativeExpression()).isEqualTo(nativeExpression); + } + + @Test + void shouldAllowSettingSearchParamsJsonIndependently() { + String searchParamsJson = "{\"metric_type\": \"IP\"}"; + MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder().searchParamsJson(searchParamsJson).build(); + + assertThat(request.getSearchParamsJson()).isEqualTo(searchParamsJson); + } + +} diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreTest.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreTest.java new file mode 100644 index 00000000000..1e6d5d3786d --- /dev/null +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreTest.java @@ -0,0 +1,140 @@ +package org.springframework.ai.vectorstore.milvus; + +import io.milvus.client.MilvusServiceClient; +import io.milvus.grpc.SearchResultData; +import io.milvus.grpc.SearchResults; +import io.milvus.param.R; +import io.milvus.param.dml.SearchParam; +import io.milvus.response.SearchResultsWrapper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.model.EmbeddingUtils; +import org.springframework.ai.vectorstore.SearchRequest; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +/** + * Unit test class for {@link MilvusVectorStore}. + * + * @author waileong + */ +@ExtendWith(MockitoExtension.class) +class MilvusVectorStoreTest { + + @Mock + private MilvusServiceClient milvusClient; + + @Mock + private EmbeddingModel embeddingModel; + + private MilvusVectorStore vectorStore; + + @BeforeEach + void setUp() { + vectorStore = MilvusVectorStore.builder(milvusClient, embeddingModel).build(); + } + + @Test + void shouldPerformSimilaritySearchWithNativeExpression() { + try (MockedStatic mockedEmbeddingUtils = mockStatic(EmbeddingUtils.class); + MockedConstruction mockedSearchResultsWrapper = mockConstruction( + SearchResultsWrapper.class, + (mock, context) -> when(mock.getRowRecords(0)).thenReturn(List.of()))) { + + String query = "sample query"; + MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder() + .query(query) + .topK(5) + .similarityThreshold(0.7) + .nativeExpression("metadata[\"age\"] > 30") // this has higher priority + .filterExpression("age <= 30") // this will be ignored + .searchParamsJson("{\"nprobe\":128}") + .build(); + + SearchParam capturedParam = performSimilaritySearch(mockedEmbeddingUtils, request); + assertThat(capturedParam.getTopK()).isEqualTo(request.getTopK()); + assertThat(capturedParam.getExpr()).isEqualTo(request.getNativeExpression()); + assertThat(capturedParam.getParams()).isEqualTo(request.getSearchParamsJson()); + } + } + + @Test + void shouldPerformSimilaritySearchWithFilterExpression() { + try (MockedStatic mockedEmbeddingUtils = mockStatic(EmbeddingUtils.class); + MockedConstruction mockedSearchResultsWrapper = mockConstruction( + SearchResultsWrapper.class, + (mock, context) -> when(mock.getRowRecords(0)).thenReturn(List.of()))) { + + String query = "sample query"; + MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder() + .query(query) + .topK(5) + .similarityThreshold(0.7) + .filterExpression("age > 30") + .searchParamsJson("{\"nprobe\":128}") + .build(); + + SearchParam capturedParam = performSimilaritySearch(mockedEmbeddingUtils, request); + + assertThat(capturedParam.getTopK()).isEqualTo(request.getTopK()); + assertThat(capturedParam.getExpr()).isEqualTo("metadata[\"age\"] > 30"); // filter + assertThat(capturedParam.getParams()).isEqualTo(request.getSearchParamsJson()); + } + } + + @Test + void shouldPerformSimilaritySearchWithOriginalSearchRequest() { + try (MockedStatic mockedEmbeddingUtils = mockStatic(EmbeddingUtils.class); + MockedConstruction mockedSearchResultsWrapper = mockConstruction( + SearchResultsWrapper.class, + (mock, context) -> when(mock.getRowRecords(0)).thenReturn(List.of()))) { + + String query = "sample query"; + SearchRequest request = SearchRequest.builder() + .query(query) + .topK(5) + .similarityThreshold(0.7) + .filterExpression("age > 30") + .build(); + + SearchParam capturedParam = performSimilaritySearch(mockedEmbeddingUtils, request); + + assertThat(capturedParam.getTopK()).isEqualTo(request.getTopK()); + assertThat(capturedParam.getExpr()).isEqualTo("metadata[\"age\"] > 30"); // filter + assertThat(capturedParam.getParams()).isEqualTo("{}"); + } + } + + private SearchParam performSimilaritySearch(MockedStatic mockedEmbeddingUtils, + SearchRequest request) { + List mockVector = List.of(1.0f, 2.0f, 3.0f); + mockedEmbeddingUtils.when(() -> EmbeddingUtils.toList(any())).thenReturn(mockVector); + + SearchResults mockResults = mock(SearchResults.class); + when(mockResults.getResults()).thenReturn(SearchResultData.getDefaultInstance()); + + R mockResponse = R.success(mockResults); + when(milvusClient.search(any(SearchParam.class))).thenReturn(mockResponse); + + ArgumentCaptor searchParamCaptor = ArgumentCaptor.forClass(SearchParam.class); + + List results = vectorStore.doSimilaritySearch(request); + + assertThat(results).isNotNull(); + verify(milvusClient).search(searchParamCaptor.capture()); + return searchParamCaptor.getValue(); + } + +}