Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,31 @@
package org.springframework.ai.ollama;

import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.metadata.OllamaUsage;
import org.springframework.util.Assert;
Expand All @@ -54,7 +64,7 @@
* @author luocongqiu
* @since 0.8.0
*/
public class OllamaChatModel implements ChatModel {
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {

/**
* Low-level Ollama API library.
Expand All @@ -71,6 +81,12 @@ public OllamaChatModel(OllamaApi chatApi) {
}

public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions) {
this(chatApi, defaultOptions, null);
}

public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);
Assert.notNull(chatApi, "OllamaApi must not be null");
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
this.chatApi = chatApi;
Expand Down Expand Up @@ -100,11 +116,32 @@ public ChatResponse call(Prompt prompt) {

OllamaApi.ChatResponse response = this.chatApi.chat(ollamaChatRequest(prompt, false));

var generator = new Generation(response.message().content());
List<AssistantMessage.ToolCall> toolCalls = response.message().toolCalls() == null ? List.of()
: response.message()
.toolCalls()
.stream()
.map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
.toList();

var assistantMessage = new AssistantMessage(response.message().content(), Map.of(), toolCalls);

ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
if (response.promptEvalCount() != null && response.evalCount() != null) {
generator = generator.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
generationMetadata = ChatGenerationMetadata.from("DONE", null);
}

var generator = new Generation(assistantMessage, generationMetadata);
var chatResponse = new ChatResponse(List.of(generator), from(response));

if (isToolCall(chatResponse, Set.of("DONE"))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
}
return new ChatResponse(List.of(generator), from(response));

return chatResponse;
}

public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
Expand All @@ -126,15 +163,39 @@ public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
@Override
public Flux<ChatResponse> stream(Prompt prompt) {

Flux<OllamaApi.ChatResponse> response = this.chatApi.streamingChat(ollamaChatRequest(prompt, true));
Flux<OllamaApi.ChatResponse> olamaResponse = this.chatApi.streamingChat(ollamaChatRequest(prompt, true));

Flux<ChatResponse> chatResponse = olamaResponse.map(chunk -> {
String content = (chunk.message() != null) ? chunk.message().content() : "";
List<AssistantMessage.ToolCall> toolCalls = chunk.message().toolCalls() == null ? List.of()
: chunk.message()
.toolCalls()
.stream()
.map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
.toList();

var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);

ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
generationMetadata = ChatGenerationMetadata.from("DONE", null);
}

var generator = new Generation(assistantMessage, generationMetadata);
return new ChatResponse(List.of(generator), from(chunk));
});

return response.map(chunk -> {
Generation generation = (chunk.message() != null) ? new Generation(chunk.message().content())
: new Generation("");
if (Boolean.TRUE.equals(chunk.done())) {
generation = generation.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
return chatResponse.flatMap(response -> {
if (isToolCall(response, Set.of("DONE"))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
return new ChatResponse(List.of(generation), from(chunk));
});
}

Expand All @@ -147,28 +208,61 @@ OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
.stream()
.filter(message -> message.getMessageType() == MessageType.USER
|| message.getMessageType() == MessageType.ASSISTANT
|| message.getMessageType() == MessageType.SYSTEM)
.map(m -> {
var messageBuilder = OllamaApi.Message.builder(toRole(m)).withContent(m.getContent());
if (m instanceof UserMessage userMessage) {
|| message.getMessageType() == MessageType.SYSTEM || message.getMessageType() == MessageType.TOOL)
.map(message -> {
if (message instanceof UserMessage userMessage) {
var messageBuilder = OllamaApi.Message.builder(Role.USER).withContent(message.getContent());
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
messageBuilder.withImages(userMessage.getMedia()
.stream()
.map(media -> this.fromMediaData(media.getData()))
.toList());
}
return List.of(messageBuilder.build());
}
else if (message instanceof SystemMessage systemMessage) {
return List
.of(OllamaApi.Message.builder(Role.SYSTEM).withContent(systemMessage.getContent()).build());
}
else if (message instanceof AssistantMessage assistantMessage) {
List<ToolCall> toolCalls = null;
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
var function = new ToolCallFunction(toolCall.name(),
ModelOptionsUtils.jsonToMap(toolCall.arguments()));
return new ToolCall(function);
}).toList();
}
return List.of(OllamaApi.Message.builder(Role.ASSISTANT)
.withContent(assistantMessage.getContent())
.withToolCalls(toolCalls)
.build());
}
else if (message instanceof ToolResponseMessage toolMessage) {

List<OllamaApi.Message> responseMessages = toolMessage.getResponses()
.stream()
.map(tr -> OllamaApi.Message.builder(Role.TOOL).withContent(tr.responseData()).build())
.toList();

return responseMessages;
}
return messageBuilder.build();
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
})
.flatMap(List::stream)
.toList();

Set<String> functionsForThisRequest = new HashSet<>();

// runtime options
OllamaOptions runtimeOptions = null;
if (prompt.getOptions() != null) {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
OllamaOptions.class);
functionsForThisRequest.addAll(this.handleFunctionCallbackConfigurations(runtimeOptions, IS_RUNTIME_CALL));
}

functionsForThisRequest.addAll(this.handleFunctionCallbackConfigurations(this.defaultOptions, IS_RUNTIME_CALL));
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

// Override the model.
Expand All @@ -190,6 +284,11 @@ OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
requestBuilder.withKeepAlive(mergedOptions.getKeepAlive());
}

// Add the enabled functions definitions to the request's tools parameter.
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
}

return requestBuilder.build();
}

Expand All @@ -206,18 +305,12 @@ else if (mediaData instanceof String text) {

}

private OllamaApi.Message.Role toRole(Message message) {

switch (message.getMessageType()) {
case USER:
return Role.USER;
case ASSISTANT:
return Role.ASSISTANT;
case SYSTEM:
return Role.SYSTEM;
default:
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
private List<ChatRequest.Tool> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
var function = new ChatRequest.Tool.Function(functionCallback.getName(), functionCallback.getDescription(),
functionCallback.getInputTypeSchema());
return new ChatRequest.Tool(function);
}).toList();
}

@Override
Expand Down
Loading