Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
Expand Down Expand Up @@ -237,7 +237,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody();
AnthropicApi.Usage usage = completionResponse.usage();

Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage())
Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(completionResponse.usage())
: new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);

Expand All @@ -256,6 +256,11 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
return response;
}

private DefaultUsage getDefaultUsage(AnthropicApi.Usage usage) {
return new DefaultUsage(usage.inputTokens(), usage.outputTokens(), usage.inputTokens() + usage.outputTokens(),
usage);
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
Expand All @@ -282,7 +287,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
// @formatter:off
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
AnthropicApi.Usage usage = chatCompletionResponse.usage();
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage();
Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(chatCompletionResponse.usage()) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);

Expand Down Expand Up @@ -352,7 +357,7 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
return from(result, AnthropicUsage.from(result.usage()));
return from(result, this.getDefaultUsage(result.usage()));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private static void validateChatResponseMetadata(ChatResponse response, String m
assertThat(response.getMetadata().getId()).isNotEmpty();
assertThat(response.getMetadata().getModel()).containsIgnoringCase(model);
assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive();
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive();
assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive();
assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive();
}

Expand All @@ -100,11 +100,11 @@ void roleTest(String modelName) {
AnthropicChatOptions.builder().model(modelName).build());
ChatResponse response = this.chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0);
assertThat(response.getMetadata().getUsage().getCompletionTokens()).isGreaterThan(0);
assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0);
assertThat(response.getMetadata().getUsage().getTotalTokens())
.isEqualTo(response.getMetadata().getUsage().getPromptTokens()
+ response.getMetadata().getUsage().getGenerationTokens());
+ response.getMetadata().getUsage().getCompletionTokens());
Generation generation = response.getResults().get(0);
assertThat(generation.getOutput().getText()).contains("Blackbeard");
assertThat(generation.getMetadata().getFinishReason()).isEqualTo("end_turn");
Expand Down Expand Up @@ -139,11 +139,11 @@ void streamingWithTokenUsage() {
var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage();

assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0);

assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens());
// assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens());
// assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens());
// assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens());

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ private void validate(ChatResponseMetadata responseMetadata, String finishReason
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getPromptTokens()))
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getGenerationTokens()))
String.valueOf(responseMetadata.getUsage().getCompletionTokens()))
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getTotalTokens()))
.hasBeenStarted()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import reactor.core.publisher.Flux;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
Expand Down Expand Up @@ -194,13 +194,14 @@ public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptM
}

public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage();
Usage usage = (chatCompletions.getUsage() != null) ? getDefaultUsage(chatCompletions.getUsage())
: new EmptyUsage();
return from(chatCompletions, promptFilterMetadata, usage);
}

public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata,
CompletionsUsage usage) {
return from(chatCompletions, promptFilterMetadata, AzureOpenAiUsage.from(usage));
return from(chatCompletions, promptFilterMetadata, getDefaultUsage(usage));
}

public static ChatResponseMetadata from(ChatResponse chatResponse, Usage usage) {
Expand All @@ -217,6 +218,10 @@ public static ChatResponseMetadata from(ChatResponse chatResponse, Usage usage)
return builder.build();
}

private static DefaultUsage getDefaultUsage(CompletionsUsage usage) {
return new DefaultUsage(usage.getPromptTokens(), usage.getCompletionTokens(), usage.getTotalTokens(), usage);
}

public AzureOpenAiChatOptions getDefaultOptions() {
return AzureOpenAiChatOptions.fromOptions(this.defaultOptions);
}
Expand Down Expand Up @@ -321,7 +326,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
}
// Accumulate the usage from the previous chat response
CompletionsUsage usage = chatCompletion.getUsage();
Usage currentChatResponseUsage = usage != null ? AzureOpenAiUsage.from(usage) : new EmptyUsage();
Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
return toChatResponse(chatCompletion, accumulatedUsage);
}).buffer(2, 1).map(bufferList -> {
Expand Down Expand Up @@ -412,7 +417,7 @@ private ChatResponse toChatResponse(ChatCompletions chatCompletions, ChatRespons
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
Usage currentUsage = null;
if (chatCompletions.getUsage() != null) {
currentUsage = AzureOpenAiUsage.from(chatCompletions);
currentUsage = getDefaultUsage(chatCompletions.getUsage());
}
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, cumulativeUsage));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.models.EmbeddingsOptions;
import com.azure.ai.openai.models.EmbeddingsUsage;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiEmbeddingUsage;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
Expand Down Expand Up @@ -159,10 +160,14 @@ EmbeddingsOptions toEmbeddingOptions(EmbeddingRequest embeddingRequest) {
private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) {
List<Embedding> data = generateEmbeddingList(embeddings.getData());
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
metadata.setUsage(AzureOpenAiEmbeddingUsage.from(embeddings.getUsage()));
metadata.setUsage(getDefaultUsage(embeddings.getUsage()));
return new EmbeddingResponse(data, metadata);
}

private DefaultUsage getDefaultUsage(EmbeddingsUsage usage) {
return new DefaultUsage(usage.getPromptTokens(), 0, usage.getTotalTokens(), usage);
}

private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
List<Embedding> data = new ArrayList<>();
for (EmbeddingItem nativeDatum : nativeData) {
Expand Down

This file was deleted.

Loading
Loading