3
3
import java .time .Duration ;
4
4
import java .util .Arrays ;
5
5
import java .util .Map ;
6
+ import java .util .concurrent .ConcurrentHashMap ;
6
7
import java .util .stream .Collectors ;
7
8
8
9
import org .springframework .ai .azure .openai .AzureOpenAiEmbeddingModel ;
46
47
public class EmbeddingModelFactory {
47
48
private final AIRedisOMProperties properties ;
48
49
private final SpringAiProperties springAiProperties ;
50
+ private final Map <String , Object > modelCache = new ConcurrentHashMap <>();
51
+
49
52
private final RestClient .Builder restClientBuilder ;
50
53
private final WebClient .Builder webClientBuilder ;
51
54
private final ResponseErrorHandler responseErrorHandler ;
@@ -62,7 +65,60 @@ public EmbeddingModelFactory(AIRedisOMProperties properties, SpringAiProperties
62
65
this .observationRegistry = observationRegistry ;
63
66
}
64
67
68
+ /**
69
+ * Generates a cache key for a model based on its type and parameters
70
+ *
71
+ * @param modelType The type of the model
72
+ * @param params Parameters that uniquely identify the model configuration
73
+ * @return A string key for caching
74
+ */
75
+ private String generateCacheKey (String modelType , String ... params ) {
76
+ StringBuilder keyBuilder = new StringBuilder (modelType );
77
+ for (String param : params ) {
78
+ keyBuilder .append (":" ).append (param );
79
+ }
80
+ return keyBuilder .toString ();
81
+ }
82
+
83
+ /**
84
+ * Clears the model cache, forcing new models to be created on next request.
85
+ * This can be useful when configuration changes or to free up resources.
86
+ */
87
+ public void clearCache () {
88
+ modelCache .clear ();
89
+ }
90
+
91
+ /**
92
+ * Removes a specific model from the cache.
93
+ *
94
+ * @param modelType The type of the model (e.g., "openai", "transformers")
95
+ * @param params Parameters that were used to create the model
96
+ * @return true if a model was removed, false otherwise
97
+ */
98
+ public boolean removeFromCache (String modelType , String ... params ) {
99
+ String cacheKey = generateCacheKey (modelType , params );
100
+ return modelCache .remove (cacheKey ) != null ;
101
+ }
102
+
103
+ /**
104
+ * Returns the current number of models in the cache.
105
+ *
106
+ * @return The number of cached models
107
+ */
108
+ public int getCacheSize () {
109
+ return modelCache .size ();
110
+ }
111
+
65
112
public TransformersEmbeddingModel createTransformersEmbeddingModel (Vectorize vectorize ) {
113
+ String cacheKey = generateCacheKey ("transformers" , vectorize .transformersModel (), vectorize .transformersTokenizer (),
114
+ vectorize .transformersResourceCacheConfiguration (), String .join ("," , vectorize .transformersTokenizerOptions ()));
115
+
116
+ TransformersEmbeddingModel cachedModel = (TransformersEmbeddingModel ) modelCache .get (cacheKey );
117
+
118
+ if (cachedModel != null ) {
119
+ return cachedModel ;
120
+ }
121
+
66
122
TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel ();
67
123
68
124
if (!vectorize .transformersModel ().isEmpty ()) {
@@ -89,6 +145,8 @@ public TransformersEmbeddingModel createTransformersEmbeddingModel(Vectorize vec
89
145
throw new RuntimeException ("Error initializing TransformersEmbeddingModel" , e );
90
146
}
91
147
148
+ modelCache .put (cacheKey , embeddingModel );
149
+
92
150
return embeddingModel ;
93
151
}
94
152
@@ -97,6 +155,13 @@ public OpenAiEmbeddingModel createOpenAiEmbeddingModel(EmbeddingModel model) {
97
155
}
98
156
99
157
public OpenAiEmbeddingModel createOpenAiEmbeddingModel (String model ) {
158
+ String cacheKey = generateCacheKey ("openai" , model , properties .getOpenAi ().getApiKey ());
159
+ OpenAiEmbeddingModel cachedModel = (OpenAiEmbeddingModel ) modelCache .get (cacheKey );
160
+
161
+ if (cachedModel != null ) {
162
+ return cachedModel ;
163
+ }
164
+
100
165
String apiKey = properties .getOpenAi ().getApiKey ();
101
166
if (!StringUtils .hasText (apiKey )) {
102
167
apiKey = springAiProperties .getOpenai ().getApiKey ();
@@ -109,8 +174,11 @@ public OpenAiEmbeddingModel createOpenAiEmbeddingModel(String model) {
109
174
OpenAiApi openAiApi = OpenAiApi .builder ().apiKey (properties .getOpenAi ().getApiKey ()).restClientBuilder (RestClient
110
175
.builder ().requestFactory (factory )).build ();
111
176
112
- return new OpenAiEmbeddingModel (openAiApi , MetadataMode .EMBED , OpenAiEmbeddingOptions .builder ().model (model )
113
- .build (), RetryUtils .DEFAULT_RETRY_TEMPLATE );
177
+ OpenAiEmbeddingModel embeddingModel = new OpenAiEmbeddingModel (openAiApi , MetadataMode .EMBED , OpenAiEmbeddingOptions
178
+ .builder ().model (model ).build (), RetryUtils .DEFAULT_RETRY_TEMPLATE );
179
+
180
+ modelCache .put (cacheKey , embeddingModel );
181
+ return embeddingModel ;
114
182
}
115
183
116
184
private OpenAIClient getOpenAIClient () {
@@ -126,6 +194,16 @@ private OpenAIClient getOpenAIClient() {
126
194
}
127
195
128
196
public AzureOpenAiEmbeddingModel createAzureOpenAiEmbeddingModel (String deploymentName ) {
197
+ String cacheKey = generateCacheKey ("azure-openai" , deploymentName , properties .getAzure ().getOpenAi ().getApiKey (),
198
+ properties .getAzure ().getOpenAi ().getEndpoint (), String .valueOf (properties .getAzure ().getEntraId ()
199
+ .isEnabled ()));
200
+
201
+ AzureOpenAiEmbeddingModel cachedModel = (AzureOpenAiEmbeddingModel ) modelCache .get (cacheKey );
202
+
203
+ if (cachedModel != null ) {
204
+ return cachedModel ;
205
+ }
206
+
129
207
String apiKey = properties .getAzure ().getOpenAi ().getApiKey ();
130
208
if (!StringUtils .hasText (apiKey )) {
131
209
apiKey = springAiProperties .getAzure ().getApiKey (); // Fallback to Spring AI property
@@ -142,10 +220,23 @@ public AzureOpenAiEmbeddingModel createAzureOpenAiEmbeddingModel(String deployme
142
220
143
221
AzureOpenAiEmbeddingOptions options = AzureOpenAiEmbeddingOptions .builder ().deploymentName (deploymentName ).build ();
144
222
145
- return new AzureOpenAiEmbeddingModel (openAIClient , MetadataMode .EMBED , options );
223
+ AzureOpenAiEmbeddingModel embeddingModel = new AzureOpenAiEmbeddingModel (openAIClient , MetadataMode .EMBED , options );
224
+
225
+ modelCache .put (cacheKey , embeddingModel );
226
+
227
+ return embeddingModel ;
146
228
}
147
229
148
230
public VertexAiTextEmbeddingModel createVertexAiTextEmbeddingModel (String model ) {
231
+ String cacheKey = generateCacheKey ("vertex-ai" , model , properties .getVertexAi ().getApiKey (), properties
232
+ .getVertexAi ().getEndpoint (), properties .getVertexAi ().getProjectId (), properties .getVertexAi ().getLocation ());
233
+
234
+ VertexAiTextEmbeddingModel cachedModel = (VertexAiTextEmbeddingModel ) modelCache .get (cacheKey );
235
+
236
+ if (cachedModel != null ) {
237
+ return cachedModel ;
238
+ }
239
+
149
240
String apiKey = properties .getVertexAi ().getApiKey ();
150
241
if (!StringUtils .hasText (apiKey )) {
151
242
apiKey = springAiProperties .getVertexAi ().getApiKey (); // Fallback to Spring AI property
@@ -183,16 +274,32 @@ public VertexAiTextEmbeddingModel createVertexAiTextEmbeddingModel(String model)
183
274
184
275
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions .builder ().model (model ).build ();
185
276
186
- return new VertexAiTextEmbeddingModel (connectionDetails , options );
277
+ VertexAiTextEmbeddingModel embeddingModel = new VertexAiTextEmbeddingModel (connectionDetails , options );
278
+
279
+ modelCache .put (cacheKey , embeddingModel );
280
+
281
+ return embeddingModel ;
187
282
}
188
283
189
284
public OllamaEmbeddingModel createOllamaEmbeddingModel (String model ) {
285
+ String cacheKey = generateCacheKey ("ollama" , model , properties .getOllama ().getBaseUrl ());
286
+
287
+ OllamaEmbeddingModel cachedModel = (OllamaEmbeddingModel ) modelCache .get (cacheKey );
288
+
289
+ if (cachedModel != null ) {
290
+ return cachedModel ;
291
+ }
292
+
190
293
OllamaApi api = OllamaApi .builder ().baseUrl (properties .getOllama ().getBaseUrl ()).restClientBuilder (
191
294
restClientBuilder ).webClientBuilder (webClientBuilder ).responseErrorHandler (responseErrorHandler ).build ();
192
295
193
296
OllamaOptions options = OllamaOptions .builder ().model (model ).truncate (false ).build ();
194
297
195
- return OllamaEmbeddingModel .builder ().ollamaApi (api ).defaultOptions (options ).build ();
298
+ OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel .builder ().ollamaApi (api ).defaultOptions (options ).build ();
299
+
300
+ modelCache .put (cacheKey , embeddingModel );
301
+
302
+ return embeddingModel ;
196
303
}
197
304
198
305
private AwsCredentials getAwsCredentials () {
@@ -218,6 +325,16 @@ private AwsCredentials getAwsCredentials() {
218
325
}
219
326
220
327
public BedrockCohereEmbeddingModel createCohereEmbeddingModel (String model ) {
328
+ String cacheKey = generateCacheKey ("bedrock-cohere" , model , properties .getAws ().getAccessKey (), properties .getAws ()
329
+ .getSecretKey (), properties .getAws ().getRegion (), String .valueOf (properties .getAws ().getBedrockCohere ()
330
+ .getResponseTimeOut ()));
331
+
332
+ BedrockCohereEmbeddingModel cachedModel = (BedrockCohereEmbeddingModel ) modelCache .get (cacheKey );
333
+
334
+ if (cachedModel != null ) {
335
+ return cachedModel ;
336
+ }
337
+
221
338
String region = properties .getAws ().getRegion ();
222
339
if (!StringUtils .hasText (region )) {
223
340
region = springAiProperties .getBedrock ().getAws ().getRegion (); // Fallback to Spring AI property
@@ -228,10 +345,24 @@ public BedrockCohereEmbeddingModel createCohereEmbeddingModel(String model) {
228
345
properties .getAws ().getRegion (), ModelOptionsUtils .OBJECT_MAPPER , Duration .ofMinutes (properties .getAws ()
229
346
.getBedrockCohere ().getResponseTimeOut ()));
230
347
231
- return new BedrockCohereEmbeddingModel (cohereEmbeddingApi );
348
+ BedrockCohereEmbeddingModel embeddingModel = new BedrockCohereEmbeddingModel (cohereEmbeddingApi );
349
+
350
+ modelCache .put (cacheKey , embeddingModel );
351
+
352
+ return embeddingModel ;
232
353
}
233
354
234
355
public BedrockTitanEmbeddingModel createTitanEmbeddingModel (String model ) {
356
+ String cacheKey = generateCacheKey ("bedrock-titan" , model , properties .getAws ().getAccessKey (), properties .getAws ()
357
+ .getSecretKey (), properties .getAws ().getRegion (), String .valueOf (properties .getAws ().getBedrockTitan ()
358
+ .getResponseTimeOut ()));
359
+
360
+ BedrockTitanEmbeddingModel cachedModel = (BedrockTitanEmbeddingModel ) modelCache .get (cacheKey );
361
+
362
+ if (cachedModel != null ) {
363
+ return cachedModel ;
364
+ }
365
+
235
366
String region = properties .getAws ().getRegion ();
236
367
if (!StringUtils .hasText (region )) {
237
368
region = springAiProperties .getBedrock ().getAws ().getRegion (); // Fallback to Spring AI property
@@ -242,6 +373,10 @@ public BedrockTitanEmbeddingModel createTitanEmbeddingModel(String model) {
242
373
properties .getAws ().getRegion (), ModelOptionsUtils .OBJECT_MAPPER , Duration .ofMinutes (properties .getAws ()
243
374
.getBedrockTitan ().getResponseTimeOut ()));
244
375
245
- return new BedrockTitanEmbeddingModel (titanEmbeddingApi , observationRegistry );
376
+ BedrockTitanEmbeddingModel embeddingModel = new BedrockTitanEmbeddingModel (titanEmbeddingApi , observationRegistry );
377
+
378
+ modelCache .put (cacheKey , embeddingModel );
379
+
380
+ return embeddingModel ;
246
381
}
247
382
}
0 commit comments