|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright 2023-2024 the original author or authors. | 
|  | 3 | + * | 
|  | 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 5 | + * you may not use this file except in compliance with the License. | 
|  | 6 | + * You may obtain a copy of the License at | 
|  | 7 | + * | 
|  | 8 | + *      https://www.apache.org/licenses/LICENSE-2.0 | 
|  | 9 | + * | 
|  | 10 | + * Unless required by applicable law or agreed to in writing, software | 
|  | 11 | + * distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 13 | + * See the License for the specific language governing permissions and | 
|  | 14 | + * limitations under the License. | 
|  | 15 | + */ | 
|  | 16 | + | 
|  | 17 | +package org.springframework.ai.vectorstore; | 
|  | 18 | + | 
|  | 19 | +import java.io.IOException; | 
|  | 20 | +import java.net.URISyntaxException; | 
|  | 21 | +import java.nio.charset.StandardCharsets; | 
|  | 22 | +import java.time.Duration; | 
|  | 23 | +import java.util.List; | 
|  | 24 | +import java.util.Map; | 
|  | 25 | +import java.util.concurrent.TimeUnit; | 
|  | 26 | + | 
|  | 27 | +import org.apache.hc.core5.http.HttpHost; | 
|  | 28 | +import org.awaitility.Awaitility; | 
|  | 29 | +import org.junit.jupiter.api.BeforeAll; | 
|  | 30 | +import org.junit.jupiter.api.BeforeEach; | 
|  | 31 | +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; | 
|  | 32 | +import org.junit.jupiter.params.ParameterizedTest; | 
|  | 33 | +import org.junit.jupiter.params.provider.ValueSource; | 
|  | 34 | +import org.opensearch.client.opensearch.OpenSearchClient; | 
|  | 35 | +import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; | 
|  | 36 | +import org.opensearch.testcontainers.OpensearchContainer; | 
|  | 37 | +import org.testcontainers.junit.jupiter.Container; | 
|  | 38 | +import org.testcontainers.junit.jupiter.Testcontainers; | 
|  | 39 | + | 
|  | 40 | +import org.springframework.ai.document.Document; | 
|  | 41 | +import org.springframework.ai.embedding.EmbeddingModel; | 
|  | 42 | +import org.springframework.ai.ollama.OllamaEmbeddingModel; | 
|  | 43 | +import org.springframework.ai.ollama.api.OllamaApi; | 
|  | 44 | +import org.springframework.ai.ollama.api.OllamaModel; | 
|  | 45 | +import org.springframework.ai.ollama.api.OllamaOptions; | 
|  | 46 | +import org.springframework.beans.factory.annotation.Qualifier; | 
|  | 47 | +import org.springframework.boot.SpringBootConfiguration; | 
|  | 48 | +import org.springframework.boot.test.context.runner.ApplicationContextRunner; | 
|  | 49 | +import org.springframework.context.annotation.Bean; | 
|  | 50 | +import org.springframework.core.io.DefaultResourceLoader; | 
|  | 51 | + | 
|  | 52 | +import static org.assertj.core.api.Assertions.assertThat; | 
|  | 53 | +import static org.hamcrest.Matchers.hasSize; | 
|  | 54 | + | 
|  | 55 | +@Testcontainers | 
|  | 56 | +@EnabledIfEnvironmentVariable(named = "OLLAMA_TESTS_ENABLED", matches = "true") | 
|  | 57 | +class OpenSearchVectorStoreWithOllamaIT { | 
|  | 58 | + | 
|  | 59 | +	@Container | 
|  | 60 | +	private static final OpensearchContainer<?> opensearchContainer = new OpensearchContainer<>( | 
|  | 61 | +			OpenSearchImage.DEFAULT_IMAGE); | 
|  | 62 | + | 
|  | 63 | +	private static final String DEFAULT = "cosinesimil"; | 
|  | 64 | + | 
|  | 65 | +	private List<Document> documents = List.of( | 
|  | 66 | +			new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), | 
|  | 67 | +			new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()), | 
|  | 68 | +			new Document("3", getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); | 
|  | 69 | + | 
|  | 70 | +	@BeforeAll | 
|  | 71 | +	public static void beforeAll() { | 
|  | 72 | +		Awaitility.setDefaultPollInterval(2, TimeUnit.SECONDS); | 
|  | 73 | +		Awaitility.setDefaultPollDelay(Duration.ZERO); | 
|  | 74 | +		Awaitility.setDefaultTimeout(Duration.ofMinutes(1)); | 
|  | 75 | +	} | 
|  | 76 | + | 
|  | 77 | +	private String getText(String uri) { | 
|  | 78 | +		var resource = new DefaultResourceLoader().getResource(uri); | 
|  | 79 | +		try { | 
|  | 80 | +			return resource.getContentAsString(StandardCharsets.UTF_8); | 
|  | 81 | +		} | 
|  | 82 | +		catch (IOException e) { | 
|  | 83 | +			throw new RuntimeException(e); | 
|  | 84 | +		} | 
|  | 85 | +	} | 
|  | 86 | + | 
|  | 87 | +	private ApplicationContextRunner getContextRunner() { | 
|  | 88 | +		return new ApplicationContextRunner().withUserConfiguration(TestApplication.class); | 
|  | 89 | +	} | 
|  | 90 | + | 
|  | 91 | +	@BeforeEach | 
|  | 92 | +	void cleanDatabase() { | 
|  | 93 | +		getContextRunner().run(context -> { | 
|  | 94 | +			VectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); | 
|  | 95 | +			vectorStore.delete(List.of("_all")); | 
|  | 96 | + | 
|  | 97 | +			VectorStore anotherVectorStore = context.getBean("anotherVectorStore", OpenSearchVectorStore.class); | 
|  | 98 | +			anotherVectorStore.delete(List.of("_all")); | 
|  | 99 | +		}); | 
|  | 100 | +	} | 
|  | 101 | + | 
|  | 102 | +	@ParameterizedTest(name = "{0} : {displayName} ") | 
|  | 103 | +	@ValueSource(strings = { DEFAULT, "l1", "l2", "linf" }) | 
|  | 104 | +	public void addAndSearchTest(String similarityFunction) { | 
|  | 105 | + | 
|  | 106 | +		getContextRunner().run(context -> { | 
|  | 107 | +			OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); | 
|  | 108 | + | 
|  | 109 | +			if (!DEFAULT.equals(similarityFunction)) { | 
|  | 110 | +				vectorStore.withSimilarityFunction(similarityFunction); | 
|  | 111 | +			} | 
|  | 112 | + | 
|  | 113 | +			vectorStore.add(this.documents); | 
|  | 114 | + | 
|  | 115 | +			Awaitility.await() | 
|  | 116 | +				.until(() -> vectorStore | 
|  | 117 | +					.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), | 
|  | 118 | +						hasSize(1)); | 
|  | 119 | + | 
|  | 120 | +			List<Document> results = vectorStore | 
|  | 121 | +				.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)); | 
|  | 122 | + | 
|  | 123 | +			assertThat(results).hasSize(1); | 
|  | 124 | +			Document resultDoc = results.get(0); | 
|  | 125 | +			assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); | 
|  | 126 | +			assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); | 
|  | 127 | +			assertThat(resultDoc.getMetadata()).hasSize(2); | 
|  | 128 | +			assertThat(resultDoc.getMetadata()).containsKey("meta2"); | 
|  | 129 | +			assertThat(resultDoc.getMetadata()).containsKey("distance"); | 
|  | 130 | + | 
|  | 131 | +			// Remove all documents from the store | 
|  | 132 | +			vectorStore.delete(this.documents.stream().map(Document::getId).toList()); | 
|  | 133 | + | 
|  | 134 | +			Awaitility.await() | 
|  | 135 | +				.until(() -> vectorStore | 
|  | 136 | +					.similaritySearch(SearchRequest.query("Great Depression").withTopK(1).withSimilarityThreshold(0)), | 
|  | 137 | +						hasSize(0)); | 
|  | 138 | +		}); | 
|  | 139 | +	} | 
|  | 140 | + | 
|  | 141 | +	@SpringBootConfiguration | 
|  | 142 | +	public static class TestApplication { | 
|  | 143 | + | 
|  | 144 | +		@Bean | 
|  | 145 | +		@Qualifier("vectorStore") | 
|  | 146 | +		public OpenSearchVectorStore vectorStore(EmbeddingModel embeddingModel) { | 
|  | 147 | +			try { | 
|  | 148 | +				return new OpenSearchVectorStore(new OpenSearchClient(ApacheHttpClient5TransportBuilder | 
|  | 149 | +					.builder(HttpHost.create(opensearchContainer.getHttpHostAddress())) | 
|  | 150 | +					.build()), embeddingModel, true); | 
|  | 151 | +			} | 
|  | 152 | +			catch (URISyntaxException e) { | 
|  | 153 | +				throw new RuntimeException(e); | 
|  | 154 | +			} | 
|  | 155 | +		} | 
|  | 156 | + | 
|  | 157 | +		@Bean | 
|  | 158 | +		@Qualifier("anotherVectorStore") | 
|  | 159 | +		public OpenSearchVectorStore anotherVectorStore(EmbeddingModel embeddingModel) { | 
|  | 160 | +			try { | 
|  | 161 | +				return new OpenSearchVectorStore("another_index", | 
|  | 162 | +						new OpenSearchClient(ApacheHttpClient5TransportBuilder | 
|  | 163 | +							.builder(HttpHost.create(opensearchContainer.getHttpHostAddress())) | 
|  | 164 | +							.build()), | 
|  | 165 | +						embeddingModel, OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION, | 
|  | 166 | +						true); | 
|  | 167 | +			} | 
|  | 168 | +			catch (URISyntaxException e) { | 
|  | 169 | +				throw new RuntimeException(e); | 
|  | 170 | +			} | 
|  | 171 | +		} | 
|  | 172 | + | 
|  | 173 | +		@Bean | 
|  | 174 | +		public EmbeddingModel embeddingModel() { | 
|  | 175 | +			return OllamaEmbeddingModel.builder() | 
|  | 176 | +				.withOllamaApi(new OllamaApi()) | 
|  | 177 | +				.withDefaultOptions(OllamaOptions.create() | 
|  | 178 | +					.withModel(OllamaModel.MXBAI_EMBED_LARGE) | 
|  | 179 | +					.withMainGPU(11) | 
|  | 180 | +					.withUseMMap(true) | 
|  | 181 | +					.withNumGPU(1)) | 
|  | 182 | +				.build(); | 
|  | 183 | +		} | 
|  | 184 | + | 
|  | 185 | +	} | 
|  | 186 | + | 
|  | 187 | +} | 
0 commit comments