Skip to content

Feature/add assistant name support #4047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ private Message getMessage(UdtValue udt) {
Map<String, Object> props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn));
switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) {
case ASSISTANT:
return new AssistantMessage(content, props);
return new AssistantMessage(content, props, List.of(), List.of(), null);
case USER:
return UserMessage.builder().text(content).metadata(props).build();
case SYSTEM:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ private Message buildAssistantMessage(org.neo4j.driver.Record record, Map<String
return new AssistantMessage.ToolCall((String) toolCallMap.get("id"),
(String) toolCallMap.get("type"), (String) toolCallMap.get("name"),
(String) toolCallMap.get("arguments"));
}), mediaList);
}), mediaList, (String) messageMap.get("name"));
return message;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.ai.azure.openai;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
Expand Down Expand Up @@ -596,6 +598,16 @@ private List<ChatRequestMessage> fromSpringAiMessage(Message message) {
}
var azureAssistantMessage = new ChatRequestAssistantMessage(message.getText());
azureAssistantMessage.setToolCalls(toolCalls);
// Try to set name field if supported by Azure OpenAI SDK
try {
// Use reflection to check if setName method exists and call it
Method setNameMethod = azureAssistantMessage.getClass().getMethod("setName", String.class);
setNameMethod.invoke(azureAssistantMessage, assistantMessage.getName());
}
catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
// Name field not supported in current Azure OpenAI SDK version
// This is expected behavior for some SDK versions
}
return List.of(azureAssistantMessage);
case TOOL:
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@ public DeepSeekAssistantMessage(String content, String reasoningContent, Map<Str
this.reasoningContent = reasoningContent;
}

// Constructors with name parameter
public DeepSeekAssistantMessage(String content, Map<String, Object> properties, String name) {
super(content, properties, name);
}

public DeepSeekAssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls,
String name) {
super(content, properties, toolCalls, name);
}

public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> properties,
List<ToolCall> toolCalls, String name) {
super(content, properties, toolCalls, name);
this.reasoningContent = reasoningContent;
}

public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> properties,
List<ToolCall> toolCalls, List<Media> media, String name) {
super(content, properties, toolCalls, media, name);
this.reasoningContent = reasoningContent;
}

public static DeepSeekAssistantMessage prefixAssistantMessage(String context) {
return prefixAssistantMessage(context, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,9 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
&& Boolean.TRUE.equals(((DeepSeekAssistantMessage) message).getPrefix())) {
isPrefixAssistantMessage = true;
}
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, isPrefixAssistantMessage, null));
return List
.of(new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT,
assistantMessage.getName(), null, toolCalls, isPrefixAssistantMessage, null));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
}).toList();
}
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls));
ChatCompletionMessage.Role.ASSISTANT, assistantMessage.getName(), null, toolCalls));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ else if (message instanceof AssistantMessage assistantMessage) {
}

return List.of(new MistralAiApi.ChatCompletionMessage(assistantMessage.getText(),
MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null));
MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, assistantMessage.getName(), toolCalls,
null));
}
else if (message instanceof ToolResponseMessage toolResponseMessage) {
toolResponseMessage.getResponses()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,9 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
audioOutput = new AudioOutput(assistantMessage.getMedia().get(0).getId(), null, null, null);

}
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null));
return List
.of(new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT,
assistantMessage.getName(), null, toolCalls, null, audioOutput, null));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
}).toList();
}
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls));
ChatCompletionMessage.Role.ASSISTANT, assistantMessage.getName(), null, toolCalls));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public class AssistantMessage extends AbstractMessage implements MediaContent {

protected final List<Media> media;

private final String name;

public AssistantMessage(String content) {
this(content, Map.of());
}
Expand All @@ -55,11 +57,29 @@ public AssistantMessage(String content, Map<String, Object> properties, List<Too

public AssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls,
List<Media> media) {
this(content, properties, toolCalls, media, null);
}

public AssistantMessage(String content, String name) {
this(content, Map.of(), name);
}

public AssistantMessage(String content, Map<String, Object> properties, String name) {
this(content, properties, List.of(), name);
}

public AssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls, String name) {
this(content, properties, toolCalls, List.of(), name);
}

public AssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls, List<Media> media,
String name) {
super(MessageType.ASSISTANT, content, properties);
Assert.notNull(toolCalls, "Tool calls must not be null");
Assert.notNull(media, "Media must not be null");
this.toolCalls = toolCalls;
this.media = media;
this.name = name;
}

