Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package org.springframework.ai.model.ollama.autoconfigure;

import org.springframework.ai.ollama.api.OllamaChatOptions;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

Expand All @@ -38,7 +38,7 @@ public class OllamaChatProperties {
* generative's defaults.
*/
@NestedConfigurationProperty
private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();
private OllamaChatOptions options = OllamaChatOptions.builder().model(OllamaModel.MISTRAL.id()).build();

public String getModel() {
return this.options.getModel();
Expand All @@ -48,7 +48,7 @@ public void setModel(String model) {
this.options.setModel(model);
}

public OllamaOptions getOptions() {
public OllamaChatOptions getOptions() {
return this.options;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package org.springframework.ai.model.ollama.autoconfigure;

import org.springframework.ai.ollama.api.OllamaEmbeddingOptions;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

Expand All @@ -38,7 +38,9 @@ public class OllamaEmbeddingProperties {
* generative's defaults.
*/
@NestedConfigurationProperty
private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE.id()).build();
private OllamaEmbeddingOptions options = OllamaEmbeddingOptions.builder()
.model(OllamaModel.MXBAI_EMBED_LARGE.id())
.build();

public String getModel() {
return this.options.getModel();
Expand All @@ -48,7 +50,7 @@ public void setModel(String model) {
this.options.setModel(model);
}

public OllamaOptions getOptions() {
public OllamaEmbeddingOptions getOptions() {
return this.options;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ public void propertiesTest() {
new ApplicationContextRunner().withPropertyValues(
// @formatter:off
"spring.ai.ollama.base-url=TEST_BASE_URL",
"spring.ai.ollama.embedding.options.model=MODEL_XYZ",
"spring.ai.ollama.embedding.options.temperature=0.13",
"spring.ai.ollama.embedding.options.topK=13"
"spring.ai.ollama.embedding.options.model=MODEL_XYZ"
// @formatter:on
)

Expand All @@ -52,9 +50,6 @@ public void propertiesTest() {

assertThat(embeddingProperties.getModel()).isEqualTo("MODEL_XYZ");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(embeddingProperties.getOptions().toMap()).containsKeys("temperature");
assertThat(embeddingProperties.getOptions().toMap().get("temperature")).isEqualTo(0.13);
assertThat(embeddingProperties.getOptions().getTopK()).isEqualTo(13);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaChatOptions;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -70,7 +72,7 @@ void functionCallTest() {
UserMessage userMessage = new UserMessage(
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");

var promptOptions = OllamaOptions.builder()
var promptOptions = OllamaChatOptions.builder()
.toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
.description(
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
Expand All @@ -95,7 +97,7 @@ void streamingFunctionCallTest() {
UserMessage userMessage = new UserMessage(
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");

var promptOptions = OllamaOptions.builder()
var promptOptions = OllamaChatOptions.builder()
.toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
.description(
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.ollama.api.OllamaChatOptions;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.ChatClient;
Expand All @@ -35,7 +36,6 @@
import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.boot.autoconfigure.AutoConfigurations;
Expand Down Expand Up @@ -94,7 +94,7 @@ void functionCallTest() {
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");

ChatResponse response = chatModel
.call(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build()));
.call(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("WeatherInfo").build()));

logger.info("Response: " + response);

Expand All @@ -112,7 +112,7 @@ void streamFunctionCallTest() {
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");

Flux<ChatResponse> response = chatModel
.stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build()));
.stream(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("WeatherInfo").build()));

String content = response.collectList()
.block()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaChatOptions;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand All @@ -36,7 +38,6 @@
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.boot.autoconfigure.AutoConfigurations;
Expand Down Expand Up @@ -85,7 +86,7 @@ void toolCallTest() {
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");

ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
OllamaOptions.builder().toolCallbacks(ToolCallbacks.from(myTools)).build()));
OllamaChatOptions.builder().toolCallbacks(ToolCallbacks.from(myTools)).build()));

logger.info("Response: {}", response);

Expand All @@ -104,7 +105,7 @@ void functionCallTest() {
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");

ChatResponse response = chatModel
.call(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build()));
.call(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build()));

logger.info("Response: {}", response);

Expand All @@ -122,7 +123,7 @@ void streamFunctionCallTest() {
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");

Flux<ChatResponse> response = chatModel
.stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build()));
.stream(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build()));

String content = response.collectList()
.block()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.springframework.ai.model.ollama.autoconfigure.BaseOllamaIT
import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration
import org.springframework.ai.model.tool.ToolCallingChatOptions
import org.springframework.ai.ollama.OllamaChatModel
import org.springframework.ai.ollama.api.OllamaOptions
import org.springframework.ai.ollama.api.OllamaChatOptions
import org.springframework.boot.autoconfigure.AutoConfigurations
import org.springframework.boot.test.context.runner.ApplicationContextRunner
import org.springframework.context.annotation.Bean
Expand Down Expand Up @@ -68,7 +68,7 @@ class FunctionCallbackResolverKotlinIT : BaseOllamaIT() {
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.")

val response = chatModel
.call(Prompt(listOf(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build()))
.call(Prompt(listOf(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build()))

logger.info("Response: $response")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.ollama.api.OllamaChatOptions;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

Expand Down Expand Up @@ -116,7 +117,7 @@ public class OllamaChatModel implements ChatModel {

private final OllamaApi chatApi;

private final OllamaOptions defaultOptions;
private final OllamaChatOptions defaultOptions;

private final ObservationRegistry observationRegistry;

Expand All @@ -134,13 +135,13 @@ public class OllamaChatModel implements ChatModel {

private final RetryTemplate retryTemplate;

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
public OllamaChatModel(OllamaApi ollamaApi, OllamaChatOptions defaultOptions, ToolCallingManager toolCallingManager,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions,
new DefaultToolExecutionEligibilityPredicate(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
public OllamaChatModel(OllamaApi ollamaApi, OllamaChatOptions defaultOptions, ToolCallingManager toolCallingManager,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate, RetryTemplate retryTemplate) {

Expand Down Expand Up @@ -388,21 +389,25 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh

Prompt buildRequestPrompt(Prompt prompt) {
// Process runtime options
OllamaOptions runtimeOptions = null;
OllamaChatOptions runtimeOptions = null;
if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
if (prompt.getOptions() instanceof OllamaOptions ollamaOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(OllamaChatOptions.fromOptions(ollamaOptions),
OllamaChatOptions.class, OllamaChatOptions.class);
}
else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
OllamaOptions.class);
OllamaChatOptions.class);
}
else {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
OllamaOptions.class);
OllamaChatOptions.class);
}
}

// Define request options by merging runtime options and default options
OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
OllamaOptions.class);
OllamaChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
OllamaChatOptions.class);
// Merge @JsonIgnore-annotated options explicitly since they are ignored by
// Jackson, used by ModelOptionsUtils.
if (runtimeOptions != null) {
Expand Down Expand Up @@ -474,7 +479,13 @@ else if (message instanceof ToolResponseMessage toolMessage) {
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}).flatMap(List::stream).toList();

OllamaOptions requestOptions = (OllamaOptions) prompt.getOptions();
OllamaChatOptions requestOptions = null;
if (prompt.getOptions() instanceof OllamaChatOptions) {
requestOptions = (OllamaChatOptions) prompt.getOptions();
}
else {
requestOptions = OllamaChatOptions.fromOptions((OllamaOptions) prompt.getOptions());
}

OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel())
.stream(stream)
Expand Down Expand Up @@ -520,7 +531,7 @@ private List<ChatRequest.Tool> getTools(List<ToolDefinition> toolDefinitions) {

@Override
public ChatOptions getDefaultOptions() {
return OllamaOptions.fromOptions(this.defaultOptions);
return OllamaChatOptions.fromOptions(this.defaultOptions);
}

/**
Expand All @@ -545,7 +556,7 @@ public static final class Builder {

private OllamaApi ollamaApi;

private OllamaOptions defaultOptions = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();
private OllamaChatOptions defaultOptions = OllamaChatOptions.builder().model(OllamaModel.MISTRAL.id()).build();

private ToolCallingManager toolCallingManager;

Expand All @@ -565,7 +576,13 @@ public Builder ollamaApi(OllamaApi ollamaApi) {
return this;
}

@Deprecated
public Builder defaultOptions(OllamaOptions defaultOptions) {
this.defaultOptions = OllamaChatOptions.fromOptions(defaultOptions);
return this;
}

public Builder defaultOptions(OllamaChatOptions defaultOptions) {
this.defaultOptions = defaultOptions;
return this;
}
Expand Down
Loading
Loading