Skip to content

Commit bfbc64b

Browse files
committed
feat(ollama): add retry template integration to OllamaChatModel
Signed-off-by: Alexandros Pappas <[email protected]>
1 parent b77e084 commit bfbc64b

File tree

11 files changed

+58
-16
lines changed

11 files changed

+58
-16
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
import org.springframework.ai.ollama.management.OllamaModelManager;
6262
import org.springframework.ai.ollama.management.PullModelStrategy;
6363
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
64+
import org.springframework.ai.retry.RetryUtils;
65+
import org.springframework.retry.support.RetryTemplate;
6466
import org.springframework.util.Assert;
6567
import org.springframework.util.CollectionUtils;
6668
import org.springframework.util.StringUtils;
@@ -108,20 +110,32 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
108110

109111
private final OllamaModelManager modelManager;
110112

113+
private final RetryTemplate retryTemplate;
114+
111115
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
112116

113117
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
114118
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
115119
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
120+
this(ollamaApi, defaultOptions, functionCallbackResolver, toolFunctionCallbacks, observationRegistry,
121+
modelManagementOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
122+
}
123+
124+
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
125+
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
126+
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
127+
RetryTemplate retryTemplate) {
116128
super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks);
117129
Assert.notNull(ollamaApi, "ollamaApi must not be null");
118130
Assert.notNull(defaultOptions, "defaultOptions must not be null");
119131
Assert.notNull(observationRegistry, "observationRegistry must not be null");
120132
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
133+
Assert.notNull(retryTemplate, "retryTemplate must not be null");
121134
this.chatApi = ollamaApi;
122135
this.defaultOptions = defaultOptions;
123136
this.observationRegistry = observationRegistry;
124137
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
138+
this.retryTemplate = retryTemplate;
125139
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
126140
}
127141

@@ -199,7 +213,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
199213
this.observationRegistry)
200214
.observe(() -> {
201215

202-
OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
216+
OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request));
203217

204218
List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
205219
: ollamaResponse.message()
@@ -471,6 +485,8 @@ public static final class Builder {
471485

472486
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
473487

488+
private RetryTemplate retryTemplate;
489+
474490
private Builder() {
475491
}
476492

@@ -504,9 +520,15 @@ public Builder modelManagementOptions(ModelManagementOptions modelManagementOpti
504520
return this;
505521
}
506522

523+
public Builder retryTemplate(RetryTemplate retryTemplate) {
524+
this.retryTemplate = retryTemplate;
525+
return this;
526+
}
527+
507528
public OllamaChatModel build() {
508529
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver,
509-
this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions);
530+
this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions,
531+
this.retryTemplate);
510532
}
511533

512534
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
*
5252
* @author Christian Tzolov
5353
* @author Thomas Vitale
54+
* @author Alexandros Pappas
5455
* @since 0.8.0
5556
*/
5657
// @formatter:off
@@ -64,8 +65,6 @@ public class OllamaApi {
6465

6566
private static final String DEFAULT_BASE_URL = "http://localhost:11434";
6667

67-
private final ResponseErrorHandler responseErrorHandler;
68-
6968
private final RestClient restClient;
7069

7170
private final WebClient webClient;
@@ -93,14 +92,16 @@ public OllamaApi(String baseUrl) {
9392
*/
9493
public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {
9594

96-
this.responseErrorHandler = new OllamaResponseErrorHandler();
95+
ResponseErrorHandler responseErrorHandler = new OllamaResponseErrorHandler();
9796

9897
Consumer<HttpHeaders> defaultHeaders = headers -> {
9998
headers.setContentType(MediaType.APPLICATION_JSON);
10099
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
101100
};
102101

103-
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
102+
this.restClient = restClientBuilder.baseUrl(baseUrl)
103+
.defaultStatusHandler(responseErrorHandler)
104+
.defaultHeaders(defaultHeaders).build();
104105

105106
this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
106107
}
@@ -121,7 +122,6 @@ public ChatResponse chat(ChatRequest chatRequest) {
121122
.uri("/api/chat")
122123
.body(chatRequest)
123124
.retrieve()
124-
.onStatus(this.responseErrorHandler)
125125
.body(ChatResponse.class);
126126
}
127127

@@ -188,7 +188,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
188188
.uri("/api/embed")
189189
.body(embeddingsRequest)
190190
.retrieve()
191-
.onStatus(this.responseErrorHandler)
192191
.body(EmbeddingsResponse.class);
193192
}
194193

@@ -199,7 +198,6 @@ public ListModelResponse listModels() {
199198
return this.restClient.get()
200199
.uri("/api/tags")
201200
.retrieve()
202-
.onStatus(this.responseErrorHandler)
203201
.body(ListModelResponse.class);
204202
}
205203

@@ -212,7 +210,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
212210
.uri("/api/show")
213211
.body(showModelRequest)
214212
.retrieve()
215-
.onStatus(this.responseErrorHandler)
216213
.body(ShowModelResponse.class);
217214
}
218215

