diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 051dabde9b0..54bf4461187 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -61,6 +61,7 @@ import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.ollama.api.common.OllamaApiConstants; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; @@ -224,7 +225,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) - .provider(OllamaApi.PROVIDER_NAME) + .provider(OllamaApiConstants.PROVIDER_NAME) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -294,7 +295,7 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) - .provider(OllamaApi.PROVIDER_NAME) + .provider(OllamaApiConstants.PROVIDER_NAME) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( @@ -343,8 +344,7 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh return Flux.just(ChatResponse.builder().from(response) .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) .build()); - } - else { + } else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java index 4a5710c9aed..a505d370e7e 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java @@ -43,6 +43,7 @@ import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.ollama.api.common.OllamaApiConstants; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; @@ -112,7 +113,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(request) - .provider(OllamaApi.PROVIDER_NAME) + .provider(OllamaApiConstants.PROVIDER_NAME) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index 1a238b5d890..0ab18df9612 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,18 +30,17 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.ai.ollama.api.common.OllamaApiConstants; +import org.springframework.ai.retry.RetryUtils; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.http.client.ClientHttpResponse; import org.springframework.util.Assert; -import org.springframework.util.StreamUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -51,21 +50,18 @@ * * @author Christian Tzolov * @author Thomas Vitale + * @author Jonghoon Park * @since 0.8.0 */ // @formatter:off public class OllamaApi { - public static final String PROVIDER_NAME = AiProvider.OLLAMA.value(); + public static Builder builder() { return new Builder(); } public static final String REQUEST_BODY_NULL_ERROR = "The request body can not be null."; private static final Log logger = LogFactory.getLog(OllamaApi.class); - private static final String DEFAULT_BASE_URL = "http://localhost:11434"; - - private final ResponseErrorHandler responseErrorHandler; - private final RestClient restClient; private final WebClient webClient; @@ -73,16 +69,18 @@ public class OllamaApi { /** * Default constructor that uses the default localhost url. */ + @Deprecated(since = "1.0.0.M7") public OllamaApi() { - this(DEFAULT_BASE_URL); + this(OllamaApiConstants.DEFAULT_BASE_URL); } /** * Crate a new OllamaApi instance with the given base url. * @param baseUrl The base url of the Ollama server. */ + @Deprecated(since = "1.0.0.M7") public OllamaApi(String baseUrl) { - this(baseUrl, RestClient.builder(), WebClient.builder()); + this(baseUrl, RestClient.builder(), WebClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); } /** @@ -90,19 +88,36 @@ public OllamaApi(String baseUrl) { * {@link RestClient.Builder}. * @param baseUrl The base url of the Ollama server. * @param restClientBuilder The {@link RestClient.Builder} to use. + * @param webClientBuilder The {@link WebClient.Builder} to use. */ + @Deprecated(since = "1.0.0.M7") public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) { + this(baseUrl, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } - this.responseErrorHandler = new OllamaResponseErrorHandler(); + /** + * Create a new OllamaApi instance + * @param baseUrl The base url of the Ollama server. + * @param restClientBuilder The {@link RestClient.Builder} to use. + * @param webClientBuilder The {@link WebClient.Builder} to use. + * @param responseErrorHandler Response error handler. + */ + private OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { Consumer defaultHeaders = headers -> { headers.setContentType(MediaType.APPLICATION_JSON); headers.setAccept(List.of(MediaType.APPLICATION_JSON)); }; - this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(defaultHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); - this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); + this.webClient = webClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(defaultHeaders) + .build(); } /** @@ -121,7 +136,6 @@ public ChatResponse chat(ChatRequest chatRequest) { .uri("/api/chat") .body(chatRequest) .retrieve() - .onStatus(this.responseErrorHandler) .body(ChatResponse.class); } @@ -188,7 +202,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) { .uri("/api/embed") .body(embeddingsRequest) .retrieve() - .onStatus(this.responseErrorHandler) .body(EmbeddingsResponse.class); } @@ -199,7 +212,6 @@ public ListModelResponse listModels() { return this.restClient.get() .uri("/api/tags") .retrieve() - .onStatus(this.responseErrorHandler) .body(ListModelResponse.class); } @@ -212,7 +224,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) { .uri("/api/show") .body(showModelRequest) .retrieve() - .onStatus(this.responseErrorHandler) .body(ShowModelResponse.class); } @@ -225,7 +236,6 @@ public ResponseEntity copyModel(CopyModelRequest copyModelRequest) { .uri("/api/copy") .body(copyModelRequest) .retrieve() - .onStatus(this.responseErrorHandler) .toBodilessEntity(); } @@ -238,7 +248,6 @@ public ResponseEntity deleteModel(DeleteModelRequest deleteModelRequest) { .uri("/api/delete") .body(deleteModelRequest) .retrieve() - .onStatus(this.responseErrorHandler) .toBodilessEntity(); } @@ -261,26 +270,6 @@ public Flux pullModel(PullModelRequest pullModelRequest) { .bodyToFlux(ProgressResponse.class); } - private static class OllamaResponseErrorHandler implements ResponseErrorHandler { - - @Override - public boolean hasError(ClientHttpResponse response) throws IOException { - return response.getStatusCode().isError(); - } - - @Override - public void handleError(ClientHttpResponse response) throws IOException { - if (response.getStatusCode().isError()) { - int statusCode = response.getStatusCode().value(); - String statusText = response.getStatusText(); - String message = StreamUtils.copyToString(response.getBody(), java.nio.charset.StandardCharsets.UTF_8); - logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message)); - throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message)); - } - } - - } - /** * Chat message object. * @@ -736,5 +725,44 @@ public record ProgressResponse( @JsonProperty("completed") Long completed ) { } + public static class Builder { + + private String baseUrl = OllamaApiConstants.DEFAULT_BASE_URL; + + private RestClient.Builder restClientBuilder = RestClient.builder(); + + private WebClient.Builder webClientBuilder = WebClient.builder(); + + private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER; + + public Builder baseUrl(String baseUrl) { + Assert.hasText(baseUrl, "baseUrl cannot be null or empty"); + this.baseUrl = baseUrl; + return this; + } + + public Builder restClientBuilder(RestClient.Builder restClientBuilder) { + Assert.notNull(restClientBuilder, "restClientBuilder cannot be null"); + this.restClientBuilder = restClientBuilder; + return this; + } + + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "webClientBuilder cannot be null"); + this.webClientBuilder = webClientBuilder; + return this; + } + + public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { + Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null"); + this.responseErrorHandler = responseErrorHandler; + return this; + } + + public OllamaApi build() { + return new OllamaApi(this.baseUrl, this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler); + } + + } } // @formatter:on diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/common/OllamaApiConstants.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/common/OllamaApiConstants.java new file mode 100644 index 00000000000..ca252582674 --- /dev/null +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/common/OllamaApiConstants.java @@ -0,0 +1,36 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama.api.common; + +import org.springframework.ai.observation.conventions.AiProvider; + +/** + * Common value constants for Ollama api. + * + * @author Jonghoon Park + */ +public final class OllamaApiConstants { + + public static final String DEFAULT_BASE_URL = "http://localhost:11434"; + + public static final String PROVIDER_NAME = AiProvider.OLLAMA.value(); + + private OllamaApiConstants() { + + } + +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java index 3ca09455a9d..d5601cd78ca 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -86,7 +86,7 @@ public static void tearDown() { private static OllamaApi buildOllamaApiWithModel(final String model) { final String baseUrl = SKIP_CONTAINER_CREATION ? OLLAMA_LOCAL_URL : ollamaContainer.getEndpoint(); - final OllamaApi api = new OllamaApi(baseUrl); + final OllamaApi api = OllamaApi.builder().baseUrl(baseUrl).build(); ensureModelIsPresent(api, model); return api; } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 2f8d10c9e69..423e3d3ed65 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -37,7 +37,7 @@ class OllamaChatRequestTests { OllamaChatModel chatModel = OllamaChatModel.builder() - .ollamaApi(new OllamaApi()) + .ollamaApi(OllamaApi.builder().build()) .defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) .build(); @@ -51,7 +51,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { .toolContext(Map.of("key1", "value1", "key2", "valueA")) .build(); OllamaChatModel chatModel = OllamaChatModel.builder() - .ollamaApi(new OllamaApi()) + .ollamaApi(OllamaApi.builder().build()) .defaultOptions(defaultOptions) .build(); @@ -143,7 +143,7 @@ public void createRequestWithPromptOptionsModelOverride() { @Test public void createRequestWithDefaultOptionsModelOverride() { OllamaChatModel chatModel = OllamaChatModel.builder() - .ollamaApi(new OllamaApi()) + .ollamaApi(OllamaApi.builder().build()) .defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build()) .build(); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java index c1a52989fd5..57d15772c56 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java @@ -35,7 +35,7 @@ public class OllamaEmbeddingRequestTests { OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel.builder() - .ollamaApi(new OllamaApi()) + .ollamaApi(OllamaApi.builder().build()) .defaultOptions(OllamaOptions.builder().model("DEFAULT_MODEL").mainGPU(11).useMMap(true).numGPU(1).build()) .build(); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc index 633250fcec4..44ac458bda3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc @@ -483,7 +483,7 @@ Next, create an `OllamaChatModel` instance and use it to send requests for text [source,java] ---- -var ollamaApi = new OllamaApi(); +var ollamaApi = OllamaApi.builder().build(); var chatModel = OllamaChatModel.builder() .ollamaApi(ollamaApi) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc index 7b9c435218f..e327abb6d33 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc @@ -319,7 +319,7 @@ Next, create an `OllamaEmbeddingModel` instance and use it to compute the embedd [source,java] ---- -var ollamaApi = new OllamaApi(); +var ollamaApi = OllamaApi.builder().build(); var embeddingModel = new OllamaEmbeddingModel(this.ollamaApi, OllamaOptions.builder() diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc index 461228ffa9e..5a8877c7698 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/tools.adoc @@ -270,7 +270,7 @@ WARNING: Default tools are shared across all the chat requests performed by that ---- ToolCallback[] dateTimeTools = ToolCallbacks.from(new DateTimeTools()); ChatModel chatModel = OllamaChatModel.builder() - .ollamaApi(new OllamaApi()) + .ollamaApi(OllamaApi.builder().build()) .defaultOptions(ToolCallingChatOptions.builder() .toolCallbacks(dateTimeTools) .build()) @@ -438,7 +438,7 @@ WARNING: Default tools are shared across all the chat requests performed by that ---- ToolCallback toolCallback = ... ChatModel chatModel = OllamaChatModel.builder() - .ollamaApi(new OllamaApi()) + .ollamaApi(OllamaApi.builder().build()) .defaultOptions(ToolCallingChatOptions.builder() .toolCallbacks(toolCallback) .build()) @@ -560,7 +560,7 @@ WARNING: Default tools are shared across all the chat requests performed by that ---- ToolCallback toolCallback = ... ChatModel chatModel = OllamaChatModel.builder() - .ollamaApi(new OllamaApi()) + .ollamaApi(OllamaApi.builder().build()) .defaultOptions(ToolCallingChatOptions.builder() .toolCallbacks(toolCallback) .build()) @@ -667,7 +667,7 @@ WARNING: Default tools are shared across all the chat requests performed by that [source,java] ---- ChatModel chatModel = OllamaChatModel.builder() - .ollamaApi(new OllamaApi()) + .ollamaApi(OllamaApi.builder().build()) .defaultOptions(ToolCallingChatOptions.builder() .toolNames("currentWeather") .build())