Skip to content

Commit f08e9d9

Browse files
committed
feat(ollama): add retry template integration to OllamaChatModel
Signed-off-by: Alexandros Pappas <[email protected]>
1 parent 092bbae commit f08e9d9

File tree

11 files changed

+54
-20
lines changed

11 files changed

+54
-20
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
import org.springframework.ai.ollama.management.ModelManagementOptions;
7070
import org.springframework.ai.ollama.management.OllamaModelManager;
7171
import org.springframework.ai.ollama.management.PullModelStrategy;
72+
import org.springframework.ai.retry.RetryUtils;
73+
import org.springframework.retry.support.RetryTemplate;
7274
import org.springframework.util.Assert;
7375
import org.springframework.util.CollectionUtils;
7476
import org.springframework.util.StringUtils;
@@ -122,6 +124,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
122124

123125
private final ToolCallingManager toolCallingManager;
124126

127+
private final RetryTemplate retryTemplate;
128+
125129
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
126130

127131
@Deprecated
@@ -130,14 +134,14 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
130134
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry,
131135
ModelManagementOptions modelManagementOptions) {
132136
this(ollamaApi, defaultOptions, new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks),
133-
observationRegistry, modelManagementOptions);
137+
observationRegistry, modelManagementOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
134138

135139
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
136140
+ "Please use the OllamaChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
137141
}
138142

139143
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
140-
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
144+
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions, RetryTemplate retryTemplate) {
141145
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
142146
// because it modifies them. We are using ToolCallingManager instead,
143147
// so we just pass empty options here.
@@ -147,11 +151,13 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCa
147151
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
148152
Assert.notNull(observationRegistry, "observationRegistry must not be null");
149153
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
154+
Assert.notNull(retryTemplate, "retryTemplate must not be null");
150155
this.chatApi = ollamaApi;
151156
this.defaultOptions = defaultOptions;
152157
this.toolCallingManager = toolCallingManager;
153158
this.observationRegistry = observationRegistry;
154159
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
160+
this.retryTemplate = retryTemplate;
155161
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
156162
}
157163

@@ -237,7 +243,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
237243
this.observationRegistry)
238244
.observe(() -> {
239245

240-
OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
246+
OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request));
241247

242248
List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
243249
: ollamaResponse.message()
@@ -543,6 +549,8 @@ public static final class Builder {
543549

544550
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
545551

552+
private RetryTemplate retryTemplate;
553+
546554
private Builder() {
547555
}
548556

@@ -583,6 +591,11 @@ public Builder modelManagementOptions(ModelManagementOptions modelManagementOpti
583591
return this;
584592
}
585593

594+
public Builder retryTemplate(RetryTemplate retryTemplate) {
595+
this.retryTemplate = retryTemplate;
596+
return this;
597+
}
598+
586599
public OllamaChatModel build() {
587600
if (toolCallingManager != null) {
588601
Assert.isNull(functionCallbackResolver,
@@ -591,7 +604,7 @@ public OllamaChatModel build() {
591604
"toolFunctionCallbacks must not be set when toolCallingManager is set");
592605

593606
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager,
594-
this.observationRegistry, this.modelManagementOptions);
607+
this.observationRegistry, this.modelManagementOptions, this.retryTemplate);
595608
}
596609

597610
if (functionCallbackResolver != null) {
@@ -604,7 +617,7 @@ public OllamaChatModel build() {
604617
}
605618

606619
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
607-
this.observationRegistry, this.modelManagementOptions);
620+
this.observationRegistry, this.modelManagementOptions, this.retryTemplate);
608621
}
609622

