Skip to content

Commit e547661

Browse files
authored
feat: Add metadata support to user and system messages in ChatClient (#3989)
- Add metadata methods to PromptUserSpec and PromptSystemSpec interfaces - Update DefaultChatClientRequestSpec to handle user and system metadata - Modify DefaultChatClientUtils to include metadata in SystemMessage and UserMessage builders - Add comprehensive test coverage for metadata functionality Signed-off-by: YunKui Lu <[email protected]>
1 parent 3ee33c1 commit e547661

File tree

6 files changed

+423
-52
lines changed

6 files changed

+423
-52
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ interface PromptUserSpec {
114114

115115
PromptUserSpec media(MimeType mimeType, Resource resource);
116116

117+
PromptUserSpec metadata(Map<String, Object> metadata);
118+
119+
PromptUserSpec metadata(String k, Object v);
120+
117121
}
118122

119123
/**
@@ -131,6 +135,10 @@ interface PromptSystemSpec {
131135

132136
PromptSystemSpec param(String k, Object v);
133137

138+
PromptSystemSpec metadata(Map<String, Object> metadata);
139+
140+
PromptSystemSpec metadata(String k, Object v);
141+
134142
}
135143

136144
interface AdvisorSpec {

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 96 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ public static class DefaultPromptUserSpec implements PromptUserSpec {
134134

135135
private final Map<String, Object> params = new HashMap<>();
136136

137+
private final Map<String, Object> metadata = new HashMap<>();
138+
137139
private final List<Media> media = new ArrayList<>();
138140

139141
@Nullable
@@ -212,6 +214,23 @@ public PromptUserSpec params(Map<String, Object> params) {
212214
return this;
213215
}
214216

217+
@Override
218+
public PromptUserSpec metadata(Map<String, Object> metadata) {
219+
Assert.notNull(metadata, "metadata cannot be null");
220+
Assert.noNullElements(metadata.keySet(), "metadata keys cannot contain null elements");
221+
Assert.noNullElements(metadata.values(), "metadata values cannot contain null elements");
222+
this.metadata.putAll(metadata);
223+
return this;
224+
}
225+
226+
@Override
227+
public PromptUserSpec metadata(String key, Object value) {
228+
Assert.hasText(key, "metadata key cannot be null or empty");
229+
Assert.notNull(value, "metadata value cannot be null");
230+
this.metadata.put(key, value);
231+
return this;
232+
}
233+
215234
@Nullable
216235
protected String text() {
217236
return this.text;
@@ -225,12 +244,18 @@ protected List<Media> media() {
225244
return this.media;
226245
}
227246

247+
protected Map<String, Object> metadata() {
248+
return this.metadata;
249+
}
250+
228251
}
229252

230253
public static class DefaultPromptSystemSpec implements PromptSystemSpec {
231254

232255
private final Map<String, Object> params = new HashMap<>();
233256

257+
private final Map<String, Object> metadata = new HashMap<>();
258+
234259
@Nullable
235260
private String text;
236261

@@ -278,6 +303,23 @@ public PromptSystemSpec params(Map<String, Object> params) {
278303
return this;
279304
}
280305

306+
@Override
307+
public PromptSystemSpec metadata(Map<String, Object> metadata) {
308+
Assert.notNull(metadata, "metadata cannot be null");
309+
Assert.noNullElements(metadata.keySet(), "metadata keys cannot contain null elements");
310+
Assert.noNullElements(metadata.values(), "metadata values cannot contain null elements");
311+
this.metadata.putAll(metadata);
312+
return this;
313+
}
314+
315+
@Override
316+
public PromptSystemSpec metadata(String key, Object value) {
317+
Assert.hasText(key, "metadata key cannot be null or empty");
318+
Assert.notNull(value, "metadata value cannot be null");
319+
this.metadata.put(key, value);
320+
return this;
321+
}
322+
281323
@Nullable
282324
protected String text() {
283325
return this.text;
@@ -287,6 +329,10 @@ protected Map<String, Object> params() {
287329
return this.params;
288330
}
289331

332+
protected Map<String, Object> metadata() {
333+
return this.metadata;
334+
}
335+
290336
}
291337

292338
public static class DefaultAdvisorSpec implements AdvisorSpec {
@@ -576,8 +622,12 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
576622

577623
private final Map<String, Object> userParams = new HashMap<>();
578624

625+
private final Map<String, Object> userMetadata = new HashMap<>();
626+
579627
private final Map<String, Object> systemParams = new HashMap<>();
580628

629+
private final Map<String, Object> systemMetadata = new HashMap<>();
630+
581631
private final List<Advisor> advisors = new ArrayList<>();
582632

583633
private final Map<String, Object> advisorParams = new HashMap<>();
@@ -597,22 +647,25 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
597647

598648
/* copy constructor */
599649
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
600-
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks,
601-
ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
602-
ccr.observationRegistry, ccr.observationConvention, ccr.toolContext, ccr.templateRenderer);
650+
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.userMetadata, ccr.systemText, ccr.systemParams,
651+
ccr.systemMetadata, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions,
652+
ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention,
653+
ccr.toolContext, ccr.templateRenderer);
603654
}
604655