@@ -225,7 +222,6 @@ public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
225222
.uri("/api/copy")
226223
.body(copyModelRequest)
227224
.retrieve()
228-
.onStatus(this.responseErrorHandler)
229225
.toBodilessEntity();
230226
}
231227

@@ -238,7 +234,6 @@ public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
238234
.uri("/api/delete")
239235
.body(deleteModelRequest)
240236
.retrieve()
241-
.onStatus(this.responseErrorHandler)
242237
.toBodilessEntity();
243238
}
244239

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.springframework.ai.ollama.api.OllamaApi;
3737
import org.springframework.ai.ollama.api.OllamaOptions;
3838
import org.springframework.ai.ollama.api.tool.MockWeatherService;
39+
import org.springframework.ai.retry.RetryUtils;
3940
import org.springframework.beans.factory.annotation.Autowired;
4041
import org.springframework.boot.SpringBootConfiguration;
4142
import org.springframework.boot.test.context.SpringBootTest;
@@ -122,6 +123,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
122123
return OllamaChatModel.builder()
123124
.ollamaApi(ollamaApi)
124125
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
126+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
125127
.build();
126128
}
127129

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.springframework.ai.ollama.management.ModelManagementOptions;
4545
import org.springframework.ai.ollama.management.OllamaModelManager;
4646
import org.springframework.ai.ollama.management.PullModelStrategy;
47+
import org.springframework.ai.retry.RetryUtils;
4748
import org.springframework.beans.factory.annotation.Autowired;
4849
import org.springframework.boot.SpringBootConfiguration;
4950
import org.springframework.boot.test.context.SpringBootTest;
@@ -277,6 +278,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
277278
.pullModelStrategy(PullModelStrategy.WHEN_MISSING)
278279
.additionalModels(List.of(ADDITIONAL_MODEL))
279280
.build())
281+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
280282
.build();
281283
}
282284

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.Media;
2828
import org.springframework.ai.ollama.api.OllamaApi;
2929
import org.springframework.ai.ollama.api.OllamaOptions;
30+
import org.springframework.ai.retry.RetryUtils;
3031
import org.springframework.beans.factory.annotation.Autowired;
3132
import org.springframework.boot.SpringBootConfiguration;
3233
import org.springframework.boot.test.context.SpringBootTest;
@@ -84,6 +85,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
8485
return OllamaChatModel.builder()
8586
.ollamaApi(ollamaApi)
8687
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
88+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
8789
.build();
8890
}
8991

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.ollama.api.OllamaApi;
3535
import org.springframework.ai.ollama.api.OllamaModel;
3636
import org.springframework.ai.ollama.api.OllamaOptions;
37+
import org.springframework.ai.retry.RetryUtils;
3738
import org.springframework.beans.factory.annotation.Autowired;
3839
import org.springframework.boot.SpringBootConfiguration;
3940
import org.springframework.boot.test.context.SpringBootTest;
@@ -47,6 +48,7 @@
4748
* Integration tests for observation instrumentation in {@link OllamaChatModel}.
4849
*
4950
* @author Thomas Vitale
51+
* @author Alexandros Pappas
5052
*/
5153
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
5254
public class OllamaChatModelObservationIT extends BaseOllamaIT {
@@ -169,7 +171,11 @@ public OllamaApi openAiApi() {
169171

170172
@Bean
171173
public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) {
172-
return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build();
174+
return OllamaChatModel.builder()
175+
.ollamaApi(ollamaApi)
176+
.observationRegistry(observationRegistry)
177+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
178+
.build();
173179
}
174180

175181
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.ollama.api.OllamaApi;
3232
import org.springframework.ai.ollama.api.OllamaModel;
3333
import org.springframework.ai.ollama.api.OllamaOptions;
34+
import org.springframework.ai.retry.RetryUtils;
3435

3536
import static org.assertj.core.api.Assertions.assertThat;
3637
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -54,6 +55,7 @@ public void buildOllamaChatModel() {
5455
() -> OllamaChatModel.builder()
5556
.ollamaApi(this.ollamaApi)
5657
.defaultOptions(OllamaOptions.builder().model(OllamaModel.LLAMA2).build())
58+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
5759
.modelManagementOptions(null)
5860
.build());
5961
assertEquals("modelManagementOptions must not be null", exception.getMessage());

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,21 @@
2222
import org.springframework.ai.chat.prompt.Prompt;
2323
import org.springframework.ai.ollama.api.OllamaApi;
2424
import org.springframework.ai.ollama.api.OllamaOptions;
25+
import org.springframework.ai.retry.RetryUtils;
2526

2627
import static org.assertj.core.api.Assertions.assertThat;
2728

2829
/**
2930
* @author Christian Tzolov
3031
* @author Thomas Vitale
32+
* @author Alexandros Pappas
3133
*/
3234
public class OllamaChatRequestTests {
3335

3436
OllamaChatModel chatModel = OllamaChatModel.builder()
3537
.ollamaApi(new OllamaApi())
3638
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
39+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
3740
.build();
3841

3942
@Test
@@ -107,6 +110,7 @@ public void createRequestWithDefaultOptionsModelOverride() {
107110
OllamaChatModel chatModel = OllamaChatModel.builder()
108111
.ollamaApi(new OllamaApi())
109112
.defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build())
113+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
110114
.build();
111115

112116
var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true);

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.springframework.boot.context.properties.EnableConfigurationProperties;
4242
import org.springframework.context.ApplicationContext;
4343
import org.springframework.context.annotation.Bean;
44+
import org.springframework.retry.support.RetryTemplate;
4445
import org.springframework.web.client.RestClient;
4546
import org.springframework.web.reactive.function.client.WebClient;
4647

@@ -82,7 +83,7 @@ public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
8283
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
8384
OllamaInitializationProperties initProperties, List<FunctionCallback> toolFunctionCallbacks,
8485
FunctionCallbackResolver functionCallbackResolver, ObjectProvider<ObservationRegistry> observationRegistry,
85-
ObjectProvider<ChatModelObservationConvention> observationConvention) {
86+
ObjectProvider<ChatModelObservationConvention> observationConvention, RetryTemplate retryTemplate) {
8687
var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
8788
: PullModelStrategy.NEVER;
8889

@@ -95,6 +96,7 @@ public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties
9596
.modelManagementOptions(
9697
new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
9798
initProperties.getTimeout(), initProperties.getMaxRetries()))
99+
.retryTemplate(retryTemplate)
98100
.build();
99101

100102
observationConvention.ifAvailable(chatModel::setObservationConvention);

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.junit.jupiter.api.Test;
2020

21+
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
2122
import org.springframework.boot.autoconfigure.AutoConfigurations;
2223
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
2324
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -41,7 +42,8 @@ public void propertiesTest() {
4142
"spring.ai.ollama.chat.options.topP=0.56",
4243
"spring.ai.ollama.chat.options.topK=123")
4344
// @formatter:on
44-
.withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
45+
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
46+
RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
4547
.run(context -> {
4648
var chatProperties = context.getBean(OllamaChatProperties.class);
4749
var connectionProperties = context.getBean(OllamaConnectionProperties.class);

0 commit comments

Comments
 (0)