diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index e9b5528eb33..94fe45595a9 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -40,13 +40,13 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; -import org.springframework.ai.anthropic.metadata.AnthropicUsage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.metadata.UsageUtils; @@ -237,7 +237,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody(); AnthropicApi.Usage usage = completionResponse.usage(); - Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage()) + Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(completionResponse.usage()) : new EmptyUsage(); Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); @@ -256,6 +256,11 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return response; } + private DefaultUsage getDefaultUsage(AnthropicApi.Usage usage) { + return new DefaultUsage(usage.inputTokens(), usage.outputTokens(), usage.inputTokens() + usage.outputTokens(), + usage); + } + @Override public Flux stream(Prompt prompt) { return this.internalStream(prompt, null); @@ -282,7 +287,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux chatResponseFlux = response.switchMap(chatCompletionResponse -> { AnthropicApi.Usage usage = chatCompletionResponse.usage(); - Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage(); + Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(chatCompletionResponse.usage()) : new EmptyUsage(); Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage); @@ -352,7 +357,7 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage } private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { - return from(result, AnthropicUsage.from(result.usage())); + return from(result, this.getDefaultUsage(result.usage())); } private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) { diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java deleted file mode 100644 index fbafc2297d3..00000000000 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.anthropic.metadata; - -import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal AnthropicApi}. - * - * @author Christian Tzolov - * @since 1.0.0 - */ -public class AnthropicUsage implements Usage { - - private final AnthropicApi.Usage usage; - - protected AnthropicUsage(AnthropicApi.Usage usage) { - Assert.notNull(usage, "AnthropicApi Usage must not be null"); - this.usage = usage; - } - - public static AnthropicUsage from(AnthropicApi.Usage usage) { - return new AnthropicUsage(usage); - } - - protected AnthropicApi.Usage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - return getUsage().inputTokens().longValue(); - } - - @Override - public Long getGenerationTokens() { - return getUsage().outputTokens().longValue(); - } - - @Override - public Long getTotalTokens() { - return this.getPromptTokens() + this.getGenerationTokens(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index 1381abe734f..94af5ab14fe 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -84,7 +84,7 @@ private static void validateChatResponseMetadata(ChatResponse response, String m assertThat(response.getMetadata().getId()).isNotEmpty(); assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } @@ -100,11 +100,11 @@ void roleTest(String modelName) { AnthropicChatOptions.builder().model(modelName).build()); ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isGreaterThan(0); assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); assertThat(response.getMetadata().getUsage().getTotalTokens()) .isEqualTo(response.getMetadata().getUsage().getPromptTokens() - + response.getMetadata().getUsage().getGenerationTokens()); + + response.getMetadata().getUsage().getCompletionTokens()); Generation generation = response.getResults().get(0); assertThat(generation.getOutput().getText()).contains("Blackbeard"); assertThat(generation.getMetadata().getFinishReason()).isEqualTo("end_turn"); @@ -139,11 +139,11 @@ void streamingWithTokenUsage() { var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); - // assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens()); + // assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens()); // assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens()); } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java index e38fe0d548c..20a8f037aad 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java @@ -147,7 +147,7 @@ private void validate(ChatResponseMetadata responseMetadata, String finishReason .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index c197e2877de..d527115da4f 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -60,13 +60,13 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import reactor.core.publisher.Flux; -import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; @@ -194,13 +194,14 @@ public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptM } public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { - Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage(); + Usage usage = (chatCompletions.getUsage() != null) ? getDefaultUsage(chatCompletions.getUsage()) + : new EmptyUsage(); return from(chatCompletions, promptFilterMetadata, usage); } public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata, CompletionsUsage usage) { - return from(chatCompletions, promptFilterMetadata, AzureOpenAiUsage.from(usage)); + return from(chatCompletions, promptFilterMetadata, getDefaultUsage(usage)); } public static ChatResponseMetadata from(ChatResponse chatResponse, Usage usage) { @@ -217,6 +218,10 @@ public static ChatResponseMetadata from(ChatResponse chatResponse, Usage usage) return builder.build(); } + private static DefaultUsage getDefaultUsage(CompletionsUsage usage) { + return new DefaultUsage(usage.getPromptTokens(), usage.getCompletionTokens(), usage.getTotalTokens(), usage); + } + public AzureOpenAiChatOptions getDefaultOptions() { return AzureOpenAiChatOptions.fromOptions(this.defaultOptions); } @@ -321,7 +326,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha } // Accumulate the usage from the previous chat response CompletionsUsage usage = chatCompletion.getUsage(); - Usage currentChatResponseUsage = usage != null ? AzureOpenAiUsage.from(usage) : new EmptyUsage(); + Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); return toChatResponse(chatCompletion, accumulatedUsage); }).buffer(2, 1).map(bufferList -> { @@ -412,7 +417,7 @@ private ChatResponse toChatResponse(ChatCompletions chatCompletions, ChatRespons PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); Usage currentUsage = null; if (chatCompletions.getUsage() != null) { - currentUsage = AzureOpenAiUsage.from(chatCompletions); + currentUsage = getDefaultUsage(chatCompletions.getUsage()); } Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, cumulativeUsage)); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java index ab9d0518093..a5f5b335781 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java @@ -23,11 +23,12 @@ import com.azure.ai.openai.models.EmbeddingItem; import com.azure.ai.openai.models.Embeddings; import com.azure.ai.openai.models.EmbeddingsOptions; +import com.azure.ai.openai.models.EmbeddingsUsage; import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.azure.openai.metadata.AzureOpenAiEmbeddingUsage; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -159,10 +160,14 @@ EmbeddingsOptions toEmbeddingOptions(EmbeddingRequest embeddingRequest) { private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) { List data = generateEmbeddingList(embeddings.getData()); EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); - metadata.setUsage(AzureOpenAiEmbeddingUsage.from(embeddings.getUsage())); + metadata.setUsage(getDefaultUsage(embeddings.getUsage())); return new EmbeddingResponse(data, metadata); } + private DefaultUsage getDefaultUsage(EmbeddingsUsage usage) { + return new DefaultUsage(usage.getPromptTokens(), 0, usage.getTotalTokens(), usage); + } + private List generateEmbeddingList(List nativeData) { List data = new ArrayList<>(); for (EmbeddingItem nativeDatum : nativeData) { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java deleted file mode 100644 index 8fe0fa1e42b..00000000000 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.azure.openai.metadata; - -import com.azure.ai.openai.models.EmbeddingsUsage; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal Microsoft Azure OpenAI Service} embedding. - * - * @author Thomas Vitale - * @see EmbeddingsUsage - */ -public class AzureOpenAiEmbeddingUsage implements Usage { - - private final EmbeddingsUsage usage; - - public AzureOpenAiEmbeddingUsage(EmbeddingsUsage usage) { - Assert.notNull(usage, "EmbeddingsUsage must not be null"); - this.usage = usage; - } - - public static AzureOpenAiEmbeddingUsage from(EmbeddingsUsage usage) { - Assert.notNull(usage, "EmbeddingsUsage must not be null"); - return new AzureOpenAiEmbeddingUsage(usage); - } - - protected EmbeddingsUsage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - return (long) getUsage().getPromptTokens(); - } - - @Override - public Long getGenerationTokens() { - return 0L; - } - - @Override - public Long getTotalTokens() { - return (long) getUsage().getTotalTokens(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java deleted file mode 100644 index b0dd15d1367..00000000000 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.azure.openai.metadata; - -import com.azure.ai.openai.models.ChatCompletions; -import com.azure.ai.openai.models.CompletionsUsage; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal Microsoft Azure OpenAI Service} chat. - * - * @author John Blum - * @see com.azure.ai.openai.models.CompletionsUsage - * @since 0.7.0 - */ -public class AzureOpenAiUsage implements Usage { - - private final CompletionsUsage usage; - - public AzureOpenAiUsage(CompletionsUsage usage) { - Assert.notNull(usage, "CompletionsUsage must not be null"); - this.usage = usage; - } - - public static AzureOpenAiUsage from(ChatCompletions chatCompletions) { - Assert.notNull(chatCompletions, "ChatCompletions must not be null"); - return from(chatCompletions.getUsage()); - } - - public static AzureOpenAiUsage from(CompletionsUsage usage) { - return new AzureOpenAiUsage(usage); - } - - protected CompletionsUsage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - return (long) getUsage().getPromptTokens(); - } - - @Override - public Long getGenerationTokens() { - return (long) getUsage().getCompletionTokens(); - } - - @Override - public Long getTotalTokens() { - return (long) getUsage().getTotalTokens(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java index 8f0e4379fc0..9310c746dc6 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java @@ -166,7 +166,7 @@ private void validate(ChatResponseMetadata responseMetadata, boolean checkModel) String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue( ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue( ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java index d10bd355bf0..e07b0ff6a0f 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java @@ -115,7 +115,7 @@ private void assertGenerationMetadata(ChatResponse response) { assertThat(usage).isNotNull(); assertThat(usage.getPromptTokens()).isEqualTo(58); - assertThat(usage.getGenerationTokens()).isEqualTo(68); + assertThat(usage.getCompletionTokens()).isEqualTo(68); assertThat(usage.getTotalTokens()).isEqualTo(126); } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 8648ec91796..37afa84de43 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -547,15 +547,15 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv allGenerations.add(toolCallGeneration); } - Long promptTokens = response.usage().inputTokens().longValue(); - Long generationTokens = response.usage().outputTokens().longValue(); - Long totalTokens = response.usage().totalTokens().longValue(); + Integer promptTokens = response.usage().inputTokens(); + Integer generationTokens = response.usage().outputTokens(); + int totalTokens = response.usage().totalTokens(); if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null && perviousChatResponse.getMetadata().getUsage() != null) { promptTokens += perviousChatResponse.getMetadata().getUsage().getPromptTokens(); - generationTokens += perviousChatResponse.getMetadata().getUsage().getGenerationTokens(); + generationTokens += perviousChatResponse.getMetadata().getUsage().getCompletionTokens(); totalTokens += perviousChatResponse.getMetadata().getUsage().getTotalTokens(); } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/BedrockUsage.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/BedrockUsage.java deleted file mode 100644 index ac58a7ca502..00000000000 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/BedrockUsage.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.bedrock.converse.api; - -import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for Bedrock Converse API. - * - * @author Christian Tzolov - * @author Wei Jiang - * @since 1.0.0 - */ -public class BedrockUsage implements Usage { - - public static BedrockUsage from(TokenUsage usage) { - Assert.notNull(usage, "'TokenUsage' must not be null."); - - return new BedrockUsage(usage.inputTokens().longValue(), usage.outputTokens().longValue()); - } - - private final Long inputTokens; - - private final Long outputTokens; - - protected BedrockUsage(Long inputTokens, Long outputTokens) { - this.inputTokens = inputTokens; - this.outputTokens = outputTokens; - } - - @Override - public Long getPromptTokens() { - return this.inputTokens; - } - - @Override - public Long getGenerationTokens() { - return this.outputTokens; - } - - @Override - public String toString() { - return "BedrockUsage [inputTokens=" + this.inputTokens + ", outputTokens=" + this.outputTokens + "]"; - } - -} diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index b43ce292f81..ed44a72f8bc 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -122,9 +122,9 @@ public static Flux toChatResponse(Flux respo List toolCalls = new ArrayList<>(); - Long promptTokens = 0L; - Long generationTokens = 0L; - Long totalTokens = 0L; + Integer promptTokens = 0; + Integer generationTokens = 0; + Integer totalTokens = 0; for (ToolUseAggregationEvent.ToolUseEntry toolUseEntry : toolUseAggregationEvent.toolUseEntries()) { var functionCallId = toolUseEntry.id(); @@ -135,7 +135,7 @@ public static Flux toChatResponse(Flux respo if (toolUseEntry.usage() != null) { promptTokens += toolUseEntry.usage().getPromptTokens(); - generationTokens += toolUseEntry.usage().getGenerationTokens(); + generationTokens += toolUseEntry.usage().getCompletionTokens(); totalTokens += toolUseEntry.usage().getTotalTokens(); } } @@ -207,9 +207,8 @@ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) { Document modelResponseFields = lastAggregation.metadataAggregation().additionalModelResponseFields(); ConverseStreamMetrics metrics = metadataEvent.metrics(); - DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(), - metadataEvent.usage().outputTokens().longValue(), - metadataEvent.usage().totalTokens().longValue()); + DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens(), + metadataEvent.usage().outputTokens(), metadataEvent.usage().totalTokens()); var chatResponseMetaData = ChatResponseMetadata.builder().usage(usage).build(); @@ -231,9 +230,9 @@ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) { var metadataBuilder = ChatResponseMetadata.builder(); - Long promptTokens = perviousChatResponse.getMetadata().getUsage().getPromptTokens(); - Long generationTokens = perviousChatResponse.getMetadata().getUsage().getGenerationTokens(); - Long totalTokens = perviousChatResponse.getMetadata().getUsage().getTotalTokens(); + Integer promptTokens = perviousChatResponse.getMetadata().getUsage().getPromptTokens(); + Integer generationTokens = perviousChatResponse.getMetadata().getUsage().getCompletionTokens(); + int totalTokens = perviousChatResponse.getMetadata().getUsage().getTotalTokens(); if (chatResponse.getMetadata() != null) { metadataBuilder.id(chatResponse.getMetadata().getId()); @@ -244,7 +243,7 @@ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) { if (chatResponse.getMetadata().getUsage() != null) { promptTokens = promptTokens + chatResponse.getMetadata().getUsage().getPromptTokens(); generationTokens = generationTokens - + chatResponse.getMetadata().getUsage().getGenerationTokens(); + + chatResponse.getMetadata().getUsage().getCompletionTokens(); totalTokens = totalTokens + chatResponse.getMetadata().getUsage().getTotalTokens(); } } @@ -290,8 +289,8 @@ else if (event.sdkEventType() == EventType.MESSAGE_STOP) { } else if (event.sdkEventType() == EventType.METADATA) { ConverseStreamMetadataEvent metadataEvent = (ConverseStreamMetadataEvent) event; - DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(), - metadataEvent.usage().outputTokens().longValue(), metadataEvent.usage().totalTokens().longValue()); + DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens(), + metadataEvent.usage().outputTokens(), metadataEvent.usage().totalTokens()); toolUseEventAggregator.withUsage(usage); if (!toolUseEventAggregator.isEmpty()) { diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 86febbe1f0a..80f6c935134 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -249,11 +249,11 @@ void functionCallWithUsageMetadataTest() { assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(500); assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500); - assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0); - assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500); + assertThat(metadata.getUsage().getCompletionTokens()).isGreaterThan(0); + assertThat(metadata.getUsage().getCompletionTokens()).isLessThan(1500); assertThat(metadata.getUsage().getTotalTokens()) - .isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens()); + .isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getCompletionTokens()); logger.info("Response: {}", response); @@ -330,11 +330,11 @@ void streamFunctionCallTest() { assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(1500); assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500); - assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0); - assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500); + assertThat(metadata.getUsage().getCompletionTokens()).isGreaterThan(0); + assertThat(metadata.getUsage().getCompletionTokens()).isLessThan(1500); assertThat(metadata.getUsage().getTotalTokens()) - .isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens()); + .isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getCompletionTokens()); String content = chatResponses.stream() .filter(cr -> cr.getResult() != null) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java index f319d442389..bdbe95bda48 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java @@ -85,7 +85,7 @@ public void call() { assertThat(result.getResult().getOutput().getText()).isSameAs("Response Content Block"); assertThat(result.getMetadata().getUsage().getPromptTokens()).isEqualTo(16); - assertThat(result.getMetadata().getUsage().getGenerationTokens()).isEqualTo(14); + assertThat(result.getMetadata().getUsage().getCompletionTokens()).isEqualTo(14); assertThat(result.getMetadata().getUsage().getTotalTokens()).isEqualTo(30); } @@ -151,7 +151,7 @@ public void callWithToolUse() { .isSameAs(converseResponseFinal.output().message().content().get(0).text()); assertThat(result.getMetadata().getUsage().getPromptTokens()).isEqualTo(445 + 540); - assertThat(result.getMetadata().getUsage().getGenerationTokens()).isEqualTo(119 + 106); + assertThat(result.getMetadata().getUsage().getCompletionTokens()).isEqualTo(119 + 106); assertThat(result.getMetadata().getUsage().getTotalTokens()).isEqualTo(564 + 646); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java index 6a4b84c0f23..f4ffbcbdd21 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java @@ -77,7 +77,7 @@ private static void validateChatResponseMetadata(ChatResponse response, String m // assertThat(response.getMetadata().getId()).isNotEmpty(); // assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } @@ -93,11 +93,11 @@ void roleTest(String modelName) { FunctionCallingOptions.builder().model(modelName).build()); ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isGreaterThan(0); assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); assertThat(response.getMetadata().getUsage().getTotalTokens()) .isEqualTo(response.getMetadata().getUsage().getPromptTokens() - + response.getMetadata().getUsage().getGenerationTokens()); + + response.getMetadata().getUsage().getCompletionTokens()); Generation generation = response.getResults().get(0); assertThat(generation.getOutput().getText()).contains("Blackbeard"); assertThat(generation.getMetadata().getFinishReason()).isEqualTo("end_turn"); @@ -133,11 +133,11 @@ void streamingWithTokenUsage() { var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); - assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens()); + assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens()); assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens()); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java index 88ed54fc4b8..aefcf6a728b 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java @@ -149,7 +149,7 @@ private void validate(ChatResponseMetadata responseMetadata, String finishReason .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java deleted file mode 100644 index 6394090443a..00000000000 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.bedrock; - -import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetrics; -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for Bedrock API. - * - * @author Christian Tzolov - * @since 0.8.0 - */ -public class BedrockUsage implements Usage { - - private final AmazonBedrockInvocationMetrics usage; - - protected BedrockUsage(AmazonBedrockInvocationMetrics usage) { - Assert.notNull(usage, "Bedrock Usage must not be null"); - this.usage = usage; - } - - public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) { - return new BedrockUsage(usage); - } - - protected AmazonBedrockInvocationMetrics getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - return getUsage().inputTokenCount().longValue(); - } - - @Override - public Long getGenerationTokens() { - return getUsage().outputTokenCount().longValue(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index f5c3bbed14e..977539fccf3 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -132,8 +132,7 @@ public Flux stream(Prompt prompt) { } protected Usage extractUsage(AnthropicChatResponse response) { - return new DefaultUsage(response.usage().inputTokens().longValue(), - response.usage().outputTokens().longValue()); + return new DefaultUsage(response.usage().inputTokens(), response.usage().outputTokens()); } /** diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java index 7596a9edb45..5506245932a 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java @@ -20,13 +20,14 @@ import reactor.core.publisher.Flux; -import org.springframework.ai.bedrock.BedrockUsage; import org.springframework.ai.bedrock.MessageToPromptConverter; +import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -80,7 +81,7 @@ public Flux stream(Prompt prompt) { return this.chatApi.chatCompletionStream(this.createRequest(prompt, true)).map(g -> { if (g.isFinished()) { String finishReason = g.finishReason().name(); - Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); + Usage usage = getDefaultUsage(g.amazonBedrockInvocationMetrics()); return new ChatResponse(List.of(new Generation(new AssistantMessage(""), ChatGenerationMetadata.builder().finishReason(finishReason).metadata("usage", usage).build()))); } @@ -88,6 +89,11 @@ public Flux stream(Prompt prompt) { }); } + private DefaultUsage getDefaultUsage(AbstractBedrockApi.AmazonBedrockInvocationMetrics usage) { + return new DefaultUsage(usage.inputTokenCount().intValue(), usage.outputTokenCount().intValue(), + usage.inputTokenCount().intValue() + usage.outputTokenCount().intValue(), usage); + } + /** * Test access. */ diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java index 852e829538f..d76b5d3920e 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java @@ -16,7 +16,9 @@ package org.springframework.ai.bedrock.llama; +import java.util.HashMap; import java.util.List; +import java.util.Map; import reactor.core.publisher.Flux; @@ -100,13 +102,22 @@ private Usage extractUsage(LlamaChatResponse response) { return new Usage() { @Override - public Long getPromptTokens() { - return response.promptTokenCount().longValue(); + public Integer getPromptTokens() { + return response.promptTokenCount(); } @Override - public Long getGenerationTokens() { - return response.generationTokenCount().longValue(); + public Integer getCompletionTokens() { + return response.generationTokenCount(); + } + + @Override + public Map getNativeUsage() { + Map usage = new HashMap<>(); + usage.put("promptTokens", getPromptTokens()); + usage.put("completionTokens", getCompletionTokens()); + usage.put("totalTokens", getTotalTokens()); + return usage; } }; } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java index bda3747595c..41eb229f214 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java @@ -16,7 +16,9 @@ package org.springframework.ai.bedrock.titan; +import java.util.HashMap; import java.util.List; +import java.util.Map; import reactor.core.publisher.Flux; @@ -140,13 +142,22 @@ private Usage extractUsage(TitanChatResponseChunk response) { return new Usage() { @Override - public Long getPromptTokens() { - return response.inputTextTokenCount().longValue(); + public Integer getPromptTokens() { + return response.inputTextTokenCount(); } @Override - public Long getGenerationTokens() { - return response.totalOutputTextTokenCount().longValue(); + public Integer getCompletionTokens() { + return response.totalOutputTextTokenCount(); + } + + @Override + public Map getNativeUsage() { + Map usage = new HashMap<>(); + usage.put("promptTokens", getPromptTokens()); + usage.put("completionTokens", getCompletionTokens()); + usage.put("totalTokens", getTotalTokens()); + return usage; } }; } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index 3c7eb1fff98..e1660d18ae1 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -36,6 +36,7 @@ import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; @@ -60,7 +61,6 @@ import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionRequest; import org.springframework.ai.minimax.api.MiniMaxApiConstants; -import org.springframework.ai.minimax.metadata.MiniMaxUsage; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; @@ -388,13 +388,17 @@ private ChatResponseMetadata from(ChatCompletion result) { Assert.notNull(result, "MiniMax ChatCompletionResult must not be null"); return ChatResponseMetadata.builder() .id(result.id() != null ? result.id() : "") - .usage(result.usage() != null ? MiniMaxUsage.from(result.usage()) : new EmptyUsage()) + .usage(result.usage() != null ? getDefaultUsage(result.usage()) : new EmptyUsage()) .model(result.model() != null ? result.model() : "") .keyValue("created", result.created() != null ? result.created() : 0L) .keyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : "") .build(); } + private DefaultUsage getDefaultUsage(MiniMaxApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + private Generation buildGeneration(ChatCompletionMessage message, ChatCompletionFinishReason completionFinishReason, Map metadata) { if (message == null || message.role() == Role.TOOL) { diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java index 340c5ddc3fb..fec3b0c310b 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java @@ -23,6 +23,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -37,7 +38,6 @@ import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; import org.springframework.ai.minimax.api.MiniMaxApi; import org.springframework.ai.minimax.api.MiniMaxApiConstants; -import org.springframework.ai.minimax.metadata.MiniMaxUsage; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; import org.springframework.lang.Nullable; @@ -171,8 +171,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { return new EmbeddingResponse(List.of()); } - var metadata = new EmbeddingResponseMetadata(apiRequest.model(), - MiniMaxUsage.from(new MiniMaxApi.Usage(0, 0, apiEmbeddingResponse.totalTokens()))); + var metadata = new EmbeddingResponseMetadata(apiRequest.model(), getDefaultUsage(apiEmbeddingResponse)); List embeddings = new ArrayList<>(); for (int i = 0; i < apiEmbeddingResponse.vectors().size(); i++) { @@ -185,6 +184,10 @@ public EmbeddingResponse call(EmbeddingRequest request) { }); } + private DefaultUsage getDefaultUsage(MiniMaxApi.EmbeddingList apiEmbeddingList) { + return new DefaultUsage(0, 0, apiEmbeddingList.totalTokens()); + } + /** * Merge runtime and default {@link EmbeddingOptions} to compute the final options to * use in the request. diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java deleted file mode 100644 index cb8a5a74a0b..00000000000 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/metadata/MiniMaxUsage.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.minimax.metadata; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.minimax.api.MiniMaxApi; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal MiniMax}. - * - * @author Thomas Vitale - */ -public class MiniMaxUsage implements Usage { - - private final MiniMaxApi.Usage usage; - - protected MiniMaxUsage(MiniMaxApi.Usage usage) { - Assert.notNull(usage, "MiniMax Usage must not be null"); - this.usage = usage; - } - - public static MiniMaxUsage from(MiniMaxApi.Usage usage) { - return new MiniMaxUsage(usage); - } - - protected MiniMaxApi.Usage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - Integer promptTokens = getUsage().promptTokens(); - return promptTokens != null ? promptTokens.longValue() : 0; - } - - @Override - public Long getGenerationTokens() { - Integer generationTokens = getUsage().completionTokens(); - return generationTokens != null ? generationTokens.longValue() : 0; - } - - @Override - public Long getTotalTokens() { - Integer totalTokens = getUsage().totalTokens(); - if (totalTokens != null) { - return totalTokens.longValue(); - } - return getPromptTokens() + getGenerationTokens(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatModelObservationIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatModelObservationIT.java index 6eb7d1fc278..6f8191a1114 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatModelObservationIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatModelObservationIT.java @@ -148,7 +148,7 @@ private void validate(ChatResponseMetadata responseMetadata) { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 98376720f2d..80438008666 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -38,6 +38,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.metadata.UsageUtils; import org.springframework.ai.chat.model.AbstractToolCallSupport; @@ -59,7 +60,6 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ChatCompletionFunction; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest; -import org.springframework.ai.mistralai.metadata.MistralAiUsage; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; @@ -154,7 +154,7 @@ public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions option public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); - MistralAiUsage usage = MistralAiUsage.from(result.usage()); + DefaultUsage usage = getDefaultUsage(result.usage()); return ChatResponseMetadata.builder() .id(result.id()) .model(result.model()) @@ -173,6 +173,10 @@ public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result, Usag .build(); } + private static DefaultUsage getDefaultUsage(MistralAiApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + @Override public ChatResponse call(Prompt prompt) { return this.internalCall(prompt, null); @@ -214,7 +218,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return buildGeneration(choice, metadata); }).toList(); - MistralAiUsage usage = MistralAiUsage.from(completionEntity.getBody().usage()); + DefaultUsage usage = getDefaultUsage(completionEntity.getBody().usage()); Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse); ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody(), cumulativeUsage)); @@ -287,7 +291,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:on if (chatCompletion2.usage() != null) { - MistralAiUsage usage = MistralAiUsage.from(chatCompletion2.usage()); + DefaultUsage usage = getDefaultUsage(chatCompletion2.usage()); Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse); return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage)); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index c3f0a13e412..834bcfcd926 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -22,6 +22,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -36,7 +37,6 @@ import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; import org.springframework.ai.mistralai.api.MistralAiApi; -import org.springframework.ai.mistralai.metadata.MistralAiUsage; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; import org.springframework.retry.support.RetryTemplate; @@ -131,7 +131,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { } var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(), - MistralAiUsage.from(apiEmbeddingResponse.usage())); + getDefaultUsage(apiEmbeddingResponse.usage())); var embeddings = apiEmbeddingResponse.data() .stream() @@ -146,6 +146,10 @@ public EmbeddingResponse call(EmbeddingRequest request) { }); } + private DefaultUsage getDefaultUsage(MistralAiApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + @SuppressWarnings("unchecked") private MistralAiApi.EmbeddingRequest> createRequest(EmbeddingRequest request) { var embeddingRequest = new MistralAiApi.EmbeddingRequest<>(request.getInstructions(), diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java deleted file mode 100644 index dbcc9a9d49b..00000000000 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/metadata/MistralAiUsage.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mistralai.metadata; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.mistralai.api.MistralAiApi; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal Mistral AI}. - * - * @author Thomas Vitale - * @since 1.0.0 - * @see Chat Completion API - */ -public class MistralAiUsage implements Usage { - - private final MistralAiApi.Usage usage; - - protected MistralAiUsage(MistralAiApi.Usage usage) { - Assert.notNull(usage, "Mistral AI Usage must not be null"); - this.usage = usage; - } - - public static MistralAiUsage from(MistralAiApi.Usage usage) { - return new MistralAiUsage(usage); - } - - protected MistralAiApi.Usage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - return getUsage().promptTokens().longValue(); - } - - @Override - public Long getGenerationTokens() { - return getUsage().completionTokens().longValue(); - } - - @Override - public Long getTotalTokens() { - return getUsage().totalTokens().longValue(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java index 32ed0a9f9f5..1cd169f3419 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java @@ -305,7 +305,7 @@ void validateCallResponseMetadata() { assertThat(response.getMetadata().getId()).isNotEmpty(); assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java index aeda26afff3..702f84874e5 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java @@ -158,7 +158,7 @@ private void validate(ChatResponseMetadata responseMetadata) { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index 9fd7dff2599..403c15a468c 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -35,6 +35,7 @@ import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.metadata.UsageUtils; @@ -65,7 +66,6 @@ import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionRequest; import org.springframework.ai.moonshot.api.MoonshotApi.FunctionTool; import org.springframework.ai.moonshot.api.MoonshotConstants; -import org.springframework.ai.moonshot.metadata.MoonshotUsage; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; @@ -226,7 +226,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return buildGeneration(choice, metadata); }).toList(); MoonshotApi.Usage usage = completionEntity.getBody().usage(); - Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage(); + Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage(); Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody(), cumulativeUsage)); @@ -247,6 +247,10 @@ && isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS return response; } + private DefaultUsage getDefaultUsage(MoonshotApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + @Override public ChatOptions getDefaultOptions() { return this.defaultOptions.copy(); @@ -302,7 +306,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha return buildGeneration(choice, metadata); }).toList(); MoonshotApi.Usage usage = chatCompletion2.usage(); - Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage(); + Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage(); Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage)); @@ -336,7 +340,7 @@ private ChatResponseMetadata from(ChatCompletion result) { Assert.notNull(result, "Moonshot ChatCompletionResult must not be null"); return ChatResponseMetadata.builder() .id(result.id() != null ? result.id() : "") - .usage(result.usage() != null ? MoonshotUsage.from(result.usage()) : new EmptyUsage()) + .usage(result.usage() != null ? getDefaultUsage(result.usage()) : new EmptyUsage()) .model(result.model() != null ? result.model() : "") .keyValue("created", result.created() != null ? result.created() : 0L) .build(); diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java deleted file mode 100644 index b40e51f7ad5..00000000000 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/metadata/MoonshotUsage.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.moonshot.metadata; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.moonshot.api.MoonshotApi; -import org.springframework.util.Assert; - -/** - * Represents the usage of a Moonshot model. - * - * @author Geng Rong - */ -public class MoonshotUsage implements Usage { - - private final MoonshotApi.Usage usage; - - protected MoonshotUsage(MoonshotApi.Usage usage) { - Assert.notNull(usage, "Moonshot Usage must not be null"); - this.usage = usage; - } - - public static MoonshotUsage from(MoonshotApi.Usage usage) { - return new MoonshotUsage(usage); - } - - protected MoonshotApi.Usage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - return getUsage().promptTokens().longValue(); - } - - @Override - public Long getGenerationTokens() { - return getUsage().completionTokens().longValue(); - } - - @Override - public Long getTotalTokens() { - return getUsage().totalTokens().longValue(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java index 3aa14844824..7283f304a04 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelObservationIT.java @@ -150,7 +150,7 @@ private void validate(ChatResponseMetadata responseMetadata) { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 9057100cd27..3329462bee4 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import io.micrometer.observation.Observation; @@ -60,7 +61,6 @@ import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; -import org.springframework.ai.ollama.metadata.OllamaChatUsage; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -132,10 +132,10 @@ public static Builder builder() { static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse previousChatResponse) { Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); - OllamaChatUsage newUsage = OllamaChatUsage.from(response); - Long promptTokens = newUsage.getPromptTokens(); - Long generationTokens = newUsage.getGenerationTokens(); - Long totalTokens = newUsage.getTotalTokens(); + DefaultUsage newUsage = getDefaultUsage(response); + Integer promptTokens = newUsage.getPromptTokens(); + Integer generationTokens = newUsage.getCompletionTokens(); + int totalTokens = newUsage.getTotalTokens(); Duration evalDuration = response.getEvalDuration(); Duration promptEvalDuration = response.getPromptEvalDuration(); @@ -158,7 +158,7 @@ static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse p } if (previousChatResponse.getMetadata().getUsage() != null) { promptTokens += previousChatResponse.getMetadata().getUsage().getPromptTokens(); - generationTokens += previousChatResponse.getMetadata().getUsage().getGenerationTokens(); + generationTokens += previousChatResponse.getMetadata().getUsage().getCompletionTokens(); totalTokens += previousChatResponse.getMetadata().getUsage().getTotalTokens(); } } @@ -170,7 +170,7 @@ static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse p .model(response.model()) .keyValue(METADATA_CREATED_AT, response.createdAt()) .keyValue(METADATA_EVAL_DURATION, evalDuration) - .keyValue(METADATA_EVAL_COUNT, aggregatedUsage.getGenerationTokens().intValue()) + .keyValue(METADATA_EVAL_COUNT, aggregatedUsage.getCompletionTokens().intValue()) .keyValue(METADATA_LOAD_DURATION, loadDuration) .keyValue(METADATA_PROMPT_EVAL_DURATION, promptEvalDuration) .keyValue(METADATA_PROMPT_EVAL_COUNT, aggregatedUsage.getPromptTokens().intValue()) @@ -179,6 +179,11 @@ static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse p .build(); } + private static DefaultUsage getDefaultUsage(OllamaApi.ChatResponse response) { + return new DefaultUsage(Optional.ofNullable(response.promptEvalCount()).orElse(0), + Optional.ofNullable(response.evalCount()).orElse(0)); + } + @Override public ChatResponse call(Prompt prompt) { return this.internalCall(prompt, null); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java index f796ec63307..9a5e7a8ab54 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java @@ -18,12 +18,14 @@ import java.time.Duration; import java.util.List; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Matcher; import java.util.regex.Pattern; import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingModel; import org.springframework.ai.embedding.Embedding; @@ -45,7 +47,6 @@ import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; -import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -126,7 +127,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { .toList(); EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata(response.model(), - OllamaEmbeddingUsage.from(response)); + getDefaultUsage(response)); EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, embeddingResponseMetadata); @@ -136,6 +137,10 @@ public EmbeddingResponse call(EmbeddingRequest request) { }); } + private DefaultUsage getDefaultUsage(OllamaApi.EmbeddingsResponse response) { + return new DefaultUsage(Optional.ofNullable(response.promptEvalCount()).orElse(0), 0); + } + /** * Package access for testing. */ diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java deleted file mode 100644 index 3ccf39b4c86..00000000000 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.ollama.metadata; - -import java.util.Optional; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal Ollama} - * - * @see Usage - * @author Fu Cheng - */ -public class OllamaChatUsage implements Usage { - - protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }"; - - private final OllamaApi.ChatResponse response; - - public OllamaChatUsage(OllamaApi.ChatResponse response) { - this.response = response; - } - - public static OllamaChatUsage from(OllamaApi.ChatResponse response) { - Assert.notNull(response, "OllamaApi.ChatResponse must not be null"); - return new OllamaChatUsage(response); - } - - @Override - public Long getPromptTokens() { - return Optional.ofNullable(this.response.promptEvalCount()).map(Integer::longValue).orElse(0L); - } - - @Override - public Long getGenerationTokens() { - return Optional.ofNullable(this.response.evalCount()).map(Integer::longValue).orElse(0L); - } - - @Override - public String toString() { - return AI_USAGE_STRING.formatted(getPromptTokens(), getGenerationTokens(), getTotalTokens()); - } - -} diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java deleted file mode 100644 index c75ebaac15a..00000000000 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaEmbeddingUsage.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.ollama.metadata; - -import java.util.Optional; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal Ollama} embeddings. - * - * @see Usage - * @author Christian Tzolov - */ -public class OllamaEmbeddingUsage implements Usage { - - protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }"; - - private Long promptTokens; - - public OllamaEmbeddingUsage(EmbeddingsResponse response) { - this.promptTokens = Optional.ofNullable(response.promptEvalCount()).map(Integer::longValue).orElse(0L); - } - - public static OllamaEmbeddingUsage from(EmbeddingsResponse response) { - Assert.notNull(response, "OllamaApi.EmbeddingsResponse must not be null"); - return new OllamaEmbeddingUsage(response); - } - - @Override - public Long getPromptTokens() { - return this.promptTokens; - } - - @Override - public Long getGenerationTokens() { - return 0L; - } - - @Override - public String toString() { - return AI_USAGE_STRING.formatted(getPromptTokens(), getGenerationTokens(), getTotalTokens()); - } - -} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index bdf000c4701..8709a5b8b3a 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -139,7 +139,7 @@ void usageTest() { assertThat(usage).isNotNull(); assertThat(usage.getPromptTokens()).isPositive(); - assertThat(usage.getGenerationTokens()).isPositive(); + assertThat(usage.getCompletionTokens()).isPositive(); assertThat(usage.getTotalTokens()).isPositive(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java index 3635ae57c72..916a364ba65 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java @@ -147,7 +147,7 @@ private void validate(ChatResponseMetadata responseMetadata) { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java index fee9c283877..f5ca4e68636 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java @@ -100,7 +100,7 @@ public void buildChatResponseMetadataAggregationWithNonEmptyMetadata() { ChatResponse previousChatResponse = ChatResponse.builder() .generations(List.of()) .metadata(ChatResponseMetadata.builder() - .usage(new DefaultUsage(66L, 99L)) + .usage(new DefaultUsage(66, 99)) .keyValue("eval-duration", Duration.ofSeconds(2)) .keyValue("prompt-eval-duration", Duration.ofSeconds(2)) .build()) @@ -108,7 +108,7 @@ public void buildChatResponseMetadataAggregationWithNonEmptyMetadata() { ChatResponseMetadata metadata = OllamaChatModel.from(response, previousChatResponse); - assertThat(metadata.getUsage()).isEqualTo(new DefaultUsage(808L + 66L, 101L + 99L)); + assertThat(metadata.getUsage()).isEqualTo(new DefaultUsage(808 + 66, 101 + 99)); assertEquals(Duration.ofNanos(evalDuration).plus(Duration.ofSeconds(2)), metadata.get("eval-duration")); assertEquals((evalCount + 99), (Integer) metadata.get("eval-count")); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 95f15d63a7a..fdb7f2020c4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -40,6 +40,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.metadata.Usage; @@ -71,7 +72,6 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.api.common.OpenAiApiConstants; -import org.springframework.ai.openai.metadata.OpenAiUsage; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.ai.retry.RetryUtils; import org.springframework.core.io.ByteArrayResource; @@ -267,7 +267,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); // Current usage OpenAiApi.Usage usage = completionEntity.getBody().usage(); - Usage currentChatResponseUsage = usage != null ? OpenAiUsage.from(usage) : new EmptyUsage(); + Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody(), rateLimit, accumulatedUsage)); @@ -352,7 +352,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha }).toList(); // @formatter:on OpenAiApi.Usage usage = chatCompletion2.usage(); - Usage currentChatResponseUsage = usage != null ? OpenAiUsage.from(usage) : new EmptyUsage(); + Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage)); @@ -501,6 +501,10 @@ private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionC chunk.systemFingerprint(), "chat.completion", chunk.usage()); } + private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + /** * Accessible for testing. */ diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java index 0cd71fb223d..22c99aefaa8 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java @@ -22,6 +22,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -38,7 +39,6 @@ import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; import org.springframework.ai.openai.api.common.OpenAiApiConstants; -import org.springframework.ai.openai.metadata.OpenAiUsage; import org.springframework.ai.retry.RetryUtils; import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; @@ -168,7 +168,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { } var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(), - OpenAiUsage.from(apiEmbeddingResponse.usage())); + getDefaultUsage(apiEmbeddingResponse.usage())); List embeddings = apiEmbeddingResponse.data() .stream() @@ -183,6 +183,10 @@ public EmbeddingResponse call(EmbeddingRequest request) { }); } + private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + private OpenAiApi.EmbeddingRequest> createRequest(EmbeddingRequest request, OpenAiEmbeddingOptions requestOptions) { return new OpenAiApi.EmbeddingRequest<>(request.getInstructions(), requestOptions.getModel(), diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java deleted file mode 100644 index b8534d53a82..00000000000 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.openai.metadata; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal OpenAI}. - * - * @author John Blum - * @author Thomas Vitale - * @author David Frizelle - * @author Christian Tzolov - * @since 0.7.0 - * @see Completion - * Object - */ -public class OpenAiUsage implements Usage { - - private final OpenAiApi.Usage usage; - - protected OpenAiUsage(OpenAiApi.Usage usage) { - Assert.notNull(usage, "OpenAI Usage must not be null"); - this.usage = usage; - } - - public static OpenAiUsage from(OpenAiApi.Usage usage) { - return new OpenAiUsage(usage); - } - - protected OpenAiApi.Usage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - Integer promptTokens = getUsage().promptTokens(); - return promptTokens != null ? promptTokens.longValue() : 0; - } - - @Override - public Long getGenerationTokens() { - Integer generationTokens = getUsage().completionTokens(); - return generationTokens != null ? generationTokens.longValue() : 0; - } - - @Override - public Long getTotalTokens() { - Integer totalTokens = getUsage().totalTokens(); - if (totalTokens != null) { - return totalTokens.longValue(); - } - else { - return getPromptTokens() + getGenerationTokens(); - } - } - - /** - * @deprecated Use {@link #getPromptTokensDetails()} instead. - */ - @Deprecated - public Long getPromptTokensDetailsCachedTokens() { - OpenAiApi.Usage.PromptTokensDetails promptTokenDetails = getUsage().promptTokensDetails(); - Integer cachedTokens = promptTokenDetails != null ? promptTokenDetails.cachedTokens() : null; - return cachedTokens != null ? cachedTokens.longValue() : 0; - } - - public PromptTokensDetails getPromptTokensDetails() { - var details = getUsage().promptTokensDetails(); - if (details == null) { - return new PromptTokensDetails(0, 0); - } - return new PromptTokensDetails(valueOrZero(details.audioTokens()), valueOrZero(details.cachedTokens())); - } - - /** - * @deprecated Use {@link #getCompletionTokenDetails()} instead. - */ - @Deprecated - public Long getReasoningTokens() { - OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails(); - Integer reasoningTokens = completionTokenDetails != null ? completionTokenDetails.reasoningTokens() : null; - return reasoningTokens != null ? reasoningTokens.longValue() : 0; - } - - /** - * @deprecated Use {@link #getCompletionTokenDetails()} instead. - */ - @Deprecated - public Long getAcceptedPredictionTokens() { - OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails(); - Integer acceptedPredictionTokens = completionTokenDetails != null - ? completionTokenDetails.acceptedPredictionTokens() : null; - return acceptedPredictionTokens != null ? acceptedPredictionTokens.longValue() : 0; - } - - /** - * @deprecated Use {@link #getCompletionTokenDetails()} instead. - */ - @Deprecated - public Long getAudioTokens() { - OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails(); - Integer audioTokens = completionTokenDetails != null ? completionTokenDetails.audioTokens() : null; - return audioTokens != null ? audioTokens.longValue() : 0; - } - - /** - * @deprecated Use {@link #getCompletionTokenDetails()} instead. - */ - @Deprecated - public Long getRejectedPredictionTokens() { - OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails(); - Integer rejectedPredictionTokens = completionTokenDetails != null - ? completionTokenDetails.rejectedPredictionTokens() : null; - return rejectedPredictionTokens != null ? rejectedPredictionTokens.longValue() : 0; - } - - public CompletionTokenDetails getCompletionTokenDetails() { - var details = getUsage().completionTokenDetails(); - if (details == null) { - return new CompletionTokenDetails(0, 0, 0, 0); - } - return new CompletionTokenDetails(valueOrZero(details.reasoningTokens()), - valueOrZero(details.acceptedPredictionTokens()), valueOrZero(details.audioTokens()), - valueOrZero(details.rejectedPredictionTokens())); - } - - @Override - public String toString() { - return getUsage().toString(); - } - - private int valueOrZero(Integer value) { - return value != null ? value : 0; - } - - public record PromptTokensDetails(// @formatter:off - Integer audioTokens, - Integer cachedTokens) { - } - - public record CompletionTokenDetails( - Integer reasoningTokens, - Integer acceptedPredictionTokens, - Integer audioTokens, - Integer rejectedPredictionTokens) { // @formatter:on - } - -} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index ae7432f60ae..ff1a5ff2ac4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -214,12 +214,12 @@ void streamingWithTokenUsage() { var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getPromptTokens()).isCloseTo(referenceTokenUsage.getPromptTokens(), Percentage.withPercentage(25)); - assertThat(streamingTokenUsage.getGenerationTokens()).isCloseTo(referenceTokenUsage.getGenerationTokens(), + assertThat(streamingTokenUsage.getCompletionTokens()).isCloseTo(referenceTokenUsage.getCompletionTokens(), Percentage.withPercentage(25)); assertThat(streamingTokenUsage.getTotalTokens()).isCloseTo(referenceTokenUsage.getTotalTokens(), Percentage.withPercentage(25)); @@ -413,9 +413,9 @@ void functionCallUsageTest() { assertThat(usage).isNotNull(); assertThat(usage).isNotInstanceOf(EmptyUsage.class); assertThat(usage).isInstanceOf(DefaultUsage.class); - assertThat(usage.getPromptTokens()).isGreaterThan(450L).isLessThan(600L); - assertThat(usage.getGenerationTokens()).isGreaterThan(230L).isLessThan(360L); - assertThat(usage.getTotalTokens()).isGreaterThan(680L).isLessThan(900L); + assertThat(usage.getPromptTokens()).isGreaterThan(450).isLessThan(600); + assertThat(usage.getCompletionTokens()).isGreaterThan(230).isLessThan(360); + assertThat(usage.getTotalTokens()).isGreaterThan(680).isLessThan(900); } @Test @@ -442,9 +442,9 @@ void streamFunctionCallUsageTest() { assertThat(usage).isNotNull(); assertThat(usage).isNotInstanceOf(EmptyUsage.class); assertThat(usage).isInstanceOf(DefaultUsage.class); - assertThat(usage.getPromptTokens()).isGreaterThan(450L).isLessThan(600L); - assertThat(usage.getGenerationTokens()).isGreaterThan(230L).isLessThan(360L); - assertThat(usage.getTotalTokens()).isGreaterThan(680L).isLessThan(960L); + assertThat(usage.getPromptTokens()).isGreaterThan(450).isLessThan(600); + assertThat(usage.getCompletionTokens()).isGreaterThan(230).isLessThan(360); + assertThat(usage.getTotalTokens()).isGreaterThan(680).isLessThan(960); } @ParameterizedTest(name = "{0} : {displayName} ") @@ -596,7 +596,7 @@ void validateCallResponseMetadata() { assertThat(response.getMetadata().getId()).isNotEmpty(); assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java index cecf4e1d1d6..e6e58499e4a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java @@ -152,7 +152,7 @@ private void validate(ChatResponseMetadata responseMetadata) { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java index b7229381fac..76dd5a7ee0e 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java @@ -88,7 +88,7 @@ void aiResponseContainsAiMetadata() { assertThat(usage).isNotNull(); assertThat(usage.getPromptTokens()).isEqualTo(9L); - assertThat(usage.getGenerationTokens()).isEqualTo(12L); + assertThat(usage.getCompletionTokens()).isEqualTo(12L); assertThat(usage.getTotalTokens()).isEqualTo(21L); RateLimit rateLimit = chatResponseMetadata.getRateLimit(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java index f097c351297..cce105b7876 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java @@ -126,11 +126,11 @@ void streamingWithTokenUsage() { var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); - assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens()); + assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens()); assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens()); } @@ -371,7 +371,7 @@ void validateCallResponseMetadata(String model) { assertThat(response.getMetadata().getId()).isNotEmpty(); assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java index c648c2c71bd..7071f3608be 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java @@ -125,11 +125,11 @@ void streamingWithTokenUsage() { var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); - assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens()); + assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens()); assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens()); } @@ -376,7 +376,7 @@ void validateCallResponseMetadata(String model) { assertThat(response.getMetadata().getId()).isNotEmpty(); assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java index e7aae6a8e6d..4983e96c7e9 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java @@ -122,11 +122,11 @@ void streamingWithTokenUsage() { var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); - assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens()); + assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens()); assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens()); } @@ -305,7 +305,7 @@ void validateCallResponseMetadata() { assertThat(response.getMetadata().getId()).isNotEmpty(); assertThat(response.getMetadata().getModel()).containsIgnoringCase(DEFAULT_NVIDIA_MODEL); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java index 9a7fb2e69b6..c4637097fd8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java @@ -143,11 +143,11 @@ void streamingWithTokenUsage() { var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); // assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); - assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens()); + assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens()); assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens()); } @@ -400,7 +400,7 @@ void validateCallResponseMetadata(String model) { assertThat(response.getMetadata().getId()).isNotEmpty(); assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java index 589953fbf29..ceb6c944121 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java @@ -139,12 +139,12 @@ void streamingWithTokenUsage() { var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); - assertThat(streamingTokenUsage.getGenerationTokens()) - .isGreaterThanOrEqualTo(referenceTokenUsage.getGenerationTokens()); + assertThat(streamingTokenUsage.getCompletionTokens()) + .isGreaterThanOrEqualTo(referenceTokenUsage.getCompletionTokens()); assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThanOrEqualTo(referenceTokenUsage.getTotalTokens()); } @@ -315,7 +315,7 @@ void validateCallResponseMetadata() { assertThat(response.getMetadata().getId()).isNotEmpty(); assertThat(response.getMetadata().getModel()).containsIgnoringCase(DEFAULT_PERPLEXITY_MODEL); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java index 92e8f595b3f..6d9c44ffbdb 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java @@ -18,129 +18,142 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.openai.api.OpenAiApi; import static org.assertj.core.api.Assertions.assertThat; /** - * Unit tests for {@link OpenAiUsage}. + * Unit tests for OpenAI usage data. * * @author Thomas Vitale * @author Christian Tzolov + * @author Ilayaperumal Gopinathan */ class OpenAiUsageTests { + private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + @Test void whenPromptTokensIsPresent() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + DefaultUsage usage = getDefaultUsage(openAiUsage); assertThat(usage.getPromptTokens()).isEqualTo(200); } @Test void whenPromptTokensIsNull() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, null, 100); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + DefaultUsage usage = getDefaultUsage(openAiUsage); assertThat(usage.getPromptTokens()).isEqualTo(0); } @Test void whenGenerationTokensIsPresent() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getGenerationTokens()).isEqualTo(100); + DefaultUsage usage = getDefaultUsage(openAiUsage); + assertThat(usage.getCompletionTokens()).isEqualTo(100); } @Test void whenGenerationTokensIsNull() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(null, 200, 200); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getGenerationTokens()).isEqualTo(0); + DefaultUsage usage = getDefaultUsage(openAiUsage); + assertThat(usage.getCompletionTokens()).isEqualTo(0); } @Test void whenTotalTokensIsPresent() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + DefaultUsage usage = getDefaultUsage(openAiUsage); assertThat(usage.getTotalTokens()).isEqualTo(300); } @Test void whenTotalTokensIsNull() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, null); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + DefaultUsage usage = getDefaultUsage(openAiUsage); assertThat(usage.getTotalTokens()).isEqualTo(300); } @Test void whenPromptAndCompletionTokensDetailsIsNull() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, null); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + DefaultUsage usage = getDefaultUsage(openAiUsage); assertThat(usage.getTotalTokens()).isEqualTo(300); - assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.promptTokensDetails()).isNull(); + assertThat(nativeUsage.completionTokenDetails()).isNull(); } @Test void whenCompletionTokenDetailsIsNull() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, null); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + DefaultUsage usage = getDefaultUsage(openAiUsage); assertThat(usage.getTotalTokens()).isEqualTo(300); - assertThat(usage.getReasoningTokens()).isEqualTo(0); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.completionTokenDetails()).isNull(); } @Test void whenReasoningTokensIsNull() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, new OpenAiApi.Usage.CompletionTokenDetails(null, null, null, null)); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getReasoningTokens()).isEqualTo(0); + DefaultUsage usage = getDefaultUsage(openAiUsage); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.completionTokenDetails().reasoningTokens()).isEqualTo(null); } @Test void whenCompletionTokenDetailsIsPresent() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, new OpenAiApi.Usage.CompletionTokenDetails(50, null, null, null)); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(50); - assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0); + DefaultUsage usage = getDefaultUsage(openAiUsage); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.completionTokenDetails().reasoningTokens()).isEqualTo(50); + assertThat(nativeUsage.completionTokenDetails().acceptedPredictionTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().audioTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().rejectedPredictionTokens()).isEqualTo(null); } @Test void whenAcceptedPredictionTokensIsPresent() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, new OpenAiApi.Usage.CompletionTokenDetails(null, 75, null, null)); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(75); - assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0); + DefaultUsage usage = getDefaultUsage(openAiUsage); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.completionTokenDetails().reasoningTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().acceptedPredictionTokens()).isEqualTo(75); + assertThat(nativeUsage.completionTokenDetails().audioTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().rejectedPredictionTokens()).isEqualTo(null); } @Test void whenAudioTokensIsPresent() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, new OpenAiApi.Usage.CompletionTokenDetails(null, null, 125, null)); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(125); - assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0); + DefaultUsage usage = getDefaultUsage(openAiUsage); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.completionTokenDetails().reasoningTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().acceptedPredictionTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().audioTokens()).isEqualTo(125); + assertThat(nativeUsage.completionTokenDetails().rejectedPredictionTokens()).isEqualTo(null); } @Test void whenRejectedPredictionTokensIsNull() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, new OpenAiApi.Usage.CompletionTokenDetails(null, null, null, null)); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(0); + DefaultUsage usage = getDefaultUsage(openAiUsage); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.completionTokenDetails().reasoningTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().acceptedPredictionTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().audioTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().rejectedPredictionTokens()).isEqualTo(null); + assertThat(nativeUsage.promptTokensDetails()).isEqualTo(null); } @@ -148,29 +161,32 @@ void whenRejectedPredictionTokensIsNull() { void whenRejectedPredictionTokensIsPresent() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, new OpenAiApi.Usage.CompletionTokenDetails(null, null, null, 25)); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getCompletionTokenDetails().reasoningTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().acceptedPredictionTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().audioTokens()).isEqualTo(0); - assertThat(usage.getCompletionTokenDetails().rejectedPredictionTokens()).isEqualTo(25); + DefaultUsage usage = getDefaultUsage(openAiUsage); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.completionTokenDetails().reasoningTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().acceptedPredictionTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().audioTokens()).isEqualTo(null); + assertThat(nativeUsage.completionTokenDetails().rejectedPredictionTokens()).isEqualTo(25); } @Test void whenCacheTokensIsNull() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(null, null), null); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getPromptTokensDetails().audioTokens()).isEqualTo(0); - assertThat(usage.getPromptTokensDetails().cachedTokens()).isEqualTo(0); + DefaultUsage usage = getDefaultUsage(openAiUsage); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.promptTokensDetails().audioTokens()).isEqualTo(null); + assertThat(nativeUsage.promptTokensDetails().cachedTokens()).isEqualTo(null); } @Test void whenCacheTokensIsPresent() { OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(99, 15), null); - OpenAiUsage usage = OpenAiUsage.from(openAiUsage); - assertThat(usage.getPromptTokensDetails().audioTokens()).isEqualTo(99); - assertThat(usage.getPromptTokensDetails().cachedTokens()).isEqualTo(15); + DefaultUsage usage = getDefaultUsage(openAiUsage); + OpenAiApi.Usage nativeUsage = (OpenAiApi.Usage) usage.getNativeUsage(); + assertThat(nativeUsage.promptTokensDetails().audioTokens()).isEqualTo(99); + assertThat(nativeUsage.promptTokensDetails().cachedTokens()).isEqualTo(15); } } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java index e76e76b0609..5f3d1ab71a3 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java @@ -30,6 +30,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -50,7 +51,6 @@ import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage.Role; import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionRequest; import org.springframework.ai.qianfan.api.QianFanConstants; -import org.springframework.ai.qianfan.metadata.QianFanUsage; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; @@ -292,12 +292,16 @@ private ChatResponseMetadata from(QianFanApi.ChatCompletion result, String model Assert.notNull(result, "QianFan ChatCompletionResult must not be null"); return ChatResponseMetadata.builder() .id(result.id() != null ? result.id() : "") - .usage(result.usage() != null ? QianFanUsage.from(result.usage()) : new EmptyUsage()) + .usage(result.usage() != null ? getDefaultUsage(result.usage()) : new EmptyUsage()) .model(model) .keyValue("created", result.created() != null ? result.created() : 0L) .build(); } + private DefaultUsage getDefaultUsage(QianFanApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + public void setObservationConvention(ChatModelObservationConvention observationConvention) { this.observationConvention = observationConvention; } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java index 6740031b1a0..a77635d4a39 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java @@ -22,6 +22,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -38,7 +39,6 @@ import org.springframework.ai.qianfan.api.QianFanApi; import org.springframework.ai.qianfan.api.QianFanApi.EmbeddingList; import org.springframework.ai.qianfan.api.QianFanConstants; -import org.springframework.ai.qianfan.metadata.QianFanUsage; import org.springframework.ai.retry.RetryUtils; import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; @@ -176,7 +176,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { } var metadata = new EmbeddingResponseMetadata(apiRequest.model(), - QianFanUsage.from(apiEmbeddingResponse.usage())); + getDefaultUsage(apiEmbeddingResponse.usage())); List embeddings = apiEmbeddingResponse.data() .stream() @@ -192,6 +192,10 @@ public EmbeddingResponse call(EmbeddingRequest request) { } + private DefaultUsage getDefaultUsage(QianFanApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + /** * Merge runtime and default {@link EmbeddingOptions} to compute the final options to * use in the request. diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java deleted file mode 100644 index eaa69e75502..00000000000 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/metadata/QianFanUsage.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.qianfan.metadata; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.qianfan.api.QianFanApi; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal QianFan}. - * - * @author Thomas Vitale - */ -public class QianFanUsage implements Usage { - - private final QianFanApi.Usage usage; - - protected QianFanUsage(QianFanApi.Usage usage) { - Assert.notNull(usage, "QianFan Usage must not be null"); - this.usage = usage; - } - - public static QianFanUsage from(QianFanApi.Usage usage) { - return new QianFanUsage(usage); - } - - protected QianFanApi.Usage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - return getUsage().promptTokens().longValue(); - } - - @Override - public Long getGenerationTokens() { - return 0L; - } - - @Override - public Long getTotalTokens() { - return getUsage().totalTokens().longValue(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java index 518b3d81c3c..24efc866c28 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java @@ -148,7 +148,7 @@ private void validate(ChatResponseMetadata responseMetadata) { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java deleted file mode 100644 index 602afbd80e3..00000000000 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUsage.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.vertexai.embedding; - -import org.springframework.ai.chat.metadata.Usage; - -public class VertexAiEmbeddingUsage implements Usage { - - private final Integer totalTokens; - - public VertexAiEmbeddingUsage(Integer totalTokens) { - this.totalTokens = totalTokens; - } - - @Override - public Long getPromptTokens() { - return 0L; - } - - @Override - public Long getGenerationTokens() { - return 0L; - } - - @Override - public Long getTotalTokens() { - return Long.valueOf(this.totalTokens); - } - -} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java index a269ca80603..9062f6ba35b 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModel.java @@ -32,6 +32,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.DocumentEmbeddingModel; @@ -44,7 +45,6 @@ import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; -import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.MultimodalInstanceBuilder; @@ -242,10 +242,14 @@ else if (media.getMimeType().isCompatibleWith(VIDEO_MIME_TYPE)) { private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens, Map metadataToUse) { - Usage usage = new VertexAiEmbeddingUsage(totalTokens); + Usage usage = getDefaultUsage(totalTokens); return new EmbeddingResponseMetadata(model, usage, metadataToUse); } + private DefaultUsage getDefaultUsage(Integer totalTokens) { + return new DefaultUsage(0, 0, totalTokens); + } + @Override public int dimensions() { return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), 768); diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 2e839736472..c2cc0ea5cbc 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -30,6 +30,7 @@ import com.google.protobuf.Value; import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -45,7 +46,6 @@ import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails; -import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder; @@ -222,11 +222,15 @@ PredictResponse getPredictResponse(PredictionServiceClient client, PredictReques private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) { EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); metadata.setModel(model); - Usage usage = new VertexAiEmbeddingUsage(totalTokens); + Usage usage = getDefaultUsage(totalTokens); metadata.setUsage(usage); return metadata; } + private DefaultUsage getDefaultUsage(Integer totalTokens) { + return new DefaultUsage(0, 0, totalTokens); + } + @Override public int dimensions() { return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions()); diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 5fa97940acf..ac83c57176d 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -57,6 +57,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -77,7 +78,6 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiConstants; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; -import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage; import org.springframework.beans.factory.DisposableBean; import org.springframework.lang.NonNull; import org.springframework.retry.support.RetryTemplate; @@ -428,7 +428,12 @@ protected List responseCandidateToGeneration(Candidate candidate) { } private ChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) { - return ChatResponseMetadata.builder().usage(new VertexAiUsage(response.getUsageMetadata())).build(); + return ChatResponseMetadata.builder().usage(getDefaultUsage(response.getUsageMetadata())).build(); + } + + private DefaultUsage getDefaultUsage(GenerateContentResponse.UsageMetadata usageMetadata) { + return new DefaultUsage(usageMetadata.getPromptTokenCount(), usageMetadata.getCandidatesTokenCount(), + usageMetadata.getTotalTokenCount(), usageMetadata); } private VertexAiGeminiChatOptions vertexAiGeminiChatOptions(Prompt prompt) { diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java deleted file mode 100644 index a0dcc376618..00000000000 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/metadata/VertexAiUsage.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.vertexai.gemini.metadata; - -import com.google.cloud.vertexai.api.GenerateContentResponse.UsageMetadata; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.util.Assert; - -/** - * Represents the usage of a Vertex AI model. - * - * @author Christian Tzolov - * @since 0.8.1 - * - */ -public class VertexAiUsage implements Usage { - - private final UsageMetadata usageMetadata; - - public VertexAiUsage(UsageMetadata usageMetadata) { - Assert.notNull(usageMetadata, "UsageMetadata must not be null"); - this.usageMetadata = usageMetadata; - } - - @Override - public Long getPromptTokens() { - return Long.valueOf(this.usageMetadata.getPromptTokenCount()); - } - - @Override - public Long getGenerationTokens() { - return Long.valueOf(this.usageMetadata.getCandidatesTokenCount()); - } - -} diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java index c6f68cfc5c4..2cb70793ade 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java @@ -148,7 +148,7 @@ private void validate(ChatResponseMetadata responseMetadata) { String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue( ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue( ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index c854864718b..17d56285277 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -38,6 +38,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; @@ -68,7 +69,6 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest; import org.springframework.ai.zhipuai.api.ZhiPuApiConstants; -import org.springframework.ai.zhipuai.metadata.ZhiPuAiUsage; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -326,13 +326,17 @@ private ChatResponseMetadata from(ChatCompletion result) { Assert.notNull(result, "ZhiPuAI ChatCompletionResult must not be null"); return ChatResponseMetadata.builder() .id(result.id() != null ? result.id() : "") - .usage(result.usage() != null ? ZhiPuAiUsage.from(result.usage()) : new EmptyUsage()) + .usage(result.usage() != null ? getDefaultUsage(result.usage()) : new EmptyUsage()) .model(result.model() != null ? result.model() : "") .keyValue("created", result.created() != null ? result.created() : 0L) .keyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : "") .build(); } + private DefaultUsage getDefaultUsage(ZhiPuAiApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java index eb21c21ea50..0ba5666e72f 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java @@ -24,6 +24,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingModel; @@ -40,7 +41,6 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; import org.springframework.ai.zhipuai.api.ZhiPuApiConstants; -import org.springframework.ai.zhipuai.metadata.ZhiPuAiUsage; import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -190,7 +190,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { String model = (request.getOptions() != null && request.getOptions().getModel() != null) ? request.getOptions().getModel() : "unknown"; - var metadata = new EmbeddingResponseMetadata(model, ZhiPuAiUsage.from(totalUsage)); + var metadata = new EmbeddingResponseMetadata(model, getDefaultUsage(totalUsage)); var indexCounter = new AtomicInteger(0); @@ -206,6 +206,10 @@ public EmbeddingResponse call(EmbeddingRequest request) { }); } + private DefaultUsage getDefaultUsage(ZhiPuAiApi.Usage usage) { + return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); + } + /** * Merge runtime and default {@link EmbeddingOptions} to compute the final options to * use in the request. diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java deleted file mode 100644 index 88d197e9f48..00000000000 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.zhipuai.metadata; - -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.zhipuai.api.ZhiPuAiApi; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal ZhiPuAI}. - * - * @author Geng Rong - * @since 1.0.0 M1 - */ -public class ZhiPuAiUsage implements Usage { - - private final ZhiPuAiApi.Usage usage; - - protected ZhiPuAiUsage(ZhiPuAiApi.Usage usage) { - Assert.notNull(usage, "ZhiPuAI Usage must not be null"); - this.usage = usage; - } - - public static ZhiPuAiUsage from(ZhiPuAiApi.Usage usage) { - return new ZhiPuAiUsage(usage); - } - - protected ZhiPuAiApi.Usage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - return getUsage().promptTokens().longValue(); - } - - @Override - public Long getGenerationTokens() { - return getUsage().completionTokens().longValue(); - } - - @Override - public Long getTotalTokens() { - return getUsage().totalTokens().longValue(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java index a46f50f1a94..118cef3bad2 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelObservationIT.java @@ -143,7 +143,7 @@ private void validate(ChatResponseMetadata responseMetadata) { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getPromptTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java index a9fa52a30db..789fe429f8c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,74 +16,120 @@ package org.springframework.ai.chat.metadata; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; /** * Default implementation of the {@link Usage} interface. * * @author Mark Pollack + * @author Ilayaperumal Gopinathan * @since 1.0.0 */ public class DefaultUsage implements Usage { - private final Long promptTokens; + private final Integer promptTokens; + private final Integer completionTokens; + + @Deprecated(forRemoval = true, since = "1.0.0-M6") private final Long generationTokens; - private final Long totalTokens; + private final int totalTokens; + + private final Object nativeUsage; /** - * Create a new DefaultUsage with promptTokens and generationTokens. + * Create a new DefaultUsage with promptTokens, completionTokens, totalTokens and + * native {@link Usage} object. * @param promptTokens the number of tokens in the prompt, or {@code null} if not * available - * @param generationTokens the number of tokens in the generation, or {@code null} if + * @param completionTokens the number of tokens in the generation, or {@code null} if * not available + * @param totalTokens the total number of tokens, or {@code null} to calculate from + * promptTokens and completionTokens + * @param nativeUsage the native usage object returned by the model provider, or + * {@code null} to return the map of prompt, completion and total tokens. */ - public DefaultUsage(Long promptTokens, Long generationTokens) { - this(promptTokens, generationTokens, null); + public DefaultUsage(Integer promptTokens, Integer completionTokens, Integer totalTokens, Object nativeUsage) { + this.promptTokens = promptTokens != null ? promptTokens : 0; + this.completionTokens = completionTokens != null ? completionTokens : 0; + this.generationTokens = Long.valueOf(this.completionTokens); + this.totalTokens = totalTokens != null ? totalTokens + : calculateTotalTokens(this.promptTokens, this.completionTokens); + this.nativeUsage = (nativeUsage != null) ? nativeUsage : getDefaultNativeUsage(); } /** - * Create a new DefaultUsage with promptTokens, generationTokens, and totalTokens. + * Create a new DefaultUsage with promptTokens and completionTokens. * @param promptTokens the number of tokens in the prompt, or {@code null} if not * available - * @param generationTokens the number of tokens in the generation, or {@code null} if + * @param completionTokens the number of tokens in the generation, or {@code null} if + * not available + */ + public DefaultUsage(Integer promptTokens, Integer completionTokens) { + this(promptTokens, completionTokens, null); + } + + /** + * Create a new DefaultUsage with promptTokens, completionTokens, and totalTokens. + * @param promptTokens the number of tokens in the prompt, or {@code null} if not + * available + * @param completionTokens the number of tokens in the generation, or {@code null} if * not available * @param totalTokens the total number of tokens, or {@code null} to calculate from - * promptTokens and generationTokens + * promptTokens and completionTokens */ @JsonCreator - public DefaultUsage(@JsonProperty("promptTokens") Long promptTokens, - @JsonProperty("generationTokens") Long generationTokens, @JsonProperty("totalTokens") Long totalTokens) { - this.promptTokens = promptTokens != null ? promptTokens : 0L; - this.generationTokens = generationTokens != null ? generationTokens : 0L; - this.totalTokens = totalTokens != null ? totalTokens - : calculateTotalTokens(this.promptTokens, this.generationTokens); + public DefaultUsage(@JsonProperty("promptTokens") Integer promptTokens, + @JsonProperty("completionTokens") Integer completionTokens, + @JsonProperty("totalTokens") Integer totalTokens) { + this(promptTokens, completionTokens, totalTokens, null); } @Override @JsonProperty("promptTokens") - public Long getPromptTokens() { + public Integer getPromptTokens() { return this.promptTokens; } @Override - @JsonProperty("generationTokens") - public Long getGenerationTokens() { - return this.generationTokens; + @JsonProperty("completionTokens") + public Integer getCompletionTokens() { + return this.completionTokens; } @Override @JsonProperty("totalTokens") - public Long getTotalTokens() { + public Integer getTotalTokens() { return this.totalTokens; } - private Long calculateTotalTokens(Long promptTokens, Long generationTokens) { - return promptTokens + generationTokens; + @Override + @JsonIgnore + public Object getNativeUsage() { + return this.nativeUsage; + } + + /** + * By default, return the Map of prompt, completion and total tokens. + * @return map containing the prompt, completion and total tokens. + */ + private Map getDefaultNativeUsage() { + Map usage = new HashMap<>(); + usage.put("promptTokens", this.promptTokens); + usage.put("completionTokens", this.completionTokens); + usage.put("totalTokens", this.totalTokens); + return usage; + } + + private Integer calculateTotalTokens(Integer promptTokens, Integer completionTokens) { + return promptTokens + completionTokens; } @Override @@ -96,18 +142,18 @@ public boolean equals(Object o) { } DefaultUsage that = (DefaultUsage) o; return Objects.equals(this.promptTokens, that.promptTokens) - && Objects.equals(this.generationTokens, that.generationTokens) + && Objects.equals(this.completionTokens, that.completionTokens) && Objects.equals(this.totalTokens, that.totalTokens); } @Override public int hashCode() { - return Objects.hash(this.promptTokens, this.generationTokens, this.totalTokens); + return Objects.hash(this.promptTokens, this.completionTokens, this.totalTokens); } @Override public String toString() { - return "DefaultUsage{" + "promptTokens=" + this.promptTokens + ", generationTokens=" + this.generationTokens + return "DefaultUsage{" + "promptTokens=" + this.promptTokens + ", completionTokens=" + this.completionTokens + ", totalTokens=" + this.totalTokens + '}'; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java index b9cdaf87249..5aeefe809d5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/EmptyUsage.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,22 +16,30 @@ package org.springframework.ai.chat.metadata; +import java.util.Map; + /** * A EmpytUsage implementation that returns zero for all property getters * * @author John Blum + * @author Ilayaperumal Gopinathan * @since 0.7.0 */ public class EmptyUsage implements Usage { @Override - public Long getPromptTokens() { - return 0L; + public Integer getPromptTokens() { + return 0; + } + + @Override + public Integer getCompletionTokens() { + return 0; } @Override - public Long getGenerationTokens() { - return 0L; + public Object getNativeUsage() { + return Map.of(); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java index 887bfbaa4c5..fa43712ff6a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,26 +21,32 @@ * per AI request. * * @author John Blum + * @author Ilayaperumal Gopinathan * @since 0.7.0 */ public interface Usage { /** * Returns the number of tokens used in the {@literal prompt} of the AI request. - * @return an {@link Long} with the number of tokens used in the {@literal prompt} of - * the AI request. - * @see #getGenerationTokens() + * @return an {@link Integer} with the number of tokens used in the {@literal prompt} + * of the AI request. + * @see #getCompletionTokens() */ - Long getPromptTokens(); + Integer getPromptTokens(); + + @Deprecated(forRemoval = true, since = "1.0.0-M6") + default Long getGenerationTokens() { + return getCompletionTokens().longValue(); + } /** * Returns the number of tokens returned in the {@literal generation (aka completion)} * of the AI's response. - * @return an {@link Long} with the number of tokens returned in the + * @return an {@link Integer} with the number of tokens returned in the * {@literal generation (aka completion)} of the AI's response. * @see #getPromptTokens() */ - Long getGenerationTokens(); + Integer getCompletionTokens(); /** * Return the total number of tokens from both the {@literal prompt} of an AI request @@ -48,14 +54,20 @@ public interface Usage { * @return the total number of tokens from both the {@literal prompt} of an AI request * and {@literal generation} of the AI's response. * @see #getPromptTokens() - * @see #getGenerationTokens() + * @see #getCompletionTokens() */ - default Long getTotalTokens() { - Long promptTokens = getPromptTokens(); + default Integer getTotalTokens() { + Integer promptTokens = getPromptTokens(); promptTokens = promptTokens != null ? promptTokens : 0; - Long completionTokens = getGenerationTokens(); + Integer completionTokens = getCompletionTokens(); completionTokens = completionTokens != null ? completionTokens : 0; return promptTokens + completionTokens; } + /** + * Return the usage data from the underlying model API response. + * @return the object of type inferred by the API response. + */ + Object getNativeUsage(); + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java index a430d4cb5a3..bffd7fb011d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java @@ -50,12 +50,12 @@ public static Usage getCumulativeUsage(final Usage currentUsage, final ChatRespo // For a valid usage from previous chat response, accumulate it to the current // usage. if (!isEmpty(currentUsage)) { - Long promptTokens = currentUsage.getPromptTokens().longValue(); - Long generationTokens = currentUsage.getGenerationTokens().longValue(); - Long totalTokens = currentUsage.getTotalTokens().longValue(); + Integer promptTokens = currentUsage.getPromptTokens(); + Integer generationTokens = currentUsage.getCompletionTokens(); + Integer totalTokens = currentUsage.getTotalTokens(); // Make sure to accumulate the usage from the previous chat response. promptTokens += usageFromPreviousChatResponse.getPromptTokens(); - generationTokens += usageFromPreviousChatResponse.getGenerationTokens(); + generationTokens += usageFromPreviousChatResponse.getCompletionTokens(); totalTokens += usageFromPreviousChatResponse.getTotalTokens(); return new DefaultUsage(promptTokens, generationTokens, totalTokens); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 965b2f6d629..b94a91f6dbd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -81,9 +81,9 @@ public Flux aggregate(Flux fluxChatResponse, ChatGenerationMetadata.NULL); // Usage - AtomicReference metadataUsagePromptTokensRef = new AtomicReference<>(0L); - AtomicReference metadataUsageGenerationTokensRef = new AtomicReference<>(0L); - AtomicReference metadataUsageTotalTokensRef = new AtomicReference<>(0L); + AtomicReference metadataUsagePromptTokensRef = new AtomicReference(0); + AtomicReference metadataUsageGenerationTokensRef = new AtomicReference(0); + AtomicReference metadataUsageTotalTokensRef = new AtomicReference(0); AtomicReference metadataPromptMetadataRef = new AtomicReference<>(PromptMetadata.empty()); AtomicReference metadataRateLimitRef = new AtomicReference<>(new EmptyRateLimit()); @@ -96,9 +96,9 @@ public Flux aggregate(Flux fluxChatResponse, messageMetadataMapRef.set(new HashMap<>()); metadataIdRef.set(""); metadataModelRef.set(""); - metadataUsagePromptTokensRef.set(0L); - metadataUsageGenerationTokensRef.set(0L); - metadataUsageTotalTokensRef.set(0L); + metadataUsagePromptTokensRef.set(0); + metadataUsageGenerationTokensRef.set(0); + metadataUsageTotalTokensRef.set(0); metadataPromptMetadataRef.set(PromptMetadata.empty()); metadataRateLimitRef.set(new EmptyRateLimit()); @@ -121,7 +121,7 @@ public Flux aggregate(Flux fluxChatResponse, Usage usage = chatResponse.getMetadata().getUsage(); metadataUsagePromptTokensRef.set( usage.getPromptTokens() > 0 ? usage.getPromptTokens() : metadataUsagePromptTokensRef.get()); - metadataUsageGenerationTokensRef.set(usage.getGenerationTokens() > 0 ? usage.getGenerationTokens() + metadataUsageGenerationTokensRef.set(usage.getCompletionTokens() > 0 ? usage.getCompletionTokens() : metadataUsageGenerationTokensRef.get()); metadataUsageTotalTokensRef .set(usage.getTotalTokens() > 0 ? usage.getTotalTokens() : metadataUsageTotalTokensRef.get()); @@ -162,32 +162,40 @@ public Flux aggregate(Flux fluxChatResponse, messageMetadataMapRef.set(new HashMap<>()); metadataIdRef.set(""); metadataModelRef.set(""); - metadataUsagePromptTokensRef.set(0L); - metadataUsageGenerationTokensRef.set(0L); - metadataUsageTotalTokensRef.set(0L); + metadataUsagePromptTokensRef.set(0); + metadataUsageGenerationTokensRef.set(0); + metadataUsageTotalTokensRef.set(0); metadataPromptMetadataRef.set(PromptMetadata.empty()); metadataRateLimitRef.set(new EmptyRateLimit()); }).doOnError(e -> logger.error("Aggregation Error", e)); } - public record DefaultUsage(long promptTokens, long generationTokens, long totalTokens) implements Usage { + public record DefaultUsage(Integer promptTokens, Integer completionTokens, Integer totalTokens) implements Usage { @Override - public Long getPromptTokens() { + public Integer getPromptTokens() { return promptTokens(); } @Override - public Long getGenerationTokens() { - return generationTokens(); + public Integer getCompletionTokens() { + return completionTokens(); } @Override - public Long getTotalTokens() { + public Integer getTotalTokens() { return totalTokens(); } + @Override + public Map getNativeUsage() { + Map usage = new HashMap<>(); + usage.put("promptTokens", promptTokens()); + usage.put("completionTokens", completionTokens()); + usage.put("totalTokens", totalTokens()); + return usage; + } } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java index a9d4bf0504a..0a2b16bb19e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java @@ -222,10 +222,10 @@ protected KeyValues usageInputTokens(KeyValues keyValues, ChatModelObservationCo protected KeyValues usageOutputTokens(KeyValues keyValues, ChatModelObservationContext context) { if (context.getResponse() != null && context.getResponse().getMetadata() != null && context.getResponse().getMetadata().getUsage() != null - && context.getResponse().getMetadata().getUsage().getGenerationTokens() != null) { + && context.getResponse().getMetadata().getUsage().getCompletionTokens() != null) { return keyValues.and( ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(context.getResponse().getMetadata().getUsage().getGenerationTokens())); + String.valueOf(context.getResponse().getMetadata().getUsage().getCompletionTokens())); } return keyValues; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java index 4a5eb8eaf71..c9443e50027 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ModelUsageMetricsGenerator.java @@ -54,13 +54,13 @@ public static void generate(Usage usage, Observation.Context context, MeterRegis .increment(usage.getPromptTokens()); } - if (usage.getGenerationTokens() != null) { + if (usage.getCompletionTokens() != null) { Counter.builder(AiObservationMetricNames.TOKEN_USAGE.value()) .tag(AiObservationMetricAttributes.TOKEN_TYPE.value(), AiTokenType.OUTPUT.value()) .description(DESCRIPTION) .tags(createTags(context)) .register(meterRegistry) - .increment(usage.getGenerationTokens()); + .increment(usage.getCompletionTokens()); } if (usage.getTotalTokens() != null) { diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java index 8511409a967..220479c5c73 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java @@ -104,7 +104,7 @@ public Long getTokensRemaining() { public Duration getTokensReset() { return Duration.ofSeconds(9); } - }).usage(new DefaultUsage(6L, 7L)) + }).usage(new DefaultUsage(6, 7)) .build())); // @formatter:on @@ -137,7 +137,7 @@ public Duration getTokensReset() { assertThat(response.getMetadata().getRateLimit().getTokensRemaining()).isEqualTo(8L); assertThat(response.getMetadata().getRateLimit().getTokensReset()).isEqualTo(Duration.ofSeconds(9)); assertThat(response.getMetadata().getUsage().getPromptTokens()).isEqualTo(6L); - assertThat(response.getMetadata().getUsage().getGenerationTokens()).isEqualTo(7L); + assertThat(response.getMetadata().getUsage().getCompletionTokens()).isEqualTo(7L); assertThat(response.getMetadata().getUsage().getTotalTokens()).isEqualTo(6L + 7L); assertThat(response.getMetadata().get("key6").toString()).isEqualTo("value6"); assertThat(response.getMetadata().get("key1").toString()).isEqualTo("value1"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java index 8059faf9580..ef426a057c3 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java @@ -27,73 +27,75 @@ public class DefaultUsageTests { @Test void testSerializationWithAllFields() throws Exception { - DefaultUsage usage = new DefaultUsage(100L, 50L, 150L); + DefaultUsage usage = new DefaultUsage(100, 50, 150); String json = this.objectMapper.writeValueAsString(usage); - assertEquals("{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}", json); + assertEquals("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}", + json); } @Test void testDeserializationWithAllFields() throws Exception { - String json = "{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}"; + String json = "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}"; DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); - assertEquals(100L, usage.getPromptTokens()); - assertEquals(50L, usage.getGenerationTokens()); - assertEquals(150L, usage.getTotalTokens()); + assertEquals(100, usage.getPromptTokens()); + assertEquals(50, usage.getCompletionTokens()); + assertEquals(150, usage.getTotalTokens()); } @Test void testSerializationWithNullFields() throws Exception { DefaultUsage usage = new DefaultUsage(null, null, null); String json = this.objectMapper.writeValueAsString(usage); - assertEquals("{\"promptTokens\":0,\"generationTokens\":0,\"totalTokens\":0}", json); + assertEquals("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0,\"generationTokens\":0}", json); } @Test void testDeserializationWithMissingFields() throws Exception { String json = "{\"promptTokens\":100}"; DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); - assertEquals(100L, usage.getPromptTokens()); - assertEquals(0L, usage.getGenerationTokens()); - assertEquals(100L, usage.getTotalTokens()); + assertEquals(100, usage.getPromptTokens()); + assertEquals(0, usage.getCompletionTokens()); + assertEquals(100, usage.getTotalTokens()); } @Test void testDeserializationWithNullFields() throws Exception { - String json = "{\"promptTokens\":null,\"generationTokens\":null,\"totalTokens\":null}"; + String json = "{\"promptTokens\":null,\"completionTokens\":null,\"totalTokens\":null}"; DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); - assertEquals(0L, usage.getPromptTokens()); - assertEquals(0L, usage.getGenerationTokens()); - assertEquals(0L, usage.getTotalTokens()); + assertEquals(0, usage.getPromptTokens()); + assertEquals(0, usage.getCompletionTokens()); + assertEquals(0, usage.getTotalTokens()); } @Test void testRoundTripSerialization() throws Exception { - DefaultUsage original = new DefaultUsage(100L, 50L, 150L); + DefaultUsage original = new DefaultUsage(100, 50, 150); String json = this.objectMapper.writeValueAsString(original); DefaultUsage deserialized = this.objectMapper.readValue(json, DefaultUsage.class); assertEquals(original.getPromptTokens(), deserialized.getPromptTokens()); - assertEquals(original.getGenerationTokens(), deserialized.getGenerationTokens()); + assertEquals(original.getCompletionTokens(), deserialized.getCompletionTokens()); assertEquals(original.getTotalTokens(), deserialized.getTotalTokens()); } @Test void testTwoArgumentConstructorAndSerialization() throws Exception { - DefaultUsage usage = new DefaultUsage(100L, 50L); + DefaultUsage usage = new DefaultUsage(100, 50); // Test that the fields are set correctly - assertEquals(100L, usage.getPromptTokens()); - assertEquals(50L, usage.getGenerationTokens()); - assertEquals(150L, usage.getTotalTokens()); // 100 + 50 = 150 + assertEquals(100, usage.getPromptTokens()); + assertEquals(50, usage.getCompletionTokens()); + assertEquals(150, usage.getTotalTokens()); // 100 + 50 = 150 // Test serialization String json = this.objectMapper.writeValueAsString(usage); - assertEquals("{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}", json); + assertEquals("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}", + json); // Test deserialization DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class); - assertEquals(100L, deserializedUsage.getPromptTokens()); - assertEquals(50L, deserializedUsage.getGenerationTokens()); - assertEquals(150L, deserializedUsage.getTotalTokens()); + assertEquals(100, deserializedUsage.getPromptTokens()); + assertEquals(50, deserializedUsage.getCompletionTokens()); + assertEquals(150, deserializedUsage.getTotalTokens()); } @Test @@ -101,19 +103,19 @@ void testTwoArgumentConstructorWithNullValues() throws Exception { DefaultUsage usage = new DefaultUsage(null, null); // Test that null values are converted to 0 - assertEquals(0L, usage.getPromptTokens()); - assertEquals(0L, usage.getGenerationTokens()); - assertEquals(0L, usage.getTotalTokens()); + assertEquals(0, usage.getPromptTokens()); + assertEquals(0, usage.getCompletionTokens()); + assertEquals(0, usage.getTotalTokens()); // Test serialization String json = this.objectMapper.writeValueAsString(usage); - assertEquals("{\"promptTokens\":0,\"generationTokens\":0,\"totalTokens\":0}", json); + assertEquals("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0,\"generationTokens\":0}", json); // Test deserialization DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class); - assertEquals(0L, deserializedUsage.getPromptTokens()); - assertEquals(0L, deserializedUsage.getGenerationTokens()); - assertEquals(0L, deserializedUsage.getTotalTokens()); + assertEquals(0, deserializedUsage.getPromptTokens()); + assertEquals(0, deserializedUsage.getCompletionTokens()); + assertEquals(0, deserializedUsage.getTotalTokens()); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java index 5112ed627d6..06acfb0d218 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java @@ -16,7 +16,9 @@ package org.springframework.ai.chat.observation; +import java.util.HashMap; import java.util.List; +import java.util.Map; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; @@ -106,13 +108,22 @@ private Prompt generatePrompt() { static class TestUsage implements Usage { @Override - public Long getPromptTokens() { - return 1000L; + public Integer getPromptTokens() { + return 1000; } @Override - public Long getGenerationTokens() { - return 500L; + public Integer getCompletionTokens() { + return 500; + } + + @Override + public Map getNativeUsage() { + Map usage = new HashMap<>(); + usage.put("promptTokens", getPromptTokens()); + usage.put("completionTokens", getCompletionTokens()); + usage.put("totalTokens", getTotalTokens()); + return usage; } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java index f847088e1e2..6446b5b151e 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java @@ -16,7 +16,9 @@ package org.springframework.ai.chat.observation; +import java.util.HashMap; import java.util.List; +import java.util.Map; import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; @@ -30,6 +32,7 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -183,13 +186,22 @@ private Prompt generatePrompt() { static class TestUsage implements Usage { @Override - public Long getPromptTokens() { - return 1000L; + public Integer getPromptTokens() { + return 1000; } @Override - public Long getGenerationTokens() { - return 500L; + public Integer getCompletionTokens() { + return 500; + } + + @Override + public Map getNativeUsage() { + Map usage = new HashMap<>(); + usage.put("promptTokens", getPromptTokens()); + usage.put("completionTokens", getCompletionTokens()); + usage.put("totalTokens", getTotalTokens()); + return usage; } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java index 977c30a443a..7ed12ac161d 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java @@ -16,6 +16,7 @@ package org.springframework.ai.embedding.observation; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -134,13 +135,22 @@ private EmbeddingRequest generateEmbeddingRequest() { static class TestUsage implements Usage { @Override - public Long getPromptTokens() { - return 1000L; + public Integer getPromptTokens() { + return 1000; } @Override - public Long getGenerationTokens() { - return 0L; + public Integer getCompletionTokens() { + return 0; + } + + @Override + public Map getNativeUsage() { + Map usage = new HashMap<>(); + usage.put("promptTokens", getPromptTokens()); + usage.put("completionTokens", getCompletionTokens()); + usage.put("totalTokens", getTotalTokens()); + return usage; } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java index a97afb9d1d3..0b36b59e152 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java @@ -16,8 +16,10 @@ package org.springframework.ai.embedding.observation; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; @@ -104,18 +106,27 @@ private EmbeddingRequest generateEmbeddingRequest() { static class TestUsage implements Usage { @Override - public Long getPromptTokens() { - return 1000L; + public Integer getPromptTokens() { + return 1000; } @Override - public Long getGenerationTokens() { - return 0L; + public Integer getCompletionTokens() { + return 0; } @Override - public Long getTotalTokens() { - return 1000L; + public Integer getTotalTokens() { + return 1000; + } + + @Override + public Map getNativeUsage() { + Map usage = new HashMap<>(); + usage.put("promptTokens", getPromptTokens()); + usage.put("completionTokens", getCompletionTokens()); + usage.put("totalTokens", getTotalTokens()); + return usage; } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java index cac20367870..72b1c6cc169 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java @@ -36,10 +36,10 @@ */ public class UsageTests { - private Usage mockUsage(Long promptTokens, Long generationTokens) { + private Usage mockUsage(Integer promptTokens, Integer generationTokens) { Usage mockUsage = mock(Usage.class); doReturn(promptTokens).when(mockUsage).getPromptTokens(); - doReturn(generationTokens).when(mockUsage).getGenerationTokens(); + doReturn(generationTokens).when(mockUsage).getCompletionTokens(); doCallRealMethod().when(mockUsage).getTotalTokens(); return mockUsage; } @@ -47,7 +47,7 @@ private Usage mockUsage(Long promptTokens, Long generationTokens) { private void verifyUsage(Usage usage) { verify(usage, times(1)).getTotalTokens(); verify(usage, times(1)).getPromptTokens(); - verify(usage, times(1)).getGenerationTokens(); + verify(usage, times(1)).getCompletionTokens(); verifyNoMoreInteractions(usage); } @@ -63,27 +63,27 @@ void totalTokensIsZeroWhenNoPromptOrGenerationMetadataPresent() { @Test void totalTokensEqualsPromptTokens() { - Usage usage = mockUsage(10L, null); + Usage usage = mockUsage(10, null); - assertThat(usage.getTotalTokens()).isEqualTo(10L); + assertThat(usage.getTotalTokens()).isEqualTo(10); verifyUsage(usage); } @Test void totalTokensEqualsGenerationTokens() { - Usage usage = mockUsage(null, 15L); + Usage usage = mockUsage(null, 15); - assertThat(usage.getTotalTokens()).isEqualTo(15L); + assertThat(usage.getTotalTokens()).isEqualTo(15); verifyUsage(usage); } @Test void totalTokensEqualsPromptTokensPlusGenerationTokens() { - Usage usage = mockUsage(10L, 15L); + Usage usage = mockUsage(10, 15); - assertThat(usage.getTotalTokens()).isEqualTo(25L); + assertThat(usage.getTotalTokens()).isEqualTo(25); verifyUsage(usage); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java index 3ef061335c0..01dc2064de1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/observation/ModelUsageMetricsGeneratorTests.java @@ -16,6 +16,9 @@ package org.springframework.ai.model.observation; +import java.util.HashMap; +import java.util.Map; + import io.micrometer.common.KeyValue; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import io.micrometer.observation.Observation; @@ -38,7 +41,7 @@ class ModelUsageMetricsGeneratorTests { @Test void whenTokenUsageThenMetrics() { var meterRegistry = new SimpleMeterRegistry(); - var usage = new TestUsage(1000L, 500L, 1500L); + var usage = new TestUsage(1000, 500, 1500); ModelUsageMetricsGenerator.generate(usage, buildContext(), meterRegistry); assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(3); @@ -59,7 +62,7 @@ void whenTokenUsageThenMetrics() { @Test void whenPartialTokenUsageThenMetrics() { var meterRegistry = new SimpleMeterRegistry(); - var usage = new TestUsage(1000L, null, 1000L); + var usage = new TestUsage(1000, null, 1000); ModelUsageMetricsGenerator.generate(usage, buildContext(), meterRegistry); assertThat(meterRegistry.get(AiObservationMetricNames.TOKEN_USAGE.value()).meters()).hasSize(2); @@ -82,33 +85,42 @@ private Observation.Context buildContext() { static class TestUsage implements Usage { - private final Long promptTokens; + private final Integer promptTokens; - private final Long generationTokens; + private final Integer generationTokens; - private final Long totalTokens; + private final int totalTokens; - TestUsage(Long promptTokens, Long generationTokens, Long totalTokens) { + TestUsage(Integer promptTokens, Integer generationTokens, int totalTokens) { this.promptTokens = promptTokens; this.generationTokens = generationTokens; this.totalTokens = totalTokens; } @Override - public Long getPromptTokens() { + public Integer getPromptTokens() { return this.promptTokens; } @Override - public Long getGenerationTokens() { + public Integer getCompletionTokens() { return this.generationTokens; } @Override - public Long getTotalTokens() { + public Integer getTotalTokens() { return this.totalTokens; } + @Override + public Map getNativeUsage() { + Map usage = new HashMap<>(); + usage.put("promptTokens", getPromptTokens()); + usage.put("completionTokens", getCompletionTokens()); + usage.put("totalTokens", getTotalTokens()); + return usage; + } + } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 6ab471eecc3..0ebe376bcaf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -151,7 +151,7 @@ void streamingWithTokenUsage() { }).collect(Collectors.joining()); assertThat(streamingTokenUsage[0].getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage[0].getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage[0].getCompletionTokens()).isGreaterThan(0); assertThat(streamingTokenUsage[0].getTotalTokens()).isGreaterThan(0); assertThat(response).isNotEmpty();