Skip to content

Commit 62506a5

Browse files
committed
chore: separate out OllamaOptions into Chat & Embedding Options
Signed-off-by: Gareth Evans <[email protected]>
1 parent b059cdf commit 62506a5

File tree

31 files changed

+1515
-163
lines changed

31 files changed

+1515
-163
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatProperties.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
package org.springframework.ai.model.ollama.autoconfigure;
1818

19+
import org.springframework.ai.ollama.api.OllamaChatOptions;
1920
import org.springframework.ai.ollama.api.OllamaModel;
20-
import org.springframework.ai.ollama.api.OllamaOptions;
2121
import org.springframework.boot.context.properties.ConfigurationProperties;
2222
import org.springframework.boot.context.properties.NestedConfigurationProperty;
2323

@@ -38,7 +38,7 @@ public class OllamaChatProperties {
3838
* generative's defaults.
3939
*/
4040
@NestedConfigurationProperty
41-
private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();
41+
private OllamaChatOptions options = OllamaChatOptions.builder().model(OllamaModel.MISTRAL.id()).build();
4242

4343
public String getModel() {
4444
return this.options.getModel();
@@ -48,7 +48,7 @@ public void setModel(String model) {
4848
this.options.setModel(model);
4949
}
5050

51-
public OllamaOptions getOptions() {
51+
public OllamaChatOptions getOptions() {
5252
return this.options;
5353
}
5454

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingProperties.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
package org.springframework.ai.model.ollama.autoconfigure;
1818

19+
import org.springframework.ai.ollama.api.OllamaEmbeddingOptions;
1920
import org.springframework.ai.ollama.api.OllamaModel;
20-
import org.springframework.ai.ollama.api.OllamaOptions;
2121
import org.springframework.boot.context.properties.ConfigurationProperties;
2222
import org.springframework.boot.context.properties.NestedConfigurationProperty;
2323

@@ -38,7 +38,9 @@ public class OllamaEmbeddingProperties {
3838
* generative's defaults.
3939
*/
4040
@NestedConfigurationProperty
41-
private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE.id()).build();
41+
private OllamaEmbeddingOptions options = OllamaEmbeddingOptions.builder()
42+
.model(OllamaModel.MXBAI_EMBED_LARGE.id())
43+
.build();
4244

4345
public String getModel() {
4446
return this.options.getModel();
@@ -48,7 +50,7 @@ public void setModel(String model) {
4850
this.options.setModel(model);
4951
}
5052

51-
public OllamaOptions getOptions() {
53+
public OllamaEmbeddingOptions getOptions() {
5254
return this.options;
5355
}
5456

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ public void propertiesTest() {
3838
new ApplicationContextRunner().withPropertyValues(
3939
// @formatter:off
4040
"spring.ai.ollama.base-url=TEST_BASE_URL",
41-
"spring.ai.ollama.embedding.options.model=MODEL_XYZ",
42-
"spring.ai.ollama.embedding.options.temperature=0.13",
43-
"spring.ai.ollama.embedding.options.topK=13"
41+
"spring.ai.ollama.embedding.options.model=MODEL_XYZ"
4442
// @formatter:on
4543
)
4644

@@ -52,9 +50,6 @@ public void propertiesTest() {
5250

5351
assertThat(embeddingProperties.getModel()).isEqualTo("MODEL_XYZ");
5452
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
55-
assertThat(embeddingProperties.getOptions().toMap()).containsKeys("temperature");
56-
assertThat(embeddingProperties.getOptions().toMap().get("temperature")).isEqualTo(0.13);
57-
assertThat(embeddingProperties.getOptions().getTopK()).isEqualTo(13);
5853
});
5954
}
6055

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import org.junit.jupiter.api.Test;
2424
import org.slf4j.Logger;
2525
import org.slf4j.LoggerFactory;
26+
import org.springframework.ai.ollama.api.OllamaModel;
27+
import org.springframework.ai.ollama.api.OllamaChatOptions;
2628
import reactor.core.publisher.Flux;
2729

2830
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -70,7 +72,7 @@ void functionCallTest() {
7072
UserMessage userMessage = new UserMessage(
7173
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
7274

73-
var promptOptions = OllamaOptions.builder()
75+
var promptOptions = OllamaChatOptions.builder()
7476
.toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
7577
.description(
7678
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
@@ -95,7 +97,7 @@ void streamingFunctionCallTest() {
9597
UserMessage userMessage = new UserMessage(
9698
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
9799

98-
var promptOptions = OllamaOptions.builder()
100+
var promptOptions = OllamaChatOptions.builder()
99101
.toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
100102
.description(
101103
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.junit.jupiter.api.Test;
2424
import org.slf4j.Logger;
2525
import org.slf4j.LoggerFactory;
26+
import org.springframework.ai.ollama.api.OllamaChatOptions;
2627
import reactor.core.publisher.Flux;
2728

2829
import org.springframework.ai.chat.client.ChatClient;
@@ -35,7 +36,6 @@
3536
import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration;
3637
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3738
import org.springframework.ai.ollama.OllamaChatModel;
38-
import org.springframework.ai.ollama.api.OllamaOptions;
3939
import org.springframework.ai.tool.ToolCallback;
4040
import org.springframework.ai.tool.function.FunctionToolCallback;
4141
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -94,7 +94,7 @@ void functionCallTest() {
9494
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
9595

9696
ChatResponse response = chatModel
97-
.call(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build()));
97+
.call(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("WeatherInfo").build()));
9898

9999
logger.info("Response: " + response);
100100

@@ -112,7 +112,7 @@ void streamFunctionCallTest() {
112112
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
113113

114114
Flux<ChatResponse> response = chatModel
115-
.stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build()));
115+
.stream(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("WeatherInfo").build()));
116116

117117
String content = response.collectList()
118118
.block()

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import org.junit.jupiter.api.Test;
2525
import org.slf4j.Logger;
2626
import org.slf4j.LoggerFactory;
27+
import org.springframework.ai.ollama.api.OllamaModel;
28+
import org.springframework.ai.ollama.api.OllamaChatOptions;
2729
import reactor.core.publisher.Flux;
2830

2931
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -36,7 +38,6 @@
3638
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3739
import org.springframework.ai.ollama.OllamaChatModel;
3840
import org.springframework.ai.ollama.api.OllamaModel;
39-
import org.springframework.ai.ollama.api.OllamaOptions;
4041
import org.springframework.ai.support.ToolCallbacks;
4142
import org.springframework.ai.tool.annotation.Tool;
4243
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -85,7 +86,7 @@ void toolCallTest() {
8586
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
8687

8788
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
88-
OllamaOptions.builder().toolCallbacks(ToolCallbacks.from(myTools)).build()));
89+
OllamaChatOptions.builder().toolCallbacks(ToolCallbacks.from(myTools)).build()));
8990

9091
logger.info("Response: {}", response);
9192

@@ -104,7 +105,7 @@ void functionCallTest() {
104105
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
105106

106107
ChatResponse response = chatModel
107-
.call(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build()));
108+
.call(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build()));
108109

109110
logger.info("Response: {}", response);
110111

@@ -122,7 +123,7 @@ void streamFunctionCallTest() {
122123
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
123124

124125
Flux<ChatResponse> response = chatModel
125-
.stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build()));
126+
.stream(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build()));
126127

127128
String content = response.collectList()
128129
.block()

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.springframework.ai.model.ollama.autoconfigure.BaseOllamaIT
2626
import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration
2727
import org.springframework.ai.model.tool.ToolCallingChatOptions
2828
import org.springframework.ai.ollama.OllamaChatModel
29-
import org.springframework.ai.ollama.api.OllamaOptions
29+
import org.springframework.ai.ollama.api.OllamaChatOptions
3030
import org.springframework.boot.autoconfigure.AutoConfigurations
3131
import org.springframework.boot.test.context.runner.ApplicationContextRunner
3232
import org.springframework.context.annotation.Bean
@@ -68,7 +68,7 @@ class FunctionCallbackResolverKotlinIT : BaseOllamaIT() {
6868
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.")
6969

7070
val response = chatModel
71-
.call(Prompt(listOf(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build()))
71+
.call(Prompt(listOf(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build()))
7272

7373
logger.info("Response: $response")
7474

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
2929
import org.slf4j.Logger;
3030
import org.slf4j.LoggerFactory;
31+
import org.springframework.ai.ollama.api.OllamaChatOptions;
3132
import reactor.core.publisher.Flux;
3233
import reactor.core.scheduler.Schedulers;
3334

@@ -116,7 +117,7 @@ public class OllamaChatModel implements ChatModel {
116117

117118
private final OllamaApi chatApi;
118119

119-
private final OllamaOptions defaultOptions;
120+
private final OllamaChatOptions defaultOptions;
120121

121122
private final ObservationRegistry observationRegistry;
122123

@@ -134,13 +135,13 @@ public class OllamaChatModel implements ChatModel {
134135

135136
private final RetryTemplate retryTemplate;
136137

137-
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
138+
public OllamaChatModel(OllamaApi ollamaApi, OllamaChatOptions defaultOptions, ToolCallingManager toolCallingManager,
138139
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
139140
this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions,
140141
new DefaultToolExecutionEligibilityPredicate(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
141142
}
142143

143-
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
144+
public OllamaChatModel(OllamaApi ollamaApi, OllamaChatOptions defaultOptions, ToolCallingManager toolCallingManager,
144145
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
145146
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate, RetryTemplate retryTemplate) {
146147

@@ -388,21 +389,25 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
388389

389390
Prompt buildRequestPrompt(Prompt prompt) {
390391
// Process runtime options
391-
OllamaOptions runtimeOptions = null;
392+
OllamaChatOptions runtimeOptions = null;
392393
if (prompt.getOptions() != null) {
393-
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
394+
if (prompt.getOptions() instanceof OllamaOptions ollamaOptions) {
395+
runtimeOptions = ModelOptionsUtils.copyToTarget(OllamaChatOptions.fromOptions(ollamaOptions),
396+
OllamaChatOptions.class, OllamaChatOptions.class);
397+
}
398+
else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
394399
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
395-
OllamaOptions.class);
400+
OllamaChatOptions.class);
396401
}
397402
else {
398403
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
399-
OllamaOptions.class);
404+
OllamaChatOptions.class);
400405
}
401406
}
402407

403408
// Define request options by merging runtime options and default options
404-
OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
405-
OllamaOptions.class);
409+
OllamaChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
410+
OllamaChatOptions.class);
406411
// Merge @JsonIgnore-annotated options explicitly since they are ignored by
407412
// Jackson, used by ModelOptionsUtils.
408413
if (runtimeOptions != null) {
@@ -474,7 +479,13 @@ else if (message instanceof ToolResponseMessage toolMessage) {
474479
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
475480
}).flatMap(List::stream).toList();
476481

477-
OllamaOptions requestOptions = (OllamaOptions) prompt.getOptions();
482+
OllamaChatOptions requestOptions = null;
483+
if (prompt.getOptions() instanceof OllamaChatOptions) {
484+
requestOptions = (OllamaChatOptions) prompt.getOptions();
485+
}
486+
else {
487+
requestOptions = OllamaChatOptions.fromOptions((OllamaOptions) prompt.getOptions());
488+
}
478489

479490
OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel())
480491
.stream(stream)
@@ -520,7 +531,7 @@ private List<ChatRequest.Tool> getTools(List<ToolDefinition> toolDefinitions) {
520531

521532
@Override
522533
public ChatOptions getDefaultOptions() {
523-
return OllamaOptions.fromOptions(this.defaultOptions);
534+
return OllamaChatOptions.fromOptions(this.defaultOptions);
524535
}
525536

526537
/**
@@ -545,7 +556,7 @@ public static final class Builder {
545556

546557
private OllamaApi ollamaApi;
547558

548-
private OllamaOptions defaultOptions = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();
559+
private OllamaChatOptions defaultOptions = OllamaChatOptions.builder().model(OllamaModel.MISTRAL.id()).build();
549560

550561
private ToolCallingManager toolCallingManager;
551562

@@ -565,7 +576,13 @@ public Builder ollamaApi(OllamaApi ollamaApi) {
565576
return this;
566577
}
567578

579+
@Deprecated
568580
public Builder defaultOptions(OllamaOptions defaultOptions) {
581+
this.defaultOptions = OllamaChatOptions.fromOptions(defaultOptions);
582+
return this;
583+
}
584+
585+
public Builder defaultOptions(OllamaChatOptions defaultOptions) {
569586
this.defaultOptions = defaultOptions;
570587
return this;
571588
}

0 commit comments

Comments
 (0)