Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,14 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
.stream()
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
.map(content -> new Generation(new AssistantMessage(content.text(), Map.of()),
ChatGenerationMetadata.from(chatCompletion.stopReason(), null)))
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()))
.toList();

List<Generation> allGenerations = new ArrayList<>(generations);

if (chatCompletion.stopReason() != null && generations.isEmpty()) {
Generation generation = new Generation(new AssistantMessage(null, Map.of()),
ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
allGenerations.add(generation);
}

Expand All @@ -322,7 +322,7 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {

AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
Generation toolCallGeneration = new Generation(assistantMessage,
ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
allGenerations.add(toolCallGeneration);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,10 @@ else if (data instanceof byte[] dataBytes) {
}

private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) {
return ChatGenerationMetadata.from(String.valueOf(choice.getFinishReason()), choice.getContentFilterResults());
return ChatGenerationMetadata.builder()
.finishReason(String.valueOf(choice.getFinishReason()))
.metadata("contentFilterResults", choice.getContentFilterResults())
.build();
}

private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ private void assertChoiceMetadata(Generation generation) {

assertThat(chatGenerationMetadata).isNotNull();
assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("stop");
assertContentFilterResults(chatGenerationMetadata.getContentFilterMetadata());
assertContentFilterResults(chatGenerationMetadata.get("contentFilterResults"));
}

private void assertContentFilterResultsForPrompt(ContentFilterResultDetailsForPrompt contentFilterResultForPrompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,14 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv
.stream()
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
.map(content -> new Generation(new AssistantMessage(content.text(), Map.of()),
ChatGenerationMetadata.from(response.stopReasonAsString(), null)))
ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build()))
.toList();

List<Generation> allGenerations = new ArrayList<>(generations);

if (response.stopReasonAsString() != null && generations.isEmpty()) {
Generation generation = new Generation(new AssistantMessage(null, Map.of()),
ChatGenerationMetadata.from(response.stopReasonAsString(), null));
ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build());
allGenerations.add(generation);
}

Expand All @@ -451,7 +451,7 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv

AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
Generation toolCallGeneration = new Generation(assistantMessage,
ChatGenerationMetadata.from(response.stopReasonAsString(), null));
ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build());
allGenerations.add(toolCallGeneration);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> respo

AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
Generation toolCallGeneration = new Generation(assistantMessage,
ChatGenerationMetadata.from("tool_use", null));
ChatGenerationMetadata.builder().finishReason("tool_use").build());

var chatResponseMetaData = ChatResponseMetadata.builder()
.withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens))
Expand Down Expand Up @@ -176,7 +176,9 @@ else if (nextEvent instanceof ContentBlockDeltaEvent contentBlockDeltaEvent) {

var generation = new Generation(
new AssistantMessage(contentBlockDeltaEvent.delta().text(), Map.of()),
ChatGenerationMetadata.from(lastAggregation.metadataAggregation().stopReason(), null));
ChatGenerationMetadata.builder()
.finishReason(lastAggregation.metadataAggregation().stopReason())
.build());