605656
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
606-
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
607-
List<ToolCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media,
608-
@Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams,
609-
ObservationRegistry observationRegistry,
657+
Map<String, Object> userParams, Map<String, Object> userMetadata, @Nullable String systemText,
658+
Map<String, Object> systemParams, Map<String, Object> systemMetadata, List<ToolCallback> toolCallbacks,
659+
List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions,
660+
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
610661
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext,
611662
@Nullable TemplateRenderer templateRenderer) {
612663

613664
Assert.notNull(chatModel, "chatModel cannot be null");
614665
Assert.notNull(userParams, "userParams cannot be null");
666+
Assert.notNull(userMetadata, "userMetadata cannot be null");
615667
Assert.notNull(systemParams, "systemParams cannot be null");
668+
Assert.notNull(systemMetadata, "systemMetadata cannot be null");
616669
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
617670
Assert.notNull(messages, "messages cannot be null");
618671
Assert.notNull(toolNames, "toolNames cannot be null");
@@ -628,8 +681,11 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
628681

629682
this.userText = userText;
630683
this.userParams.putAll(userParams);
684+
this.userMetadata.putAll(userMetadata);
685+
631686
this.systemText = systemText;
632687
this.systemParams.putAll(systemParams);
688+
this.systemMetadata.putAll(systemMetadata);
633689

634690
this.toolNames.addAll(toolNames);
635691
this.toolCallbacks.addAll(toolCallbacks);
@@ -653,6 +709,10 @@ public Map<String, Object> getUserParams() {
653709
return this.userParams;
654710
}
655711

712+
public Map<String, Object> getUserMetadata() {
713+
return this.userMetadata;
714+
}
715+
656716
@Nullable
657717
public String getSystemText() {
658718
return this.systemText;
@@ -662,6 +722,10 @@ public Map<String, Object> getSystemParams() {
662722
return this.systemParams;
663723
}
664724

725+
public Map<String, Object> getSystemMetadata() {
726+
return this.systemMetadata;
727+
}
728+
665729
@Nullable
666730
public ChatOptions getChatOptions() {
667731
return this.chatOptions;
@@ -717,12 +781,15 @@ public Builder mutate() {
717781
}
718782

719783
if (StringUtils.hasText(this.userText)) {
720-
builder.defaultUser(
721-
u -> u.text(this.userText).params(this.userParams).media(this.media.toArray(new Media[0])));
784+
builder.defaultUser(u -> u.text(this.userText)
785+
.params(this.userParams)
786+
.media(this.media.toArray(new Media[0]))
787+
.metadata(this.userMetadata));
722788
}
723789

724790
if (StringUtils.hasText(this.systemText)) {
725-
builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams));
791+
builder.defaultSystem(
792+
s -> s.text(this.systemText).params(this.systemParams).metadata(this.systemMetadata));
726793
}
727794

728795
if (this.chatOptions != null) {
@@ -734,6 +801,7 @@ public Builder mutate() {
734801
return builder;
735802
}
736803

804+
@Override
737805
public ChatClientRequestSpec advisors(Consumer<ChatClient.AdvisorSpec> consumer) {
738806
Assert.notNull(consumer, "consumer cannot be null");
739807
var advisorSpec = new DefaultAdvisorSpec();
@@ -743,27 +811,31 @@ public ChatClientRequestSpec advisors(Consumer<ChatClient.AdvisorSpec> consumer)
743811
return this;
744812
}
745813

814+
@Override
746815
public ChatClientRequestSpec advisors(Advisor... advisors) {
747816
Assert.notNull(advisors, "advisors cannot be null");
748817
Assert.noNullElements(advisors, "advisors cannot contain null elements");
749818
this.advisors.addAll(Arrays.asList(advisors));
750819
return this;
751820
}
752821

822+
@Override
753823
public ChatClientRequestSpec advisors(List<Advisor> advisors) {
754824
Assert.notNull(advisors, "advisors cannot be null");
755825
Assert.noNullElements(advisors, "advisors cannot contain null elements");
756826
this.advisors.addAll(advisors);
757827
return this;
758828
}
759829

830+
@Override
760831
public ChatClientRequestSpec messages(Message... messages) {
761832
Assert.notNull(messages, "messages cannot be null");
762833
Assert.noNullElements(messages, "messages cannot contain null elements");
763834
this.messages.addAll(List.of(messages));
764835
return this;
765836
}
766837

838+
@Override
767839
public ChatClientRequestSpec messages(List<Message> messages) {
768840
Assert.notNull(messages, "messages cannot be null");
769841
Assert.noNullElements(messages, "messages cannot contain null elements");
@@ -819,6 +891,7 @@ public ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackP
819891
return this;
820892
}
821893

894+
@Override
822895
public ChatClientRequestSpec toolContext(Map<String, Object> toolContext) {
823896
Assert.notNull(toolContext, "toolContext cannot be null");
824897
Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements");
@@ -827,12 +900,14 @@ public ChatClientRequestSpec toolContext(Map<String, Object> toolContext) {
827900
return this;
828901
}
829902

903+
@Override
830904
public ChatClientRequestSpec system(String text) {
831905
Assert.hasText(text, "text cannot be null or empty");
832906
this.systemText = text;
833907
return this;
834908
}
835909

910+
@Override
836911
public ChatClientRequestSpec system(Resource text, Charset charset) {
837912
Assert.notNull(text, "text cannot be null");
838913
Assert.notNull(charset, "charset cannot be null");
@@ -846,28 +921,32 @@ public ChatClientRequestSpec system(Resource text, Charset charset) {
846921
return this;
847922
}
848923

924+
@Override
849925
public ChatClientRequestSpec system(Resource text) {
850926
Assert.notNull(text, "text cannot be null");
851927
return this.system(text, Charset.defaultCharset());
852928
}
853929

930+
@Override
854931
public ChatClientRequestSpec system(Consumer<PromptSystemSpec> consumer) {
855932
Assert.notNull(consumer, "consumer cannot be null");
856933

857934
var systemSpec = new DefaultPromptSystemSpec();
858935
consumer.accept(systemSpec);
859936
this.systemText = StringUtils.hasText(systemSpec.text()) ? systemSpec.text() : this.systemText;
860937
this.systemParams.putAll(systemSpec.params());
861-
938+
this.systemMetadata.putAll(systemSpec.metadata());
862939
return this;
863940
}
864941

942+
@Override
865943
public ChatClientRequestSpec user(String text) {
866944
Assert.hasText(text, "text cannot be null or empty");
867945
this.userText = text;
868946
return this;
869947
}
870948

949+
@Override
871950
public ChatClientRequestSpec user(Resource text, Charset charset) {
872951
Assert.notNull(text, "text cannot be null");
873952
Assert.notNull(charset, "charset cannot be null");
@@ -881,11 +960,13 @@ public ChatClientRequestSpec user(Resource text, Charset charset) {
881960
return this;
882961
}
883962

963+
@Override
884964
public ChatClientRequestSpec user(Resource text) {
885965
Assert.notNull(text, "text cannot be null");
886966
return this.user(text, Charset.defaultCharset());
887967
}
888968

969+
@Override
889970
public ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer) {
890971
Assert.notNull(consumer, "consumer cannot be null");
891972

@@ -894,21 +975,25 @@ public ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer) {
894975
this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText;
895976
this.userParams.putAll(us.params());
896977
this.media.addAll(us.media());
978+
this.userMetadata.putAll(us.metadata());
897979
return this;
898980
}
899981

982+
@Override
900983
public ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer) {
901984
Assert.notNull(templateRenderer, "templateRenderer cannot be null");
902985
this.templateRenderer = templateRenderer;
903986
return this;
904987
}
905988

989+
@Override
906990
public CallResponseSpec call() {
907991
BaseAdvisorChain advisorChain = buildAdvisorChain();
908992
return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
909993
this.observationRegistry, this.observationConvention);
910994
}
911995

996+
@Override
912997
public StreamResponseSpec stream() {
913998
BaseAdvisorChain advisorChain = buildAdvisorChain();
914999
return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa
6464
@Nullable ChatClientObservationConvention customObservationConvention) {
6565
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
6666
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
67-
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(),
68-
List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
67+
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), Map.of(), null, Map.of(),
68+
Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
6969
customObservationConvention, Map.of(), null);
7070
}
7171

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient
6767
.build()
6868
.render();
6969
}
70-
processedMessages.add(new SystemMessage(processedSystemText));
70+
processedMessages.add(SystemMessage.builder()
71+
.text(processedSystemText)
72+
.metadata(inputRequest.getSystemMetadata())
73+
.build());
7174
}
7275

7376
// Messages => In the middle of the list
@@ -86,7 +89,11 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient
8689
.build()
8790
.render();
8891
}
89-
processedMessages.add(UserMessage.builder().text(processedUserText).media(inputRequest.getMedia()).build());
92+
processedMessages.add(UserMessage.builder()
93+
.text(processedUserText)
94+
.media(inputRequest.getMedia())
95+
.metadata(inputRequest.getUserMetadata())
96+
.build());
9097
}
9198

9299
/*

0 commit comments

Comments
 (0)