Skip to content

Commit 6b2f371

Browse files
committed
refactor: upgrade to SpringAI 1.0.0-M6
1 parent 21c059b commit 6b2f371

File tree

12 files changed

+176
-86
lines changed

12 files changed

+176
-86
lines changed

demos/roms-vss/src/main/java/com/redis/om/vss/domain/Product.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ public class Product {
7979
schemaFieldType = SchemaFieldType.VECTOR, //
8080
algorithm = VectorAlgorithm.HNSW, //
8181
type = VectorType.FLOAT32, //
82-
dimension = 768, //
82+
dimension = 384, //
8383
distanceMetric = DistanceMetric.COSINE, //
8484
initialCapacity = 10
8585
)

redis-om-spring/pom.xml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@
7878
<elementary.version>2.0.1</elementary.version>
7979
<gson.version>2.10.1</gson.version>
8080
<djl.starter.version>0.26</djl.starter.version>
81-
<djl.version>0.27.0</djl.version>
81+
<djl.version>0.30.0</djl.version>
8282
<junit-bom.version>5.10.2</junit-bom.version>
83-
<spring-ai.version>1.0.0-M2</spring-ai.version>
83+
<spring-ai.version>1.0.0-M6</spring-ai.version>
8484
</properties>
8585

8686
<dependencyManagement>
@@ -200,7 +200,7 @@
200200
</dependency>
201201
<dependency>
202202
<groupId>org.springframework.ai</groupId>
203-
<artifactId>spring-ai-vertex-ai-palm2</artifactId>
203+
<artifactId>spring-ai-vertex-ai-embedding</artifactId>
204204
<version>${spring-ai.version}</version>
205205
<optional>true</optional>
206206
</dependency>

redis-om-spring/src/main/java/com/redis/om/spring/RedisAiConfiguration.java

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,29 @@
3535
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
3636
import org.springframework.ai.openai.api.OpenAiApi;
3737
import org.springframework.ai.retry.RetryUtils;
38-
import org.springframework.ai.vertexai.palm2.VertexAiPaLm2EmbeddingModel;
39-
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api;
38+
import org.springframework.ai.transformers.TransformersEmbeddingModel;
39+
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
40+
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel;
41+
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingOptions;
4042
import org.springframework.beans.factory.annotation.Qualifier;
4143
import org.springframework.beans.factory.annotation.Value;
4244
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
4345
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
4446
import org.springframework.boot.context.properties.EnableConfigurationProperties;
4547
import org.springframework.context.ApplicationContext;
46-
import org.springframework.context.annotation.*;
48+
import org.springframework.context.annotation.Bean;
49+
import org.springframework.context.annotation.Configuration;
50+
import org.springframework.context.annotation.Primary;
4751
import org.springframework.lang.Nullable;
4852
import org.springframework.util.StringUtils;
49-
import org.springframework.web.client.RestClient;
5053
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
5154
import software.amazon.awssdk.auth.credentials.AwsCredentials;
5255
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
5356
import software.amazon.awssdk.regions.Region;
5457

5558
import java.io.IOException;
5659
import java.net.InetAddress;
57-
import java.time.*;
60+
import java.time.Duration;
5861
import java.util.Map;
5962

6063
@ConditionalOnProperty(name = "redis.om.spring.ai.enabled")
@@ -70,8 +73,8 @@ public ImageFactory imageFactory() {
7073
}
7174

