diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatProperties.java index 56d38928e97..83cba8cb8f5 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatProperties.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatProperties.java @@ -16,8 +16,8 @@ package org.springframework.ai.model.ollama.autoconfigure; +import org.springframework.ai.ollama.api.OllamaChatOptions; import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; @@ -38,7 +38,7 @@ public class OllamaChatProperties { * generative's defaults. */ @NestedConfigurationProperty - private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build(); + private OllamaChatOptions options = OllamaChatOptions.builder().model(OllamaModel.MISTRAL.id()).build(); public String getModel() { return this.options.getModel(); @@ -48,7 +48,7 @@ public void setModel(String model) { this.options.setModel(model); } - public OllamaOptions getOptions() { + public OllamaChatOptions getOptions() { return this.options; } diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingProperties.java index 2351c9be8b9..57342b8f57f 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingProperties.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingProperties.java @@ -16,8 +16,8 @@ package org.springframework.ai.model.ollama.autoconfigure; +import org.springframework.ai.ollama.api.OllamaEmbeddingOptions; import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; @@ -38,7 +38,9 @@ public class OllamaEmbeddingProperties { * generative's defaults. */ @NestedConfigurationProperty - private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE.id()).build(); + private OllamaEmbeddingOptions options = OllamaEmbeddingOptions.builder() + .model(OllamaModel.MXBAI_EMBED_LARGE.id()) + .build(); public String getModel() { return this.options.getModel(); @@ -48,7 +50,7 @@ public void setModel(String model) { this.options.setModel(model); } - public OllamaOptions getOptions() { + public OllamaEmbeddingOptions getOptions() { return this.options; } diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java index 2490e5258b6..fba4b5eb67c 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java @@ -38,9 +38,7 @@ public void propertiesTest() { new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.ollama.base-url=TEST_BASE_URL", - "spring.ai.ollama.embedding.options.model=MODEL_XYZ", - "spring.ai.ollama.embedding.options.temperature=0.13", - "spring.ai.ollama.embedding.options.topK=13" + "spring.ai.ollama.embedding.options.model=MODEL_XYZ" // @formatter:on ) @@ -52,9 +50,6 @@ public void propertiesTest() { assertThat(embeddingProperties.getModel()).isEqualTo("MODEL_XYZ"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); - assertThat(embeddingProperties.getOptions().toMap()).containsKeys("temperature"); - assertThat(embeddingProperties.getOptions().toMap().get("temperature")).isEqualTo(0.13); - assertThat(embeddingProperties.getOptions().getTopK()).isEqualTo(13); }); } diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java index f9e366d8fb0..e32fc73d80a 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java @@ -23,6 +23,8 @@ import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaChatOptions; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -70,7 +72,7 @@ void functionCallTest() { UserMessage userMessage = new UserMessage( "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); - var promptOptions = OllamaOptions.builder() + var promptOptions = OllamaChatOptions.builder() .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") @@ -95,7 +97,7 @@ void streamingFunctionCallTest() { UserMessage userMessage = new UserMessage( "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); - var promptOptions = OllamaOptions.builder() + var promptOptions = OllamaChatOptions.builder() .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java index d24df5fb89d..aaeaf505f85 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.ollama.api.OllamaChatOptions; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; @@ -35,7 +36,6 @@ import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.ollama.OllamaChatModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -94,7 +94,7 @@ void functionCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); ChatResponse response = chatModel - .call(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build())); + .call(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("WeatherInfo").build())); logger.info("Response: " + response); @@ -112,7 +112,7 @@ void streamFunctionCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); Flux response = chatModel - .stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build())); + .stream(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("WeatherInfo").build())); String content = response.collectList() .block() diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java index 5922b3f9db8..327e8e314b7 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java @@ -24,6 +24,8 @@ import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaChatOptions; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -36,7 +38,6 @@ import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.support.ToolCallbacks; import org.springframework.ai.tool.annotation.Tool; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -85,7 +86,7 @@ void toolCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OllamaOptions.builder().toolCallbacks(ToolCallbacks.from(myTools)).build())); + OllamaChatOptions.builder().toolCallbacks(ToolCallbacks.from(myTools)).build())); logger.info("Response: {}", response); @@ -104,7 +105,7 @@ void functionCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); ChatResponse response = chatModel - .call(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build())); + .call(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build())); logger.info("Response: {}", response); @@ -122,7 +123,7 @@ void streamFunctionCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); Flux response = chatModel - .stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build())); + .stream(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build())); String content = response.collectList() .block() diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt index a9495617f17..1ef6be1d79e 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt @@ -26,7 +26,7 @@ import org.springframework.ai.model.ollama.autoconfigure.BaseOllamaIT import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration import org.springframework.ai.model.tool.ToolCallingChatOptions import org.springframework.ai.ollama.OllamaChatModel -import org.springframework.ai.ollama.api.OllamaOptions +import org.springframework.ai.ollama.api.OllamaChatOptions import org.springframework.boot.autoconfigure.AutoConfigurations import org.springframework.boot.test.context.runner.ApplicationContextRunner import org.springframework.context.annotation.Bean @@ -68,7 +68,7 @@ class FunctionCallbackResolverKotlinIT : BaseOllamaIT() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.") val response = chatModel - .call(Prompt(listOf(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build())) + .call(Prompt(listOf(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build())) logger.info("Response: $response") diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 32f5457ba69..a25e216fc1a 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -28,6 +28,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.ollama.api.OllamaChatOptions; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -116,7 +117,7 @@ public class OllamaChatModel implements ChatModel { private final OllamaApi chatApi; - private final OllamaOptions defaultOptions; + private final OllamaChatOptions defaultOptions; private final ObservationRegistry observationRegistry; @@ -134,13 +135,13 @@ public class OllamaChatModel implements ChatModel { private final RetryTemplate retryTemplate; - public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, + public OllamaChatModel(OllamaApi ollamaApi, OllamaChatOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions, new DefaultToolExecutionEligibilityPredicate(), RetryUtils.DEFAULT_RETRY_TEMPLATE); } - public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, + public OllamaChatModel(OllamaApi ollamaApi, OllamaChatOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate, RetryTemplate retryTemplate) { @@ -388,21 +389,25 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh Prompt buildRequestPrompt(Prompt prompt) { // Process runtime options - OllamaOptions runtimeOptions = null; + OllamaChatOptions runtimeOptions = null; if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + if (prompt.getOptions() instanceof OllamaOptions ollamaOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(OllamaChatOptions.fromOptions(ollamaOptions), + OllamaChatOptions.class, OllamaChatOptions.class); + } + else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, - OllamaOptions.class); + OllamaChatOptions.class); } else { runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, - OllamaOptions.class); + OllamaChatOptions.class); } } // Define request options by merging runtime options and default options - OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, - OllamaOptions.class); + OllamaChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + OllamaChatOptions.class); // Merge @JsonIgnore-annotated options explicitly since they are ignored by // Jackson, used by ModelOptionsUtils. if (runtimeOptions != null) { @@ -474,7 +479,13 @@ else if (message instanceof ToolResponseMessage toolMessage) { throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); }).flatMap(List::stream).toList(); - OllamaOptions requestOptions = (OllamaOptions) prompt.getOptions(); + OllamaChatOptions requestOptions = null; + if (prompt.getOptions() instanceof OllamaChatOptions) { + requestOptions = (OllamaChatOptions) prompt.getOptions(); + } + else { + requestOptions = OllamaChatOptions.fromOptions((OllamaOptions) prompt.getOptions()); + } OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel()) .stream(stream) @@ -520,7 +531,7 @@ private List getTools(List toolDefinitions) { @Override public ChatOptions getDefaultOptions() { - return OllamaOptions.fromOptions(this.defaultOptions); + return OllamaChatOptions.fromOptions(this.defaultOptions); } /** @@ -545,7 +556,7 @@ public static final class Builder { private OllamaApi ollamaApi; - private OllamaOptions defaultOptions = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build(); + private OllamaChatOptions defaultOptions = OllamaChatOptions.builder().model(OllamaModel.MISTRAL.id()).build(); private ToolCallingManager toolCallingManager; @@ -565,7 +576,13 @@ public Builder ollamaApi(OllamaApi ollamaApi) { return this; } + @Deprecated public Builder defaultOptions(OllamaOptions defaultOptions) { + this.defaultOptions = OllamaChatOptions.fromOptions(defaultOptions); + return this; + } + + public Builder defaultOptions(OllamaChatOptions defaultOptions) { this.defaultOptions = defaultOptions; return this; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java index a505d370e7e..4ee2751c582 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java @@ -41,6 +41,7 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse; +import org.springframework.ai.ollama.api.OllamaEmbeddingOptions; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.common.OllamaApiConstants; @@ -69,7 +70,7 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel { private final OllamaApi ollamaApi; - private final OllamaOptions defaultOptions; + private final OllamaEmbeddingOptions defaultOptions; private final ObservationRegistry observationRegistry; @@ -77,7 +78,7 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel { private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, + public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingOptions defaultOptions, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { Assert.notNull(ollamaApi, "ollamaApi must not be null"); Assert.notNull(defaultOptions, "options must not be null"); @@ -146,15 +147,15 @@ private DefaultUsage getDefaultUsage(OllamaApi.EmbeddingsResponse response) { EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { // Process runtime options - OllamaOptions runtimeOptions = null; + OllamaEmbeddingOptions runtimeOptions = null; if (embeddingRequest.getOptions() != null) { runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, - OllamaOptions.class); + OllamaEmbeddingOptions.class); } // Define request options by merging runtime options and default options - OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, - OllamaOptions.class); + OllamaEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + OllamaEmbeddingOptions.class); // Validate request options if (!StringUtils.hasText(requestOptions.getModel())) { @@ -168,10 +169,17 @@ EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { * Package access for testing. */ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(EmbeddingRequest embeddingRequest) { - OllamaOptions requestOptions = (OllamaOptions) embeddingRequest.getOptions(); + OllamaEmbeddingOptions requestOptions = null; + if (embeddingRequest.getOptions() instanceof OllamaEmbeddingOptions) { + requestOptions = (OllamaEmbeddingOptions) embeddingRequest.getOptions(); + } + else { + requestOptions = OllamaEmbeddingOptions.fromOptions((OllamaOptions) embeddingRequest.getOptions()); + } + return new OllamaApi.EmbeddingsRequest(requestOptions.getModel(), embeddingRequest.getInstructions(), DurationParser.parse(requestOptions.getKeepAlive()), - OllamaOptions.filterNonSupportedFields(requestOptions.toMap()), requestOptions.getTruncate()); + OllamaEmbeddingOptions.filterNonSupportedFields(requestOptions.toMap()), requestOptions.getTruncate()); } /** @@ -227,7 +235,7 @@ public static final class Builder { private OllamaApi ollamaApi; - private OllamaOptions defaultOptions = OllamaOptions.builder() + private OllamaEmbeddingOptions defaultOptions = OllamaEmbeddingOptions.builder() .model(OllamaModel.MXBAI_EMBED_LARGE.id()) .build(); @@ -243,7 +251,13 @@ public Builder ollamaApi(OllamaApi ollamaApi) { return this; } + @Deprecated public Builder defaultOptions(OllamaOptions defaultOptions) { + this.defaultOptions = OllamaEmbeddingOptions.fromOptions(defaultOptions); + return this; + } + + public Builder defaultOptions(OllamaEmbeddingOptions defaultOptions) { this.defaultOptions = defaultOptions; return this; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index 48f2e6b9ad6..4cbf1357c59 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -384,8 +384,8 @@ public Message build() { * @param keepAlive Controls how long the model will stay loaded into memory following this request (default: 5m). * @param tools List of tools the model has access to. * @param options Model-specific options. For example, "temperature" can be set through this field, if the model supports it. - * You can use the {@link OllamaOptions} builder to create the options then {@link OllamaOptions#toMap()} to convert the options into a map. * @param think Think controls whether thinking/reasoning models will think before responding. + * You can use the {@link OllamaChatOptions} builder to create the options then {@link OllamaChatOptions#toMap()} to convert the options into a map. * * @see Chat @@ -514,14 +514,21 @@ public Builder options(Map options) { return this; } + public Builder think(Boolean think) { + this.think = think; + return this; + } + + @Deprecated public Builder options(OllamaOptions options) { Objects.requireNonNull(options, "The options can not be null."); this.options = OllamaOptions.filterNonSupportedFields(options.toMap()); return this; } - public Builder think(Boolean think) { - this.think = think; + public Builder options(OllamaChatOptions options) { + Objects.requireNonNull(options, "The options can not be null."); + this.options = OllamaChatOptions.filterNonSupportedFields(options.toMap()); return this; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaChatOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaChatOptions.java new file mode 100644 index 00000000000..ecfd4411e1d --- /dev/null +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaChatOptions.java @@ -0,0 +1,1076 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama.api; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.util.*; +import java.util.stream.Collectors; + +/** + * Helper class for creating strongly-typed Ollama options. + * + * @author Christian Tzolov + * @author Thomas Vitale + * @author Ilayaperumal Gopinathan + * @since 0.8.0 + * @see Ollama + * Valid Parameters and Values + * @see Ollama Types + */ +@JsonInclude(Include.NON_NULL) +public class OllamaChatOptions implements ToolCallingChatOptions { + + private static final List NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive", "truncate"); + + // Following fields are options which must be set when the model is loaded into + // memory. + // See: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/README.md + + // @formatter:off + + /** + * Whether to use NUMA. (Default: false) + */ + @JsonProperty("numa") + private Boolean useNUMA; + + /** + * Sets the size of the context window used to generate the next token. (Default: 2048) + */ + @JsonProperty("num_ctx") + private Integer numCtx; + + /** + * Prompt processing maximum batch size. (Default: 512) + */ + @JsonProperty("num_batch") + private Integer numBatch; + + /** + * The number of layers to send to the GPU(s). On macOS, it defaults to 1 + * to enable metal support, 0 to disable. + * (Default: -1, which indicates that numGPU should be set dynamically) + */ + @JsonProperty("num_gpu") + private Integer numGPU; + + /** + * When using multiple GPUs this option controls which GPU is used + * for small tensors for which the overhead of splitting the computation + * across all GPUs is not worthwhile. The GPU in question will use slightly + * more VRAM to store a scratch buffer for temporary results. + * By default, GPU 0 is used. + */ + @JsonProperty("main_gpu") + private Integer mainGPU; + + /** + * (Default: false) + */ + @JsonProperty("low_vram") + private Boolean lowVRAM; + + /** + * (Default: true) + */ + @JsonProperty("f16_kv") + private Boolean f16KV; + + /** + * Return logits for all the tokens, not just the last one. + * To enable completions to return logprobs, this must be true. + */ + @JsonProperty("logits_all") + private Boolean logitsAll; + + /** + * Load only the vocabulary, not the weights. + */ + @JsonProperty("vocab_only") + private Boolean vocabOnly; + + /** + * By default, models are mapped into memory, which allows the system to load only the necessary parts + * of the model as needed. However, if the model is larger than your total amount of RAM or if your system is low + * on available memory, using mmap might increase the risk of pageouts, negatively impacting performance. + * Disabling mmap results in slower load times but may reduce pageouts if you're not using mlock. + * Note that if the model is larger than the total amount of RAM, turning off mmap would prevent + * the model from loading at all. + * (Default: null) + */ + @JsonProperty("use_mmap") + private Boolean useMMap; + + /** + * Lock the model in memory, preventing it from being swapped out when memory-mapped. + * This can improve performance but trades away some of the advantages of memory-mapping + * by requiring more RAM to run and potentially slowing down load times as the model loads into RAM. + * (Default: false) + */ + @JsonProperty("use_mlock") + private Boolean useMLock; + + /** + * Set the number of threads to use during generation. For optimal performance, it is recommended to set this value + * to the number of physical CPU cores your system has (as opposed to the logical number of cores). + * Using the correct number of threads can greatly improve performance. + * By default, Ollama will detect this value for optimal performance. + */ + @JsonProperty("num_thread") + private Integer numThread; + + // Following fields are predict options used at runtime. + + /** + * (Default: 4) + */ + @JsonProperty("num_keep") + private Integer numKeep; + + /** + * Sets the random number seed to use for generation. Setting this to a + * specific number will make the model generate the same text for the same prompt. + * (Default: -1) + */ + @JsonProperty("seed") + private Integer seed; + + /** + * Maximum number of tokens to predict when generating text. + * (Default: 128, -1 = infinite generation, -2 = fill context) + */ + @JsonProperty("num_predict") + private Integer numPredict; + + /** + * Reduces the probability of generating nonsense. A higher value (e.g. + * 100) will give more diverse answers, while a lower value (e.g. 10) will be more + * conservative. (Default: 40) + */ + @JsonProperty("top_k") + private Integer topK; + + /** + * Works together with top-k. A higher value (e.g., 0.95) will lead to + * more diverse text, while a lower value (e.g., 0.5) will generate more focused and + * conservative text. (Default: 0.9) + */ + @JsonProperty("top_p") + private Double topP; + + /** + * Alternative to the top_p, and aims to ensure a balance of quality and variety. + * The parameter p represents the minimum probability for a token to be considered, + * relative to the probability of the most likely token. For example, with p=0.05 and + * the most likely token having a probability of 0.9, logits with a value + * less than 0.045 are filtered out. (Default: 0.0) + */ + @JsonProperty("min_p") + private Double minP; + + /** + * Tail free sampling is used to reduce the impact of less probable tokens + * from the output. A higher value (e.g., 2.0) will reduce the impact more, while a + * value of 1.0 disables this setting. (default: 1) + */ + @JsonProperty("tfs_z") + private Float tfsZ; + + /** + * (Default: 1.0) + */ + @JsonProperty("typical_p") + private Float typicalP; + + /** + * Sets how far back for the model to look back to prevent + * repetition. (Default: 64, 0 = disabled, -1 = num_ctx) + */ + @JsonProperty("repeat_last_n") + private Integer repeatLastN; + + /** + * The temperature of the model. Increasing the temperature will + * make the model answer more creatively. (Default: 0.8) + */ + @JsonProperty("temperature") + private Double temperature; + + /** + * Sets how strongly to penalize repetitions. A higher value + * (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., + * 0.9) will be more lenient. (Default: 1.1) + */ + @JsonProperty("repeat_penalty") + private Double repeatPenalty; + + /** + * (Default: 0.0) + */ + @JsonProperty("presence_penalty") + private Double presencePenalty; + + /** + * (Default: 0.0) + */ + @JsonProperty("frequency_penalty") + private Double frequencyPenalty; + + /** + * Enable Mirostat sampling for controlling perplexity. (default: 0, 0 + * = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + */ + @JsonProperty("mirostat") + private Integer mirostat; + + /** + * Controls the balance between coherence and diversity of the output. + * A lower value will result in more focused and coherent text. (Default: 5.0) + */ + @JsonProperty("mirostat_tau") + private Float mirostatTau; + + /** + * Influences how quickly the algorithm responds to feedback from the generated text. + * A lower learning rate will result in slower adjustments, while a higher learning rate + * will make the algorithm more responsive. (Default: 0.1) + */ + @JsonProperty("mirostat_eta") + private Float mirostatEta; + + /** + * (Default: true) + */ + @JsonProperty("penalize_newline") + private Boolean penalizeNewline; + + /** + * Sets the stop sequences to use. When this pattern is encountered the + * LLM will stop generating text and return. Multiple stop patterns may be set by + * specifying multiple separate stop parameters in a modelfile. + */ + @JsonProperty("stop") + private List stop; + + + // Following fields are not part of the Ollama Options API but part of the Request. + + /** + * NOTE: Synthetic field not part of the official Ollama API. + * Used to allow overriding the model name with prompt options. + * Part of Chat completion parameters. + */ + @JsonProperty("model") + private String model; + + /** + * Sets the desired format of output from the LLM. The only valid values are null or "json". + * Part of Chat completion advanced parameters. + */ + @JsonProperty("format") + private Object format; + + /** + * Sets the length of time for Ollama to keep the model loaded. Valid values for this + * setting are parsed by ParseDuration in Go. + * Part of Chat completion advanced parameters. + */ + @JsonProperty("keep_alive") + private String keepAlive; + + /** + * Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. + * Defaults to true. + */ + @JsonProperty("truncate") + private Boolean truncate; + + @JsonIgnore + private Boolean internalToolExecutionEnabled; + + /** + * Tool Function Callbacks to register with the ChatModel. + * For Prompt Options the toolCallbacks are automatically enabled for the duration of the prompt execution. + * For Default Options the toolCallbacks are registered but disabled by default. Use the enableFunctions to set the functions + * from the registry to be used by the ChatModel chat completion requests. + */ + @JsonIgnore + private List toolCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. + * Functions with those names must exist in the toolCallbacks registry. + * The {@link #toolCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution. + * Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing. + * If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution. + */ + @JsonIgnore + private Set toolNames = new HashSet<>(); + + @JsonIgnore + private Map toolContext = new HashMap<>(); + + public static Builder builder() { + return new Builder(); + } + + /** + * Filter out the non-supported fields from the options. + * @param options The options to filter. + * @return The filtered options. + */ + public static Map filterNonSupportedFields(Map options) { + return options.entrySet().stream() + .filter(e -> !NON_SUPPORTED_FIELDS.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static OllamaChatOptions fromOptions(OllamaChatOptions fromOptions) { + return builder() + .model(fromOptions.getModel()) + .format(fromOptions.getFormat()) + .keepAlive(fromOptions.getKeepAlive()) + .truncate(fromOptions.getTruncate()) + .useNUMA(fromOptions.getUseNUMA()) + .numCtx(fromOptions.getNumCtx()) + .numBatch(fromOptions.getNumBatch()) + .numGPU(fromOptions.getNumGPU()) + .mainGPU(fromOptions.getMainGPU()) + .lowVRAM(fromOptions.getLowVRAM()) + .f16KV(fromOptions.getF16KV()) + .logitsAll(fromOptions.getLogitsAll()) + .vocabOnly(fromOptions.getVocabOnly()) + .useMMap(fromOptions.getUseMMap()) + .useMLock(fromOptions.getUseMLock()) + .numThread(fromOptions.getNumThread()) + .numKeep(fromOptions.getNumKeep()) + .seed(fromOptions.getSeed()) + .numPredict(fromOptions.getNumPredict()) + .topK(fromOptions.getTopK()) + .topP(fromOptions.getTopP()) + .minP(fromOptions.getMinP()) + .tfsZ(fromOptions.getTfsZ()) + .typicalP(fromOptions.getTypicalP()) + .repeatLastN(fromOptions.getRepeatLastN()) + .temperature(fromOptions.getTemperature()) + .repeatPenalty(fromOptions.getRepeatPenalty()) + .presencePenalty(fromOptions.getPresencePenalty()) + .frequencyPenalty(fromOptions.getFrequencyPenalty()) + .mirostat(fromOptions.getMirostat()) + .mirostatTau(fromOptions.getMirostatTau()) + .mirostatEta(fromOptions.getMirostatEta()) + .penalizeNewline(fromOptions.getPenalizeNewline()) + .stop(fromOptions.getStop()) + .toolNames(fromOptions.getToolNames()) + .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolCallbacks(fromOptions.getToolCallbacks()) + .toolContext(fromOptions.getToolContext()).build(); + } + + public static OllamaChatOptions fromOptions(OllamaOptions fromOptions) { + return builder() + .model(fromOptions.getModel()) + .format(fromOptions.getFormat()) + .keepAlive(fromOptions.getKeepAlive()) + .truncate(fromOptions.getTruncate()) + .useNUMA(fromOptions.getUseNUMA()) + .numCtx(fromOptions.getNumCtx()) + .numBatch(fromOptions.getNumBatch()) + .numGPU(fromOptions.getNumGPU()) + .mainGPU(fromOptions.getMainGPU()) + .lowVRAM(fromOptions.getLowVRAM()) + .f16KV(fromOptions.getF16KV()) + .logitsAll(fromOptions.getLogitsAll()) + .vocabOnly(fromOptions.getVocabOnly()) + .useMMap(fromOptions.getUseMMap()) + .useMLock(fromOptions.getUseMLock()) + .numThread(fromOptions.getNumThread()) + .numKeep(fromOptions.getNumKeep()) + .seed(fromOptions.getSeed()) + .numPredict(fromOptions.getNumPredict()) + .topK(fromOptions.getTopK()) + .topP(fromOptions.getTopP()) + .minP(fromOptions.getMinP()) + .tfsZ(fromOptions.getTfsZ()) + .typicalP(fromOptions.getTypicalP()) + .repeatLastN(fromOptions.getRepeatLastN()) + .temperature(fromOptions.getTemperature()) + .repeatPenalty(fromOptions.getRepeatPenalty()) + .presencePenalty(fromOptions.getPresencePenalty()) + .frequencyPenalty(fromOptions.getFrequencyPenalty()) + .mirostat(fromOptions.getMirostat()) + .mirostatTau(fromOptions.getMirostatTau()) + .mirostatEta(fromOptions.getMirostatEta()) + .penalizeNewline(fromOptions.getPenalizeNewline()) + .stop(fromOptions.getStop()) + .toolNames(fromOptions.getToolNames()) + .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolCallbacks(fromOptions.getToolCallbacks()) + .toolContext(fromOptions.getToolContext()).build(); + } + + // ------------------- + // Getters and Setters + // ------------------- + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public Object getFormat() { + return this.format; + } + + public void setFormat(Object format) { + this.format = format; + } + + public String getKeepAlive() { + return this.keepAlive; + } + + public void setKeepAlive(String keepAlive) { + this.keepAlive = keepAlive; + } + + public Boolean getUseNUMA() { + return this.useNUMA; + } + + public void setUseNUMA(Boolean useNUMA) { + this.useNUMA = useNUMA; + } + + public Integer getNumCtx() { + return this.numCtx; + } + + public void setNumCtx(Integer numCtx) { + this.numCtx = numCtx; + } + + public Integer getNumBatch() { + return this.numBatch; + } + + public void setNumBatch(Integer numBatch) { + this.numBatch = numBatch; + } + + public Integer getNumGPU() { + return this.numGPU; + } + + public void setNumGPU(Integer numGPU) { + this.numGPU = numGPU; + } + + public Integer getMainGPU() { + return this.mainGPU; + } + + public void setMainGPU(Integer mainGPU) { + this.mainGPU = mainGPU; + } + + public Boolean getLowVRAM() { + return this.lowVRAM; + } + + public void setLowVRAM(Boolean lowVRAM) { + this.lowVRAM = lowVRAM; + } + + public Boolean getF16KV() { + return this.f16KV; + } + + public void setF16KV(Boolean f16kv) { + this.f16KV = f16kv; + } + + public Boolean getLogitsAll() { + return this.logitsAll; + } + + public void setLogitsAll(Boolean logitsAll) { + this.logitsAll = logitsAll; + } + + public Boolean getVocabOnly() { + return this.vocabOnly; + } + + public void setVocabOnly(Boolean vocabOnly) { + this.vocabOnly = vocabOnly; + } + + public Boolean getUseMMap() { + return this.useMMap; + } + + public void setUseMMap(Boolean useMMap) { + this.useMMap = useMMap; + } + + public Boolean getUseMLock() { + return this.useMLock; + } + + public void setUseMLock(Boolean useMLock) { + this.useMLock = useMLock; + } + + public Integer getNumThread() { + return this.numThread; + } + + public void setNumThread(Integer numThread) { + this.numThread = numThread; + } + + public Integer getNumKeep() { + return this.numKeep; + } + + public void setNumKeep(Integer numKeep) { + this.numKeep = numKeep; + } + + public Integer getSeed() { + return this.seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + @Override + @JsonIgnore + public Integer getMaxTokens() { + return getNumPredict(); + } + + @JsonIgnore + public void setMaxTokens(Integer maxTokens) { + setNumPredict(maxTokens); + } + + public Integer getNumPredict() { + return this.numPredict; + } + + public void setNumPredict(Integer numPredict) { + this.numPredict = numPredict; + } + + @Override + public Integer getTopK() { + return this.topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + public Double getMinP() { + return this.minP; + } + + public void setMinP(Double minP) { + this.minP = minP; + } + + public Float getTfsZ() { + return this.tfsZ; + } + + public void setTfsZ(Float tfsZ) { + this.tfsZ = tfsZ; + } + + public Float getTypicalP() { + return this.typicalP; + } + + public void setTypicalP(Float typicalP) { + this.typicalP = typicalP; + } + + public Integer getRepeatLastN() { + return this.repeatLastN; + } + + public void setRepeatLastN(Integer repeatLastN) { + this.repeatLastN = repeatLastN; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public Double getRepeatPenalty() { + return this.repeatPenalty; + } + + public void setRepeatPenalty(Double repeatPenalty) { + this.repeatPenalty = repeatPenalty; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Integer getMirostat() { + return this.mirostat; + } + + public void setMirostat(Integer mirostat) { + this.mirostat = mirostat; + } + + public Float getMirostatTau() { + return this.mirostatTau; + } + + public void setMirostatTau(Float mirostatTau) { + this.mirostatTau = mirostatTau; + } + + public Float getMirostatEta() { + return this.mirostatEta; + } + + public void setMirostatEta(Float mirostatEta) { + this.mirostatEta = mirostatEta; + } + + public Boolean getPenalizeNewline() { + return this.penalizeNewline; + } + + public void setPenalizeNewline(Boolean penalizeNewline) { + this.penalizeNewline = penalizeNewline; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return getStop(); + } + + @JsonIgnore + public void setStopSequences(List stopSequences) { + setStop(stopSequences); + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + public Boolean getTruncate() { + return this.truncate; + } + + public void setTruncate(Boolean truncate) { + this.truncate = truncate; + } + + @Override + @JsonIgnore + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + @JsonIgnore + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + @JsonIgnore + public Set getToolNames() { + return this.toolNames; + } + + @Override + @JsonIgnore + public void setToolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements")); + this.toolNames = toolNames; + } + + @Override + @Nullable + @JsonIgnore + public Boolean getInternalToolExecutionEnabled() { + return this.internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public Map getToolContext() { + return this.toolContext; + } + + @Override + @JsonIgnore + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + + /** + * Convert the {@link OllamaChatOptions} object to a {@link Map} of key/value pairs. + * @return The {@link Map} of key/value pairs. + */ + public Map toMap() { + return ModelOptionsUtils.objectToMap(this); + } + + @Override + public OllamaChatOptions copy() { + return fromOptions(this); + } + // @formatter:on + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + OllamaChatOptions that = (OllamaChatOptions) o; + return Objects.equals(this.model, that.model) && Objects.equals(this.format, that.format) + && Objects.equals(this.keepAlive, that.keepAlive) && Objects.equals(this.truncate, that.truncate) + && Objects.equals(this.useNUMA, that.useNUMA) && Objects.equals(this.numCtx, that.numCtx) + && Objects.equals(this.numBatch, that.numBatch) && Objects.equals(this.numGPU, that.numGPU) + && Objects.equals(this.mainGPU, that.mainGPU) && Objects.equals(this.lowVRAM, that.lowVRAM) + && Objects.equals(this.f16KV, that.f16KV) && Objects.equals(this.logitsAll, that.logitsAll) + && Objects.equals(this.vocabOnly, that.vocabOnly) && Objects.equals(this.useMMap, that.useMMap) + && Objects.equals(this.useMLock, that.useMLock) && Objects.equals(this.numThread, that.numThread) + && Objects.equals(this.numKeep, that.numKeep) && Objects.equals(this.seed, that.seed) + && Objects.equals(this.numPredict, that.numPredict) && Objects.equals(this.topK, that.topK) + && Objects.equals(this.topP, that.topP) && Objects.equals(this.minP, that.minP) + && Objects.equals(this.tfsZ, that.tfsZ) && Objects.equals(this.typicalP, that.typicalP) + && Objects.equals(this.repeatLastN, that.repeatLastN) + && Objects.equals(this.temperature, that.temperature) + && Objects.equals(this.repeatPenalty, that.repeatPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau) + && Objects.equals(this.mirostatEta, that.mirostatEta) + && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.toolContext, that.toolContext); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx, + this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, + this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, + this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, + this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, + this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, + this.toolContext); + } + + public static class Builder { + + private final OllamaChatOptions options = new OllamaChatOptions(); + + public Builder model(String model) { + this.options.model = model; + return this; + } + + public Builder model(OllamaModel model) { + this.options.model = model.getName(); + return this; + } + + public Builder format(Object format) { + this.options.format = format; + return this; + } + + public Builder keepAlive(String keepAlive) { + this.options.keepAlive = keepAlive; + return this; + } + + public Builder truncate(Boolean truncate) { + this.options.truncate = truncate; + return this; + } + + public Builder useNUMA(Boolean useNUMA) { + this.options.useNUMA = useNUMA; + return this; + } + + public Builder numCtx(Integer numCtx) { + this.options.numCtx = numCtx; + return this; + } + + public Builder numBatch(Integer numBatch) { + this.options.numBatch = numBatch; + return this; + } + + public Builder numGPU(Integer numGPU) { + this.options.numGPU = numGPU; + return this; + } + + public Builder mainGPU(Integer mainGPU) { + this.options.mainGPU = mainGPU; + return this; + } + + public Builder lowVRAM(Boolean lowVRAM) { + this.options.lowVRAM = lowVRAM; + return this; + } + + public Builder f16KV(Boolean f16KV) { + this.options.f16KV = f16KV; + return this; + } + + public Builder logitsAll(Boolean logitsAll) { + this.options.logitsAll = logitsAll; + return this; + } + + public Builder vocabOnly(Boolean vocabOnly) { + this.options.vocabOnly = vocabOnly; + return this; + } + + public Builder useMMap(Boolean useMMap) { + this.options.useMMap = useMMap; + return this; + } + + public Builder useMLock(Boolean useMLock) { + this.options.useMLock = useMLock; + return this; + } + + public Builder numThread(Integer numThread) { + this.options.numThread = numThread; + return this; + } + + public Builder numKeep(Integer numKeep) { + this.options.numKeep = numKeep; + return this; + } + + public Builder seed(Integer seed) { + this.options.seed = seed; + return this; + } + + public Builder numPredict(Integer numPredict) { + this.options.numPredict = numPredict; + return this; + } + + public Builder topK(Integer topK) { + this.options.topK = topK; + return this; + } + + public Builder topP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder minP(Double minP) { + this.options.minP = minP; + return this; + } + + public Builder tfsZ(Float tfsZ) { + this.options.tfsZ = tfsZ; + return this; + } + + public Builder typicalP(Float typicalP) { + this.options.typicalP = typicalP; + return this; + } + + public Builder repeatLastN(Integer repeatLastN) { + this.options.repeatLastN = repeatLastN; + return this; + } + + public Builder temperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder repeatPenalty(Double repeatPenalty) { + this.options.repeatPenalty = repeatPenalty; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder mirostat(Integer mirostat) { + this.options.mirostat = mirostat; + return this; + } + + public Builder mirostatTau(Float mirostatTau) { + this.options.mirostatTau = mirostatTau; + return this; + } + + public Builder mirostatEta(Float mirostatEta) { + this.options.mirostatEta = mirostatEta; + return this; + } + + public Builder penalizeNewline(Boolean penalizeNewline) { + this.options.penalizeNewline = penalizeNewline; + return this; + } + + public Builder stop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); + return this; + } + + public Builder toolCallbacks(ToolCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); + return this; + } + + public Builder toolNames(Set toolNames) { + this.options.setToolNames(toolNames); + return this; + } + + public Builder toolNames(String... toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.toolNames.addAll(Set.of(toolNames)); + return this; + } + + public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + return this; + } + + public Builder toolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public OllamaChatOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaEmbeddingOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaEmbeddingOptions.java new file mode 100644 index 00000000000..5863b90a084 --- /dev/null +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaEmbeddingOptions.java @@ -0,0 +1,208 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama.api; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.util.*; +import java.util.stream.Collectors; + +/** + * Helper class for creating strongly-typed Ollama options. + * + * @author Christian Tzolov + * @author Thomas Vitale + * @author Ilayaperumal Gopinathan + * @since 0.8.0 + * @see Ollama + * Valid Parameters and Values + * @see Ollama Types + */ +@JsonInclude(Include.NON_NULL) +public class OllamaEmbeddingOptions implements EmbeddingOptions { + + private static final List NON_SUPPORTED_FIELDS = List.of("model", "keep_alive", "truncate"); + + // Following fields are options which must be set when the model is loaded into + // memory. + // See: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/README.md + + // @formatter:off + + + // Following fields are not part of the Ollama Options API but part of the Request. + + /** + * NOTE: Synthetic field not part of the official Ollama API. + * Used to allow overriding the model name with prompt options. + * Part of Chat completion parameters. + */ + @JsonProperty("model") + private String model; + + /** + * Sets the length of time for Ollama to keep the model loaded. Valid values for this + * setting are parsed by ParseDuration in Go. + * Part of Chat completion advanced parameters. + */ + @JsonProperty("keep_alive") + private String keepAlive; + + /** + * Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. + * Defaults to true. + */ + @JsonProperty("truncate") + private Boolean truncate; + + public static Builder builder() { + return new Builder(); + } + + /** + * Filter out the non-supported fields from the options. + * @param options The options to filter. + * @return The filtered options. + */ + public static Map filterNonSupportedFields(Map options) { + return options.entrySet().stream() + .filter(e -> !NON_SUPPORTED_FIELDS.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static OllamaEmbeddingOptions fromOptions(OllamaOptions fromOptions) { + return builder() + .model(fromOptions.getModel()) + .keepAlive(fromOptions.getKeepAlive()) + .truncate(fromOptions.getTruncate()) + .build(); + } + + public static OllamaEmbeddingOptions fromOptions(OllamaEmbeddingOptions fromOptions) { + return builder() + .model(fromOptions.getModel()) + .keepAlive(fromOptions.getKeepAlive()) + .truncate(fromOptions.getTruncate()) + .build(); + } + + // ------------------- + // Getters and Setters + // ------------------- + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getKeepAlive() { + return this.keepAlive; + } + + public void setKeepAlive(String keepAlive) { + this.keepAlive = keepAlive; + } + + public Boolean getTruncate() { + return this.truncate; + } + + public void setTruncate(Boolean truncate) { + this.truncate = truncate; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + + /** + * Convert the {@link OllamaEmbeddingOptions} object to a {@link Map} of key/value pairs. + * @return The {@link Map} of key/value pairs. + */ + public Map toMap() { + return ModelOptionsUtils.objectToMap(this); + } + + public OllamaEmbeddingOptions copy() { + return fromOptions(this); + } + // @formatter:on + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + OllamaEmbeddingOptions that = (OllamaEmbeddingOptions) o; + return Objects.equals(this.model, that.model) && Objects.equals(this.keepAlive, that.keepAlive) + && Objects.equals(this.truncate, that.truncate); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.keepAlive, this.truncate); + } + + public static class Builder { + + private final OllamaEmbeddingOptions options = new OllamaEmbeddingOptions(); + + public Builder model(String model) { + this.options.model = model; + return this; + } + + public Builder model(OllamaModel model) { + this.options.model = model.getName(); + return this; + } + + public Builder keepAlive(String keepAlive) { + this.options.keepAlive = keepAlive; + return this; + } + + public Builder truncate(Boolean truncate) { + this.options.truncate = truncate; + return this; + } + + public OllamaEmbeddingOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 64da524c653..cb9d043649d 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -49,8 +49,10 @@ * "https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">Ollama * Valid Parameters and Values * @see Ollama Types + * @deprecated use OllamaChatOptions or OllamaEmbeddingOptions instead. */ @JsonInclude(Include.NON_NULL) +@Deprecated public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { private static final List NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive", "truncate"); @@ -944,6 +946,7 @@ public int hashCode() { this.toolContext); } + @Deprecated public static class Builder { private final OllamaOptions options = new OllamaOptions(); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java index 4bc9ef3438d..e75ab37c676 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -23,6 +23,8 @@ import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaChatOptions; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -62,7 +64,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); - var promptOptions = OllamaOptions.builder() + var promptOptions = OllamaChatOptions.builder() .model(MODEL) .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( @@ -85,7 +87,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); - var promptOptions = OllamaOptions.builder() + var promptOptions = OllamaChatOptions.builder() .model(MODEL) .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( @@ -121,7 +123,7 @@ public OllamaApi ollamaApi() { public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { return OllamaChatModel.builder() .ollamaApi(ollamaApi) - .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .defaultOptions(OllamaChatOptions.builder().model(MODEL).temperature(0.9).build()) .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index 1ac31830bb1..3362a733878 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -46,8 +46,8 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaChatOptions; import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; @@ -82,7 +82,7 @@ void autoPullModelTest() { String joke = ChatClient.create(this.chatModel) .prompt("Tell me a joke") - .options(OllamaOptions.builder().model(ADDITIONAL_MODEL).build()) + .options(OllamaChatOptions.builder().model(ADDITIONAL_MODEL).build()) .call() .content(); @@ -111,7 +111,7 @@ void roleTest() { assertThat(response.getResult().getOutput().getText()).contains("Blackbeard"); // ollama specific options - var ollamaOptions = OllamaOptions.builder().lowVRAM(true).build(); + var ollamaOptions = OllamaChatOptions.builder().lowVRAM(true).build(); response = this.chatModel.call(new Prompt(List.of(systemMessage, userMessage), ollamaOptions)); assertThat(response.getResult().getOutput().getText()).contains("Blackbeard"); @@ -260,7 +260,10 @@ void jsonSchemaFormatStructuredOutput() { """); Map model = Map.of("country", "denmark"); var prompt = userPromptTemplate.create(model, - OllamaOptions.builder().model(MODEL).format(outputConverter.getJsonSchemaMap()).build()); + OllamaChatOptions.builder() + .model(OllamaModel.LLAMA3_2.getName()) + .format(outputConverter.getJsonSchemaMap()) + .build()); var chatResponse = this.chatModel.call(prompt); @@ -362,7 +365,7 @@ public OllamaApi ollamaApi() { public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { return OllamaChatModel.builder() .ollamaApi(ollamaApi) - .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .defaultOptions(OllamaChatOptions.builder().model(MODEL).temperature(0.9).build()) .modelManagementOptions(ModelManagementOptions.builder() .pullModelStrategy(PullModelStrategy.WHEN_MISSING) .additionalModels(List.of(ADDITIONAL_MODEL)) diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java index 1bcc41f4061..33fb5fec2b7 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java @@ -28,7 +28,7 @@ import org.springframework.ai.content.Media; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.ollama.api.OllamaChatOptions; import org.springframework.ai.retry.TransientAiException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -108,7 +108,7 @@ public void onError(RetryContext context .build(); return OllamaChatModel.builder() .ollamaApi(ollamaApi) - .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .defaultOptions(OllamaChatOptions.builder().model(MODEL).temperature(0.9).build()) .retryTemplate(retryTemplate) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java index b6d9948dd4f..845f3f772df 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java @@ -23,6 +23,7 @@ import io.micrometer.observation.tck.TestObservationRegistryAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.ai.ollama.api.OllamaChatOptions; import reactor.core.publisher.Flux; import org.springframework.ai.chat.metadata.ChatResponseMetadata; @@ -33,7 +34,6 @@ import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.retry.RetryUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -68,7 +68,7 @@ void beforeEach() { @Test void observationForChatOperation() { - var options = OllamaOptions.builder() + var options = OllamaChatOptions.builder() .model(MODEL) .frequencyPenalty(0.0) .numPredict(2048) @@ -92,7 +92,7 @@ void observationForChatOperation() { @Test void observationForStreamingChatOperation() { - var options = OllamaOptions.builder() + var options = OllamaChatOptions.builder() .model(MODEL) .frequencyPenalty(0.0) .numPredict(2048) diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java index bd8d83e5a7c..688235309ff 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java @@ -34,8 +34,8 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaChatOptions; import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.retry.RetryUtils; @@ -62,7 +62,7 @@ class OllamaChatModelTests { void buildOllamaChatModelWithDeprecatedConstructor() { ChatModel chatModel = OllamaChatModel.builder() .ollamaApi(this.ollamaApi) - .defaultOptions(OllamaOptions.builder().model(OllamaModel.MISTRAL).build()) + .defaultOptions(OllamaChatOptions.builder().model(OllamaModel.MISTRAL).build()) .observationRegistry(ObservationRegistry.NOOP) .build(); assertThat(chatModel).isNotNull(); @@ -71,7 +71,7 @@ void buildOllamaChatModelWithDeprecatedConstructor() { @Test void buildOllamaChatModelWithConstructor() { ChatModel chatModel = new OllamaChatModel(this.ollamaApi, - OllamaOptions.builder().model(OllamaModel.MISTRAL).build(), ToolCallingManager.builder().build(), + OllamaChatOptions.builder().model(OllamaModel.MISTRAL).build(), ToolCallingManager.builder().build(), ObservationRegistry.NOOP, ModelManagementOptions.builder().build()); assertThat(chatModel).isNotNull(); } @@ -87,7 +87,7 @@ void buildOllamaChatModel() { Exception exception = assertThrows(IllegalArgumentException.class, () -> OllamaChatModel.builder() .ollamaApi(this.ollamaApi) - .defaultOptions(OllamaOptions.builder().model(OllamaModel.LLAMA2).build()) + .defaultOptions(OllamaChatOptions.builder().model(OllamaModel.LLAMA2).build()) .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .modelManagementOptions(null) .build()); @@ -185,7 +185,11 @@ void buildOllamaChatModelWithNullOllamaApi() { @Test void buildOllamaChatModelWithAllBuilderOptions() { - OllamaOptions options = OllamaOptions.builder().model(OllamaModel.CODELLAMA).temperature(0.7).topK(50).build(); + OllamaChatOptions options = OllamaChatOptions.builder() + .model(OllamaModel.CODELLAMA) + .temperature(0.7) + .topK(50) + .build(); ToolCallingManager toolManager = ToolCallingManager.builder().build(); ModelManagementOptions managementOptions = ModelManagementOptions.builder().build(); @@ -244,7 +248,7 @@ void buildChatResponseMetadataAggregationWithNullPrevious() { @ValueSource(strings = { "LLAMA2", "MISTRAL", "CODELLAMA", "LLAMA3", "GEMMA" }) void buildOllamaChatModelWithDifferentModels(String modelName) { OllamaModel model = OllamaModel.valueOf(modelName); - OllamaOptions options = OllamaOptions.builder().model(model).build(); + OllamaChatOptions options = OllamaChatOptions.builder().model(model).build(); ChatModel chatModel = OllamaChatModel.builder().ollamaApi(this.ollamaApi).defaultOptions(options).build(); @@ -313,7 +317,7 @@ void buildChatResponseMetadataAggregationOverflowHandling() { @Test void buildOllamaChatModelImmutability() { // Test that the builder creates immutable instances - OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MISTRAL).temperature(0.5).build(); + OllamaChatOptions options = OllamaChatOptions.builder().model(OllamaModel.MISTRAL).temperature(0.5).build(); ChatModel chatModel1 = OllamaChatModel.builder().ollamaApi(this.ollamaApi).defaultOptions(options).build(); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index d03de073b7e..27085bfc17d 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -24,6 +24,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaChatOptions; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.tool.ToolCallback; @@ -41,13 +42,13 @@ class OllamaChatRequestTests { OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(OllamaApi.builder().build()) - .defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) + .defaultOptions(OllamaChatOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); @Test void whenToolRuntimeOptionsThenMergeWithDefaults() { - OllamaOptions defaultOptions = OllamaOptions.builder() + OllamaChatOptions defaultOptions = OllamaChatOptions.builder() .model("MODEL_NAME") .internalToolExecutionEnabled(true) .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) @@ -59,7 +60,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { .defaultOptions(defaultOptions) .build(); - OllamaOptions runtimeOptions = OllamaOptions.builder() + OllamaChatOptions runtimeOptions = OllamaChatOptions.builder() .internalToolExecutionEnabled(false) .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) .toolNames("tool3") @@ -115,6 +116,27 @@ void createRequestWithPromptOllamaOptions() { // promptOptions. } + @Test + void createRequestWithPromptOllamaChatOptions() { + // Runtime options should override the default options. + OllamaChatOptions promptOptions = OllamaChatOptions.builder().temperature(0.8).topP(0.5).numGPU(2).build(); + var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions)); + + var request = this.chatModel.ollamaChatRequest(prompt, true); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isTrue(); + + assertThat(request.model()).isEqualTo("MODEL_NAME"); + assertThat(request.options().get("temperature")).isEqualTo(0.8); + assertThat(request.options().get("top_k")).isEqualTo(99); // still the default + // value. + assertThat(request.options().get("num_gpu")).isEqualTo(2); + assertThat(request.options().get("top_p")).isEqualTo(0.5); // new field introduced + // by the + // promptOptions. + } + @Test public void createRequestWithPromptPortableChatOptions() { // Ollama runtime options. @@ -136,7 +158,7 @@ public void createRequestWithPromptPortableChatOptions() { @Test public void createRequestWithPromptOptionsModelOverride() { // Ollama runtime options. - OllamaOptions promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").build(); + OllamaChatOptions promptOptions = OllamaChatOptions.builder().model("PROMPT_MODEL").build(); var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions)); var request = this.chatModel.ollamaChatRequest(prompt, true); @@ -148,7 +170,7 @@ public void createRequestWithPromptOptionsModelOverride() { public void createRequestWithDefaultOptionsModelOverride() { OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(OllamaApi.builder().build()) - .defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build()) + .defaultOptions(OllamaChatOptions.builder().model("DEFAULT_OPTIONS_MODEL").build()) .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); @@ -167,6 +189,29 @@ public void createRequestWithDefaultOptionsModelOverride() { assertThat(request.model()).isEqualTo("PROMPT_MODEL"); } + @Test + public void createRequestWithDefaultOptionsModelChatOptionsOverride() { + OllamaChatModel chatModel = OllamaChatModel.builder() + .ollamaApi(OllamaApi.builder().build()) + .defaultOptions(OllamaChatOptions.builder().model("DEFAULT_OPTIONS_MODEL").build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); + + var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content")); + + var request = chatModel.ollamaChatRequest(prompt1, true); + + assertThat(request.model()).isEqualTo("DEFAULT_OPTIONS_MODEL"); + + // Prompt options should override the default options. + OllamaChatOptions promptOptions = OllamaChatOptions.builder().model("PROMPT_MODEL").build(); + var prompt2 = chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions)); + + request = chatModel.ollamaChatRequest(prompt2, true); + + assertThat(request.model()).isEqualTo("PROMPT_MODEL"); + } + static class TestToolCallback implements ToolCallback { private final ToolDefinition toolDefinition; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java index 950d7dbe5d3..bb48b01d222 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java @@ -23,8 +23,8 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaEmbeddingOptions; import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; @@ -52,7 +52,7 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT { void embeddings() { assertThat(this.embeddingModel).isNotNull(); EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest( - List.of("Hello World", "Something else"), OllamaOptions.builder().truncate(false).build())); + List.of("Hello World", "Something else"), OllamaEmbeddingOptions.builder().build())); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); @@ -73,9 +73,8 @@ void autoPullModelAtStartupTime() { var modelManager = new OllamaModelManager(this.ollamaApi); assertThat(modelManager.isModelAvailable(ADDITIONAL_MODEL)).isTrue(); - EmbeddingResponse embeddingResponse = this.embeddingModel - .call(new EmbeddingRequest(List.of("Hello World", "Something else"), - OllamaOptions.builder().model(model).truncate(false).build())); + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest( + List.of("Hello World", "Something else"), OllamaEmbeddingOptions.builder().model(model).build())); assertThat(embeddingResponse.getResults()).hasSize(2); assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); @@ -103,7 +102,7 @@ public OllamaApi ollamaApi() { public OllamaEmbeddingModel ollamaEmbedding(OllamaApi ollamaApi) { return OllamaEmbeddingModel.builder() .ollamaApi(ollamaApi) - .defaultOptions(OllamaOptions.builder().model(MODEL).build()) + .defaultOptions(OllamaEmbeddingOptions.builder().model(MODEL).build()) .modelManagementOptions(ModelManagementOptions.builder() .pullModelStrategy(PullModelStrategy.WHEN_MISSING) .additionalModels(List.of(ADDITIONAL_MODEL)) diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java index a94dbbe6312..51d8566eb3c 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelObservationIT.java @@ -31,8 +31,8 @@ import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaEmbeddingOptions; import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -58,7 +58,7 @@ public class OllamaEmbeddingModelObservationIT extends BaseOllamaIT { @Test void observationForEmbeddingOperation() { - var options = OllamaOptions.builder().model(OllamaModel.NOMIC_EMBED_TEXT.getName()).build(); + var options = OllamaEmbeddingOptions.builder().model(OllamaModel.NOMIC_EMBED_TEXT.getName()).build(); EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java index 6295d833d38..4826478ab2f 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java @@ -16,7 +16,6 @@ package org.springframework.ai.ollama; -import java.time.Duration; import java.util.List; import java.util.Map; @@ -34,7 +33,7 @@ import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsRequest; import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse; -import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.ollama.api.OllamaEmbeddingOptions; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -64,7 +63,7 @@ public void options() { List.of(new float[] { 7f, 8f, 9f }, new float[] { 10f, 11f, 12f }), 0L, 0L, 0)); // Tests default options - var defaultOptions = OllamaOptions.builder().model("DEFAULT_MODEL").build(); + var defaultOptions = OllamaEmbeddingOptions.builder().model("DEFAULT_MODEL").build(); var embeddingModel = OllamaEmbeddingModel.builder() .ollamaApi(this.ollamaApi) @@ -90,12 +89,7 @@ public void options() { assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("DEFAULT_MODEL"); // Tests runtime options - var runtimeOptions = OllamaOptions.builder() - .model("RUNTIME_MODEL") - .keepAlive("10m") - .truncate(false) - .mainGPU(666) - .build(); + var runtimeOptions = OllamaEmbeddingOptions.builder().model("RUNTIME_MODEL").build(); response = embeddingModel.call(new EmbeddingRequest(List.of("Input4", "Input5", "Input6"), runtimeOptions)); @@ -108,10 +102,7 @@ public void options() { assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY); assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME2"); - assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofMinutes(10)); - assertThat(this.embeddingsRequestCaptor.getValue().truncate()).isFalse(); assertThat(this.embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Input4", "Input5", "Input6")); - assertThat(this.embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of("main_gpu", 666)); assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("RUNTIME_MODEL"); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java index 4269bae3ceb..4e03b56cc1f 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java @@ -26,7 +26,7 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.ollama.api.OllamaEmbeddingOptions; import static org.assertj.core.api.Assertions.assertThat; @@ -43,7 +43,7 @@ public class OllamaEmbeddingRequestTests { public void setUp() { this.embeddingModel = OllamaEmbeddingModel.builder() .ollamaApi(OllamaApi.builder().build()) - .defaultOptions(OllamaOptions.builder().model("DEFAULT_MODEL").mainGPU(11).useMMap(true).numGPU(1).build()) + .defaultOptions(OllamaEmbeddingOptions.builder().model("DEFAULT_MODEL").build()) .build(); } @@ -53,19 +53,13 @@ public void ollamaEmbeddingRequestDefaultOptions() { var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest); assertThat(ollamaRequest.model()).isEqualTo("DEFAULT_MODEL"); - assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(1); - assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(11); - assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(true); assertThat(ollamaRequest.input()).isEqualTo(List.of("Hello")); } @Test public void ollamaEmbeddingRequestRequestOptions() { - var promptOptions = OllamaOptions.builder()// + var promptOptions = OllamaEmbeddingOptions.builder()// .model("PROMPT_MODEL")// - .mainGPU(22)// - .useMMap(true)// - .numGPU(2) .build(); var embeddingRequest = this.embeddingModel @@ -73,15 +67,12 @@ public void ollamaEmbeddingRequestRequestOptions() { var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest); assertThat(ollamaRequest.model()).isEqualTo("PROMPT_MODEL"); - assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(2); - assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(22); - assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(true); assertThat(ollamaRequest.input()).isEqualTo(List.of("Hello")); } @Test public void ollamaEmbeddingRequestWithNegativeKeepAlive() { - var promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").keepAlive("-1m").build(); + var promptOptions = OllamaEmbeddingOptions.builder().model("PROMPT_MODEL").keepAlive("-1m").build(); var embeddingRequest = this.embeddingModel .buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), promptOptions)); @@ -112,12 +103,7 @@ public void ollamaEmbeddingRequestWithMultipleInputs() { @Test public void ollamaEmbeddingRequestOptionsOverrideDefaults() { - var requestOptions = OllamaOptions.builder() - .model("OVERRIDE_MODEL") - .mainGPU(99) - .useMMap(false) - .numGPU(8) - .build(); + var requestOptions = OllamaEmbeddingOptions.builder().model("OVERRIDE_MODEL").build(); var embeddingRequest = this.embeddingModel .buildEmbeddingRequest(new EmbeddingRequest(List.of("Override test"), requestOptions)); @@ -125,22 +111,19 @@ public void ollamaEmbeddingRequestOptionsOverrideDefaults() { // Request options should override defaults assertThat(ollamaRequest.model()).isEqualTo("OVERRIDE_MODEL"); - assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(8); - assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(99); - assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(false); } @Test public void ollamaEmbeddingRequestWithDifferentKeepAliveFormats() { // Test seconds format - var optionsSeconds = OllamaOptions.builder().keepAlive("30s").build(); + var optionsSeconds = OllamaEmbeddingOptions.builder().keepAlive("30s").build(); var requestSeconds = this.embeddingModel .buildEmbeddingRequest(new EmbeddingRequest(List.of("Test"), optionsSeconds)); var ollamaRequestSeconds = this.embeddingModel.ollamaEmbeddingRequest(requestSeconds); assertThat(ollamaRequestSeconds.keepAlive()).isEqualTo(Duration.ofSeconds(30)); // Test hours format - var optionsHours = OllamaOptions.builder().keepAlive("2h").build(); + var optionsHours = OllamaEmbeddingOptions.builder().keepAlive("2h").build(); var requestHours = this.embeddingModel .buildEmbeddingRequest(new EmbeddingRequest(List.of("Test"), optionsHours)); var ollamaRequestHours = this.embeddingModel.ollamaEmbeddingRequest(requestHours); @@ -152,7 +135,7 @@ public void ollamaEmbeddingRequestWithMinimalDefaults() { // Create model with minimal defaults var minimalModel = OllamaEmbeddingModel.builder() .ollamaApi(OllamaApi.builder().build()) - .defaultOptions(OllamaOptions.builder().model("MINIMAL_MODEL").build()) + .defaultOptions(OllamaEmbeddingOptions.builder().model("MINIMAL_MODEL").build()) .build(); var embeddingRequest = minimalModel.buildEmbeddingRequest(new EmbeddingRequest(List.of("Minimal test"), null)); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java index f3702be26c1..323c969c6fa 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java @@ -29,6 +29,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaChatOptions; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.retry.NonTransientAiException; @@ -74,7 +75,7 @@ public void beforeEach() { this.chatModel = OllamaChatModel.builder() .ollamaApi(this.ollamaApi) - .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .defaultOptions(OllamaChatOptions.builder().model(MODEL).temperature(0.9).build()) .retryTemplate(this.retryTemplate) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java index 13b11fbefca..cc62628a12b 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/aot/OllamaRuntimeHintsTests.java @@ -22,7 +22,8 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.ollama.api.OllamaChatOptions; +import org.springframework.ai.ollama.api.OllamaEmbeddingOptions; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; @@ -50,7 +51,8 @@ void registerHints() { assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.ChatRequest.class))).isTrue(); assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.ChatRequest.Tool.class))).isTrue(); assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.Message.class))).isTrue(); - assertThat(registeredTypes.contains(TypeReference.of(OllamaOptions.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OllamaChatOptions.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OllamaEmbeddingOptions.class))).isTrue(); } @Test @@ -101,7 +103,7 @@ void verifyMainApiClassesRegistered() { // Verify that the main classes we already know exist are registered assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.ChatRequest.class))).isTrue(); assertThat(registeredTypes.contains(TypeReference.of(OllamaApi.Message.class))).isTrue(); - assertThat(registeredTypes.contains(TypeReference.of(OllamaOptions.class))).isTrue(); + assertThat(registeredTypes.contains(TypeReference.of(OllamaChatOptions.class))).isTrue(); } @Test diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index 176c6d3c5b5..b31ba5365f8 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -63,7 +63,7 @@ public void chat() { .content("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") .build())) - .options(OllamaOptions.builder().temperature(0.9).build()) + .options(OllamaChatOptions.builder().temperature(0.9).build()) .build(); ChatResponse response = getOllamaApi().chat(request); @@ -84,7 +84,7 @@ public void streamingChat() { .messages(List.of(Message.builder(Role.USER) .content("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") .build())) - .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .options(OllamaChatOptions.builder().temperature(0.9).build().toMap()) .build(); Flux response = getOllamaApi().streamingChat(request); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java index 3a4d985d91c..e50ace4a094 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java @@ -33,7 +33,7 @@ public class OllamaModelOptionsTests { @Test public void testBasicOptions() { - var options = OllamaOptions.builder().temperature(3.14).topK(30).stop(List.of("a", "b", "c")).build(); + var options = OllamaChatOptions.builder().temperature(3.14).topK(30).stop(List.of("a", "b", "c")).build(); var optionsMap = options.toMap(); assertThat(optionsMap).containsEntry("temperature", 3.14); @@ -43,7 +43,7 @@ public void testBasicOptions() { @Test public void testAllNumericOptions() { - var options = OllamaOptions.builder() + var options = OllamaChatOptions.builder() .numCtx(2048) .numBatch(512) .numGPU(1) @@ -91,7 +91,7 @@ public void testAllNumericOptions() { @Test public void testBooleanOptions() { - var options = OllamaOptions.builder() + var options = OllamaChatOptions.builder() .truncate(true) .useNUMA(true) .lowVRAM(false) @@ -117,7 +117,7 @@ public void testBooleanOptions() { @Test public void testModelAndFormat() { - var options = OllamaOptions.builder().model("llama2").format("json").build(); + var options = OllamaChatOptions.builder().model("llama2").format("json").build(); var optionsMap = options.toMap(); assertThat(optionsMap).containsEntry("model", "llama2"); @@ -126,7 +126,7 @@ public void testModelAndFormat() { @Test public void testFunctionAndToolOptions() { - var options = OllamaOptions.builder() + var options = OllamaChatOptions.builder() .toolNames("function1") .toolNames("function2") .toolNames("function3") @@ -150,21 +150,21 @@ public void testFunctionOptionsWithMutableSet() { functionSet.add("function1"); functionSet.add("function2"); - var options = OllamaOptions.builder().toolNames(functionSet).toolNames("function3").build(); + var options = OllamaChatOptions.builder().toolNames(functionSet).toolNames("function3").build(); assertThat(options.getToolNames()).containsExactlyInAnyOrder("function1", "function2", "function3"); } @Test public void testFromOptions() { - var originalOptions = OllamaOptions.builder() + var originalOptions = OllamaChatOptions.builder() .model("llama2") .temperature(0.7) .topK(40) .toolNames(Set.of("function1")) .build(); - var copiedOptions = OllamaOptions.fromOptions(originalOptions); + var copiedOptions = OllamaChatOptions.fromOptions(originalOptions); // Test the copied options directly rather than through toMap() assertThat(copiedOptions.getModel()).isEqualTo("llama2"); @@ -175,7 +175,7 @@ public void testFromOptions() { @Test public void testFunctionOptionsNotInMap() { - var options = OllamaOptions.builder().model("llama2").toolNames(Set.of("function1")).build(); + var options = OllamaChatOptions.builder().model("llama2").toolNames(Set.of("function1")).build(); var optionsMap = options.toMap(); @@ -190,10 +190,14 @@ public void testFunctionOptionsNotInMap() { assertThat(options.getToolNames()).containsExactly("function1"); } - @SuppressWarnings("deprecation") @Test public void testDeprecatedMethods() { - var options = OllamaOptions.builder().model("llama2").temperature(0.7).topK(40).toolNames("function1").build(); + var options = OllamaChatOptions.builder() + .model("llama2") + .temperature(0.7) + .topK(40) + .toolNames("function1") + .build(); var optionsMap = options.toMap(); assertThat(optionsMap).containsEntry("model", "llama2"); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc index b8b9895ab3c..ef877838202 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc @@ -158,7 +158,7 @@ TIP: All properties prefixed with `spring.ai.ollama.chat.options` can be overrid == Runtime Options [[chat-options]] -The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java[OllamaOptions.java] class provides model configurations, such as the model to use, the temperature, etc. +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaChatOptions.java[OllamaChatOptions.java] class provides model configurations, such as the model to use, the temperature, etc. On start-up, the default options can be configured with the `OllamaChatModel(api, options)` constructor or the `spring.ai.ollama.chat.options.*` properties. @@ -170,14 +170,14 @@ For example, to override the default model and temperature for a specific reques ChatResponse response = chatModel.call( new Prompt( "Generate the names of 5 famous pirates.", - OllamaOptions.builder() + OllamaChatOptions.builder() .model(OllamaModel.LLAMA3_1) .temperature(0.4) .build() )); ---- -TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java[OllamaOptions] you can use a portable link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java[ChatOptions#builder()]. +TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaChatOptions.java[OllamaChatOptions] you can use a portable link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java[ChatOptions#builder()]. [[auto-pulling-models]] == Auto-pulling Models @@ -271,7 +271,7 @@ var userMessage = new UserMessage("Explain what do you see on this picture?", new Media(MimeTypeUtils.IMAGE_PNG, this.imageResource)); ChatResponse response = chatModel.call(new Prompt(this.userMessage, - OllamaOptions.builder().model(OllamaModel.LLAVA)).build()); + OllamaChatOptions.builder().model(OllamaModel.LLAVA)).build()); ---- The example shows a model taking as an input the `multimodal.test.png` image: @@ -295,11 +295,11 @@ In addition to the existing Spring AI model-agnostic xref::api/structured-output === Configuration -Spring AI allows you to configure your response format programmatically using the `OllamaOptions` builder. +Spring AI allows you to configure your response format programmatically using the `OllamaChatOptions` builder. ==== Using the Chat Options Builder -You can set the response format programmatically with the `OllamaOptions` builder as shown below: +You can set the response format programmatically with the `OllamaChatOptions` builder as shown below: [source,java] ---- @@ -327,7 +327,7 @@ String jsonSchema = """ """; Prompt prompt = new Prompt("how can I solve 8x + 7 = -23", - OllamaOptions.builder() + OllamaChatOptions.builder() .model(OllamaModel.LLAMA3_2.getName()) .format(new ObjectMapper().readValue(jsonSchema, Map.class)) .build()); @@ -358,7 +358,7 @@ record MathReasoning( var outputConverter = new BeanOutputConverter<>(MathReasoning.class); Prompt prompt = new Prompt("how can I solve 8x + 7 = -23", - OllamaOptions.builder() + OllamaChatOptions.builder() .model(OllamaModel.LLAMA3_2.getName()) .format(outputConverter.getJsonSchemaMap()) .build()); @@ -488,7 +488,7 @@ var ollamaApi = OllamaApi.builder().build(); var chatModel = OllamaChatModel.builder() .ollamaApi(ollamaApi) .defaultOptions( - OllamaOptions.builder() + OllamaChatOptions.builder() .model(OllamaModel.MISTRAL) .temperature(0.9) .build()) @@ -502,7 +502,7 @@ Flux response = this.chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- -The `OllamaOptions` provides the configuration information for all chat requests. +The `OllamaChatOptions` provides the configuration information for all chat requests. == Low-level OllamaApi Client [[low-level-api]] @@ -531,7 +531,7 @@ var request = ChatRequest.builder("orca-mini") .content("What is the capital of Bulgaria and what is the size? " + "What is the national anthem?") .build())) - .options(OllamaOptions.builder().temperature(0.9).build()) + .options(OllamaChatOptions.builder().temperature(0.9).build()) .build(); ChatResponse response = this.ollamaApi.chat(this.request); @@ -542,7 +542,7 @@ var request2 = ChatRequest.builder("orca-mini") .messages(List.of(Message.builder(Role.USER) .content("What is the capital of Bulgaria and what is the size? " + "What is the national anthem?") .build())) - .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .options(OllamaChatOptions.builder().temperature(0.9).build().toMap()) .build(); Flux streamingResponse = this.ollamaApi.streamingChat(this.request2); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc index e9a3f16401e..8cc0c1867c1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/ollama-embeddings.adoc @@ -162,12 +162,12 @@ TIP: All properties prefixed with `spring.ai.ollama.embedding.options` can be ov == Runtime Options [[embedding-options]] -The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java[OllamaOptions.java] provides the Ollama configurations, such as the model to use, the low level GPU and CPU tuning, etc. +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaEmbeddingOptions.java[OllamaEmbeddingOptions.java] provides the Ollama configurations, such as the model to use, the low level GPU and CPU tuning, etc. The default options can be configured using the `spring.ai.ollama.embedding.options` properties as well. -At start-time use the `OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions)` to configure the default options used for all embedding requests. -At run-time you can override the default options, using a `OllamaOptions` instance as part of your `EmbeddingRequest`. +At start-time use the `OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingOptions defaultOptions)` to configure the default options used for all embedding requests. +At run-time you can override the default options, using a `OllamaEmbeddingOptions` instance as part of your `EmbeddingRequest`. For example to override the default model name for a specific request: @@ -175,7 +175,7 @@ For example to override the default model name for a specific request: ---- EmbeddingResponse embeddingResponse = embeddingModel.call( new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), - OllamaOptions.builder() + OllamaEmbeddingOptions.builder() .model("Different-Embedding-Model-Deployment-Name")) .truncates(false) .build()); @@ -322,16 +322,16 @@ Next, create an `OllamaEmbeddingModel` instance and use it to compute the embedd var ollamaApi = OllamaApi.builder().build(); var embeddingModel = new OllamaEmbeddingModel(this.ollamaApi, - OllamaOptions.builder() + OllamaEmbeddingOptions.builder() .model(OllamaModel.MISTRAL.id()) .build()); EmbeddingResponse embeddingResponse = this.embeddingModel.call( new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), - OllamaOptions.builder() + OllamaEmbeddingOptions.builder() .model("chroma/all-minilm-l6-v2-f32")) .truncate(false) .build()); ---- -The `OllamaOptions` provides the configuration information for all embedding requests. +The `OllamaEmbeddingOptions` provides the configuration information for all embedding requests. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/testing.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/testing.adoc index 2b7910018cd..b2795a3e36f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/testing.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/testing.adoc @@ -157,7 +157,7 @@ void testFactChecking() { OllamaApi ollamaApi = new OllamaApi("http://localhost:11434"); ChatModel chatModel = new OllamaChatModel(ollamaApi, - OllamaOptions.builder().model(BESPOKE_MINICHECK).numPredict(2).temperature(0.0d).build()) + OllamaChatOptions.builder().model(BESPOKE_MINICHECK).numPredict(2).temperature(0.0d).build()) // Create the FactCheckingEvaluator diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreWithOllamaIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreWithOllamaIT.java index d53969bacb0..ba73ea06257 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreWithOllamaIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreWithOllamaIT.java @@ -34,15 +34,13 @@ import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.httpclient5.ApacheHttpClient5TransportBuilder; import org.opensearch.testcontainers.OpensearchContainer; +import org.springframework.ai.ollama.api.*; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.ollama.OllamaEmbeddingModel; -import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaModel; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; @@ -201,12 +199,7 @@ public OpenSearchVectorStore anotherVectorStore(EmbeddingModel embeddingModel) { public EmbeddingModel embeddingModel() { return OllamaEmbeddingModel.builder() .ollamaApi(OllamaApi.builder().build()) - .defaultOptions(OllamaOptions.builder() - .model(OllamaModel.MXBAI_EMBED_LARGE) - .mainGPU(11) - .useMMap(true) - .numGPU(1) - .build()) + .defaultOptions(OllamaEmbeddingOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE).build()) .build(); }