diff --git a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java index f14167b6b17..ebf6e4cde31 100644 --- a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java +++ b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java @@ -42,6 +42,8 @@ public class PostgresMlEmbeddingClient extends AbstractEmbeddingClient implement private final MetadataMode metadataMode; + private final boolean skipCreateExtension; + public enum VectorType { PG_ARRAY("", null, (rs, i) -> { @@ -91,7 +93,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer) * @param vectorType vector type in PostgreSQL */ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType) { - this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED); + this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED, false); } /** @@ -100,9 +102,10 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, * @param transformer huggingface sentence-transformer name * @param vectorType vector type in PostgreSQL * @param kwargs optional arguments + * @param skipCreateExtension whether to skip "CREATE EXTENSION" */ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType, - Map kwargs, MetadataMode metadataMode) { + Map kwargs, MetadataMode metadataMode, boolean skipCreateExtension) { Assert.notNull(jdbcTemplate, "jdbc template must not be null."); Assert.notNull(transformer, "transformer must not be null."); Assert.notNull(vectorType, "vectorType must not be null."); @@ -119,6 +122,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, catch (JsonProcessingException e) { throw new IllegalArgumentException(e); } + this.skipCreateExtension = skipCreateExtension; } @Override @@ -174,6 +178,9 @@ public EmbeddingResponse call(EmbeddingRequest request) { @Override public void afterPropertiesSet() { + if (this.skipCreateExtension) { + return; + } this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS pgml"); if (StringUtils.hasText(this.vectorType.extensionName)) { this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS " + this.vectorType.extensionName); diff --git a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java index 18a26372d59..ce1dd7cf274 100644 --- a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java +++ b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java @@ -6,7 +6,7 @@ import java.util.Map; import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -42,7 +42,7 @@ class PostgresMlEmbeddingClientIT { @Container @ServiceConnection static PostgreSQLContainer postgres = new PostgreSQLContainer<>( - DockerImageName.parse("ghcr.io/postgresml/postgresml:2.7.3").asCompatibleSubstituteFor("postgres")) + DockerImageName.parse("ghcr.io/postgresml/postgresml:2.7.13").asCompatibleSubstituteFor("postgres")) .withCommand("sleep", "infinity") .withLabel("org.springframework.boot.service-connection", "postgres") .withUsername("postgresml") @@ -54,7 +54,7 @@ class PostgresMlEmbeddingClientIT { @Autowired JdbcTemplate jdbcTemplate; - @AfterEach + @BeforeEach void dropPgmlExtension() { this.jdbcTemplate.execute("DROP EXTENSION IF EXISTS pgml"); } @@ -65,7 +65,6 @@ void embed() { embeddingClient.afterPropertiesSet(); List embed = embeddingClient.embed("Hello World!"); assertThat(embed).hasSize(768); - // embeddingClient.dropPgmlExtension(); } @Test @@ -75,7 +74,6 @@ void embedWithPgVector() { embeddingClient.afterPropertiesSet(); List embed = embeddingClient.embed(new Document("Hello World!")); assertThat(embed).hasSize(768); - // embeddingClient.dropPgmlExtension(); } @Test @@ -85,18 +83,16 @@ void embedWithDifferentModel() { embeddingClient.afterPropertiesSet(); List embed = embeddingClient.embed(new Document("Hello World!")); assertThat(embed).hasSize(384); - // embeddingClient.dropPgmlExtension(); } @Test void embedWithKwargs() { PostgresMlEmbeddingClient embeddingClient = new PostgresMlEmbeddingClient(this.jdbcTemplate, "distilbert-base-uncased", PostgresMlEmbeddingClient.VectorType.PG_ARRAY, Map.of("device", "cpu"), - MetadataMode.EMBED); + MetadataMode.EMBED, false); embeddingClient.afterPropertiesSet(); List embed = embeddingClient.embed(new Document("Hello World!")); assertThat(embed).hasSize(768); - // embeddingClient.dropPgmlExtension(); } @ParameterizedTest @@ -117,7 +113,6 @@ void embedForResponse(String vectorType) { assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(768); assertThat(embeddingResponse.getResults().get(2).getIndex()).isEqualTo(2); assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(768); - // embeddingClient.dropPgmlExtension(); } @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java index 003670cc96f..1949a798d33 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlAutoConfiguration.java @@ -39,7 +39,7 @@ public EmbeddingClient postgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, PostgresMlProperties postgresMlProperties) { return new PostgresMlEmbeddingClient(jdbcTemplate, postgresMlProperties.getEmbedding().getTransformer(), postgresMlProperties.getEmbedding().getVectorType(), postgresMlProperties.getEmbedding().getKwargs(), - postgresMlProperties.getEmbedding().getMetadataMode()); + postgresMlProperties.getEmbedding().getMetadataMode(), postgresMlProperties.isSkipCreateExtension()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlProperties.java index f1a3d5614ff..514ba05385f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlProperties.java @@ -42,6 +42,8 @@ public class PostgresMlProperties { private MetadataMode metadataMode = MetadataMode.EMBED; + private boolean skipCreateExtension = false; + public PostgresMlProperties.Embedding getEmbedding() { return this.embedding; } @@ -78,6 +80,15 @@ public void setMetadataMode(MetadataMode metadataMode) { this.metadataMode = metadataMode; } + public boolean isSkipCreateExtension() { + return skipCreateExtension; + } + + public PostgresMlProperties setSkipCreateExtension(boolean skipCreateExtension) { + this.skipCreateExtension = skipCreateExtension; + return this; + } + public static class Embedding { private PostgresMlProperties postgresMlProperties; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlPropertiesTests.java index eacbee53608..67dd965a80e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/postgresml/PostgresMlPropertiesTests.java @@ -35,7 +35,8 @@ * @author Utkarsh Srivastava */ @SpringBootTest(properties = { "spring.ai.postgresml.metadata-mode=all", "spring.ai.postgresml.kwargs.key1=value1", - "spring.ai.postgresml.kwargs.key2=value2", "spring.ai.postgresml.embedding.transformer=abc123" }) + "spring.ai.postgresml.kwargs.key2=value2", "spring.ai.postgresml.embedding.transformer=abc123", + "spring.ai.postgresml.skip-create-extension=true" }) class PostgresMlPropertiesTests { @Autowired @@ -48,6 +49,7 @@ void postgresMlPropertiesAreCorrect() { assertThat(this.postgresMlProperties.getVectorType()).isEqualTo(PostgresMlEmbeddingClient.VectorType.PG_ARRAY); assertThat(this.postgresMlProperties.getKwargs()).isEqualTo(Map.of("key1", "value1", "key2", "value2")); assertThat(this.postgresMlProperties.getMetadataMode()).isEqualTo(MetadataMode.ALL); + assertThat(this.postgresMlProperties.isSkipCreateExtension()).isTrue(); PostgresMlProperties.Embedding embedding = this.postgresMlProperties.getEmbedding();