return new Aggregation(
MetadataAggregation.builder().copy(lastAggregation.metadataAggregation()).build(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
String stopReason = response.stopReason() != null ? response.stopReason() : null;
ChatGenerationMetadata chatGenerationMetadata = null;
if (response.amazonBedrockInvocationMetrics() != null) {
chatGenerationMetadata = ChatGenerationMetadata.from(stopReason,
response.amazonBedrockInvocationMetrics());
chatGenerationMetadata = ChatGenerationMetadata.builder()
.finishReason(stopReason)
.metadata("metrics", response.amazonBedrockInvocationMetrics())
.build();
}
return new ChatResponse(
List.of(new Generation(new AssistantMessage(response.completion()), chatGenerationMetadata)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public ChatResponse call(Prompt prompt) {
List<Generation> generations = response.content()
.stream()
.map(content -> new Generation(new AssistantMessage(content.text()),
ChatGenerationMetadata.from(response.stopReason(), null)))
ChatGenerationMetadata.builder().finishReason(response.stopReason()).build()))
.toList();

ChatResponseMetadata metadata = ChatResponseMetadata.builder()
Expand Down Expand Up @@ -116,9 +116,12 @@ public Flux<ChatResponse> stream(Prompt prompt) {
String content = response.type() == StreamingType.CONTENT_BLOCK_DELTA ? response.delta().text() : "";
ChatGenerationMetadata chatGenerationMetadata = null;
if (response.type() == StreamingType.MESSAGE_DELTA) {
chatGenerationMetadata = ChatGenerationMetadata.from(response.delta().stopReason(),
new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(),
response.usage().outputTokens()));
chatGenerationMetadata = ChatGenerationMetadata.builder()
.finishReason(response.delta().stopReason())
.metadata("usage",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move all these common metadata keys to constants?

new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(),
response.usage().outputTokens()))
.build();
}
return new ChatResponse(List.of(new Generation(new AssistantMessage(content), chatGenerationMetadata)));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
if (g.isFinished()) {
String finishReason = g.finishReason().name();
Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics());
return new ChatResponse(List
.of(new Generation(new AssistantMessage(""), ChatGenerationMetadata.from(finishReason, usage))));
return new ChatResponse(List.of(new Generation(new AssistantMessage(""),
ChatGenerationMetadata.builder().finishReason(finishReason).metadata("usage", usage).build())));
}
return new ChatResponse(List.of(new Generation(new AssistantMessage(g.text()))));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public ChatResponse call(Prompt prompt) {
return new ChatResponse(response.completions()
.stream()
.map(completion -> new Generation(new AssistantMessage(completion.data().text()),
ChatGenerationMetadata.from(completion.finishReason().reason(), null)))
ChatGenerationMetadata.builder().finishReason(completion.finishReason().reason()).build()))
.toList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ public ChatResponse call(Prompt prompt) {
LlamaChatResponse response = this.chatApi.chatCompletion(request);

return new ChatResponse(List.of(new Generation(new AssistantMessage(response.generation()),
ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response)))));
ChatGenerationMetadata.builder()
.finishReason(response.stopReason().name())
.metadata("usage", extractUsage(response))
.build())));
}

@Override
Expand All @@ -83,7 +86,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
return fluxResponse.map(response -> {
String stopReason = response.stopReason() != null ? response.stopReason().name() : null;
return new ChatResponse(List.of(new Generation(new AssistantMessage(response.generation()),
ChatGenerationMetadata.from(stopReason, extractUsage(response)))));
ChatGenerationMetadata.builder()
.finishReason(stopReason)
.metadata("usage", extractUsage(response))
.build())));
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,17 @@ public Flux<ChatResponse> stream(Prompt prompt) {
ChatGenerationMetadata chatGenerationMetadata = null;
if (chunk.amazonBedrockInvocationMetrics() != null) {
String completionReason = chunk.completionReason().name();
chatGenerationMetadata = ChatGenerationMetadata.from(completionReason,
chunk.amazonBedrockInvocationMetrics());
chatGenerationMetadata = ChatGenerationMetadata.builder()
.finishReason(completionReason)
.metadata("usage", chunk.amazonBedrockInvocationMetrics())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be "metrics" ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class BedrockUsage has AmazonBedrockInvocationMetrics getUsage() so maybe it is the other way around, the class name should change? Prob not that important as we will remove this on the switch to bedrock converse.

.build();
}
else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount() != null) {
String completionReason = chunk.completionReason().name();
chatGenerationMetadata = ChatGenerationMetadata.from(completionReason, extractUsage(chunk));
chatGenerationMetadata = ChatGenerationMetadata.builder()
.finishReason(completionReason)
.metadata("usage", extractUsage(chunk))
.build();
}
return new ChatResponse(
List.of(new Generation(new AssistantMessage(chunk.outputText()), chatGenerationMetadata)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met
});
var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generationMetadata = ChatGenerationMetadata.from(finishReason, null);
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
return new Generation(assistantMessage, generationMetadata);
}

Expand Down Expand Up @@ -408,7 +408,7 @@ private Generation buildGeneration(ChatCompletionMessage message, ChatCompletion

var assistantMessage = new AssistantMessage(message.content(), metadata, toolCalls);
String finishReason = (completionFinishReason != null ? completionFinishReason.name() : "");
var generationMetadata = ChatGenerationMetadata.from(finishReason, null);
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
return new Generation(assistantMessage, generationMetadata);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata)

var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generationMetadata = ChatGenerationMetadata.from(finishReason, null);
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
return new Generation(assistantMessage, generationMetadata);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met

