Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* Provides the API for embedding observations.
*/
@NonNullApi
@NonNullFields
package org.springframework.ai.embedding;

import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,18 @@
public abstract class AbstractVectorStoreBuilder<T extends AbstractVectorStoreBuilder<T>>
implements VectorStore.Builder<T> {

protected EmbeddingModel embeddingModel;
protected final EmbeddingModel embeddingModel;

protected ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

@Nullable
protected VectorStoreObservationConvention customObservationConvention;

public AbstractVectorStoreBuilder(EmbeddingModel embeddingModel) {
Assert.notNull(embeddingModel, "EmbeddingModel must be configured");
this.embeddingModel = embeddingModel;
}

public EmbeddingModel getEmbeddingModel() {
return this.embeddingModel;
}
Expand Down Expand Up @@ -71,20 +76,9 @@ public T observationRegistry(ObservationRegistry observationRegistry) {
}

@Override
public T customObservationConvention(VectorStoreObservationConvention convention) {
public T customObservationConvention(@Nullable VectorStoreObservationConvention convention) {
this.customObservationConvention = convention;
return self();
}

@Override
public T embeddingModel(EmbeddingModel embeddingModel) {
Assert.notNull(embeddingModel, "EmbeddingModel must not be null");
this.embeddingModel = embeddingModel;
return self();
}

protected void validate() {
Assert.notNull(this.embeddingModel, "EmbeddingModel must be configured");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ public SimpleVectorStore(EmbeddingModel embeddingModel) {
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry observationRegistry,
VectorStoreObservationConvention customObservationConvention) {
this(builder().embeddingModel(embeddingModel)
.observationRegistry(observationRegistry)
this(builder(embeddingModel).observationRegistry(observationRegistry)
.customObservationConvention(customObservationConvention));
}

Expand All @@ -106,8 +105,8 @@ protected SimpleVectorStore(SimpleVectorStoreBuilder builder) {
* Creates an instance of SimpleVectorStore builder.
* @return the SimpleVectorStore builder.
*/
public static SimpleVectorStoreBuilder builder() {
return new SimpleVectorStoreBuilder();
public static SimpleVectorStoreBuilder builder(EmbeddingModel embeddingModel) {
return new SimpleVectorStoreBuilder(embeddingModel);
}

@Override
Expand Down Expand Up @@ -297,9 +296,12 @@ public static float norm(float[] vector) {

public static final class SimpleVectorStoreBuilder extends AbstractVectorStoreBuilder<SimpleVectorStoreBuilder> {

private SimpleVectorStoreBuilder(EmbeddingModel embeddingModel) {
super(embeddingModel);
}

@Override
public SimpleVectorStore build() {
validate();
return new SimpleVectorStore(this);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentWriter;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -59,6 +58,7 @@ default void accept(List<Document> documents) {
* @param idList list of document ids for which documents will be removed.
* @return Returns true if the documents were successfully deleted.
*/
@Nullable
Optional<Boolean> delete(List<String> idList);

/**
Expand All @@ -68,6 +68,7 @@ default void accept(List<Document> documents) {
* topK, similarity threshold and metadata filter expressions.
* @return Returns documents th match the query request conditions.
*/
@Nullable
List<Document> similaritySearch(SearchRequest request);

/**
Expand All @@ -77,6 +78,7 @@ default void accept(List<Document> documents) {
* @return Returns a list of documents that have embeddings similar to the query text
* embedding.
*/
@Nullable
default List<Document> similaritySearch(String query) {
return this.similaritySearch(SearchRequest.query(query));
}
Expand All @@ -90,8 +92,6 @@ default List<Document> similaritySearch(String query) {
*/
interface Builder<T extends Builder<T>> {

T embeddingModel(EmbeddingModel embeddingModel);

/**
* Sets the registry for collecting observations and metrics. Defaults to
* {@link ObservationRegistry#NOOP} if not specified.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ public abstract class AbstractObservationVectorStore implements VectorStore {
@Nullable
private final VectorStoreObservationConvention customObservationConvention;

@Nullable
protected final EmbeddingModel embeddingModel;

/**
Expand All @@ -59,8 +58,7 @@ public AbstractObservationVectorStore(ObservationRegistry observationRegistry,
this(null, observationRegistry, customObservationConvention);
}

private AbstractObservationVectorStore(@Nullable EmbeddingModel embeddingModel,
ObservationRegistry observationRegistry,
private AbstractObservationVectorStore(EmbeddingModel embeddingModel, ObservationRegistry observationRegistry,
@Nullable VectorStoreObservationConvention customObservationConvention) {
this.embeddingModel = embeddingModel;
this.observationRegistry = observationRegistry;
Expand Down Expand Up @@ -94,6 +92,7 @@ public void add(List<Document> documents) {
}

@Override
@Nullable
public Optional<Boolean> delete(List<String> deleteDocIds) {

VectorStoreObservationContext observationContext = this
Expand All @@ -107,6 +106,7 @@ public Optional<Boolean> delete(List<String> deleteDocIds) {
}

@Override
@Nullable
public List<Document> similaritySearch(SearchRequest request) {

VectorStoreObservationContext searchObservationContext = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void setUp() {
when(this.mockEmbeddingModel.dimensions()).thenReturn(3);
when(this.mockEmbeddingModel.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(this.mockEmbeddingModel.embed(any(Document.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
this.vectorStore = new SimpleVectorStore(SimpleVectorStore.builder().embeddingModel(this.mockEmbeddingModel));
this.vectorStore = new SimpleVectorStore(SimpleVectorStore.builder(this.mockEmbeddingModel));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ public AzureVectorStore vectorStore(SearchIndexClient searchIndexClient, Embeddi
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

var builder = AzureVectorStore.builder()
.searchIndexClient(searchIndexClient)
.embeddingModel(embeddingModel)
var builder = AzureVectorStore.builder(searchIndexClient, embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.filterMetadataFields(List.of())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, Cassandra
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

return CassandraVectorStore.builder()
return CassandraVectorStore.builder(embeddingModel)
.session(cqlSession)
.keyspace(properties.getKeyspace())
.table(properties.getTable())
Expand All @@ -72,7 +72,6 @@ public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, Cassandra
.fixedThreadPoolExecutorSize(properties.getFixedThreadPoolExecutorSize())
.disallowSchemaChanges(!properties.isInitializeSchema())
.returnEmbeddings(properties.getReturnEmbeddings())
.embeddingModel(embeddingModel)
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
.batchingStrategy(batchingStrategy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ public ChromaVectorStore vectorStore(EmbeddingModel embeddingModel, ChromaApi ch
ChromaVectorStoreProperties storeProperties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy chromaBatchingStrategy) {
return ChromaVectorStore.builder()
.chromaApi(chromaApi)
.embeddingModel(embeddingModel)
return ChromaVectorStore.builder(chromaApi, embeddingModel)
.collectionName(storeProperties.getCollectionName())
.initializeSchema(storeProperties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,8 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti
elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity());
}

return ElasticsearchVectorStore.builder()
.restClient(restClient)
return ElasticsearchVectorStore.builder(restClient, embeddingModel)
.options(elasticsearchVectorStoreOptions)
.embeddingModel(embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public GemFireVectorStore gemfireVectorStore(EmbeddingModel embeddingModel, GemF
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

return GemFireVectorStore.builder()
return GemFireVectorStore.builder(embeddingModel)
.host(gemFireConnectionDetails.getHost())
.port(gemFireConnectionDetails.getPort())
.indexName(properties.getIndexName())
Expand All @@ -74,7 +74,6 @@ public GemFireVectorStore gemfireVectorStore(EmbeddingModel embeddingModel, GemF
.vectorSimilarityFunction(properties.getVectorSimilarityFunction())
.fields(properties.getFields())
.sslEnabled(properties.isSslEnabled())
.embeddingModel(embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ public HanaCloudVectorStore vectorStore(HanaVectorRepository<? extends HanaVecto
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {

return HanaCloudVectorStore.builder()
.repository(repository)
.embeddingModel(embeddingModel)
return HanaCloudVectorStore.builder(repository, embeddingModel)
.tableName(properties.getTableName())
.topK(properties.getTopK())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public MariaDBVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel

var initializeSchema = properties.isInitializeSchema();

return MariaDBVectorStore.builder(jdbcTemplate)
.embeddingModel(embeddingModel)
return MariaDBVectorStore.builder(jdbcTemplate, embeddingModel)
.schemaName(properties.getSchemaName())
.vectorTableName(properties.getTableName())
.schemaValidation(properties.isSchemaValidation())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ public MilvusVectorStore vectorStore(MilvusServiceClient milvusClient, Embedding
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {

return MilvusVectorStore.builder()
.milvusClient(milvusClient)
.embeddingModel(embeddingModel)
return MilvusVectorStore.builder(milvusClient, embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.batchingStrategy(batchingStrategy)
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ MongoDBAtlasVectorStore vectorStore(MongoTemplate mongoTemplate, EmbeddingModel
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

MongoDBAtlasVectorStore.MongoDBBuilder builder = MongoDBAtlasVectorStore.builder()
.mongoTemplate(mongoTemplate)
.embeddingModel(embeddingModel)
MongoDBAtlasVectorStore.MongoDBBuilder builder = MongoDBAtlasVectorStore.builder(mongoTemplate, embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ public Neo4jVectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

return Neo4jVectorStore.builder()
.driver(driver)
.embeddingModel(embeddingModel)
return Neo4jVectorStore.builder(driver, embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,8 @@ OpenSearchVectorStore vectorStore(OpenSearchVectorStoreProperties properties, Op
var mappingJson = Optional.ofNullable(properties.getMappingJson())
.orElse(OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION);

return OpenSearchVectorStore.builder()
return OpenSearchVectorStore.builder(openSearchClient, embeddingModel)
.index(indexName)
.openSearchClient(openSearchClient)
.embeddingModel(embeddingModel)
.mappingJson(mappingJson)
.initializeSchema(properties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ public OracleVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel e
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

return OracleVectorStore.builder()
.jdbcTemplate(jdbcTemplate)
.embeddingModel(embeddingModel)
return OracleVectorStore.builder(jdbcTemplate, embeddingModel)
.tableName(properties.getTableName())
.indexType(properties.getIndexType())
.distanceType(properties.getDistanceType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed

var initializeSchema = properties.isInitializeSchema();

return PgVectorStore.builder()
.jdbcTemplate(jdbcTemplate)
.embeddingModel(embeddingModel)
return PgVectorStore.builder(jdbcTemplate, embeddingModel)
.schemaName(properties.getSchemaName())
.vectorTableName(properties.getTableName())
.vectorTableValidationsEnabled(properties.isSchemaValidation())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,9 @@ public PineconeVectorStore vectorStore(EmbeddingModel embeddingModel, PineconeVe
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

return PineconeVectorStore.builder()
.embeddingModel(embeddingModel)
.apiKey(properties.getApiKey())
.environment(properties.getEnvironment())
.projectId(properties.getProjectId())
.indexName(properties.getIndexName())
return PineconeVectorStore
.builder(embeddingModel, properties.getApiKey(), properties.getProjectId(), properties.getEnvironment(),
properties.getIndexName())
.namespace(properties.getNamespace())
.contentFieldName(properties.getContentFieldName())
.distanceMetadataFieldName(properties.getDistanceMetadataFieldName())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ public QdrantVectorStore vectorStore(EmbeddingModel embeddingModel, QdrantVector
QdrantClient qdrantClient, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {
return QdrantVectorStore.builder(qdrantClient)
return QdrantVectorStore.builder(qdrantClient, embeddingModel)
.collectionName(properties.getCollectionName())
.embeddingModel(embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorSt
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

return RedisVectorStore.builder()
.jedis(new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()))
.embeddingModel(embeddingModel)
JedisPooled jedisPooled = new JedisPooled(jedisConnectionFactory.getHostName(),
jedisConnectionFactory.getPort());
return RedisVectorStore.builder(jedisPooled, embeddingModel)
.initializeSchema(properties.isInitializeSchema())
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ public TypesenseVectorStore vectorStore(Client typesenseClient, EmbeddingModel e
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

return TypesenseVectorStore.builder()
.client(typesenseClient)
.embeddingModel(embeddingModel)
return TypesenseVectorStore.builder(typesenseClient, embeddingModel)
.collectionName(properties.getCollectionName())
.embeddingDimension(properties.getEmbeddingDimension())
.initializeSchema(properties.isInitializeSchema())
Expand Down
Loading
Loading