Skip to content

Commit 7dcf381

Browse files
authored
fix: adding cache for embedding models in EmbeddingModelFactory (#601)
* [fix] adding cache for embedding models in EmbeddingModelFactory * [fix] adding cache for embedding models in EmbeddingModelFactory * [fix] adding cache for embedding models in EmbeddingModelFactory
1 parent 89cc1e9 commit 7dcf381

File tree

2 files changed

+143
-8
lines changed

2 files changed

+143
-8
lines changed

redis-om-spring-ai/src/main/java/com/redis/om/spring/vectorize/EmbeddingModelFactory.java

Lines changed: 142 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import java.time.Duration;
44
import java.util.Arrays;
55
import java.util.Map;
6+
import java.util.concurrent.ConcurrentHashMap;
67
import java.util.stream.Collectors;
78

89
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
@@ -46,6 +47,8 @@
4647
public class EmbeddingModelFactory {
4748
private final AIRedisOMProperties properties;
4849
private final SpringAiProperties springAiProperties;
50+
private final Map<String, Object> modelCache = new ConcurrentHashMap<>();
51+
4952
private final RestClient.Builder restClientBuilder;
5053
private final WebClient.Builder webClientBuilder;
5154
private final ResponseErrorHandler responseErrorHandler;
@@ -62,7 +65,60 @@ public EmbeddingModelFactory(AIRedisOMProperties properties, SpringAiProperties
6265
this.observationRegistry = observationRegistry;
6366
}
6467

68+
/**
69+
* Generates a cache key for a model based on its type and parameters
70+
*
71+
* @param modelType The type of the model
72+
* @param params Parameters that uniquely identify the model configuration
73+
* @return A string key for caching
74+
*/
75+
private String generateCacheKey(String modelType, String... params) {
76+
StringBuilder keyBuilder = new StringBuilder(modelType);
77+
for (String param : params) {
78+
keyBuilder.append(":").append(param);
79+
}
80+
return keyBuilder.toString();
81+
}
82+
83+
/**
84+
* Clears the model cache, forcing new models to be created on next request.
85+
* This can be useful when configuration changes or to free up resources.
86+
*/
87+
public void clearCache() {
88+
modelCache.clear();
89+
}
90+
91+
/**
92+
* Removes a specific model from the cache.
93+
*
94+
* @param modelType The type of the model (e.g., "openai", "transformers")
95+
* @param params Parameters that were used to create the model
96+
* @return true if a model was removed, false otherwise
97+
*/
98+
public boolean removeFromCache(String modelType, String... params) {
99+
String cacheKey = generateCacheKey(modelType, params);
100+
return modelCache.remove(cacheKey) != null;
101+
}
102+
103+
/**
104+
* Returns the current number of models in the cache.
105+
*
106+
* @return The number of cached models
107+
*/
108+
public int getCacheSize() {
109+
return modelCache.size();
110+
}
111+
65112
public TransformersEmbeddingModel createTransformersEmbeddingModel(Vectorize vectorize) {
113+
String cacheKey = generateCacheKey("transformers", vectorize.transformersModel(), vectorize.transformersTokenizer(),
114+
vectorize.transformersResourceCacheConfiguration(), String.join(",", vectorize.transformersTokenizerOptions()));
115+
116+
TransformersEmbeddingModel cachedModel = (TransformersEmbeddingModel) modelCache.get(cacheKey);
117+
118+
if (cachedModel != null) {
119+
return cachedModel;
120+
}
121+
66122
TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel();
67123

68124
if (!vectorize.transformersModel().isEmpty()) {
@@ -89,6 +145,8 @@ public TransformersEmbeddingModel createTransformersEmbeddingModel(Vectorize vec
89145
throw new RuntimeException("Error initializing TransformersEmbeddingModel", e);
90146
}
91147

148+
modelCache.put(cacheKey, embeddingModel);
149+
92150
return embeddingModel;
93151
}
94152

@@ -97,6 +155,13 @@ public OpenAiEmbeddingModel createOpenAiEmbeddingModel(EmbeddingModel model) {
97155
}
98156

99157
public OpenAiEmbeddingModel createOpenAiEmbeddingModel(String model) {
158+
String cacheKey = generateCacheKey("openai", model, properties.getOpenAi().getApiKey());
159+
OpenAiEmbeddingModel cachedModel = (OpenAiEmbeddingModel) modelCache.get(cacheKey);
160+
161+
if (cachedModel != null) {
162+
return cachedModel;
163+
}
164+
100165
String apiKey = properties.getOpenAi().getApiKey();
101166
if (!StringUtils.hasText(apiKey)) {
102167
apiKey = springAiProperties.getOpenai().getApiKey();
@@ -109,8 +174,11 @@ public OpenAiEmbeddingModel createOpenAiEmbeddingModel(String model) {
109174
OpenAiApi openAiApi = OpenAiApi.builder().apiKey(properties.getOpenAi().getApiKey()).restClientBuilder(RestClient
110175
.builder().requestFactory(factory)).build();
111176

112-
return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, OpenAiEmbeddingOptions.builder().model(model)
113-
.build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
177+
OpenAiEmbeddingModel embeddingModel = new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, OpenAiEmbeddingOptions
178+
.builder().model(model).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
179+
180+
modelCache.put(cacheKey, embeddingModel);
181+
return embeddingModel;
114182
}
115183

116184
private OpenAIClient getOpenAIClient() {
@@ -126,6 +194,16 @@ private OpenAIClient getOpenAIClient() {
126194
}
127195

128196
public AzureOpenAiEmbeddingModel createAzureOpenAiEmbeddingModel(String deploymentName) {
197+
String cacheKey = generateCacheKey("azure-openai", deploymentName, properties.getAzure().getOpenAi().getApiKey(),
198+
properties.getAzure().getOpenAi().getEndpoint(), String.valueOf(properties.getAzure().getEntraId()
199+
.isEnabled()));
200+
201+
AzureOpenAiEmbeddingModel cachedModel = (AzureOpenAiEmbeddingModel) modelCache.get(cacheKey);
202+
203+
if (cachedModel != null) {
204+
return cachedModel;
205+
}
206+
129207
String apiKey = properties.getAzure().getOpenAi().getApiKey();
130208
if (!StringUtils.hasText(apiKey)) {
131209
apiKey = springAiProperties.getAzure().getApiKey(); // Fallback to Spring AI property
@@ -142,10 +220,23 @@ public AzureOpenAiEmbeddingModel createAzureOpenAiEmbeddingModel(String deployme
142220

143221
AzureOpenAiEmbeddingOptions options = AzureOpenAiEmbeddingOptions.builder().deploymentName(deploymentName).build();
144222

145-
return new AzureOpenAiEmbeddingModel(openAIClient, MetadataMode.EMBED, options);
223+
AzureOpenAiEmbeddingModel embeddingModel = new AzureOpenAiEmbeddingModel(openAIClient, MetadataMode.EMBED, options);
224+
225+
modelCache.put(cacheKey, embeddingModel);
226+
227+
return embeddingModel;
146228
}
147229

148230
public VertexAiTextEmbeddingModel createVertexAiTextEmbeddingModel(String model) {
231+
String cacheKey = generateCacheKey("vertex-ai", model, properties.getVertexAi().getApiKey(), properties
232+
.getVertexAi().getEndpoint(), properties.getVertexAi().getProjectId(), properties.getVertexAi().getLocation());
233+
234+
VertexAiTextEmbeddingModel cachedModel = (VertexAiTextEmbeddingModel) modelCache.get(cacheKey);
235+
236+
if (cachedModel != null) {
237+
return cachedModel;
238+
}
239+
149240
String apiKey = properties.getVertexAi().getApiKey();
150241
if (!StringUtils.hasText(apiKey)) {
151242
apiKey = springAiProperties.getVertexAi().getApiKey(); // Fallback to Spring AI property
@@ -183,16 +274,32 @@ public VertexAiTextEmbeddingModel createVertexAiTextEmbeddingModel(String model)
183274

184275
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder().model(model).build();
185276

186-
return new VertexAiTextEmbeddingModel(connectionDetails, options);
277+
VertexAiTextEmbeddingModel embeddingModel = new VertexAiTextEmbeddingModel(connectionDetails, options);
278+
279+
modelCache.put(cacheKey, embeddingModel);
280+
281+
return embeddingModel;
187282
}
188283

189284
public OllamaEmbeddingModel createOllamaEmbeddingModel(String model) {
285+
String cacheKey = generateCacheKey("ollama", model, properties.getOllama().getBaseUrl());
286+
287+
OllamaEmbeddingModel cachedModel = (OllamaEmbeddingModel) modelCache.get(cacheKey);
288+
289+
if (cachedModel != null) {
290+
return cachedModel;
291+
}
292+
190293
OllamaApi api = OllamaApi.builder().baseUrl(properties.getOllama().getBaseUrl()).restClientBuilder(
191294
restClientBuilder).webClientBuilder(webClientBuilder).responseErrorHandler(responseErrorHandler).build();
192295

193296
OllamaOptions options = OllamaOptions.builder().model(model).truncate(false).build();
194297

195-
return OllamaEmbeddingModel.builder().ollamaApi(api).defaultOptions(options).build();
298+
OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel.builder().ollamaApi(api).defaultOptions(options).build();
299+
300+
modelCache.put(cacheKey, embeddingModel);
301+
302+
return embeddingModel;
196303
}
197304

198305
private AwsCredentials getAwsCredentials() {
@@ -218,6 +325,16 @@ private AwsCredentials getAwsCredentials() {
218325
}
219326

220327
public BedrockCohereEmbeddingModel createCohereEmbeddingModel(String model) {
328+
String cacheKey = generateCacheKey("bedrock-cohere", model, properties.getAws().getAccessKey(), properties.getAws()
329+
.getSecretKey(), properties.getAws().getRegion(), String.valueOf(properties.getAws().getBedrockCohere()
330+
.getResponseTimeOut()));
331+
332+
BedrockCohereEmbeddingModel cachedModel = (BedrockCohereEmbeddingModel) modelCache.get(cacheKey);
333+
334+
if (cachedModel != null) {
335+
return cachedModel;
336+
}
337+
221338
String region = properties.getAws().getRegion();
222339
if (!StringUtils.hasText(region)) {
223340
region = springAiProperties.getBedrock().getAws().getRegion(); // Fallback to Spring AI property
@@ -228,10 +345,24 @@ public BedrockCohereEmbeddingModel createCohereEmbeddingModel(String model) {
228345
properties.getAws().getRegion(), ModelOptionsUtils.OBJECT_MAPPER, Duration.ofMinutes(properties.getAws()
229346
.getBedrockCohere().getResponseTimeOut()));
230347

231-
return new BedrockCohereEmbeddingModel(cohereEmbeddingApi);
348+
BedrockCohereEmbeddingModel embeddingModel = new BedrockCohereEmbeddingModel(cohereEmbeddingApi);
349+
350+
modelCache.put(cacheKey, embeddingModel);
351+
352+
return embeddingModel;
232353
}
233354

234355
public BedrockTitanEmbeddingModel createTitanEmbeddingModel(String model) {
356+
String cacheKey = generateCacheKey("bedrock-titan", model, properties.getAws().getAccessKey(), properties.getAws()
357+
.getSecretKey(), properties.getAws().getRegion(), String.valueOf(properties.getAws().getBedrockTitan()
358+
.getResponseTimeOut()));
359+
360+
BedrockTitanEmbeddingModel cachedModel = (BedrockTitanEmbeddingModel) modelCache.get(cacheKey);
361+
362+
if (cachedModel != null) {
363+
return cachedModel;
364+
}
365+
235366
String region = properties.getAws().getRegion();
236367
if (!StringUtils.hasText(region)) {
237368
region = springAiProperties.getBedrock().getAws().getRegion(); // Fallback to Spring AI property
@@ -242,6 +373,10 @@ public BedrockTitanEmbeddingModel createTitanEmbeddingModel(String model) {
242373
properties.getAws().getRegion(), ModelOptionsUtils.OBJECT_MAPPER, Duration.ofMinutes(properties.getAws()
243374
.getBedrockTitan().getResponseTimeOut()));
244375

245-
return new BedrockTitanEmbeddingModel(titanEmbeddingApi, observationRegistry);
376+
BedrockTitanEmbeddingModel embeddingModel = new BedrockTitanEmbeddingModel(titanEmbeddingApi, observationRegistry);
377+
378+
modelCache.put(cacheKey, embeddingModel);
379+
380+
return embeddingModel;
246381
}
247382
}

tests/src/test/resources/vss_on.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ redis:
22
om:
33
spring:
44
ai:
5-
\enabled: true
5+
enabled: true

0 commit comments

Comments
 (0)