Skip to content

Commit 66699a7

Browse files
committed
Add builder to AssistantMessage and rename properties to metadata
Signed-off-by: Jemin Huh <[email protected]>
1 parent 36c5977 commit 66699a7

File tree

8 files changed

+160
-42
lines changed

8 files changed

+160
-42
lines changed

spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,28 @@
1616

1717
package org.springframework.ai.chat.messages;
1818

19+
import java.util.ArrayList;
20+
import java.util.Arrays;
21+
import java.util.HashMap;
1922
import java.util.List;
2023
import java.util.Map;
2124
import java.util.Objects;
2225

2326
import org.springframework.ai.content.Media;
2427
import org.springframework.ai.content.MediaContent;
28+
import org.springframework.core.io.Resource;
29+
import org.springframework.lang.Nullable;
2530
import org.springframework.util.Assert;
2631
import org.springframework.util.CollectionUtils;
32+
import org.springframework.util.StringUtils;
2733

2834
/**
2935
* Lets the generative know the content was generated as a response to the user. This role
3036
* indicates messages that the generative has previously generated in the conversation. By
3137
* including assistant messages in the series, you provide context to the generative about
3238
* prior exchanges in the conversation.
3339
*
40+
* @author Jemin Huh
3441
* @author Mark Pollack
3542
* @author Christian Tzolov
3643
* @since 1.0.0
@@ -42,20 +49,16 @@ public class AssistantMessage extends AbstractMessage implements MediaContent {
4249
protected final List<Media> media;
4350

4451
public AssistantMessage(String content) {
45-
this(content, Map.of());
52+
this(content, Map.of(), List.of(), List.of());
4653
}
4754

48-
public AssistantMessage(String content, Map<String, Object> properties) {
49-
this(content, properties, List.of());
55+
public AssistantMessage(Resource resource) {
56+
this(MessageUtils.readResource(resource));
5057
}
5158

52-
public AssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls) {
53-
this(content, properties, toolCalls, List.of());
54-
}
55-
56-
public AssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls,
59+
private AssistantMessage(String content, Map<String, Object> metadata, List<ToolCall> toolCalls,
5760
List<Media> media) {
58-
super(MessageType.ASSISTANT, content, properties);
61+
super(MessageType.ASSISTANT, content, metadata);
5962
Assert.notNull(toolCalls, "Tool calls must not be null");
6063
Assert.notNull(media, "Media must not be null");
6164
this.toolCalls = toolCalls;
@@ -104,4 +107,88 @@ public record ToolCall(String id, String type, String name, String arguments) {
104107

105108
}
106109

110+
public AssistantMessage copy() {
111+
return new Builder().text(getText())
112+
.metadata(Map.copyOf(getMetadata()))
113+
.toolCalls(List.copyOf(getToolCalls()))
114+
.media(List.copyOf(getMedia()))
115+
.build();
116+
}
117+
118+
public AssistantMessage.Builder mutate() {
119+
return new Builder().text(getText())
120+
.metadata(Map.copyOf(getMetadata()))
121+
.toolCalls(List.copyOf(getToolCalls()))
122+
.media(List.copyOf(getMedia()));
123+
}
124+
125+
public static AssistantMessage.Builder builder() {
126+
return new AssistantMessage.Builder();
127+
}
128+
129+
public static class Builder {
130+
131+
@Nullable
132+
private String textContent;
133+
134+
@Nullable
135+
private Resource resource;
136+
137+
private Map<String, Object> metadata = new HashMap<>();
138+
139+
private List<ToolCall> toolCalls = new ArrayList<>();
140+
141+
private List<Media> media = new ArrayList<>();
142+
143+
public AssistantMessage.Builder text(String textContent) {
144+
this.textContent = textContent;
145+
return this;
146+
}
147+
148+
public AssistantMessage.Builder text(Resource resource) {
149+
this.resource = resource;
150+
return this;
151+
}
152+
153+
public AssistantMessage.Builder metadata(Map<String, Object> metadata) {
154+
this.metadata = metadata;
155+
return this;
156+
}
157+
158+
public AssistantMessage.Builder toolCalls(List<ToolCall> toolCalls) {
159+
this.toolCalls = toolCalls;
160+
return this;
161+
}
162+
163+
public AssistantMessage.Builder toolCalls(@Nullable ToolCall... toolCalls) {
164+
if (media != null) {
165+
this.toolCalls = Arrays.asList(toolCalls);
166+
}
167+
return this;
168+
}
169+
170+
public AssistantMessage.Builder media(List<Media> media) {
171+
this.media = media;
172+
return this;
173+
}
174+
175+
public AssistantMessage.Builder media(@Nullable Media... media) {
176+
if (media != null) {
177+
this.media = Arrays.asList(media);
178+
}
179+
return this;
180+
}
181+
182+
public AssistantMessage build() {
183+
if (StringUtils.hasText(this.textContent) && this.resource != null) {
184+
throw new IllegalArgumentException("textContent and resource cannot be set at the same time");
185+
}
186+
else if (this.resource != null) {
187+
this.textContent = MessageUtils.readResource(this.resource);
188+
}
189+
return new AssistantMessage(this.textContent, this.metadata, this.toolCalls, this.media);
190+
}
191+
192+
}
193+
107194
}

spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
* Helper that for streaming chat responses, aggregate the chat response messages into a
4040
* single AssistantMessage. Job is performed in parallel to the chat response processing.
4141
*
42+
* @author Jemin Huh
4243
* @author Christian Tzolov
4344
* @author Alexandros Pappas
4445
* @author Thomas Vitale
@@ -133,9 +134,10 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
133134
.promptMetadata(metadataPromptMetadataRef.get())
134135
.build();
135136

136-
onAggregationComplete.accept(new ChatResponse(List.of(new Generation(
137-
new AssistantMessage(messageTextContentRef.get().toString(), messageMetadataMapRef.get()),
138-
generationMetadataRef.get())), chatResponseMetadata));
137+
onAggregationComplete.accept(new ChatResponse(List.of(new Generation(AssistantMessage.builder()
138+
.text(messageTextContentRef.get().toString())
139+
.metadata(messageMetadataMapRef.get())
140+
.build(), generationMetadataRef.get())), chatResponseMetadata));
139141

140142
messageTextContentRef.set(new StringBuilder());
141143
messageMetadataMapRef.set(new HashMap<>());

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
* The Prompt class represents a prompt used in AI model requests. A prompt consists of
4040
* one or more messages and additional chat options.
4141
*
42+
* @author Jemin Huh
4243
* @author Mark Pollack
4344
* @author luocongqiu
4445
* @author Thomas Vitale
@@ -177,8 +178,7 @@ else if (message instanceof SystemMessage systemMessage) {
177178
messagesCopy.add(systemMessage.copy());
178179
}
179180
else if (message instanceof AssistantMessage assistantMessage) {
180-
messagesCopy.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
181-
assistantMessage.getToolCalls()));
181+
messagesCopy.add(assistantMessage.copy());
182182
}
183183
else if (message instanceof ToolResponseMessage toolResponseMessage) {
184184
messagesCopy.add(new ToolResponseMessage(new ArrayList<>(toolResponseMessage.getResponses()),

spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
/**
5151
* Default implementation of {@link ToolCallingManager}.
5252
*
53+
* @author Jemin Huh
5354
* @author Thomas Vitale
5455
* @since 1.0.0
5556
*/
@@ -154,8 +155,7 @@ private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assi
154155
toolContextMap = new HashMap<>(toolCallingChatOptions.getToolContext());
155156

