Skip to content

Commit 8e04a24

Browse files
committed
Neo4j module: Determine default embedding dimension from model.
In cases where no custom size is set, derive the size by the given embedding model. Had to migrate the embeddingDimension Spring Boot property from int to Integer to introduce the null check for the fluent config. Everything else would have been just noisy. Closes #977 Signed-off-by: Gerrit Meier <[email protected]>
1 parent dee4d27 commit 8e04a24

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreAutoConfiguration.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ public Neo4jVectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel
6464
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
6565
.batchingStrategy(batchingStrategy)
6666
.databaseName(properties.getDatabaseName())
67-
.embeddingDimension(properties.getEmbeddingDimension())
67+
.embeddingDimension(properties.getEmbeddingDimension() != null ? properties.getEmbeddingDimension()
68+
: embeddingModel.dimensions())
6869
.distanceType(properties.getDistanceType())
6970
.label(properties.getLabel())
7071
.embeddingProperty(properties.getEmbeddingProperty())

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/neo4j/Neo4jVectorStoreProperties.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class Neo4jVectorStoreProperties extends CommonVectorStoreProperties {
3333

3434
private String databaseName;
3535

36-
private int embeddingDimension = Neo4jVectorStore.DEFAULT_EMBEDDING_DIMENSION;
36+
private Integer embeddingDimension;
3737

3838
private Neo4jVectorStore.Neo4jDistanceType distanceType = Neo4jVectorStore.Neo4jDistanceType.COSINE;
3939

@@ -55,7 +55,7 @@ public void setDatabaseName(String databaseName) {
5555
this.databaseName = databaseName;
5656
}
5757

58-
public int getEmbeddingDimension() {
58+
public Integer getEmbeddingDimension() {
5959
return this.embeddingDimension;
6060
}
6161

vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
*/
134134
public class Neo4jVectorStore extends AbstractObservationVectorStore implements InitializingBean {
135135

136+
@Deprecated(forRemoval = true)
136137
public static final int DEFAULT_EMBEDDING_DIMENSION = 1536;
137138

138139
public static final int DEFAULT_TRANSACTION_SIZE = 10_000;
@@ -182,7 +183,7 @@ protected Neo4jVectorStore(Builder builder) {
182183

183184
this.driver = builder.driver;
184185
this.sessionConfig = builder.sessionConfig;
185-
this.embeddingDimension = builder.embeddingDimension;
186+
this.embeddingDimension = builder.embeddingDimension.orElseGet(() -> builder.getEmbeddingModel().dimensions());
186187
this.distanceType = builder.distanceType;
187188
this.embeddingProperty = SchemaNames.sanitize(builder.embeddingProperty).orElseThrow();
188189
this.label = SchemaNames.sanitize(builder.label).orElseThrow();
@@ -372,7 +373,7 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {
372373

373374
private SessionConfig sessionConfig = SessionConfig.defaultConfig();
374375

375-
private int embeddingDimension = DEFAULT_EMBEDDING_DIMENSION;
376+
private Optional<Integer> embeddingDimension = Optional.empty();
376377

377378
private Neo4jDistanceType distanceType = Neo4jDistanceType.COSINE;
378379

@@ -425,7 +426,7 @@ public Builder sessionConfig(SessionConfig sessionConfig) {
425426
*/
426427
public Builder embeddingDimension(int dimension) {
427428
Assert.isTrue(dimension >= 1, "Dimension has to be positive");
428-
this.embeddingDimension = dimension;
429+
this.embeddingDimension = Optional.of(dimension);
429430
return this;
430431
}
431432

vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreIT.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.driver.AuthTokens;
2929
import org.neo4j.driver.Driver;
3030
import org.neo4j.driver.GraphDatabase;
31+
import org.springframework.context.annotation.Primary;
3132
import org.testcontainers.containers.Neo4jContainer;
3233
import org.testcontainers.junit.jupiter.Container;
3334
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -301,16 +302,43 @@ void ensureIdIndexGetsCreated() {
301302
.isTrue());
302303
}
303304

305+
@Test
306+
void vectorIndexDimensionsDefaultAndOverwriteWorks() {
307+
this.contextRunner.run(context -> {
308+
var result = context.getBean(Driver.class)
309+
.executableQuery(
310+
"SHOW VECTOR INDEXES yield name, options return name, options['indexConfig']['vector.dimensions'] as dimensions")
311+
.execute()
312+
.records()
313+
.stream()
314+
.map(r -> r.get("name").asString() + r.get("dimensions").asInt())
315+
.toList();
316+
assertThat(result).containsExactlyInAnyOrder("secondIndex123", "spring-ai-document-index1536");
317+
});
318+
}
319+
304320
@SpringBootConfiguration
305321
@EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
306322
public static class TestApplication {
307323

308324
@Bean
325+
@Primary
309326
public VectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel) {
310327

311328
return Neo4jVectorStore.builder(driver, embeddingModel).initializeSchema(true).build();
312329
}
313330

331+
@Bean
332+
public VectorStore vectorStoreWithCustomDimension(Driver driver, EmbeddingModel embeddingModel) {
333+
334+
return Neo4jVectorStore.builder(driver, embeddingModel)
335+
.initializeSchema(true)
336+
.indexName("secondIndex")
337+
.embeddingProperty("somethingElse")
338+
.embeddingDimension(123)
339+
.build();
340+
}
341+
314342
@Bean
315343
public Driver driver() {
316344
return GraphDatabase.driver(neo4jContainer.getBoltUrl(),

0 commit comments

Comments
 (0)