Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public class PostgresMlEmbeddingClient extends AbstractEmbeddingClient implement

private final MetadataMode metadataMode;

private final boolean skipCreateExtension;

public enum VectorType {

PG_ARRAY("", null, (rs, i) -> {
Expand Down Expand Up @@ -91,7 +93,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer)
* @param vectorType vector type in PostgreSQL
*/
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType) {
this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED);
this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED, false);
}

/**
Expand All @@ -100,9 +102,10 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer,
* @param transformer huggingface sentence-transformer name
* @param vectorType vector type in PostgreSQL
* @param kwargs optional arguments
* @param skipCreateExtension whether to skip "CREATE EXTENSION"
*/
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType,
Map<String, Object> kwargs, MetadataMode metadataMode) {
Map<String, Object> kwargs, MetadataMode metadataMode, boolean skipCreateExtension) {
Assert.notNull(jdbcTemplate, "jdbc template must not be null.");
Assert.notNull(transformer, "transformer must not be null.");
Assert.notNull(vectorType, "vectorType must not be null.");
Expand All @@ -119,6 +122,7 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer,
catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
this.skipCreateExtension = skipCreateExtension;
}

@Override
Expand Down Expand Up @@ -174,6 +178,9 @@ public EmbeddingResponse call(EmbeddingRequest request) {

@Override
public void afterPropertiesSet() {
if (this.skipCreateExtension) {
return;
}
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS pgml");
if (StringUtils.hasText(this.vectorType.extensionName)) {
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS " + this.vectorType.extensionName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import java.util.Map;

import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -42,7 +42,7 @@ class PostgresMlEmbeddingClientIT {
@Container
@ServiceConnection
static PostgreSQLContainer<?> postgres = new PostgreSQLContainer<>(
DockerImageName.parse("ghcr.io/postgresml/postgresml:2.7.3").asCompatibleSubstituteFor("postgres"))
DockerImageName.parse("ghcr.io/postgresml/postgresml:2.7.13").asCompatibleSubstituteFor("postgres"))
.withCommand("sleep", "infinity")
.withLabel("org.springframework.boot.service-connection", "postgres")
.withUsername("postgresml")
Expand All @@ -54,7 +54,7 @@ class PostgresMlEmbeddingClientIT {
@Autowired
JdbcTemplate jdbcTemplate;

@AfterEach
@BeforeEach
void dropPgmlExtension() {
this.jdbcTemplate.execute("DROP EXTENSION IF EXISTS pgml");
}
Expand All @@ -65,7 +65,6 @@ void embed() {
embeddingClient.afterPropertiesSet();
List<Double> embed = embeddingClient.embed("Hello World!");
assertThat(embed).hasSize(768);
// embeddingClient.dropPgmlExtension();
}

@Test
Expand All @@ -75,7 +74,6 @@ void embedWithPgVector() {
embeddingClient.afterPropertiesSet();
List<Double> embed = embeddingClient.embed(new Document("Hello World!"));
assertThat(embed).hasSize(768);
// embeddingClient.dropPgmlExtension();
}

@Test
Expand All @@ -85,18 +83,16 @@ void embedWithDifferentModel() {
embeddingClient.afterPropertiesSet();
List<Double> embed = embeddingClient.embed(new Document("Hello World!"));
assertThat(embed).hasSize(384);
// embeddingClient.dropPgmlExtension();
}

@Test
void embedWithKwargs() {
PostgresMlEmbeddingClient embeddingClient = new PostgresMlEmbeddingClient(this.jdbcTemplate,
"distilbert-base-uncased", PostgresMlEmbeddingClient.VectorType.PG_ARRAY, Map.of("device", "cpu"),
MetadataMode.EMBED);
MetadataMode.EMBED, false);
embeddingClient.afterPropertiesSet();
List<Double> embed = embeddingClient.embed(new Document("Hello World!"));
assertThat(embed).hasSize(768);
// embeddingClient.dropPgmlExtension();
}

@ParameterizedTest
Expand All @@ -117,7 +113,6 @@ void embedForResponse(String vectorType) {
assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(768);
assertThat(embeddingResponse.getResults().get(2).getIndex()).isEqualTo(2);
assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(768);
// embeddingClient.dropPgmlExtension();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public EmbeddingClient postgresMlEmbeddingClient(JdbcTemplate jdbcTemplate,
PostgresMlProperties postgresMlProperties) {
return new PostgresMlEmbeddingClient(jdbcTemplate, postgresMlProperties.getEmbedding().getTransformer(),
postgresMlProperties.getEmbedding().getVectorType(), postgresMlProperties.getEmbedding().getKwargs(),
postgresMlProperties.getEmbedding().getMetadataMode());
postgresMlProperties.getEmbedding().getMetadataMode(), postgresMlProperties.isSkipCreateExtension());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public class PostgresMlProperties {

private MetadataMode metadataMode = MetadataMode.EMBED;

private boolean skipCreateExtension = false;

public PostgresMlProperties.Embedding getEmbedding() {
return this.embedding;
}
Expand Down Expand Up @@ -78,6 +80,15 @@ public void setMetadataMode(MetadataMode metadataMode) {
this.metadataMode = metadataMode;
}

public boolean isSkipCreateExtension() {
return skipCreateExtension;
}

public PostgresMlProperties setSkipCreateExtension(boolean skipCreateExtension) {
this.skipCreateExtension = skipCreateExtension;
return this;
}

public static class Embedding {

private PostgresMlProperties postgresMlProperties;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
* @author Utkarsh Srivastava
*/
@SpringBootTest(properties = { "spring.ai.postgresml.metadata-mode=all", "spring.ai.postgresml.kwargs.key1=value1",
"spring.ai.postgresml.kwargs.key2=value2", "spring.ai.postgresml.embedding.transformer=abc123" })
"spring.ai.postgresml.kwargs.key2=value2", "spring.ai.postgresml.embedding.transformer=abc123",
"spring.ai.postgresml.skip-create-extension=true" })
class PostgresMlPropertiesTests {

@Autowired
Expand All @@ -48,6 +49,7 @@ void postgresMlPropertiesAreCorrect() {
assertThat(this.postgresMlProperties.getVectorType()).isEqualTo(PostgresMlEmbeddingClient.VectorType.PG_ARRAY);
assertThat(this.postgresMlProperties.getKwargs()).isEqualTo(Map.of("key1", "value1", "key2", "value2"));
assertThat(this.postgresMlProperties.getMetadataMode()).isEqualTo(MetadataMode.ALL);
assertThat(this.postgresMlProperties.isSkipCreateExtension()).isTrue();

PostgresMlProperties.Embedding embedding = this.postgresMlProperties.getEmbedding();

Expand Down