156157
List<Message> messageHistory = new ArrayList<>(prompt.copy().getInstructions());
157-
messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
158-
assistantMessage.getToolCalls()));
158+
messageHistory.add(assistantMessage.copy());
159159

160160
toolContextMap.put(ToolContext.TOOL_CALL_HISTORY,
161161
buildConversationHistoryBeforeToolExecution(prompt, assistantMessage));
@@ -167,8 +167,7 @@ private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assi
167167
private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt prompt,
168168
AssistantMessage assistantMessage) {
169169
List<Message> messageHistory = new ArrayList<>(prompt.copy().getInstructions());
170-
messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
171-
assistantMessage.getToolCalls()));
170+
messageHistory.add(assistantMessage.copy());
172171
return messageHistory;
173172
}
174173

spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ class ChatResponseTests {
3838
@Test
3939
void whenToolCallsArePresentThenReturnTrue() {
4040
ChatResponse chatResponse = ChatResponse.builder()
41-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
42-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))))))
41+
.generations(List.of(new Generation(AssistantMessage.builder()
42+
.text("")
43+
.toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))
44+
.build())))
4345
.build();
4446
assertThat(chatResponse.hasToolCalls()).isTrue();
4547
}

spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerIT.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ void observationForToolCall() {
7373
.build();
7474

7575
ChatResponse chatResponse = ChatResponse.builder()
76-
.generations(List.of(new Generation(new AssistantMessage("Answer", Map.of(),
77-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))))))
76+
.generations(List.of(new Generation(AssistantMessage.builder()
77+
.text("Answer")
78+
.toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))
79+
.build())))
7880
.build();
7981