610623
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
*
5252
* @author Christian Tzolov
5353
* @author Thomas Vitale
54+
* @author Alexandros Pappas
5455
* @since 0.8.0
5556
*/
5657
// @formatter:off
@@ -64,8 +65,6 @@ public class OllamaApi {
6465

6566
private static final String DEFAULT_BASE_URL = "http://localhost:11434";
6667

67-
private final ResponseErrorHandler responseErrorHandler;
68-
6968
private final RestClient restClient;
7069

7170
private final WebClient webClient;
@@ -93,14 +92,16 @@ public OllamaApi(String baseUrl) {
9392
*/
9493
public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {
9594

96-
this.responseErrorHandler = new OllamaResponseErrorHandler();
95+
ResponseErrorHandler responseErrorHandler = new OllamaResponseErrorHandler();
9796

9897
Consumer<HttpHeaders> defaultHeaders = headers -> {
9998
headers.setContentType(MediaType.APPLICATION_JSON);
10099
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
101100
};
102101

103-
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
102+
this.restClient = restClientBuilder.baseUrl(baseUrl)
103+
.defaultStatusHandler(responseErrorHandler)
104+
.defaultHeaders(defaultHeaders).build();
104105

105106
this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
106107
}
@@ -121,7 +122,6 @@ public ChatResponse chat(ChatRequest chatRequest) {
121122
.uri("/api/chat")
122123
.body(chatRequest)
123124
.retrieve()
124-
.onStatus(this.responseErrorHandler)
125125
.body(ChatResponse.class);
126126
}
127127

@@ -188,7 +188,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
188188
.uri("/api/embed")
189189
.body(embeddingsRequest)
190190
.retrieve()
191-
.onStatus(this.responseErrorHandler)
192191
.body(EmbeddingsResponse.class);
193192
}
194193

@@ -199,7 +198,6 @@ public ListModelResponse listModels() {
199198
return this.restClient.get()
200199
.uri("/api/tags")
201200
.retrieve()
202-
.onStatus(this.responseErrorHandler)
203201
.body(ListModelResponse.class);
204202
}
205203

@@ -212,7 +210,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
212210
.uri("/api/show")
213211
.body(showModelRequest)
214212
.retrieve()
215-
.onStatus(this.responseErrorHandler)
216213
.body(ShowModelResponse.class);
217214
}
218215

@@ -225,7 +222,6 @@ public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
225222
.uri("/api/copy")
226223
.body(copyModelRequest)
227224
.retrieve()
228-
.onStatus(this.responseErrorHandler)
229225
.toBodilessEntity();
230226
}
231227