public List<ToolCall> getToolCalls() {
Expand All @@ -75,6 +95,16 @@ public List<Media> getMedia() {
return this.media;
}

/**
* Get the name of the assistant. This field allows the model to distinguish the name
* of the assistant, making it easier for building multi-agent systems to share global
* context.
* @return the assistant name, or null if not set
*/
public String getName() {
return this.name;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -86,18 +116,19 @@ public boolean equals(Object o) {
if (!super.equals(o)) {
return false;
}
return Objects.equals(this.toolCalls, that.toolCalls) && Objects.equals(this.media, that.media);
return Objects.equals(this.toolCalls, that.toolCalls) && Objects.equals(this.media, that.media)
&& Objects.equals(this.name, that.name);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), this.toolCalls, this.media);
return Objects.hash(super.hashCode(), this.toolCalls, this.media, this.name);
}

@Override
public String toString() {
return "AssistantMessage [messageType=" + this.messageType + ", toolCalls=" + this.toolCalls + ", textContent="
+ this.textContent + ", metadata=" + this.metadata + "]";
+ this.textContent + ", name=" + this.name + ", metadata=" + this.metadata + "]";
}

public record ToolCall(String id, String type, String name, String arguments) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ else if (message instanceof SystemMessage systemMessage) {
}
else if (message instanceof AssistantMessage assistantMessage) {
messagesCopy.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
assistantMessage.getToolCalls()));
assistantMessage.getToolCalls(), assistantMessage.getMedia(), assistantMessage.getName()));
}
else if (message instanceof ToolResponseMessage toolResponseMessage) {
messagesCopy.add(new ToolResponseMessage(new ArrayList<>(toolResponseMessage.getResponses()),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.messages;

import org.junit.jupiter.api.Test;
import org.springframework.ai.content.Media;

import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Tests for {@link AssistantMessage} with name property support.
*
* @author Spring AI Team
*/
class AssistantMessageTest {

@Test
void shouldCreateAssistantMessageWithName() {
AssistantMessage message = new AssistantMessage("Hello", "Alice");
assertThat(message.getText()).isEqualTo("Hello");
assertThat(message.getName()).isEqualTo("Alice");
assertThat(message.getMessageType()).isEqualTo(MessageType.ASSISTANT);
}

@Test
void shouldCreateAssistantMessageWithNameAndProperties() {
Map<String, Object> properties = Map.of("key", "value");
AssistantMessage message = new AssistantMessage("Hello", properties, "Bob");
assertThat(message.getText()).isEqualTo("Hello");
assertThat(message.getName()).isEqualTo("Bob");
assertThat(message.getMetadata()).containsEntry("key", "value");
}

@Test
void shouldCreateAssistantMessageWithNameAndToolCalls() {
List<AssistantMessage.ToolCall> toolCalls = List
.of(new AssistantMessage.ToolCall("1", "function", "testTool", "{}"));
AssistantMessage message = new AssistantMessage("Hello", Map.of(), toolCalls, "Charlie");
assertThat(message.getText()).isEqualTo("Hello");
assertThat(message.getName()).isEqualTo("Charlie");
assertThat(message.getToolCalls()).hasSize(1);
assertThat(message.getToolCalls().get(0).name()).isEqualTo("testTool");
}

@Test
void shouldCreateAssistantMessageWithNameAndMedia() {
List<AssistantMessage.ToolCall> toolCalls = List.of();
List<Media> media = List.of();
AssistantMessage message = new AssistantMessage("Hello", Map.of(), toolCalls, media, "David");
assertThat(message.getText()).isEqualTo("Hello");
assertThat(message.getName()).isEqualTo("David");
assertThat(message.getToolCalls()).isEmpty();
assertThat(message.getMedia()).isEmpty();
}

@Test
void shouldHandleNullName() {
AssistantMessage message = new AssistantMessage("Hello", Map.of(), List.of(), List.of(), null);
assertThat(message.getText()).isEqualTo("Hello");
assertThat(message.getName()).isNull();
}

@Test
void shouldHandleEmptyName() {
AssistantMessage message = new AssistantMessage("Hello", "");
assertThat(message.getText()).isEqualTo("Hello");
assertThat(message.getName()).isEqualTo("");
}

@Test
void shouldBeEqualWithSameName() {
AssistantMessage message1 = new AssistantMessage("Hello", "Alice");
AssistantMessage message2 = new AssistantMessage("Hello", "Alice");
assertThat(message1).isEqualTo(message2);
assertThat(message1.hashCode()).isEqualTo(message2.hashCode());
}

@Test
void shouldNotBeEqualWithDifferentName() {
AssistantMessage message1 = new AssistantMessage("Hello", "Alice");
AssistantMessage message2 = new AssistantMessage("Hello", "Bob");
assertThat(message1).isNotEqualTo(message2);
}

@Test
void shouldIncludeNameInToString() {
AssistantMessage message = new AssistantMessage("Hello", "Alice");
String toString = message.toString();
assertThat(toString).contains("name=Alice");
}

@Test
void shouldHandleNullNameInToString() {
AssistantMessage message = new AssistantMessage("Hello", Map.of(), List.of(), List.of(), null);
String toString = message.toString();
assertThat(toString).contains("name=null");
}

}