|
44 | 44 | import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; |
45 | 45 | import org.junit.jupiter.params.ParameterizedTest; |
46 | 46 | import org.junit.jupiter.params.provider.ValueSource; |
| 47 | +import org.springframework.ai.model.SimpleApiKey; |
47 | 48 | import org.testcontainers.elasticsearch.ElasticsearchContainer; |
48 | 49 | import org.testcontainers.junit.jupiter.Container; |
49 | 50 | import org.testcontainers.junit.jupiter.Testcontainers; |
|
54 | 55 | import org.springframework.ai.openai.OpenAiEmbeddingModel; |
55 | 56 | import org.springframework.ai.openai.api.OpenAiApi; |
56 | 57 | import org.springframework.ai.vectorstore.SearchRequest; |
57 | | -import org.springframework.ai.vectorstore.filter.Filter; |
58 | 58 | import org.springframework.ai.vectorstore.filter.Filter.Expression; |
59 | 59 | import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; |
60 | 60 | import org.springframework.ai.vectorstore.filter.Filter.Key; |
@@ -117,10 +117,11 @@ void cleanDatabase() { |
117 | 117 | }); |
118 | 118 | } |
119 | 119 |
|
120 | | - @Test |
121 | | - public void addAndDeleteDocumentsTest() { |
| 120 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 121 | + @ValueSource(strings = { "cosine", "custom_embedding_field" }) |
| 122 | + public void addAndDeleteDocumentsTest(String vectorStoreBeanName) { |
122 | 123 | getContextRunner().run(context -> { |
123 | | - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 124 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
124 | 125 | ElasticsearchVectorStore.class); |
125 | 126 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class); |
126 | 127 |
|
@@ -149,10 +150,11 @@ public void addAndDeleteDocumentsTest() { |
149 | 150 | }); |
150 | 151 | } |
151 | 152 |
|
152 | | - @Test |
153 | | - public void deleteDocumentsByFilterExpressionTest() { |
| 153 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 154 | + @ValueSource(strings = { "cosine", "custom_embedding_field" }) |
| 155 | + public void deleteDocumentsByFilterExpressionTest(String vectorStoreBeanName) { |
154 | 156 | getContextRunner().run(context -> { |
155 | | - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 157 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
156 | 158 | ElasticsearchVectorStore.class); |
157 | 159 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class); |
158 | 160 |
|
@@ -202,10 +204,11 @@ public void deleteDocumentsByFilterExpressionTest() { |
202 | 204 | }); |
203 | 205 | } |
204 | 206 |
|
205 | | - @Test |
206 | | - public void deleteWithStringFilterExpressionTest() { |
| 207 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 208 | + @ValueSource(strings = { "cosine", "custom_embedding_field" }) |
| 209 | + public void deleteWithStringFilterExpressionTest(String vectorStoreBeanName) { |
207 | 210 | getContextRunner().run(context -> { |
208 | | - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 211 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
209 | 212 | ElasticsearchVectorStore.class); |
210 | 213 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class); |
211 | 214 |
|
@@ -234,12 +237,12 @@ public void deleteWithStringFilterExpressionTest() { |
234 | 237 | } |
235 | 238 |
|
236 | 239 | @ParameterizedTest(name = "{0} : {displayName} ") |
237 | | - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
238 | | - public void addAndSearchTest(String similarityFunction) { |
| 240 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 241 | + public void addAndSearchTest(String vectorStoreBeanName) { |
239 | 242 |
|
240 | 243 | getContextRunner().run(context -> { |
241 | 244 |
|
242 | | - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 245 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
243 | 246 | ElasticsearchVectorStore.class); |
244 | 247 |
|
245 | 248 | vectorStore.add(this.documents); |
@@ -271,11 +274,11 @@ public void addAndSearchTest(String similarityFunction) { |
271 | 274 | } |
272 | 275 |
|
273 | 276 | @ParameterizedTest(name = "{0} : {displayName} ") |
274 | | - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
275 | | - public void searchWithFilters(String similarityFunction) { |
| 277 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 278 | + public void searchWithFilters(String vectorStoreBeanName) { |
276 | 279 |
|
277 | 280 | getContextRunner().run(context -> { |
278 | | - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 281 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
279 | 282 | ElasticsearchVectorStore.class); |
280 | 283 |
|
281 | 284 | var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner", |
@@ -385,11 +388,11 @@ public void searchWithFilters(String similarityFunction) { |
385 | 388 | } |
386 | 389 |
|
387 | 390 | @ParameterizedTest(name = "{0} : {displayName} ") |
388 | | - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
389 | | - public void documentUpdateTest(String similarityFunction) { |
| 391 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 392 | + public void documentUpdateTest(String vectorStoreBeanName) { |
390 | 393 |
|
391 | 394 | getContextRunner().run(context -> { |
392 | | - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 395 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
393 | 396 | ElasticsearchVectorStore.class); |
394 | 397 |
|
395 | 398 | Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", |
@@ -443,10 +446,10 @@ public void documentUpdateTest(String similarityFunction) { |
443 | 446 | } |
444 | 447 |
|
445 | 448 | @ParameterizedTest(name = "{0} : {displayName} ") |
446 | | - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
447 | | - public void searchThresholdTest(String similarityFunction) { |
| 449 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 450 | + public void searchThresholdTest(String vectorStoreBeanName) { |
448 | 451 | getContextRunner().run(context -> { |
449 | | - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 452 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
450 | 453 | ElasticsearchVectorStore.class); |
451 | 454 |
|
452 | 455 | vectorStore.add(this.documents); |
@@ -581,9 +584,20 @@ public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingMo |
581 | 584 | .build(); |
582 | 585 | } |
583 | 586 |
|
| 587 | + @Bean("vectorStore_custom_embedding_field") |
| 588 | + public ElasticsearchVectorStore vectorStoreCustomField(EmbeddingModel embeddingModel, RestClient restClient) { |
| 589 | + ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); |
| 590 | + options.setEmbeddingFieldName("custom_embedding_field"); |
| 591 | + return ElasticsearchVectorStore.builder(restClient, embeddingModel) |
| 592 | + .initializeSchema(true) |
| 593 | + .options(options) |
| 594 | + .build(); |
| 595 | + } |
| 596 | + |
584 | 597 | @Bean |
585 | 598 | public EmbeddingModel embeddingModel() { |
586 | | - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); |
| 599 | + return new OpenAiEmbeddingModel( |
| 600 | + OpenAiApi.builder().apiKey(new SimpleApiKey(System.getenv("OPENAI_API_KEY"))).build()); |
587 | 601 | } |
588 | 602 |
|
589 | 603 | @Bean |
|
0 commit comments