diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java index 236fb2cc011..82489c2d2de 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java @@ -83,8 +83,11 @@ void addAndGet() { memory.deleteByConversationId(sessionId); assertThat(memory.findByConversationId(sessionId)).isEmpty(); - AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(), - List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments"))); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("test answer") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments"))) + .build(); memory.saveAll(sessionId, List.of(userMessage, assistantMessage)); messages = memory.findByConversationId(sessionId); diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java index 4a3a9ec6ebd..9f7e90983af 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-chroma/src/test/java/org/springframework/ai/vectorstore/chroma/autoconfigure/ChromaVectorStoreAutoConfigurationIT.java @@ -95,8 +95,10 @@ public void verifyThatChromaCanHandleComplexMetadataValues() { var response = ChatClientResponse.builder() .chatResponse(ChatResponse.builder() - .generations(List - .of(new Generation(new AssistantMessage("AssistantMessage", Map.of("annotations", List.of()))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("AssistantMessage") + .properties(Map.of("annotations", List.of())) + .build()))) .build()) .build(); var res2 = advisor.after(response, null); diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java index 909bfda5a39..9f5167b463b 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java @@ -433,7 +433,7 @@ public static List getToolCallbacksFromAsyncClients(List props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn)); switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) { case ASSISTANT: - return new AssistantMessage(content, props); + return AssistantMessage.builder().content(content).properties(props).build(); case USER: return UserMessage.builder().text(content).metadata(props).build(); case SYSTEM: diff --git a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java index 21cdd80a54e..2d9a2099906 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java @@ -183,14 +183,16 @@ private Message buildToolMessage(org.neo4j.driver.Record record) { private Message buildAssistantMessage(org.neo4j.driver.Record record, Map messageMap, List mediaList) { - Message message; - message = new AssistantMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(), - record.get("metadata").asMap(Map.of()), record.get("toolCalls").asList(v -> { - var toolCallMap = v.asMap(); - return new AssistantMessage.ToolCall((String) toolCallMap.get("id"), - (String) toolCallMap.get("type"), (String) toolCallMap.get("name"), - (String) toolCallMap.get("arguments")); - }), mediaList); + Message message = AssistantMessage.builder() + .content(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()) + .properties(record.get("metadata").asMap(Map.of())) + .toolCalls(record.get("toolCalls").asList(v -> { + var toolCallMap = v.asMap(); + return new AssistantMessage.ToolCall((String) toolCallMap.get("id"), (String) toolCallMap.get("type"), + (String) toolCallMap.get("name"), (String) toolCallMap.get("arguments")); + })) + .media(mediaList) + .build(); return message; } diff --git a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java index 83ff42a71ae..acb06ede872 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java @@ -263,9 +263,12 @@ void handleMediaContent() { void handleAssistantMessageWithToolCalls() { var conversationId = UUID.randomUUID().toString(); - AssistantMessage assistantMessage = new AssistantMessage("Message with tool calls", Map.of(), - List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"), - new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2"))); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("Message with tool calls") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"), + new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2"))) + .build(); this.chatMemoryRepository.saveAll(conversationId, List.of(assistantMessage)); @@ -389,7 +392,8 @@ void saveAndFindMessagesWithEmptyContentOrMetadata() { assertThat(retrievedEmptyContentMsg.getMetadata().keySet()).hasSize(1); // Only // messageType - // Verify second message (empty metadata from input, should only have messageType + // Verify second message (empty metadata from input, should only have + // messageType // after retrieval) Message retrievedEmptyMetadataMsg = retrievedMessages.get(1); assertThat(retrievedEmptyMetadataMsg).isInstanceOf(UserMessage.class); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 880a04601c1..5f3cce8b0f5 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -321,19 +321,24 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage for (ContentBlock content : chatCompletion.content()) { switch (content.type()) { case TEXT, TEXT_DELTA: - generations.add(new Generation(new AssistantMessage(content.text(), Map.of()), + generations.add(new Generation( + AssistantMessage.builder().content(content.text()).properties(Map.of()).build(), ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); break; case THINKING, THINKING_DELTA: Map thinkingProperties = new HashMap<>(); thinkingProperties.put("signature", content.signature()); - generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties), + generations.add(new Generation( + AssistantMessage.builder() + .content(content.thinking()) + .properties(thinkingProperties) + .build(), ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); break; case REDACTED_THINKING: Map redactedProperties = new HashMap<>(); redactedProperties.put("data", content.data()); - generations.add(new Generation(new AssistantMessage(null, redactedProperties), + generations.add(new Generation(AssistantMessage.builder().properties(redactedProperties).build(), ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); break; case TOOL_USE: @@ -347,13 +352,17 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage } if (chatCompletion.stopReason() != null && generations.isEmpty()) { - Generation generation = new Generation(new AssistantMessage(null, Map.of()), + Generation generation = new Generation(AssistantMessage.builder().content("").properties(Map.of()).build(), ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); generations.add(generation); } if (!CollectionUtils.isEmpty(toolCalls)) { - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(toolCalls) + .build(); Generation toolCallGeneration = new Generation(assistantMessage, ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); generations.add(toolCallGeneration); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index b434e5d0b04..e00a64edc69 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -493,7 +493,11 @@ private Generation buildGeneration(ChatChoice choice, Map metada } var content = responseMessage == null ? "" : responseMessage.getContent(); - var assistantMessage = new AssistantMessage(content, metadata, toolCalls); + var assistantMessage = AssistantMessage.builder() + .content(content) + .properties(metadata) + .toolCalls(toolCalls) + .build(); var generationMetadata = generateChoiceMetadata(choice); return new Generation(assistantMessage, generationMetadata); diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 27695a42112..0ae7d3e9468 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -572,14 +572,15 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv List generations = message.content() .stream() .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) - .map(content -> new Generation(new AssistantMessage(content.text(), Map.of()), + .map(content -> new Generation( + AssistantMessage.builder().content(content.text()).properties(Map.of()).build(), ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build())) .toList(); List allGenerations = new ArrayList<>(generations); if (response.stopReasonAsString() != null && generations.isEmpty()) { - Generation generation = new Generation(new AssistantMessage(null, Map.of()), + Generation generation = new Generation(AssistantMessage.builder().properties(Map.of()).build(), ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build()); allGenerations.add(generation); } @@ -603,7 +604,11 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv .add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); } - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(toolCalls) + .build(); Generation toolCallGeneration = new Generation(assistantMessage, ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build()); allGenerations.add(toolCallGeneration); diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index d58fdbad8cf..21767620ce7 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -140,7 +140,11 @@ public static Flux toChatResponse(Flux respo } } - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(toolCalls) + .build(); Generation toolCallGeneration = new Generation(assistantMessage, ChatGenerationMetadata.builder().finishReason("tool_use").build()); @@ -176,7 +180,10 @@ else if (nextEvent instanceof ContentBlockDeltaEvent contentBlockDeltaEvent) { if (contentBlockDeltaEvent.delta().type().equals(ContentBlockDelta.Type.TEXT)) { var generation = new Generation( - new AssistantMessage(contentBlockDeltaEvent.delta().text(), Map.of()), + AssistantMessage.builder() + .content(contentBlockDeltaEvent.delta().text()) + .properties(Map.of()) + .build(), ChatGenerationMetadata.builder() .finishReason(lastAggregation.metadataAggregation().stopReason()) .build()); diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index 668c1e5a0d7..3d11dbe320e 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -616,7 +616,11 @@ protected List responseCandidateToGeneration(Candidate candidate) { }) .toList(); - AssistantMessage assistantMessage = new AssistantMessage("", messageMetadata, assistantToolCalls); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(messageMetadata) + .toolCalls(assistantToolCalls) + .build(); return List.of(new Generation(assistantMessage, chatGenerationMetadata)); } @@ -626,7 +630,10 @@ protected List responseCandidateToGeneration(Candidate candidate) { .parts() .orElse(List.of()) .stream() - .map(part -> new AssistantMessage(part.text().orElse(""), messageMetadata)) + .map(part -> AssistantMessage.builder() + .content(part.text().orElse("")) + .properties(messageMetadata) + .build()) .map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata)) .toList(); } diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java index 01762f50949..2a11a12145f 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java @@ -104,7 +104,8 @@ public ChatResponse call(Prompt prompt) { new TypeReference<>() { }); - Generation generation = new Generation(new AssistantMessage(generatedText, detailsMap)); + Generation generation = new Generation( + AssistantMessage.builder().content(generatedText).properties(detailsMap).build()); generations.add(generation); } return new ChatResponse(generations); diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index 6116a058edf..5c771b2f5db 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -226,7 +226,11 @@ private static Generation buildGeneration(Choice choice, Map met acc1.addAll(acc2); return acc1; }); - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + var assistantMessage = AssistantMessage.builder() + .content(choice.message().content()) + .properties(metadata) + .toolCalls(toolCalls) + .build(); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); @@ -432,7 +436,11 @@ private Generation buildGeneration(ChatCompletionMessage message, ChatCompletion toolCall.function().name(), toolCall.function().arguments())) .toList(); - var assistantMessage = new AssistantMessage(message.content(), metadata, toolCalls); + var assistantMessage = AssistantMessage.builder() + .content(message.content()) + .properties(metadata) + .toolCalls(toolCalls) + .build(); String finishReason = (completionFinishReason != null ? completionFinishReason.name() : ""); var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index e05130ffb7f..991a67b1a94 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -362,7 +362,11 @@ private Generation buildGeneration(Choice choice, Map metadata) toolCall.function().name(), toolCall.function().arguments())) .toList(); - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + var assistantMessage = AssistantMessage.builder() + .content(choice.message().content()) + .properties(metadata) + .toolCalls(toolCalls) + .build(); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java index b9c771b58bb..e90514b1946 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java @@ -200,7 +200,7 @@ private List toGenerations(com.oracle.bmc.generativeaiinference.resp ChatGenerationMetadata metadata = ChatGenerationMetadata.builder() .finishReason(resp.getFinishReason().getValue()) .build(); - AssistantMessage message = new AssistantMessage(resp.getText(), Map.of()); + AssistantMessage message = AssistantMessage.builder().content(resp.getText()).properties(Map.of()).build(); generations.add(new Generation(message, metadata)); return generations; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 32f5457ba69..c1091b976ba 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -255,7 +255,11 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon ModelOptionsUtils.toJsonString(toolCall.function().arguments()))) .toList(); - var assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls); + var assistantMessage = AssistantMessage.builder() + .content(ollamaResponse.message().content()) + .properties(Map.of()) + .toolCalls(toolCalls) + .build(); ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { @@ -333,7 +337,11 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh .toList(); } - var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls); + var assistantMessage = AssistantMessage.builder() + .content(content) + .properties(Map.of()) + .toolCalls(toolCalls) + .build(); ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; if (chunk.promptEvalCount() != null && chunk.evalCount() != null) { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index cb0fed3e549..246b7893c4a 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -448,7 +448,12 @@ private Generation buildGeneration(Choice choice, Map metadata, generationMetadataBuilder.metadata("logprobs", choice.logprobs()); } - var assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media); + var assistantMessage = AssistantMessage.builder() + .content(textContent) + .properties(metadata) + .toolCalls(toolCalls) + .media(media) + .build(); return new Generation(assistantMessage, generationMetadataBuilder.build()); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index ed9789e861d..8ca44afaf28 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -631,7 +631,11 @@ protected List responseCandidateToGeneration(Candidate candidate) { }) .toList(); - AssistantMessage assistantMessage = new AssistantMessage("", messageMetadata, assistantToolCalls); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(messageMetadata) + .toolCalls(assistantToolCalls) + .build(); return List.of(new Generation(assistantMessage, chatGenerationMetadata)); } @@ -639,7 +643,7 @@ protected List responseCandidateToGeneration(Candidate candidate) { List generations = candidate.getContent() .getPartsList() .stream() - .map(part -> new AssistantMessage(part.getText(), messageMetadata)) + .map(part -> AssistantMessage.builder().content(part.getText()).properties(messageMetadata).build()) .map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata)) .toList(); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 498c35b8d17..59f7db03d57 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -28,6 +28,7 @@ import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.EmptyRateLimit; @@ -157,12 +158,17 @@ public Flux aggregate(Flux fluxChatResponse, if (!CollectionUtils.isEmpty(collectedToolCalls)) { - finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(), - messageMetadataMapRef.get(), collectedToolCalls); + finalAssistantMessage = AssistantMessage.builder() + .content(messageTextContentRef.get().toString()) + .properties(messageMetadataMapRef.get()) + .toolCalls(collectedToolCalls) + .build(); } else { - finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(), - messageMetadataMapRef.get()); + finalAssistantMessage = AssistantMessage.builder() + .content(messageTextContentRef.get().toString()) + .properties(messageMetadataMapRef.get()) + .build(); } onAggregationComplete.accept(new ChatResponse(List.of(new Generation(finalAssistantMessage, diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 471e5a48233..b77421de9aa 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -177,8 +177,11 @@ else if (message instanceof SystemMessage systemMessage) { messagesCopy.add(systemMessage.copy()); } else if (message instanceof AssistantMessage assistantMessage) { - messagesCopy.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), - assistantMessage.getToolCalls())); + messagesCopy.add(AssistantMessage.builder() + .content(assistantMessage.getText()) + .properties(assistantMessage.getMetadata()) + .toolCalls(assistantMessage.getToolCalls()) + .build()); } else if (message instanceof ToolResponseMessage toolResponseMessage) { messagesCopy.add(new ToolResponseMessage(new ArrayList<>(toolResponseMessage.getResponses()), diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index 36f7cf8eb97..bc007054171 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -164,8 +164,11 @@ private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assi private static List buildConversationHistoryBeforeToolExecution(Prompt prompt, AssistantMessage assistantMessage) { List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); - messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), - assistantMessage.getToolCalls())); + messageHistory.add(AssistantMessage.builder() + .content(assistantMessage.getText()) + .properties(assistantMessage.getMetadata()) + .toolCalls(assistantMessage.getToolCalls()) + .build()); return messageHistory; } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java index 69225009797..e2655204a32 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java @@ -43,8 +43,11 @@ class ChatResponseTests { @Test void whenToolCallsArePresentThenReturnTrue() { ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation( - new AssistantMessage("", Map.of(), List.of(new ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new ToolCall("toolA", "function", "toolA", "{}"))) + .build()))) .build(); assertThat(chatResponse.hasToolCalls()).isTrue(); } @@ -136,8 +139,11 @@ void whenEmptyGenerationsListThenReturnFalse() { void whenMultipleGenerationsWithToolCallsThenReturnTrue() { ChatResponse chatResponse = ChatResponse.builder() .generations(List.of(new Generation(new AssistantMessage("First response")), - new Generation(new AssistantMessage("", Map.of(), - List.of(new ToolCall("toolB", "function", "toolB", "{}")))))) + new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new ToolCall("toolB", "function", "toolB", "{}"))) + .build()))) .build(); assertThat(chatResponse.hasToolCalls()).isTrue(); } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java index d7d6dd10050..0aa31889cdf 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java @@ -73,8 +73,11 @@ void observationForToolCall() { .build(); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("Answer", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("Answer") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))) + .build()))) .build(); ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java index b1a83bca9fa..ce775b20cd5 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java @@ -72,7 +72,11 @@ public String call(String toolInput) { AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "testTool", null); // Create a ChatResponse - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), List.of(toolCall)); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); Generation generation = new Generation(assistantMessage); ChatResponse chatResponse = new ChatResponse(List.of(generation)); @@ -125,7 +129,11 @@ public String call(String toolInput) { AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "testTool", ""); // Create a ChatResponse - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), List.of(toolCall)); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); Generation generation = new Generation(assistantMessage); ChatResponse chatResponse = new ChatResponse(List.of(generation)); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java index c25e221eb09..41c27409bde 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java @@ -162,8 +162,11 @@ void whenSingleToolCallInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -184,8 +187,11 @@ void whenSingleToolCallWithReturnDirectInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -209,9 +215,12 @@ void whenMultipleToolCallsInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), - new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -233,8 +242,11 @@ void whenDuplicateMixedToolCallsInChatResponseThenExecute() { .toolNames("toolA") .build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -257,9 +269,12 @@ void whenMultipleToolCallsWithReturnDirectInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), - new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -284,9 +299,12 @@ void whenMultipleToolCallsWithMixedReturnDirectInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), - new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -309,8 +327,11 @@ void whenToolCallWithExceptionThenReturnError() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolC", "function", "toolC", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("toolC", "function", "toolC", "{}"))) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -349,9 +370,12 @@ void whenMixedMethodToolCallsInChatResponseThenExecute() throws NoSuchMethodExce .build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), - new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java index b8e32f1ade0..8c6e22e1742 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java @@ -44,7 +44,11 @@ void whenToolExecutionEnabledAndHasToolCalls() { // Create a ChatResponse with tool calls AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); - AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("test") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); // Test the predicate @@ -73,7 +77,11 @@ void whenToolExecutionDisabledAndHasToolCalls() { // Create a ChatResponse with tool calls AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); - AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("test") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); // Test the predicate @@ -102,7 +110,11 @@ void whenRegularChatOptionsAndHasToolCalls() { // Create a ChatResponse with tool calls AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); - AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("test") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); // Test the predicate - should use default value (true) for internal tool @@ -141,7 +153,11 @@ void whenMultipleGenerationsWithMixedToolCalls() { // Create multiple generations - some with tool calls, some without AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); - AssistantMessage messageWithToolCall = new AssistantMessage("test1", Map.of(), List.of(toolCall)); + AssistantMessage messageWithToolCall = AssistantMessage.builder() + .content("test1") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); AssistantMessage messageWithoutToolCall = new AssistantMessage("test2"); ChatResponse chatResponse = new ChatResponse( @@ -174,7 +190,11 @@ void whenAssistantMessageHasEmptyToolCallsList() { ToolCallingChatOptions options = ToolCallingChatOptions.builder().internalToolExecutionEnabled(true).build(); // Create a ChatResponse with AssistantMessage having empty tool calls list - AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of()); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("test") + .properties(Map.of()) + .toolCalls(List.of()) + .build(); ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); // Test the predicate @@ -191,7 +211,11 @@ void whenMultipleToolCallsPresent() { AssistantMessage.ToolCall toolCall1 = new AssistantMessage.ToolCall("id1", "function", "testTool1", "{}"); AssistantMessage.ToolCall toolCall2 = new AssistantMessage.ToolCall("id2", "function", "testTool2", "{\"param\": \"value\"}"); - AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall1, toolCall2)); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("test") + .properties(Map.of()) + .toolCalls(List.of(toolCall1, toolCall2)) + .build(); ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); // Test the predicate