Skip to content

Commit 71e16c9

Browse files
committed
This change introduces a new field for tracking cached tokens in the
OpenAI API response. It extends the Usage record to include PromptTokensDetails, allowing for more granular token usage reporting. The OpenAiUsage class is updated to expose this new data, and corresponding unit tests are added to verify the behavior. This enhancement provides more detailed insights into token usage, indicating how many of the prompt tokens were a cache hit.
1 parent 4c83fe8 commit 71e16c9

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -939,17 +939,29 @@ public record TopLogProbs(// @formatter:off
939939
* @param promptTokens Number of tokens in the prompt.
940940
* @param totalTokens Total number of tokens used in the request (prompt +
941941
* completion).
942-
* @param completionTokenDetails Breakdown of tokens used in a completion
942+
* @param promptTokensDetails Breakdown of tokens used in the prompt.
943+
* @param completionTokenDetails Breakdown of tokens used in a completion.
943944
*/
944945
@JsonInclude(Include.NON_NULL)
945946
public record Usage(// @formatter:off
946947
@JsonProperty("completion_tokens") Integer completionTokens,
947948
@JsonProperty("prompt_tokens") Integer promptTokens,
948949
@JsonProperty("total_tokens") Integer totalTokens,
950+
@JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails,
949951
@JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) {// @formatter:on
950952

951953
public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) {
952-
this(completionTokens, promptTokens, totalTokens, null);
954+
this(completionTokens, promptTokens, totalTokens, null, null);
955+
}
956+
957+
/**
958+
* Breakdown of tokens used in the prompt
959+
*
960+
* @param cachedTokens Cached tokens present in the prompt.
961+
*/
962+
@JsonInclude(Include.NON_NULL)
963+
public record PromptTokensDetails(// @formatter:off
964+
@JsonProperty("cached_tokens") Integer cachedTokens) {// @formatter:on
953965
}
954966

955967
/**

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ public Long getGenerationTokens() {
5858
return generationTokens != null ? generationTokens.longValue() : 0;
5959
}
6060

61+
public Long getCachedTokens() {
62+
OpenAiApi.Usage.PromptTokensDetails promptTokenDetails = getUsage().promptTokensDetails();
63+
Integer cachedTokens = promptTokenDetails != null ? promptTokenDetails.cachedTokens() : null;
64+
return cachedTokens != null ? cachedTokens.longValue() : 0;
65+
}
66+
6167
public Long getReasoningTokens() {
6268
OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails();
6369
Integer reasoningTokens = completionTokenDetails != null ? completionTokenDetails.reasoningTokens() : null;

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,27 +55,44 @@ void whenTotalTokensIsNull() {
5555
}
5656

5757
@Test
58-
void whenCompletionTokenDetailsIsNull() {
59-
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null);
58+
void whenPromptAndCompletionTokensDetailsIsNull() {
59+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null, null);
6060
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
6161
assertThat(usage.getTotalTokens()).isEqualTo(300);
62+
assertThat(usage.getCachedTokens()).isEqualTo(0);
6263
assertThat(usage.getReasoningTokens()).isEqualTo(0);
6364
}
6465

6566
@Test
6667
void whenReasoningTokensIsNull() {
67-
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300,
68+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
6869
new OpenAiApi.Usage.CompletionTokenDetails(null));
6970
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
7071
assertThat(usage.getReasoningTokens()).isEqualTo(0);
7172
}
7273

7374
@Test
7475
void whenCompletionTokenDetailsIsPresent() {
75-
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300,
76+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null,
7677
new OpenAiApi.Usage.CompletionTokenDetails(50));
7778
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
7879
assertThat(usage.getReasoningTokens()).isEqualTo(50);
7980
}
8081

82+
@Test
83+
void whenCacheTokensIsNull() {
84+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(null),
85+
null);
86+
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
87+
assertThat(usage.getCachedTokens()).isEqualTo(0);
88+
}
89+
90+
@Test
91+
void whenCacheTokensIsPresent() {
92+
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, new OpenAiApi.Usage.PromptTokensDetails(15),
93+
null);
94+
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
95+
assertThat(usage.getCachedTokens()).isEqualTo(15);
96+
}
97+
8198
}

0 commit comments

Comments
 (0)