Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@
*/
package org.springframework.ai.mistralai;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand All @@ -39,24 +45,16 @@
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
import org.springframework.ai.mistralai.metadata.MistralAiUsage;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
* @author Ricken Bazolo
* @author Christian Tzolov
Expand All @@ -67,7 +65,7 @@
*/
public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel {

private final Logger log = LoggerFactory.getLogger(getClass());
private final Logger logger = LoggerFactory.getLogger(getClass());

/**
* The default options used for the chat completion requests.
Expand Down Expand Up @@ -108,158 +106,133 @@ public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions option

@Override
public ChatResponse call(Prompt prompt) {
var request = createRequest(prompt, false);

return retryTemplate.execute(ctx -> {

ResponseEntity<ChatCompletion> completionEntity = this.mistralAiApi.chatCompletionEntity(request);
var request = createRequest(prompt, false);

if (this.isToolFunctionCall(completionEntity.getBody())) {
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
completionEntity.getBody());
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions()));
}
ResponseEntity<ChatCompletion> completionEntity = retryTemplate
.execute(ctx -> this.mistralAiApi.chatCompletionEntity(request));

var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
log.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}
ChatCompletion chatCompletion = completionEntity.getBody();

List<Generation> generations = chatCompletion.choices()
.stream()
.map(choice -> new Generation(choice.message().content(), toMap(chatCompletion.id(), choice))
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
.toList();
if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}

return new ChatResponse(generations, from(chatCompletion));
});
}
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id() != null ? chatCompletion.id() : "",
"index", choice.index(),
"role", choice.message().role() != null ? choice.message().role().name() : "",
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
MistralAiUsage usage = MistralAiUsage.from(result.usage());
return ChatResponseMetadata.builder()
.withId(result.id())
.withModel(result.model())
.withUsage(usage)
.withKeyValue("created", result.created())
.build();
}
// // Non function calling.
// RateLimit rateLimit =
// OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);

private Map<String, Object> toMap(String id, ChatCompletion.Choice choice) {
Map<String, Object> map = new HashMap<>();
ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));

var message = choice.message();
if (message.role() != null) {
map.put("role", message.role().name());
}
if (choice.finishReason() != null) {
map.put("finishReason", choice.finishReason().name());
if (isToolCall(chatResponse, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
MistralAiApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
}
map.put("id", id);
return map;

return chatResponse;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
var request = createRequest(prompt, true);

return retryTemplate.execute(ctx -> {

Flux<ChatCompletionChunk> completionChunks = this.mistralAiApi.chatCompletionStream(request);
Flux<ChatCompletionChunk> completionChunks = retryTemplate
.execute(ctx -> this.mistralAiApi.chatCompletionStream(request));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

return completionChunks.map(this::toChatCompletion).switchMap(chatCompletion -> {
if (this.isToolFunctionCall(chatCompletion)) {
var toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
chatCompletion);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions()));
}

return Mono.just(chatCompletion).map(chatCompletion2 -> {
// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::toChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String id = chatCompletion2.id();

// @formatter:off
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generation = new Generation(choice.message().content(),
Map.of("id", id, "role", roleMap.get(id), "finishReason", finish));
if (choice.finishReason() != null) {
generation = generation.withGenerationMetadata(
ChatGenerationMetadata.from(choice.finishReason().name(), null));
}
return generation;
}).toList();

return new ChatResponse(generations);
});
});
});
}

private List<Message> handleToolCallRequests(List<Message> previousMessages, ChatCompletion chatCompletion) {

ChatCompletionMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion);

List<AssistantMessage.ToolCall> assistantToolCalls = nativeAssistantMessage.toolCalls()
.stream()
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", toolCall.function().name(),
toolCall.function().arguments()))
.toList();

AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.content(), Map.of(),
assistantToolCalls);

ToolResponseMessage toolResponseMessage = this.executeFunctions(assistantMessage);

// History
List<Message> messages = new ArrayList<>(previousMessages);
messages.add(assistantMessage);
messages.add(toolResponseMessage);

return messages;
}

private ChatCompletionMessage extractAssistantMessage(ChatCompletion chatCompletion) {
ChatCompletionMessage msg = chatCompletion.choices().iterator().next().message();
if (msg.role() == null) {
// add missing role
msg = new ChatCompletionMessage(msg.content(), ChatCompletionMessage.Role.ASSISTANT, msg.name(),
msg.toolCalls());
}
return msg;
}

protected ToolResponseMessage executeFunctions(AssistantMessage assistantMessage) {

List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
}).toList();
// @formatter:on

if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, from(chatCompletion2));
}
else {
return new ChatResponse(generations);
}
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}

for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
}));

var functionName = toolCall.name();
String functionArguments = toolCall.arguments();
return chatResponse.flatMap(response -> {

if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
if (isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
MistralAiApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
});
}

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
private Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of()
: choice.message()
.toolCalls()
.stream()
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function",
toolCall.function().name(), toolCall.function().arguments()))
.toList();

toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, functionResponse));
}
var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generationMetadata = ChatGenerationMetadata.from(finishReason, null);
return new Generation(assistantMessage, generationMetadata);
}

return new ToolResponseMessage(toolResponses, Map.of());
public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
MistralAiUsage usage = MistralAiUsage.from(result.usage());
return ChatResponseMetadata.builder()
.withId(result.id())
.withModel(result.model())
.withUsage(usage)
.withKeyValue("created", result.created())
.build();
}

private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
Expand Down Expand Up @@ -358,21 +331,6 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
}).toList();
}

protected boolean isToolFunctionCall(ChatCompletion chatCompletion) {

var body = chatCompletion;
if (body == null) {
return false;
}

var choices = body.choices();
if (CollectionUtils.isEmpty(choices)) {
return false;
}

return !CollectionUtils.isEmpty(choices.get(0).message().toolCalls());
}

@Override
public ChatOptions getDefaultOptions() {
return MistralAiChatOptions.fromOptions(this.defaultOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ public ChatResponse call(Prompt prompt) {
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id() != null ? chatCompletion.id() : "",
"role", choice.message().role() != null ? choice.message().role().name() : "",
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
return buildGeneration(choice, metadata);
Expand Down Expand Up @@ -215,6 +216,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
}).toList();
Expand All @@ -236,7 +238,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {

return chatResponse.flatMap(response -> {

if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), "stop"))) {
if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
Expand Down