Skip to content

Commit 98cfb16

Browse files
committed
Update Ollama APIs and fix multimodality test
* Add doneReason to ChatResponse and update OllamaChatModel accordingly * Add missing descriptions to Ollama options * Consolidate Ollama Testcontainers image setup * Fix multimodality Ollama test Signed-off-by: Thomas Vitale <[email protected]>
1 parent 554fbcd commit 98cfb16

File tree

11 files changed

+145
-68
lines changed

11 files changed

+145
-68
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
*
6060
* @author Christian Tzolov
6161
* @author luocongqiu
62+
* @author Thomas Vitale
6263
* @since 1.0.0
6364
*/
6465
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {
@@ -125,13 +126,13 @@ public ChatResponse call(Prompt prompt) {
125126

126127
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
127128
if (response.promptEvalCount() != null && response.evalCount() != null) {
128-
generationMetadata = ChatGenerationMetadata.from("DONE", null);
129+
generationMetadata = ChatGenerationMetadata.from(response.doneReason(), null);
129130
}
130131

131132
var generator = new Generation(assistantMessage, generationMetadata);
132133
var chatResponse = new ChatResponse(List.of(generator), from(response));
133134

134-
if (isToolCall(chatResponse, Set.of("DONE"))) {
135+
if (isToolCall(chatResponse, Set.of("stop"))) {
135136
var toolCallConversation = handleToolCalls(prompt, chatResponse);
136137
// Recursively call the call method with the tool call message
137138
// conversation that contains the call responses.
@@ -176,7 +177,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
176177

177178
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
178179
if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
179-
generationMetadata = ChatGenerationMetadata.from("DONE", null);
180+
generationMetadata = ChatGenerationMetadata.from(chunk.doneReason(), null);
180181
}
181182

182183
var generator = new Generation(assistantMessage, generationMetadata);
@@ -203,9 +204,6 @@ OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
203204

204205
List<OllamaApi.Message> ollamaMessages = prompt.getInstructions()
205206
.stream()
206-
.filter(message -> message.getMessageType() == MessageType.USER
207-
|| message.getMessageType() == MessageType.ASSISTANT
208-
|| message.getMessageType() == MessageType.SYSTEM || message.getMessageType() == MessageType.TOOL)
209207
.map(message -> {
210208
if (message instanceof UserMessage userMessage) {
211209
var messageBuilder = OllamaApi.Message.builder(Role.USER).withContent(message.getContent());

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
* Java Client for the Ollama API. <a href="https://ollama.ai/">https://ollama.ai</a>
4848
*
4949
* @author Christian Tzolov
50+
* @author Thomas Vitale
5051
* @since 0.8.0
5152
*/
5253
// @formatter:off
@@ -454,15 +455,20 @@ public Message build() {
454455
/**
455456
* Chat request object.
456457
*
457-
* @param model The model to use for completion.
458-
* @param messages The list of messages to chat with.
459-
* @param stream Whether to stream the response.
460-
* @param format The format to return the response in. Currently, the only accepted
461-
* value is "json".
462-
* @param keepAlive The duration to keep the model loaded in ollama while idle.
463-
* @param options Additional model parameters. You can use the {@link OllamaOptions} builder
464-
* to create the options then {@link OllamaOptions#toMap()} to convert the options into a
465-
* map.
458+
* @param model The model to use for completion. It should be a name familiar to Ollama from the <a href="https://ollama.com/library">Library</a>.
459+
* @param messages The list of messages in the chat. This can be used to keep a chat memory.
460+
* @param stream Whether to stream the response. If false, the response will be returned as a single response object rather than a stream of objects.
461+
* @param format The format to return the response in. Currently, the only accepted value is "json".
462+
* @param keepAlive Controls how long the model will stay loaded into memory following this request (default: 5m).
463+
* @param tools List of tools the model has access to.
464+
* @param options Model-specific options. For example, "temperature" can be set through this field, if the model supports it.
465+
* You can use the {@link OllamaOptions} builder to create the options then {@link OllamaOptions#toMap()} to convert the options into a map.
466+
*
467+
* @see <a href=
468+
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Chat
469+
* Completion API</a>
470+
* @see <a href="https://github.com/ollama/ollama/blob/main/api/types.go">Ollama
471+
* Types</a>
466472
*/
467473
@JsonInclude(Include.NON_NULL)
468474
public record ChatRequest(
@@ -471,9 +477,9 @@ public record ChatRequest(
471477
@JsonProperty("stream") Boolean stream,
472478
@JsonProperty("format") String format,
473479
@JsonProperty("keep_alive") String keepAlive,
474-
@JsonProperty("options") Map<String, Object> options,
475-
@JsonProperty("tools") List<Tool> tools) {
476-
480+
@JsonProperty("tools") List<Tool> tools,
481+
@JsonProperty("options") Map<String, Object> options
482+
) {
477483

478484
/**
479485
* Represents a tool the model may call. Currently, only functions are supported as a tool.
@@ -544,8 +550,8 @@ public static class Builder {
544550
private boolean stream = false;
545551
private String format;
546552
private String keepAlive;
547-
private Map<String, Object> options = Map.of();
548553
private List<Tool> tools = List.of();
554+
private Map<String, Object> options = Map.of();
549555

550556
public Builder(String model) {
551557
Assert.notNull(model, "The model can not be null.");
@@ -572,6 +578,11 @@ public Builder withKeepAlive(String keepAlive) {
572578
return this;
573579
}
574580

581+
public Builder withTools(List<Tool> tools) {
582+
this.tools = tools;
583+
return this;
584+
}
585+
575586
public Builder withOptions(Map<String, Object> options) {
576587
Objects.requireNonNull(options, "The options can not be null.");
577588

@@ -585,33 +596,30 @@ public Builder withOptions(OllamaOptions options) {
585596
return this;
586597
}
587598

588-
public Builder withTools(List<Tool> tools) {
589-
this.tools = tools;
590-
return this;
591-
}
592-
593599
public ChatRequest build() {
594-
return new ChatRequest(model, messages, stream, format, keepAlive, options, tools);
600+
return new ChatRequest(model, messages, stream, format, keepAlive, tools, options);
595601
}
596602
}
597603
}
598604

599605
/**
600606
* Ollama chat response object.
601607
*
602-
* @param model The model name used for completion.
603-
* @param createdAt When the request was made.
608+
* @param model The model used for generating the response.
609+
* @param createdAt The timestamp of the response generation.
604610
* @param message The response {@link Message} with {@link Message.Role#ASSISTANT}.
611+
* @param doneReason The reason the model stopped generating text.
605612
* @param done Whether this is the final response. For streaming response only the
606613
* last message is marked as done. If true, this response may be followed by another
607614
* response with the following, additional fields: context, prompt_eval_count,
608615
* prompt_eval_duration, eval_count, eval_duration.
609616
* @param totalDuration Time spent generating the response.
610617
* @param loadDuration Time spent loading the model.
611-
* @param promptEvalCount number of tokens in the prompt.(*)
612-
* @param promptEvalDuration time spent evaluating the prompt.
613-
* @param evalCount number of tokens in the response.
614-
* @param evalDuration time spent generating the response.
618+
* @param promptEvalCount Number of tokens in the prompt.
619+
* @param promptEvalDuration Time spent evaluating the prompt.
620+
* @param evalCount Number of tokens in the response.
621+
* @param evalDuration Time spent generating the response.
622+
*
615623
* @see <a href=
616624
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Chat
617625
* Completion API</a>
@@ -623,13 +631,15 @@ public record ChatResponse(
623631
@JsonProperty("model") String model,
624632
@JsonProperty("created_at") Instant createdAt,
625633
@JsonProperty("message") Message message,
634+
@JsonProperty("done_reason") String doneReason,
626635
@JsonProperty("done") Boolean done,
627636
@JsonProperty("total_duration") Duration totalDuration,
628637
@JsonProperty("load_duration") Duration loadDuration,
629638
@JsonProperty("prompt_eval_count") Integer promptEvalCount,
630639
@JsonProperty("prompt_eval_duration") Duration promptEvalDuration,
631640
@JsonProperty("eval_count") Integer evalCount,
632-
@JsonProperty("eval_duration") Duration evalDuration) {
641+
@JsonProperty("eval_duration") Duration evalDuration
642+
) {
633643
}
634644

635645
/**

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* Helper class for common Ollama models.
2222
*
2323
* @author Siarhei Blashuk
24+
* @author Thomas Vitale
2425
* @since 0.8.1
2526
*/
2627
public enum OllamaModel implements ChatModelDescription {
@@ -35,11 +36,27 @@ public enum OllamaModel implements ChatModelDescription {
3536
*/
3637
LLAMA3("llama3"),
3738

39+
/**
40+
* The 8B language model from Meta.
41+
*/
42+
LLAMA3_1("llama3-1"),
43+
3844
/**
3945
* The 7B parameters model
4046
*/
4147
MISTRAL("mistral"),
4248

49+
/**
50+
* A 12B model with 128k context length, built by Mistral AI in collaboration with
51+
* NVIDIA.
52+
*/
53+
MISTRAL_NEMO("mistral-nemo"),
54+
55+
/**
56+
* A small vision language model designed to run efficiently on edge devices.
57+
*/
58+
MOONDREAM("moondream"),
59+
4360
/**
4461
* The 2.7B uncensored Dolphin model
4562
*/

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
* Helper class for creating strongly-typed Ollama options.
4141
*
4242
* @author Christian Tzolov
43+
* @author Thomas Vitale
4344
* @since 0.8.0
4445
* @see <a href=
4546
* "https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">Ollama
@@ -53,11 +54,13 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
5354

5455
private static final List<String> NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive");
5556

56-
// Following fields are ptions which must be set when the model is loaded into memory.
57+
// Following fields are options which must be set when the model is loaded into memory.
58+
// See: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/README.md
5759

5860
// @formatter:off
61+
5962
/**
60-
* useNUMA Whether to use NUMA.
63+
* Whether to use NUMA. (Default: false)
6164
*/
6265
@JsonProperty("numa") private Boolean useNUMA;
6366

@@ -67,63 +70,78 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
6770
@JsonProperty("num_ctx") private Integer numCtx;
6871

6972
/**
70-
* ???
73+
* Prompt processing maximum batch size. (Default: 512)
7174
*/
7275
@JsonProperty("num_batch") private Integer numBatch;
7376

7477
/**
7578
* The number of layers to send to the GPU(s). On macOS, it defaults to 1
7679
* to enable metal support, 0 to disable.
77-
*/
80+
* (Default: -1, which indicates that numGPU should be set dynamically)
81+
*/
7882
@JsonProperty("num_gpu") private Integer numGPU;
7983

8084
/**
81-
* ???
85+
* When using multiple GPUs this option controls which GPU is used
86+
* for small tensors for which the overhead of splitting the computation
87+
* across all GPUs is not worthwhile. The GPU in question will use slightly
88+
* more VRAM to store a scratch buffer for temporary results.
89+
* By default, GPU 0 is used.
8290
*/
8391
@JsonProperty("main_gpu")private Integer mainGPU;
8492

8593
/**
86-
* ???
94+
* (Default: false)
8795
*/
8896
@JsonProperty("low_vram") private Boolean lowVRAM;
8997

9098
/**
91-
* ???
99+
* (Default: true)
92100
*/
93101
@JsonProperty("f16_kv") private Boolean f16KV;
94102

95103
/**
96-
* ???
104+
* Return logits for all the tokens, not just the last one.
105+
* To enable completions to return logprobs, this must be true.
97106
*/
98107
@JsonProperty("logits_all") private Boolean logitsAll;
99108

100109
/**
101-
* ???
110+
* Load only the vocabulary, not the weights.
102111
*/
103112
@JsonProperty("vocab_only") private Boolean vocabOnly;
104113

105114
/**
106-
* ???
115+
* By default, models are mapped into memory, which allows the system to load only the necessary parts
116+
* of the model as needed. However, if the model is larger than your total amount of RAM or if your system is low
117+
* on available memory, using mmap might increase the risk of pageouts, negatively impacting performance.
118+
* Disabling mmap results in slower load times but may reduce pageouts if you're not using mlock.
119+
* Note that if the model is larger than the total amount of RAM, turning off mmap would prevent
120+
* the model from loading at all.
121+
* (Default: null)
107122
*/
108123
@JsonProperty("use_mmap") private Boolean useMMap;
109124

110125
/**
111-
* ???
126+
* Lock the model in memory, preventing it from being swapped out when memory-mapped.
127+
* This can improve performance but trades away some of the advantages of memory-mapping
128+
* by requiring more RAM to run and potentially slowing down load times as the model loads into RAM.
129+
* (Default: false)
112130
*/
113131
@JsonProperty("use_mlock") private Boolean useMLock;
114132

115133
/**
116-
* Sets the number of threads to use during computation. By default,
117-
* Ollama will detect this for optimal performance. It is recommended to set this
118-
* value to the number of physical CPU cores your system has (as opposed to the
119-
* logical number of cores).
134+
* Set the number of threads to use during generation. For optimal performance, it is recommended to set this value
135+
* to the number of physical CPU cores your system has (as opposed to the logical number of cores).
136+
* Using the correct number of threads can greatly improve performance.
137+
* By default, Ollama will detect this value for optimal performance.
120138
*/
121139
@JsonProperty("num_thread") private Integer numThread;
122140

123141
// Following fields are predict options used at runtime.
124142

125143
/**
126-
* ???
144+
* (Default: 4)
127145
*/
128146
@JsonProperty("num_keep") private Integer numKeep;
129147

@@ -162,7 +180,7 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
162180
@JsonProperty("tfs_z") private Float tfsZ;
163181

164182
/**
165-
* ???
183+
* (Default: 1.0)
166184
*/
167185
@JsonProperty("typical_p") private Float typicalP;
168186

@@ -186,12 +204,12 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
186204
@JsonProperty("repeat_penalty") private Float repeatPenalty;
187205

188206
/**
189-
* ???
207+
* (Default: 0.0)
190208
*/
191209
@JsonProperty("presence_penalty") private Float presencePenalty;
192210

193211
/**
194-
* ???
212+
* (Default: 0.0)
195213
*/
196214
@JsonProperty("frequency_penalty") private Float frequencyPenalty;
197215

@@ -215,7 +233,7 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
215233
@JsonProperty("mirostat_eta") private Float mirostatEta;
216234

217235
/**
218-
* ???
236+
* (Default: true)
219237
*/
220238
@JsonProperty("penalize_newline") private Boolean penalizeNewline;
221239

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.springframework.ai.chat.prompt.Prompt;
3737
import org.springframework.ai.model.function.FunctionCallbackWrapper;
3838
import org.springframework.ai.ollama.api.OllamaApi;
39+
import org.springframework.ai.ollama.api.OllamaModel;
3940
import org.springframework.ai.ollama.api.OllamaOptions;
4041
import org.springframework.ai.ollama.api.tool.MockWeatherService;
4142
import org.springframework.beans.factory.annotation.Autowired;
@@ -55,10 +56,10 @@ class OllamaChatModelFunctionCallingIT {
5556

5657
private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelFunctionCallingIT.class);
5758

58-
private static String MODEL = "mistral";
59+
private static final String MODEL = OllamaModel.MISTRAL.getName();
5960

6061
@Container
61-
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.2.8");
62+
static OllamaContainer ollamaContainer = new OllamaContainer(OllamaImage.DEFAULT_IMAGE);
6263

6364
static String baseUrl = "http://localhost:11434";
6465

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class OllamaChatModelIT {
6363
private static final Log logger = LogFactory.getLog(OllamaChatModelIT.class);
6464

6565
@Container
66-
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.2.8");
66+
static OllamaContainer ollamaContainer = new OllamaContainer(OllamaImage.DEFAULT_IMAGE);
6767

6868
static String baseUrl = "http://localhost:11434";
6969

0 commit comments

Comments
 (0)