Skip to content

Commit ecae493

Browse files
authored
Fix handling of nullable fields in OllamaChatModel deprecated constructor (#2174)
Fix gh-2172 Signed-off-by: Thomas Vitale <[email protected]>
1 parent 1c6132c commit ecae493

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.model.tool.ToolExecutionResult;
3535
import org.springframework.ai.tool.definition.ToolDefinition;
3636
import org.springframework.ai.util.json.JsonParser;
37+
import org.springframework.lang.Nullable;
3738
import reactor.core.publisher.Flux;
3839

3940
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -125,8 +126,9 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
125126

126127
@Deprecated
127128
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
128-
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
129-
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
129+
@Nullable FunctionCallbackResolver functionCallbackResolver,
130+
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry,
131+
ModelManagementOptions modelManagementOptions) {
130132
super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks);
131133
Assert.notNull(ollamaApi, "ollamaApi must not be null");
132134
Assert.notNull(defaultOptions, "defaultOptions must not be null");

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,17 +20,21 @@
2020
import java.time.Instant;
2121
import java.util.List;
2222

23+
import io.micrometer.observation.ObservationRegistry;
2324
import org.junit.jupiter.api.Test;
2425
import org.junit.jupiter.api.extension.ExtendWith;
2526
import org.mockito.Mock;
2627
import org.mockito.junit.jupiter.MockitoExtension;
2728

2829
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
2930
import org.springframework.ai.chat.metadata.DefaultUsage;
31+
import org.springframework.ai.chat.model.ChatModel;
3032
import org.springframework.ai.chat.model.ChatResponse;
33+
import org.springframework.ai.model.tool.ToolCallingManager;
3134
import org.springframework.ai.ollama.api.OllamaApi;
3235
import org.springframework.ai.ollama.api.OllamaModel;
3336
import org.springframework.ai.ollama.api.OllamaOptions;
37+
import org.springframework.ai.ollama.management.ModelManagementOptions;
3438

3539
import static org.assertj.core.api.Assertions.assertThat;
3640
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -40,16 +44,39 @@
4044
* @author Jihoon Kim
4145
* @author Christian Tzolov
4246
* @author Alexandros Pappas
47+
* @author Thomas Vitale
4348
* @since 1.0.0
4449
*/
4550
@ExtendWith(MockitoExtension.class)
46-
public class OllamaChatModelTests {
51+
class OllamaChatModelTests {
4752

4853
@Mock
4954
OllamaApi ollamaApi;
5055

5156
@Test
52-
public void buildOllamaChatModel() {
57+
void buildOllamaChatModelWithDeprecatedConstructor() {
58+
ChatModel chatModel = new OllamaChatModel(this.ollamaApi,
59+
OllamaOptions.builder().model(OllamaModel.MISTRAL).build(), null, null, ObservationRegistry.NOOP,
60+
ModelManagementOptions.builder().build());
61+
assertThat(chatModel).isNotNull();
62+
}
63+
64+
@Test
65+
void buildOllamaChatModelWithConstructor() {
66+
ChatModel chatModel = new OllamaChatModel(this.ollamaApi,
67+
OllamaOptions.builder().model(OllamaModel.MISTRAL).build(), ToolCallingManager.builder().build(),
68+
ObservationRegistry.NOOP, ModelManagementOptions.builder().build());
69+
assertThat(chatModel).isNotNull();
70+
}
71+
72+
@Test
73+
void buildOllamaChatModelWithBuilder() {
74+
ChatModel chatModel = OllamaChatModel.builder().ollamaApi(ollamaApi).build();
75+
assertThat(chatModel).isNotNull();
76+
}
77+
78+
@Test
79+
void buildOllamaChatModel() {
5380
Exception exception = assertThrows(IllegalArgumentException.class,
5481
() -> OllamaChatModel.builder()
5582
.ollamaApi(this.ollamaApi)
@@ -60,7 +87,7 @@ public void buildOllamaChatModel() {
6087
}
6188

6289
@Test
63-
public void buildChatResponseMetadata() {
90+
void buildChatResponseMetadata() {
6491

6592
Long evalDuration = 1000L;
6693
Integer evalCount = 101;
@@ -83,7 +110,7 @@ public void buildChatResponseMetadata() {
83110
}
84111

85112
@Test
86-
public void buildChatResponseMetadataAggregationWithNonEmptyMetadata() {
113+
void buildChatResponseMetadataAggregationWithNonEmptyMetadata() {
87114

88115
Long evalDuration = 1000L;
89116
Integer evalCount = 101;

spring-ai-core/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
@Deprecated
5656
public class LegacyToolCallingManager implements ToolCallingManager {
5757

58+
@Nullable
5859
private final FunctionCallbackResolver functionCallbackResolver;
5960

6061
private final Map<String, FunctionCallback> functionCallbacks = new HashMap<>();
@@ -64,11 +65,11 @@ public class LegacyToolCallingManager implements ToolCallingManager {
6465
.build();
6566

6667
public LegacyToolCallingManager(@Nullable FunctionCallbackResolver functionCallbackResolver,
67-
List<FunctionCallback> functionCallbacks) {
68-
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
69-
Assert.noNullElements(functionCallbacks.toArray(), "functionCallbacks cannot contain null elements");
68+
@Nullable List<FunctionCallback> functionCallbacks) {
7069
this.functionCallbackResolver = functionCallbackResolver;
71-
functionCallbacks.forEach(toolCallback -> this.functionCallbacks.put(toolCallback.getName(), toolCallback));
70+
if (functionCallbacks != null) {
71+
functionCallbacks.forEach(toolCallback -> this.functionCallbacks.put(toolCallback.getName(), toolCallback));
72+
}
7273
}
7374

7475
@Override

0 commit comments

Comments
 (0)