From ece3665874ce383170682adcf35002b647d6d7d9 Mon Sep 17 00:00:00 2001 From: dafriz Date: Sun, 22 Sep 2024 21:53:20 +1000 Subject: [PATCH] Add completion_tokens_details with reasoning_tokens to OpenAi Usage --- .../ai/openai/api/OpenAiApi.java | 18 +++++++++++++- .../ai/openai/metadata/OpenAiUsage.java | 6 +++++ .../ai/openai/metadata/OpenAiUsageTests.java | 24 +++++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 64e0d58cef6..094e1c1f1bc 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -929,12 +929,28 @@ public record TopLogProbs(// @formatter:off * @param promptTokens Number of tokens in the prompt. * @param totalTokens Total number of tokens used in the request (prompt + * completion). + * @param completionTokenDetails Breakdown of tokens used in a completion */ @JsonInclude(Include.NON_NULL) public record Usage(// @formatter:off @JsonProperty("completion_tokens") Integer completionTokens, @JsonProperty("prompt_tokens") Integer promptTokens, - @JsonProperty("total_tokens") Integer totalTokens) {// @formatter:on + @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) {// @formatter:on + + public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { + this(completionTokens, promptTokens, totalTokens, null); + } + + /** + * Breakdown of tokens used in a completion + * + * @param reasoningTokens Number of tokens generated by the model for reasoning. + */ + @JsonInclude(Include.NON_NULL) + public record CompletionTokenDetails(// @formatter:off + @JsonProperty("reasoning_tokens") Integer reasoningTokens) {// @formatter:on + } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java index 821a0325d01..add5d896b57 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java @@ -58,6 +58,12 @@ public Long getGenerationTokens() { return generationTokens != null ? generationTokens.longValue() : 0; } + public Long getReasoningTokens() { + OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails(); + Integer reasoningTokens = completionTokenDetails != null ? completionTokenDetails.reasoningTokens() : null; + return reasoningTokens != null ? reasoningTokens.longValue() : 0; + } + @Override public Long getTotalTokens() { Integer totalTokens = getUsage().totalTokens(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java index 58c378f35bb..b9215b4c3df 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java @@ -54,4 +54,28 @@ void whenTotalTokensIsNull() { assertThat(usage.getTotalTokens()).isEqualTo(300); } + @Test + void whenCompletionTokenDetailsIsNull() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getTotalTokens()).isEqualTo(300); + assertThat(usage.getReasoningTokens()).isEqualTo(0); + } + + @Test + void whenReasoningTokensIsNull() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, + new OpenAiApi.Usage.CompletionTokenDetails(null)); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getReasoningTokens()).isEqualTo(0); + } + + @Test + void whenCompletionTokenDetailsIsPresent() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, + new OpenAiApi.Usage.CompletionTokenDetails(50)); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getReasoningTokens()).isEqualTo(50); + } + }