Skip to content

Commit 00e6c87

Browse files
committed
Update Ollama APIs and fix multimodality test
* Add doneReason to ChatResponse and update OllamaChatModel accordingly * Add missing descriptions to Ollama options * Consolidate Ollama Testcontainers image setup * Fix multimodality Ollama test Signed-off-by: Thomas Vitale <[email protected]>
1 parent 554fbcd commit 00e6c87

File tree

11 files changed

+179
-108
lines changed

11 files changed

+179
-108
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
*
6060
* @author Christian Tzolov
6161
* @author luocongqiu
62+
* @author Thomas Vitale
6263
* @since 1.0.0
6364
*/
6465
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {
@@ -125,13 +126,13 @@ public ChatResponse call(Prompt prompt) {
125126

126127
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
127128
if (response.promptEvalCount() != null && response.evalCount() != null) {
128-
generationMetadata = ChatGenerationMetadata.from("DONE", null);
129+
generationMetadata = ChatGenerationMetadata.from(response.doneReason(), null);
129130
}
130131

131132
var generator = new Generation(assistantMessage, generationMetadata);
132133
var chatResponse = new ChatResponse(List.of(generator), from(response));
133134

134-
if (isToolCall(chatResponse, Set.of("DONE"))) {
135+
if (isToolCall(chatResponse, Set.of("stop"))) {
135136
var toolCallConversation = handleToolCalls(prompt, chatResponse);
136137
// Recursively call the call method with the tool call message
137138
// conversation that contains the call responses.
@@ -176,7 +177,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
176177

177178
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
178179
if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
179-
generationMetadata = ChatGenerationMetadata.from("DONE", null);
180+
generationMetadata = ChatGenerationMetadata.from(chunk.doneReason(), null);
180181
}
181182

182183
var generator = new Generation(assistantMessage, generationMetadata);
@@ -201,53 +202,43 @@ public Flux<ChatResponse> stream(Prompt prompt) {
201202
*/
202203
OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
203204

204-
List<OllamaApi.Message> ollamaMessages = prompt.getInstructions()
205-
.stream()
206-
.filter(message -> message.getMessageType() == MessageType.USER
207-
|| message.getMessageType() == MessageType.ASSISTANT
208-
|| message.getMessageType() == MessageType.SYSTEM || message.getMessageType() == MessageType.TOOL)
209-
.map(message -> {
210-
if (message instanceof UserMessage userMessage) {
211-
var messageBuilder = OllamaApi.Message.builder(Role.USER).withContent(message.getContent());
212-
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
213-
messageBuilder.withImages(userMessage.getMedia()
214-
.stream()
215-
.map(media -> this.fromMediaData(media.getData()))
216-
.toList());
217-
}
218-
return List.of(messageBuilder.build());
205+
List<OllamaApi.Message> ollamaMessages = prompt.getInstructions().stream().map(message -> {
206+
if (message instanceof UserMessage userMessage) {
207+
var messageBuilder = OllamaApi.Message.builder(Role.USER).withContent(message.getContent());
208+
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
209+
messageBuilder.withImages(
210+
userMessage.getMedia().stream().map(media -> this.fromMediaData(media.getData())).toList());
219211
}
220-
else if (message instanceof SystemMessage systemMessage) {
221-
return List
222-
.of(OllamaApi.Message.builder(Role.SYSTEM).withContent(systemMessage.getContent()).build());
223-
}
224-
else if (message instanceof AssistantMessage assistantMessage) {
225-
List<ToolCall> toolCalls = null;
226-
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
227-
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
228-
var function = new ToolCallFunction(toolCall.name(),
229-
ModelOptionsUtils.jsonToMap(toolCall.arguments()));
230-
return new ToolCall(function);
231-
}).toList();
232-
}
233-
return List.of(OllamaApi.Message.builder(Role.ASSISTANT)
234-
.withContent(assistantMessage.getContent())
235-
.withToolCalls(toolCalls)
236-
.build());
212+
return List.of(messageBuilder.build());
213+
}
214+
else if (message instanceof SystemMessage systemMessage) {
215+
return List.of(OllamaApi.Message.builder(Role.SYSTEM).withContent(systemMessage.getContent()).build());
216+
}
217+
else if (message instanceof AssistantMessage assistantMessage) {
218+
List<ToolCall> toolCalls = null;
219+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
220+
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
221+
var function = new ToolCallFunction(toolCall.name(),
222+
ModelOptionsUtils.jsonToMap(toolCall.arguments()));
223+
return new ToolCall(function);
224+
}).toList();
237225
}
238-
else if (message instanceof ToolResponseMessage toolMessage) {
226+
return List.of(OllamaApi.Message.builder(Role.ASSISTANT)
227+
.withContent(assistantMessage.getContent())
228+
.withToolCalls(toolCalls)
229+
.build());
230+
}
231+
else if (message instanceof ToolResponseMessage toolMessage) {
239232

240-
List<OllamaApi.Message> responseMessages = toolMessage.getResponses()
241-
.stream()
242-
.map(tr -> OllamaApi.Message.builder(Role.TOOL).withContent(tr.responseData()).build())
243-
.toList();
233+
List<OllamaApi.Message> responseMessages = toolMessage.getResponses()
234+
.stream()
235+
.map(tr -> OllamaApi.Message.builder(Role.TOOL).withContent(tr.responseData()).build())
236+
.toList();
244237

245-
return responseMessages;
246-
}
247-
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
248-
})
249-
.flatMap(List::stream)
250-
.toList();
238+
return responseMessages;
239+
}
240+
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
241+
}).flatMap(List::stream).toList();
251242

