Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public class PostgresMlEmbeddingClient extends AbstractEmbeddingClient implement

private final JdbcTemplate jdbcTemplate;

private final boolean skipCreateExtension;

public enum VectorType {

PG_ARRAY("", null, (rs, i) -> {
Expand Down Expand Up @@ -84,15 +86,16 @@ public enum VectorType {
* @param jdbcTemplate JdbcTemplate
*/
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate) {
this(jdbcTemplate, PostgresMlEmbeddingOptions.builder().build());
this(jdbcTemplate, PostgresMlEmbeddingOptions.builder().build(), false);
}

/**
* a PostgresMlEmbeddingClient constructor
* @param jdbcTemplate JdbcTemplate to use to interact with the database.
* @param options PostgresMlEmbeddingOptions to configure the client.
*/
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options) {
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options,
boolean skipCreateExtension) {
Assert.notNull(jdbcTemplate, "jdbc template must not be null.");
Assert.notNull(options, "options must not be null.");
Assert.notNull(options.getTransformer(), "transformer must not be null.");
Expand All @@ -102,6 +105,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingO

this.jdbcTemplate = jdbcTemplate;
this.defaultOptions = options;
this.skipCreateExtension = skipCreateExtension;
}

/**
Expand All @@ -123,7 +127,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer)
*/
@Deprecated(since = "0.8.0", forRemoval = true)
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType) {
this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED);
this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED, false);
}

/**
Expand All @@ -136,7 +140,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer,
*/
@Deprecated(since = "0.8.0", forRemoval = true)
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType,
Map<String, Object> kwargs, MetadataMode metadataMode) {
Map<String, Object> kwargs, MetadataMode metadataMode, boolean skipCreateExtension) {
Assert.notNull(jdbcTemplate, "jdbc template must not be null.");
Assert.notNull(transformer, "transformer must not be null.");
Assert.notNull(vectorType, "vectorType must not be null.");
Expand All @@ -151,6 +155,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer,
.withMetadataMode(metadataMode)
.withKwargs(ModelOptionsUtils.toJsonString(kwargs))
.build();
this.skipCreateExtension = skipCreateExtension;
}

