35
35
import org .springframework .ai .openai .OpenAiEmbeddingOptions ;
36
36
import org .springframework .ai .openai .api .OpenAiApi ;
37
37
import org .springframework .ai .retry .RetryUtils ;
38
- import org .springframework .ai .vertexai .palm2 .VertexAiPaLm2EmbeddingModel ;
39
- import org .springframework .ai .vertexai .palm2 .api .VertexAiPaLm2Api ;
38
+ import org .springframework .ai .transformers .TransformersEmbeddingModel ;
39
+ import org .springframework .ai .vertexai .embedding .VertexAiEmbeddingConnectionDetails ;
40
+ import org .springframework .ai .vertexai .embedding .text .VertexAiTextEmbeddingModel ;
41
+ import org .springframework .ai .vertexai .embedding .text .VertexAiTextEmbeddingOptions ;
40
42
import org .springframework .beans .factory .annotation .Qualifier ;
41
43
import org .springframework .beans .factory .annotation .Value ;
42
44
import org .springframework .boot .autoconfigure .condition .ConditionalOnMissingBean ;
43
45
import org .springframework .boot .autoconfigure .condition .ConditionalOnProperty ;
44
46
import org .springframework .boot .context .properties .EnableConfigurationProperties ;
45
47
import org .springframework .context .ApplicationContext ;
46
- import org .springframework .context .annotation .*;
48
+ import org .springframework .context .annotation .Bean ;
49
+ import org .springframework .context .annotation .Configuration ;
50
+ import org .springframework .context .annotation .Primary ;
47
51
import org .springframework .lang .Nullable ;
48
52
import org .springframework .util .StringUtils ;
49
- import org .springframework .web .client .RestClient ;
50
53
import software .amazon .awssdk .auth .credentials .AwsBasicCredentials ;
51
54
import software .amazon .awssdk .auth .credentials .AwsCredentials ;
52
55
import software .amazon .awssdk .auth .credentials .StaticCredentialsProvider ;
53
56
import software .amazon .awssdk .regions .Region ;
54
57
55
58
import java .io .IOException ;
56
59
import java .net .InetAddress ;
57
- import java .time .* ;
60
+ import java .time .Duration ;
58
61
import java .util .Map ;
59
62
60
63
@ ConditionalOnProperty (name = "redis.om.spring.ai.enabled" )
@@ -70,8 +73,8 @@ public ImageFactory imageFactory() {
70
73
}
71
74
72
75
@ Bean (name = "djlImageEmbeddingModelCriteria" )
73
- public Criteria <Image , byte []> imageEmbeddingModelCriteria (RedisOMAiProperties properties ) {
74
- return Criteria .builder ().setTypes (Image .class , byte [].class ) //
76
+ public Criteria <Image , float []> imageEmbeddingModelCriteria (RedisOMAiProperties properties ) {
77
+ return Criteria .builder ().setTypes (Image .class , float [].class ) //
75
78
.optEngine (properties .getDjl ().getImageEmbeddingModelEngine ()) //
76
79
.optModelUrls (properties .getDjl ().getImageEmbeddingModelModelUrls ()) //
77
80
.build ();
@@ -123,7 +126,8 @@ public Criteria<Image, float[]> faceEmbeddingModelCriteria( //
123
126
RedisOMAiProperties properties ) {
124
127
125
128
return Criteria .builder () //
126
- .setTypes (Image .class , float [].class ).optModelUrls (properties .getDjl ().getFaceEmbeddingModelModelUrls ()) //
129
+ .setTypes (Image .class , float [].class ) //
130
+ .optModelUrls (properties .getDjl ().getFaceEmbeddingModelModelUrls ()) //
127
131
.optModelName (properties .getDjl ().getFaceEmbeddingModelName ()) //
128
132
.optTranslator (translator ) //
129
133
.optEngine (properties .getDjl ().getFaceEmbeddingModelEngine ()) //
@@ -142,8 +146,9 @@ public ZooModel<Image, float[]> faceEmbeddingModel(
142
146
}
143
147
144
148
@ Bean (name = "djlImageEmbeddingModel" )
145
- public ZooModel <Image , byte []> imageModel (
146
- @ Nullable @ Qualifier ("djlImageEmbeddingModelCriteria" ) Criteria <Image , byte []> criteria ) throws MalformedModelException , ModelNotFoundException , IOException {
149
+ public ZooModel <Image , float []> imageModel (
150
+ @ Nullable @ Qualifier ("djlImageEmbeddingModelCriteria" ) Criteria <Image , float []> criteria )
151
+ throws MalformedModelException , ModelNotFoundException , IOException {
147
152
return criteria != null ? ModelZoo .loadModel (criteria ) : null ;
148
153
}
149
154
@@ -178,6 +183,28 @@ public HuggingFaceTokenizer sentenceTokenizer(RedisOMAiProperties properties) {
178
183
}
179
184
}
180
185
186
+ @ Bean (name = "transformersEmbeddingModel" )
187
+ public TransformersEmbeddingModel transformersEmbeddingModel (RedisOMAiProperties properties ) {
188
+ TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel ();
189
+ if (properties .getTransformers ().getTokenizerResource () != null ) {
190
+ embeddingModel .setTokenizerResource (properties .getTransformers ().getTokenizerResource ());
191
+ }
192
+
193
+ if (properties .getTransformers ().getModelResource () != null ) {
194
+ embeddingModel .setModelResource (properties .getTransformers ().getModelResource ());
195
+ }
196
+
197
+ if (properties .getTransformers ().getResourceCacheDirectory () != null ) {
198
+ embeddingModel .setResourceCacheDirectory (properties .getTransformers ().getResourceCacheDirectory ());
199
+ }
200
+
201
+ if (!properties .getTransformers ().getTokenizerOptions ().isEmpty ()) {
202
+ embeddingModel .setTokenizerOptions (properties .getTransformers ().getTokenizerOptions ());
203
+ }
204
+
205
+ return embeddingModel ;
206
+ }
207
+
181
208
@ ConditionalOnMissingBean
182
209
@ Bean
183
210
public OpenAiEmbeddingModel openAITextVectorizer (RedisOMAiProperties properties ,
@@ -204,7 +231,7 @@ public OpenAiEmbeddingModel openAITextVectorizer(RedisOMAiProperties properties,
204
231
205
232
// Rest of the configuration
206
233
return new OpenAiEmbeddingModel (openAiApi , MetadataMode .EMBED ,
207
- OpenAiEmbeddingOptions .builder ().withModel ("text-embedding-ada-002" ).build (),
234
+ OpenAiEmbeddingOptions .builder ().model ("text-embedding-ada-002" ).build (),
208
235
RetryUtils .DEFAULT_RETRY_TEMPLATE );
209
236
} else {
210
237
return null ;
@@ -251,7 +278,7 @@ public OpenAIClient azureOpenAIClient(RedisOMAiProperties properties, //
251
278
252
279
@ ConditionalOnMissingBean
253
280
@ Bean
254
- VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel (RedisOMAiProperties properties , //
281
+ VertexAiTextEmbeddingModel vertexAiEmbeddingModel (RedisOMAiProperties properties , //
255
282
@ Value ("${spring.ai.vertex.ai.api-key:}" ) String apiKey ,
256
283
@ Value ("${spring.ai.vertex.ai.ai.base-url:}" ) String baseUrl ) {
257
284
if (!StringUtils .hasText (apiKey )) {
@@ -281,9 +308,15 @@ VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel(RedisOMAiProperties prop
281
308
}
282
309
283
310
if (StringUtils .hasText (apiKey ) && StringUtils .hasText (baseUrl )) {
284
- VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api (baseUrl , apiKey , VertexAiPaLm2Api .DEFAULT_GENERATE_MODEL ,
285
- VertexAiPaLm2Api .DEFAULT_EMBEDDING_MODEL , RestClient .builder ());
286
- return new VertexAiPaLm2EmbeddingModel (vertexAiApi );
311
+
312
+ VertexAiEmbeddingConnectionDetails connectionDetails = VertexAiEmbeddingConnectionDetails .builder ()
313
+ .projectId (System .getenv ("VERTEX_AI_GEMINI_PROJECT_ID" )).location (System .getenv ("VERTEX_AI_GEMINI_LOCATION" ))
314
+ .build ();
315
+
316
+ VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions .builder ()
317
+ .model (VertexAiTextEmbeddingOptions .DEFAULT_MODEL_NAME ).build ();
318
+
319
+ return new VertexAiTextEmbeddingModel (connectionDetails , options );
287
320
} else {
288
321
return null ;
289
322
}
@@ -346,7 +379,7 @@ BedrockCohereEmbeddingModel bedrockCohereEmbeddingModel(RedisOMAiProperties prop
346
379
if (!StringUtils .hasText (model )) {
347
380
model = properties .getBedrockCohere ().getModel ();
348
381
if (!StringUtils .hasText (model )) {
349
- model = CohereEmbeddingModel .COHERE_EMBED_MULTILINGUAL_V1 .id ();
382
+ model = CohereEmbeddingModel .COHERE_EMBED_MULTILINGUAL_V3 .id ();
350
383
properties .getBedrockCohere ().setModel (model );
351
384
}
352
385
}
@@ -439,19 +472,19 @@ BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel(RedisOMAiProperties proper
439
472
@ Primary
440
473
@ Bean (name = "featureExtractor" )
441
474
public Embedder featureExtractor (
442
- @ Nullable @ Qualifier ("djlImageEmbeddingModel" ) ZooModel <Image , byte []> imageEmbeddingModel ,
475
+ @ Nullable @ Qualifier ("djlImageEmbeddingModel" ) ZooModel <Image , float []> imageEmbeddingModel ,
443
476
@ Nullable @ Qualifier ("djlFaceEmbeddingModel" ) ZooModel <Image , float []> faceEmbeddingModel ,
444
477
@ Nullable @ Qualifier ("djlImageFactory" ) ImageFactory imageFactory ,
445
478
@ Nullable @ Qualifier ("djlDefaultImagePipeline" ) Pipeline defaultImagePipeline ,
446
- @ Nullable @ Qualifier ("djlSentenceTokenizer " ) HuggingFaceTokenizer sentenceTokenizer ,
479
+ @ Nullable @ Qualifier ("transformersEmbeddingModel " ) TransformersEmbeddingModel transformersEmbeddingModel ,
447
480
@ Nullable OpenAiEmbeddingModel openAITextVectorizer , @ Nullable OpenAIClient azureOpenAIClient ,
448
- @ Nullable VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel ,
481
+ @ Nullable VertexAiTextEmbeddingModel vertexAiTextEmbeddingModel ,
449
482
@ Nullable BedrockCohereEmbeddingModel bedrockCohereEmbeddingModel ,
450
483
@ Nullable BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel ,
451
484
RedisOMAiProperties properties ,
452
485
ApplicationContext ac ) {
453
486
return new DefaultEmbedder (ac , imageEmbeddingModel , faceEmbeddingModel , imageFactory , defaultImagePipeline ,
454
- sentenceTokenizer , openAITextVectorizer , azureOpenAIClient , vertexAiPaLm2EmbeddingModel ,
487
+ transformersEmbeddingModel , openAITextVectorizer , azureOpenAIClient , vertexAiTextEmbeddingModel ,
455
488
bedrockCohereEmbeddingModel , bedrockTitanEmbeddingModel , properties );
456
489
}
457
490
}
0 commit comments