252243
Set<String> functionsForThisRequest = new HashSet<>();
253244

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
* Java Client for the Ollama API. <a href="https://ollama.ai/">https://ollama.ai</a>
4848
*
4949
* @author Christian Tzolov
50+
* @author Thomas Vitale
5051
* @since 0.8.0
5152
*/
5253
// @formatter:off
@@ -454,15 +455,20 @@ public Message build() {
454455
/**
455456
* Chat request object.
456457
*
457-
* @param model The model to use for completion.
458-
* @param messages The list of messages to chat with.
459-
* @param stream Whether to stream the response.
460-
* @param format The format to return the response in. Currently, the only accepted
461-
* value is "json".
462-
* @param keepAlive The duration to keep the model loaded in ollama while idle.
463-
* @param options Additional model parameters. You can use the {@link OllamaOptions} builder
464-
* to create the options then {@link OllamaOptions#toMap()} to convert the options into a
465-
* map.
458+
* @param model The model to use for completion. It should be a name familiar to Ollama from the <a href="https://ollama.com/library">Library</a>.
459+
* @param messages The list of messages in the chat. This can be used to keep a chat memory.
460+
* @param stream Whether to stream the response. If false, the response will be returned as a single response object rather than a stream of objects.
461+
* @param format The format to return the response in. Currently, the only accepted value is "json".
462+
* @param keepAlive Controls how long the model will stay loaded into memory following this request (default: 5m).
463+
* @param tools List of tools the model has access to.
464+
* @param options Model-specific options. For example, "temperature" can be set through this field, if the model supports it.
465+
* You can use the {@link OllamaOptions} builder to create the options then {@link OllamaOptions#toMap()} to convert the options into a map.
466+
*
467+
* @see <a href=
468+
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Chat
469+
* Completion API</a>
470+
* @see <a href="https://github.com/ollama/ollama/blob/main/api/types.go">Ollama
471+
* Types</a>
466472
*/
467473
@JsonInclude(Include.NON_NULL)
468474
public record ChatRequest(
@@ -471,9 +477,9 @@ public record ChatRequest(
471477
@JsonProperty("stream") Boolean stream,
472478
@JsonProperty("format") String format,
473479
@JsonProperty("keep_alive") String keepAlive,
474-
@JsonProperty("options") Map<String, Object> options,
475-
@JsonProperty("tools") List<Tool> tools) {
476-
480+
@JsonProperty("tools") List<Tool> tools,
481+
@JsonProperty("options") Map<String, Object> options
482+
) {
477483

478484
/**
479485
* Represents a tool the model may call. Currently, only functions are supported as a tool.
@@ -544,8 +550,8 @@ public static class Builder {
544550
private boolean stream = false;
545551
private String format;
546552
private String keepAlive;
547-
private Map<String, Object> options = Map.of();
548553
private List<Tool> tools = List.of();
554+
private Map<String, Object> options = Map.of();
549555

550556
public Builder(String model) {
551557
Assert.notNull(model, "The model can not be null.");
@@ -572,6 +578,11 @@ public Builder withKeepAlive(String keepAlive) {
572578
return this;
573579
}
574580

581+
public Builder withTools(List<Tool> tools) {
582+
this.tools = tools;
583+
return this;
584+
}
585+
575586
public Builder withOptions(Map<String, Object> options) {
576587
Objects.requireNonNull(options, "The options can not be null.");
577588

@@ -585,33 +596,30 @@ public Builder withOptions(OllamaOptions options) {
585596
return this;
586597
}
587598

588-
public Builder withTools(List<Tool> tools) {
589-
this.tools = tools;
590-
return this;
591-
}
592-
593599
public ChatRequest build() {
594-
return new ChatRequest(model, messages, stream, format, keepAlive, options, tools);
600+
return new ChatRequest(model, messages, stream, format, keepAlive, tools, options);
595601
}
596602
}
597603
}
598604

599605
/**
600606
* Ollama chat response object.
601607
*
602-
* @param model The model name used for completion.
603-
* @param createdAt When the request was made.
608+
* @param model The model used for generating the response.
609+
* @param createdAt The timestamp of the response generation.
604610
* @param message The response {@link Message} with {@link Message.Role#ASSISTANT}.
611+
* @param doneReason The reason the model stopped generating text.
605612
* @param done Whether this is the final response. For streaming response only the
606613
* last message is marked as done. If true, this response may be followed by another
607614
* response with the following, additional fields: context, prompt_eval_count,
608615
* prompt_eval_duration, eval_count, eval_duration.
609616
* @param totalDuration Time spent generating the response.
610617
* @param loadDuration Time spent loading the model.
611-
* @param promptEvalCount number of tokens in the prompt.(*)
612-
* @param promptEvalDuration time spent evaluating the prompt.
613-
* @param evalCount number of tokens in the response.
614-
* @param evalDuration time spent generating the response.
618+
* @param promptEvalCount Number of tokens in the prompt.
619+
* @param promptEvalDuration Time spent evaluating the prompt.
620+
* @param evalCount Number of tokens in the response.
621+
* @param evalDuration Time spent generating the response.
622+
*
615623
* @see <a href=
616624
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Chat
617625
* Completion API</a>
@@ -623,13 +631,15 @@ public record ChatResponse(
623631
@JsonProperty("model") String model,
624632
@JsonProperty("created_at") Instant createdAt,
625633
@JsonProperty("message") Message message,
634+
@JsonProperty("done_reason") String doneReason,
626635
@JsonProperty("done") Boolean done,
627636
@JsonProperty("total_duration") Duration totalDuration,
628637
@JsonProperty("load_duration") Duration loadDuration,
629638
@JsonProperty("prompt_eval_count") Integer promptEvalCount,
630639
@JsonProperty("prompt_eval_duration") Duration promptEvalDuration,
631640
@JsonProperty("eval_count") Integer evalCount,
632-
@JsonProperty("eval_duration") Duration evalDuration) {
641+
@JsonProperty("eval_duration") Duration evalDuration
642+
) {
633643
}
634644

635645
/**

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* Helper class for common Ollama models.
2222
*
2323
* @author Siarhei Blashuk
24+
* @author Thomas Vitale
2425
* @since 0.8.1
2526
*/
2627
public enum OllamaModel implements ChatModelDescription {
@@ -35,11 +36,27 @@ public enum OllamaModel implements ChatModelDescription {
3536
*/
3637
LLAMA3("llama3"),
3738

39+
/**
40+
* The 8B language model from Meta.
41+
*/
42+
LLAMA3_1("llama3.1"),
43+
3844
/**
3945
* The 7B parameters model
4046
*/
4147
MISTRAL("mistral"),
4248

49+
/**
50+
* A 12B model with 128k context length, built by Mistral AI in collaboration with
51+
* NVIDIA.
52+
*/
53+
MISTRAL_NEMO("mistral-nemo"),
54+
55+
/**
56+
* A small vision language model designed to run efficiently on edge devices.
57+
*/
58+
MOONDREAM("moondream"),
59+
4360
/**
4461
* The 2.7B uncensored Dolphin model
4562
*/

0 commit comments

Comments
 (0)