@SuppressWarnings("null")
Expand Down Expand Up @@ -226,6 +231,10 @@ PostgresMlEmbeddingOptions mergeOptions(EmbeddingOptions requestOptions) {

@Override
public void afterPropertiesSet() {
if (this.skipCreateExtension) {
return;
}

this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS pgml");
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
if (StringUtils.hasText(this.defaultOptions.getVectorType().extensionName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,10 @@
import java.util.Map;

import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.postgresml.PostgresMlEmbeddingClient.VectorType;

import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy;
import org.testcontainers.junit.jupiter.Container;
Expand All @@ -40,6 +33,10 @@

import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.postgresml.PostgresMlEmbeddingClient.VectorType;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.autoconfigure.jdbc.AutoConfigureTestDatabase;
Expand All @@ -55,13 +52,14 @@
@JdbcTest(properties = "logging.level.sql=TRACE")
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
@Testcontainers
@Disabled("Disabled from automatic execution, as it requires an excessive amount of memory (over 9GB)!")
// @Disabled("Disabled from automatic execution, as it requires an excessive amount of
// memory (over 9GB)!")
class PostgresMlEmbeddingClientIT {

@Container
@ServiceConnection
static PostgreSQLContainer<?> postgres = new PostgreSQLContainer<>(
DockerImageName.parse("ghcr.io/postgresml/postgresml:2.8.1").asCompatibleSubstituteFor("postgres"))
DockerImageName.parse("ghcr.io/postgresml/postgresml:2.8.2").asCompatibleSubstituteFor("postgres"))
.withCommand("sleep", "infinity")
.withLabel("org.springframework.boot.service-connection", "postgres")
.withUsername("postgresml")
Expand All @@ -73,9 +71,11 @@ class PostgresMlEmbeddingClientIT {
@Autowired
JdbcTemplate jdbcTemplate;

@AfterEach
@BeforeEach
void dropPgmlExtension() {
this.jdbcTemplate.execute("DROP EXTENSION IF EXISTS pgml");
this.jdbcTemplate.execute("DROP EXTENSION IF EXISTS hstore");
this.jdbcTemplate.execute("DROP EXTENSION IF EXISTS vector");
}

@Test
Expand All @@ -94,7 +94,8 @@ void embedWithPgVector() {
PostgresMlEmbeddingOptions.builder()
.withTransformer("distilbert-base-uncased")
.withVectorType(PostgresMlEmbeddingClient.VectorType.PG_VECTOR)
.build());
.build(),
false);
embeddingClient.afterPropertiesSet();

List<Double> embed = embeddingClient.embed(new Document("Hello World!"));
Expand All @@ -105,7 +106,7 @@ void embedWithPgVector() {
@Test
void embedWithDifferentModel() {
PostgresMlEmbeddingClient embeddingClient = new PostgresMlEmbeddingClient(this.jdbcTemplate,
PostgresMlEmbeddingOptions.builder().withTransformer("intfloat/e5-small").build());
PostgresMlEmbeddingOptions.builder().withTransformer("intfloat/e5-small").build(), false);
embeddingClient.afterPropertiesSet();

List<Double> embed = embeddingClient.embed(new Document("Hello World!"));
Expand All @@ -121,7 +122,8 @@ void embedWithKwargs() {
.withVectorType(PostgresMlEmbeddingClient.VectorType.PG_ARRAY)
.withKwargs(Map.of("device", "cpu"))
.withMetadataMode(MetadataMode.EMBED)
.build());
.build(),
false);
embeddingClient.afterPropertiesSet();

List<Double> embed = embeddingClient.embed(new Document("Hello World!"));
Expand All @@ -136,7 +138,8 @@ void embedForResponse(String vectorType) {
PostgresMlEmbeddingOptions.builder()
.withTransformer("distilbert-base-uncased")
.withVectorType(VectorType.valueOf(vectorType))
.build());
.build(),
false);
embeddingClient.afterPropertiesSet();

EmbeddingResponse embeddingResponse = embeddingClient
Expand All @@ -161,7 +164,8 @@ void embedCallWithRequestOptionsOverride() {
PostgresMlEmbeddingOptions.builder()
.withTransformer("distilbert-base-uncased")
.withVectorType(VectorType.PG_VECTOR)
.build());
.build(),
false);
embeddingClient.afterPropertiesSet();

var request1 = new EmbeddingRequest(List.of("Hello World!", "Spring AI!", "LLM!"), EmbeddingOptions.EMPTY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public class PostgresMlAutoConfiguration {
public PostgresMlEmbeddingClient postgresMlEmbeddingClient(JdbcTemplate jdbcTemplate,
PostgresMlEmbeddingProperties embeddingProperties) {

return new PostgresMlEmbeddingClient(jdbcTemplate, embeddingProperties.getOptions());
return new PostgresMlEmbeddingClient(jdbcTemplate, embeddingProperties.getOptions(),
embeddingProperties.isSkipCreateExtension());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public class PostgresMlEmbeddingProperties {
*/
private boolean enabled = true;

private boolean skipCreateExtension = false;

@NestedConfigurationProperty
private PostgresMlEmbeddingOptions options = PostgresMlEmbeddingOptions.builder()
.withTransformer(PostgresMlEmbeddingClient.DEFAULT_TRANSFORMER_MODEL)
Expand Down Expand Up @@ -70,4 +72,12 @@ public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public boolean isSkipCreateExtension() {
return skipCreateExtension;
}

public void setSkipCreateExtension(boolean skipCreateExtension) {
this.skipCreateExtension = skipCreateExtension;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
@SpringBootTest(properties = { "spring.ai.postgresml.embedding.options.metadata-mode=all",
"spring.ai.postgresml.embedding.options.kwargs.key1=value1",
"spring.ai.postgresml.embedding.options.kwargs.key2=value2",
"spring.ai.postgresml.embedding.options.transformer=abc123" })
"spring.ai.postgresml.embedding.options.transformer=abc123",
"spring.ai.postgresml.skip-create-extension=true" })
class PostgresMlEmbeddingPropertiesTests {

@Autowired
Expand All @@ -52,6 +53,7 @@ void postgresMlPropertiesAreCorrect() {
assertThat(this.postgresMlProperties.getOptions().getKwargs())
.isEqualTo(Map.of("key1", "value1", "key2", "value2"));
assertThat(this.postgresMlProperties.getOptions().getMetadataMode()).isEqualTo(MetadataMode.ALL);
assertThat(this.postgresMlProperties.isSkipCreateExtension()).isTrue();
}

@SpringBootConfiguration
Expand Down