diff --git a/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java index 9f7c71666db..1e18284a0e4 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java @@ -209,7 +209,7 @@ private Message getMessage(UdtValue udt) { Map props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn)); switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) { case ASSISTANT: - return new AssistantMessage(content, props); + return new AssistantMessage(content, props, List.of(), List.of(), null); case USER: return UserMessage.builder().text(content).metadata(props).build(); case SYSTEM: diff --git a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java index 21cdd80a54e..a7dc04bc28f 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java @@ -190,7 +190,7 @@ private Message buildAssistantMessage(org.neo4j.driver.Record record, Map fromSpringAiMessage(Message message) { } var azureAssistantMessage = new ChatRequestAssistantMessage(message.getText()); azureAssistantMessage.setToolCalls(toolCalls); + // Try to set name field if supported by Azure OpenAI SDK + try { + // Use reflection to check if setName method exists and call it + Method setNameMethod = azureAssistantMessage.getClass().getMethod("setName", String.class); + setNameMethod.invoke(azureAssistantMessage, assistantMessage.getName()); + } + catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + // Name field not supported in current Azure OpenAI SDK version + // This is expected behavior for some SDK versions + } return List.of(azureAssistantMessage); case TOOL: ToolResponseMessage toolMessage = (ToolResponseMessage) message; diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java index 6159d9beadb..a13e5a0c84f 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java @@ -57,6 +57,28 @@ public DeepSeekAssistantMessage(String content, String reasoningContent, Map properties, String name) { + super(content, properties, name); + } + + public DeepSeekAssistantMessage(String content, Map properties, List toolCalls, + String name) { + super(content, properties, toolCalls, name); + } + + public DeepSeekAssistantMessage(String content, String reasoningContent, Map properties, + List toolCalls, String name) { + super(content, properties, toolCalls, name); + this.reasoningContent = reasoningContent; + } + + public DeepSeekAssistantMessage(String content, String reasoningContent, Map properties, + List toolCalls, List media, String name) { + super(content, properties, toolCalls, media, name); + this.reasoningContent = reasoningContent; + } + public static DeepSeekAssistantMessage prefixAssistantMessage(String context) { return prefixAssistantMessage(context, null); } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java index 6295666e07f..ed6e47422ac 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -446,8 +446,9 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { && Boolean.TRUE.equals(((DeepSeekAssistantMessage) message).getPrefix())) { isPrefixAssistantMessage = true; } - return List.of(new ChatCompletionMessage(assistantMessage.getText(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, isPrefixAssistantMessage, null)); + return List + .of(new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT, + assistantMessage.getName(), null, toolCalls, isPrefixAssistantMessage, null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index 19f821b7fb3..828abd12a53 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -519,7 +519,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { }).toList(); } return List.of(new ChatCompletionMessage(assistantMessage.getText(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls)); + ChatCompletionMessage.Role.ASSISTANT, assistantMessage.getName(), null, toolCalls)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index b1449fe580a..4b51ba890e5 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -454,7 +454,8 @@ else if (message instanceof AssistantMessage assistantMessage) { } return List.of(new MistralAiApi.ChatCompletionMessage(assistantMessage.getText(), - MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null)); + MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, assistantMessage.getName(), toolCalls, + null)); } else if (message instanceof ToolResponseMessage toolResponseMessage) { toolResponseMessage.getResponses() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 2ad584fa82f..f4691a07f44 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -591,8 +591,9 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { audioOutput = new AudioOutput(assistantMessage.getMedia().get(0).getId(), null, null, null); } - return List.of(new ChatCompletionMessage(assistantMessage.getText(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null)); + return List + .of(new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT, + assistantMessage.getName(), null, toolCalls, null, audioOutput, null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 01402acc36a..e5c3f4de37b 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -510,7 +510,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { }).toList(); } return List.of(new ChatCompletionMessage(assistantMessage.getText(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls)); + ChatCompletionMessage.Role.ASSISTANT, assistantMessage.getName(), null, toolCalls)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java index b092de2d6da..449f8a4437e 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java @@ -41,6 +41,8 @@ public class AssistantMessage extends AbstractMessage implements MediaContent { protected final List media; + private final String name; + public AssistantMessage(String content) { this(content, Map.of()); } @@ -55,11 +57,29 @@ public AssistantMessage(String content, Map properties, List properties, List toolCalls, List media) { + this(content, properties, toolCalls, media, null); + } + + public AssistantMessage(String content, String name) { + this(content, Map.of(), name); + } + + public AssistantMessage(String content, Map properties, String name) { + this(content, properties, List.of(), name); + } + + public AssistantMessage(String content, Map properties, List toolCalls, String name) { + this(content, properties, toolCalls, List.of(), name); + } + + public AssistantMessage(String content, Map properties, List toolCalls, List media, + String name) { super(MessageType.ASSISTANT, content, properties); Assert.notNull(toolCalls, "Tool calls must not be null"); Assert.notNull(media, "Media must not be null"); this.toolCalls = toolCalls; this.media = media; + this.name = name; } public List getToolCalls() { @@ -75,6 +95,16 @@ public List getMedia() { return this.media; } + /** + * Get the name of the assistant. This field allows the model to distinguish the name + * of the assistant, making it easier for building multi-agent systems to share global + * context. + * @return the assistant name, or null if not set + */ + public String getName() { + return this.name; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -86,18 +116,19 @@ public boolean equals(Object o) { if (!super.equals(o)) { return false; } - return Objects.equals(this.toolCalls, that.toolCalls) && Objects.equals(this.media, that.media); + return Objects.equals(this.toolCalls, that.toolCalls) && Objects.equals(this.media, that.media) + && Objects.equals(this.name, that.name); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), this.toolCalls, this.media); + return Objects.hash(super.hashCode(), this.toolCalls, this.media, this.name); } @Override public String toString() { return "AssistantMessage [messageType=" + this.messageType + ", toolCalls=" + this.toolCalls + ", textContent=" - + this.textContent + ", metadata=" + this.metadata + "]"; + + this.textContent + ", name=" + this.name + ", metadata=" + this.metadata + "]"; } public record ToolCall(String id, String type, String name, String arguments) { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 471e5a48233..1679fbdfb00 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -178,7 +178,7 @@ else if (message instanceof SystemMessage systemMessage) { } else if (message instanceof AssistantMessage assistantMessage) { messagesCopy.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), - assistantMessage.getToolCalls())); + assistantMessage.getToolCalls(), assistantMessage.getMedia(), assistantMessage.getName())); } else if (message instanceof ToolResponseMessage toolResponseMessage) { messagesCopy.add(new ToolResponseMessage(new ArrayList<>(toolResponseMessage.getResponses()), diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/AssistantMessageTest.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/AssistantMessageTest.java new file mode 100644 index 00000000000..249cdaf605d --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/AssistantMessageTest.java @@ -0,0 +1,116 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.messages; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.content.Media; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link AssistantMessage} with name property support. + * + * @author Spring AI Team + */ +class AssistantMessageTest { + + @Test + void shouldCreateAssistantMessageWithName() { + AssistantMessage message = new AssistantMessage("Hello", "Alice"); + assertThat(message.getText()).isEqualTo("Hello"); + assertThat(message.getName()).isEqualTo("Alice"); + assertThat(message.getMessageType()).isEqualTo(MessageType.ASSISTANT); + } + + @Test + void shouldCreateAssistantMessageWithNameAndProperties() { + Map properties = Map.of("key", "value"); + AssistantMessage message = new AssistantMessage("Hello", properties, "Bob"); + assertThat(message.getText()).isEqualTo("Hello"); + assertThat(message.getName()).isEqualTo("Bob"); + assertThat(message.getMetadata()).containsEntry("key", "value"); + } + + @Test + void shouldCreateAssistantMessageWithNameAndToolCalls() { + List toolCalls = List + .of(new AssistantMessage.ToolCall("1", "function", "testTool", "{}")); + AssistantMessage message = new AssistantMessage("Hello", Map.of(), toolCalls, "Charlie"); + assertThat(message.getText()).isEqualTo("Hello"); + assertThat(message.getName()).isEqualTo("Charlie"); + assertThat(message.getToolCalls()).hasSize(1); + assertThat(message.getToolCalls().get(0).name()).isEqualTo("testTool"); + } + + @Test + void shouldCreateAssistantMessageWithNameAndMedia() { + List toolCalls = List.of(); + List media = List.of(); + AssistantMessage message = new AssistantMessage("Hello", Map.of(), toolCalls, media, "David"); + assertThat(message.getText()).isEqualTo("Hello"); + assertThat(message.getName()).isEqualTo("David"); + assertThat(message.getToolCalls()).isEmpty(); + assertThat(message.getMedia()).isEmpty(); + } + + @Test + void shouldHandleNullName() { + AssistantMessage message = new AssistantMessage("Hello", Map.of(), List.of(), List.of(), null); + assertThat(message.getText()).isEqualTo("Hello"); + assertThat(message.getName()).isNull(); + } + + @Test + void shouldHandleEmptyName() { + AssistantMessage message = new AssistantMessage("Hello", ""); + assertThat(message.getText()).isEqualTo("Hello"); + assertThat(message.getName()).isEqualTo(""); + } + + @Test + void shouldBeEqualWithSameName() { + AssistantMessage message1 = new AssistantMessage("Hello", "Alice"); + AssistantMessage message2 = new AssistantMessage("Hello", "Alice"); + assertThat(message1).isEqualTo(message2); + assertThat(message1.hashCode()).isEqualTo(message2.hashCode()); + } + + @Test + void shouldNotBeEqualWithDifferentName() { + AssistantMessage message1 = new AssistantMessage("Hello", "Alice"); + AssistantMessage message2 = new AssistantMessage("Hello", "Bob"); + assertThat(message1).isNotEqualTo(message2); + } + + @Test + void shouldIncludeNameInToString() { + AssistantMessage message = new AssistantMessage("Hello", "Alice"); + String toString = message.toString(); + assertThat(toString).contains("name=Alice"); + } + + @Test + void shouldHandleNullNameInToString() { + AssistantMessage message = new AssistantMessage("Hello", Map.of(), List.of(), List.of(), null); + String toString = message.toString(); + assertThat(toString).contains("name=null"); + } + +} \ No newline at end of file