Skip to content

Commit 4ab1844

Browse files
committed
Add builder to ToolResponseMessage
Fixes #3243 Signed-off-by: Eric Bottard <[email protected]>
1 parent 01082b8 commit 4ab1844

File tree

13 files changed

+132
-67
lines changed

13 files changed

+132
-67
lines changed

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,11 @@ void addAndGet() {
115115
assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator()
116116
.isEqualTo(media);
117117
memory.deleteByConversationId(sessionId);
118-
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(
119-
List.of(new ToolResponse("id", "name", "responseData"),
120-
new ToolResponse("id2", "name2", "responseData2")),
121-
Map.of("id", "id", "metadataKey", "metadata"));
118+
ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder()
119+
.responses(List.of(new ToolResponse("id", "name", "responseData"),
120+
new ToolResponse("id2", "name2", "responseData2")))
121+
.metadata(Map.of("id", "id", "metadataKey", "metadata"))
122+
.build();
122123
memory.saveAll(sessionId, List.of(toolResponseMessage));
123124
messages = memory.findByConversationId(sessionId);
124125
assertThat(messages.size()).isEqualTo(1);

memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ private Message getMessage(UdtValue udt) {
216216
return SystemMessage.builder().text(content).metadata(props).build();
217217
case TOOL:
218218
// todo – persist ToolResponse somehow
219-
return new ToolResponseMessage(List.of(), props);
219+
return ToolResponseMessage.builder().responses(List.of()).metadata(props).build();
220220
default:
221221
throw new IllegalStateException(
222222
String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn)));

memory/repository/spring-ai-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/chat/memory/repository/cosmosdb/CosmosDBChatMemoryRepository.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ private Message mapToMessage(Map<String, Object> doc) {
235235
case ASSISTANT -> new AssistantMessage(content, metadata);
236236
case USER -> UserMessage.builder().text(content).metadata(metadata).build();
237237
case SYSTEM -> SystemMessage.builder().text(content).metadata(metadata).build();
238-
case TOOL -> new ToolResponseMessage(List.of(), metadata);
238+
case TOOL -> ToolResponseMessage.builder().responses(List.of()).metadata(metadata).build();
239239
default -> throw new IllegalStateException(String.format("Unknown message type: %s", messageTypeStr));
240240
};
241241
}

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ public Message mapRow(ResultSet rs, int i) throws SQLException {
148148
// The content is always stored empty for ToolResponseMessages.
149149
// If we want to capture the actual content, we need to extend
150150
// AddBatchPreparedStatement to support it.
151-
case TOOL -> new ToolResponseMessage(List.of());
151+
case TOOL -> ToolResponseMessage.builder().responses(List.of()).build();
152152
};
153153
}
154154

memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.ai.chat.messages.MessageType;
3636
import org.springframework.ai.chat.messages.SystemMessage;
3737
import org.springframework.ai.chat.messages.ToolResponseMessage;
38+
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
3839
import org.springframework.ai.chat.messages.UserMessage;
3940
import org.springframework.ai.content.Media;
4041
import org.springframework.ai.content.MediaContent;
@@ -172,12 +173,12 @@ public Neo4jChatMemoryRepositoryConfig getConfig() {
172173

173174
private Message buildToolMessage(org.neo4j.driver.Record record) {
174175
Message message;
175-
message = new ToolResponseMessage(record.get("toolResponses").asList(v -> {
176+
message = ToolResponseMessage.builder().responses(record.get("toolResponses").asList(v -> {
176177
Map<String, Object> trMap = v.asMap();
177-
return new ToolResponseMessage.ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()),
178+
return new ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()),
178179
(String) trMap.get(ToolResponseAttributes.NAME.getValue()),
179180
(String) trMap.get(ToolResponseAttributes.RESPONSE_DATA.getValue()));
180-
}), record.get("metadata").asMap());
181+
})).metadata(record.get("metadata").asMap()).build();
181182
return message;
182183
}
183184

memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ void saveAndFindMultipleMessages() {
130130
List<Message> messages = List.of(new AssistantMessage("Message from assistant - " + conversationId),
131131
new UserMessage("Message from user - " + conversationId),
132132
new SystemMessage("Message from system - " + conversationId),
133-
new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData"))));
133+
ToolResponseMessage.builder()
134+
.responses(List.of(new ToolResponse("id", "name", "responseData")))
135+
.build());
134136

135137
this.chatMemoryRepository.saveAll(conversationId, messages);
136138
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
@@ -285,9 +287,11 @@ void handleAssistantMessageWithToolCalls() {
285287
void handleToolResponseMessage() {
286288
var conversationId = UUID.randomUUID().toString();
287289

288-
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List
289-
.of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")),
290-
Map.of("metadataKey", "metadataValue"));
290+
ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder()
291+
.responses(List.of(new ToolResponse("id1", "name1", "responseData1"),
292+
new ToolResponse("id2", "name2", "responseData2")))
293+
.metadata(Map.of("metadataKey", "metadataValue"))
294+
.build();
291295

292296
this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(toolResponseMessage));
293297

@@ -408,7 +412,9 @@ private Message createMessageByType(String content, MessageType messageType) {
408412
case ASSISTANT -> new AssistantMessage(content);
409413
case USER -> new UserMessage(content);
410414
case SYSTEM -> new SystemMessage(content);
411-
case TOOL -> new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData")));
415+
case TOOL -> ToolResponseMessage.builder()
416+
.responses(List.of(new ToolResponse("id", "name", "responseData")))
417+
.build();
412418
};
413419
}
414420

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ void createChatCompletionMessagesWithToolResponseMessage() {
199199
var toolResponse1 = createToolResponse(1);
200200
var toolResponse2 = createToolResponse(2);
201201
var toolResponse3 = createToolResponse(3);
202-
var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse1, toolResponse2, toolResponse3));
202+
var toolResponseMessage = ToolResponseMessage.builder()
203+
.responses(List.of(toolResponse1, toolResponse2, toolResponse3))
204+
.build();
203205
var prompt = createPrompt(toolResponseMessage);
204206
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
205207
var chatCompletionMessages = chatCompletionRequest.messages();
@@ -212,7 +214,7 @@ void createChatCompletionMessagesWithToolResponseMessage() {
212214
@Test
213215
void createChatCompletionMessagesWithInvalidToolResponseMessage() {
214216
var toolResponse = new ToolResponseMessage.ToolResponse(null, null, null);
215-
var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse));
217+
var toolResponseMessage = ToolResponseMessage.builder().responses(List.of(toolResponse)).build();
216218
var prompt = createPrompt(toolResponseMessage);
217219
assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
218220
.isInstanceOf(IllegalArgumentException.class)

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.springframework.ai.chat.messages.Message;
2626
import org.springframework.ai.chat.messages.SystemMessage;
2727
import org.springframework.ai.chat.messages.ToolResponseMessage;
28+
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
2829
import org.springframework.ai.chat.messages.UserMessage;
2930
import org.springframework.ai.chat.prompt.ChatOptions;
3031
import org.springframework.ai.chat.prompt.Prompt;
@@ -256,11 +257,10 @@ private static List<Message> createMessagesWithAllMessageTypes() {
256257
var systemMessage = new SystemMessage("Test system message");
257258
var userMessage = new UserMessage("Test user message");
258259
// @formatter:off
259-
var toolResponseMessage = new ToolResponseMessage(List.of(
260-
new ToolResponseMessage.ToolResponse("tool1", "Tool 1", "Test tool response 1"),
261-
new ToolResponseMessage.ToolResponse("tool2", "Tool 2", "Test tool response 2"),
262-
new ToolResponseMessage.ToolResponse("tool3", "Tool 3", "Test tool response 3"))
263-
);
260+
var toolResponseMessage = ToolResponseMessage.builder().responses(List.of(
261+
new ToolResponse("tool1", "Tool 1", "Test tool response 1"),
262+
new ToolResponse("tool2", "Tool 2", "Test tool response 2"),
263+
new ToolResponse("tool3", "Tool 3", "Test tool response 3"))).build();
264264
// @formatter:on
265265
var assistantMessage = new AssistantMessage("Test assistant message");
266266

spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,27 @@ public class ToolResponseMessage extends AbstractMessage {
3131

3232
protected final List<ToolResponse> responses;
3333

34+
/**
35+
* @deprecated in favor of using {@link ToolResponseMessage.Builder}
36+
*/
37+
@Deprecated
3438
public ToolResponseMessage(List<ToolResponse> responses) {
3539
this(responses, Map.of());
3640
}
3741

42+
/**
43+
* @deprecated in favor of using {@link ToolResponseMessage.Builder}
44+
*/
45+
@Deprecated
3846
public ToolResponseMessage(List<ToolResponse> responses, Map<String, Object> metadata) {
3947
super(MessageType.TOOL, "", metadata);
4048
this.responses = responses;
4149
}
4250

51+
public static Builder builder() {
52+
return new Builder();
53+
}
54+
4355
public List<ToolResponse> getResponses() {
4456
return this.responses;
4557
}
@@ -73,4 +85,29 @@ public record ToolResponse(String id, String name, String responseData) {
7385

7486
}
7587

88+
public static final class Builder {
89+
90+
private List<ToolResponse> responses;
91+
92+
private Map<String, Object> metadata = Map.of();
93+
94+
private Builder() {
95+
}
96+
97+
public Builder responses(List<ToolResponse> responses) {
98+
this.responses = responses;
99+
return this;
100+
}
101+
102+
public Builder metadata(Map<String, Object> metadata) {
103+
this.metadata = metadata;
104+
return this;
105+
}
106+
107+
public ToolResponseMessage build() {
108+
return new ToolResponseMessage(this.responses, this.metadata);
109+
}
110+
111+
}
112+
76113
}

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,10 @@ else if (message instanceof AssistantMessage assistantMessage) {
184184
.build());
185185
}
186186
else if (message instanceof ToolResponseMessage toolResponseMessage) {
187-
messagesCopy.add(new ToolResponseMessage(new ArrayList<>(toolResponseMessage.getResponses()),
188-
new HashMap<>(toolResponseMessage.getMetadata())));
187+
messagesCopy.add(ToolResponseMessage.builder()
188+
.responses(new ArrayList<>(toolResponseMessage.getResponses()))
189+
.metadata(new HashMap<>(toolResponseMessage.getMetadata()))
190+
.build());
189191
}
190192
else {
191193
throw new IllegalArgumentException("Unsupported message type: " + message.getClass().getName());

0 commit comments

Comments
 (0)