diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java index 236fb2cc011..3ed2b0a9359 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java @@ -46,6 +46,7 @@ import static org.assertj.core.api.Assertions.assertThat; /** + * @author Jemin Huh * @author Mick Semb Wever * @author Jihoon Kim * @author Enrico Rampazzo @@ -83,8 +84,10 @@ void addAndGet() { memory.deleteByConversationId(sessionId); assertThat(memory.findByConversationId(sessionId)).isEmpty(); - AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(), - List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments"))); + AssistantMessage assistantMessage = AssistantMessage.builder() + .text("test answer") + .toolCalls(List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments"))) + .build(); memory.saveAll(sessionId, List.of(userMessage, assistantMessage)); messages = memory.findByConversationId(sessionId); @@ -112,10 +115,11 @@ void addAndGet() { assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator() .isEqualTo(media); memory.deleteByConversationId(sessionId); - ToolResponseMessage toolResponseMessage = new ToolResponseMessage( - List.of(new ToolResponse("id", "name", "responseData"), - new ToolResponse("id2", "name2", "responseData2")), - Map.of("id", "id", "metadataKey", "metadata")); + ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("id", "name", "responseData"), + new ToolResponse("id2", "name2", "responseData2"))) + .metadata(Map.of("id", "id", "metadataKey", "metadata")) + .build(); memory.saveAll(sessionId, List.of(toolResponseMessage)); messages = memory.findByConversationId(sessionId); assertThat(messages.size()).isEqualTo(1); diff --git a/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java index 9f7c71666db..c36ccfc4838 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java @@ -209,14 +209,14 @@ private Message getMessage(UdtValue udt) { Map props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn)); switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) { case ASSISTANT: - return new AssistantMessage(content, props); + return AssistantMessage.builder().text(content).metadata(props).build(); case USER: return UserMessage.builder().text(content).metadata(props).build(); case SYSTEM: return SystemMessage.builder().text(content).metadata(props).build(); case TOOL: // todo – persist ToolResponse somehow - return new ToolResponseMessage(List.of(), props); + return ToolResponseMessage.builder().responses(List.of()).metadata(props).build(); default: throw new IllegalStateException( String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn))); diff --git a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java index 21cdd80a54e..7d751c9c636 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java @@ -172,25 +172,28 @@ public Neo4jChatMemoryRepositoryConfig getConfig() { private Message buildToolMessage(org.neo4j.driver.Record record) { Message message; - message = new ToolResponseMessage(record.get("toolResponses").asList(v -> { + message = ToolResponseMessage.builder().responses(record.get("toolResponses").asList(v -> { Map trMap = v.asMap(); return new ToolResponseMessage.ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()), (String) trMap.get(ToolResponseAttributes.NAME.getValue()), (String) trMap.get(ToolResponseAttributes.RESPONSE_DATA.getValue())); - }), record.get("metadata").asMap()); + })).metadata(record.get("metadata").asMap()).build(); return message; } private Message buildAssistantMessage(org.neo4j.driver.Record record, Map messageMap, List mediaList) { Message message; - message = new AssistantMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(), - record.get("metadata").asMap(Map.of()), record.get("toolCalls").asList(v -> { - var toolCallMap = v.asMap(); - return new AssistantMessage.ToolCall((String) toolCallMap.get("id"), - (String) toolCallMap.get("type"), (String) toolCallMap.get("name"), - (String) toolCallMap.get("arguments")); - }), mediaList); + message = AssistantMessage.builder() + .text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()) + .metadata(record.get("metadata").asMap(Map.of())) + .toolCalls(record.get("toolCalls").asList(v -> { + var toolCallMap = v.asMap(); + return new AssistantMessage.ToolCall((String) toolCallMap.get("id"), (String) toolCallMap.get("type"), + (String) toolCallMap.get("name"), (String) toolCallMap.get("arguments")); + })) + .media(mediaList) + .build(); return message; } diff --git a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java index 83ff42a71ae..6b288de02cb 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java @@ -53,6 +53,7 @@ /** * Integration tests for {@link Neo4jChatMemoryRepository}. * + * @author Jemin Huh * @author Enrico Rampazzo * @since 1.0.0 */ @@ -263,9 +264,11 @@ void handleMediaContent() { void handleAssistantMessageWithToolCalls() { var conversationId = UUID.randomUUID().toString(); - AssistantMessage assistantMessage = new AssistantMessage("Message with tool calls", Map.of(), - List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"), - new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2"))); + AssistantMessage assistantMessage = AssistantMessage.builder() + .text("Message with tool calls") + .toolCalls(List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"), + new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2"))) + .build(); this.chatMemoryRepository.saveAll(conversationId, List.of(assistantMessage)); @@ -282,9 +285,11 @@ void handleAssistantMessageWithToolCalls() { void handleToolResponseMessage() { var conversationId = UUID.randomUUID().toString(); - ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List - .of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")), - Map.of("metadataKey", "metadataValue")); + ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("id1", "name1", "responseData1"), + new ToolResponse("id2", "name2", "responseData2"))) + .metadata(Map.of("metadataKey", "metadataValue")) + .build(); this.chatMemoryRepository.saveAll(conversationId, List.of(toolResponseMessage)); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 270f3bef43d..e7597e40f67 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -82,6 +82,7 @@ /** * The {@link ChatModel} implementation for the Anthropic service. * + * @author Jemin Huh * @author Christian Tzolov * @author luocongqiu * @author Mariusz Bernacki @@ -302,19 +303,21 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage for (ContentBlock content : chatCompletion.content()) { switch (content.type()) { case TEXT, TEXT_DELTA: - generations.add(new Generation(new AssistantMessage(content.text(), Map.of()), + generations.add(new Generation(new AssistantMessage(content.text()), ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); break; case THINKING, THINKING_DELTA: Map thinkingProperties = new HashMap<>(); thinkingProperties.put("signature", content.signature()); - generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties), + generations.add(new Generation( + AssistantMessage.builder().text(content.thinking()).metadata(thinkingProperties).build(), ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); break; case REDACTED_THINKING: Map redactedProperties = new HashMap<>(); redactedProperties.put("data", content.data()); - generations.add(new Generation(new AssistantMessage(null, redactedProperties), + generations.add(new Generation( + AssistantMessage.builder().text((String) null).metadata(redactedProperties).build(), ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); break; case TOOL_USE: @@ -328,13 +331,13 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage } if (chatCompletion.stopReason() != null && generations.isEmpty()) { - Generation generation = new Generation(new AssistantMessage(null, Map.of()), + Generation generation = new Generation(new AssistantMessage((String) null), ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); generations.add(generation); } if (!CollectionUtils.isEmpty(toolCalls)) { - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); + AssistantMessage assistantMessage = AssistantMessage.builder().text("").toolCalls(toolCalls).build(); Generation toolCallGeneration = new Generation(assistantMessage, ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); generations.add(toolCallGeneration); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 1933f575300..8f8b6601c6b 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -106,6 +106,7 @@ * {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by * {@link OpenAIClient}. * + * @author Jemin Huh * @author Mark Pollack * @author Ueibin Kim * @author John Blum @@ -485,7 +486,7 @@ private Generation buildGeneration(ChatChoice choice, Map metada } var content = responseMessage == null ? "" : responseMessage.getContent(); - var assistantMessage = new AssistantMessage(content, metadata, toolCalls); + var assistantMessage = AssistantMessage.builder().text(content).metadata(metadata).toolCalls(toolCalls).build(); var generationMetadata = generateChoiceMetadata(choice); return new Generation(assistantMessage, generationMetadata); diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 380951c7265..299c2a9c0b5 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -25,7 +25,6 @@ import java.util.ArrayList; import java.util.Base64; import java.util.List; -import java.util.Map; import java.util.Set; import io.micrometer.observation.Observation; @@ -128,6 +127,7 @@ *

* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html * + * @author Jemin Huh * @author Christian Tzolov * @author Wei Jiang * @author Alexandros Pappas @@ -566,14 +566,14 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv List generations = message.content() .stream() .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) - .map(content -> new Generation(new AssistantMessage(content.text(), Map.of()), + .map(content -> new Generation(new AssistantMessage(content.text()), ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build())) .toList(); List allGenerations = new ArrayList<>(generations); if (response.stopReasonAsString() != null && generations.isEmpty()) { - Generation generation = new Generation(new AssistantMessage(null, Map.of()), + Generation generation = new Generation(new AssistantMessage((String) null), ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build()); allGenerations.add(generation); } @@ -597,7 +597,7 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv .add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); } - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); + AssistantMessage assistantMessage = AssistantMessage.builder().text("").toolCalls(toolCalls).build(); Generation toolCallGeneration = new Generation(assistantMessage, ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build()); allGenerations.add(toolCallGeneration); diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index a19de831a7e..8fce75798ce 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -60,6 +60,7 @@ /** * Amazon Bedrock Converse API utils. * + * @author Jemin Huh * @author Wei Jiang * @author Christian Tzolov * @author Alexandros Pappas @@ -140,7 +141,7 @@ public static Flux toChatResponse(Flux respo } } - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); + AssistantMessage assistantMessage = AssistantMessage.builder().text("").toolCalls(toolCalls).build(); Generation toolCallGeneration = new Generation(assistantMessage, ChatGenerationMetadata.builder().finishReason("tool_use").build()); @@ -175,8 +176,7 @@ else if (nextEvent instanceof ContentBlockStartEvent contentBlockStartEvent) { else if (nextEvent instanceof ContentBlockDeltaEvent contentBlockDeltaEvent) { if (contentBlockDeltaEvent.delta().type().equals(ContentBlockDelta.Type.TEXT)) { - var generation = new Generation( - new AssistantMessage(contentBlockDeltaEvent.delta().text(), Map.of()), + var generation = new Generation(new AssistantMessage(contentBlockDeltaEvent.delta().text()), ChatGenerationMetadata.builder() .finishReason(lastAggregation.metadataAggregation().stopReason()) .build()); diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java index 6159d9beadb..70ab3d4e274 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java @@ -23,6 +23,14 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.content.Media; +/** + * Represents an assistant message generated by the DeepSeek model. + * + * @author Jemin Huh + * @author Soby Chacko + * @author Mark Pollack + * @since 1.0.0 + */ public class DeepSeekAssistantMessage extends AssistantMessage { private Boolean prefix; @@ -38,22 +46,14 @@ public DeepSeekAssistantMessage(String content, String reasoningContent) { this.reasoningContent = reasoningContent; } - public DeepSeekAssistantMessage(String content, Map properties) { - super(content, properties); - } - - public DeepSeekAssistantMessage(String content, Map properties, List toolCalls) { - super(content, properties, toolCalls); - } - - public DeepSeekAssistantMessage(String content, String reasoningContent, Map properties, + public DeepSeekAssistantMessage(String content, String reasoningContent, Map metadata, List toolCalls) { - this(content, reasoningContent, properties, toolCalls, List.of()); + this(content, reasoningContent, metadata, toolCalls, List.of()); } - public DeepSeekAssistantMessage(String content, String reasoningContent, Map properties, + public DeepSeekAssistantMessage(String content, String reasoningContent, Map metadata, List toolCalls, List media) { - super(content, properties, toolCalls, media); + super(content, metadata, toolCalls, media); this.reasoningContent = reasoningContent; } @@ -102,9 +102,9 @@ public int hashCode() { @Override public String toString() { - return "AssistantMessage [messageType=" + this.messageType + ", toolCalls=" + super.getToolCalls() - + ", textContent=" + this.textContent + ", reasoningContent=" + this.reasoningContent + ", prefix=" - + this.prefix + ", metadata=" + this.metadata + "]"; + return "AssistantMessage [messageType=" + this.messageType + ", toolCalls=" + this.toolCalls + ", textContent=" + + this.textContent + ", reasoningContent=" + this.reasoningContent + ", prefix=" + this.prefix + + ", metadata=" + this.metadata + "]"; } } diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java index 5546b4c54d2..897f1d2655f 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java @@ -39,6 +39,7 @@ * An implementation of {@link ChatModel} that interfaces with HuggingFace Inference * Endpoints for text generation. * + * @author Jemin Huh * @author Mark Pollack * @author Jihoon Kim */ @@ -104,7 +105,8 @@ public ChatResponse call(Prompt prompt) { new TypeReference>() { }); - Generation generation = new Generation(new AssistantMessage(generatedText, detailsMap)); + Generation generation = new Generation( + AssistantMessage.builder().text(generatedText).metadata(detailsMap).build()); generations.add(generation); } return new ChatResponse(generations); diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index e5a774cacf9..1de0ce8ecc0 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -76,6 +76,7 @@ * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal MiniMax} * backed by {@link MiniMaxApi}. * + * @author Jemin Huh * @author Geng Rong * @author Alexandros Pappas * @author Ilayaperumal Gopinathan @@ -225,7 +226,11 @@ private static Generation buildGeneration(Choice choice, Map met acc1.addAll(acc2); return acc1; }); - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + var assistantMessage = AssistantMessage.builder() + .text(choice.message().content()) + .metadata(metadata) + .toolCalls(toolCalls) + .build(); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); @@ -424,7 +429,11 @@ private Generation buildGeneration(ChatCompletionMessage message, ChatCompletion toolCall.function().name(), toolCall.function().arguments())) .toList(); - var assistantMessage = new AssistantMessage(message.content(), metadata, toolCalls); + var assistantMessage = AssistantMessage.builder() + .text(message.content()) + .metadata(metadata) + .toolCalls(toolCalls) + .build(); String finishReason = (completionFinishReason != null ? completionFinishReason.name() : ""); var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index b9838dcedf1..0cfc35f5da8 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -354,7 +354,11 @@ private Generation buildGeneration(Choice choice, Map metadata) toolCall.function().name(), toolCall.function().arguments())) .toList(); - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + var assistantMessage = AssistantMessage.builder() + .text(choice.message().content()) + .metadata(metadata) + .toolCalls(toolCalls) + .build(); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java index b9c771b58bb..898748e30ea 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java @@ -18,7 +18,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Objects; import com.oracle.bmc.generativeaiinference.GenerativeAiInference; @@ -61,6 +60,7 @@ /** * {@link ChatModel} implementation that uses the OCI GenAI Chat API. * + * @author Jemin Huh * @author Anders Swanson * @author Alexandros Pappas * @since 1.0.0 @@ -200,7 +200,7 @@ private List toGenerations(com.oracle.bmc.generativeaiinference.resp ChatGenerationMetadata metadata = ChatGenerationMetadata.builder() .finishReason(resp.getFinishReason().getValue()) .build(); - AssistantMessage message = new AssistantMessage(resp.getText(), Map.of()); + AssistantMessage message = new AssistantMessage(resp.getText()); generations.add(new Generation(message, metadata)); return generations; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index a75a274a797..5bd2712ab43 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -19,7 +19,6 @@ import java.time.Duration; import java.util.Base64; import java.util.List; -import java.util.Map; import java.util.Optional; import com.fasterxml.jackson.core.type.TypeReference; @@ -78,6 +77,7 @@ * Hugging Face. Please refer to the official Ollama * website for the most up-to-date information on available models. * + * @author Jemin Huh * @author Christian Tzolov * @author luocongqiu * @author Thomas Vitale @@ -243,7 +243,10 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon ModelOptionsUtils.toJsonString(toolCall.function().arguments()))) .toList(); - var assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls); + var assistantMessage = AssistantMessage.builder() + .text(ollamaResponse.message().content()) + .toolCalls(toolCalls) + .build(); ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { @@ -321,7 +324,7 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh .toList(); } - var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls); + var assistantMessage = AssistantMessage.builder().text(content).toolCalls(toolCalls).build(); ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; if (chunk.promptEvalCount() != null && chunk.evalCount() != null) { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 58a0062eccd..3f0c92f1dd8 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -440,7 +440,12 @@ private Generation buildGeneration(Choice choice, Map metadata, generationMetadataBuilder.metadata("logprobs", choice.logprobs()); } - var assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media); + var assistantMessage = AssistantMessage.builder() + .text(textContent) + .metadata(metadata) + .toolCalls(toolCalls) + .media(media) + .build(); return new Generation(assistantMessage, generationMetadataBuilder.build()); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 01ab8b96c02..2e52618af74 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -127,6 +127,7 @@ * .build(); * } * + * @author Jemin Huh * @author Christian Tzolov * @author Grogdunn * @author luocongqiu @@ -602,7 +603,11 @@ protected List responseCandidateToGeneration(Candidate candidate) { }) .toList(); - AssistantMessage assistantMessage = new AssistantMessage("", messageMetadata, assistantToolCalls); + AssistantMessage assistantMessage = AssistantMessage.builder() + .text("") + .metadata(messageMetadata) + .toolCalls(assistantToolCalls) + .build(); return List.of(new Generation(assistantMessage, chatGenerationMetadata)); } @@ -610,7 +615,7 @@ protected List responseCandidateToGeneration(Candidate candidate) { List generations = candidate.getContent() .getPartsList() .stream() - .map(part -> new AssistantMessage(part.getText(), messageMetadata)) + .map(part -> AssistantMessage.builder().text(part.getText()).metadata(messageMetadata).build()) .map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata)) .toList(); diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 408666fdc34..04224ae6b05 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -79,6 +79,7 @@ * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal ZhiPuAI} * backed by {@link ZhiPuAiApi}. * + * @author Jemin Huh * @author Geng Rong * @author Alexandros Pappas * @author Ilayaperumal Gopinathan @@ -226,7 +227,11 @@ private static Generation buildGeneration(Choice choice, Map met toolCall.function().name(), toolCall.function().arguments())) .toList(); - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + var assistantMessage = AssistantMessage.builder() + .text(choice.message().content()) + .metadata(metadata) + .toolCalls(toolCalls) + .build(); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index fd795971a2e..8b271b3e4b8 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -784,7 +784,7 @@ void whenChatResponseContentIsNull() { ChatModel chatModel = mock(ChatModel.class); ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); given(chatModel.call(promptCaptor.capture())) - .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null))))); ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient @@ -814,7 +814,7 @@ void whenResponseEntityWithParameterizedTypeAndChatResponseContentNull() { ChatModel chatModel = mock(ChatModel.class); ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); given(chatModel.call(promptCaptor.capture())) - .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null))))); ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient @@ -874,7 +874,7 @@ void whenResponseEntityWithConverterAndChatResponseContentNull() { ChatModel chatModel = mock(ChatModel.class); ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); given(chatModel.call(promptCaptor.capture())) - .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null))))); ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient @@ -926,7 +926,7 @@ void whenResponseEntityWithTypeAndChatResponseContentNull() { ChatModel chatModel = mock(ChatModel.class); ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); given(chatModel.call(promptCaptor.capture())) - .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null))))); ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient @@ -978,7 +978,7 @@ void whenEntityWithParameterizedTypeAndChatResponseContentNull() { ChatModel chatModel = mock(ChatModel.class); ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); given(chatModel.call(promptCaptor.capture())) - .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null))))); ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient @@ -1076,7 +1076,7 @@ void whenEntityWithTypeAndChatResponseContentNull() { ChatModel chatModel = mock(ChatModel.class); ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); given(chatModel.call(promptCaptor.capture())) - .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null))))); ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient @@ -1282,7 +1282,7 @@ void whenChatResponseContentIsNullThenReturnFlux() { ChatModel chatModel = mock(ChatModel.class); ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); given(chatModel.stream(promptCaptor.capture())) - .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage(null)))))); + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null)))))); ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java index b092de2d6da..c190839d05d 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java @@ -16,14 +16,20 @@ package org.springframework.ai.chat.messages; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import org.springframework.ai.content.Media; import org.springframework.ai.content.MediaContent; +import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * Lets the generative know the content was generated as a response to the user. This role @@ -31,31 +37,28 @@ * including assistant messages in the series, you provide context to the generative about * prior exchanges in the conversation. * + * @author Jemin Huh * @author Mark Pollack * @author Christian Tzolov * @since 1.0.0 */ public class AssistantMessage extends AbstractMessage implements MediaContent { - private final List toolCalls; + protected final List toolCalls; protected final List media; public AssistantMessage(String content) { - this(content, Map.of()); + this(content, Map.of(), List.of(), List.of()); } - public AssistantMessage(String content, Map properties) { - this(content, properties, List.of()); + public AssistantMessage(Resource resource) { + this(MessageUtils.readResource(resource)); } - public AssistantMessage(String content, Map properties, List toolCalls) { - this(content, properties, toolCalls, List.of()); - } - - public AssistantMessage(String content, Map properties, List toolCalls, + protected AssistantMessage(String content, Map metadata, List toolCalls, List media) { - super(MessageType.ASSISTANT, content, properties); + super(MessageType.ASSISTANT, content, metadata); Assert.notNull(toolCalls, "Tool calls must not be null"); Assert.notNull(media, "Media must not be null"); this.toolCalls = toolCalls; @@ -104,4 +107,88 @@ public record ToolCall(String id, String type, String name, String arguments) { } + public AssistantMessage copy() { + return new Builder().text(getText()) + .metadata(Map.copyOf(getMetadata())) + .toolCalls(List.copyOf(getToolCalls())) + .media(List.copyOf(getMedia())) + .build(); + } + + public AssistantMessage.Builder mutate() { + return new Builder().text(getText()) + .metadata(Map.copyOf(getMetadata())) + .toolCalls(List.copyOf(getToolCalls())) + .media(List.copyOf(getMedia())); + } + + public static AssistantMessage.Builder builder() { + return new AssistantMessage.Builder(); + } + + public static class Builder { + + @Nullable + private String textContent; + + @Nullable + private Resource resource; + + private Map metadata = new HashMap<>(); + + private List toolCalls = new ArrayList<>(); + + private List media = new ArrayList<>(); + + public AssistantMessage.Builder text(String textContent) { + this.textContent = textContent; + return this; + } + + public AssistantMessage.Builder text(Resource resource) { + this.resource = resource; + return this; + } + + public AssistantMessage.Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public AssistantMessage.Builder toolCalls(List toolCalls) { + this.toolCalls = toolCalls; + return this; + } + + public AssistantMessage.Builder toolCalls(@Nullable ToolCall... toolCalls) { + if (media != null) { + this.toolCalls = Arrays.asList(toolCalls); + } + return this; + } + + public AssistantMessage.Builder media(List media) { + this.media = media; + return this; + } + + public AssistantMessage.Builder media(@Nullable Media... media) { + if (media != null) { + this.media = Arrays.asList(media); + } + return this; + } + + public AssistantMessage build() { + if (StringUtils.hasText(this.textContent) && this.resource != null) { + throw new IllegalArgumentException("textContent and resource cannot be set at the same time"); + } + else if (this.resource != null) { + this.textContent = MessageUtils.readResource(this.resource); + } + return new AssistantMessage(this.textContent, this.metadata, this.toolCalls, this.media); + } + + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java index 47da252180f..de41d334c8e 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java @@ -16,6 +16,10 @@ package org.springframework.ai.chat.messages; +import org.springframework.lang.Nullable; + +import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -24,6 +28,7 @@ * The ToolResponseMessage class represents a message with a function content in a chat * application. * + * @author Jemin Huh * @author Christian Tzolov * @since 1.0.0 */ @@ -35,7 +40,7 @@ public ToolResponseMessage(List responses) { this(responses, Map.of()); } - public ToolResponseMessage(List responses, Map metadata) { + protected ToolResponseMessage(List responses, Map metadata) { super(MessageType.TOOL, "", metadata); this.responses = responses; } @@ -73,4 +78,45 @@ public record ToolResponse(String id, String name, String responseData) { } + public ToolResponseMessage copy() { + return new ToolResponseMessage(getResponses(), Map.copyOf(this.metadata)); + } + + public ToolResponseMessage.Builder mutate() { + return new Builder().responses(getResponses()).metadata(Map.copyOf(this.metadata)); + } + + public static ToolResponseMessage.Builder builder() { + return new ToolResponseMessage.Builder(); + } + + public static class Builder { + + private List responses; + + private Map metadata = new HashMap<>(); + + public ToolResponseMessage.Builder responses(List responses) { + this.responses = responses; + return this; + } + + public ToolResponseMessage.Builder media(@Nullable ToolResponse... responses) { + if (responses != null) { + this.responses = Arrays.asList(responses); + } + return this; + } + + public ToolResponseMessage.Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public ToolResponseMessage build() { + return new ToolResponseMessage(this.responses, this.metadata); + } + + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 839d99e23d8..4851834c0e5 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -39,6 +39,7 @@ * Helper that for streaming chat responses, aggregate the chat response messages into a * single AssistantMessage. Job is performed in parallel to the chat response processing. * + * @author Jemin Huh * @author Christian Tzolov * @author Alexandros Pappas * @author Thomas Vitale @@ -133,9 +134,10 @@ public Flux aggregate(Flux fluxChatResponse, .promptMetadata(metadataPromptMetadataRef.get()) .build(); - onAggregationComplete.accept(new ChatResponse(List.of(new Generation( - new AssistantMessage(messageTextContentRef.get().toString(), messageMetadataMapRef.get()), - generationMetadataRef.get())), chatResponseMetadata)); + onAggregationComplete.accept(new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .text(messageTextContentRef.get().toString()) + .metadata(messageMetadataMapRef.get()) + .build(), generationMetadataRef.get())), chatResponseMetadata)); messageTextContentRef.set(new StringBuilder()); messageMetadataMapRef.set(new HashMap<>()); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 2d5d8b64e9d..ece549bc6dd 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -39,6 +39,7 @@ * The Prompt class represents a prompt used in AI model requests. A prompt consists of * one or more messages and additional chat options. * + * @author Jemin Huh * @author Mark Pollack * @author luocongqiu * @author Thomas Vitale @@ -177,12 +178,13 @@ else if (message instanceof SystemMessage systemMessage) { messagesCopy.add(systemMessage.copy()); } else if (message instanceof AssistantMessage assistantMessage) { - messagesCopy.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), - assistantMessage.getToolCalls())); + messagesCopy.add(assistantMessage.copy()); } else if (message instanceof ToolResponseMessage toolResponseMessage) { - messagesCopy.add(new ToolResponseMessage(new ArrayList<>(toolResponseMessage.getResponses()), - new HashMap<>(toolResponseMessage.getMetadata()))); + messagesCopy.add(ToolResponseMessage.builder() + .responses(toolResponseMessage.getResponses()) + .metadata(new HashMap<>(toolResponseMessage.getMetadata())) + .build()); } else { throw new IllegalArgumentException("Unsupported message type: " + message.getClass().getName()); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index 887ba56bb72..113c349bdf2 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -50,6 +50,7 @@ /** * Default implementation of {@link ToolCallingManager}. * + * @author Jemin Huh * @author Thomas Vitale * @since 1.0.0 */ @@ -154,8 +155,7 @@ private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assi toolContextMap = new HashMap<>(toolCallingChatOptions.getToolContext()); List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); - messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), - assistantMessage.getToolCalls())); + messageHistory.add(assistantMessage.copy()); toolContextMap.put(ToolContext.TOOL_CALL_HISTORY, buildConversationHistoryBeforeToolExecution(prompt, assistantMessage)); @@ -167,8 +167,7 @@ private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assi private static List buildConversationHistoryBeforeToolExecution(Prompt prompt, AssistantMessage assistantMessage) { List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); - messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), - assistantMessage.getToolCalls())); + messageHistory.add(assistantMessage.copy()); return messageHistory; } @@ -234,7 +233,8 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess toolCallResult != null ? toolCallResult : "")); } - return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect); + return new InternalToolExecutionResult(ToolResponseMessage.builder().responses(toolResponses).build(), + returnDirect); } private List buildConversationHistoryAfterToolExecution(List previousMessages, diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java index a9d673e23be..6cdf8687a5e 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java @@ -38,8 +38,10 @@ class ChatResponseTests { @Test void whenToolCallsArePresentThenReturnTrue() { ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .text("") + .toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")) + .build()))) .build(); assertThat(chatResponse.hasToolCalls()).isTrue(); } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java index d7d6dd10050..3a45b9118f2 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java @@ -73,8 +73,10 @@ void observationForToolCall() { .build(); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("Answer", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .text("Answer") + .toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")) + .build()))) .build(); ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java index 7dba4ad2518..0c4f4fce176 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java @@ -158,8 +158,10 @@ void whenSingleToolCallInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .text("") + .toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -180,8 +182,10 @@ void whenSingleToolCallWithReturnDirectInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .text("") + .toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -205,9 +209,11 @@ void whenMultipleToolCallsInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), - new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .text("") + .toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -229,8 +235,10 @@ void whenDuplicateMixedToolCallsInChatResponseThenExecute() { .toolNames("toolA") .build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .text("") + .toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -253,9 +261,11 @@ void whenMultipleToolCallsWithReturnDirectInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), - new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .text("") + .toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -280,9 +290,11 @@ void whenMultipleToolCallsWithMixedReturnDirectInChatResponseThenExecute() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), - new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .text("") + .toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( @@ -305,8 +317,10 @@ void whenToolCallWithExceptionThenReturnError() { Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolC", "function", "toolC", "{}")))))) + .generations(List.of(new Generation(AssistantMessage.builder() + .text("") + .toolCalls(new AssistantMessage.ToolCall("toolC", "function", "toolC", "{}")) + .build()))) .build(); ToolResponseMessage expectedToolResponse = new ToolResponseMessage( diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java index 8b92a3fad79..116ce9ffcda 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java @@ -44,7 +44,11 @@ void whenToolExecutionEnabledAndHasToolCalls() { // Create a ChatResponse with tool calls AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); - AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + AssistantMessage assistantMessage = AssistantMessage.builder() + .text("test") + .metadata(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); // Test the predicate @@ -73,7 +77,11 @@ void whenToolExecutionDisabledAndHasToolCalls() { // Create a ChatResponse with tool calls AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); - AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + AssistantMessage assistantMessage = AssistantMessage.builder() + .text("test") + .metadata(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); // Test the predicate @@ -102,7 +110,11 @@ void whenRegularChatOptionsAndHasToolCalls() { // Create a ChatResponse with tool calls AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); - AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); + AssistantMessage assistantMessage = AssistantMessage.builder() + .text("test") + .metadata(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); // Test the predicate - should use default value (true) for internal tool