Skip to content

Commit e46dccf

Browse files
committed
Add builder to ToolResponseMessage
Signed-off-by: Jemin Huh <[email protected]>
1 parent 66699a7 commit e46dccf

File tree

4 files changed

+60
-11
lines changed

4 files changed

+60
-11
lines changed

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ void whenChatResponseContentIsNull() {
784784
ChatModel chatModel = mock(ChatModel.class);
785785
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
786786
given(chatModel.call(promptCaptor.capture()))
787-
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null)))));
787+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null)))));
788788

789789
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
790790
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
@@ -814,7 +814,7 @@ void whenResponseEntityWithParameterizedTypeAndChatResponseContentNull() {
814814
ChatModel chatModel = mock(ChatModel.class);
815815
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
816816
given(chatModel.call(promptCaptor.capture()))
817-
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null)))));
817+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null)))));
818818

819819
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
820820
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
@@ -874,7 +874,7 @@ void whenResponseEntityWithConverterAndChatResponseContentNull() {
874874
ChatModel chatModel = mock(ChatModel.class);
875875
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
876876
given(chatModel.call(promptCaptor.capture()))
877-
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null)))));
877+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null)))));
878878

879879
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
880880
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
@@ -926,7 +926,7 @@ void whenResponseEntityWithTypeAndChatResponseContentNull() {
926926
ChatModel chatModel = mock(ChatModel.class);
927927
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
928928
given(chatModel.call(promptCaptor.capture()))
929-
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null)))));
929+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null)))));
930930

931931
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
932932
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
@@ -978,7 +978,7 @@ void whenEntityWithParameterizedTypeAndChatResponseContentNull() {
978978
ChatModel chatModel = mock(ChatModel.class);
979979
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
980980
given(chatModel.call(promptCaptor.capture()))
981-
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null)))));
981+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null)))));
982982

983983
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
984984
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
@@ -1076,7 +1076,7 @@ void whenEntityWithTypeAndChatResponseContentNull() {
10761076
ChatModel chatModel = mock(ChatModel.class);
10771077
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
10781078
given(chatModel.call(promptCaptor.capture()))
1079-
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null)))));
1079+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null)))));
10801080

10811081
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
10821082
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
@@ -1282,7 +1282,7 @@ void whenChatResponseContentIsNullThenReturnFlux() {
12821282
ChatModel chatModel = mock(ChatModel.class);
12831283
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
12841284
given(chatModel.stream(promptCaptor.capture()))
1285-
.willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))));
1285+
.willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage((String) null))))));
12861286

12871287
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
12881288
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient

spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
package org.springframework.ai.chat.messages;
1818

19+
import org.springframework.lang.Nullable;
20+
21+
import java.util.Arrays;
22+
import java.util.HashMap;
1923
import java.util.List;
2024
import java.util.Map;
2125
import java.util.Objects;
@@ -24,6 +28,7 @@
2428
* The ToolResponseMessage class represents a message with a function content in a chat
2529
* application.
2630
*
31+
* @author Jemin Huh
2732
* @author Christian Tzolov
2833
* @since 1.0.0
2934
*/
@@ -35,7 +40,7 @@ public ToolResponseMessage(List<ToolResponse> responses) {
3540
this(responses, Map.of());
3641
}
3742

38-
public ToolResponseMessage(List<ToolResponse> responses, Map<String, Object> metadata) {
43+
private ToolResponseMessage(List<ToolResponse> responses, Map<String, Object> metadata) {
3944
super(MessageType.TOOL, "", metadata);
4045
this.responses = responses;
4146
}
@@ -73,4 +78,45 @@ public record ToolResponse(String id, String name, String responseData) {
7378

7479
}
7580

81+
public ToolResponseMessage copy() {
82+
return new ToolResponseMessage(getResponses(), Map.copyOf(this.metadata));
83+
}
84+
85+
public ToolResponseMessage.Builder mutate() {
86+
return new Builder().responses(getResponses()).metadata(Map.copyOf(this.metadata));
87+
}
88+
89+
public static ToolResponseMessage.Builder builder() {
90+
return new ToolResponseMessage.Builder();
91+
}
92+
93+
public static class Builder {
94+
95+
private List<ToolResponse> responses;
96+
97+
private Map<String, Object> metadata = new HashMap<>();
98+
99+
public ToolResponseMessage.Builder responses(List<ToolResponse> responses) {
100+
this.responses = responses;
101+
return this;
102+
}
103+
104+
public ToolResponseMessage.Builder media(@Nullable ToolResponse... responses) {
105+
if (responses != null) {
106+
this.responses = Arrays.asList(responses);
107+
}
108+
return this;
109+
}
110+
111+
public ToolResponseMessage.Builder metadata(Map<String, Object> metadata) {
112+
this.metadata = metadata;
113+
return this;
114+
}
115+
116+
public ToolResponseMessage build() {
117+
return new ToolResponseMessage(this.responses, this.metadata);
118+
}
119+
120+
}
121+
76122
}

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,10 @@ else if (message instanceof AssistantMessage assistantMessage) {
181181
messagesCopy.add(assistantMessage.copy());
182182
}
183183
else if (message instanceof ToolResponseMessage toolResponseMessage) {
184-
messagesCopy.add(new ToolResponseMessage(new ArrayList<>(toolResponseMessage.getResponses()),
185-
new HashMap<>(toolResponseMessage.getMetadata())));
184+
messagesCopy.add(ToolResponseMessage.builder()
185+
.responses(toolResponseMessage.getResponses())
186+
.metadata(new HashMap<>(toolResponseMessage.getMetadata()))
187+
.build());
186188
}
187189
else {
188190
throw new IllegalArgumentException("Unsupported message type: " + message.getClass().getName());

spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess
233233
toolCallResult != null ? toolCallResult : ""));
234234
}
235235

236-
return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
236+
return new InternalToolExecutionResult(ToolResponseMessage.builder().responses(toolResponses).build(),
237+
returnDirect);
237238
}
238239

239240
private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,

0 commit comments

Comments
 (0)