diff --git a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java index 677a7fba4ee..88444c18d26 100644 --- a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java +++ b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java @@ -52,6 +52,8 @@ public class PostgresMlEmbeddingClient extends AbstractEmbeddingClient implement private final JdbcTemplate jdbcTemplate; + private final boolean skipCreateExtension; + public enum VectorType { PG_ARRAY("", null, (rs, i) -> { @@ -84,7 +86,7 @@ public enum VectorType { * @param jdbcTemplate JdbcTemplate */ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate) { - this(jdbcTemplate, PostgresMlEmbeddingOptions.builder().build()); + this(jdbcTemplate, PostgresMlEmbeddingOptions.builder().build(), false); } /** @@ -92,7 +94,8 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate) { * @param jdbcTemplate JdbcTemplate to use to interact with the database. * @param options PostgresMlEmbeddingOptions to configure the client. */ - public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options) { + public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options, + boolean skipCreateExtension) { Assert.notNull(jdbcTemplate, "jdbc template must not be null."); Assert.notNull(options, "options must not be null."); Assert.notNull(options.getTransformer(), "transformer must not be null."); @@ -102,6 +105,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingO this.jdbcTemplate = jdbcTemplate; this.defaultOptions = options; + this.skipCreateExtension = skipCreateExtension; } /** @@ -123,7 +127,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer) */ @Deprecated(since = "0.8.0", forRemoval = true) public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType) { - this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED); + this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED, false); } /** @@ -136,7 +140,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, */ @Deprecated(since = "0.8.0", forRemoval = true) public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType, - Map kwargs, MetadataMode metadataMode) { + Map kwargs, MetadataMode metadataMode, boolean skipCreateExtension) { Assert.notNull(jdbcTemplate, "jdbc template must not be null."); Assert.notNull(transformer, "transformer must not be null."); Assert.notNull(vectorType, "vectorType must not be null."); @@ -151,6 +155,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, .withMetadataMode(metadataMode) .withKwargs(ModelOptionsUtils.toJsonString(kwargs)) .build(); + this.skipCreateExtension = skipCreateExtension; } @SuppressWarnings("null") @@ -226,6 +231,10 @@ PostgresMlEmbeddingOptions mergeOptions(EmbeddingOptions requestOptions) { @Override public void afterPropertiesSet() { + if (this.skipCreateExtension) { + return; + } + this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS pgml"); this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore"); if (StringUtils.hasText(this.defaultOptions.getVectorType().extensionName)) { diff --git a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java index 7a4539e0e4e..4549f3ef6b4 100644 --- a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java +++ b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java @@ -21,17 +21,10 @@ import java.util.Map; import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; - -import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.ai.embedding.EmbeddingRequest; -import org.springframework.ai.embedding.EmbeddingResponse; -import org.springframework.ai.postgresml.PostgresMlEmbeddingClient.VectorType; - import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; import org.testcontainers.junit.jupiter.Container; @@ -40,6 +33,10 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.postgresml.PostgresMlEmbeddingClient.VectorType; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.test.autoconfigure.jdbc.AutoConfigureTestDatabase; @@ -55,13 +52,14 @@ @JdbcTest(properties = "logging.level.sql=TRACE") @AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE) @Testcontainers -@Disabled("Disabled from automatic execution, as it requires an excessive amount of memory (over 9GB)!") +// @Disabled("Disabled from automatic execution, as it requires an excessive amount of +// memory (over 9GB)!") class PostgresMlEmbeddingClientIT { @Container @ServiceConnection static PostgreSQLContainer postgres = new PostgreSQLContainer<>( - DockerImageName.parse("ghcr.io/postgresml/postgresml:2.8.1").asCompatibleSubstituteFor("postgres")) + DockerImageName.parse("ghcr.io/postgresml/postgresml:2.8.2").asCompatibleSubstituteFor("postgres")) .withCommand("sleep", "infinity") .withLabel("org.springframework.boot.service-connection", "postgres") .withUsername("postgresml") @@ -73,9 +71,11 @@ class PostgresMlEmbeddingClientIT { @Autowired JdbcTemplate jdbcTemplate; - @AfterEach + @BeforeEach void dropPgmlExtension() { this.jdbcTemplate.execute("DROP EXTENSION IF EXISTS pgml"); + this.jdbcTemplate.execute("DROP EXTENSION IF EXISTS hstore"); + this.jdbcTemplate.execute("DROP EXTENSION IF EXISTS vector"); } @Test @@ -94,7 +94,8 @@ void embedWithPgVector() { PostgresMlEmbeddingOptions.builder() .withTransformer("distilbert-base-uncased") .withVectorType(PostgresMlEmbeddingClient.VectorType.PG_VECTOR) - .build()); + .build(), + false); embeddingClient.afterPropertiesSet(); List embed = embeddingClient.embed(new Document("Hello World!")); @@ -105,7 +106,7 @@ void embedWithPgVector() { @Test void embedWithDifferentModel() { PostgresMlEmbeddingClient embeddingClient = new PostgresMlEmbeddingClient(this.jdbcTemplate, - PostgresMlEmbeddingOptions.builder().withTransformer("intfloat/e5-small").build()); + PostgresMlEmbeddingOptions.builder().withTransformer("intfloat/e5-small").build(), false); embeddingClient.afterPropertiesSet(); List embed = embeddingClient.embed(new Document("Hello World!")); @@ -121,7 +122,8 @@ void embedWithKwargs() { .withVectorType(PostgresMlEmbeddingClient.VectorType.PG_ARRAY) .withKwargs(Map.of("device", "cpu")) .withMetadataMode(MetadataMode.EMBED) - .build()); + .build(), + false); embeddingClient.afterPropertiesSet(); List embed = embeddingClient.embed(new Document("Hello World!")); @@ -136,7 +138,8 @@ void embedForResponse(String vectorType) { PostgresMlEmbeddingOptions.builder() .withTransformer("distilbert-base-uncased") .withVectorType(VectorType.valueOf(vectorType)) - .build()); + .build(), + false); embeddingClient.afterPropertiesSet(); EmbeddingResponse embeddingResponse = embeddingClient @@ -161,7 +164,8 @@ void embedCallWithRequestOptionsOverride() { PostgresMlEmbeddingOptions.builder() .withTransformer("distilbert-base-uncased") .withVectorType(VectorType.PG_VECTOR) - .build()); + .build(), + false); embeddingClient.afterPropertiesSet(); var request1 = new EmbeddingRequest(List.of("Hello World!", "Spring AI!", "LLM!"), EmbeddingOptions.EMPTY); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java index 78ac62dc4ea..c2d90857685 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java @@ -43,7 +43,8 @@ public class PostgresMlAutoConfiguration { public PostgresMlEmbeddingClient postgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingProperties embeddingProperties) { - return new PostgresMlEmbeddingClient(jdbcTemplate, embeddingProperties.getOptions()); + return new PostgresMlEmbeddingClient(jdbcTemplate, embeddingProperties.getOptions(), + embeddingProperties.isSkipCreateExtension()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java index c0b13b5404c..7a3904f1d82 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingProperties.java @@ -40,6 +40,8 @@ public class PostgresMlEmbeddingProperties { */ private boolean enabled = true; + private boolean skipCreateExtension = false; + @NestedConfigurationProperty private PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder() .withTransformer(PostgresMlEmbeddingClient.DEFAULT_TRANSFORMER_MODEL) @@ -70,4 +72,12 @@ public void setEnabled(boolean enabled) { this.enabled = enabled; } + public boolean isSkipCreateExtension() { + return skipCreateExtension; + } + + public void setSkipCreateExtension(boolean skipCreateExtension) { + this.skipCreateExtension = skipCreateExtension; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java index 6f5a84d2f7b..f30131de771 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlEmbeddingPropertiesTests.java @@ -37,7 +37,8 @@ @SpringBootTest(properties = { "spring.ai.postgresml.embedding.options.metadata-mode=all", "spring.ai.postgresml.embedding.options.kwargs.key1=value1", "spring.ai.postgresml.embedding.options.kwargs.key2=value2", - "spring.ai.postgresml.embedding.options.transformer=abc123" }) + "spring.ai.postgresml.embedding.options.transformer=abc123", + "spring.ai.postgresml.skip-create-extension=true" }) class PostgresMlEmbeddingPropertiesTests { @Autowired @@ -52,6 +53,7 @@ void postgresMlPropertiesAreCorrect() { assertThat(this.postgresMlProperties.getOptions().getKwargs()) .isEqualTo(Map.of("key1", "value1", "key2", "value2")); assertThat(this.postgresMlProperties.getOptions().getMetadataMode()).isEqualTo(MetadataMode.ALL); + assertThat(this.postgresMlProperties.isSkipCreateExtension()).isTrue(); } @SpringBootConfiguration