var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generationMetadata = ChatGenerationMetadata.from(finishReason, null);
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
return new Generation(assistantMessage, generationMetadata);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ private List<Generation> toGenerations(com.oracle.bmc.generativeaiinference.resp
BaseChatResponse cr = ociChatResponse.getChatResult().getChatResponse();
if (cr instanceof CohereChatResponse resp) {
List<Generation> generations = new ArrayList<>();
ChatGenerationMetadata metadata = ChatGenerationMetadata.from(resp.getFinishReason().getValue(), null);
ChatGenerationMetadata metadata = ChatGenerationMetadata.builder()
.finishReason(resp.getFinishReason().getValue())
.build();
AssistantMessage message = new AssistantMessage(resp.getText(), Map.of());
generations.add(new Generation(message, metadata));
return generations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ public ChatResponse call(Prompt prompt) {

ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) {
generationMetadata = ChatGenerationMetadata.from(ollamaResponse.doneReason(), null);
generationMetadata = ChatGenerationMetadata.builder()
.finishReason(ollamaResponse.doneReason())
.build();
}

var generator = new Generation(assistantMessage, generationMetadata);
Expand Down Expand Up @@ -217,7 +219,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {

ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
generationMetadata = ChatGenerationMetadata.from(chunk.doneReason(), null);
generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build();
}

var generator = new Generation(assistantMessage, generationMetadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata)

var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generationMetadata = ChatGenerationMetadata.from(finishReason, null);
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
return new Generation(assistantMessage, generationMetadata);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void aiResponseContainsAiMetadata() {
ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata();
assertThat(chatGenerationMetadata).isNotNull();
assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("STOP");
assertThat(chatGenerationMetadata.<Object>getContentFilterMetadata()).isNull();
assertThat(chatGenerationMetadata.getContentFilters()).isEmpty();
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,9 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
Map<String, Object> messageMetadata = Map.of("candidateIndex", candidateIndex, "finishReason",
candidateFinishReason);

ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.from(candidateFinishReason.name(), null);
ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.builder()
.finishReason(candidateFinishReason.name())
.build();

boolean isFunctionCall = candidate.getContent().getPartsList().stream().allMatch(Part::hasFunctionCall);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ public ChatResponse call(Prompt prompt) {

WatsonxAiChatResponse response = this.watsonxAiApi.generate(request).getBody();
var generation = new Generation(new AssistantMessage(response.results().get(0).generatedText()),
ChatGenerationMetadata.from(response.results().get(0).stopReason(), response.system()));
ChatGenerationMetadata.builder()
.finishReason(response.results().get(0).stopReason())
.metadata("system", response.system())
.build());

return new ChatResponse(List.of(generation));
}
Expand All @@ -103,7 +106,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {

ChatGenerationMetadata metadata = ChatGenerationMetadata.NULL;
if (chunk.system() != null) {
metadata = ChatGenerationMetadata.from(chunk.results().get(0).stopReason(), chunk.system());
metadata = ChatGenerationMetadata.builder()
.finishReason(chunk.results().get(0).stopReason())
.metadata("system", chunk.system())
.build();
}

Generation generation = new Generation(assistantMessage, metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,11 @@ public void testCallMethod() {
.willReturn(ResponseEntity.of(Optional.of(fakeResponse)));

Generation expectedGenerator = new Generation(new AssistantMessage("LLM response"),
ChatGenerationMetadata.from("max_tokens",
Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning")))));
ChatGenerationMetadata.builder()
.finishReason("max_tokens")
.metadata("system",
Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning"))))
.build());

ChatResponse expectedResponse = new ChatResponse(List.of(expectedGenerator));
ChatResponse response = chatModel.call(prompt);
Expand Down Expand Up @@ -206,8 +209,12 @@ public void testStreamMethod() {
Flux<WatsonxAiChatResponse> fakeResponse = Flux.just(fakeResponseFirst, fakeResponseSecond);
given(mockChatApi.generateStreaming(any(WatsonxAiChatRequest.class))).willReturn(fakeResponse);

Generation firstGen = new Generation(new AssistantMessage("LLM resp"), ChatGenerationMetadata.from("max_tokens",
Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning")))));
Generation firstGen = new Generation(new AssistantMessage("LLM resp"),
ChatGenerationMetadata.builder()
.finishReason("max_tokens")
.metadata("system",
Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning"))))
.build());
Generation secondGen = new Generation(new AssistantMessage("onse"));

Flux<ChatResponse> response = chatModel.stream(prompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met

var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generationMetadata = ChatGenerationMetadata.from(finishReason, null);
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
return new Generation(assistantMessage, generationMetadata);
}

Expand Down
Loading