8082
ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);

spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,10 @@ void whenSingleToolCallInChatResponseThenExecute() {
158158

159159
Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build());
160160
ChatResponse chatResponse = ChatResponse.builder()
161-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
162-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))))))
161+
.generations(List.of(new Generation(AssistantMessage.builder()
162+
.text("")
163+
.toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))
164+
.build())))
163165
.build();
164166

165167
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
@@ -180,8 +182,10 @@ void whenSingleToolCallWithReturnDirectInChatResponseThenExecute() {
180182

181183
Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build());
182184
ChatResponse chatResponse = ChatResponse.builder()
183-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
184-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))))))
185+
.generations(List.of(new Generation(AssistantMessage.builder()
186+
.text("")
187+
.toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))
188+
.build())))
185189
.build();
186190

187191
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
@@ -205,9 +209,11 @@ void whenMultipleToolCallsInChatResponseThenExecute() {
205209

206210
Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build());
207211
ChatResponse chatResponse = ChatResponse.builder()
208-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
209-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"),
210-
new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))))))
212+
.generations(List.of(new Generation(AssistantMessage.builder()
213+
.text("")
214+
.toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"),
215+
new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))
216+
.build())))
211217
.build();
212218

213219
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
@@ -229,8 +235,10 @@ void whenDuplicateMixedToolCallsInChatResponseThenExecute() {
229235
.toolNames("toolA")
230236
.build());
231237
ChatResponse chatResponse = ChatResponse.builder()
232-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
233-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))))))
238+
.generations(List.of(new Generation(AssistantMessage.builder()
239+
.text("")
240+
.toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))
241+
.build())))
234242
.build();
235243

236244
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
@@ -253,9 +261,11 @@ void whenMultipleToolCallsWithReturnDirectInChatResponseThenExecute() {
253261

254262
Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build());
255263
ChatResponse chatResponse = ChatResponse.builder()
256-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
257-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"),
258-
new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))))))
264+
.generations(List.of(new Generation(AssistantMessage.builder()
265+
.text("")
266+
.toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"),
267+
new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))
268+
.build())))
259269
.build();
260270

261271
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
@@ -280,9 +290,11 @@ void whenMultipleToolCallsWithMixedReturnDirectInChatResponseThenExecute() {
280290

281291
Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build());
282292
ChatResponse chatResponse = ChatResponse.builder()
283-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
284-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"),
285-
new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))))))
293+
.generations(List.of(new Generation(AssistantMessage.builder()
294+
.text("")
295+
.toolCalls(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"),
296+
new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))
297+
.build())))
286298
.build();
287299

288300
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
@@ -305,8 +317,10 @@ void whenToolCallWithExceptionThenReturnError() {
305317

306318
Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build());
307319
ChatResponse chatResponse = ChatResponse.builder()
308-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
309-
List.of(new AssistantMessage.ToolCall("toolC", "function", "toolC", "{}"))))))
320+
.generations(List.of(new Generation(AssistantMessage.builder()
321+
.text("")
322+
.toolCalls(new AssistantMessage.ToolCall("toolC", "function", "toolC", "{}"))
323+
.build())))
310324
.build();
311325

312326
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(

spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ void whenToolExecutionEnabledAndHasToolCalls() {
4444

4545
// Create a ChatResponse with tool calls
4646
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}");
47-
AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall));
47+
AssistantMessage assistantMessage = AssistantMessage.builder()
48+
.text("test")
49+
.metadata(Map.of())
50+
.toolCalls(List.of(toolCall))
51+
.build();
4852
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage)));
4953

5054
// Test the predicate
@@ -73,7 +77,11 @@ void whenToolExecutionDisabledAndHasToolCalls() {
7377

7478
// Create a ChatResponse with tool calls
7579
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}");
76-
AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall));
80+
AssistantMessage assistantMessage = AssistantMessage.builder()
81+
.text("test")
82+
.metadata(Map.of())
83+
.toolCalls(List.of(toolCall))
84+
.build();
7785
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage)));
7886

7987
// Test the predicate
@@ -102,7 +110,11 @@ void whenRegularChatOptionsAndHasToolCalls() {
102110

103111
// Create a ChatResponse with tool calls
104112
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}");
105-
AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall));
113+
AssistantMessage assistantMessage = AssistantMessage.builder()
114+
.text("test")
115+
.metadata(Map.of())
116+
.toolCalls(List.of(toolCall))
117+
.build();
106118
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage)));
107119

108120
// Test the predicate - should use default value (true) for internal tool

0 commit comments

Comments
 (0)