diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index 0f41afbdd33..9fd7dff2599 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -36,6 +36,8 @@ import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.metadata.UsageUtils; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -75,6 +77,7 @@ * * @author Geng Rong * @author Alexandros Pappas + * @author Ilayaperumal Gopinathan */ public class MoonshotChatModel extends AbstractToolCallSupport implements ChatModel, StreamingChatModel { @@ -180,6 +183,10 @@ private static Generation buildGeneration(Choice choice, Map met @Override public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt, null); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { ChatCompletionRequest request = createRequest(prompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -218,8 +225,11 @@ public ChatResponse call(Prompt prompt) { // @formatter:on return buildGeneration(choice, metadata); }).toList(); - - ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); + MoonshotApi.Usage usage = completionEntity.getBody().usage(); + Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage(); + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, + from(completionEntity.getBody(), cumulativeUsage)); observationContext.setResponse(chatResponse); @@ -232,7 +242,7 @@ && isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); } return response; } @@ -244,6 +254,10 @@ public ChatOptions getDefaultOptions() { @Override public Flux stream(Prompt prompt) { + return this.internalStream(prompt, null); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -287,8 +301,11 @@ public Flux stream(Prompt prompt) { // @formatter:on return buildGeneration(choice, metadata); }).toList(); + MoonshotApi.Usage usage = chatCompletion2.usage(); + Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage(); + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); - return new ChatResponse(generations, from(chatCompletion2)); + return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage)); } catch (Exception e) { logger.error("Error processing chat completion", e); @@ -303,7 +320,7 @@ public Flux stream(Prompt prompt) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response); } return Flux.just(response); }) @@ -325,6 +342,16 @@ private ChatResponseMetadata from(ChatCompletion result) { .build(); } + private ChatResponseMetadata from(ChatCompletion result, Usage usage) { + Assert.notNull(result, "Moonshot ChatCompletionResult must not be null"); + return ChatResponseMetadata.builder() + .id(result.id() != null ? result.id() : "") + .usage(usage) + .model(result.model() != null ? result.model() : "") + .keyValue("created", result.created() != null ? result.created() : 0L) + .build(); + } + /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert @@ -336,10 +363,11 @@ private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) { if (delta == null) { delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.ASSISTANT); } - return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason()); + return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(), cc.usage()); }).toList(); - - return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null); + // Get the usage from the latest choice + MoonshotApi.Usage usage = choices.get(choices.size() - 1).usage(); + return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, usage); } /** diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java index b4a2162e28b..532fb851b8b 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java @@ -532,7 +532,8 @@ public record Choice( // @formatter:off @JsonProperty("index") Integer index, @JsonProperty("message") ChatCompletionMessage message, - @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) { + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("usage") Usage usage) { // @formatter:on } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java index 06f1dc7655d..df03cbb8015 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java @@ -64,8 +64,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { : previous.finishReason()); Integer index = (current.index() != null ? current.index() : previous.index()); + MoonshotApi.Usage usage = current.usage() != null ? current.usage() : previous.usage(); + ChatCompletionMessage message = merge(previous.delta(), current.delta()); - return new ChunkChoice(index, message, finishReason, null); + return new ChunkChoice(index, message, finishReason, usage); } private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java index 33ef4855623..af8f4c71319 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java @@ -80,7 +80,7 @@ public void beforeEach() { public void moonshotChatTransientError() { var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), - ChatCompletionFinishReason.STOP); + ChatCompletionFinishReason.STOP, null); ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model", List.of(choice), new MoonshotApi.Usage(10, 10, 10)); diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java index 6b4d5ba19b7..f24600653a4 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.ai.moonshot.chat; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -53,6 +54,33 @@ class MoonshotChatModelFunctionCallingIT { @Autowired ChatModel chatModel; + private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool( + MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function( + "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """)); + @Test void functionCallTest() { @@ -89,6 +117,7 @@ void streamFunctionCallTest() { .functionCallbacks(List.of(FunctionCallback.builder() .function("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) .build())) .build(); @@ -108,4 +137,47 @@ void streamFunctionCallTest() { assertThat(content).contains("30", "10", "15"); } + @Test + public void toolFunctionCallWithUsage() { + var promptOptions = MoonshotChatOptions.builder() + .model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) + .tools(Arrays.asList(FUNCTION_TOOL)) + .functionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", + promptOptions); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput()); + assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco"); + assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0"); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); + } + + @Test + public void testStreamFunctionCallUsage() { + var promptOptions = MoonshotChatOptions.builder() + .model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue()) + .tools(Arrays.asList(FUNCTION_TOOL)) + .functionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.", + promptOptions); + + ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getMetadata()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280); + } + }