Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -75,6 +77,7 @@
* @author Grogdunn
* @author Thomas Vitale
* @author luocongqiu
* @author Ilayaperumal Gopinathan
* @since 1.0.0
*/
public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel {
Expand Down Expand Up @@ -156,8 +159,22 @@ public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
.build();
}

public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result, Usage usage) {
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
return ChatResponseMetadata.builder()
.withId(result.id())
.withModel(result.model())
.withUsage(usage)
.withKeyValue("created", result.created())
.build();
}

@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {

MistralAiApi.ChatCompletionRequest request = createRequest(prompt, false);

Expand Down Expand Up @@ -193,7 +210,10 @@ public ChatResponse call(Prompt prompt) {
return buildGeneration(choice, metadata);
}).toList();

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
MistralAiUsage usage = MistralAiUsage.from(completionEntity.getBody().usage());
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
from(completionEntity.getBody(), cumulativeUsage));

observationContext.setResponse(chatResponse);

Expand All @@ -206,14 +226,18 @@ && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALL
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
}

return response;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
var request = createRequest(prompt, true);

Expand Down Expand Up @@ -259,7 +283,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
// @formatter:on

if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, from(chatCompletion2));
MistralAiUsage usage = MistralAiUsage.from(chatCompletion2.usage());
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse);
return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
}
else {
return new ChatResponse(generations);
Expand All @@ -277,7 +303,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
}
else {
return Flux.just(response);
Expand Down Expand Up @@ -314,7 +340,8 @@ private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
.map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason(), cc.logprobs()))
.toList();

return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices,
chunk.usage());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
return !isInsideTool.get();
})
.concatMapIterable(window -> {
Mono<ChatCompletionChunk> mono1 = window.reduce(new ChatCompletionChunk(null, null, null, null, null),
Mono<ChatCompletionChunk> mono1 = window.reduce(
new ChatCompletionChunk(null, null, null, null, null, null),
(previous, current) -> this.chunkMerger.merge(previous, current));
return List.of(mono1);
})
Expand Down Expand Up @@ -934,6 +935,7 @@ public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("lo
* @param model The model used for the chat completion.
* @param choices A list of chat completion choices. Can be more than one if n is
* greater than 1.
* @param usage usage metrics for the chat completion.
*/
@JsonInclude(Include.NON_NULL)
public record ChatCompletionChunk(
Expand All @@ -942,7 +944,8 @@ public record ChatCompletionChunk(
@JsonProperty("object") String object,
@JsonProperty("created") Long created,
@JsonProperty("model") String model,
@JsonProperty("choices") List<ChunkChoice> choices) {
@JsonProperty("choices") List<ChunkChoice> choices,
@JsonProperty("usage") Usage usage) {
// @formatter:on

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu

ChunkChoice choice = merge(previousChoice0, currentChoice0);

return new ChatCompletionChunk(id, object, created, model, List.of(choice));
MistralAiApi.Usage usage = (current.usage() != null ? current.usage() : previous.usage());

return new ChatCompletionChunk(id, object, created, model, List.of(choice), usage);
}

private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ void functionCallTest() {
logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getText()).containsAnyOf("30.0", "30");
assertThat(response.getMetadata()).isNotNull();
assertThat(response.getMetadata().getUsage()).isNotNull();
assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(1050).isGreaterThan(800);
}

@Test
Expand Down Expand Up @@ -238,6 +241,32 @@ void streamFunctionCallTest() {
assertThat(content).containsAnyOf("10.0", "10");
}

@Test
void streamFunctionCallUsageTest() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Response in Celsius");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = MistralAiChatOptions.builder()
.withModel(MistralAiApi.ChatModel.SMALL.getValue())
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
.build()))
.build();

Flux<ChatResponse> response = this.streamingChatModel.stream(new Prompt(messages, promptOptions));
ChatResponse chatResponse = response.last().block();

logger.info("Response: {}", chatResponse);
assertThat(chatResponse.getMetadata()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(1050).isGreaterThan(800);
}

record ActorsFilmsRecord(String actor, List<String> movies) {

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void mistralAiChatStreamTransientError() {
var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
ChatCompletionFinishReason.STOP, null);
ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789L,
"model", List.of(choice));
"model", List.of(choice), null);

given(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
.willThrow(new TransientAiException("Transient Error 1"))
Expand Down
Loading