Skip to content

Commit 482c07a

Browse files
committed
Refactor Usage handling
- Remove model specific Usage implementations - Add `Object getNativeUsage()` to Usage interface - This will allow the model specific Usage data to be returned - At the client side, client needs to cast the return type of `getNativeUsage` into the corresponding Usage returned by the model API - Rename `generationTokens` to `completionTokens` - Since `completion` token name is more common among the models, renaming generation tokens into completion tokens - Change the prompt, completion and total token return types to Integer - Use DefaultUsage for most of the model specific usage handling - When initializing set the native usage to the model specific usage type
1 parent f5761de commit 482c07a

File tree

83 files changed

+582
-1375
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+582
-1375
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@
4040
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
4141
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
4242
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
43-
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
4443
import org.springframework.ai.chat.messages.AssistantMessage;
4544
import org.springframework.ai.chat.messages.MessageType;
4645
import org.springframework.ai.chat.messages.ToolResponseMessage;
4746
import org.springframework.ai.chat.messages.UserMessage;
4847
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
4948
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
49+
import org.springframework.ai.chat.metadata.DefaultUsage;
5050
import org.springframework.ai.chat.metadata.EmptyUsage;
5151
import org.springframework.ai.chat.metadata.Usage;
5252
import org.springframework.ai.chat.metadata.UsageUtils;
@@ -237,7 +237,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
237237
AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody();
238238
AnthropicApi.Usage usage = completionResponse.usage();
239239

240-
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage())
240+
Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(completionResponse.usage())
241241
: new EmptyUsage();
242242
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
243243

@@ -256,6 +256,11 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
256256
return response;
257257
}
258258

259+
private DefaultUsage getDefaultUsage(AnthropicApi.Usage usage) {
260+
return new DefaultUsage(usage.inputTokens(), usage.outputTokens(), usage.inputTokens() + usage.outputTokens(),
261+
usage);
262+
}
263+
259264
@Override
260265
public Flux<ChatResponse> stream(Prompt prompt) {
261266
return this.internalStream(prompt, null);
@@ -282,7 +287,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
282287
// @formatter:off
283288
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
284289
AnthropicApi.Usage usage = chatCompletionResponse.usage();
285-
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage();
290+
Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(chatCompletionResponse.usage()) : new EmptyUsage();
286291
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
287292
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);
288293

@@ -352,7 +357,7 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
352357
}
353358

354359
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
355-
return from(result, AnthropicUsage.from(result.usage()));
360+
return from(result, this.getDefaultUsage(result.usage()));
356361
}
357362

358363
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java

Lines changed: 0 additions & 66 deletions
This file was deleted.

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ private static void validateChatResponseMetadata(ChatResponse response, String m
8484
assertThat(response.getMetadata().getId()).isNotEmpty();
8585
assertThat(response.getMetadata().getModel()).containsIgnoringCase(model);
8686
assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive();
87-
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive();
87+
assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive();
8888
assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive();
8989
}
9090

@@ -100,11 +100,11 @@ void roleTest(String modelName) {
100100
AnthropicChatOptions.builder().model(modelName).build());
101101
ChatResponse response = this.chatModel.call(prompt);
102102
assertThat(response.getResults()).hasSize(1);
103-
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0);
103+
assertThat(response.getMetadata().getUsage().getCompletionTokens()).isGreaterThan(0);
104104
assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0);
105105
assertThat(response.getMetadata().getUsage().getTotalTokens())
106106
.isEqualTo(response.getMetadata().getUsage().getPromptTokens()
107-
+ response.getMetadata().getUsage().getGenerationTokens());
107+
+ response.getMetadata().getUsage().getCompletionTokens());
108108
Generation generation = response.getResults().get(0);
109109
assertThat(generation.getOutput().getText()).contains("Blackbeard");
110110
assertThat(generation.getMetadata().getFinishReason()).isEqualTo("end_turn");
@@ -139,11 +139,11 @@ void streamingWithTokenUsage() {
139139
var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage();
140140

141141
assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0);
142-
assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0);
142+
assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0);
143143
assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0);
144144

145145
assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens());
146-
// assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens());
146+
// assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens());
147147
// assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens());
148148

149149
}

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ private void validate(ChatResponseMetadata responseMetadata, String finishReason
147147
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
148148
String.valueOf(responseMetadata.getUsage().getPromptTokens()))
149149
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(),
150-
String.valueOf(responseMetadata.getUsage().getGenerationTokens()))
150+
String.valueOf(responseMetadata.getUsage().getCompletionTokens()))
151151
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
152152
String.valueOf(responseMetadata.getUsage().getTotalTokens()))
153153
.hasBeenStarted()

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@
6060
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
6161
import reactor.core.publisher.Flux;
6262

