diff --git a/README.md b/README.md index f839635cffb..c57ad89e468 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i - [Moderation](https://docs.spring.io/spring-ai/reference/api/index.html#api/moderation) * Portable API support across AI providers for both synchronous and streaming API options are supported. Access to [model-specific features](https://docs.spring.io/spring-ai/reference/api/chatmodel.html#_chat_options) is also available. * [Structured Outputs](https://docs.spring.io/spring-ai/reference/api/structured-output-converter.html) - Mapping of AI Model output to POJOs. -* Support for all major [Vector Database providers](https://docs.spring.io/spring-ai/reference/api/vectordbs.html) such as *Apache Cassandra, Azure Vector Search, Chroma, Milvus, MongoDB Atlas, Neo4j, Oracle, PostgreSQL/PGVector, PineCone, Qdrant, Redis, and Weaviate*. +* Support for all major [Vector Database providers](https://docs.spring.io/spring-ai/reference/api/vectordbs.html) such as *Apache Cassandra, Azure Vector Search, Chroma, Milvus, MongoDB Atlas, MariaDB, Neo4j, Oracle, PostgreSQL/PGVector, PineCone, Qdrant, Redis, and Weaviate*. * Portable API across Vector Store providers, including a novel SQL-like [metadata filter API](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#metadata-filters). * [Tools/Function Calling](https://docs.spring.io/spring-ai/reference/api/functions.html) - permits the model to request the execution of client-side tools and functions, thereby accessing necessary real-time information as required. * [Observability](https://docs.spring.io/spring-ai/reference/observability/index.html) - Provides insights into AI-related operations. diff --git a/pom.xml b/pom.xml index 96243bee263..3cc026aa9a6 100644 --- a/pom.xml +++ b/pom.xml @@ -52,6 +52,7 @@ vector-stores/spring-ai-elasticsearch-store vector-stores/spring-ai-gemfire-store vector-stores/spring-ai-hanadb-store + vector-stores/spring-ai-mariadb-store vector-stores/spring-ai-milvus-store vector-stores/spring-ai-mongodb-atlas-store vector-stores/spring-ai-neo4j-store @@ -73,6 +74,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store spring-ai-spring-boot-starters/spring-ai-starter-gemfire-store spring-ai-spring-boot-starters/spring-ai-starter-hanadb-store + spring-ai-spring-boot-starters/spring-ai-starter-mariadb-store spring-ai-spring-boot-starters/spring-ai-starter-milvus-store spring-ai-spring-boot-starters/spring-ai-starter-mongodb-atlas-store spring-ai-spring-boot-starters/spring-ai-starter-neo4j-store @@ -212,6 +214,7 @@ 1.9.1 0.5.0 2.10.1 + 3.5.1 0.22.0 @@ -678,13 +681,13 @@ org.springframework.ai.vectorstore**/Hana**IT.java org.springframework.ai.vectorstore**/Hana**IT.java org.springframework.ai.vectorstore**/Milvus**IT.java - org.springframework.ai.vectorstore**/Mongo**IT.java + org.springframework.ai.vectorstore**/MariaDB**IT.java org.springframework.ai.vectorstore**/Mongo**IT.java org.springframework.ai.vectorstore**/Neo4j**IT.java org.springframework.ai.vectorstore**/OpenSearch**IT.java org.springframework.ai.vectorstore**/Oracle**IT.java - org.springframework.ai.vectorstore**/Pinecone**IT.java + org.springframework.ai.vectorstore**/Pinecone**IT.java org.springframework.ai.vectorstore.qdrant/**/**IT.java org.springframework.ai.vectorstore**/Qdrant**IT.java org.springframework.ai.vectorstore**/Redis**IT.java @@ -692,14 +695,14 @@ org.springframework.ai.vectorstore**/Weaviate**IT.java - + org.springframework.ai.autoconfigure.anthropic/**/**IT.java org.springframework.ai.autoconfigure.azure/**/**IT.java org.springframework.ai.autoconfigure.bedrock/**/**IT.java org.springframework.ai.autoconfigure.huggingface/**/**IT.java - + org.springframework.ai.autoconfigure.chat/**/**IT.java org.springframework.ai.autoconfigure.embedding/**/**IT.java org.springframework.ai.autoconfigure.image/**/**IT.java @@ -713,24 +716,24 @@ org.springframework.ai.autoconfigure.postgresml/**/**IT.java org.springframework.ai.autoconfigure.qianfan/**/**IT.java - + org.springframework.ai.autoconfigure.retry/**/**IT.java - + org.springframework.ai.autoconfigure.stabilityai/**/**IT.java org.springframework.ai.autoconfigure.transformers/**/**IT.java - - org.springframework.ai.autoconfigure.vectorstore/**/**IT.java - + + org.springframework.ai.autoconfigure.vectorstore/**/**IT.java + org.springframework.ai.autoconfigure.vertexai/**/**IT.java org.springframework.ai.autoconfigure.watsonxai/**/**IT.java org.springframework.ai.autoconfigure.zhipuai/**/**IT.java - + org.springframework.ai.autoconfigure.zhipuai/**/**IT.java - - + + org.springframework.ai.testcontainers/**/**IT.java - + package org.springframework.ai.docker.compose/**/**IT.java diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index c1b53ed9b96..fb8fee83f4e 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -285,6 +285,12 @@ ${project.version} + + org.springframework.ai + spring-ai-mariadb-store + ${project.version} + + org.springframework.ai @@ -455,6 +461,12 @@ ${project.version} + + org.springframework.ai + spring-ai-mariadb-store-spring-boot-starter + ${project.version} + + org.springframework.ai spring-ai-stability-ai-spring-boot-starter diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java index f43e564fa4e..a65bb57c3aa 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/VectorStoreProvider.java @@ -67,6 +67,11 @@ public enum VectorStoreProvider { */ HANA("hana"), + /** + * Vector store provided by MariaDB. + */ + MARIADB("mariadb"), + /** * Vector store provided by Milvus. */ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index b07bc2f4ae7..98ab6b8370c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -80,6 +80,7 @@ ** xref:api/vectordbs/chroma.adoc[] ** xref:api/vectordbs/elasticsearch.adoc[] ** xref:api/vectordbs/gemfire.adoc[GemFire] +** xref:api/vectordbs/mariadb.adoc[] ** xref:api/vectordbs/milvus.adoc[] ** xref:api/vectordbs/mongodb.adoc[] ** xref:api/vectordbs/neo4j.adoc[] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc index a645f6bb727..15addeb628e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc @@ -207,6 +207,7 @@ These are the available implementations of the `VectorStore` interface: * xref:api/vectordbs/chroma.adoc[Chroma Vector Store] - The https://www.trychroma.com/[Chroma] vector store. * xref:api/vectordbs/elasticsearch.adoc[Elasticsearch Vector Store] - The https://www.elastic.co/[Elasticsearch] vector store. * xref:api/vectordbs/gemfire.adoc[GemFire Vector Store] - The https://tanzu.vmware.com/content/blog/vmware-gemfire-vector-database-extension[GemFire] vector store. +* xref:api/vectordbs/mariadb.adoc[MariaDB Vector Store] - The https://mariadb.com/[MariaDB] vector store. * xref:api/vectordbs/milvus.adoc[Milvus Vector Store] - The https://milvus.io/[Milvus] vector store. * xref:api/vectordbs/mongodb.adoc[MongoDB Atlas Vector Store] - The https://www.mongodb.com/atlas/database[MongoDB Atlas] vector store. * xref:api/vectordbs/neo4j.adoc[Neo4j Vector Store] - The https://neo4j.com/[Neo4j] vector store. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mariadb.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mariadb.adoc new file mode 100644 index 00000000000..c793a8979e9 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mariadb.adoc @@ -0,0 +1,187 @@ += MariaDB Vector + +This section walks you through setting up the MariaDB `VectorStore` to store document embeddings and perform similarity searches. + +link:https://mariadb.org/projects/mariadb-vector/[MariaDB vector] is part of MariaDB 11.7 and enables storing and searching over machine learning-generated embeddings. + +== Auto-Configuration + +Add the MariaDBVectorStore boot starter dependency to your project: + +[source,xml] +---- + + org.springframework.ai + spring-ai-mariadb-store-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-mariadb-store-spring-boot-starter' +} +---- + +The vector store implementation can initialize the required schema for you, but you must opt-in by specifying the `initializeSchema` boolean in the appropriate constructor or by setting `...initialize-schema=true` in the `application.properties` file. + +The Vector Store also requires an `EmbeddingModel` instance to calculate embeddings for the documents. +You can pick one of the available xref:api/embeddings.adoc#available-implementations[EmbeddingModel Implementations]. + +For example, to use the xref:api/embeddings/openai-embeddings.adoc[OpenAI EmbeddingModel], add the following dependency to your project: + +[source,xml] +---- + + org.springframework.ai + spring-ai-openai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add Milestone and/or Snapshot Repositories to your build file. + +To connect to and configure the `MariaDBVectorStore`, you need to provide access details for your instance. +A simple configuration can be provided via Spring Boot's `application.yml`. + +[yml] +---- +spring: + datasource: + url: jdbc:mariadb://localhost/db + username: myUser + password: myPassword + ai: + vectorstore: + mariadbvector: + distance-type: COSINE + dimensions: 1536 +---- + +TIP: If you run MariaDBvector as a Spring Boot dev service via link:https://docs.spring.io/spring-boot/reference/features/dev-services.html#features.dev-services.docker-compose[Docker Compose] +or link:https://docs.spring.io/spring-boot/reference/features/dev-services.html#features.dev-services.testcontainers[Testcontainers], +you don't need to configure URL, username and password since they are autoconfigured by Spring Boot. + +TIP: Check the list of xref:#mariadbvector-properties[configuration parameters] to learn about the default values and configuration options. + +Now you can auto-wire the `MariaDBVectorStore` in your application and use it + +[source,java] +---- +@Autowired VectorStore vectorStore; + +// ... + +List documents = List.of( + new Document("Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!! Spring AI rocks!!", Map.of("meta1", "meta1")), + new Document("The World is Big and Salvation Lurks Around the Corner"), + new Document("You walk forward facing the past and you turn back toward the future.", Map.of("meta2", "meta2"))); + +// Add the documents to PGVector +vectorStore.add(documents); + +// Retrieve documents similar to a query +List results = this.vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); +---- + +[[mariadbvector-properties]] +=== Configuration properties + +You can use the following properties in your Spring Boot configuration to customize the MariaDB vector store. + +[cols="2,5,1",stripes=even] +|=== +|Property| Description | Default value + +|`spring.ai.vectorstore.mariadb.distance-type`| Search distance type. Defaults to `COSINE`. But if vectors are normalized to length 1, you can use `EUCLIDEAN` for best performance.| COSINE +|`spring.ai.vectorstore.mariadb.dimensions`| Embeddings dimension. If not specified explicitly the PgVectorStore will retrieve the dimensions form the provided `EmbeddingModel`. Dimensions are set to the embedding column the on table creation. If you change the dimensions your would have to re-create the vector_store table as well. | - +|`spring.ai.vectorstore.mariadb.remove-existing-vector-store-table` | Deletes the existing `vector_store` table on start up. | false +|`spring.ai.vectorstore.mariadb.initialize-schema` | Whether to initialize the required schema | false +|`spring.ai.vectorstore.mariadb.schema-name` | Vector store schema name | null +|`spring.ai.vectorstore.mariadb.table-name` | Vector store table name | `vector_store` +|`spring.ai.vectorstore.mariadb.schema-validation` | Enables schema and table name validation to ensure they are valid and existing objects. | false + +|=== + +TIP: If you configure a custom schema and/or table name, consider enabling schema validation by setting `spring.ai.vectorstore.mariadb.schema-validation=true`. +This ensures the correctness of the names and reduces the risk of SQL injection attacks. + +== Metadata filtering + +You can leverage the generic, portable link:https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_metadata_filters[metadata filters] with the MariaDB Vector store. + +For example, you can use either the text expression language: + +[source,java] +---- +vectorStore.similaritySearch( + SearchRequest.defaults() + .withQuery("The World") + .withTopK(TOP_K) + .withSimilarityThreshold(SIMILARITY_THRESHOLD) + .withFilterExpression("author in ['john', 'jill'] && article_type == 'blog'")); +---- + +or programmatically using the `Filter.Expression` DSL: + +[source,java] +---- +FilterExpressionBuilder b = new FilterExpressionBuilder(); + +vectorStore.similaritySearch(SearchRequest.defaults() + .withQuery("The World") + .withTopK(TOP_K) + .withSimilarityThreshold(SIMILARITY_THRESHOLD) + .withFilterExpression(b.and( + b.in("author","john", "jill"), + b.eq("article_type", "blog")).build())); +---- + +NOTE: These filter expressions are converted into the equivalent PgVector filters. + +== Manual Configuration + +Instead of using the Spring Boot auto-configuration, you can manually configure the `MariaDBVectorStore`. +For this you need to add the MariaDB connector and `JdbcTemplate` auto-configuration dependencies to your project: + +[source,xml] +---- + + org.springframework.boot + spring-boot-starter-jdbc + + + + org.mariadb.jdbc + mariadb-java-client + runtime + + + + org.springframework.ai + spring-ai-mariadb-store + +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +To configure MariaDB Vector in your application, you can use the following setup: + +[source,java] +---- +@Bean +public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + return new MariaDBVectorStore(jdbcTemplate, embeddingModel); +} +---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/index.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/index.adoc index 484b9a0a643..a560889d98d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/index.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/index.adoc @@ -1,5 +1,4 @@ [[introduction]] -= Introduction image::spring_ai_logo_with_text.svg[Integration Problem, width=300, align="left"] @@ -26,7 +25,7 @@ Spring AI provides the following features: ** xref:api/audio/speech.adoc[Text to Speech] ** xref:api/moderation[Moderation] * xref:api/structured-output-converter.adoc[Structured Outputs] - Mapping of AI Model output to POJOs. -* Support for all major xref:api/vectordbs.adoc[Vector Database providers] such as Apache Cassandra, Azure Cosmos DB, Azure Vector Search, Chroma, Elasticsearch, GemFire, Milvus, MongoDB Atlas, Neo4j, OpenSearch, Oracle, PostgreSQL/PGVector, PineCone, Qdrant, Redis, SAP Hana, Typesense and Weaviate. +* Support for all major xref:api/vectordbs.adoc[Vector Database providers] such as Apache Cassandra, Azure Cosmos DB, Azure Vector Search, Chroma, Elasticsearch, GemFire, MariaDB, Milvus, MongoDB Atlas, Neo4j, OpenSearch, Oracle, PostgreSQL/PGVector, PineCone, Qdrant, Redis, SAP Hana, Typesense and Weaviate. * Portable API across Vector Store providers, including a novel SQL-like metadata filter API. * xref:api/functions.adoc[Tools/Function Calling] - permits the model to request the execution of client-side tools and functions, thereby accessing necessary real-time information as required. * xref:observability/index.adoc[Observability] - Provides insights into AI-related operations. diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 4ccc3c233ef..d2abc8de56e 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -178,6 +178,20 @@ true + + + org.springframework.ai + spring-ai-mariadb-store + ${project.parent.version} + true + + + org.mariadb.jdbc + mariadb-java-client + ${mariadb.version} + true + + org.springframework.ai @@ -457,6 +471,12 @@ test + + org.testcontainers + mariadb + test + + org.testcontainers junit-jupiter diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfiguration.java new file mode 100644 index 00000000000..1b940447bd4 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfiguration.java @@ -0,0 +1,78 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.mariadb; + +import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.vectorstore.MariaDBVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.jdbc.core.JdbcTemplate; + +import javax.sql.DataSource; + +/** + * @author Diego Dupin + * @since 1.0.0 + */ +@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class) +@ConditionalOnClass({ MariaDBVectorStore.class, DataSource.class, JdbcTemplate.class }) +@EnableConfigurationProperties(MariaDbStoreProperties.class) +public class MariaDbStoreAutoConfiguration { + + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy mariaDbStoreBatchingStrategy() { + return new TokenCountBatchingStrategy(); + } + + @Bean + @ConditionalOnMissingBean + public MariaDBVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, + MariaDbStoreProperties properties, ObjectProvider observationRegistry, + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { + + var initializeSchema = properties.isInitializeSchema(); + + return new MariaDBVectorStore.Builder(jdbcTemplate, embeddingModel).withSchemaName(properties.getSchemaName()) + .withVectorTableName(properties.getTableName()) + .withVectorTableValidationsEnabled(properties.isSchemaValidation()) + .withDimensions(properties.getDimensions()) + .withDistanceType(properties.getDistanceType()) + .withContentFieldName(properties.getContentFieldName()) + .withEmbeddingFieldName(properties.getEmbeddingFieldName()) + .withIdFieldName(properties.getIdFieldName()) + .withMetadataFieldName(properties.getMetadataFieldName()) + .withRemoveExistingVectorStoreTable(properties.isRemoveExistingVectorStoreTable()) + .withInitializeSchema(initializeSchema) + .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null)) + .withBatchingStrategy(batchingStrategy) + .withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize()) + .build(); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreProperties.java new file mode 100644 index 00000000000..1451c557d13 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreProperties.java @@ -0,0 +1,142 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.mariadb; + +import org.springframework.ai.autoconfigure.vectorstore.CommonVectorStoreProperties; +import org.springframework.ai.vectorstore.MariaDBVectorStore; +import org.springframework.ai.vectorstore.MariaDBVectorStore.MariaDBDistanceType; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author Diego Dupin + */ +@ConfigurationProperties(MariaDbStoreProperties.CONFIG_PREFIX) +public class MariaDbStoreProperties extends CommonVectorStoreProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vectorstore.mariadb"; + + private int dimensions = MariaDBVectorStore.INVALID_EMBEDDING_DIMENSION; + + private MariaDBDistanceType distanceType = MariaDBDistanceType.COSINE; + + private boolean removeExistingVectorStoreTable = false; + + private String tableName = MariaDBVectorStore.DEFAULT_TABLE_NAME; + + private String schemaName = null; + + private String embeddingFieldName = MariaDBVectorStore.DEFAULT_COLUMN_EMBEDDING; + + private String idFieldName = MariaDBVectorStore.DEFAULT_COLUMN_ID; + + private String metadataFieldName = MariaDBVectorStore.DEFAULT_COLUMN_METADATA; + + private String contentFieldName = MariaDBVectorStore.DEFAULT_COLUMN_CONTENT; + + private boolean schemaValidation = MariaDBVectorStore.DEFAULT_SCHEMA_VALIDATION; + + private int maxDocumentBatchSize = MariaDBVectorStore.MAX_DOCUMENT_BATCH_SIZE; + + public int getDimensions() { + return this.dimensions; + } + + public void setDimensions(int dimensions) { + this.dimensions = dimensions; + } + + public MariaDBVectorStore.MariaDBDistanceType getDistanceType() { + return this.distanceType; + } + + public void setDistanceType(MariaDBDistanceType distanceType) { + this.distanceType = distanceType; + } + + public boolean isRemoveExistingVectorStoreTable() { + return this.removeExistingVectorStoreTable; + } + + public void setRemoveExistingVectorStoreTable(boolean removeExistingVectorStoreTable) { + this.removeExistingVectorStoreTable = removeExistingVectorStoreTable; + } + + public String getTableName() { + return this.tableName; + } + + public void setTableName(String vectorTableName) { + this.tableName = vectorTableName; + } + + public String getSchemaName() { + return this.schemaName; + } + + public void setSchemaName(String schemaName) { + this.schemaName = schemaName; + } + + public boolean isSchemaValidation() { + return this.schemaValidation; + } + + public void setSchemaValidation(boolean schemaValidation) { + this.schemaValidation = schemaValidation; + } + + public int getMaxDocumentBatchSize() { + return this.maxDocumentBatchSize; + } + + public void setMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + } + + public String getEmbeddingFieldName() { + return embeddingFieldName; + } + + public void setEmbeddingFieldName(String embeddingFieldName) { + this.embeddingFieldName = embeddingFieldName; + } + + public String getIdFieldName() { + return idFieldName; + } + + public void setIdFieldName(String idFieldName) { + this.idFieldName = idFieldName; + } + + public String getMetadataFieldName() { + return metadataFieldName; + } + + public void setMetadataFieldName(String metadataFieldName) { + this.metadataFieldName = metadataFieldName; + } + + public String getContentFieldName() { + return contentFieldName; + } + + public void setContentFieldName(String contentFieldName) { + this.contentFieldName = contentFieldName; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index b15feb0453f..f3e5633efc0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -47,6 +47,7 @@ org.springframework.ai.autoconfigure.vectorstore.neo4j.Neo4jVectorStoreAutoConfi org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.hanadb.HanaCloudVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.cosmosdb.CosmosDBVectorStoreAutoConfiguration +org.springframework.ai.autoconfigure.vectorstore.mariadb.MariaDbStoreAutoConfiguration org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration org.springframework.ai.autoconfigure.postgresml.PostgresMlAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.mongo.MongoDBAtlasVectorStoreAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfigurationIT.java new file mode 100644 index 00000000000..ad6e565b8b2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStoreAutoConfigurationIT.java @@ -0,0 +1,195 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.mariadb; + +import io.micrometer.observation.tck.TestObservationRegistry; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.transformers.TransformersEmbeddingModel; +import org.springframework.ai.vectorstore.MariaDBVectorStore; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; + +/** + * @author Diego Dupin + */ +@Testcontainers +public class MariaDbStoreAutoConfigurationIT { + + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("mariadb:11.7-rc"); + + @Container + @SuppressWarnings("resource") + static MariaDBContainer mariadbContainer = new MariaDBContainer<>(DEFAULT_IMAGE); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(MariaDbStoreAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .withPropertyValues("spring.ai.vectorstore.mariadb.distanceType=COSINE", + "spring.ai.vectorstore.mariadb.initialize-schema=true", + // JdbcTemplate configuration + "spring.datasource.url=" + mariadbContainer.getJdbcUrl(), + "spring.datasource.username=" + mariadbContainer.getUsername(), + "spring.datasource.password=" + mariadbContainer.getPassword()); + + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("spring", "great")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("depression", "bad"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static boolean isFullyQualifiedTableExists(ApplicationContext context, String schemaName, + String tableName) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + if (schemaName == null) { + String sqlWithoutSchema = "SELECT EXISTS (SELECT * FROM information_schema.tables WHERE table_schema = SCHEMA() AND table_name = ?) as results"; + return jdbcTemplate.queryForObject(sqlWithoutSchema, Boolean.class, tableName); + } else { + String sqlWithSchema = "SELECT EXISTS (SELECT * FROM information_schema.tables WHERE table_schema = ? AND table_name = ?) as results"; + return jdbcTemplate.queryForObject(sqlWithSchema, Boolean.class, schemaName , tableName); + } + } + + @Test + public void addAndSearch() { + + this.contextRunner.run(context -> { + + MariaDBVectorStore vectorStore = context.getBean(MariaDBVectorStore.class); + TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); + + assertThat(isFullyQualifiedTableExists(context, null, MariaDBVectorStore.DEFAULT_TABLE_NAME)).isTrue(); + + vectorStore.add(this.documents); + + assertObservationRegistry(observationRegistry, VectorStoreProvider.MARIADB, + VectorStoreObservationContext.Operation.ADD); + observationRegistry.clear(); + + List results = vectorStore + .similaritySearch(SearchRequest.query("What is Great Depression?").withTopK(1)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); + assertThat(resultDoc.getMetadata()).containsKeys("depression", "distance"); + + assertObservationRegistry(observationRegistry, VectorStoreProvider.MARIADB, + VectorStoreObservationContext.Operation.QUERY); + observationRegistry.clear(); + + // Remove all documents from the store + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); + + assertObservationRegistry(observationRegistry, VectorStoreProvider.MARIADB, + VectorStoreObservationContext.Operation.DELETE); + + results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + assertThat(results).hasSize(0); + observationRegistry.clear(); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "test:vector_store:id:metadata:embedding:content", + "test:my_table:my_id:my_metadata:my_embedding:my_content" }) + public void customSchemaNames(String schemaTableName) { + String schemaName = schemaTableName.split(":")[0]; + String tableName = schemaTableName.split(":")[1]; + String idName = schemaTableName.split(":")[2]; + String metaName = schemaTableName.split(":")[3]; + String embeddingName = schemaTableName.split(":")[4]; + String contentName = schemaTableName.split(":")[5]; + + this.contextRunner + .withPropertyValues("spring.ai.vectorstore.mariadb.schema-name=" + schemaName, + "spring.ai.vectorstore.mariadb.table-name=" + tableName, + "spring.ai.vectorstore.mariadb.id-field-name=" + idName, + "spring.ai.vectorstore.mariadb.metadata-field-name=" + metaName, + "spring.ai.vectorstore.mariadb.embedding-field-name=" + embeddingName, + "spring.ai.vectorstore.mariadb.content-field-name=" + contentName) + .run(context -> { + assertThat(isFullyQualifiedTableExists(context, schemaName, tableName)).isTrue(); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "test:vector_store", "test:my_table" }) + public void disableSchemaInitialization(String schemaTableName) { + String schemaName = schemaTableName.split(":")[0]; + String tableName = schemaTableName.split(":")[1]; + + this.contextRunner + .withPropertyValues("spring.ai.vectorstore.mariadb.schema-name=" + schemaName, + "spring.ai.vectorstore.mariadb.table-name=" + tableName, + "spring.ai.vectorstore.mariadb.initialize-schema=false") + .run(context -> { + assertThat(isFullyQualifiedTableExists(context, schemaName, tableName)).isFalse(); + }); + } + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public EmbeddingModel embeddingModel() { + return new TransformersEmbeddingModel(); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStorePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStorePropertiesTests.java new file mode 100644 index 00000000000..5e1b79e1e6c --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/mariadb/MariaDbStorePropertiesTests.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vectorstore.mariadb; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.autoconfigure.vectorstore.mariadb.MariaDbStoreProperties; +import org.springframework.ai.vectorstore.MariaDBVectorStore; +import org.springframework.ai.vectorstore.MariaDBVectorStore.MariaDBDistanceType; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Diego Dupin + */ +public class MariaDbStorePropertiesTests { + + @Test + public void defaultValues() { + var props = new MariaDbStoreProperties(); + assertThat(props.getDimensions()).isEqualTo(MariaDBVectorStore.INVALID_EMBEDDING_DIMENSION); + assertThat(props.getDistanceType()).isEqualTo(MariaDBDistanceType.COSINE); + assertThat(props.isRemoveExistingVectorStoreTable()).isFalse(); + + assertThat(props.isSchemaValidation()).isFalse(); + assertThat(props.getSchemaName()).isNull(); + assertThat(props.getTableName()).isEqualTo(MariaDBVectorStore.DEFAULT_TABLE_NAME); + + } + + @Test + public void customValues() { + var props = new MariaDbStoreProperties(); + + props.setDimensions(1536); + props.setDistanceType(MariaDBDistanceType.EUCLIDEAN); + props.setRemoveExistingVectorStoreTable(true); + + props.setSchemaValidation(true); + props.setSchemaName("my_vector_schema"); + props.setTableName("my_vector_table"); + props.setIdFieldName("my_vector_id"); + props.setMetadataFieldName("my_vector_meta"); + props.setContentFieldName("my_vector_content"); + props.setEmbeddingFieldName("my_vector_embedding"); + props.setInitializeSchema(true); + + assertThat(props.getDimensions()).isEqualTo(1536); + assertThat(props.getDistanceType()).isEqualTo(MariaDBDistanceType.EUCLIDEAN); + assertThat(props.isRemoveExistingVectorStoreTable()).isTrue(); + + assertThat(props.isSchemaValidation()).isTrue(); + assertThat(props.getSchemaName()).isEqualTo("my_vector_schema"); + assertThat(props.getTableName()).isEqualTo("my_vector_table"); + assertThat(props.getIdFieldName()).isEqualTo("my_vector_id"); + assertThat(props.getMetadataFieldName()).isEqualTo("my_vector_meta"); + assertThat(props.getContentFieldName()).isEqualTo("my_vector_content"); + assertThat(props.getEmbeddingFieldName()).isEqualTo("my_vector_embedding"); + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-mariadb-store/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-mariadb-store/pom.xml new file mode 100644 index 00000000000..5360a7ba223 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-mariadb-store/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-mariadb-store-spring-boot-starter + jar + Spring AI Starter - MariaDB Vector Store + Spring AI MariaDB Vector Store Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-mariadb-store + ${project.parent.version} + + + + diff --git a/vector-stores/spring-ai-mariadb-store/README.md b/vector-stores/spring-ai-mariadb-store/README.md new file mode 100644 index 00000000000..175b193ba7e --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/README.md @@ -0,0 +1 @@ +[MariaDB Vector Store Documentation](https://docs.spring.io/spring-ai/reference/api/vectordbs/mariadb.html) \ No newline at end of file diff --git a/vector-stores/spring-ai-mariadb-store/pom.xml b/vector-stores/spring-ai-mariadb-store/pom.xml new file mode 100644 index 00000000000..18331999ab2 --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/pom.xml @@ -0,0 +1,128 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-mariadb-store + jar + Spring AI Vector Store - MariaDB + Spring AI MariaDB Vector Store + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + com.zaxxer + HikariCP + + + + org.springframework + spring-jdbc + + + + org.mariadb.jdbc + mariadb-java-client + ${mariadb.version} + + + + + org.springframework.ai + spring-ai-openai + ${project.parent.version} + test + + + + + org.springframework.ai + spring-ai-test + ${project.parent.version} + test + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.testcontainers + testcontainers + test + + + org.testcontainers + mariadb + test + + + + org.testcontainers + junit-jupiter + test + + + + io.micrometer + micrometer-observation-test + test + + + + + + + org.apache.maven.plugins + maven-failsafe-plugin + ${maven-failsafe-plugin.version} + + ${skip.vectorstore.mariadb} + + + + + integration-test + verify + + + + + + + diff --git a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/MariaDBFilterExpressionConverter.java b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/MariaDBFilterExpressionConverter.java new file mode 100644 index 00000000000..b9e63864db3 --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/MariaDBFilterExpressionConverter.java @@ -0,0 +1,114 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import java.util.List; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Group; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; + +/** + * Converts {@link Expression} into JSON metadata filter expression format. + * (https://mariadb.com/kb/en/json-functions/) + * + * @author Diego Dupin + */ +public class MariaDBFilterExpressionConverter extends AbstractFilterExpressionConverter { + + private final String metadataFieldName; + + public MariaDBFilterExpressionConverter(String metadataFieldName) { + this.metadataFieldName = metadataFieldName; + } + + @Override + protected void doExpression(Expression expression, StringBuilder context) { + this.convertOperand(expression.left(), context); + context.append(getOperationSymbol(expression)); + this.convertOperand(expression.right(), context); + } + + private void convertToConditions(Expression expression, StringBuilder context) { + Filter.Value right = (Filter.Value) expression.right(); + Object value = right.value(); + if (!(value instanceof List)) { + throw new IllegalArgumentException("Expected a List, but got: " + value.getClass().getSimpleName()); + } + List values = (List) value; + for (int i = 0; i < values.size(); i++) { + this.convertOperand(expression.left(), context); + context.append(" == "); + this.doSingleValue(values.get(i), context); + if (i < values.size() - 1) { + context.append(" || "); + } + } + } + + @Override + protected void doSingleValue(Object value, StringBuilder context) { + if (value instanceof String) { + context.append(String.format("\'%s\'", value)); + } + else { + context.append(value); + } + } + + private String getOperationSymbol(Expression exp) { + return switch (exp.type()) { + case AND -> " AND "; + case OR -> " OR "; + case EQ -> " = "; + case NE -> " != "; + case LT -> " < "; + case LTE -> " <= "; + case GT -> " > "; + case GTE -> " >= "; + case IN -> " IN "; + case NOT, NIN -> " NOT IN "; + // you never know what the future might bring + default -> throw new RuntimeException("Not supported expression type: " + exp.type()); + }; + } + + @Override + protected void doKey(Key key, StringBuilder context) { + context.append("JSON_VALUE(" + this.metadataFieldName + ", '$." + key.key() + "')"); + } + + protected void doStartValueRange(Filter.Value listValue, StringBuilder context) { + context.append("("); + } + + protected void doEndValueRange(Filter.Value listValue, StringBuilder context) { + context.append(")"); + } + + @Override + protected void doStartGroup(Group group, StringBuilder context) { + context.append("("); + } + + @Override + protected void doEndGroup(Group group, StringBuilder context) { + context.append(")"); + } + +} diff --git a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/MariaDBSchemaValidator.java b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/MariaDBSchemaValidator.java new file mode 100644 index 00000000000..dc013d83a16 --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/MariaDBSchemaValidator.java @@ -0,0 +1,164 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import org.mariadb.jdbc.Driver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.core.JdbcTemplate; + +/** + * @author Diego Dupin + * @since 1.0.0 + */ +public class MariaDBSchemaValidator { + + private static final Logger logger = LoggerFactory.getLogger(MariaDBSchemaValidator.class); + + private final JdbcTemplate jdbcTemplate; + + public MariaDBSchemaValidator(JdbcTemplate jdbcTemplate) { + this.jdbcTemplate = jdbcTemplate; + } + + private boolean isTableExists(String schemaName, String tableName) { + // schema and table are expected to be escaped + String sql = String.format("SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = ", + (schemaName == null) ? "SCHEMA()" : schemaName, tableName); + try { + // Query for a single integer value, if it exists, table exists + this.jdbcTemplate.queryForObject(sql, Integer.class); + return true; + } + catch (DataAccessException e) { + return false; + } + } + + void validateTableSchema(String schemaName, String tableName, String idFieldName, String contentFieldName, + String metadataFieldName, String embeddingFieldName, int embeddingDimensions) { + + if (!isTableExists(schemaName, tableName)) { + throw new IllegalStateException( + String.format("Table '%s' does not exist in schema '%s'", tableName, schemaName)); + } + + // ensure server support VECTORs + try { + // Query for a single integer value, if it exists, database support vector + this.jdbcTemplate.queryForObject("SELECT vec_distance_euclidean(x'0000803f', x'0000803f')", Integer.class, + schemaName, tableName); + } + catch (DataAccessException e) { + logger.error("Error while validating database vector support " + e.getMessage()); + logger.error("Failed to validate that database supports VECTOR.\n" + "Run the following SQL commands:\n" + + " SELECT @@version; \nAnd ensure that version is >= 11.7.1"); + throw new IllegalStateException(e); + } + + try { + logger.info("Validating MariaDBStore schema for table: {} in schema: {}", tableName, schemaName); + + List expectedColumns = new ArrayList<>(); + expectedColumns.add(idFieldName); + expectedColumns.add(contentFieldName); + expectedColumns.add(metadataFieldName); + expectedColumns.add(embeddingFieldName); + + // Query to check if the table exists with the required fields and types + // Include the schema name in the query to target the correct table + String query = "SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS " + + "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?"; + List> columns = this.jdbcTemplate.queryForList(query, schemaName, tableName); + + if (columns.isEmpty()) { + throw new IllegalStateException("Error while validating table schema, Table " + tableName + + " does not exist in schema " + schemaName); + } + + // Check each column against expected fields + List availableColumns = new ArrayList<>(); + for (Map column : columns) { + String columnName = validateAndEnquoteIdentifier((String) column.get("COLUMN_NAME"), false); + availableColumns.add(columnName); + } + + // TODO ensure id is a primary key for batch update + + expectedColumns.removeAll(availableColumns); + + if (expectedColumns.isEmpty()) { + logger.info("MariaDB VectorStore schema validation successful"); + } + else { + throw new IllegalStateException("Missing fields " + expectedColumns); + } + + } + catch (DataAccessException | IllegalStateException e) { + logger.error("Error while validating table schema" + e.getMessage()); + logger.error("Failed to operate with the specified table in the database. To resolve this issue," + + " please ensure the following steps are completed:\n" + + "1. Verify that the table exists with the appropriate structure. If it does not" + + " exist, create it using a SQL command similar to the following:\n" + + String.format(""" + CREATE TABLE IF NOT EXISTS %s ( + %s UUID NOT NULL DEFAULT uuid() PRIMARY KEY, + %s TEXT, + %s JSON, + %s VECTOR(%d) NOT NULL, + VECTOR INDEX (%s) + ) ENGINE=InnoDB""", schemaName == null ? tableName : schemaName + "." + tableName, + idFieldName, contentFieldName, metadataFieldName, embeddingFieldName, embeddingDimensions, + embeddingFieldName) + + "\n" + "Please adjust these commands based on your specific configuration and the" + + " capabilities of your vector database system."); + throw new IllegalStateException(e); + } + } + + /** + * Escaped identifier according to MariaDB requirement. + * @param identifier identifier + * @param alwaysQuote indicate if identifier must be quoted even if not necessary. + * @return return escaped identifier, quoted when necessary or indicated with + * alwaysQuote. + * @see mariadb + * identifier name + */ + // @Override when not supporting java 8 + public static String validateAndEnquoteIdentifier(String identifier, boolean alwaysQuote) { + try { + String quotedId = Driver.enquoteIdentifier(identifier, alwaysQuote); + // force use of simple table name + if (Pattern.compile("`?[\\p{Alnum}_]*`?").matcher(identifier).matches()) + return quotedId; + throw new IllegalArgumentException(String + .format("Identifier '%s' should only contain alphanumeric characters and underscores", quotedId)); + } + catch (SQLException e) { + throw new IllegalArgumentException(e); + } + } + +} diff --git a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/MariaDBVectorStore.java b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/MariaDBVectorStore.java new file mode 100644 index 00000000000..461f684efb4 --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/MariaDBVectorStore.java @@ -0,0 +1,559 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import io.micrometer.observation.ObservationRegistry; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; +import org.springframework.ai.util.JacksonUtils; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; +import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * Uses the "vector_store" table to store the Spring AI vector data. The table and the + * vector index will be auto-created if not available. + * + * @author Diego Dupin + * @since 1.0.0 + */ +public class MariaDBVectorStore extends AbstractObservationVectorStore implements InitializingBean { + + public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536; + + public static final int INVALID_EMBEDDING_DIMENSION = -1; + + public static final boolean DEFAULT_SCHEMA_VALIDATION = false; + + public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000; + + private static final Logger logger = LoggerFactory.getLogger(MariaDBVectorStore.class); + + public static final String DEFAULT_TABLE_NAME = "vector_store"; + + public static final String DEFAULT_COLUMN_EMBEDDING = "embedding"; + + public static final String DEFAULT_COLUMN_METADATA = "metadata"; + + public static final String DEFAULT_COLUMN_ID = "id"; + + public static final String DEFAULT_COLUMN_CONTENT = "content"; + + private static Map SIMILARITY_TYPE_MAPPING = Map.of( + MariaDBDistanceType.COSINE, VectorStoreSimilarityMetric.COSINE, MariaDBDistanceType.EUCLIDEAN, + VectorStoreSimilarityMetric.EUCLIDEAN); + + public final FilterExpressionConverter filterExpressionConverter; + + private final String vectorTableName; + + private final JdbcTemplate jdbcTemplate; + + private final EmbeddingModel embeddingModel; + + private final String schemaName; + + private final boolean schemaValidation; + + private final boolean initializeSchema; + + private final int dimensions; + + private final String contentFieldName; + + private final String embeddingFieldName; + + private final String idFieldName; + + private final String metadataFieldName; + + private final MariaDBDistanceType distanceType; + + private final ObjectMapper objectMapper; + + private final boolean removeExistingVectorStoreTable; + + private final MariaDBSchemaValidator schemaValidator; + + private final BatchingStrategy batchingStrategy; + + private final int maxDocumentBatchSize; + + public MariaDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, MariaDBDistanceType.COSINE, false, false); + } + + public MariaDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) { + this(jdbcTemplate, embeddingModel, dimensions, MariaDBDistanceType.COSINE, false, false); + } + + public MariaDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, + MariaDBDistanceType distanceType, boolean removeExistingVectorStoreTable, boolean initializeSchema) { + + this(DEFAULT_TABLE_NAME, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, + initializeSchema); + } + + public MariaDBVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, + int dimensions, MariaDBDistanceType distanceType, boolean removeExistingVectorStoreTable, + boolean initializeSchema) { + + this(null, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions, distanceType, + removeExistingVectorStoreTable, initializeSchema); + } + + private MariaDBVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, + JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, MariaDBDistanceType distanceType, + boolean removeExistingVectorStoreTable, boolean initializeSchema) { + + this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions, + distanceType, removeExistingVectorStoreTable, initializeSchema, ObservationRegistry.NOOP, null, + new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE, DEFAULT_COLUMN_EMBEDDING, + DEFAULT_COLUMN_METADATA, DEFAULT_COLUMN_ID, DEFAULT_COLUMN_CONTENT); + } + + private MariaDBVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, + JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, MariaDBDistanceType distanceType, + boolean removeExistingVectorStoreTable, boolean initializeSchema, ObservationRegistry observationRegistry, + VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy, + int maxDocumentBatchSize, String contentFieldName, String embeddingFieldName, String idFieldName, + String metadataFieldName) { + + super(observationRegistry, customObservationConvention); + + this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build(); + + this.vectorTableName = (null == vectorTableName || vectorTableName.isEmpty()) ? DEFAULT_TABLE_NAME + : MariaDBSchemaValidator.validateAndEnquoteIdentifier(vectorTableName.trim(), false); + logger.info("Using the vector table name: {}. Is empty: {}", this.vectorTableName, + (vectorTableName == null || vectorTableName.isEmpty())); + + this.schemaName = schemaName == null ? null + : MariaDBSchemaValidator.validateAndEnquoteIdentifier(schemaName, false); + this.schemaValidation = vectorTableValidationsEnabled; + + this.jdbcTemplate = jdbcTemplate; + this.embeddingModel = embeddingModel; + this.dimensions = dimensions; + this.distanceType = distanceType; + this.removeExistingVectorStoreTable = removeExistingVectorStoreTable; + this.initializeSchema = initializeSchema; + this.schemaValidator = new MariaDBSchemaValidator(jdbcTemplate); + this.batchingStrategy = batchingStrategy; + this.maxDocumentBatchSize = maxDocumentBatchSize; + + this.contentFieldName = MariaDBSchemaValidator.validateAndEnquoteIdentifier(contentFieldName, false); + this.embeddingFieldName = MariaDBSchemaValidator.validateAndEnquoteIdentifier(embeddingFieldName, false); + this.idFieldName = MariaDBSchemaValidator.validateAndEnquoteIdentifier(idFieldName, false); + this.metadataFieldName = MariaDBSchemaValidator.validateAndEnquoteIdentifier(metadataFieldName, false); + filterExpressionConverter = new MariaDBFilterExpressionConverter(this.metadataFieldName); + } + + public MariaDBDistanceType getDistanceType() { + return this.distanceType; + } + + @Override + public void doAdd(List documents) { + // Batch the documents based on the batching strategy + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + + List> batchedDocuments = batchDocuments(documents); + batchedDocuments.forEach(this::insertOrUpdateBatch); + } + + private List> batchDocuments(List documents) { + List> batches = new ArrayList<>(); + for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) { + batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size()))); + } + return batches; + } + + private void insertOrUpdateBatch(List batch) { + String sql = String.format( + "INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?) " + + "ON DUPLICATE KEY UPDATE %s = VALUES(%s) , %s = VALUES(%s) , %s = VALUES(%s)", + getFullyQualifiedTableName(), this.idFieldName, this.contentFieldName, this.metadataFieldName, + this.embeddingFieldName, this.contentFieldName, this.contentFieldName, this.metadataFieldName, + this.metadataFieldName, this.embeddingFieldName, this.embeddingFieldName); + + this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() { + + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + var document = batch.get(i); + ps.setObject(1, document.getId()); + ps.setString(2, document.getContent()); + ps.setString(3, toJson(document.getMetadata())); + ps.setObject(4, document.getEmbedding()); + } + + @Override + public int getBatchSize() { + return batch.size(); + } + }); + } + + private String toJson(Map map) { + try { + return this.objectMapper.writeValueAsString(map); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public Optional doDelete(List idList) { + int updateCount = 0; + for (String id : idList) { + int count = this.jdbcTemplate.update( + String.format("DELETE FROM %s WHERE %s = ?", getFullyQualifiedTableName(), this.idFieldName), id); + updateCount = updateCount + count; + } + + return Optional.of(updateCount == idList.size()); + } + + @Override + public List doSimilaritySearch(SearchRequest request) { + + String nativeFilterExpression = (request.getFilterExpression() != null) + ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : ""; + float[] embedding = this.embeddingModel.embed(request.getQuery()); + String jsonPathFilter = ""; + + if (StringUtils.hasText(nativeFilterExpression)) { + jsonPathFilter = "and " + nativeFilterExpression + " "; + } + String distanceType = this.distanceType.name().toLowerCase(Locale.ROOT); + + double distance = 1 - request.getSimilarityThreshold(); + final String sql = String.format( + "SELECT * FROM (select %s, %s, %s, %s, vec_distance_%s(%s, ?) as distance " + + "from %s) as t where distance < ? %sorder by distance asc LIMIT ?", + this.idFieldName, this.contentFieldName, this.metadataFieldName, this.embeddingFieldName, distanceType, + this.embeddingFieldName, getFullyQualifiedTableName(), + jsonPathFilter); + + logger.debug("SQL query: " + sql); + + return this.jdbcTemplate.query(sql, new DocumentRowMapper(this.objectMapper), embedding, distance, + request.getTopK()); + } + + // --------------------------------------------------------------------------------- + // Initialize + // --------------------------------------------------------------------------------- + @Override + public void afterPropertiesSet() { + + logger.info("Initializing MariaDBVectorStore schema for table: {} in schema: {}", this.vectorTableName, + this.schemaName); + + logger.info("vectorTableValidationsEnabled {}", this.schemaValidation); + + if (this.schemaValidation) { + this.schemaValidator.validateTableSchema(this.schemaName, this.vectorTableName, idFieldName, + contentFieldName, metadataFieldName, embeddingFieldName, this.embeddingDimensions()); + } + + if (!this.initializeSchema) { + logger.debug("Skipping the schema initialization for the table: {}", this.getFullyQualifiedTableName()); + return; + } + + if (this.schemaName != null) + this.jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", this.schemaName)); + + // Remove existing VectorStoreTable + if (this.removeExistingVectorStoreTable) { + this.jdbcTemplate.execute(String.format("DROP TABLE IF EXISTS %s", this.getFullyQualifiedTableName())); + } + + this.jdbcTemplate.execute(String.format(""" + CREATE TABLE IF NOT EXISTS %s ( + %s UUID NOT NULL DEFAULT uuid() PRIMARY KEY, + %s TEXT, + %s JSON, + %s VECTOR(%d) NOT NULL, + VECTOR INDEX %s_idx (%s) + ) ENGINE=InnoDB + """, this.getFullyQualifiedTableName(), idFieldName, contentFieldName, metadataFieldName, + embeddingFieldName, this.embeddingDimensions(), + (vectorTableName + "_" + embeddingFieldName).replaceAll("[^\\n\\r\\t\\p{Print}]", ""), + embeddingFieldName)); + } + + private String getFullyQualifiedTableName() { + if (this.schemaName != null) + return this.schemaName + "." + this.vectorTableName; + return this.vectorTableName; + } + + int embeddingDimensions() { + // The manually set dimensions have precedence over the computed one. + if (this.dimensions > 0) { + return this.dimensions; + } + + try { + int embeddingDimensions = this.embeddingModel.dimensions(); + if (embeddingDimensions > 0) { + return embeddingDimensions; + } + } + catch (Exception e) { + logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to" + + " default:" + OPENAI_EMBEDDING_DIMENSION_SIZE, e); + } + return OPENAI_EMBEDDING_DIMENSION_SIZE; + } + + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + + return VectorStoreObservationContext.builder(VectorStoreProvider.MARIADB.value(), operationName) + .withCollectionName(this.vectorTableName) + .withDimensions(this.embeddingDimensions()) + .withNamespace(this.schemaName) + .withSimilarityMetric(getSimilarityMetric()); + } + + private String getSimilarityMetric() { + if (!SIMILARITY_TYPE_MAPPING.containsKey(this.getDistanceType())) { + return this.getDistanceType().name(); + } + return SIMILARITY_TYPE_MAPPING.get(this.distanceType).value(); + } + + public enum MariaDBDistanceType { + + EUCLIDEAN, COSINE; + + } + + private static class DocumentRowMapper implements RowMapper { + + private final ObjectMapper objectMapper; + + public DocumentRowMapper(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + } + + @Override + public Document mapRow(ResultSet rs, int rowNum) throws SQLException { + String id = rs.getString(1); + String content = rs.getString(2); + Map metadata = toMap(rs.getString(3)); + float[] embedding = rs.getObject(4, float[].class); + float distance = rs.getFloat(5); + + metadata.put("distance", distance); + + Document document = new Document(id, content, metadata); + document.setEmbedding(embedding); + + return document; + } + + private Map toMap(String source) { + try { + return (Map) this.objectMapper.readValue(source, Map.class); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + } + + public static class Builder { + + private String contentFieldName = DEFAULT_COLUMN_CONTENT; + + private String embeddingFieldName = DEFAULT_COLUMN_EMBEDDING; + + private String idFieldName = DEFAULT_COLUMN_ID; + + private String metadataFieldName = DEFAULT_COLUMN_METADATA; + + private final JdbcTemplate jdbcTemplate; + + private final EmbeddingModel embeddingModel; + + private String schemaName = null; + + private String vectorTableName; + + private boolean vectorTableValidationsEnabled = MariaDBVectorStore.DEFAULT_SCHEMA_VALIDATION; + + private int dimensions = MariaDBVectorStore.INVALID_EMBEDDING_DIMENSION; + + private MariaDBDistanceType distanceType = MariaDBDistanceType.COSINE; + + private boolean removeExistingVectorStoreTable = false; + + private boolean initializeSchema; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + + private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE; + + @Nullable + private VectorStoreObservationConvention searchObservationConvention; + + // Builder constructor with mandatory parameters + public Builder(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + if (jdbcTemplate == null || embeddingModel == null) { + throw new IllegalArgumentException("JdbcTemplate and EmbeddingModel must not be null"); + } + this.jdbcTemplate = jdbcTemplate; + this.embeddingModel = embeddingModel; + } + + public Builder withSchemaName(String schemaName) { + this.schemaName = schemaName; + return this; + } + + public Builder withVectorTableName(String vectorTableName) { + this.vectorTableName = vectorTableName; + return this; + } + + public Builder withVectorTableValidationsEnabled(boolean vectorTableValidationsEnabled) { + this.vectorTableValidationsEnabled = vectorTableValidationsEnabled; + return this; + } + + public Builder withDimensions(int dimensions) { + this.dimensions = dimensions; + return this; + } + + public Builder withDistanceType(MariaDBDistanceType distanceType) { + this.distanceType = distanceType; + return this; + } + + public Builder withRemoveExistingVectorStoreTable(boolean removeExistingVectorStoreTable) { + this.removeExistingVectorStoreTable = removeExistingVectorStoreTable; + return this; + } + + public Builder withInitializeSchema(boolean initializeSchema) { + this.initializeSchema = initializeSchema; + return this; + } + + public Builder withObservationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public Builder withSearchObservationConvention(VectorStoreObservationConvention customObservationConvention) { + this.searchObservationConvention = customObservationConvention; + return this; + } + + public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) { + this.batchingStrategy = batchingStrategy; + return this; + } + + public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + return this; + } + + /** + * Configures the content field name to use. + * @param name the content field name to use + * @return this builder + */ + public Builder withContentFieldName(String name) { + this.contentFieldName = name; + return this; + } + + /** + * Configures the embedding field name to use. + * @param name the embedding field name to use + * @return this builder + */ + public Builder withEmbeddingFieldName(String name) { + this.embeddingFieldName = name; + return this; + } + + /** + * Configures the id field name to use. + * @param name the id field name to use + * @return this builder + */ + public Builder withIdFieldName(String name) { + this.idFieldName = name; + return this; + } + + /** + * Configures the metadata field name to use. + * @param name the metadata field name to use + * @return this builder + */ + public Builder withMetadataFieldName(String name) { + this.metadataFieldName = name; + return this; + } + + public MariaDBVectorStore build() { + return new MariaDBVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled, + this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType, + this.removeExistingVectorStoreTable, this.initializeSchema, this.observationRegistry, + this.searchObservationConvention, this.batchingStrategy, this.maxDocumentBatchSize, + this.contentFieldName, this.embeddingFieldName, this.idFieldName, this.metadataFieldName); + } + + } + +} diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBEmbeddingDimensionsTests.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBEmbeddingDimensionsTests.java new file mode 100644 index 00000000000..57be05b33da --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBEmbeddingDimensionsTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.only; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.jdbc.core.JdbcTemplate; + +/** + * @author Diego Dupin + */ +@ExtendWith(MockitoExtension.class) +public class MariaDBEmbeddingDimensionsTests { + + @Mock + private EmbeddingModel embeddingModel; + + @Mock + private JdbcTemplate jdbcTemplate; + + @Test + public void explicitlySetDimensions() { + + final int explicitDimensions = 696; + + var dim = new MariaDBVectorStore(this.jdbcTemplate, this.embeddingModel, explicitDimensions) + .embeddingDimensions(); + + assertThat(dim).isEqualTo(explicitDimensions); + verify(this.embeddingModel, never()).dimensions(); + } + + @Test + public void embeddingModelDimensions() { + when(this.embeddingModel.dimensions()).thenReturn(969); + + var dim = new MariaDBVectorStore(this.jdbcTemplate, this.embeddingModel).embeddingDimensions(); + + assertThat(dim).isEqualTo(969); + + verify(this.embeddingModel, only()).dimensions(); + } + + @Test + public void fallBackToDefaultDimensions() { + + when(this.embeddingModel.dimensions()).thenThrow(new RuntimeException()); + + var dim = new MariaDBVectorStore(this.jdbcTemplate, this.embeddingModel).embeddingDimensions(); + + assertThat(dim).isEqualTo(MariaDBVectorStore.OPENAI_EMBEDDING_DIMENSION_SIZE); + verify(this.embeddingModel, only()).dimensions(); + } + +} diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBFilterExpressionConverterTests.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBFilterExpressionConverterTests.java new file mode 100644 index 00000000000..019cb69a0c1 --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBFilterExpressionConverterTests.java @@ -0,0 +1,124 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.IN; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.LTE; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; + +import java.util.List; +import org.junit.jupiter.api.Test; +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.Group; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Value; +import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; + +/** + * @author Diego Dupin + */ +public class MariaDBFilterExpressionConverterTests { + + FilterExpressionConverter converter = new MariaDBFilterExpressionConverter("metadata"); + + @Test + public void testEQ() { + // country == "BG" + String vectorExpr = this.converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); + assertThat(vectorExpr).isEqualTo("JSON_VALUE(metadata, '$.country') = 'BG'"); + } + + @Test + public void tesEqAndGte() { + // genre == "drama" AND year >= 2020 + String vectorExpr = this.converter + .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), + new Expression(GTE, new Key("year"), new Value(2020)))); + assertThat(vectorExpr) + .isEqualTo("JSON_VALUE(metadata, '$.genre') = 'drama' AND JSON_VALUE(metadata, '$.year') >=" + " 2020"); + } + + @Test + public void tesIn() { + // genre in ["comedy", "documentary", "drama"] + String vectorExpr = this.converter.convertExpression( + new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); + assertThat(vectorExpr).isEqualTo("JSON_VALUE(metadata, '$.genre') IN ('comedy','documentary','drama')"); + } + + @Test + public void testNe() { + // year >= 2020 OR country == "BG" AND city != "Sofia" + String vectorExpr = this.converter + .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), + new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), + new Expression(NE, new Key("city"), new Value("Sofia"))))); + assertThat(vectorExpr) + .isEqualTo("JSON_VALUE(metadata, '$.year') >= 2020 OR JSON_VALUE(metadata, '$.country') = 'BG'" + + " AND JSON_VALUE(metadata, '$.city') != 'Sofia'"); + } + + @Test + public void testGroup() { + // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] + String vectorExpr = this.converter.convertExpression(new Expression(AND, + new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), + new Expression(EQ, new Key("country"), new Value("BG")))), + new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv"))))); + assertThat(vectorExpr) + .isEqualTo("(JSON_VALUE(metadata, '$.year') >= 2020 OR JSON_VALUE(metadata, '$.country') =" + + " 'BG') AND JSON_VALUE(metadata, '$.city') NOT IN ('Sofia','Plovdiv')"); + } + + @Test + public void testBoolean() { + // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] + String vectorExpr = this.converter.convertExpression(new Expression(AND, + new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), + new Expression(GTE, new Key("year"), new Value(2020))), + new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); + + assertThat(vectorExpr) + .isEqualTo("JSON_VALUE(metadata, '$.isOpen') = true AND JSON_VALUE(metadata, '$.year') >= 2020" + + " AND JSON_VALUE(metadata, '$.country') IN ('BG','NL','US')"); + } + + @Test + public void testDecimal() { + // temperature >= -15.6 && temperature <= +20.13 + String vectorExpr = this.converter + .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), + new Expression(LTE, new Key("temperature"), new Value(20.13)))); + + assertThat(vectorExpr).isEqualTo("JSON_VALUE(metadata, '$.temperature') >= -15.6 AND JSON_VALUE(metadata," + + " '$.temperature') <= 20.13"); + } + + @Test + public void testComplexIdentifiers() { + String vectorExpr = this.converter + .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); + assertThat(vectorExpr).isEqualTo("JSON_VALUE(metadata, '$.\"country 1 2 3\"') = 'BG'"); + } + +} diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBImage.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBImage.java new file mode 100644 index 00000000000..b7267867629 --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBImage.java @@ -0,0 +1,28 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import org.testcontainers.utility.DockerImageName; + +/** + * @author Diego Dupin + */ +public class MariaDBImage { + + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("mariadb:11.7-rc"); + +} diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreCustomNamesIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreCustomNamesIT.java new file mode 100644 index 00000000000..33adf505e15 --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreCustomNamesIT.java @@ -0,0 +1,256 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.zaxxer.hikari.HikariDataSource; +import javax.sql.DataSource; +import org.junit.jupiter.api.Test; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +/** + * @author Diego Dupin + */ +@Testcontainers +public class MariaDBStoreCustomNamesIT { + + private static String schemaName = "testdb"; + + @Container + @SuppressWarnings("resource") + static MariaDBContainer mariadbContainer = new MariaDBContainer<>(MariaDBImage.DEFAULT_IMAGE) + .withUsername("mariadb") + .withPassword("mariadbpwd") + .withDatabaseName(schemaName); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues("test.spring.ai.vectorstore.mariadb.distanceType=COSINE", + + // JdbcTemplate configuration + "app.datasource.url=" + mariadbContainer.getJdbcUrl(), + "app.datasource.username=mariadb", "app.datasource.password=mariadbpwd", + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + private static void dropTableByName(ApplicationContext context, String name) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + jdbcTemplate.execute("DROP TABLE IF EXISTS " + schemaName + "." + name); + } + + private static boolean isIndexExists(ApplicationContext context, String schemaName, String tableName, + String indexName) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + String sql = "SELECT EXISTS (SELECT * FROM information_schema.statistics WHERE TABLE_SCHEMA=? AND" + + " TABLE_NAME=? AND INDEX_NAME=? AND INDEX_TYPE='VECTOR')"; + return jdbcTemplate.queryForObject(sql, Boolean.class, schemaName, tableName, indexName); + } + + @SuppressWarnings("null") + private static boolean isTableExists(ApplicationContext context, String tableName) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + return jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT * FROM information_schema.tables WHERE table_schema= ? AND" + " table_name = ?)", + Boolean.class, schemaName, tableName); + } + + private static boolean areColumnsExisting(ApplicationContext context, String tableName, String[] fieldNames) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + for (String field : fieldNames) { + boolean fieldsExists = jdbcTemplate + .queryForObject("SELECT EXISTS (SELECT * FROM information_schema.columns WHERE table_schema= ? AND" + + " table_name = ? AND column_name = ?)", Boolean.class, schemaName, tableName, field); + if (!fieldsExists) + return false; + } + return true; + } + + private static boolean isSchemaExists(ApplicationContext context, String schemaName) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + String sql = "SELECT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = ?)"; + return jdbcTemplate.queryForObject(sql, Boolean.class, schemaName); + } + + @Test + public void shouldCreateDefaultTableAndIndexIfNotPresentInConfig() { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.mariadb.schemaValidation=false") + .run(context -> { + assertThat(context).hasNotFailed(); + assertThat(isTableExists(context, "vector_store")).isTrue(); + assertThat(isSchemaExists(context, schemaName)).isTrue(); + dropTableByName(context, "vector_store"); + }); + } + + @Test + public void shouldCreateTableAndIndexIfNotPresentInDatabase() { + String tableName = "new_vector_table"; + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.mariadb.vectorTableName=" + tableName) + .run(context -> { + assertThat(isTableExists(context, tableName)).isTrue(); + assertThat(isIndexExists(context, schemaName, tableName, tableName + "_embedding_idx")).isTrue(); + assertThat(isTableExists(context, "vector_store")).isFalse(); + dropTableByName(context, tableName); + }); + } + + @Test + public void shouldCreateSpecificTableAndIndexIfNotPresentInDatabase() { + String tableName = "new_vector_table2"; + String contentFieldName = "content2"; + String embeddingFieldName = "embedding2"; + String idFieldName = "id2"; + String metadataFieldName = "metadata2"; + + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.mariadb.vectorTableName=" + tableName) + .withPropertyValues("test.spring.ai.vectorstore.mariadb.contentFieldName=" + contentFieldName) + .withPropertyValues("test.spring.ai.vectorstore.mariadb.embeddingFieldName=" + embeddingFieldName) + .withPropertyValues("test.spring.ai.vectorstore.mariadb.idFieldName=" + idFieldName) + .withPropertyValues("test.spring.ai.vectorstore.mariadb.metadataFieldName=" + metadataFieldName) + .run(context -> { + assertThat(isTableExists(context, tableName)).isTrue(); + assertThat(isIndexExists(context, schemaName, tableName, tableName + "_embedding_idx")).isTrue(); + assertThat(isTableExists(context, "vector_store")).isFalse(); + assertThat(areColumnsExisting(context, tableName, + new String[] { contentFieldName, embeddingFieldName, idFieldName, metadataFieldName })) + .isFalse(); + dropTableByName(context, tableName); + }); + } + + @Test + public void shouldFailWhenCustomTableIsAbsentAndValidationEnabled() { + + String tableName = "customvectortable"; + + this.contextRunner + .withPropertyValues("test.spring.ai.vectorstore.mariadb.vectorTableName=" + tableName, + "test.spring.ai.vectorstore.mariadb.schemaValidation=true") + .run(context -> { + assertThat(context).hasFailed(); + assertThat(context.getStartupFailure()).hasCauseInstanceOf(IllegalStateException.class) + .hasMessageContaining("Table 'customvectortable' does not exist in schema 'testdb'"); + }); + } + + @Test + public void shouldFailOnSQLInjectionAttemptInTableName() { + + String tableName = "users; DROP TABLE users;"; + + this.contextRunner + .withPropertyValues("test.spring.ai.vectorstore.mariadb.vectorTableName=" + tableName, + "test.spring.ai.vectorstore.mariadb.schemaValidation=true") + .run(context -> { + assertThat(context).hasFailed(); + assertThat(context.getStartupFailure().getCause()).hasCauseInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Identifier '`users; DROP TABLE users;`' should only contain alphanumeric" + + " characters and underscores"); + }); + } + + @Test + public void shouldFailOnSQLInjectionAttemptInSchemaName() { + + String schemaName = "public; DROP TABLE users;"; + String tableName = "customvectortable"; + + this.contextRunner + .withPropertyValues("test.spring.ai.vectorstore.mariadb.vectorTableName=" + tableName, + "test.spring.ai.vectorstore.mariadb.schemaName=" + schemaName, + "test.spring.ai.vectorstore.mariadb.schemaValidation=true") + .run(context -> { + assertThat(context).hasFailed(); + assertThat(context.getStartupFailure().getCause()).hasCauseInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Identifier '`public; DROP TABLE users;`' should only contain alphanumeric" + + " characters and underscores"); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + public static class TestApplication { + + @Value("${test.spring.ai.vectorstore.mariadb.vectorTableName:}") + String vectorTableName; + + @Value("${test.spring.ai.vectorstore.mariadb.schemaName:testdb}") + String schemaName; + + @Value("${test.spring.ai.vectorstore.mariadb.schemaValidation:false}") + boolean schemaValidation; + + int dimensions = 1536; + + @Bean + public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + + return new MariaDBVectorStore.Builder(jdbcTemplate, embeddingModel).withSchemaName(this.schemaName) + .withVectorTableName(this.vectorTableName) + .withVectorTableValidationsEnabled(this.schemaValidation) + .withDimensions(this.dimensions) + .withDistanceType(MariaDBVectorStore.MariaDBDistanceType.COSINE) + .withRemoveExistingVectorStoreTable(true) + .withInitializeSchema(true) + .build(); + } + + // + @Bean + public JdbcTemplate myJdbcTemplate(DataSource dataSource) { + JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); + + return jdbcTemplate; + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build(); + } + + @Bean + public EmbeddingModel embeddingModel() { + return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + } + + } + +} diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreIT.java new file mode 100644 index 00000000000..efc6b102d9b --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreIT.java @@ -0,0 +1,374 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.zaxxer.hikari.HikariDataSource; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Stream; +import javax.sql.DataSource; +import org.junit.Assert; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser.FilterExpressionParseException; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.util.CollectionUtils; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +/** + * @author Diego Dupin + */ +@Testcontainers +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +public class MariaDBStoreIT { + + private static String schemaName = "testdb"; + + @Container + @SuppressWarnings("resource") + static MariaDBContainer mariadbContainer = new MariaDBContainer<>(MariaDBImage.DEFAULT_IMAGE) + .withUsername("mariadb") + .withPassword("mariadbpwd") + .withDatabaseName(schemaName); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues("test.spring.ai.vectorstore.mariadb.distanceType=COSINE", + + // JdbcTemplate configuration + String.format("app.datasource.url=jdbc:mariadb://%s:%d/%s?maxQuerySizeToLog=50000", mariadbContainer.getHost(), + mariadbContainer.getMappedPort(3306), schemaName), + "app.datasource.username=mariadb", "app.datasource.password=mariadbpwd", + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void dropTable(ApplicationContext context) { + JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class); + jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store"); + } + + static Stream provideFilters() { + return Stream.of(Arguments.of("country in ['BG','NL']", 3), // String Filters In + Arguments.of("year in [2020]", 1), // Numeric Filters In + Arguments.of("country not in ['BG']", 1), // String Filter Not In + Arguments.of("year not in [2020]", 1) // Numeric Filter Not In + ); + } + + private static boolean isSortedByDistance(List docs) { + + List distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList(); + + if (CollectionUtils.isEmpty(distances) || distances.size() == 1) { + return true; + } + + Iterator iter = distances.iterator(); + Float current, previous = iter.next(); + while (iter.hasNext()) { + current = iter.next(); + if (previous > current) { + return false; + } + previous = current; + } + return true; + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "COSINE", "EUCLIDEAN" }) + public void addAndSearch(String distanceType) { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.mariadb.distanceType=" + distanceType) + .run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + vectorStore.add(this.documents); + + List results = vectorStore + .similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId()); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + + // Remove all documents from the store + vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); + + List results2 = vectorStore + .similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); + assertThat(results2).hasSize(0); + + dropTable(context); + }); + } + + @ParameterizedTest(name = "Filter expression {0} should return {1} records ") + @MethodSource("provideFilters") + public void searchWithInFilter(String expression, Integer expectedRecords) { + + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.mariadb.distanceType=COSINE").run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2020, "foo bar 1", "bar.foo")); + var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "NL")); + var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2023)); + + vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); + + SearchRequest searchRequest = SearchRequest.query("The World") + .withFilterExpression(expression) + .withTopK(5) + .withSimilarityThresholdAll(); + + List results = vectorStore.similaritySearch(searchRequest); + + assertThat(results).hasSize(expectedRecords); + + // Remove all documents from the store + dropTable(context); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "COSINE", "EUCLIDEAN" }) + public void searchWithFilters(String distanceType) { + + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.mariadb.distanceType=" + distanceType) + .run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2020, "foo bar 1", "bar.foo")); + var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "NL")); + var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner", + Map.of("country", "BG", "year", 2023)); + + vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); + + SearchRequest searchRequest = SearchRequest.query("The World").withTopK(5).withSimilarityThresholdAll(); + + List results = vectorStore.similaritySearch(searchRequest); + + assertThat(results).hasSize(3); + + results = vectorStore.similaritySearch(searchRequest.withFilterExpression("country == 'NL'")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + + results = vectorStore.similaritySearch(searchRequest.withFilterExpression("country == 'BG'")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + + results = vectorStore + .similaritySearch(searchRequest.withFilterExpression("country == 'BG' && year == 2020")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + + results = vectorStore.similaritySearch( + searchRequest.withFilterExpression("(country == 'BG' && year == 2020) || (country == 'NL')")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), nlDocument.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), nlDocument.getId()); + + results = vectorStore.similaritySearch(searchRequest + .withFilterExpression("NOT((country == 'BG' && year == 2020) || (country == 'NL'))")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("\"foo bar 1\" == 'bar.foo'")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + + try { + vectorStore.similaritySearch(searchRequest.withFilterExpression("country == NL")); + Assert.fail("Invalid filter expression should have been cached!"); + } + catch (FilterExpressionParseException e) { + assertThat(e.getMessage()).contains("Line: 1:17, Error: no viable alternative at input 'NL'"); + } + + // Remove all documents from the store + dropTable(context); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "COSINE", "EUCLIDEAN" }) + public void documentUpdate(String distanceType) { + + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.mariadb.distanceType=" + distanceType) + .run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!", + Collections.singletonMap("meta1", "meta1")); + + vectorStore.add(List.of(document)); + + List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(5)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("Spring AI rocks!!"); + assertThat(resultDoc.getMetadata()).containsKeys("meta1", "distance"); + + Document sameIdDocument = new Document(document.getId(), + "The World is Big and Salvation Lurks Around the Corner", + Collections.singletonMap("meta2", "meta2")); + + vectorStore.add(List.of(sameIdDocument)); + + results = vectorStore.similaritySearch(SearchRequest.query("FooBar").withTopK(5)); + + assertThat(results).hasSize(1); + resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(document.getId()); + assertThat(resultDoc.getContent()).isEqualTo("The World is Big and Salvation Lurks Around the Corner"); + assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance"); + + dropTable(context); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "COSINE", "EUCLIDEAN" }) + public void searchWithThreshold(String distanceType) { + + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.mariadb.distanceType=" + distanceType) + .run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + vectorStore.add(this.documents); + + List fullResult = vectorStore + .similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThresholdAll()); + + assertThat(fullResult).hasSize(3); + + assertThat(isSortedByDistance(fullResult)).isTrue(); + + List distances = fullResult.stream() + .map(doc -> (Float) doc.getMetadata().get("distance")) + .toList(); + + float threshold = (distances.get(0) + distances.get(1)) / 2; + + List results = vectorStore.similaritySearch( + SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThreshold(1 - threshold)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId()); + + dropTable(context); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + public static class TestApplication { + + @Value("${test.spring.ai.vectorstore.mariadb.distanceType}") + MariaDBVectorStore.MariaDBDistanceType distanceType; + + @Bean + public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { + return new MariaDBVectorStore(jdbcTemplate, embeddingModel, MariaDBVectorStore.INVALID_EMBEDDING_DIMENSION, + this.distanceType, true, true); + } + + @Bean + public JdbcTemplate myJdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build(); + } + + @Bean + public EmbeddingModel embeddingModel() { + return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + } + + } + +} diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreObservationIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreObservationIT.java new file mode 100644 index 00000000000..3ef7b688064 --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreObservationIT.java @@ -0,0 +1,207 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.zaxxer.hikari.HikariDataSource; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import javax.sql.DataSource; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.observation.conventions.SpringAiKind; +import org.springframework.ai.observation.conventions.VectorStoreProvider; +import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.HighCardinalityKeyNames; +import org.springframework.ai.vectorstore.observation.VectorStoreObservationDocumentation.LowCardinalityKeyNames; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +/** + * @author Diego Dupin + */ +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +@Testcontainers +public class MariaDBStoreObservationIT { + + private static String schemaName = "testdb"; + + @Container + @SuppressWarnings("resource") + static MariaDBContainer mariadbContainer = new MariaDBContainer<>(MariaDBImage.DEFAULT_IMAGE) + .withUsername("mariadb") + .withPassword("mariadbpwd") + .withDatabaseName(schemaName); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(Config.class) + .withPropertyValues("test.spring.ai.vectorstore.mariadb.distanceType=COSINE", + // JdbcTemplate configuration + "app.datasource.url=" + mariadbContainer.getJdbcUrl(), + "app.datasource.username=mariadb", "app.datasource.password=mariadbpwd", + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + List documents = List.of( + new Document(getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")), + new Document(getText("classpath:/test/data/time.shelter.txt")), + new Document(getText("classpath:/test/data/great.depression.txt"), Map.of("meta2", "meta2"))); + + public static String getText(String uri) { + var resource = new DefaultResourceLoader().getResource(uri); + try { + return resource.getContentAsString(StandardCharsets.UTF_8); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Test + void observationVectorStoreAddAndQueryOperations() { + + this.contextRunner.run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + + TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); + + vectorStore.add(this.documents); + + TestObservationRegistryAssert.assertThat(observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultVectorStoreObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("%s add".formatted(VectorStoreProvider.MARIADB.value())) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.DB_OPERATION_NAME.asString(), "add") + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.DB_SYSTEM.asString(), + VectorStoreProvider.MARIADB.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.SPRING_AI_KIND.asString(), + SpringAiKind.VECTOR_STORE.value()) + .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_QUERY_CONTENT.asString()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_DIMENSION_COUNT.asString(), "1536") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(), + MariaDBVectorStore.DEFAULT_TABLE_NAME) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_NAMESPACE.asString(), schemaName) + .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString(), + VectorStoreSimilarityMetric.COSINE.value()) + .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString()) + .doesNotHaveHighCardinalityKeyValueWithKey( + HighCardinalityKeyNames.DB_VECTOR_QUERY_SIMILARITY_THRESHOLD.asString()) + .hasBeenStarted() + .hasBeenStopped(); + + observationRegistry.clear(); + + List results = vectorStore + .similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1)); + + assertThat(results).isNotEmpty(); + + TestObservationRegistryAssert.assertThat(observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultVectorStoreObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("%s query".formatted(VectorStoreProvider.MARIADB.value())) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.DB_OPERATION_NAME.asString(), "query") + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.DB_SYSTEM.asString(), + VectorStoreProvider.MARIADB.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.SPRING_AI_KIND.asString(), + SpringAiKind.VECTOR_STORE.value()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_QUERY_CONTENT.asString(), + "What is Great Depression") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_DIMENSION_COUNT.asString(), "1536") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_COLLECTION_NAME.asString(), + MariaDBVectorStore.DEFAULT_TABLE_NAME) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_NAMESPACE.asString(), schemaName) + .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.DB_VECTOR_FIELD_NAME.asString()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_SEARCH_SIMILARITY_METRIC.asString(), + VectorStoreSimilarityMetric.COSINE.value()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_QUERY_TOP_K.asString(), "1") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.DB_VECTOR_QUERY_SIMILARITY_THRESHOLD.asString(), + "0.0") + .hasBeenStarted() + .hasBeenStopped(); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, + ObservationRegistry observationRegistry) { + return new MariaDBVectorStore.Builder(jdbcTemplate, embeddingModel).withSchemaName(schemaName) + .withDistanceType(MariaDBVectorStore.MariaDBDistanceType.COSINE) + .withObservationRegistry(observationRegistry) + .withInitializeSchema(true) + .build(); + } + + @Bean + public JdbcTemplate myJdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build(); + } + + @Bean + public EmbeddingModel embeddingModel() { + return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + } + + } + +} diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreTests.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreTests.java new file mode 100644 index 00000000000..6f8ab4fc6ed --- /dev/null +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/MariaDBStoreTests.java @@ -0,0 +1,98 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import java.util.Collections; + +import org.junit.Assert; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.ArgumentCaptor; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.only; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * @author Diego Dupin + */ +public class MariaDBStoreTests { + + @ParameterizedTest(name = "{0} - enquote identifier validation") + @CsvSource({ + // Standard valid cases + "customvectorstore, true, `customvectorstore`", "user_data, true, `user_data`", "test123, true, `test123`", + "valid_table_name, true, `valid_table_name`", "customvectorstore, false, customvectorstore", + "user_data, false, user_data", "test123, false, test123", "valid_table_name, false, valid_table_name", + "1234567890123456789012345678901234567890123456789012345678901234, false, `1234567890123456789012345678901234567890123456789012345678901234`" }) + void enquoteIdentifier(String tableName, boolean alwaysQuote, String expected) { + assertThat(MariaDBSchemaValidator.validateAndEnquoteIdentifier(tableName, alwaysQuote)); + } + + @ParameterizedTest(name = "{0} - error identifier validation") + @CsvSource({ "12345678901234567890123456789012345678901234567890123456789012345, false", + "12345678901234567890123456789012345678901234567890123456789012345, true", + "customvectorstore;drop table users;, false", "some\u0000notpossibleValue, true" }) + void enquoteIdentifierThrow(String tableName, boolean alwaysQuote) { + Assert.assertThrows(IllegalArgumentException.class, + () -> MariaDBSchemaValidator.validateAndEnquoteIdentifier(tableName, alwaysQuote)); + } + + @Test + void shouldAddDocumentsInBatchesAndEmbedOnce() { + // Given + var jdbcTemplate = mock(JdbcTemplate.class); + var embeddingModel = mock(EmbeddingModel.class); + var mariadbVectorStore = new MariaDBVectorStore.Builder(jdbcTemplate, embeddingModel) + .withMaxDocumentBatchSize(1000) + .build(); + + // Testing with 9989 documents + var documents = Collections.nCopies(9989, new Document("foo")); + + // When + mariadbVectorStore.doAdd(documents); + + // Then + verify(embeddingModel, only()).embed(eq(documents), any(), any()); + + var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class); + verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture()); + + assertThat(batchUpdateCaptor.getAllValues()).hasSize(10) + .allSatisfy(BatchPreparedStatementSetter::getBatchSize) + .satisfies(batches -> { + for (int i = 0; i < 9; i++) { + assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i) + .isEqualTo(1000); + } + assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989); + }); + } + +}