Skip to content

Commit f908aa1

Browse files
garethjevansilayaperumalg
authored andcommitted
chore: separate out OllamaOptions into Chat & Embedding Options
Signed-off-by: Gareth Evans <[email protected]>
1 parent 1b3705f commit f908aa1

File tree

31 files changed

+1528
-172
lines changed

31 files changed

+1528
-172
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatProperties.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
package org.springframework.ai.model.ollama.autoconfigure;
1818

19+
import org.springframework.ai.ollama.api.OllamaChatOptions;
1920
import org.springframework.ai.ollama.api.OllamaModel;
20-
import org.springframework.ai.ollama.api.OllamaOptions;
2121
import org.springframework.boot.context.properties.ConfigurationProperties;
2222
import org.springframework.boot.context.properties.NestedConfigurationProperty;
2323

@@ -38,7 +38,7 @@ public class OllamaChatProperties {
3838
* generative's defaults.
3939
*/
4040
@NestedConfigurationProperty
41-
private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();
41+
private OllamaChatOptions options = OllamaChatOptions.builder().model(OllamaModel.MISTRAL.id()).build();
4242

4343
public String getModel() {
4444
return this.options.getModel();
@@ -48,7 +48,7 @@ public void setModel(String model) {
4848
this.options.setModel(model);
4949
}
5050

51-
public OllamaOptions getOptions() {
51+
public OllamaChatOptions getOptions() {
5252
return this.options;
5353
}
5454

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/main/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingProperties.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
package org.springframework.ai.model.ollama.autoconfigure;
1818

19+
import org.springframework.ai.ollama.api.OllamaEmbeddingOptions;
1920
import org.springframework.ai.ollama.api.OllamaModel;
20-
import org.springframework.ai.ollama.api.OllamaOptions;
2121
import org.springframework.boot.context.properties.ConfigurationProperties;
2222
import org.springframework.boot.context.properties.NestedConfigurationProperty;
2323

@@ -38,7 +38,9 @@ public class OllamaEmbeddingProperties {
3838
* generative's defaults.
3939
*/
4040
@NestedConfigurationProperty
41-
private OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MXBAI_EMBED_LARGE.id()).build();
41+
private OllamaEmbeddingOptions options = OllamaEmbeddingOptions.builder()
42+
.model(OllamaModel.MXBAI_EMBED_LARGE.id())
43+
.build();
4244

4345
public String getModel() {
4446
return this.options.getModel();
@@ -48,7 +50,7 @@ public void setModel(String model) {
4850
this.options.setModel(model);
4951
}
5052

51-
public OllamaOptions getOptions() {
53+
public OllamaEmbeddingOptions getOptions() {
5254
return this.options;
5355
}
5456

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ public void propertiesTest() {
3838
new ApplicationContextRunner().withPropertyValues(
3939
// @formatter:off
4040
"spring.ai.ollama.base-url=TEST_BASE_URL",
41-
"spring.ai.ollama.embedding.options.model=MODEL_XYZ",
42-
"spring.ai.ollama.embedding.options.temperature=0.13",
43-
"spring.ai.ollama.embedding.options.topK=13"
41+
"spring.ai.ollama.embedding.options.model=MODEL_XYZ"
4442
// @formatter:on
4543
)
4644

@@ -52,9 +50,6 @@ public void propertiesTest() {
5250

5351
assertThat(embeddingProperties.getModel()).isEqualTo("MODEL_XYZ");
5452
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
55-
assertThat(embeddingProperties.getOptions().toMap()).containsKeys("temperature");
56-
assertThat(embeddingProperties.getOptions().toMap().get("temperature")).isEqualTo(0.13);
57-
assertThat(embeddingProperties.getOptions().getTopK()).isEqualTo(13);
5853
});
5954
}
6055

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
import org.springframework.ai.model.ollama.autoconfigure.BaseOllamaIT;
3434
import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration;
3535
import org.springframework.ai.ollama.OllamaChatModel;
36+
import org.springframework.ai.ollama.api.OllamaChatOptions;
3637
import org.springframework.ai.ollama.api.OllamaModel;
37-
import org.springframework.ai.ollama.api.OllamaOptions;
3838
import org.springframework.ai.tool.function.FunctionToolCallback;
3939
import org.springframework.boot.autoconfigure.AutoConfigurations;
4040
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -70,7 +70,7 @@ void functionCallTest() {
7070
UserMessage userMessage = new UserMessage(
7171
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
7272

73-
var promptOptions = OllamaOptions.builder()
73+
var promptOptions = OllamaChatOptions.builder()
7474
.toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
7575
.description(
7676
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
@@ -95,7 +95,7 @@ void streamingFunctionCallTest() {
9595
UserMessage userMessage = new UserMessage(
9696
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
9797

98-
var promptOptions = OllamaOptions.builder()
98+
var promptOptions = OllamaChatOptions.builder()
9999
.toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
100100
.description(
101101
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration;
3636
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3737
import org.springframework.ai.ollama.OllamaChatModel;
38-
import org.springframework.ai.ollama.api.OllamaOptions;
38+
import org.springframework.ai.ollama.api.OllamaChatOptions;
3939
import org.springframework.ai.tool.ToolCallback;
4040
import org.springframework.ai.tool.function.FunctionToolCallback;
4141
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -94,7 +94,7 @@ void functionCallTest() {
9494
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
9595

9696
ChatResponse response = chatModel
97-
.call(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build()));
97+
.call(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("WeatherInfo").build()));
9898

9999
logger.info("Response: " + response);
100100

@@ -112,7 +112,7 @@ void streamFunctionCallTest() {
112112
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
113113

114114
Flux<ChatResponse> response = chatModel
115-
.stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build()));
115+
.stream(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("WeatherInfo").build()));
116116

117117
String content = response.collectList()
118118
.block()

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration;
3636
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3737
import org.springframework.ai.ollama.OllamaChatModel;
38+
import org.springframework.ai.ollama.api.OllamaChatOptions;
3839
import org.springframework.ai.ollama.api.OllamaModel;
39-
import org.springframework.ai.ollama.api.OllamaOptions;
4040
import org.springframework.ai.support.ToolCallbacks;
4141
import org.springframework.ai.tool.annotation.Tool;
4242
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -85,7 +85,7 @@ void toolCallTest() {
8585
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
8686

8787
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
88-
OllamaOptions.builder().toolCallbacks(ToolCallbacks.from(myTools)).build()));
88+
OllamaChatOptions.builder().toolCallbacks(ToolCallbacks.from(myTools)).build()));
8989

9090
logger.info("Response: {}", response);
9191

@@ -104,7 +104,7 @@ void functionCallTest() {
104104
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
105105

106106
ChatResponse response = chatModel
107-
.call(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build()));
107+
.call(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build()));
108108

109109
logger.info("Response: {}", response);
110110

@@ -122,7 +122,7 @@ void streamFunctionCallTest() {
122122
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
123123

124124
Flux<ChatResponse> response = chatModel
125-
.stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build()));
125+
.stream(new Prompt(List.of(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build()));
126126

127127
String content = response.collectList()
128128
.block()

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.springframework.ai.model.ollama.autoconfigure.BaseOllamaIT
2626
import org.springframework.ai.model.ollama.autoconfigure.OllamaChatAutoConfiguration
2727
import org.springframework.ai.model.tool.ToolCallingChatOptions
2828
import org.springframework.ai.ollama.OllamaChatModel
29-
import org.springframework.ai.ollama.api.OllamaOptions
29+
import org.springframework.ai.ollama.api.OllamaChatOptions
3030
import org.springframework.boot.autoconfigure.AutoConfigurations
3131
import org.springframework.boot.test.context.runner.ApplicationContextRunner
3232
import org.springframework.context.annotation.Bean
@@ -68,7 +68,7 @@ class FunctionCallbackResolverKotlinIT : BaseOllamaIT() {
6868
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.")
6969

7070
val response = chatModel
71-
.call(Prompt(listOf(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build()))
71+
.call(Prompt(listOf(userMessage), OllamaChatOptions.builder().toolNames("weatherInfo").build()))
7272

7373
logger.info("Response: $response")
7474

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
6161
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
6262
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
63+
import org.springframework.ai.ollama.api.OllamaChatOptions;
6364
import org.springframework.ai.ollama.api.OllamaModel;
6465
import org.springframework.ai.ollama.api.OllamaOptions;
6566
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
@@ -116,7 +117,7 @@ public class OllamaChatModel implements ChatModel {
116117

117118
private final OllamaApi chatApi;
118119

119-
private final OllamaOptions defaultOptions;
120+
private final OllamaChatOptions defaultOptions;
120121

121122
private final ObservationRegistry observationRegistry;
122123

@@ -134,13 +135,13 @@ public class OllamaChatModel implements ChatModel {
134135

135136
private final RetryTemplate retryTemplate;
136137

137-
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
138+
public OllamaChatModel(OllamaApi ollamaApi, OllamaChatOptions defaultOptions, ToolCallingManager toolCallingManager,
138139
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
139140
this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions,
140141
new DefaultToolExecutionEligibilityPredicate(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
141142
}
142143

143-
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
144+
public OllamaChatModel(OllamaApi ollamaApi, OllamaChatOptions defaultOptions, ToolCallingManager toolCallingManager,
144145
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
145146
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate, RetryTemplate retryTemplate) {
146147

@@ -396,21 +397,25 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
396397

397398
Prompt buildRequestPrompt(Prompt prompt) {
398399
// Process runtime options
399-
OllamaOptions runtimeOptions = null;
400+
OllamaChatOptions runtimeOptions = null;
400401
if (prompt.getOptions() != null) {
401-
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
402+
if (prompt.getOptions() instanceof OllamaOptions ollamaOptions) {
403+
runtimeOptions = ModelOptionsUtils.copyToTarget(OllamaChatOptions.fromOptions(ollamaOptions),
404+
OllamaChatOptions.class, OllamaChatOptions.class);
405+
}
406+
else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
402407
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
403-
OllamaOptions.class);
408+
OllamaChatOptions.class);
404409
}
405410
else {
406411
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
407-
OllamaOptions.class);
412+
OllamaChatOptions.class);
408413
}
409414
}
410415

411416
// Define request options by merging runtime options and default options
412-
OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
413-
OllamaOptions.class);
417+
OllamaChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
418+
OllamaChatOptions.class);
414419
// Merge @JsonIgnore-annotated options explicitly since they are ignored by
415420
// Jackson, used by ModelOptionsUtils.
416421
if (runtimeOptions != null) {
@@ -489,7 +494,13 @@ else if (message.getMessageType() == MessageType.TOOL) {
489494
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
490495
}).flatMap(List::stream).toList();
491496

492-
OllamaOptions requestOptions = (OllamaOptions) prompt.getOptions();
497+
OllamaChatOptions requestOptions = null;
498+
if (prompt.getOptions() instanceof OllamaChatOptions) {
499+
requestOptions = (OllamaChatOptions) prompt.getOptions();
500+
}
501+
else {
502+
requestOptions = OllamaChatOptions.fromOptions((OllamaOptions) prompt.getOptions());
503+
}
493504

494505
OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel())
495506
.stream(stream)
@@ -535,7 +546,7 @@ private List<ChatRequest.Tool> getTools(List<ToolDefinition> toolDefinitions) {
535546

536547
@Override
537548
public ChatOptions getDefaultOptions() {
538-
return OllamaOptions.fromOptions(this.defaultOptions);
549+
return OllamaChatOptions.fromOptions(this.defaultOptions);
539550
}
540551

541552
/**
@@ -560,7 +571,7 @@ public static final class Builder {
560571

561572
private OllamaApi ollamaApi;
562573

563-
private OllamaOptions defaultOptions = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();
574+
private OllamaChatOptions defaultOptions = OllamaChatOptions.builder().model(OllamaModel.MISTRAL.id()).build();
564575

565576
private ToolCallingManager toolCallingManager;
566577

@@ -580,7 +591,13 @@ public Builder ollamaApi(OllamaApi ollamaApi) {
580591
return this;
581592
}
582593

594+
@Deprecated
583595
public Builder defaultOptions(OllamaOptions defaultOptions) {
596+
this.defaultOptions = OllamaChatOptions.fromOptions(defaultOptions);
597+
return this;
598+
}
599+
600+
public Builder defaultOptions(OllamaChatOptions defaultOptions) {
584601
this.defaultOptions = defaultOptions;
585602
return this;
586603
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.springframework.ai.model.ModelOptionsUtils;
4242
import org.springframework.ai.ollama.api.OllamaApi;
4343
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
44+
import org.springframework.ai.ollama.api.OllamaEmbeddingOptions;
4445
import org.springframework.ai.ollama.api.OllamaModel;
4546
import org.springframework.ai.ollama.api.OllamaOptions;
4647
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
@@ -69,15 +70,15 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
6970

7071
private final OllamaApi ollamaApi;
7172

72-
private final OllamaOptions defaultOptions;
73+
private final OllamaEmbeddingOptions defaultOptions;
7374

7475
private final ObservationRegistry observationRegistry;
7576

7677
private final OllamaModelManager modelManager;
7778

7879
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
7980

80-
public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
81+
public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaEmbeddingOptions defaultOptions,
8182
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
8283
Assert.notNull(ollamaApi, "ollamaApi must not be null");
8384
Assert.notNull(defaultOptions, "options must not be null");
@@ -146,15 +147,15 @@ private DefaultUsage getDefaultUsage(OllamaApi.EmbeddingsResponse response) {
146147

147148
EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
148149
// Process runtime options
149-
OllamaOptions runtimeOptions = null;
150+
OllamaEmbeddingOptions runtimeOptions = null;
150151
if (embeddingRequest.getOptions() != null) {
151152
runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class,
152-
OllamaOptions.class);
153+
OllamaEmbeddingOptions.class);
153154
}
154155

155156
// Define request options by merging runtime options and default options
156-
OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
157-
OllamaOptions.class);
157+
OllamaEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
158+
OllamaEmbeddingOptions.class);
158159

159160
// Validate request options
160161
if (!StringUtils.hasText(requestOptions.getModel())) {
@@ -168,10 +169,17 @@ EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
168169
* Package access for testing.
169170
*/
170171
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(EmbeddingRequest embeddingRequest) {
171-
OllamaOptions requestOptions = (OllamaOptions) embeddingRequest.getOptions();
172+
OllamaEmbeddingOptions requestOptions = null;
173+
if (embeddingRequest.getOptions() instanceof OllamaEmbeddingOptions) {
174+
requestOptions = (OllamaEmbeddingOptions) embeddingRequest.getOptions();
175+
}
176+
else {
177+
requestOptions = OllamaEmbeddingOptions.fromOptions((OllamaOptions) embeddingRequest.getOptions());
178+
}
179+
172180
return new OllamaApi.EmbeddingsRequest(requestOptions.getModel(), embeddingRequest.getInstructions(),
173181
DurationParser.parse(requestOptions.getKeepAlive()),
174-
OllamaOptions.filterNonSupportedFields(requestOptions.toMap()), requestOptions.getTruncate());
182+
OllamaEmbeddingOptions.filterNonSupportedFields(requestOptions.toMap()), requestOptions.getTruncate());
175183
}
176184

177185
/**
@@ -227,7 +235,7 @@ public static final class Builder {
227235

228236
private OllamaApi ollamaApi;
229237

230-
private OllamaOptions defaultOptions = OllamaOptions.builder()
238+
private OllamaEmbeddingOptions defaultOptions = OllamaEmbeddingOptions.builder()
231239
.model(OllamaModel.MXBAI_EMBED_LARGE.id())
232240
.build();
233241

@@ -243,7 +251,13 @@ public Builder ollamaApi(OllamaApi ollamaApi) {
243251
return this;
244252
}
245253

254+
@Deprecated
246255
public Builder defaultOptions(OllamaOptions defaultOptions) {
256+
this.defaultOptions = OllamaEmbeddingOptions.fromOptions(defaultOptions);
257+
return this;
258+
}
259+
260+
public Builder defaultOptions(OllamaEmbeddingOptions defaultOptions) {
247261
this.defaultOptions = defaultOptions;
248262
return this;
249263
}

0 commit comments

Comments
 (0)