Skip to content

Commit d48f4f6

Browse files
dev-jonghoonparkilayaperumalg
authored andcommitted
apply builder pattern to OllamaApi
Signed-off-by: jonghoon park <[email protected]>
1 parent 53406c1 commit d48f4f6

File tree

10 files changed

+122
-57
lines changed

10 files changed

+122
-57
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
6262
import org.springframework.ai.ollama.api.OllamaModel;
6363
import org.springframework.ai.ollama.api.OllamaOptions;
64+
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
6465
import org.springframework.ai.ollama.management.ModelManagementOptions;
6566
import org.springframework.ai.ollama.management.OllamaModelManager;
6667
import org.springframework.ai.ollama.management.PullModelStrategy;
@@ -224,7 +225,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
224225

225226
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
226227
.prompt(prompt)
227-
.provider(OllamaApi.PROVIDER_NAME)
228+
.provider(OllamaApiConstants.PROVIDER_NAME)
228229
.build();
229230

230231
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
@@ -294,7 +295,7 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
294295

295296
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
296297
.prompt(prompt)
297-
.provider(OllamaApi.PROVIDER_NAME)
298+
.provider(OllamaApiConstants.PROVIDER_NAME)
298299
.build();
299300

300301
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
@@ -343,8 +344,7 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
343344
return Flux.just(ChatResponse.builder().from(response)
344345
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
345346
.build());
346-
}
347-
else {
347+
} else {
348348
// Send the tool execution result back to the model.
349349
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
350350
response);

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
4444
import org.springframework.ai.ollama.api.OllamaModel;
4545
import org.springframework.ai.ollama.api.OllamaOptions;
46+
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
4647
import org.springframework.ai.ollama.management.ModelManagementOptions;
4748
import org.springframework.ai.ollama.management.OllamaModelManager;
4849
import org.springframework.ai.ollama.management.PullModelStrategy;
@@ -112,7 +113,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
112113

113114
var observationContext = EmbeddingModelObservationContext.builder()
114115
.embeddingRequest(request)
115-
.provider(OllamaApi.PROVIDER_NAME)
116+
.provider(OllamaApiConstants.PROVIDER_NAME)
116117
.build();
117118

118119
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -30,18 +30,17 @@
3030
import com.fasterxml.jackson.annotation.JsonProperty;
3131
import org.apache.commons.logging.Log;
3232
import org.apache.commons.logging.LogFactory;
33+
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
34+
import org.springframework.ai.retry.RetryUtils;
3335
import reactor.core.publisher.Flux;
3436
import reactor.core.publisher.Mono;
3537