7275
@Bean(name = "djlImageEmbeddingModelCriteria")
73-
public Criteria<Image, byte[]> imageEmbeddingModelCriteria(RedisOMAiProperties properties) {
74-
return Criteria.builder().setTypes(Image.class, byte[].class) //
76+
public Criteria<Image, float[]> imageEmbeddingModelCriteria(RedisOMAiProperties properties) {
77+
return Criteria.builder().setTypes(Image.class, float[].class) //
7578
.optEngine(properties.getDjl().getImageEmbeddingModelEngine()) //
7679
.optModelUrls(properties.getDjl().getImageEmbeddingModelModelUrls()) //
7780
.build();
@@ -123,7 +126,8 @@ public Criteria<Image, float[]> faceEmbeddingModelCriteria( //
123126
RedisOMAiProperties properties) {
124127

125128
return Criteria.builder() //
126-
.setTypes(Image.class, float[].class).optModelUrls(properties.getDjl().getFaceEmbeddingModelModelUrls()) //
129+
.setTypes(Image.class, float[].class) //
130+
.optModelUrls(properties.getDjl().getFaceEmbeddingModelModelUrls()) //
127131
.optModelName(properties.getDjl().getFaceEmbeddingModelName()) //
128132
.optTranslator(translator) //
129133
.optEngine(properties.getDjl().getFaceEmbeddingModelEngine()) //
@@ -142,8 +146,9 @@ public ZooModel<Image, float[]> faceEmbeddingModel(
142146
}
143147

144148
@Bean(name = "djlImageEmbeddingModel")
145-
public ZooModel<Image, byte[]> imageModel(
146-
@Nullable @Qualifier("djlImageEmbeddingModelCriteria") Criteria<Image, byte[]> criteria) throws MalformedModelException, ModelNotFoundException, IOException {
149+
public ZooModel<Image, float[]> imageModel(
150+
@Nullable @Qualifier("djlImageEmbeddingModelCriteria") Criteria<Image, float[]> criteria)
151+
throws MalformedModelException, ModelNotFoundException, IOException {
147152
return criteria != null ? ModelZoo.loadModel(criteria) : null;
148153
}
149154

@@ -178,6 +183,28 @@ public HuggingFaceTokenizer sentenceTokenizer(RedisOMAiProperties properties) {
178183
}
179184
}
180185

186+
@Bean(name = "transformersEmbeddingModel")
187+
public TransformersEmbeddingModel transformersEmbeddingModel(RedisOMAiProperties properties) {
188+
TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel();
189+
if (properties.getTransformers().getTokenizerResource() != null) {
190+
embeddingModel.setTokenizerResource(properties.getTransformers().getTokenizerResource());
191+
}
192+
193+
if (properties.getTransformers().getModelResource() != null) {
194+
embeddingModel.setModelResource(properties.getTransformers().getModelResource());
195+
}
196+
197+
if (properties.getTransformers().getResourceCacheDirectory() != null) {
198+
embeddingModel.setResourceCacheDirectory(properties.getTransformers().getResourceCacheDirectory());
199+
}
200+
201+
if (!properties.getTransformers().getTokenizerOptions().isEmpty()) {
202+
embeddingModel.setTokenizerOptions(properties.getTransformers().getTokenizerOptions());
203+
}
204+
205+
return embeddingModel;
206+
}
207+
181208
@ConditionalOnMissingBean
182209
@Bean
183210
public OpenAiEmbeddingModel openAITextVectorizer(RedisOMAiProperties properties,
@@ -204,7 +231,7 @@ public OpenAiEmbeddingModel openAITextVectorizer(RedisOMAiProperties properties,
204231

205232
// Rest of the configuration
206233
return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED,
207-
OpenAiEmbeddingOptions.builder().withModel("text-embedding-ada-002").build(),
234+
OpenAiEmbeddingOptions.builder().model("text-embedding-ada-002").build(),
208235
RetryUtils.DEFAULT_RETRY_TEMPLATE);
209236
} else {
210237
return null;
@@ -251,7 +278,7 @@ public OpenAIClient azureOpenAIClient(RedisOMAiProperties properties, //
251278

252279
@ConditionalOnMissingBean
253280
@Bean
254-
VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel(RedisOMAiProperties properties, //
281+
VertexAiTextEmbeddingModel vertexAiEmbeddingModel(RedisOMAiProperties properties, //
255282
@Value("${spring.ai.vertex.ai.api-key:}") String apiKey,
256283
@Value("${spring.ai.vertex.ai.ai.base-url:}") String baseUrl) {
257284
if (!StringUtils.hasText(apiKey)) {
@@ -281,9 +308,15 @@ VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel(RedisOMAiProperties prop
281308
}
282309

283310
if (StringUtils.hasText(apiKey) && StringUtils.hasText(baseUrl)) {
284-
VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(baseUrl, apiKey, VertexAiPaLm2Api.DEFAULT_GENERATE_MODEL,
285-
VertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, RestClient.builder());
286-
return new VertexAiPaLm2EmbeddingModel(vertexAiApi);
311+
312+
VertexAiEmbeddingConnectionDetails connectionDetails = VertexAiEmbeddingConnectionDetails.builder()
313+
.projectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID")).location(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
314+
.build();
315+
316+
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
317+
.model(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME).build();
318+
319+
return new VertexAiTextEmbeddingModel(connectionDetails, options);
287320
} else {
288321
return null;
289322
}
@@ -346,7 +379,7 @@ BedrockCohereEmbeddingModel bedrockCohereEmbeddingModel(RedisOMAiProperties prop
346379
if (!StringUtils.hasText(model)) {
347380
model = properties.getBedrockCohere().getModel();
348381
if (!StringUtils.hasText(model)) {
349-
model = CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id();
382+
model = CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V3.id();
350383
properties.getBedrockCohere().setModel(model);
351384
}
352385
}
@@ -439,19 +472,19 @@ BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel(RedisOMAiProperties proper
439472
@Primary
440473
@Bean(name = "featureExtractor")
441474
public Embedder featureExtractor(
442-
@Nullable @Qualifier("djlImageEmbeddingModel") ZooModel<Image, byte[]> imageEmbeddingModel,
475+
@Nullable @Qualifier("djlImageEmbeddingModel") ZooModel<Image, float[]> imageEmbeddingModel,
443476
@Nullable @Qualifier("djlFaceEmbeddingModel") ZooModel<Image, float[]> faceEmbeddingModel,
444477
@Nullable @Qualifier("djlImageFactory") ImageFactory imageFactory,
445478
@Nullable @Qualifier("djlDefaultImagePipeline") Pipeline defaultImagePipeline,
446-
@Nullable @Qualifier("djlSentenceTokenizer") HuggingFaceTokenizer sentenceTokenizer,
479+
@Nullable @Qualifier("transformersEmbeddingModel") TransformersEmbeddingModel transformersEmbeddingModel,
447480
@Nullable OpenAiEmbeddingModel openAITextVectorizer, @Nullable OpenAIClient azureOpenAIClient,
448-
@Nullable VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel,
481+
@Nullable VertexAiTextEmbeddingModel vertexAiTextEmbeddingModel,
449482
@Nullable BedrockCohereEmbeddingModel bedrockCohereEmbeddingModel,
450483
@Nullable BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel,
451484
RedisOMAiProperties properties,
452485
ApplicationContext ac) {
453486
return new DefaultEmbedder(ac, imageEmbeddingModel, faceEmbeddingModel, imageFactory, defaultImagePipeline,
454-
sentenceTokenizer, openAITextVectorizer, azureOpenAIClient, vertexAiPaLm2EmbeddingModel,
487+
transformersEmbeddingModel, openAITextVectorizer, azureOpenAIClient, vertexAiTextEmbeddingModel,
455488
bedrockCohereEmbeddingModel, bedrockTitanEmbeddingModel, properties);
456489
}
457490
}

redis-om-spring/src/main/java/com/redis/om/spring/RedisOMAiProperties.java

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
66
import org.springframework.boot.context.properties.ConfigurationProperties;
77

8+
import java.util.HashMap;
9+
import java.util.Map;
10+
811
@ConditionalOnProperty(name = "redis.om.spring.ai.enabled")
912
@ConfigurationProperties(
1013
prefix = "redis.om.spring.ai", ignoreInvalidFields = true
1114
)
1215
public class RedisOMAiProperties {
1316
private boolean enabled = false;
1417
private final Djl djl = new Djl();
18+
private final Transformers transformers = new Transformers();
1519
private final OpenAi openAi = new OpenAi();
1620
private final AzureOpenAi azureOpenAi = new AzureOpenAi();
1721
private final VertexAi vertexAi = new VertexAi();
@@ -31,6 +35,10 @@ public Djl getDjl() {
3135
return djl;
3236
}
3337

38+
public Transformers getTransformers() {
39+
return transformers;
40+
}
41+
3442
public OpenAi getOpenAi() {
3543
return openAi;
3644
}
@@ -55,6 +63,30 @@ public Ollama getOllama() {
5563
return ollama;
5664
}
5765

66+
// Transformer properties
67+
public static class Transformers {
68+
private String tokenizerResource;
69+
private String modelResource;
70+
private String resourceCacheDirectory;
71+
private Map<String, String> tokenizerOptions = new HashMap<>();
72+
73+
public String getTokenizerResource() {
74+
return tokenizerResource;
75+
}
76+
77+
public String getModelResource() {
78+
return modelResource;
79+
}
80+
81+
public String getResourceCacheDirectory() {
82+
return resourceCacheDirectory;
83+
}
84+
85+
public Map<String, String> getTokenizerOptions() {
86+
return tokenizerOptions;
87+
}
88+
}
89+
5890
// DJL properties
5991
public static class Djl {
6092
private static final String DEFAULT_ENGINE = "PyTorch";
@@ -73,7 +105,7 @@ public static class Djl {
73105
@NotNull
74106
private String sentenceTokenizerModelMaxLength = "768";
75107
@NotNull
76-
private String sentenceTokenizerModel = "sentence-transformers/all-mpnet-base-v2";
108+
private String sentenceTokenizerModel = "sentence-transformers/msmarco-distilbert-dot-v5";
77109

78110
// face detection
79111
@NotNull
@@ -91,6 +123,7 @@ public static class Djl {
91123
@NotNull
92124
private String faceEmbeddingModelModelUrls = "https://resources.djl.ai/test-models/pytorch/face_feature.zip";
93125

126+
94127
public Djl() {
95128
}
96129

@@ -278,6 +311,24 @@ public static class VertexAi {
278311
private String apiKey;
279312
private String endPoint;
280313
private String model;
314+
private String projectId;
315+
private String location;
316+
317+
public String getProjectId() {
318+
return projectId;
319+
}
320+
321+
public void setProjectId(String projectId) {
322+
this.projectId = projectId;
323+
}
324+
325+
public String getLocation() {
326+
return location;
327+
}
328+
329+
public void setLocation(String location) {
330+
this.location = location;
331+
}
281332

282333
public String getApiKey() {
283334
return apiKey;

redis-om-spring/src/main/java/com/redis/om/spring/annotations/EmbeddingProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.redis.om.spring.annotations;
22

33
public enum EmbeddingProvider {
4+
TRANSFORMERS,
45
DJL,
56
OPENAI,
67
OLLAMA,

redis-om-spring/src/main/java/com/redis/om/spring/annotations/Vectorize.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel;
55
import org.springframework.ai.ollama.api.OllamaModel;
66
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingModel;
7-
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api;
87

98
import java.lang.annotation.*;
109

@@ -16,17 +15,17 @@
1615

1716
EmbeddingType embeddingType() default EmbeddingType.SENTENCE;
1817

19-
EmbeddingProvider provider() default EmbeddingProvider.DJL;
18+
EmbeddingProvider provider() default EmbeddingProvider.TRANSFORMERS;
2019

2120
EmbeddingModel openAiEmbeddingModel() default EmbeddingModel.TEXT_EMBEDDING_ADA_002;
2221

2322
OllamaModel ollamaEmbeddingModel() default OllamaModel.MISTRAL;
2423

2524
String azureOpenAiDeploymentName() default "text-embedding-ada-002";
2625

27-
String vertexAiPaLm2ApiModel() default VertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL;
26+
String vertexAiPaLm2ApiModel() default "text-embedding-004";
2827

29-
CohereEmbeddingModel cohereEmbeddingModel() default CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1;
28+
CohereEmbeddingModel cohereEmbeddingModel() default CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V3;
3029

3130
TitanEmbeddingModel titanEmbeddingModel() default TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1;
3231
}

0 commit comments

Comments
 (0)