Skip to content

Commit 8222f78

Browse files
committed
code refactor after rebase
Signed-off-by: Alexandros Pappas <[email protected]>
1 parent 201edde commit 8222f78

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,10 @@ public OllamaChatModel build() {
617617
toolCallbacks, this.observationRegistry, this.modelManagementOptions);
618618
}
619619

620+
if (this.retryTemplate == null) {
621+
this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
622+
}
623+
620624
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
621625
this.observationRegistry, this.modelManagementOptions, this.retryTemplate);
622626
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package org.springframework.ai.ollama;
2+
3+
import java.time.Instant;
4+
5+
import org.junit.jupiter.api.BeforeEach;
6+
import org.junit.jupiter.api.Test;
7+
import org.junit.jupiter.api.extension.ExtendWith;
8+
import org.mockito.Mock;
9+
import org.mockito.junit.jupiter.MockitoExtension;
10+
11+
import org.springframework.ai.chat.prompt.Prompt;
12+
import org.springframework.ai.ollama.api.OllamaApi;
13+
import org.springframework.ai.ollama.api.OllamaModel;
14+
import org.springframework.ai.ollama.api.OllamaOptions;
15+
import org.springframework.ai.retry.RetryUtils;
16+
import org.springframework.ai.retry.TransientAiException;
17+
import org.springframework.retry.RetryCallback;
18+
import org.springframework.retry.RetryContext;
19+
import org.springframework.retry.RetryListener;
20+
import org.springframework.retry.support.RetryTemplate;
21+
22+
import static org.assertj.core.api.Assertions.assertThat;
23+
import static org.mockito.ArgumentMatchers.isA;
24+
import static org.mockito.Mockito.when;
25+
26+
/**
27+
* Tests for the OllamaRetryTests class.
28+
*
29+
* @author Alexandros Pappas
30+
*/
31+
@ExtendWith(MockitoExtension.class)
32+
class OllamaRetryTests {
33+
34+
private static final String MODEL = OllamaModel.LLAMA3_2.getName();
35+
36+
private TestRetryListener retryListener;
37+
38+
private RetryTemplate retryTemplate;
39+
40+
@Mock
41+
private OllamaApi ollamaApi;
42+
43+
private OllamaChatModel chatModel;
44+
45+
@BeforeEach
46+
public void beforeEach() {
47+
this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE;
48+
this.retryListener = new TestRetryListener();
49+
this.retryTemplate.registerListener(this.retryListener);
50+
51+
this.chatModel = OllamaChatModel.builder()
52+
.ollamaApi(this.ollamaApi)
53+
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
54+
.retryTemplate(this.retryTemplate)
55+
.build();
56+
}
57+
58+
@Test
59+
void ollamaChatTransientError() {
60+
String promptText = "What is the capital of Bulgaria and what is the size? What it the national anthem?";
61+
var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
62+
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("Response").build(), null, true,
63+
null, null, null, null, null, null);
64+
65+
when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
66+
.thenThrow(new TransientAiException("Transient Error 1"))
67+
.thenThrow(new TransientAiException("Transient Error 2"))
68+
.thenReturn(expectedChatResponse);
69+
70+
var result = this.chatModel.call(new Prompt(promptText));
71+
72+
assertThat(result).isNotNull();
73+
assertThat(result.getResult().getOutput().getText()).isSameAs("Response");
74+
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2);
75+
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
76+
}
77+
78+
private static class TestRetryListener implements RetryListener {
79+
80+
int onErrorRetryCount = 0;
81+
82+
int onSuccessRetryCount = 0;
83+
84+
@Override
85+
public <T, E extends Throwable> void onSuccess(RetryContext context, RetryCallback<T, E> callback, T result) {
86+
this.onSuccessRetryCount = context.getRetryCount();
87+
}
88+
89+
@Override
90+
public <T, E extends Throwable> void onError(RetryContext context, RetryCallback<T, E> callback,
91+
Throwable throwable) {
92+
this.onErrorRetryCount = context.getRetryCount();
93+
}
94+
95+
}
96+
97+
}

0 commit comments

Comments
 (0)