3638
import org.springframework.ai.model.ModelOptionsUtils;
37-
import org.springframework.ai.observation.conventions.AiProvider;
3839
import org.springframework.http.HttpHeaders;
3940
import org.springframework.http.HttpMethod;
4041
import org.springframework.http.MediaType;
4142
import org.springframework.http.ResponseEntity;
42-
import org.springframework.http.client.ClientHttpResponse;
4343
import org.springframework.util.Assert;
44-
import org.springframework.util.StreamUtils;
4544
import org.springframework.web.client.ResponseErrorHandler;
4645
import org.springframework.web.client.RestClient;
4746
import org.springframework.web.reactive.function.client.WebClient;
@@ -51,58 +50,74 @@
5150
*
5251
* @author Christian Tzolov
5352
* @author Thomas Vitale
53+
* @author Jonghoon Park
5454
* @since 0.8.0
5555
*/
5656
// @formatter:off
5757
public class OllamaApi {
5858

59-
public static final String PROVIDER_NAME = AiProvider.OLLAMA.value();
59+
public static Builder builder() { return new Builder(); }
6060

6161
public static final String REQUEST_BODY_NULL_ERROR = "The request body can not be null.";
6262

6363
private static final Log logger = LogFactory.getLog(OllamaApi.class);
6464

65-
private static final String DEFAULT_BASE_URL = "http://localhost:11434";
66-
67-
private final ResponseErrorHandler responseErrorHandler;
68-
6965
private final RestClient restClient;
7066

7167
private final WebClient webClient;
7268

7369
/**
7470
* Default constructor that uses the default localhost url.
7571
*/
72+
@Deprecated(since = "1.0.0.M8")
7673
public OllamaApi() {
77-
this(DEFAULT_BASE_URL);
74+
this(OllamaApiConstants.DEFAULT_BASE_URL);
7875
}
7976

8077
/**
8178
* Crate a new OllamaApi instance with the given base url.
8279
* @param baseUrl The base url of the Ollama server.
8380
*/
81+
@Deprecated(since = "1.0.0.M8")
8482
public OllamaApi(String baseUrl) {
85-
this(baseUrl, RestClient.builder(), WebClient.builder());
83+
this(baseUrl, RestClient.builder(), WebClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
8684
}
8785

8886
/**
8987
* Crate a new OllamaApi instance with the given base url and
9088
* {@link RestClient.Builder}.
9189
* @param baseUrl The base url of the Ollama server.
9290
* @param restClientBuilder The {@link RestClient.Builder} to use.
91+
* @param webClientBuilder The {@link WebClient.Builder} to use.
9392
*/
93+
@Deprecated(since = "1.0.0.M8")
9494
public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {
95+
this(baseUrl, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
96+
}
9597

96-
this.responseErrorHandler = new OllamaResponseErrorHandler();
98+
/**
99+
* Create a new OllamaApi instance
100+
* @param baseUrl The base url of the Ollama server.
101+
* @param restClientBuilder The {@link RestClient.Builder} to use.
102+
* @param webClientBuilder The {@link WebClient.Builder} to use.
103+
* @param responseErrorHandler Response error handler.
104+
*/
105+
private OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
97106

98107
Consumer<HttpHeaders> defaultHeaders = headers -> {
99108
headers.setContentType(MediaType.APPLICATION_JSON);
100109
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
101110
};
102111

103-
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
112+
this.restClient = restClientBuilder.baseUrl(baseUrl)
113+
.defaultHeaders(defaultHeaders)
114+
.defaultStatusHandler(responseErrorHandler)
115+
.build();
104116

105-
this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
117+
this.webClient = webClientBuilder
118+
.baseUrl(baseUrl)
119+
.defaultHeaders(defaultHeaders)
120+
.build();
106121
}
107122

108123
/**
@@ -121,7 +136,6 @@ public ChatResponse chat(ChatRequest chatRequest) {
121136
.uri("/api/chat")
122137
.body(chatRequest)
123138
.retrieve()
124-
.onStatus(this.responseErrorHandler)
125139
.body(ChatResponse.class);
126140
}
127141

@@ -188,7 +202,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
188202
.uri("/api/embed")
189203
.body(embeddingsRequest)
190204
.retrieve()
191-
.onStatus(this.responseErrorHandler)
192205
.body(EmbeddingsResponse.class);
193206
}
194207

@@ -199,7 +212,6 @@ public ListModelResponse listModels() {
199212
return this.restClient.get()
200213
.uri("/api/tags")
201214
.retrieve()
202-
.onStatus(this.responseErrorHandler)
203215
.body(ListModelResponse.class);
204216
}
205217

@@ -212,7 +224,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
212224
.uri("/api/show")
213225
.body(showModelRequest)
214226
.retrieve()
215-
.onStatus(this.responseErrorHandler)
216227
.body(ShowModelResponse.class);
217228
}
218229

@@ -225,7 +236,6 @@ public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
225236
.uri("/api/copy")
226237
.body(copyModelRequest)
227238
.retrieve()
228-
.onStatus(this.responseErrorHandler)
229239
.toBodilessEntity();
230240
}
231241

@@ -238,7 +248,6 @@ public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
238248
.uri("/api/delete")
239249
.body(deleteModelRequest)
240250
.retrieve()
241-
.onStatus(this.responseErrorHandler)
242251
.toBodilessEntity();
243252
}
244253

@@ -261,26 +270,6 @@ public Flux<ProgressResponse> pullModel(PullModelRequest pullModelRequest) {
261270
.bodyToFlux(ProgressResponse.class);
262271
}
263272

264-
private static class OllamaResponseErrorHandler implements ResponseErrorHandler {
265-
266-
@Override
267-
public boolean hasError(ClientHttpResponse response) throws IOException {
268-
return response.getStatusCode().isError();
269-
}
270-
271-
@Override
272-
public void handleError(ClientHttpResponse response) throws IOException {
273-
if (response.getStatusCode().isError()) {
274-
int statusCode = response.getStatusCode().value();
275-
String statusText = response.getStatusText();
276-
String message = StreamUtils.copyToString(response.getBody(), java.nio.charset.StandardCharsets.UTF_8);
277-
logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message));
278-
throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message));
279-
}
280-
}
281-
282-
}
283-
284273
/**
285274
* Chat message object.
286275
*
@@ -736,5 +725,44 @@ public record ProgressResponse(
736725
@JsonProperty("completed") Long completed
737726
) { }
738727

728+
public static class Builder {
729+
730+
private String baseUrl = OllamaApiConstants.DEFAULT_BASE_URL;
731+
732+
private RestClient.Builder restClientBuilder = RestClient.builder();
733+
734+
private WebClient.Builder webClientBuilder = WebClient.builder();
735+
736+
private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER;
737+
738+
public Builder baseUrl(String baseUrl) {
739+
Assert.hasText(baseUrl, "baseUrl cannot be null or empty");
740+
this.baseUrl = baseUrl;
741+
return this;
742+
}
743+
744+
public Builder restClientBuilder(RestClient.Builder restClientBuilder) {
745+
Assert.notNull(restClientBuilder, "restClientBuilder cannot be null");
746+
this.restClientBuilder = restClientBuilder;
747+
return this;
748+
}
749+
750+
public Builder webClientBuilder(WebClient.Builder webClientBuilder) {
751+
Assert.notNull(webClientBuilder, "webClientBuilder cannot be null");
752+
this.webClientBuilder = webClientBuilder;
753+
return this;
754+
}
755+
756+
public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) {
757+
Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null");
758+
this.responseErrorHandler = responseErrorHandler;
759+
return this;
760+
}
761+
762+
public OllamaApi build() {
763+
return new OllamaApi(this.baseUrl, this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler);
764+
}
765+
766+
}
739767
}
740768
// @formatter:on
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.ollama.api.common;
18+
19+
import org.springframework.ai.observation.conventions.AiProvider;
20+
21+
/**
22+
* Common value constants for Ollama api.
23+
*
24+
* @author Jonghoon Park
25+
*/
26+
public final class OllamaApiConstants {
27+
28+
public static final String DEFAULT_BASE_URL = "http://localhost:11434";
29+
30+
public static final String PROVIDER_NAME = AiProvider.OLLAMA.value();
31+
32+
private OllamaApiConstants() {
33+
34+
}
35+
36+
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -86,7 +86,7 @@ public static void tearDown() {
8686

8787
private static OllamaApi buildOllamaApiWithModel(final String model) {
8888
final String baseUrl = SKIP_CONTAINER_CREATION ? OLLAMA_LOCAL_URL : ollamaContainer.getEndpoint();
89-
final OllamaApi api = new OllamaApi(baseUrl);
89+
final OllamaApi api = OllamaApi.builder().baseUrl(baseUrl).build();
9090
ensureModelIsPresent(api, model);
9191
return api;
9292
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
class OllamaChatRequestTests {
3838

3939
OllamaChatModel chatModel = OllamaChatModel.builder()
40-
.ollamaApi(new OllamaApi())
40+
.ollamaApi(OllamaApi.builder().build())
4141
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
4242
.build();
4343

@@ -51,7 +51,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() {
5151
.toolContext(Map.of("key1", "value1", "key2", "valueA"))
5252
.build();
5353
OllamaChatModel chatModel = OllamaChatModel.builder()
54-
.ollamaApi(new OllamaApi())
54+
.ollamaApi(OllamaApi.builder().build())
5555
.defaultOptions(defaultOptions)
5656
.build();
5757

@@ -143,7 +143,7 @@ public void createRequestWithPromptOptionsModelOverride() {
143143
@Test
144144
public void createRequestWithDefaultOptionsModelOverride() {
145145
OllamaChatModel chatModel = OllamaChatModel.builder()
146-
.ollamaApi(new OllamaApi())
146+
.ollamaApi(OllamaApi.builder().build())
147147
.defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build())
148148
.build();
149149

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
public class OllamaEmbeddingRequestTests {
3636

3737
OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel.builder()
38-
.ollamaApi(new OllamaApi())
38+
.ollamaApi(OllamaApi.builder().build())
3939
.defaultOptions(OllamaOptions.builder().model("DEFAULT_MODEL").mainGPU(11).useMMap(true).numGPU(1).build())
4040
.build();
4141

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ Next, create an `OllamaChatModel` instance and use it to send requests for text
483483

484484
[source,java]
485485
----
486-
var ollamaApi = new OllamaApi();
486+
var ollamaApi = OllamaApi.builder().build();
487487
488488
var chatModel = OllamaChatModel.builder()
489489
.ollamaApi(ollamaApi)

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ Next, create an `OllamaEmbeddingModel` instance and use it to compute the embedd
319319

320320
[source,java]
321321
----
322-
var ollamaApi = new OllamaApi();
322+
var ollamaApi = OllamaApi.builder().build();
323323
324324
var embeddingModel = new OllamaEmbeddingModel(this.ollamaApi,
325325
OllamaOptions.builder()

0 commit comments

Comments
 (0)