@@ -238,7 +234,6 @@ public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
238234
.uri("/api/delete")
239235
.body(deleteModelRequest)
240236
.retrieve()
241-
.onStatus(this.responseErrorHandler)
242237
.toBodilessEntity();
243238
}
244239

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.springframework.ai.ollama.api.OllamaOptions;
3737
import org.springframework.ai.ollama.api.tool.MockWeatherService;
3838
import org.springframework.ai.tool.function.FunctionToolCallback;
39+
import org.springframework.ai.retry.RetryUtils;
3940
import org.springframework.beans.factory.annotation.Autowired;
4041
import org.springframework.boot.SpringBootConfiguration;
4142
import org.springframework.boot.test.context.SpringBootTest;
@@ -120,6 +121,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
120121
return OllamaChatModel.builder()
121122
.ollamaApi(ollamaApi)
122123
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
124+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
123125
.build();
124126
}
125127

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.springframework.ai.ollama.management.ModelManagementOptions;
4545
import org.springframework.ai.ollama.management.OllamaModelManager;
4646
import org.springframework.ai.ollama.management.PullModelStrategy;
47+
import org.springframework.ai.retry.RetryUtils;
4748
import org.springframework.beans.factory.annotation.Autowired;
4849
import org.springframework.boot.SpringBootConfiguration;
4950
import org.springframework.boot.test.context.SpringBootTest;
@@ -277,6 +278,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
277278
.pullModelStrategy(PullModelStrategy.WHEN_MISSING)
278279
.additionalModels(List.of(ADDITIONAL_MODEL))
279280
.build())
281+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
280282
.build();
281283
}
282284

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.Media;
2828
import org.springframework.ai.ollama.api.OllamaApi;
2929
import org.springframework.ai.ollama.api.OllamaOptions;
30+
import org.springframework.ai.retry.RetryUtils;
3031
import org.springframework.beans.factory.annotation.Autowired;
3132
import org.springframework.boot.SpringBootConfiguration;
3233
import org.springframework.boot.test.context.SpringBootTest;
@@ -85,6 +86,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
8586
return OllamaChatModel.builder()
8687
.ollamaApi(ollamaApi)
8788
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
89+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
8890
.build();
8991
}
9092

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.ollama.api.OllamaApi;
3535
import org.springframework.ai.ollama.api.OllamaModel;
3636
import org.springframework.ai.ollama.api.OllamaOptions;
37+
import org.springframework.ai.retry.RetryUtils;
3738
import org.springframework.beans.factory.annotation.Autowired;
3839
import org.springframework.boot.SpringBootConfiguration;
3940
import org.springframework.boot.test.context.SpringBootTest;
@@ -47,6 +48,7 @@
4748
* Integration tests for observation instrumentation in {@link OllamaChatModel}.
4849
*
4950
* @author Thomas Vitale
51+
* @author Alexandros Pappas
5052
*/
5153
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
5254
public class OllamaChatModelObservationIT extends BaseOllamaIT {
@@ -169,7 +171,11 @@ public OllamaApi openAiApi() {
169171

170172
@Bean
171173
public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) {
172-
return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build();
174+
return OllamaChatModel.builder()
175+
.ollamaApi(ollamaApi)
176+
.observationRegistry(observationRegistry)
177+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
178+
.build();
173179
}
174180

175181
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.ai.ollama.api.OllamaModel;
3636
import org.springframework.ai.ollama.api.OllamaOptions;
3737
import org.springframework.ai.ollama.management.ModelManagementOptions;
38+
import org.springframework.ai.retry.RetryUtils;
3839

3940
import static org.assertj.core.api.Assertions.assertThat;
4041
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -65,7 +66,7 @@ void buildOllamaChatModelWithDeprecatedConstructor() {
6566
void buildOllamaChatModelWithConstructor() {
6667
ChatModel chatModel = new OllamaChatModel(this.ollamaApi,
6768
OllamaOptions.builder().model(OllamaModel.MISTRAL).build(), ToolCallingManager.builder().build(),
68-
ObservationRegistry.NOOP, ModelManagementOptions.builder().build());
69+
ObservationRegistry.NOOP, ModelManagementOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
6970
assertThat(chatModel).isNotNull();
7071
}
7172

@@ -81,6 +82,7 @@ void buildOllamaChatModel() {
8182
() -> OllamaChatModel.builder()
8283
.ollamaApi(this.ollamaApi)
8384
.defaultOptions(OllamaOptions.builder().model(OllamaModel.LLAMA2).build())
85+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
8486
.modelManagementOptions(null)
8587
.build());
8688
assertEquals("modelManagementOptions must not be null", exception.getMessage());

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.springframework.ai.ollama.api.OllamaOptions;
2727
import org.springframework.ai.tool.ToolCallback;
2828
import org.springframework.ai.tool.definition.ToolDefinition;
29+
import org.springframework.ai.retry.RetryUtils;
30+
2931

3032
import java.util.Map;
3133

@@ -34,12 +36,14 @@
3436
/**
3537
* @author Christian Tzolov
3638
* @author Thomas Vitale
39+
* @author Alexandros Pappas
3740
*/
3841
class OllamaChatRequestTests {
3942

4043
OllamaChatModel chatModel = OllamaChatModel.builder()
4144
.ollamaApi(new OllamaApi())
4245
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
46+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
4347
.build();
4448

4549
@Test
@@ -146,6 +150,7 @@ public void createRequestWithDefaultOptionsModelOverride() {
146150
OllamaChatModel chatModel = OllamaChatModel.builder()
147151
.ollamaApi(new OllamaApi())
148152
.defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build())
153+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
149154
.build();
150155

151156
var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content"));

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.springframework.boot.context.properties.EnableConfigurationProperties;
4141
import org.springframework.context.ApplicationContext;
4242
import org.springframework.context.annotation.Bean;
43+
import org.springframework.retry.support.RetryTemplate;
4344
import org.springframework.web.client.RestClient;
4445
import org.springframework.web.reactive.function.client.WebClient;
4546

@@ -82,7 +83,7 @@ public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
8283
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
8384
OllamaInitializationProperties initProperties, ToolCallingManager toolCallingManager,
8485
ObjectProvider<ObservationRegistry> observationRegistry,
85-
ObjectProvider<ChatModelObservationConvention> observationConvention) {
86+
ObjectProvider<ChatModelObservationConvention> observationConvention, RetryTemplate retryTemplate) {
8687
var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
8788
: PullModelStrategy.NEVER;
8889

@@ -94,6 +95,7 @@ public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties
9495
.modelManagementOptions(
9596
new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
9697
initProperties.getTimeout(), initProperties.getMaxRetries()))
98+
.retryTemplate(retryTemplate)
9799
.build();
98100

99101
observationConvention.ifAvailable(chatModel::setObservationConvention);

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.junit.jupiter.api.Test;
2020

21+
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
2122
import org.springframework.boot.autoconfigure.AutoConfigurations;
2223
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
2324
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -41,7 +42,8 @@ public void propertiesTest() {
4142
"spring.ai.ollama.chat.options.topP=0.56",
4243
"spring.ai.ollama.chat.options.topK=123")
4344
// @formatter:on
44-
.withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
45+
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
46+
RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
4547
.run(context -> {
4648
var chatProperties = context.getBean(OllamaChatProperties.class);
4749
var connectionProperties = context.getBean(OllamaConnectionProperties.class);

0 commit comments

Comments
 (0)