63-
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
6463
import org.springframework.ai.chat.messages.AssistantMessage;
6564
import org.springframework.ai.chat.messages.Message;
6665
import org.springframework.ai.chat.messages.ToolResponseMessage;
6766
import org.springframework.ai.chat.messages.UserMessage;
6867
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
6968
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
69+
import org.springframework.ai.chat.metadata.DefaultUsage;
7070
import org.springframework.ai.chat.metadata.EmptyUsage;
7171
import org.springframework.ai.chat.metadata.PromptMetadata;
7272
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
@@ -194,13 +194,14 @@ public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptM
194194
}
195195

196196
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
197-
Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage();
197+
Usage usage = (chatCompletions.getUsage() != null) ? getDefaultUsage(chatCompletions.getUsage())
198+
: new EmptyUsage();
198199
return from(chatCompletions, promptFilterMetadata, usage);
199200
}
200201

201202
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata,
202203
CompletionsUsage usage) {
203-
return from(chatCompletions, promptFilterMetadata, AzureOpenAiUsage.from(usage));
204+
return from(chatCompletions, promptFilterMetadata, getDefaultUsage(usage));
204205
}
205206

206207
public static ChatResponseMetadata from(ChatResponse chatResponse, Usage usage) {
@@ -217,6 +218,10 @@ public static ChatResponseMetadata from(ChatResponse chatResponse, Usage usage)
217218
return builder.build();
218219
}
219220

221+
private static DefaultUsage getDefaultUsage(CompletionsUsage usage) {
222+
return new DefaultUsage(usage.getPromptTokens(), usage.getCompletionTokens(), usage.getTotalTokens(), usage);
223+
}
224+
220225
public AzureOpenAiChatOptions getDefaultOptions() {
221226
return AzureOpenAiChatOptions.fromOptions(this.defaultOptions);
222227
}
@@ -321,7 +326,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
321326
}
322327
// Accumulate the usage from the previous chat response
323328
CompletionsUsage usage = chatCompletion.getUsage();
324-
Usage currentChatResponseUsage = usage != null ? AzureOpenAiUsage.from(usage) : new EmptyUsage();
329+
Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
325330
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
326331
return toChatResponse(chatCompletion, accumulatedUsage);
327332
}).buffer(2, 1).map(bufferList -> {
@@ -412,7 +417,7 @@ private ChatResponse toChatResponse(ChatCompletions chatCompletions, ChatRespons
412417
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
413418
Usage currentUsage = null;
414419
if (chatCompletions.getUsage() != null) {
415-
currentUsage = AzureOpenAiUsage.from(chatCompletions);
420+
currentUsage = getDefaultUsage(chatCompletions.getUsage());
416421
}
417422
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
418423
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, cumulativeUsage));

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@
2323
import com.azure.ai.openai.models.EmbeddingItem;
2424
import com.azure.ai.openai.models.Embeddings;
2525
import com.azure.ai.openai.models.EmbeddingsOptions;
26+
import com.azure.ai.openai.models.EmbeddingsUsage;
2627
import io.micrometer.observation.ObservationRegistry;
2728
import org.slf4j.Logger;
2829
import org.slf4j.LoggerFactory;
2930

30-
import org.springframework.ai.azure.openai.metadata.AzureOpenAiEmbeddingUsage;
31+
import org.springframework.ai.chat.metadata.DefaultUsage;
3132
import org.springframework.ai.document.Document;
3233
import org.springframework.ai.document.MetadataMode;
3334
import org.springframework.ai.embedding.AbstractEmbeddingModel;
@@ -159,10 +160,14 @@ EmbeddingsOptions toEmbeddingOptions(EmbeddingRequest embeddingRequest) {
159160
private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) {
160161
List<Embedding> data = generateEmbeddingList(embeddings.getData());
161162
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
162-
metadata.setUsage(AzureOpenAiEmbeddingUsage.from(embeddings.getUsage()));
163+
metadata.setUsage(getDefaultUsage(embeddings.getUsage()));
163164
return new EmbeddingResponse(data, metadata);
164165
}
165166

167+
private DefaultUsage getDefaultUsage(EmbeddingsUsage usage) {
168+
return new DefaultUsage(usage.getPromptTokens(), 0, usage.getTotalTokens(), usage);
169+
}
170+
166171
private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
167172
List<Embedding> data = new ArrayList<>();
168173
for (EmbeddingItem nativeDatum : nativeData) {

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiEmbeddingUsage.java

Lines changed: 0 additions & 68 deletions
This file was deleted.

0 commit comments

Comments
 (0)