Skip to content

Commit 2df3f3f

Browse files
committed
Add unit tests for MilvusVectorStore and MilvusSearchRequest
Introduce comprehensive unit tests to validate the functionality of MilvusVectorStore and MilvusSearchRequest, including scenarios for native and filter expressions. Refactor MilvusVectorStore to improve filter expression handling by introducing a helper method for converted expressions.
1 parent bc43a9b commit 2df3f3f

File tree

3 files changed

+217
-8
lines changed

3 files changed

+217
-8
lines changed

vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,16 +326,14 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
326326
String nativeFilterExpressions = "";
327327
String searchParamsJson = null;
328328
if (request instanceof MilvusSearchRequest milvusReq) {
329-
if (milvusReq.getNativeExpression() != null && !milvusReq.getNativeExpression().isEmpty()) {
330-
nativeFilterExpressions = milvusReq.getNativeExpression();
331-
}
332-
if (milvusReq.getSearchParamsJson() != null && !milvusReq.getSearchParamsJson().isEmpty()) {
333-
searchParamsJson = milvusReq.getSearchParamsJson();
334-
}
329+
nativeFilterExpressions = StringUtils.hasText(milvusReq.getNativeExpression())
330+
? milvusReq.getNativeExpression() : getConvertedFilterExpression(request);
331+
332+
searchParamsJson = StringUtils.hasText(milvusReq.getSearchParamsJson()) ? milvusReq.getSearchParamsJson()
333+
: null;
335334
}
336335
else {
337-
nativeFilterExpressions = (request.getFilterExpression() != null)
338-
? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
336+
nativeFilterExpressions = getConvertedFilterExpression(request);
339337
}
340338

341339
Assert.notNull(request.getQuery(), "Query string must not be null");
@@ -400,6 +398,11 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
400398
.toList();
401399
}
402400

