diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index fc4da3e8bd6..5904697470b 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -291,14 +291,14 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) { .stream() .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) .map(content -> new Generation(new AssistantMessage(content.text(), Map.of()), - ChatGenerationMetadata.from(chatCompletion.stopReason(), null))) + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())) .toList(); List allGenerations = new ArrayList<>(generations); if (chatCompletion.stopReason() != null && generations.isEmpty()) { Generation generation = new Generation(new AssistantMessage(null, Map.of()), - ChatGenerationMetadata.from(chatCompletion.stopReason(), null)); + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); allGenerations.add(generation); } @@ -322,7 +322,7 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) { AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); Generation toolCallGeneration = new Generation(assistantMessage, - ChatGenerationMetadata.from(chatCompletion.stopReason(), null)); + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); allGenerations.add(toolCallGeneration); } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index e670d538293..c0980df78b0 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -459,7 +459,10 @@ else if (data instanceof byte[] dataBytes) { } private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) { - return ChatGenerationMetadata.from(String.valueOf(choice.getFinishReason()), choice.getContentFilterResults()); + return ChatGenerationMetadata.builder() + .finishReason(String.valueOf(choice.getFinishReason())) + .metadata("contentFilterResults", choice.getContentFilterResults()) + .build(); } private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) { diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java index d962252ac71..8de56f2dbd1 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java @@ -128,7 +128,7 @@ private void assertChoiceMetadata(Generation generation) { assertThat(chatGenerationMetadata).isNotNull(); assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("stop"); - assertContentFilterResults(chatGenerationMetadata.getContentFilterMetadata()); + assertContentFilterResults(chatGenerationMetadata.get("contentFilterResults")); } private void assertContentFilterResultsForPrompt(ContentFilterResultDetailsForPrompt contentFilterResultForPrompt, diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index d2f200c5035..6116cda4210 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -419,14 +419,14 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv .stream() .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) .map(content -> new Generation(new AssistantMessage(content.text(), Map.of()), - ChatGenerationMetadata.from(response.stopReasonAsString(), null))) + ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build())) .toList(); List allGenerations = new ArrayList<>(generations); if (response.stopReasonAsString() != null && generations.isEmpty()) { Generation generation = new Generation(new AssistantMessage(null, Map.of()), - ChatGenerationMetadata.from(response.stopReasonAsString(), null)); + ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build()); allGenerations.add(generation); } @@ -451,7 +451,7 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); Generation toolCallGeneration = new Generation(assistantMessage, - ChatGenerationMetadata.from(response.stopReasonAsString(), null)); + ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build()); allGenerations.add(toolCallGeneration); } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index fa8da7dbe4b..88047c0cc8d 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -141,7 +141,7 @@ public static Flux toChatResponse(Flux respo AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); Generation toolCallGeneration = new Generation(assistantMessage, - ChatGenerationMetadata.from("tool_use", null)); + ChatGenerationMetadata.builder().finishReason("tool_use").build()); var chatResponseMetaData = ChatResponseMetadata.builder() .withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens)) @@ -176,7 +176,9 @@ else if (nextEvent instanceof ContentBlockDeltaEvent contentBlockDeltaEvent) { var generation = new Generation( new AssistantMessage(contentBlockDeltaEvent.delta().text(), Map.of()), - ChatGenerationMetadata.from(lastAggregation.metadataAggregation().stopReason(), null)); + ChatGenerationMetadata.builder() + .finishReason(lastAggregation.metadataAggregation().stopReason()) + .build()); return new Aggregation( MetadataAggregation.builder().copy(lastAggregation.metadataAggregation()).build(), diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java index ef93a2bdac6..61f2e7fa4c9 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java @@ -83,8 +83,10 @@ public Flux stream(Prompt prompt) { String stopReason = response.stopReason() != null ? response.stopReason() : null; ChatGenerationMetadata chatGenerationMetadata = null; if (response.amazonBedrockInvocationMetrics() != null) { - chatGenerationMetadata = ChatGenerationMetadata.from(stopReason, - response.amazonBedrockInvocationMetrics()); + chatGenerationMetadata = ChatGenerationMetadata.builder() + .finishReason(stopReason) + .metadata("metrics", response.amazonBedrockInvocationMetrics()) + .build(); } return new ChatResponse( List.of(new Generation(new AssistantMessage(response.completion()), chatGenerationMetadata))); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index db122878569..a9d1a620f2f 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -88,7 +88,7 @@ public ChatResponse call(Prompt prompt) { List generations = response.content() .stream() .map(content -> new Generation(new AssistantMessage(content.text()), - ChatGenerationMetadata.from(response.stopReason(), null))) + ChatGenerationMetadata.builder().finishReason(response.stopReason()).build())) .toList(); ChatResponseMetadata metadata = ChatResponseMetadata.builder() @@ -116,9 +116,12 @@ public Flux stream(Prompt prompt) { String content = response.type() == StreamingType.CONTENT_BLOCK_DELTA ? response.delta().text() : ""; ChatGenerationMetadata chatGenerationMetadata = null; if (response.type() == StreamingType.MESSAGE_DELTA) { - chatGenerationMetadata = ChatGenerationMetadata.from(response.delta().stopReason(), - new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(), - response.usage().outputTokens())); + chatGenerationMetadata = ChatGenerationMetadata.builder() + .finishReason(response.delta().stopReason()) + .metadata("usage", + new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(), + response.usage().outputTokens())) + .build(); } return new ChatResponse(List.of(new Generation(new AssistantMessage(content), chatGenerationMetadata))); }); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java index b925b3bd813..ed73af42c99 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java @@ -78,8 +78,8 @@ public Flux stream(Prompt prompt) { if (g.isFinished()) { String finishReason = g.finishReason().name(); Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); - return new ChatResponse(List - .of(new Generation(new AssistantMessage(""), ChatGenerationMetadata.from(finishReason, usage)))); + return new ChatResponse(List.of(new Generation(new AssistantMessage(""), + ChatGenerationMetadata.builder().finishReason(finishReason).metadata("usage", usage).build()))); } return new ChatResponse(List.of(new Generation(new AssistantMessage(g.text())))); }); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java index d18492b5380..2258dbd84ca 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java @@ -70,7 +70,7 @@ public ChatResponse call(Prompt prompt) { return new ChatResponse(response.completions() .stream() .map(completion -> new Generation(new AssistantMessage(completion.data().text()), - ChatGenerationMetadata.from(completion.finishReason().reason(), null))) + ChatGenerationMetadata.builder().finishReason(completion.finishReason().reason()).build())) .toList()); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java index b158a7a3fc1..c6658641c43 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java @@ -70,7 +70,10 @@ public ChatResponse call(Prompt prompt) { LlamaChatResponse response = this.chatApi.chatCompletion(request); return new ChatResponse(List.of(new Generation(new AssistantMessage(response.generation()), - ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response))))); + ChatGenerationMetadata.builder() + .finishReason(response.stopReason().name()) + .metadata("usage", extractUsage(response)) + .build()))); } @Override @@ -83,7 +86,10 @@ public Flux stream(Prompt prompt) { return fluxResponse.map(response -> { String stopReason = response.stopReason() != null ? response.stopReason().name() : null; return new ChatResponse(List.of(new Generation(new AssistantMessage(response.generation()), - ChatGenerationMetadata.from(stopReason, extractUsage(response))))); + ChatGenerationMetadata.builder() + .finishReason(stopReason) + .metadata("usage", extractUsage(response)) + .build()))); }); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java index 57c832fd6de..7d259ff9c7d 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java @@ -78,12 +78,17 @@ public Flux stream(Prompt prompt) { ChatGenerationMetadata chatGenerationMetadata = null; if (chunk.amazonBedrockInvocationMetrics() != null) { String completionReason = chunk.completionReason().name(); - chatGenerationMetadata = ChatGenerationMetadata.from(completionReason, - chunk.amazonBedrockInvocationMetrics()); + chatGenerationMetadata = ChatGenerationMetadata.builder() + .finishReason(completionReason) + .metadata("usage", chunk.amazonBedrockInvocationMetrics()) + .build(); } else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount() != null) { String completionReason = chunk.completionReason().name(); - chatGenerationMetadata = ChatGenerationMetadata.from(completionReason, extractUsage(chunk)); + chatGenerationMetadata = ChatGenerationMetadata.builder() + .finishReason(completionReason) + .metadata("usage", extractUsage(chunk)) + .build(); } return new ChatResponse( List.of(new Generation(new AssistantMessage(chunk.outputText()), chatGenerationMetadata))); diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index f32fe5d3c5a..b9ff96e26d2 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -203,7 +203,7 @@ private static Generation buildGeneration(Choice choice, Map met }); var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); } @@ -408,7 +408,7 @@ private Generation buildGeneration(ChatCompletionMessage message, ChatCompletion var assistantMessage = new AssistantMessage(message.content(), metadata, toolCalls); String finishReason = (completionFinishReason != null ? completionFinishReason.name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 25657ec39b1..6df96c45e4c 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -304,7 +304,7 @@ private Generation buildGeneration(Choice choice, Map metadata) var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index 09c18bdd54b..51857fd6c1f 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -173,7 +173,7 @@ private static Generation buildGeneration(Choice choice, Map met var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); } diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java index 941658cb83a..4fcb82241f9 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java @@ -181,7 +181,9 @@ private List toGenerations(com.oracle.bmc.generativeaiinference.resp BaseChatResponse cr = ociChatResponse.getChatResult().getChatResponse(); if (cr instanceof CohereChatResponse resp) { List generations = new ArrayList<>(); - ChatGenerationMetadata metadata = ChatGenerationMetadata.from(resp.getFinishReason().getValue(), null); + ChatGenerationMetadata metadata = ChatGenerationMetadata.builder() + .finishReason(resp.getFinishReason().getValue()) + .build(); AssistantMessage message = new AssistantMessage(resp.getText(), Map.of()); generations.add(new Generation(message, metadata)); return generations; diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index eae56f6206c..74c735c0cf4 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -156,7 +156,9 @@ public ChatResponse call(Prompt prompt) { ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { - generationMetadata = ChatGenerationMetadata.from(ollamaResponse.doneReason(), null); + generationMetadata = ChatGenerationMetadata.builder() + .finishReason(ollamaResponse.doneReason()) + .build(); } var generator = new Generation(assistantMessage, generationMetadata); @@ -217,7 +219,7 @@ public Flux stream(Prompt prompt) { ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; if (chunk.promptEvalCount() != null && chunk.evalCount() != null) { - generationMetadata = ChatGenerationMetadata.from(chunk.doneReason(), null); + generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build(); } var generator = new Generation(assistantMessage, generationMetadata); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index f21a064ee6b..b41417821f5 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -376,7 +376,7 @@ private Generation buildGeneration(Choice choice, Map metadata) var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java index ae598fcaad1..b7229381fac 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java @@ -119,7 +119,7 @@ void aiResponseContainsAiMetadata() { ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); assertThat(chatGenerationMetadata).isNotNull(); assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("STOP"); - assertThat(chatGenerationMetadata.getContentFilterMetadata()).isNull(); + assertThat(chatGenerationMetadata.getContentFilters()).isEmpty(); }); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 30f030c3737..dfe932e4888 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -390,7 +390,9 @@ protected List responseCandidateToGeneration(Candidate candidate) { Map messageMetadata = Map.of("candidateIndex", candidateIndex, "finishReason", candidateFinishReason); - ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.from(candidateFinishReason.name(), null); + ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.builder() + .finishReason(candidateFinishReason.name()) + .build(); boolean isFunctionCall = candidate.getContent().getPartsList().stream().allMatch(Part::hasFunctionCall); diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java index 8b4e78b34ec..bf40d60a9f2 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java @@ -85,7 +85,10 @@ public ChatResponse call(Prompt prompt) { WatsonxAiChatResponse response = this.watsonxAiApi.generate(request).getBody(); var generation = new Generation(new AssistantMessage(response.results().get(0).generatedText()), - ChatGenerationMetadata.from(response.results().get(0).stopReason(), response.system())); + ChatGenerationMetadata.builder() + .finishReason(response.results().get(0).stopReason()) + .metadata("system", response.system()) + .build()); return new ChatResponse(List.of(generation)); } @@ -103,7 +106,10 @@ public Flux stream(Prompt prompt) { ChatGenerationMetadata metadata = ChatGenerationMetadata.NULL; if (chunk.system() != null) { - metadata = ChatGenerationMetadata.from(chunk.results().get(0).stopReason(), chunk.system()); + metadata = ChatGenerationMetadata.builder() + .finishReason(chunk.results().get(0).stopReason()) + .metadata("system", chunk.system()) + .build(); } Generation generation = new Generation(assistantMessage, metadata); diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java index 38e9939bca4..3607fbbe26b 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java @@ -174,8 +174,11 @@ public void testCallMethod() { .willReturn(ResponseEntity.of(Optional.of(fakeResponse))); Generation expectedGenerator = new Generation(new AssistantMessage("LLM response"), - ChatGenerationMetadata.from("max_tokens", - Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning"))))); + ChatGenerationMetadata.builder() + .finishReason("max_tokens") + .metadata("system", + Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning")))) + .build()); ChatResponse expectedResponse = new ChatResponse(List.of(expectedGenerator)); ChatResponse response = chatModel.call(prompt); @@ -206,8 +209,12 @@ public void testStreamMethod() { Flux fakeResponse = Flux.just(fakeResponseFirst, fakeResponseSecond); given(mockChatApi.generateStreaming(any(WatsonxAiChatRequest.class))).willReturn(fakeResponse); - Generation firstGen = new Generation(new AssistantMessage("LLM resp"), ChatGenerationMetadata.from("max_tokens", - Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning"))))); + Generation firstGen = new Generation(new AssistantMessage("LLM resp"), + ChatGenerationMetadata.builder() + .finishReason("max_tokens") + .metadata("system", + Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning")))) + .build()); Generation secondGen = new Generation(new AssistantMessage("onse")); Flux response = chatModel.stream(prompt); diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 7da150c62e7..7ac8985f2c9 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -188,7 +188,7 @@ private static Generation buildGeneration(Choice choice, Map met var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java index 77728e276c8..110b38e068b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/ChatGenerationMetadata.java @@ -16,61 +16,68 @@ package org.springframework.ai.chat.metadata; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + import org.springframework.ai.model.ResultMetadata; -import org.springframework.lang.Nullable; /** - * Abstract Data Type (ADT) encapsulating information on the completion choices in the AI - * response. + * + * Represents the metadata associated with the generation of a chat response. * * @author John Blum + * @author Christian Tzolov * @since 0.7.0 */ public interface ChatGenerationMetadata extends ResultMetadata { - ChatGenerationMetadata NULL = ChatGenerationMetadata.from(null, null); + ChatGenerationMetadata NULL = builder().build(); - /** - * Factory method used to construct a new {@link ChatGenerationMetadata} from the - * given {@link String finish reason} and content filter metadata. - * @param finishReason {@link String} contain the reason for the choice completion. - * @param contentFilterMetadata underlying AI provider metadata for filtering applied - * to generation content. - * @return a new {@link ChatGenerationMetadata} from the given {@link String finish - * reason} and content filter metadata. - */ - static ChatGenerationMetadata from(String finishReason, Object contentFilterMetadata) { - return new ChatGenerationMetadata() { - - @Override - @SuppressWarnings("unchecked") - public T getContentFilterMetadata() { - return (T) contentFilterMetadata; - } - - @Override - public String getFinishReason() { - return finishReason; - } - - @Override - public String toString() { - return "ChatGenerationMetadata{finishReason=" + finishReason + "," + "contentFilterMetadata=" - + contentFilterMetadata + "}"; - } - }; - } + // /** + // * Factory method used to construct a new {@link ChatGenerationMetadata} from the + // * given {@link String finish reason} and content filter metadata. + // * @param finishReason {@link String} contain the reason for the choice completion. + // * @param contentFilterMetadata underlying AI provider metadata for filtering + // applied + // * to generation content. + // * @return a new {@link ChatGenerationMetadata} from the given {@link String finish + // * reason} and content filter metadata. + // */ + // static ChatGenerationMetadata from(String finishReason, Object + // contentFilterMetadata) { + // return new ChatGenerationMetadata() { - /** - * Returns the underlying AI provider metadata for filtering applied to generation - * content. - * @param {@link Class Type} used to cast the filtered content metadata into the - * AI provider-specific type. - * @return the underlying AI provider metadata for filtering applied to generation - * content. - */ - @Nullable - T getContentFilterMetadata(); + // @Override + // @SuppressWarnings("unchecked") + // public T getContentFilterMetadata() { + // return (T) contentFilterMetadata; + // } + + // @Override + // public String getFinishReason() { + // return finishReason; + // } + + // @Override + // public String toString() { + // return "ChatGenerationMetadata{finishReason=" + finishReason + "," + + // "contentFilterMetadata=" + // + contentFilterMetadata + "}"; + // } + // }; + // } + + // /** + // * Returns the underlying AI provider metadata for filtering applied to generation + // * content. + // * @param {@link Class Type} used to cast the filtered content metadata into the + // * AI provider-specific type. + // * @return the underlying AI provider metadata for filtering applied to generation + // * content. + // */ + // @Nullable + // T getContentFilterMetadata(); /** * Get the {@link String reason} this choice completed for the generation. @@ -78,4 +85,60 @@ public String toString() { */ String getFinishReason(); + Set getContentFilters(); + + T get(String key); + + boolean containsKey(String key); + + T getOrDefault(String key, T defaultObject); + + Set> entrySet(); + + Set keySet(); + + boolean isEmpty(); + + public static Builder builder() { + return new DefaultChatGenerationMetadataBuilder(); + } + + /** + * @author Christian Tzolov + * @since 1.0.0 + */ + public interface Builder { + + /** + * Set the reason this choice completed for the generation. + */ + Builder finishReason(String id); + + /** + * Add metadata to the Generation result. + */ + Builder metadata(String key, T value); + + /** + * Add metadata to the Generation result. + */ + Builder metadata(Map metadata); + + /** + * Add content filter to the Generation result. + */ + Builder contentFilter(String contentFilter); + + /** + * Add content filters to the Generation result. + */ + Builder contentFilters(Set contentFilters); + + /** + * Build the Generation metadata. + */ + ChatGenerationMetadata build(); + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultChatGenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultChatGenerationMetadata.java new file mode 100644 index 00000000000..5d8a55179ad --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultChatGenerationMetadata.java @@ -0,0 +1,98 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.metadata; + +import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +public class DefaultChatGenerationMetadata implements ChatGenerationMetadata { + + private final Map metadata; + + private final String finishReason; + + private final Set contentFilters = new HashSet<>(); + + DefaultChatGenerationMetadata(Map metadata, String finishReason, Set contentFilters) { + this.metadata = metadata; + this.finishReason = finishReason; + this.contentFilters.addAll(contentFilters); + } + + @Override + public T get(String key) { + return (T) this.metadata.get(key); + } + + @Override + public boolean containsKey(String key) { + return this.metadata.containsKey(key); + } + + @Override + public T getOrDefault(String key, T defaultObject) { + return containsKey(key) ? get(key) : defaultObject; + } + + @Override + public Set> entrySet() { + return this.metadata.entrySet(); + } + + @Override + public Set keySet() { + return this.metadata.keySet(); + } + + @Override + public boolean isEmpty() { + return this.metadata.isEmpty(); + } + + @Override + public String getFinishReason() { + return this.finishReason; + } + + @Override + public Set getContentFilters() { + return this.contentFilters; + } + + @Override + public int hashCode() { + return Objects.hash(metadata, finishReason, contentFilters); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null || getClass() != obj.getClass()) + return false; + DefaultChatGenerationMetadata other = (DefaultChatGenerationMetadata) obj; + return Objects.equals(metadata, other.metadata) && Objects.equals(finishReason, other.finishReason) + && Objects.equals(contentFilters, other.contentFilters); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultChatGenerationMetadataBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultChatGenerationMetadataBuilder.java new file mode 100644 index 00000000000..e0a74305710 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultChatGenerationMetadataBuilder.java @@ -0,0 +1,76 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.metadata; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.springframework.ai.chat.metadata.ChatGenerationMetadata.Builder; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ + +public class DefaultChatGenerationMetadataBuilder implements ChatGenerationMetadata.Builder { + + private String finishReason; + + private Map metadata = new HashMap<>(); + + private Set contentFilters = new HashSet<>(); + + DefaultChatGenerationMetadataBuilder() { + } + + @Override + public Builder finishReason(String finishReason) { + this.finishReason = finishReason; + return this; + } + + @Override + public Builder metadata(String key, T value) { + this.metadata.put(key, value); + return this; + } + + @Override + public Builder metadata(Map metadata) { + this.metadata.putAll(metadata); + return this; + } + + @Override + public Builder contentFilter(String contentFilter) { + this.contentFilters.add(contentFilter); + return this; + } + + @Override + public Builder contentFilters(Set contentFilters) { + this.contentFilters.addAll(contentFilters); + return this; + } + + @Override + public ChatGenerationMetadata build() { + return new DefaultChatGenerationMetadata(this.metadata, this.finishReason, this.contentFilters); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java index 929637d788f..8c371bfd2e2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java @@ -111,7 +111,7 @@ void shouldHaveKeyValuesWhenDefinedAndResponse() { .build(); observationContext.setResponse(new ChatResponse( List.of(new Generation(new AssistantMessage("response"), - ChatGenerationMetadata.from("this-is-the-end", null))), + ChatGenerationMetadata.builder().finishReason("this-is-the-end").build())), ChatResponseMetadata.builder() .withId("say33") .withModel("mistral-42") @@ -168,7 +168,8 @@ void shouldNotHaveKeyValuesWhenEmptyValues() { .requestOptions(ChatOptionsBuilder.builder().withStopSequences(List.of()).build()) .build(); observationContext.setResponse(new ChatResponse( - List.of(new Generation(new AssistantMessage("response"), ChatGenerationMetadata.from("", null))), + List.of(new Generation(new AssistantMessage("response"), + ChatGenerationMetadata.builder().finishReason("").build())), ChatResponseMetadata.builder().withId("").build())); assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext) .stream()