diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java index bf74ab6335c..3b445aa0d14 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -114,6 +114,10 @@ interface PromptUserSpec { PromptUserSpec media(MimeType mimeType, Resource resource); + PromptUserSpec metadata(Map metadata); + + PromptUserSpec metadata(String k, Object v); + } /** @@ -131,6 +135,10 @@ interface PromptSystemSpec { PromptSystemSpec param(String k, Object v); + PromptSystemSpec metadata(Map metadata); + + PromptSystemSpec metadata(String k, Object v); + } interface AdvisorSpec { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 5cfe18ac9b4..6452ccee6d9 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -134,6 +134,8 @@ public static class DefaultPromptUserSpec implements PromptUserSpec { private final Map params = new HashMap<>(); + private final Map metadata = new HashMap<>(); + private final List media = new ArrayList<>(); @Nullable @@ -212,6 +214,23 @@ public PromptUserSpec params(Map params) { return this; } + @Override + public PromptUserSpec metadata(Map metadata) { + Assert.notNull(metadata, "metadata cannot be null"); + Assert.noNullElements(metadata.keySet(), "metadata keys cannot contain null elements"); + Assert.noNullElements(metadata.values(), "metadata values cannot contain null elements"); + this.metadata.putAll(metadata); + return this; + } + + @Override + public PromptUserSpec metadata(String key, Object value) { + Assert.hasText(key, "metadata key cannot be null or empty"); + Assert.notNull(value, "metadata value cannot be null"); + this.metadata.put(key, value); + return this; + } + @Nullable protected String text() { return this.text; @@ -225,12 +244,18 @@ protected List media() { return this.media; } + protected Map metadata() { + return this.metadata; + } + } public static class DefaultPromptSystemSpec implements PromptSystemSpec { private final Map params = new HashMap<>(); + private final Map metadata = new HashMap<>(); + @Nullable private String text; @@ -278,6 +303,23 @@ public PromptSystemSpec params(Map params) { return this; } + @Override + public PromptSystemSpec metadata(Map metadata) { + Assert.notNull(metadata, "metadata cannot be null"); + Assert.noNullElements(metadata.keySet(), "metadata keys cannot contain null elements"); + Assert.noNullElements(metadata.values(), "metadata values cannot contain null elements"); + this.metadata.putAll(metadata); + return this; + } + + @Override + public PromptSystemSpec metadata(String key, Object value) { + Assert.hasText(key, "metadata key cannot be null or empty"); + Assert.notNull(value, "metadata value cannot be null"); + this.metadata.put(key, value); + return this; + } + @Nullable protected String text() { return this.text; @@ -287,6 +329,10 @@ protected Map params() { return this.params; } + protected Map metadata() { + return this.metadata; + } + } public static class DefaultAdvisorSpec implements AdvisorSpec { @@ -579,8 +625,12 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final Map userParams = new HashMap<>(); + private final Map userMetadata = new HashMap<>(); + private final Map systemParams = new HashMap<>(); + private final Map systemMetadata = new HashMap<>(); + private final List advisors = new ArrayList<>(); private final Map advisorParams = new HashMap<>(); @@ -600,22 +650,25 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe /* copy constructor */ DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { - this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks, - ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, - ccr.observationRegistry, ccr.observationConvention, ccr.toolContext, ccr.templateRenderer); + this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.userMetadata, ccr.systemText, ccr.systemParams, + ccr.systemMetadata, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, + ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention, + ccr.toolContext, ccr.templateRenderer); } public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, - Map userParams, @Nullable String systemText, Map systemParams, - List toolCallbacks, List messages, List toolNames, List media, - @Nullable ChatOptions chatOptions, List advisors, Map advisorParams, - ObservationRegistry observationRegistry, + Map userParams, Map userMetadata, @Nullable String systemText, + Map systemParams, Map systemMetadata, List toolCallbacks, + List messages, List toolNames, List media, @Nullable ChatOptions chatOptions, + List advisors, Map advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map toolContext, @Nullable TemplateRenderer templateRenderer) { Assert.notNull(chatModel, "chatModel cannot be null"); Assert.notNull(userParams, "userParams cannot be null"); + Assert.notNull(userMetadata, "userMetadata cannot be null"); Assert.notNull(systemParams, "systemParams cannot be null"); + Assert.notNull(systemMetadata, "systemMetadata cannot be null"); Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.notNull(messages, "messages cannot be null"); Assert.notNull(toolNames, "toolNames cannot be null"); @@ -631,8 +684,11 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe this.userText = userText; this.userParams.putAll(userParams); + this.userMetadata.putAll(userMetadata); + this.systemText = systemText; this.systemParams.putAll(systemParams); + this.systemMetadata.putAll(systemMetadata); this.toolNames.addAll(toolNames); this.toolCallbacks.addAll(toolCallbacks); @@ -656,6 +712,10 @@ public Map getUserParams() { return this.userParams; } + public Map getUserMetadata() { + return this.userMetadata; + } + @Nullable public String getSystemText() { return this.systemText; @@ -665,6 +725,10 @@ public Map getSystemParams() { return this.systemParams; } + public Map getSystemMetadata() { + return this.systemMetadata; + } + @Nullable public ChatOptions getChatOptions() { return this.chatOptions; @@ -720,12 +784,15 @@ public Builder mutate() { } if (StringUtils.hasText(this.userText)) { - builder.defaultUser( - u -> u.text(this.userText).params(this.userParams).media(this.media.toArray(new Media[0]))); + builder.defaultUser(u -> u.text(this.userText) + .params(this.userParams) + .media(this.media.toArray(new Media[0])) + .metadata(this.userMetadata)); } if (StringUtils.hasText(this.systemText)) { - builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams)); + builder.defaultSystem( + s -> s.text(this.systemText).params(this.systemParams).metadata(this.systemMetadata)); } if (this.chatOptions != null) { @@ -737,6 +804,7 @@ public Builder mutate() { return builder; } + @Override public ChatClientRequestSpec advisors(Consumer consumer) { Assert.notNull(consumer, "consumer cannot be null"); var advisorSpec = new DefaultAdvisorSpec(); @@ -746,6 +814,7 @@ public ChatClientRequestSpec advisors(Consumer consumer) return this; } + @Override public ChatClientRequestSpec advisors(Advisor... advisors) { Assert.notNull(advisors, "advisors cannot be null"); Assert.noNullElements(advisors, "advisors cannot contain null elements"); @@ -753,6 +822,7 @@ public ChatClientRequestSpec advisors(Advisor... advisors) { return this; } + @Override public ChatClientRequestSpec advisors(List advisors) { Assert.notNull(advisors, "advisors cannot be null"); Assert.noNullElements(advisors, "advisors cannot contain null elements"); @@ -760,6 +830,7 @@ public ChatClientRequestSpec advisors(List advisors) { return this; } + @Override public ChatClientRequestSpec messages(Message... messages) { Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); @@ -767,6 +838,7 @@ public ChatClientRequestSpec messages(Message... messages) { return this; } + @Override public ChatClientRequestSpec messages(List messages) { Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); @@ -822,6 +894,7 @@ public ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackP return this; } + @Override public ChatClientRequestSpec toolContext(Map toolContext) { Assert.notNull(toolContext, "toolContext cannot be null"); Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements"); @@ -830,12 +903,14 @@ public ChatClientRequestSpec toolContext(Map toolContext) { return this; } + @Override public ChatClientRequestSpec system(String text) { Assert.hasText(text, "text cannot be null or empty"); this.systemText = text; return this; } + @Override public ChatClientRequestSpec system(Resource text, Charset charset) { Assert.notNull(text, "text cannot be null"); Assert.notNull(charset, "charset cannot be null"); @@ -849,11 +924,13 @@ public ChatClientRequestSpec system(Resource text, Charset charset) { return this; } + @Override public ChatClientRequestSpec system(Resource text) { Assert.notNull(text, "text cannot be null"); return this.system(text, Charset.defaultCharset()); } + @Override public ChatClientRequestSpec system(Consumer consumer) { Assert.notNull(consumer, "consumer cannot be null"); @@ -861,16 +938,18 @@ public ChatClientRequestSpec system(Consumer consumer) { consumer.accept(systemSpec); this.systemText = StringUtils.hasText(systemSpec.text()) ? systemSpec.text() : this.systemText; this.systemParams.putAll(systemSpec.params()); - + this.systemMetadata.putAll(systemSpec.metadata()); return this; } + @Override public ChatClientRequestSpec user(String text) { Assert.hasText(text, "text cannot be null or empty"); this.userText = text; return this; } + @Override public ChatClientRequestSpec user(Resource text, Charset charset) { Assert.notNull(text, "text cannot be null"); Assert.notNull(charset, "charset cannot be null"); @@ -884,11 +963,13 @@ public ChatClientRequestSpec user(Resource text, Charset charset) { return this; } + @Override public ChatClientRequestSpec user(Resource text) { Assert.notNull(text, "text cannot be null"); return this.user(text, Charset.defaultCharset()); } + @Override public ChatClientRequestSpec user(Consumer consumer) { Assert.notNull(consumer, "consumer cannot be null"); @@ -897,21 +978,25 @@ public ChatClientRequestSpec user(Consumer consumer) { this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText; this.userParams.putAll(us.params()); this.media.addAll(us.media()); + this.userMetadata.putAll(us.metadata()); return this; } + @Override public ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer) { Assert.notNull(templateRenderer, "templateRenderer cannot be null"); this.templateRenderer = templateRenderer; return this; } + @Override public CallResponseSpec call() { BaseAdvisorChain advisorChain = buildAdvisorChain(); return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, this.observationRegistry, this.observationConvention); } + @Override public StreamResponseSpec stream() { BaseAdvisorChain advisorChain = buildAdvisorChain(); return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 7d42596ef1e..a937356e543 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -64,8 +64,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa @Nullable ChatClientObservationConvention customObservationConvention) { Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); - this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(), - List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, + this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), Map.of(), null, Map.of(), + Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of(), null); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index 10f623e2b70..fe413734679 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -67,7 +67,10 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient .build() .render(); } - processedMessages.add(new SystemMessage(processedSystemText)); + processedMessages.add(SystemMessage.builder() + .text(processedSystemText) + .metadata(inputRequest.getSystemMetadata()) + .build()); } // Messages => In the middle of the list @@ -86,7 +89,11 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient .build() .render(); } - processedMessages.add(UserMessage.builder().text(processedUserText).media(inputRequest.getMedia()).build()); + processedMessages.add(UserMessage.builder() + .text(processedUserText) + .media(inputRequest.getMedia()) + .metadata(inputRequest.getUserMetadata()) + .build()); } /* diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 783a7356c0a..44ddb0aeccc 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -19,6 +19,7 @@ import java.net.MalformedURLException; import java.net.URL; import java.util.List; +import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; @@ -49,6 +50,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.BDDMockito.given; +import static org.springframework.ai.chat.messages.MessageType.USER; /** * @author Christian Tzolov @@ -92,6 +94,7 @@ void defaultSystemText() { Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("Default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); content = join(chatClient.prompt("What's Spring AI?").stream().content()); @@ -100,6 +103,7 @@ void defaultSystemText() { systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("Default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); // Override the default system text with prompt system content = chatClient.prompt("What's Spring AI?").system("Override default system text").call().content(); @@ -108,6 +112,7 @@ void defaultSystemText() { systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("Override default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); // Streaming content = join( @@ -117,6 +122,7 @@ void defaultSystemText() { systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("Override default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); } @Test @@ -135,7 +141,9 @@ void defaultSystemTextLambda() { var chatClient = ChatClient.builder(this.chatModel) .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") - .param("param2", "value2")) + .param("param2", "value2") + .metadata("metadata1", "svalue1") + .metadata("metadata2", "svalue2")) .build(); var content = chatClient.prompt("What's Spring AI?").call().content(); @@ -145,6 +153,10 @@ void defaultSystemTextLambda() { Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("Default system text value1, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("metadata1", "svalue1") + .containsEntry("metadata2", "svalue2"); // Streaming content = join(chatClient.prompt("What's Spring AI?").stream().content()); @@ -154,6 +166,10 @@ void defaultSystemTextLambda() { systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("Default system text value1, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("metadata1", "svalue1") + .containsEntry("metadata2", "svalue2"); // Override single default system parameter content = chatClient.prompt("What's Spring AI?").system(s -> s.param("param1", "value1New")).call().content(); @@ -162,6 +178,24 @@ void defaultSystemTextLambda() { systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("Default system text value1New, value2"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("metadata1", "svalue1") + .containsEntry("metadata2", "svalue2"); + + // Override default system metadata + content = chatClient.prompt("What's Spring AI?") + .system(s -> s.metadata("metadata1", "svalue1New")) + .call() + .content(); + assertThat(content).isEqualTo("response"); + systemMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(systemMessage.getText()).isEqualTo("Default system text value1, value2"); + assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("metadata1", "svalue1New") + .containsEntry("metadata2", "svalue2"); // streaming content = join( @@ -182,10 +216,16 @@ void defaultSystemTextLambda() { systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("Override default system text value3"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("metadata1", "svalue1") + .containsEntry("metadata2", "svalue2"); // Streaming content = join(chatClient.prompt("What's Spring AI?") - .system(s -> s.text("Override default system text {param3}").param("param3", "value3")) + .system(s -> s.text("Override default system text {param3}") + .param("param3", "value3") + .metadata("metadata3", "svalue3")) .stream() .content()); @@ -193,6 +233,11 @@ void defaultSystemTextLambda() { systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("Override default system text value3"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(4) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("metadata1", "svalue1") + .containsEntry("metadata2", "svalue2") + .containsEntry("metadata3", "svalue3"); } @Test @@ -215,7 +260,9 @@ void mutateDefaults() { var chatClient = ChatClient.builder(this.chatModel) .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") - .param("param2", "value2")) + .param("param2", "value2") + .metadata("smetadata1", "svalue1") + .metadata("smetadata2", "svalue2")) .defaultToolNames("fun1", "fun2") .defaultToolCallbacks(FunctionToolCallback.builder("fun3", mockFunction) .description("fun3description") @@ -225,7 +272,10 @@ void mutateDefaults() { .param("uparam1", "value1") .param("uparam2", "value2") .media(MimeTypeUtils.IMAGE_JPEG, - new DefaultResourceLoader().getResource("classpath:/bikes.json"))) + new DefaultResourceLoader().getResource("classpath:/bikes.json")) + .metadata("umetadata1", "udata1") + .metadata("umetadata2", "udata2") + ) .build(); // @formatter:on @@ -238,12 +288,20 @@ void mutateDefaults() { Message systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("Default system text value1, value2"); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("smetadata1", "svalue1") + .containsEntry("smetadata2", "svalue2"); UserMessage userMessage = (UserMessage) prompt.getInstructions().get(1); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); assertThat(userMessage.getText()).isEqualTo("Default user text value1, value2"); assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); + assertThat(userMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", USER) + .containsEntry("umetadata1", "udata1") + .containsEntry("umetadata2", "udata2"); var fco = (ToolCallingChatOptions) prompt.getOptions(); @@ -260,12 +318,20 @@ void mutateDefaults() { systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("Default system text value1, value2"); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("smetadata1", "svalue1") + .containsEntry("smetadata2", "svalue2"); userMessage = (UserMessage) prompt.getInstructions().get(1); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); assertThat(userMessage.getText()).isEqualTo("Default user text value1, value2"); assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); + assertThat(userMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", USER) + .containsEntry("umetadata1", "udata1") + .containsEntry("umetadata2", "udata2"); fco = (ToolCallingChatOptions) prompt.getOptions(); @@ -290,12 +356,20 @@ void mutateDefaults() { systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("Mutated default system text value1, value2"); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("smetadata1", "svalue1") + .containsEntry("smetadata2", "svalue2"); userMessage = (UserMessage) prompt.getInstructions().get(1); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); assertThat(userMessage.getText()).isEqualTo("Mutated default user text value1, value2"); assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); + assertThat(userMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", USER) + .containsEntry("umetadata1", "udata1") + .containsEntry("umetadata2", "udata2"); fco = (ToolCallingChatOptions) prompt.getOptions(); @@ -312,12 +386,20 @@ void mutateDefaults() { systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("Mutated default system text value1, value2"); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("smetadata1", "svalue1") + .containsEntry("smetadata2", "svalue2"); userMessage = (UserMessage) prompt.getInstructions().get(1); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); assertThat(userMessage.getText()).isEqualTo("Mutated default user text value1, value2"); assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); + assertThat(userMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", USER) + .containsEntry("umetadata1", "udata1") + .containsEntry("umetadata2", "udata2"); fco = (ToolCallingChatOptions) prompt.getOptions(); @@ -345,7 +427,9 @@ void mutatePrompt() { var chatClient = ChatClient.builder(this.chatModel) .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") - .param("param2", "value2")) + .param("param2", "value2") + .metadata("smetadata1", "svalue1") + .metadata("smetadata2", "svalue2")) .defaultToolNames("fun1", "fun2") .defaultToolCallbacks(FunctionToolCallback.builder("fun3", mockFunction) .description("fun3description") @@ -354,6 +438,8 @@ void mutatePrompt() { .defaultUser(u -> u.text("Default user text {uparam1}, {uparam2}") .param("uparam1", "value1") .param("uparam2", "value2") + .metadata("umetadata1", "udata1") + .metadata("umetadata2", "udata2") .media(MimeTypeUtils.IMAGE_JPEG, new DefaultResourceLoader().getResource("classpath:/bikes.json"))) .build(); @@ -362,7 +448,8 @@ void mutatePrompt() { .prompt() .system("New default system text {param1}, {param2}") .user(u -> u.param("uparam1", "userValue1") - .param("uparam2", "userValue2")) + .param("uparam2", "userValue2") + .metadata("umetadata2", "userData2")) .toolNames("fun5") .mutate().build() // mutate and build new prompt .prompt().call().content(); @@ -375,12 +462,20 @@ void mutatePrompt() { Message systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("New default system text value1, value2"); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("smetadata1", "svalue1") + .containsEntry("smetadata2", "svalue2"); UserMessage userMessage = (UserMessage) prompt.getInstructions().get(1); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); assertThat(userMessage.getText()).isEqualTo("Default user text userValue1, userValue2"); assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); + assertThat(userMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", USER) + .containsEntry("umetadata1", "udata1") + .containsEntry("umetadata2", "userData2"); var tco = (ToolCallingChatOptions) prompt.getOptions(); @@ -393,7 +488,8 @@ void mutatePrompt() { .prompt() .system("New default system text {param1}, {param2}") .user(u -> u.param("uparam1", "userValue1") - .param("uparam2", "userValue2")) + .param("uparam2", "userValue2") + .metadata("umetadata2", "userData2")) .toolNames("fun5") .mutate().build() // mutate and build new prompt .prompt().stream().content()); @@ -406,12 +502,20 @@ void mutatePrompt() { systemMessage = prompt.getInstructions().get(0); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("New default system text value1, value2"); + assertThat(systemMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", MessageType.SYSTEM) + .containsEntry("smetadata1", "svalue1") + .containsEntry("smetadata2", "svalue2"); userMessage = (UserMessage) prompt.getInstructions().get(1); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); assertThat(userMessage.getText()).isEqualTo("Default user text userValue1, userValue2"); assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); + assertThat(userMessage.getMetadata()).hasSize(3) + .containsEntry("messageType", USER) + .containsEntry("umetadata1", "udata1") + .containsEntry("umetadata2", "userData2"); var tcoptions = (ToolCallingChatOptions) prompt.getOptions(); @@ -433,7 +537,8 @@ void defaultUserText() { Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getText()).isEqualTo("Default user text"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); // Override the default system text with prompt system content = chatClient.prompt().user("Override default user text").call().content(); @@ -441,7 +546,8 @@ void defaultUserText() { assertThat(content).isEqualTo("response"); userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getText()).isEqualTo("Override default user text"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); } @Test @@ -454,7 +560,8 @@ void simpleUserPromptAsString() { Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getText()).isEqualTo("User prompt"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); } @Test @@ -467,7 +574,8 @@ void simpleUserPrompt() { Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getText()).isEqualTo("User prompt"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); } @Test @@ -478,15 +586,22 @@ void simpleUserPromptObject() { var media = new Media(MimeTypeUtils.IMAGE_JPEG, new DefaultResourceLoader().getResource("classpath:/bikes.json")); - UserMessage message = UserMessage.builder().text("User prompt").media(List.of(media)).build(); + UserMessage message = UserMessage.builder() + .text("User prompt") + .media(List.of(media)) + .metadata(Map.of("umetadata1", "udata1")) + .build(); Prompt prompt = new Prompt(message); assertThat(ChatClient.builder(this.chatModel).build().prompt(prompt).call().content()).isEqualTo("response"); assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(1); Message userMessage = this.promptCaptor.getValue().getInstructions().get(0); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); assertThat(userMessage.getText()).isEqualTo("User prompt"); assertThat(((UserMessage) userMessage).getMedia()).hasSize(1); + assertThat(((UserMessage) userMessage).getMetadata()).hasSize(2) + .containsEntry("messageType", USER) + .containsEntry("umetadata1", "udata1"); } @Test @@ -508,6 +623,7 @@ void simpleSystemPrompt() { Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("System prompt"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); } @Test @@ -527,7 +643,7 @@ void complexCall() throws MalformedURLException { .build(); String response = client.prompt() - .user(u -> u.text("User text {music}").param("music", "Rock").media(MimeTypeUtils.IMAGE_PNG, url)) + .user(u -> u.text("User text {music}").param("music", "Rock").media(MimeTypeUtils.IMAGE_PNG, url).metadata(Map.of("umetadata1", "udata1"))) .call() .content(); // @formatter:on @@ -541,11 +657,14 @@ void complexCall() throws MalformedURLException { UserMessage userMessage = (UserMessage) this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getText()).isEqualTo("User text Rock"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_PNG); assertThat(userMessage.getMedia().iterator().next().getData()) .isEqualTo("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); + assertThat(userMessage.getMetadata()).hasSize(2) + .containsEntry("messageType", USER) + .containsEntry("umetadata1", "udata1"); ToolCallingChatOptions runtieOptions = (ToolCallingChatOptions) this.promptCaptor.getValue().getOptions(); @@ -596,7 +715,7 @@ void whenPromptWithStringContent() { assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(1); var userMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getText()).isEqualTo("my question"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); } @Test @@ -613,7 +732,8 @@ void whenPromptWithMessages() { assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); var userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getText()).isEqualTo("my question"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); } @Test @@ -629,7 +749,8 @@ void whenPromptWithStringContentAndUserText() { assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); var userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getText()).isEqualTo("another question"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); } @Test @@ -646,7 +767,8 @@ void whenPromptWithHistoryAndUserText() { assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(3); var userMessage = this.promptCaptor.getValue().getInstructions().get(2); assertThat(userMessage.getText()).isEqualTo("another question"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); } @Test @@ -663,7 +785,8 @@ void whenPromptWithUserMessageAndUserText() { assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); var userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getText()).isEqualTo("another question"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); } @Test @@ -680,6 +803,8 @@ void whenMessagesWithHistoryAndUserText() { assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(3); var userMessage = this.promptCaptor.getValue().getInstructions().get(2); assertThat(userMessage.getText()).isEqualTo("another question"); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); } @Test @@ -696,7 +821,8 @@ void whenMessagesWithUserMessageAndUserText() { assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); var userMessage = this.promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getText()).isEqualTo("another question"); - assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getMessageType()).isEqualTo(USER); + assertThat(userMessage.getMetadata()).hasSize(1).containsEntry("messageType", USER); } // Prompt Tests - System @@ -716,6 +842,7 @@ void whenPromptWithMessagesAndSystemText() { var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); } @Test @@ -733,6 +860,7 @@ void whenPromptWithSystemMessageAndNoSystemText() { var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); } @Test @@ -750,6 +878,7 @@ void whenPromptWithSystemMessageAndSystemText() { var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("other instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); } @Test @@ -772,6 +901,7 @@ void whenMessagesAndSystemText() { var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); } @Test @@ -789,6 +919,7 @@ void whenMessagesWithSystemMessageAndNoSystemText() { var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); } @Test @@ -811,6 +942,7 @@ void whenMessagesWithSystemMessageAndSystemText() { var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("other instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + assertThat(systemMessage.getMetadata()).hasSize(1).containsEntry("messageType", MessageType.SYSTEM); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 1c72596f490..0eb64fe9b9a 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -143,7 +143,10 @@ void testMutate() { defaultChatClientBuilder.addMessages(List.of(userMessage1, userMessage2)); ChatClient originalChatClient = defaultChatClientBuilder.defaultAdvisors(advisor) .defaultOptions(chatOptions) - .defaultUser(u -> u.text("original user {userParams}").param("userParams", "user value2").media(media)) + .defaultUser(u -> u.text("original user {userParams}") + .param("userParams", "user value2") + .media(media) + .metadata("userMetadata", "user data3")) .defaultSystem(s -> s.text("original system {sysParams}").param("sysParams", "system value1")) .defaultTemplateRenderer(templateRenderer) .defaultToolNames("toolName1", "toolName2") @@ -162,6 +165,7 @@ void testMutate() { assertThat(mutateSpec.getChatOptions()).isEqualTo(copyChatOptions); assertThat(mutateSpec.getUserText()).isEqualTo("original user {userParams}"); assertThat(mutateSpec.getUserParams()).containsEntry("userParams", "user value2"); + assertThat(mutateSpec.getUserMetadata()).containsEntry("userMetadata", "user data3"); assertThat(mutateSpec.getMedia()).hasSize(1).containsOnly(media); assertThat(mutateSpec.getSystemText()).isEqualTo("original system {sysParams}"); assertThat(mutateSpec.getSystemParams()).containsEntry("sysParams", "system value1"); @@ -196,6 +200,7 @@ void buildPromptUserSpec() { assertThat(spec).isNotNull(); assertThat(spec.media()).isNotNull(); assertThat(spec.params()).isNotNull(); + assertThat(spec.metadata()).isNotNull(); assertThat(spec.text()).isNull(); } @@ -395,6 +400,66 @@ void whenUserParamsThenReturn() { assertThat(spec.params()).containsEntry("key", "value"); } + @Test + void whenUserMetadataKeyIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.metadata(null, "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata key cannot be null or empty"); + } + + @Test + void whenUserMetadataKeyIsEmptyThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.metadata("", "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata key cannot be null or empty"); + } + + @Test + void whenUserMetadataValueIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.metadata("key", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata value cannot be null"); + } + + @Test + void whenUserMetadataKeyValueThenReturn() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.metadata("key", "value"); + assertThat(spec.metadata()).containsEntry("key", "value"); + } + + @Test + void whenUserMetadataIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.metadata(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata cannot be null"); + } + + @Test + void whenUserMetadataMapKeyIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + Map metadata = new HashMap<>(); + metadata.put(null, "value"); + assertThatThrownBy(() -> spec.metadata(metadata)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata keys cannot contain null elements"); + } + + @Test + void whenUserMetadataMapValueIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + Map metadata = new HashMap<>(); + metadata.put("key", null); + assertThatThrownBy(() -> spec.metadata(metadata)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata values cannot contain null elements"); + } + + @Test + void whenUserMetadataThenReturn() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.metadata(Map.of("key", "value")); + assertThat(spec.metadata()).containsEntry("key", "value"); + } + // DefaultPromptSystemSpec @Test @@ -402,6 +467,7 @@ void buildPromptSystemSpec() { DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); assertThat(spec).isNotNull(); assertThat(spec.params()).isNotNull(); + assertThat(spec.metadata()).isNotNull(); assertThat(spec.text()).isNull(); } @@ -524,6 +590,66 @@ void whenSystemParamsThenReturn() { assertThat(spec.params()).containsEntry("key", "value"); } + @Test + void whenSystemMetadataKeyIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.metadata(null, "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata key cannot be null or empty"); + } + + @Test + void whenSystemMetadataKeyIsEmptyThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.metadata("", "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata key cannot be null or empty"); + } + + @Test + void whenSystemMetadataValueIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.metadata("key", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata value cannot be null"); + } + + @Test + void whenSystemMetadataKeyValueThenReturn() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + spec = (DefaultChatClient.DefaultPromptSystemSpec) spec.metadata("key", "value"); + assertThat(spec.metadata()).containsEntry("key", "value"); + } + + @Test + void whenSystemMetadataIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.metadata(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata cannot be null"); + } + + @Test + void whenSystemMetadataMapKeyIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + Map metadata = new HashMap<>(); + metadata.put(null, "value"); + assertThatThrownBy(() -> spec.metadata(metadata)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata keys cannot contain null elements"); + } + + @Test + void whenSystemMetadataMapValueIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + Map metadata = new HashMap<>(); + metadata.put("key", null); + assertThatThrownBy(() -> spec.metadata(metadata)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("metadata values cannot contain null elements"); + } + + @Test + void whenSystemMetadataThenReturn() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + spec = (DefaultChatClient.DefaultPromptSystemSpec) spec.metadata(Map.of("key", "value")); + assertThat(spec.metadata()).containsEntry("key", "value"); + } + // DefaultAdvisorSpec @Test @@ -1347,15 +1473,15 @@ void whenChatResponseContentIsNullThenReturnFlux() { void buildChatClientRequestSpec() { ChatModel chatModel = mock(ChatModel.class); DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec( - chatModel, null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), - Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); + chatModel, null, Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), + List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); assertThat(spec).isNotNull(); } @Test void whenChatModelIsNullThenThrow() { - assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), null, - Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), + assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), Map.of(), + null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("chatModel cannot be null"); @@ -1364,8 +1490,8 @@ void whenChatModelIsNullThenThrow() { @Test void whenObservationRegistryIsNullThenThrow() { assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null, - Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), null, - null, Map.of(), null)) + Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, + List.of(), Map.of(), null, null, Map.of(), null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("observationRegistry cannot be null"); } @@ -1817,30 +1943,37 @@ void whenSystemConsumerIsNullThenThrow() { void whenSystemConsumerThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - spec = spec.system(system -> system.text("my instruction about {topic}").param("topic", "AI")); + spec = spec.system(system -> system.text("my instruction about {topic}") + .param("topic", "AI") + .metadata("msgId", "uuid-xxx")); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getSystemText()).isEqualTo("my instruction about {topic}"); assertThat(defaultSpec.getSystemParams()).containsEntry("topic", "AI"); + assertThat(defaultSpec.getSystemMetadata()).containsEntry("msgId", "uuid-xxx"); } @Test void whenSystemConsumerWithExistingSystemTextThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt().system("my instruction"); - spec = spec.system(system -> system.text("my instruction about {topic}").param("topic", "AI")); + spec = spec.system(system -> system.text("my instruction about {topic}") + .param("topic", "AI") + .metadata("msgId", "uuid-xxx")); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getSystemText()).isEqualTo("my instruction about {topic}"); assertThat(defaultSpec.getSystemParams()).containsEntry("topic", "AI"); + assertThat(defaultSpec.getSystemMetadata()).containsEntry("msgId", "uuid-xxx"); } @Test void whenSystemConsumerWithoutSystemTextThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt().system("my instruction about {topic}"); - spec = spec.system(system -> system.param("topic", "AI")); + spec = spec.system(system -> system.param("topic", "AI").metadata("msgId", "uuid-xxx")); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getSystemText()).isEqualTo("my instruction about {topic}"); assertThat(defaultSpec.getSystemParams()).containsEntry("topic", "AI"); + assertThat(defaultSpec.getSystemMetadata()).containsEntry("msgId", "uuid-xxx"); } @Test @@ -1926,11 +2059,13 @@ void whenUserConsumerThenReturn() { ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); spec = spec.user(user -> user.text("my question about {topic}") .param("topic", "AI") + .metadata("msgId", "uuid-xxx") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("tabby-cat.png"))); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getUserText()).isEqualTo("my question about {topic}"); assertThat(defaultSpec.getUserParams()).containsEntry("topic", "AI"); assertThat(defaultSpec.getMedia()).hasSize(1); + assertThat(defaultSpec.getUserMetadata()).containsEntry("msgId", "uuid-xxx"); } @Test @@ -1939,11 +2074,13 @@ void whenUserConsumerWithExistingUserTextThenReturn() { ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("my question"); spec = spec.user(user -> user.text("my question about {topic}") .param("topic", "AI") + .metadata("msgId", "uuid-xxx") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("tabby-cat.png"))); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getUserText()).isEqualTo("my question about {topic}"); assertThat(defaultSpec.getUserParams()).containsEntry("topic", "AI"); assertThat(defaultSpec.getMedia()).hasSize(1); + assertThat(defaultSpec.getUserMetadata()).containsEntry("msgId", "uuid-xxx"); } @Test @@ -1951,11 +2088,13 @@ void whenUserConsumerWithoutUserTextThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("my question about {topic}"); spec = spec.user(user -> user.param("topic", "AI") + .metadata("msgId", "uuid-xxx") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("tabby-cat.png"))); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getUserText()).isEqualTo("my question about {topic}"); assertThat(defaultSpec.getUserParams()).containsEntry("topic", "AI"); assertThat(defaultSpec.getMedia()).hasSize(1); + assertThat(defaultSpec.getUserMetadata()).containsEntry("msgId", "uuid-xxx"); } record Person(String name) {