401+
private String getConvertedFilterExpression(SearchRequest request) {
402+
return (request.getFilterExpression() != null)
403+
? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
404+
}
405+
403406
private float getResultSimilarity(RowRecord rowRecord) {
404407
Float score = (Float) rowRecord.get(SIMILARITY_FIELD_NAME);
405408
return (this.metricType == MetricType.IP || this.metricType == MetricType.COSINE) ? score : (1 - score);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package org.springframework.ai.vectorstore.milvus;
2+
3+
import org.junit.jupiter.api.Test;
4+
5+
import static org.assertj.core.api.Assertions.assertThat;
6+
import static org.springframework.ai.vectorstore.SearchRequest.DEFAULT_TOP_K;
7+
import static org.springframework.ai.vectorstore.SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL;
8+
9+
/**
10+
* Test class for verifying the functionality of the {@link MilvusSearchRequest} class.
11+
*
12+
* @author waileong
13+
*/
14+
class MilvusSearchRequestTest {
15+
16+
@Test
17+
void shouldBuildMilvusSearchRequestWithNativeExpression() {
18+
String query = "sample query";
19+
int topK = 10;
20+
double similarityThreshold = 0.8;
21+
String nativeExpression = "city LIKE 'New%'";
22+
String searchParamsJson = "{\"nprobe\":128}";
23+
24+
MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder()
25+
.query(query)
26+
.topK(topK)
27+
.similarityThreshold(similarityThreshold)
28+
.nativeExpression(nativeExpression)
29+
.searchParamsJson(searchParamsJson)
30+
.build();
31+
32+
assertThat(request.getQuery()).isEqualTo(query);
33+
assertThat(request.getTopK()).isEqualTo(topK);
34+
assertThat(request.getSimilarityThreshold()).isEqualTo(similarityThreshold);
35+
assertThat(request.getNativeExpression()).isEqualTo(nativeExpression);
36+
assertThat(request.getSearchParamsJson()).isEqualTo(searchParamsJson);
37+
}
38+
39+
@Test
40+
void shouldBuildMilvusSearchRequestWithDefaults() {
41+
MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder().build();
42+
43+
assertThat(request.getQuery()).isEmpty();
44+
assertThat(request.getTopK()).isEqualTo(DEFAULT_TOP_K);
45+
assertThat(request.getSimilarityThreshold()).isEqualTo(SIMILARITY_THRESHOLD_ACCEPT_ALL);
46+
assertThat(request.getNativeExpression()).isNull();
47+
assertThat(request.getSearchParamsJson()).isNull();
48+
}
49+
50+
@Test
51+
void shouldAllowSettingNativeExpressionIndependently() {
52+
String nativeExpression = "age > 30";
53+
MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder().nativeExpression(nativeExpression).build();
54+
55+
assertThat(request.getNativeExpression()).isEqualTo(nativeExpression);
56+
}
57+
58+
@Test
59+
void shouldAllowSettingSearchParamsJsonIndependently() {
60+
String searchParamsJson = "{\"metric_type\": \"IP\"}";
61+
MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder().searchParamsJson(searchParamsJson).build();
62+
63+
assertThat(request.getSearchParamsJson()).isEqualTo(searchParamsJson);
64+
}
65+
66+
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package org.springframework.ai.vectorstore.milvus;
2+
3+
import io.milvus.client.MilvusServiceClient;
4+
import io.milvus.grpc.SearchResultData;
5+
import io.milvus.grpc.SearchResults;
6+
import io.milvus.param.R;
7+
import io.milvus.param.dml.SearchParam;
8+
import io.milvus.response.SearchResultsWrapper;
9+
import org.junit.jupiter.api.BeforeEach;
10+
import org.junit.jupiter.api.Test;
11+
import org.junit.jupiter.api.extension.ExtendWith;
12+
import org.mockito.ArgumentCaptor;
13+
import org.mockito.Mock;
14+
import org.mockito.MockedConstruction;
15+
import org.mockito.MockedStatic;
16+
import org.mockito.junit.jupiter.MockitoExtension;
17+
import org.springframework.ai.document.Document;
18+
import org.springframework.ai.embedding.EmbeddingModel;
19+
import org.springframework.ai.model.EmbeddingUtils;
20+
import org.springframework.ai.vectorstore.SearchRequest;
21+
22+
import java.util.List;
23+
24+
import static org.assertj.core.api.Assertions.assertThat;
25+
import static org.mockito.ArgumentMatchers.any;
26+
import static org.mockito.Mockito.*;
27+
28+
/**
29+
* Unit test class for {@link MilvusVectorStore}.
30+
*
31+
* @author waileong
32+
*/
33+
@ExtendWith(MockitoExtension.class)
34+
class MilvusVectorStoreTest {
35+
36+
@Mock
37+
private MilvusServiceClient milvusClient;
38+
39+
@Mock
40+
private EmbeddingModel embeddingModel;
41+
42+
private MilvusVectorStore vectorStore;
43+
44+
@BeforeEach
45+
void setUp() {
46+
vectorStore = MilvusVectorStore.builder(milvusClient, embeddingModel).build();
47+
}
48+
49+
@Test
50+
void shouldPerformSimilaritySearchWithNativeExpression() {
51+
try (MockedStatic<EmbeddingUtils> mockedEmbeddingUtils = mockStatic(EmbeddingUtils.class);
52+
MockedConstruction<SearchResultsWrapper> mockedSearchResultsWrapper = mockConstruction(
53+
SearchResultsWrapper.class,
54+
(mock, context) -> when(mock.getRowRecords(0)).thenReturn(List.of()))) {
55+
56+
String query = "sample query";
57+
MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder()
58+
.query(query)
59+
.topK(5)
60+
.similarityThreshold(0.7)
61+
.nativeExpression("metadata[\"age\"] > 30") // this has higher priority
62+
.filterExpression("age <= 30") // this will be ignored
63+
.searchParamsJson("{\"nprobe\":128}")
64+
.build();
65+
66+
SearchParam capturedParam = performSimilaritySearch(mockedEmbeddingUtils, request);
67+
assertThat(capturedParam.getTopK()).isEqualTo(request.getTopK());
68+
assertThat(capturedParam.getExpr()).isEqualTo(request.getNativeExpression());
69+
assertThat(capturedParam.getParams()).isEqualTo(request.getSearchParamsJson());
70+
}
71+
}
72+
73+
@Test
74+
void shouldPerformSimilaritySearchWithFilterExpression() {
75+
try (MockedStatic<EmbeddingUtils> mockedEmbeddingUtils = mockStatic(EmbeddingUtils.class);
76+
MockedConstruction<SearchResultsWrapper> mockedSearchResultsWrapper = mockConstruction(
77+
SearchResultsWrapper.class,
78+
(mock, context) -> when(mock.getRowRecords(0)).thenReturn(List.of()))) {
79+
80+
String query = "sample query";
81+
MilvusSearchRequest request = MilvusSearchRequest.milvusBuilder()
82+
.query(query)
83+
.topK(5)
84+
.similarityThreshold(0.7)
85+
.filterExpression("age > 30")
86+
.searchParamsJson("{\"nprobe\":128}")
87+
.build();
88+
89+
SearchParam capturedParam = performSimilaritySearch(mockedEmbeddingUtils, request);
90+
91+
assertThat(capturedParam.getTopK()).isEqualTo(request.getTopK());
92+
assertThat(capturedParam.getExpr()).isEqualTo("metadata[\"age\"] > 30"); // filter
93+
assertThat(capturedParam.getParams()).isEqualTo(request.getSearchParamsJson());
94+
}
95+
}
96+
97+
@Test
98+
void shouldPerformSimilaritySearchWithOriginalSearchRequest() {
99+
try (MockedStatic<EmbeddingUtils> mockedEmbeddingUtils = mockStatic(EmbeddingUtils.class);
100+
MockedConstruction<SearchResultsWrapper> mockedSearchResultsWrapper = mockConstruction(
101+
SearchResultsWrapper.class,
102+
(mock, context) -> when(mock.getRowRecords(0)).thenReturn(List.of()))) {
103+
104+
String query = "sample query";
105+
SearchRequest request = SearchRequest.builder()
106+
.query(query)
107+
.topK(5)
108+
.similarityThreshold(0.7)
109+
.filterExpression("age > 30")
110+
.build();
111+
112+
SearchParam capturedParam = performSimilaritySearch(mockedEmbeddingUtils, request);
113+
114+
assertThat(capturedParam.getTopK()).isEqualTo(request.getTopK());
115+
assertThat(capturedParam.getExpr()).isEqualTo("metadata[\"age\"] > 30"); // filter
116+
assertThat(capturedParam.getParams()).isEqualTo("{}");
117+
}
118+
}
119+
120+
private SearchParam performSimilaritySearch(MockedStatic<EmbeddingUtils> mockedEmbeddingUtils,
121+
SearchRequest request) {
122+
List<Float> mockVector = List.of(1.0f, 2.0f, 3.0f);
123+
mockedEmbeddingUtils.when(() -> EmbeddingUtils.toList(any())).thenReturn(mockVector);
124+
125+
SearchResults mockResults = mock(SearchResults.class);
126+
when(mockResults.getResults()).thenReturn(SearchResultData.getDefaultInstance());
127+
128+
R<SearchResults> mockResponse = R.success(mockResults);
129+
when(milvusClient.search(any(SearchParam.class))).thenReturn(mockResponse);
130+
131+
ArgumentCaptor<SearchParam> searchParamCaptor = ArgumentCaptor.forClass(SearchParam.class);
132+
133+
List<Document> results = vectorStore.doSimilaritySearch(request);
134+
135+
assertThat(results).isNotNull();
136+
verify(milvusClient).search(searchParamCaptor.capture());
137+
return searchParamCaptor.getValue();
138+
}
139+
140+
}

0 commit comments

Comments
 (0)