diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java index 3fbbd3ffb5b..a885612cd07 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java @@ -30,6 +30,7 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.DeveloperMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.jdbc.core.BatchPreparedStatementSetter; @@ -129,6 +130,7 @@ public Message mapRow(ResultSet rs, int i) throws SQLException { case USER -> new UserMessage(content); case ASSISTANT -> new AssistantMessage(content); case SYSTEM -> new SystemMessage(content); + case DEVELOPER -> new DeveloperMessage(content); // The content is always stored empty for ToolResponseMessages. // If we want to capture the actual content, we need to extend // AddBatchPreparedStatement to support it. diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java index 5e5abc6ac41..c624ff062ff 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java @@ -24,6 +24,7 @@ import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.DeveloperMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -72,6 +73,7 @@ void saveMessagesSingleMessage(String content, MessageType messageType) { case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); case USER -> new UserMessage(content + " - " + conversationId); case SYSTEM -> new SystemMessage(content + " - " + conversationId); + case DEVELOPER -> new DeveloperMessage(content + " - " + conversationId); default -> throw new IllegalArgumentException("Type not supported: " + messageType); }; diff --git a/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepositoryIT.java b/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepositoryIT.java index e01b9caebb9..3a040bbcc97 100644 --- a/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepositoryIT.java +++ b/memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepositoryIT.java @@ -28,11 +28,11 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.DeveloperMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.content.Media; import org.springframework.util.MimeType; import org.testcontainers.containers.Neo4jContainer; @@ -402,6 +402,7 @@ private Message createMessageByType(String content, MessageType messageType) { case ASSISTANT -> new AssistantMessage(content); case USER -> new UserMessage(content); case SYSTEM -> new SystemMessage(content); + case DEVELOPER -> new DeveloperMessage(content); case TOOL -> new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData"))); }; } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 1933f575300..eb9997ead50 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -49,6 +49,7 @@ import com.azure.ai.openai.models.ChatRequestAssistantMessage; import com.azure.ai.openai.models.ChatRequestMessage; import com.azure.ai.openai.models.ChatRequestSystemMessage; +import com.azure.ai.openai.models.ChatRequestDeveloperMessage; import com.azure.ai.openai.models.ChatRequestToolMessage; import com.azure.ai.openai.models.ChatRequestUserMessage; import com.azure.ai.openai.models.CompletionsFinishReason; @@ -575,6 +576,8 @@ private List fromSpringAiMessage(Message message) { return List.of(new ChatRequestUserMessage(items)); case SYSTEM: return List.of(new ChatRequestSystemMessage(message.getText())); + case DEVELOPER: + return List.of(new ChatRequestDeveloperMessage(message.getText())); case ASSISTANT: AssistantMessage assistantMessage = (AssistantMessage) message; List toolCalls = null; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index d7719b1418f..a1f8a718ae5 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -104,6 +104,7 @@ * @author Ilayaperumal Gopinathan * @author Alexandros Pappas * @author Soby Chacko + * @author Andres da Silva Santos * @see ChatModel * @see StreamingChatModel * @see OpenAiApi @@ -552,7 +553,8 @@ private Map mergeHttpHeaders(Map runtimeHttpHead ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List chatCompletionMessages = prompt.getInstructions().stream().map(message -> { - if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) { + if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM + || message.getMessageType() == MessageType.DEVELOPER) { Object content = message.getText(); if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java index d6fe27574d1..b6a31345dac 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -53,6 +53,7 @@ * @author Josh Long * @author Arjen Poutsma * @author Thomas Vitale + * @author Andres da Silva Santos * @since 1.0.0 */ public interface ChatClient { @@ -133,6 +134,23 @@ interface PromptSystemSpec { } + /** + * Specification for a prompt developer. + */ + interface PromptDeveloperSpec { + + PromptDeveloperSpec text(String text); + + PromptDeveloperSpec text(Resource text, Charset charset); + + PromptDeveloperSpec text(Resource text); + + PromptDeveloperSpec params(Map p); + + PromptDeveloperSpec param(String k, Object v); + + } + interface AdvisorSpec { AdvisorSpec param(String k, Object v); @@ -232,6 +250,14 @@ interface ChatClientRequestSpec { ChatClientRequestSpec toolContext(Map toolContext); + ChatClientRequestSpec developer(String text); + + ChatClientRequestSpec developer(Resource textResource, Charset charset); + + ChatClientRequestSpec developer(Resource text); + + ChatClientRequestSpec developer(Consumer consumer); + ChatClientRequestSpec system(String text); ChatClientRequestSpec system(Resource textResource, Charset charset); @@ -277,6 +303,14 @@ interface Builder { Builder defaultUser(Consumer userSpecConsumer); + Builder defaultDeveloper(String text); + + Builder defaultDeveloper(Resource text, Charset charset); + + Builder defaultDeveloper(Resource text); + + Builder defaultDeveloper(Consumer developerSpecConsumer); + Builder defaultSystem(String text); Builder defaultSystem(Resource text, Charset charset); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 82327f43981..8ea789bf5a3 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -75,6 +75,7 @@ * @author Soby Chacko * @author Dariusz Jedrzejczyk * @author Thomas Vitale + * @author Andres da Silva Santos * @since 1.0.0 */ public class DefaultChatClient implements ChatClient { @@ -288,6 +289,68 @@ protected Map params() { } + public static class DefaultPromptDeveloperSpec implements PromptDeveloperSpec { + + private final Map params = new HashMap<>(); + + @Nullable + private String text; + + @Override + public PromptDeveloperSpec text(String text) { + Assert.hasText(text, "text cannot be null or empty"); + this.text = text; + return this; + } + + @Override + public PromptDeveloperSpec text(Resource text, Charset charset) { + Assert.notNull(text, "text cannot be null"); + Assert.notNull(charset, "charset cannot be null"); + try { + this.text(text.getContentAsString(charset)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + @Override + public PromptDeveloperSpec text(Resource text) { + Assert.notNull(text, "text cannot be null"); + this.text(text, Charset.defaultCharset()); + return this; + } + + @Override + public PromptDeveloperSpec param(String key, Object value) { + Assert.hasText(key, "key cannot be null or empty"); + Assert.notNull(value, "value cannot be null"); + this.params.put(key, value); + return this; + } + + @Override + public PromptDeveloperSpec params(Map params) { + Assert.notNull(params, "params cannot be null"); + Assert.noNullElements(params.keySet(), "param keys cannot contain null elements"); + Assert.noNullElements(params.values(), "param values cannot contain null elements"); + this.params.putAll(params); + return this; + } + + @Nullable + protected String text() { + return this.text; + } + + protected Map params() { + return this.params; + } + + } + public static class DefaultAdvisorSpec implements AdvisorSpec { private final List advisors = new ArrayList<>(); @@ -577,6 +640,8 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final Map systemParams = new HashMap<>(); + private final Map developerParams = new HashMap<>(); + private final List advisors = new ArrayList<>(); private final Map advisorParams = new HashMap<>(); @@ -591,27 +656,32 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe @Nullable private String systemText; + @Nullable + private String developerText; + @Nullable private ChatOptions chatOptions; /* copy constructor */ DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { - this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks, - ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, - ccr.observationRegistry, ccr.observationConvention, ccr.toolContext, ccr.templateRenderer); + this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.developerText, + ccr.developerParams, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, + ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention, + ccr.toolContext, ccr.templateRenderer); } public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map userParams, @Nullable String systemText, Map systemParams, - List toolCallbacks, List messages, List toolNames, List media, - @Nullable ChatOptions chatOptions, List advisors, Map advisorParams, - ObservationRegistry observationRegistry, + @Nullable String developerText, Map developerParams, List toolCallbacks, + List messages, List toolNames, List media, @Nullable ChatOptions chatOptions, + List advisors, Map advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map toolContext, @Nullable TemplateRenderer templateRenderer) { Assert.notNull(chatModel, "chatModel cannot be null"); Assert.notNull(userParams, "userParams cannot be null"); Assert.notNull(systemParams, "systemParams cannot be null"); + Assert.notNull(developerParams, "developerParams cannot be null"); Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.notNull(messages, "messages cannot be null"); Assert.notNull(toolNames, "toolNames cannot be null"); @@ -629,6 +699,8 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe this.userParams.putAll(userParams); this.systemText = systemText; this.systemParams.putAll(systemParams); + this.developerText = developerText; + this.developerParams.putAll(developerParams); this.toolNames.addAll(toolNames); this.toolCallbacks.addAll(toolCallbacks); @@ -661,6 +733,15 @@ public Map getSystemParams() { return this.systemParams; } + @Nullable + public String getDeveloperText() { + return this.developerText; + } + + public Map getDeveloperParams() { + return this.developerParams; + } + @Nullable public ChatOptions getChatOptions() { return this.chatOptions; @@ -719,6 +800,10 @@ public Builder mutate() { builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams)); } + if (StringUtils.hasText(this.developerText)) { + builder.defaultDeveloper(s -> s.text(this.developerText).params(this.developerParams)); + } + if (this.chatOptions != null) { builder.defaultOptions(this.chatOptions); } @@ -821,6 +906,41 @@ public ChatClientRequestSpec toolContext(Map toolContext) { return this; } + public ChatClientRequestSpec developer(String text) { + Assert.hasText(text, "text cannot be null or empty"); + this.developerText = text; + return this; + } + + public ChatClientRequestSpec developer(Resource text, Charset charset) { + Assert.notNull(text, "text cannot be null"); + Assert.notNull(charset, "charset cannot be null"); + + try { + this.developerText = text.getContentAsString(charset); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + public ChatClientRequestSpec developer(Resource text) { + Assert.notNull(text, "text cannot be null"); + return this.developer(text, Charset.defaultCharset()); + } + + public ChatClientRequestSpec developer(Consumer consumer) { + Assert.notNull(consumer, "consumer cannot be null"); + + var developerSpec = new DefaultPromptDeveloperSpec(); + consumer.accept(developerSpec); + this.developerText = StringUtils.hasText(developerSpec.text()) ? developerSpec.text() : this.developerText; + this.developerParams.putAll(developerSpec.params()); + + return this; + } + public ChatClientRequestSpec system(String text) { Assert.hasText(text, "text cannot be null or empty"); this.systemText = text; diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 8d314b0ef59..725ae74034f 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -26,6 +26,7 @@ import org.springframework.ai.chat.client.ChatClient.Builder; import org.springframework.ai.chat.client.ChatClient.PromptSystemSpec; +import org.springframework.ai.chat.client.ChatClient.PromptDeveloperSpec; import org.springframework.ai.chat.client.ChatClient.PromptUserSpec; import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.client.advisor.api.Advisor; @@ -50,6 +51,7 @@ * @author Josh Long * @author Arjen Poutsma * @author Thomas Vitale + * @author Andres da Silva Santos * @since 1.0.0 */ public class DefaultChatClientBuilder implements Builder { @@ -64,8 +66,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa @Nullable ChatClientObservationConvention customObservationConvention) { Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); - this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(), - List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, + this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), null, + Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of(), null); } @@ -149,6 +151,32 @@ public Builder defaultSystem(Consumer systemSpecConsumer) { return this; } + public Builder defaultDeveloper(String text) { + this.defaultRequest.developer(text); + return this; + } + + public Builder defaultDeveloper(Resource text, Charset charset) { + Assert.notNull(text, "text cannot be null"); + Assert.notNull(charset, "charset cannot be null"); + try { + this.defaultRequest.developer(text.getContentAsString(charset)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + public Builder defaultDeveloper(Resource text) { + return this.defaultDeveloper(text, Charset.defaultCharset()); + } + + public Builder defaultDeveloper(Consumer developerSpecConsumer) { + this.defaultRequest.developer(developerSpecConsumer); + return this; + } + @Override public Builder defaultToolNames(String... toolNames) { this.defaultRequest.toolNames(toolNames); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index 3b793d6a99a..117315ad508 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -16,6 +16,7 @@ package org.springframework.ai.chat.client; +import org.springframework.ai.chat.messages.DeveloperMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; @@ -39,6 +40,7 @@ * Utilities for supporting the {@link DefaultChatClient} implementation. * * @author Thomas Vitale + * @author Andres da Silva Santos * @since 1.0.0 */ class DefaultChatClientUtils { @@ -66,6 +68,20 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient processedMessages.add(new SystemMessage(processedSystemText)); } + // Developer Text => First in the list + String processedDeveloperText = inputRequest.getDeveloperText(); + if (StringUtils.hasText(processedDeveloperText)) { + if (!CollectionUtils.isEmpty(inputRequest.getDeveloperParams())) { + processedDeveloperText = PromptTemplate.builder() + .template(processedDeveloperText) + .variables(inputRequest.getDeveloperParams()) + .renderer(inputRequest.getTemplateRenderer()) + .build() + .render(); + } + processedMessages.add(new DeveloperMessage(processedDeveloperText)); + } + // Messages => In the middle of the list if (!CollectionUtils.isEmpty(inputRequest.getMessages())) { processedMessages.addAll(inputRequest.getMessages()); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 783a7356c0a..0531c953216 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -34,6 +34,7 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.DeveloperMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -195,6 +196,136 @@ void defaultSystemTextLambda() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } + @Test + void defaultDeveloperText() { + + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( + () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { + sink.next(state); + sink.complete(); + return state; + })); + + var chatClient = ChatClient.builder(this.chatModel).defaultDeveloper("Default developer text").build(); + + var content = chatClient.prompt("What's Spring AI?").call().content(); + + assertThat(content).isEqualTo("response"); + + Message developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Default developer text"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + + content = join(chatClient.prompt("What's Spring AI?").stream().content()); + + assertThat(content).isEqualTo("response"); + + developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Default developer text"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + + // Override the default developer text with prompt developer + content = chatClient.prompt("What's Spring AI?").developer("Override default developer text").call().content(); + + assertThat(content).isEqualTo("response"); + developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Override default developer text"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + + // Streaming + content = join( + chatClient.prompt("What's Spring AI?").developer("Override default developer text").stream().content()); + + assertThat(content).isEqualTo("response"); + developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Override default developer text"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + } + + @Test + void defaultDeveloperTextLambda() { + + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate( + () -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> { + sink.next(state); + sink.complete(); + return state; + })); + + var chatClient = ChatClient.builder(this.chatModel) + .defaultDeveloper(s -> s.text("Default developer text {param1}, {param2}") + .param("param1", "value1") + .param("param2", "value2")) + .build(); + + var content = chatClient.prompt("What's Spring AI?").call().content(); + + assertThat(content).isEqualTo("response"); + + Message developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Default developer text value1, value2"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + + // Streaming + content = join(chatClient.prompt("What's Spring AI?").stream().content()); + + assertThat(content).isEqualTo("response"); + + developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Default developer text value1, value2"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + + // Override single default developer parameter + content = chatClient.prompt("What's Spring AI?") + .developer(s -> s.param("param1", "value1New")) + .call() + .content(); + + assertThat(content).isEqualTo("response"); + developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Default developer text value1New, value2"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + + // streaming + content = join(chatClient.prompt("What's Spring AI?") + .developer(s -> s.param("param1", "value1New")) + .stream() + .content()); + + assertThat(content).isEqualTo("response"); + developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Default developer text value1New, value2"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + + // Override default developer text + content = chatClient.prompt("What's Spring AI?") + .developer(s -> s.text("Override default developer text {param3}").param("param3", "value3")) + .call() + .content(); + + assertThat(content).isEqualTo("response"); + developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Override default developer text value3"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + + // Streaming + content = join(chatClient.prompt("What's Spring AI?") + .developer(s -> s.text("Override default developer text {param3}").param("param3", "value3")) + .stream() + .content()); + + assertThat(content).isEqualTo("response"); + developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Override default developer text value3"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + } + @Test void mutateDefaults() { @@ -216,6 +347,9 @@ void mutateDefaults() { .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) + .defaultDeveloper(s -> s.text("Default developer text {param1}, {param2}") + .param("param1", "value1") + .param("param2", "value2")) .defaultToolNames("fun1", "fun2") .defaultToolCallbacks(FunctionToolCallback.builder("fun3", mockFunction) .description("fun3description") @@ -239,7 +373,11 @@ void mutateDefaults() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("Default system text value1, value2"); - UserMessage userMessage = (UserMessage) prompt.getInstructions().get(1); + Message developerMessage = prompt.getInstructions().get(1); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + assertThat(developerMessage.getText()).isEqualTo("Default developer text value1, value2"); + + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(2); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getText()).isEqualTo("Default user text value1, value2"); assertThat(userMessage.getMedia()).hasSize(1); @@ -261,7 +399,11 @@ void mutateDefaults() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("Default system text value1, value2"); - userMessage = (UserMessage) prompt.getInstructions().get(1); + developerMessage = prompt.getInstructions().get(1); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + assertThat(developerMessage.getText()).isEqualTo("Default developer text value1, value2"); + + userMessage = (UserMessage) prompt.getInstructions().get(2); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getText()).isEqualTo("Default user text value1, value2"); assertThat(userMessage.getMedia()).hasSize(1); @@ -276,6 +418,7 @@ void mutateDefaults() { // @formatter:off chatClient = chatClient.mutate() .defaultSystem("Mutated default system text {param1}, {param2}") + .defaultDeveloper("Mutated default developer text {param1}, {param2}") .defaultToolNames("fun4") .defaultUser("Mutated default user text {uparam1}, {uparam2}") .build(); @@ -291,7 +434,11 @@ void mutateDefaults() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("Mutated default system text value1, value2"); - userMessage = (UserMessage) prompt.getInstructions().get(1); + developerMessage = prompt.getInstructions().get(1); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + assertThat(developerMessage.getText()).isEqualTo("Mutated default developer text value1, value2"); + + userMessage = (UserMessage) prompt.getInstructions().get(2); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getText()).isEqualTo("Mutated default user text value1, value2"); assertThat(userMessage.getMedia()).hasSize(1); @@ -313,7 +460,11 @@ void mutateDefaults() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("Mutated default system text value1, value2"); - userMessage = (UserMessage) prompt.getInstructions().get(1); + developerMessage = prompt.getInstructions().get(1); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + assertThat(developerMessage.getText()).isEqualTo("Mutated default developer text value1, value2"); + + userMessage = (UserMessage) prompt.getInstructions().get(2); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getText()).isEqualTo("Mutated default user text value1, value2"); assertThat(userMessage.getMedia()).hasSize(1); @@ -346,6 +497,9 @@ void mutatePrompt() { .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) + .defaultDeveloper(s -> s.text("Default developer text {param1}, {param2}") + .param("param1", "value1") + .param("param2", "value2")) .defaultToolNames("fun1", "fun2") .defaultToolCallbacks(FunctionToolCallback.builder("fun3", mockFunction) .description("fun3description") @@ -361,6 +515,7 @@ void mutatePrompt() { var content = chatClient .prompt() .system("New default system text {param1}, {param2}") + .developer("New default developer text {param1}, {param2}") .user(u -> u.param("uparam1", "userValue1") .param("uparam2", "userValue2")) .toolNames("fun5") @@ -376,7 +531,11 @@ void mutatePrompt() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("New default system text value1, value2"); - UserMessage userMessage = (UserMessage) prompt.getInstructions().get(1); + Message developerMessage = prompt.getInstructions().get(1); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + assertThat(developerMessage.getText()).isEqualTo("New default developer text value1, value2"); + + UserMessage userMessage = (UserMessage) prompt.getInstructions().get(2); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getText()).isEqualTo("Default user text userValue1, userValue2"); assertThat(userMessage.getMedia()).hasSize(1); @@ -392,6 +551,7 @@ void mutatePrompt() { content = join(chatClient .prompt() .system("New default system text {param1}, {param2}") + .developer("New default developer text {param1}, {param2}") .user(u -> u.param("uparam1", "userValue1") .param("uparam2", "userValue2")) .toolNames("fun5") @@ -407,7 +567,11 @@ void mutatePrompt() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); assertThat(systemMessage.getText()).isEqualTo("New default system text value1, value2"); - userMessage = (UserMessage) prompt.getInstructions().get(1); + developerMessage = prompt.getInstructions().get(1); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + assertThat(developerMessage.getText()).isEqualTo("New default developer text value1, value2"); + + userMessage = (UserMessage) prompt.getInstructions().get(2); assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); assertThat(userMessage.getText()).isEqualTo("Default user text userValue1, userValue2"); assertThat(userMessage.getMedia()).hasSize(1); @@ -510,6 +674,27 @@ void simpleSystemPrompt() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } + @Test + void simpleDeveloperPrompt() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + String response = ChatClient.builder(this.chatModel) + .build() + .prompt("What's Spring AI?") + .developer("Developer prompt") + .call() + .content(); + + assertThat(response).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); + + Message developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("Developer prompt"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + } + @Test void complexCall() throws MalformedURLException { given(this.chatModel.call(this.promptCaptor.capture())) @@ -813,4 +998,122 @@ void whenMessagesWithSystemMessageAndSystemText() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } + // Prompt Tests - Developer + + @Test + void whenPromptWithMessagesAndDeveloperText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var prompt = new Prompt(new UserMessage("my question"), new AssistantMessage("your answer")); + var content = chatClient.prompt(prompt).developer("instructions").user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); + var developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("instructions"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + } + + @Test + void whenPromptWithDeveloperMessageAndNoDeveloperText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var prompt = new Prompt(new DeveloperMessage("instructions"), new UserMessage("my question")); + var content = chatClient.prompt(prompt).user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(3); + var developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("instructions"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + } + + @Test + void whenPromptWithDeveloperMessageAndDeveloperText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var prompt = new Prompt(new DeveloperMessage("instructions"), new UserMessage("my question")); + var content = chatClient.prompt(prompt) + .developer("other instructions") + .user("another question") + .call() + .content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); + var developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("other instructions"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + } + + @Test + void whenMessagesAndDeveloperText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + List messages = List.of(new UserMessage("my question"), new AssistantMessage("your answer")); + var content = chatClient.prompt() + .messages(messages) + .developer("instructions") + .user("another question") + .call() + .content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); + var developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("instructions"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + } + + @Test + void whenMessagesWithDeveloperMessageAndNoDeveloperText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + List messages = List.of(new DeveloperMessage("instructions"), new UserMessage("my question")); + var content = chatClient.prompt().messages(messages).user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(3); + var developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("instructions"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + } + + @Test + void whenMessagesWithDeveloperMessageAndDeveloperText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + List messages = List.of(new DeveloperMessage("instructions"), new UserMessage("my question")); + var content = chatClient.prompt() + .messages(messages) + .developer("other instructions") + .user("another question") + .call() + .content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); + var developerMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(developerMessage.getText()).isEqualTo("other instructions"); + assertThat(developerMessage.getMessageType()).isEqualTo(MessageType.DEVELOPER); + } + } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java index a4cb02541a7..a7c646e6291 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java @@ -32,6 +32,7 @@ * Unit tests for {@link DefaultChatClientBuilder}. * * @author Thomas Vitale + * @author Andres da Silva Santos */ class DefaultChatClientBuilderTests { @@ -40,14 +41,17 @@ void whenCloneBuilder() { var chatModel = mock(ChatModel.class); var originalBuilder = new DefaultChatClientBuilder(chatModel); originalBuilder.defaultSystem("first instructions"); + originalBuilder.defaultDeveloper("first instructions"); var clonedBuilder = (DefaultChatClientBuilder) originalBuilder.clone(); originalBuilder.defaultSystem("second instructions"); + originalBuilder.defaultDeveloper("second instructions"); assertThat(clonedBuilder).isNotSameAs(originalBuilder); var clonedBuilderRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) ReflectionTestUtils .getField(clonedBuilder, "defaultRequest"); assertThat(clonedBuilderRequestSpec).isNotNull(); assertThat(clonedBuilderRequestSpec.getSystemText()).isEqualTo("first instructions"); + assertThat(clonedBuilderRequestSpec.getDeveloperText()).isEqualTo("first instructions"); } @Test @@ -87,6 +91,14 @@ void whenSystemResourceIsNullThenThrows() { .hasMessage("text cannot be null"); } + @Test + void whenDeveloperResourceIsNullThenThrows() { + DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class)); + assertThatThrownBy(() -> builder.defaultDeveloper(null, Charset.defaultCharset())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + @Test void whenSystemCharsetIsNullThenThrows() { DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class)); @@ -95,6 +107,14 @@ void whenSystemCharsetIsNullThenThrows() { .hasMessage("charset cannot be null"); } + @Test + void whenDeveloperCharsetIsNullThenThrows() { + DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class)); + assertThatThrownBy(() -> builder.defaultDeveloper(new ClassPathResource("system-prompt.txt"), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("charset cannot be null"); + } + @Test void whenTemplateRendererIsNullThenThrows() { DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class)); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index fd795971a2e..1fbfc168ad0 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -39,6 +39,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.DeveloperMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -65,6 +66,7 @@ * Unit tests for {@link DefaultChatClient}. * * @author Thomas Vitale + * @author Andres da Silva Santos */ class DefaultChatClientTests { @@ -113,6 +115,18 @@ void whenPromptWithMessagesThenReturn() { assertThat(spec.getChatOptions()).isNull(); } + @Test + void whenPromptWithDeveloperAndMessagesThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + Prompt prompt = new Prompt(new DeveloperMessage("instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec spec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt); + assertThat(spec.getMessages()).hasSize(2); + assertThat(spec.getMessages().get(0).getText()).isEqualTo("instructions"); + assertThat(spec.getMessages().get(1).getText()).isEqualTo("my question"); + assertThat(spec.getChatOptions()).isNull(); + } + @Test void whenPromptWithOptionsThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); @@ -477,6 +491,127 @@ void whenSystemParamsThenReturn() { assertThat(spec.params()).containsEntry("key", "value"); } + // DefaultPromptDeveloperSpec + + @Test + void buildPromptDeveloperSpec() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + assertThat(spec).isNotNull(); + assertThat(spec.params()).isNotNull(); + assertThat(spec.text()).isNull(); + } + + @Test + void whenDeveloperTextStringIsNullThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + assertThatThrownBy(() -> spec.text((String) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenDeveloperTextStringIsEmptyThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + assertThatThrownBy(() -> spec.text("")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenDeveloperTextStringThenReturn() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + spec = (DefaultChatClient.DefaultPromptDeveloperSpec) spec.text("developer instructions"); + assertThat(spec.text()).isEqualTo("developer instructions"); + } + + @Test + void whenDeveloperTextResourceIsNullWithCharsetThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + assertThatThrownBy(() -> spec.text(null, Charset.defaultCharset())).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenDeveloperTextResourceAndCharsetThenReturn() throws Exception { + Resource textResource = new ClassPathResource("developer-prompt.txt"); + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + spec = (DefaultChatClient.DefaultPromptDeveloperSpec) spec.text(textResource, Charset.defaultCharset()); + assertThat(spec.text()).isEqualTo("instructions"); + } + + @Test + void whenDeveloperTextResourceIsNullThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + assertThatThrownBy(() -> spec.text((Resource) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenDeveloperTextResourceThenReturn() throws Exception { + Resource textResource = new ClassPathResource("developer-prompt.txt"); + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + spec = (DefaultChatClient.DefaultPromptDeveloperSpec) spec.text(textResource); + assertThat(spec.text()).isEqualTo("instructions"); + } + + @Test + void whenDeveloperParamKeyIsNullThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + assertThatThrownBy(() -> spec.param(null, "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("key cannot be null or empty"); + } + + @Test + void whenDeveloperParamKeyIsEmptyThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + assertThatThrownBy(() -> spec.param("", "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("key cannot be null or empty"); + } + + @Test + void whenDeveloperParamValueIsNullThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + assertThatThrownBy(() -> spec.param("key", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("value cannot be null"); + } + + @Test + void whenDeveloperParamKeyValueThenReturn() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + spec = (DefaultChatClient.DefaultPromptDeveloperSpec) spec.param("key", "value"); + assertThat(spec.params()).containsEntry("key", "value"); + } + + @Test + void whenDeveloperParamsIsNullThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + assertThatThrownBy(() -> spec.params(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("params cannot be null"); + } + + @Test + void whenDeveloperParamsContainsNullKeyThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + Map params = new HashMap<>(); + params.put(null, "value"); + assertThatThrownBy(() -> spec.params(params)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("param keys cannot contain null elements"); + } + + @Test + void whenDeveloperParamsContainsNullValueThenThrow() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + Map params = new HashMap<>(); + params.put("key", null); + assertThatThrownBy(() -> spec.params(params)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("param values cannot contain null elements"); + } + + @Test + void whenDeveloperParamsThenReturn() { + DefaultChatClient.DefaultPromptDeveloperSpec spec = new DefaultChatClient.DefaultPromptDeveloperSpec(); + spec = (DefaultChatClient.DefaultPromptDeveloperSpec) spec.params(Map.of("key", "value")); + assertThat(spec.params()).containsEntry("key", "value"); + } + // DefaultAdvisorSpec @Test @@ -710,6 +845,30 @@ void whenFullPromptThenChatResponse() { assertThat(actualPrompt.getInstructions().get(1).getText()).isEqualTo("my question"); } + @Test + void whenFullPromptWithDeveloperThenChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + Prompt prompt = new Prompt(new DeveloperMessage("developer instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); + + ChatResponse chatResponse = spec.chatResponse(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(2); + assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("developer instructions"); + assertThat(actualPrompt.getInstructions().get(1).getText()).isEqualTo("my question"); + } + @Test void whenPromptAndUserTextThenChatResponse() { ChatModel chatModel = mock(ChatModel.class); @@ -736,6 +895,30 @@ void whenPromptAndUserTextThenChatResponse() { assertThat(actualPrompt.getInstructions().get(2).getText()).isEqualTo("another question"); } + @Test + void whenPromptWithDeveloperAndUserTextThenChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + Prompt prompt = new Prompt(new DeveloperMessage("developer instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); + + ChatResponse chatResponse = spec.chatResponse(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(2); + assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("developer instructions"); + assertThat(actualPrompt.getInstructions().get(1).getText()).isEqualTo("my question"); + } + @Test void whenUserTextAndMessagesThenChatResponse() { ChatModel chatModel = mock(ChatModel.class); @@ -763,6 +946,35 @@ void whenUserTextAndMessagesThenChatResponse() { assertThat(actualPrompt.getInstructions().get(2).getText()).isEqualTo("another question"); } + @Test + void whenDeveloperMessageAndMessagesThenChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + List messages = List.of(new DeveloperMessage("developer instructions"), new UserMessage("my question"), + new UserMessage("additional question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt() + .user("additional user message") + .messages(messages); + DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec + .call(); + + ChatResponse chatResponse = spec.chatResponse(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(4); + assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("developer instructions"); + assertThat(actualPrompt.getInstructions().get(1).getText()).isEqualTo("my question"); + assertThat(actualPrompt.getInstructions().get(2).getText()).isEqualTo("additional question"); + assertThat(actualPrompt.getInstructions().get(3).getText()).isEqualTo("additional user message"); + } + @Test void whenChatResponseIsNull() { ChatModel chatModel = mock(ChatModel.class); @@ -1277,6 +1489,85 @@ void whenUserTextAndMessagesThenFluxChatResponse() { assertThat(actualPrompt.getInstructions().get(2).getText()).isEqualTo("another question"); } + @Test + void whenFullPromptWithDeveloperThenFluxChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + Prompt prompt = new Prompt(new DeveloperMessage("developer instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt); + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); + + ChatResponse chatResponse = spec.chatResponse().blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(2); + assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("developer instructions"); + assertThat(actualPrompt.getInstructions().get(1).getText()).isEqualTo("my question"); + } + + @Test + void whenPromptAndUserTextWithDeveloperThenFluxChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + Prompt prompt = new Prompt(new DeveloperMessage("developer instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt) + .user("another question"); + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); + + ChatResponse chatResponse = spec.chatResponse().blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(3); + assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("developer instructions"); + assertThat(actualPrompt.getInstructions().get(1).getText()).isEqualTo("my question"); + assertThat(actualPrompt.getInstructions().get(2).getText()).isEqualTo("another question"); + } + + @Test + void whenUserTextAndMessagesWithDeveloperThenFluxChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + List messages = List.of(new DeveloperMessage("developer instructions"), + new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt() + .user("another question") + .messages(messages); + + DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec + .stream(); + + ChatResponse chatResponse = spec.chatResponse().blockLast(); + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(3); + assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("developer instructions"); + assertThat(actualPrompt.getInstructions().get(1).getText()).isEqualTo("my question"); + } + @Test void whenChatResponseContentIsNullThenReturnFlux() { ChatModel chatModel = mock(ChatModel.class); @@ -1300,15 +1591,15 @@ void whenChatResponseContentIsNullThenReturnFlux() { void buildChatClientRequestSpec() { ChatModel chatModel = mock(ChatModel.class); DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec( - chatModel, null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), - Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); + chatModel, null, Map.of(), null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), + null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); assertThat(spec).isNotNull(); } @Test void whenChatModelIsNullThenThrow() { assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), null, - Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), + Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("chatModel cannot be null"); @@ -1317,8 +1608,8 @@ void whenChatModelIsNullThenThrow() { @Test void whenObservationRegistryIsNullThenThrow() { assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null, - Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), null, - null, Map.of(), null)) + Map.of(), null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), + Map.of(), null, null, Map.of(), null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("observationRegistry cannot be null"); } @@ -1796,6 +2087,73 @@ void whenSystemConsumerWithoutSystemTextThenReturn() { assertThat(defaultSpec.getSystemParams()).containsEntry("topic", "AI"); } + @Test + void whenDeveloperTextIsEmptyThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + assertThatThrownBy(() -> chatClient.prompt().developer("")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenDeveloperTextThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + var spec = chatClient.prompt().developer(dev -> dev.text("instructions")); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getDeveloperText()).isEqualTo("instructions"); + } + + @Test + void whenDeveloperResourceIsNullWithCharsetThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + assertThatThrownBy(() -> chatClient.prompt().developer(null, Charset.defaultCharset())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenDeveloperCharsetIsNullWithResourceThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + Resource textResource = new ClassPathResource("developer-prompt.txt"); + assertThatThrownBy(() -> chatClient.prompt().developer(textResource, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("charset cannot be null"); + } + + @Test + void whenDeveloperResourceAndCharsetThenReturn() { + Resource textResource = new ClassPathResource("developer-prompt.txt"); + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + var spec = chatClient.prompt().developer(textResource, Charset.defaultCharset()); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getDeveloperText()).isEqualTo("instructions"); + } + + @Test + void whenDeveloperResourceThenReturn() { + Resource textResource = new ClassPathResource("developer-prompt.txt"); + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + var spec = chatClient.prompt().developer(textResource); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getDeveloperText()).isEqualTo("instructions"); + } + + @Test + void whenDeveloperConsumerIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + assertThatThrownBy(() -> chatClient.prompt().developer((Consumer) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("consumer cannot be null"); + } + + @Test + void whenDeveloperConsumerThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + var spec = chatClient.prompt().developer(dev -> dev.text("my instruction about {topic}").param("topic", "AI")); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getDeveloperText()).isEqualTo("my instruction about {topic}"); + assertThat(defaultSpec.getDeveloperParams()).containsEntry("topic", "AI"); + } + @Test void whenUserTextIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java index a14f5d3fdce..8451959298e 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.DeveloperMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.content.Media; @@ -87,6 +88,42 @@ void whenSystemTextWithParamsIsProvidedThenSystemMessageIsRenderedAndAddedToProm assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("System instructions for Spring AI"); } + @Test + void whenDeveloperTextIsProvidedThenDeveloperMessageIsAddedToPrompt() { + String developerText = "Developer instructions"; + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .developer(developerText); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).isNotEmpty(); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(DeveloperMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo(developerText); + } + + @Test + void whenDeveloperTextWithParamsIsProvidedThenDeveloperMessageIsRenderedAndAddedToPrompt() { + String developerText = "Developer instructions for {name}"; + Map developerParams = Map.of("name", "Spring Boot"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .developer(s -> s.text(developerText).params(developerParams)); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).isNotEmpty(); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(DeveloperMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()) + .isEqualTo("Developer instructions for Spring Boot"); + } + @Test void whenMessagesAreProvidedThenTheyAreAddedToPrompt() { List messages = List.of(new SystemMessage("System message"), new UserMessage("User message")); @@ -178,6 +215,25 @@ void whenSystemTextAndSystemMessageAreProvidedThenSystemTextIsFirst() { assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo(systemText); } + @Test + void whenDeveloperTextAndMessagesAreProvidedThenDeveloperMessageIsFirst() { + String developerText = "Developer instructions"; + List messages = List.of(new DeveloperMessage("Developer message")); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .developer(developerText) + .messages(messages); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).hasSize(2); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(DeveloperMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo(developerText); + } + @Test void whenUserTextAndUserMessageAreProvidedThenUserTextIsLast() { String userText = "User question"; @@ -413,6 +469,84 @@ void whenAllComponentsAreProvidedThenCompleteRequestIsCreated() { assertThat(result.context()).containsAllEntriesOf(advisorParams); } + @Test + void whenCustomTemplateRendererWithDeveloperThenItIsUsedForRendering() { + String developerText = "Instructions "; + Map developerParams = Map.of("name", "Spring AI"); + TemplateRenderer customRenderer = StTemplateRenderer.builder() + .startDelimiterToken('<') + .endDelimiterToken('>') + .build(); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .developer(s -> s.text(developerText).params(developerParams)) + .templateRenderer(customRenderer); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).isNotEmpty(); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(DeveloperMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("Instructions Spring AI"); + } + + @Test + void whenAllComponentsAreProvidedWithDeveloperThenCompleteRequestIsCreated() { + String developerText = "Developer instructions for {name}"; + Map developerParams = Map.of("name", "Spring AI"); + + String userText = "Question about {topic}"; + Map userParams = Map.of("topic", "Spring AI"); + Media media = mock(Media.class); + + List messages = List.of(new UserMessage("Intermediate message")); + + ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().build(); + List toolNames = List.of("tool1", "tool2"); + ToolCallback toolCallback = new TestToolCallback("tool3"); + Map toolContext = Map.of("toolKey", "toolValue"); + Map advisorParams = Map.of("advisorKey", "advisorValue"); + + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .developer(s -> s.text(developerText).params(developerParams)) + .user(u -> u.text(userText).params(userParams).media(media)) + .messages(messages) + .toolNames(toolNames.toArray(new String[0])) + .toolCallbacks(toolCallback) + .toolContext(toolContext) + .options(chatOptions) + .advisors(a -> a.params(advisorParams)); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + + assertThat(result.prompt().getInstructions()).hasSize(3); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(DeveloperMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()) + .isEqualTo("Developer instructions for Spring AI"); + assertThat(result.prompt().getInstructions().get(1)).isInstanceOf(Message.class); + assertThat(result.prompt().getInstructions().get(1).getText()).isEqualTo("Intermediate message"); + assertThat(result.prompt().getInstructions().get(2)).isInstanceOf(UserMessage.class); + assertThat(result.prompt().getInstructions().get(2).getText()).isEqualTo("Question about Spring AI"); + UserMessage userMessage = (UserMessage) result.prompt().getInstructions().get(2); + assertThat(userMessage.getMedia()).contains(media); + + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames); + assertThat(resultOptions.getToolCallbacks()).contains(toolCallback); + assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext); + + assertThat(result.context()).containsAllEntriesOf(advisorParams); + } + static class TestToolCallback implements ToolCallback { private final ToolDefinition toolDefinition; diff --git a/spring-ai-client-chat/src/test/resources/developer-prompt.txt b/spring-ai-client-chat/src/test/resources/developer-prompt.txt new file mode 100644 index 00000000000..e468cde2e7d --- /dev/null +++ b/spring-ai-client-chat/src/test/resources/developer-prompt.txt @@ -0,0 +1 @@ +instructions \ No newline at end of file diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 6e37fd7548b..86b11791082 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -67,8 +67,9 @@ public abstract class AbstractMessage implements Message { */ protected AbstractMessage(MessageType messageType, @Nullable String textContent, Map metadata) { Assert.notNull(messageType, "Message type must not be null"); - if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { - Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages"); + if (messageType == MessageType.SYSTEM || messageType == MessageType.DEVELOPER + || messageType == MessageType.USER) { + Assert.notNull(textContent, "Content must not be null for SYSTEM, DEVELOPER or USER messages"); } Assert.notNull(metadata, "Metadata must not be null"); this.messageType = messageType; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/DeveloperMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/DeveloperMessage.java new file mode 100644 index 00000000000..f2b7a314592 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/DeveloperMessage.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.messages; + +import org.springframework.core.io.Resource; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * A message of the type 'developer' passed as input. The developer message gives + * instructions or requests from developers using the API. This role typically provides + * detailed instructions for the model to follow, such as specific actions or formats. + */ +public class DeveloperMessage extends AbstractMessage { + + public DeveloperMessage(String textContent) { + this(textContent, Map.of()); + } + + public DeveloperMessage(Resource resource) { + this(MessageUtils.readResource(resource), Map.of()); + } + + private DeveloperMessage(String textContent, Map metadata) { + super(MessageType.DEVELOPER, textContent, metadata); + } + + @Override + @NonNull + public String getText() { + return this.textContent; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof DeveloperMessage that)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return Objects.equals(this.textContent, that.textContent); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.textContent); + } + + @Override + public String toString() { + return "DeveloperMessage{" + "textContent='" + this.textContent + '\'' + ", messageType=" + this.messageType + + ", metadata=" + this.metadata + '}'; + } + + public DeveloperMessage copy() { + return new DeveloperMessage(getText(), Map.copyOf(this.metadata)); + } + + public Builder mutate() { + return new Builder().text(this.textContent).metadata(this.metadata); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + @Nullable + private String textContent; + + @Nullable + private Resource resource; + + private Map metadata = new HashMap<>(); + + public Builder text(String textContent) { + this.textContent = textContent; + return this; + } + + public Builder text(Resource resource) { + this.resource = resource; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public DeveloperMessage build() { + if (StringUtils.hasText(textContent) && resource != null) { + throw new IllegalArgumentException("textContent and resource cannot be set at the same time"); + } + else if (resource != null) { + this.textContent = MessageUtils.readResource(resource); + } + return new DeveloperMessage(this.textContent, this.metadata); + } + + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/MessageType.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/MessageType.java index 4ceb76b3a43..3a4b4eb0b0e 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/MessageType.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/MessageType.java @@ -44,6 +44,13 @@ public enum MessageType { */ SYSTEM("system"), + /** + * A {@link Message} of type {@literal developer} passed as input {@link Message} + * Messages containing instructions or requests from developers using the API. + * @see DeveloperMessage + */ + DEVELOPER("developer"), + /** * A {@link Message} of type {@literal function} passed as input {@link Message * Messages} with function content in a chat application. diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DeveloperPromptTemplate.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DeveloperPromptTemplate.java new file mode 100644 index 00000000000..19b092264b4 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DeveloperPromptTemplate.java @@ -0,0 +1,55 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.prompt; + +import java.util.Map; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.DeveloperMessage; +import org.springframework.core.io.Resource; + +public class DeveloperPromptTemplate extends PromptTemplate { + + public DeveloperPromptTemplate(String template) { + super(template); + } + + public DeveloperPromptTemplate(Resource resource) { + super(resource); + } + + @Override + public Message createMessage() { + return new DeveloperMessage(render()); + } + + @Override + public Message createMessage(Map model) { + return new DeveloperMessage(render(model)); + } + + @Override + public Prompt create() { + return new Prompt(new DeveloperMessage(render())); + } + + @Override + public Prompt create(Map model) { + return new Prompt(new DeveloperMessage(render(model))); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/DeveloperMessageTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/DeveloperMessageTests.java new file mode 100644 index 00000000000..0b13e463441 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/DeveloperMessageTests.java @@ -0,0 +1,90 @@ +package org.springframework.ai.chat.messages; + +import org.junit.jupiter.api.Test; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.*; +import static org.springframework.ai.chat.messages.AbstractMessage.MESSAGE_TYPE; + +class DeveloperMessageTests { + + @Test + void developerMessageWithNullText() { + assertThrows(IllegalArgumentException.class, () -> new DeveloperMessage((String) null)); + } + + @Test + void developerMessageWithTextContent() { + String text = "Developer instructions for the model."; + DeveloperMessage message = new DeveloperMessage(text); + assertEquals(text, message.getText()); + assertEquals(MessageType.DEVELOPER, message.getMetadata().get(MESSAGE_TYPE)); + } + + @Test + void developerMessageWithNullResource() { + assertThrows(IllegalArgumentException.class, () -> new DeveloperMessage((Resource) null)); + } + + @Test + void developerMessageWithResource() { + DeveloperMessage message = new DeveloperMessage(new ClassPathResource("prompt-developer.txt")); + assertEquals("Tell me, did you sail across the sun?", message.getText()); + assertEquals(MessageType.DEVELOPER, message.getMetadata().get(MESSAGE_TYPE)); + } + + @Test + void developerMessageFromBuilderWithText() { + String text = "Developer instructions for {name}"; + DeveloperMessage message = DeveloperMessage.builder().text(text).metadata(Map.of("key", "value")).build(); + assertEquals(text, message.getText()); + assertThat(message.getMetadata()).hasSize(2) + .containsEntry(MESSAGE_TYPE, MessageType.DEVELOPER) + .containsEntry("key", "value"); + } + + @Test + void developerMessageFromBuilderWithResource() { + Resource resource = new ClassPathResource("prompt-developer.txt"); + DeveloperMessage message = DeveloperMessage.builder().text(resource).metadata(Map.of("key", "value")).build(); + assertEquals("Tell me, did you sail across the sun?", message.getText()); + assertThat(message.getMetadata()).hasSize(2) + .containsEntry(MESSAGE_TYPE, MessageType.DEVELOPER) + .containsEntry("key", "value"); + } + + @Test + void developerMessageCopy() { + String text1 = "Developer instructions"; + Map metadata1 = Map.of("key", "value"); + DeveloperMessage message1 = DeveloperMessage.builder().text(text1).metadata(metadata1).build(); + + DeveloperMessage message2 = message1.copy(); + + assertThat(message2.getText()).isEqualTo(text1); + assertThat(message2.getMetadata()).hasSize(2).isNotSameAs(metadata1); + } + + @Test + void developerMessageMutate() { + String text1 = "Developer instructions"; + Map metadata1 = Map.of("key", "value"); + DeveloperMessage message1 = DeveloperMessage.builder().text(text1).metadata(metadata1).build(); + + DeveloperMessage message2 = message1.mutate().build(); + + assertThat(message2.getText()).isEqualTo(text1); + assertThat(message2.getMetadata()).hasSize(2).isNotSameAs(metadata1); + + String newText = "Updated developer instructions"; + DeveloperMessage message3 = message2.mutate().text(newText).build(); + + assertThat(message3.getText()).isEqualTo(newText); + assertThat(message3.getMetadata()).hasSize(2); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/UserMessageTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/UserMessageTests.java index 26bb59718bd..70adb99af49 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/UserMessageTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/messages/UserMessageTests.java @@ -38,7 +38,7 @@ class UserMessageTests { @Test void userMessageWithNullText() { assertThatThrownBy(() -> new UserMessage((String) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Content must not be null for SYSTEM or USER messages"); + .hasMessageContaining("Content must not be null for SYSTEM, DEVELOPER or USER messages"); ; } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java index a9e8f84ecef..7c43508a6da 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java @@ -36,11 +36,11 @@ class PromptTests { @Test void whenContentIsNullThenThrow() { assertThatThrownBy(() -> new Prompt((String) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Content must not be null for SYSTEM or USER messages"); + .hasMessageContaining("Content must not be null for SYSTEM, DEVELOPER or USER messages"); assertThatThrownBy(() -> new Prompt((String) null, ChatOptions.builder().build())) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Content must not be null for SYSTEM or USER messages"); + .hasMessageContaining("Content must not be null for SYSTEM, DEVELOPER or USER messages"); } @Test diff --git a/spring-ai-model/src/test/resources/prompt-developer.txt b/spring-ai-model/src/test/resources/prompt-developer.txt new file mode 100644 index 00000000000..b292fd2f45c --- /dev/null +++ b/spring-ai-model/src/test/resources/prompt-developer.txt @@ -0,0 +1 @@ +Tell me, did you sail across the sun? \ No newline at end of file