diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index fd282fd319d..d53e998967a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -38,6 +38,8 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.MimeType; /** @@ -49,6 +51,7 @@ * @author Christian Tzolov * @author Josh Long * @author Arjen Poutsma + * @author Thomas Vitale * @since 1.0.0 */ public interface ChatClient { @@ -62,7 +65,9 @@ static ChatClient create(ChatModel chatModel, ObservationRegistry observationReg } static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry, - ChatClientObservationConvention observationConvention) { + @Nullable ChatClientObservationConvention observationConvention) { + Assert.notNull(chatModel, "chatModel cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); return builder(chatModel, observationRegistry, observationConvention).build(); } @@ -71,7 +76,9 @@ static Builder builder(ChatModel chatModel) { } static Builder builder(ChatModel chatModel, ObservationRegistry observationRegistry, - ChatClientObservationConvention customObservationConvention) { + @Nullable ChatClientObservationConvention customObservationConvention) { + Assert.notNull(chatModel, "chatModel cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); return new DefaultChatClientBuilder(chatModel, observationRegistry, customObservationConvention); } @@ -136,14 +143,19 @@ interface AdvisorSpec { interface CallResponseSpec { + @Nullable T entity(ParameterizedTypeReference type); + @Nullable T entity(StructuredOutputConverter structuredOutputConverter); + @Nullable T entity(Class type); + @Nullable ChatResponse chatResponse(); + @Nullable String content(); ResponseEntity responseEntity(Class type); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index c0b2d5a8d22..6f46f2749b0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -21,6 +21,7 @@ import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -63,6 +64,7 @@ import org.springframework.core.Ordered; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; @@ -88,18 +90,44 @@ public class DefaultChatClient implements ChatClient { private final DefaultChatClientRequestSpec defaultChatClientRequest; public DefaultChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) { + Assert.notNull(defaultChatClientRequest, "defaultChatClientRequest cannot be null"); this.defaultChatClientRequest = defaultChatClientRequest; } - private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest, String formatParam) { + private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest, + @Nullable String formatParam) { + Assert.notNull(inputRequest, "inputRequest cannot be null"); + Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); if (StringUtils.hasText(formatParam)) { advisorContext.put("formatParam", formatParam); } - return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, - inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, - inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, + // Process userText, media and messages before creating the AdvisedRequest. + String userText = inputRequest.userText; + List media = inputRequest.media; + List messages = inputRequest.messages; + + // If the userText is empty, then try extracting the userText from the last + // message + // in the messages list and remove it from the messages list. + if (!StringUtils.hasText(userText) && !CollectionUtils.isEmpty(messages)) { + Message lastMessage = messages.get(messages.size() - 1); + if (lastMessage.getMessageType() == MessageType.USER) { + UserMessage userMessage = (UserMessage) lastMessage; + if (StringUtils.hasText(userMessage.getContent())) { + userText = lastMessage.getContent(); + } + Collection messageMedia = userMessage.getMedia(); + if (!CollectionUtils.isEmpty(messageMedia)) { + media.addAll(messageMedia); + } + messages = messages.subList(0, messages.size() - 1); + } + } + + return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.chatOptions, + media, inputRequest.functionNames, inputRequest.functionCallbacks, messages, inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext, inputRequest.toolContext); } @@ -122,10 +150,13 @@ public ChatClientRequestSpec prompt() { @Override public ChatClientRequestSpec prompt(String content) { + Assert.hasText(content, "content cannot be null or empty"); return prompt(new Prompt(content)); } + @Override public ChatClientRequestSpec prompt(Prompt prompt) { + Assert.notNull(prompt, "prompt cannot be null"); DefaultChatClientRequestSpec spec = new DefaultChatClientRequestSpec(this.defaultChatClientRequest); @@ -135,26 +166,10 @@ public ChatClientRequestSpec prompt(Prompt prompt) { } // Messages - List messages = prompt.getInstructions(); - - if (!CollectionUtils.isEmpty(messages)) { - var lastMessage = messages.get(messages.size() - 1); - if (lastMessage.getMessageType() == MessageType.USER) { - // Unzip the last message - var userMessage = (UserMessage) lastMessage; - if (StringUtils.hasText(userMessage.getContent())) { - spec.user(lastMessage.getContent()); - } - var media = userMessage.getMedia(); - if (!CollectionUtils.isEmpty(media)) { - spec.user(u -> u.media(media.toArray(new Media[media.size()]))); - } - messages = messages.subList(0, messages.size() - 1); - } + if (prompt.getInstructions() != null) { + spec.messages(prompt.getInstructions()); } - spec.messages(messages); - return spec; } @@ -173,34 +188,44 @@ public static class DefaultPromptUserSpec implements PromptUserSpec { private final List media = new ArrayList<>(); - private String text = ""; + @Nullable + private String text; @Override public PromptUserSpec media(Media... media) { + Assert.notNull(media, "media cannot be null"); + Assert.noNullElements(media, "media cannot contain null elements"); this.media.addAll(Arrays.asList(media)); return this; } @Override public PromptUserSpec media(MimeType mimeType, URL url) { + Assert.notNull(mimeType, "mimeType cannot be null"); + Assert.notNull(url, "url cannot be null"); this.media.add(new Media(mimeType, url)); return this; } @Override public PromptUserSpec media(MimeType mimeType, Resource resource) { + Assert.notNull(mimeType, "mimeType cannot be null"); + Assert.notNull(resource, "resource cannot be null"); this.media.add(new Media(mimeType, resource)); return this; } @Override public PromptUserSpec text(String text) { + Assert.hasText(text, "text cannot be null or empty"); this.text = text; return this; } @Override public PromptUserSpec text(Resource text, Charset charset) { + Assert.notNull(text, "text cannot be null"); + Assert.notNull(charset, "charset cannot be null"); try { this.text(text.getContentAsString(charset)); } @@ -212,22 +237,29 @@ public PromptUserSpec text(Resource text, Charset charset) { @Override public PromptUserSpec text(Resource text) { + Assert.notNull(text, "text cannot be null"); this.text(text, Charset.defaultCharset()); return this; } @Override - public PromptUserSpec param(String k, Object v) { - this.params.put(k, v); + public PromptUserSpec param(String key, Object value) { + Assert.hasText(key, "key cannot be null or empty"); + Assert.notNull(value, "value cannot be null"); + this.params.put(key, value); return this; } @Override - public PromptUserSpec params(Map p) { - this.params.putAll(p); + public PromptUserSpec params(Map params) { + Assert.notNull(params, "params cannot be null"); + Assert.noNullElements(params.keySet(), "param keys cannot contain null elements"); + Assert.noNullElements(params.values(), "param values cannot contain null elements"); + this.params.putAll(params); return this; } + @Nullable protected String text() { return this.text; } @@ -246,16 +278,20 @@ public static class DefaultPromptSystemSpec implements PromptSystemSpec { private final Map params = new HashMap<>(); - private String text = ""; + @Nullable + private String text; @Override public PromptSystemSpec text(String text) { + Assert.hasText(text, "text cannot be null or empty"); this.text = text; return this; } @Override public PromptSystemSpec text(Resource text, Charset charset) { + Assert.notNull(text, "text cannot be null"); + Assert.notNull(charset, "charset cannot be null"); try { this.text(text.getContentAsString(charset)); } @@ -267,22 +303,29 @@ public PromptSystemSpec text(Resource text, Charset charset) { @Override public PromptSystemSpec text(Resource text) { + Assert.notNull(text, "text cannot be null"); this.text(text, Charset.defaultCharset()); return this; } @Override - public PromptSystemSpec param(String k, Object v) { - this.params.put(k, v); + public PromptSystemSpec param(String key, Object value) { + Assert.hasText(key, "key cannot be null or empty"); + Assert.notNull(value, "value cannot be null"); + this.params.put(key, value); return this; } @Override - public PromptSystemSpec params(Map p) { - this.params.putAll(p); + public PromptSystemSpec params(Map params) { + Assert.notNull(params, "params cannot be null"); + Assert.noNullElements(params.keySet(), "param keys cannot contain null elements"); + Assert.noNullElements(params.values(), "param values cannot contain null elements"); + this.params.putAll(params); return this; } + @Nullable protected String text() { return this.text; } @@ -299,22 +342,35 @@ public static class DefaultAdvisorSpec implements AdvisorSpec { private final Map params = new HashMap<>(); - public AdvisorSpec param(String k, Object v) { - this.params.put(k, v); + @Override + public AdvisorSpec param(String key, Object value) { + Assert.hasText(key, "key cannot be null or empty"); + Assert.notNull(value, "value cannot be null"); + this.params.put(key, value); return this; } - public AdvisorSpec params(Map p) { - this.params.putAll(p); + @Override + public AdvisorSpec params(Map params) { + Assert.notNull(params, "params cannot be null"); + Assert.noNullElements(params.keySet(), "param keys cannot contain null elements"); + Assert.noNullElements(params.values(), "param values cannot contain null elements"); + this.params.putAll(params); return this; } + @Override public AdvisorSpec advisors(Advisor... advisors) { + Assert.notNull(advisors, "advisors cannot be null"); + Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(List.of(advisors)); return this; } + @Override public AdvisorSpec advisors(List advisors) { + Assert.notNull(advisors, "advisors cannot be null"); + Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(advisors); return this; } @@ -334,57 +390,80 @@ public static class DefaultCallResponseSpec implements CallResponseSpec { private final DefaultChatClientRequestSpec request; public DefaultCallResponseSpec(DefaultChatClientRequestSpec request) { + Assert.notNull(request, "request cannot be null"); this.request = request; } + @Override public ResponseEntity responseEntity(Class type) { - Assert.notNull(type, "the class must be non-null"); + Assert.notNull(type, "type cannot be null"); return doResponseEntity(new BeanOutputConverter(type)); } + @Override public ResponseEntity responseEntity(ParameterizedTypeReference type) { + Assert.notNull(type, "type cannot be null"); return doResponseEntity(new BeanOutputConverter(type)); } + @Override public ResponseEntity responseEntity( StructuredOutputConverter structuredOutputConverter) { + Assert.notNull(structuredOutputConverter, "structuredOutputConverter cannot be null"); return doResponseEntity(structuredOutputConverter); } - protected ResponseEntity doResponseEntity(StructuredOutputConverter boc) { - var chatResponse = doGetObservableChatResponse(this.request, boc.getFormat()); - var responseContent = chatResponse.getResult().getOutput().getContent(); - T entity = boc.convert(responseContent); - + protected ResponseEntity doResponseEntity(StructuredOutputConverter outputConverter) { + Assert.notNull(outputConverter, "structuredOutputConverter cannot be null"); + var chatResponse = doGetObservableChatResponse(this.request, outputConverter.getFormat()); + var responseContent = getContentFromChatResponse(chatResponse); + if (responseContent == null) { + return new ResponseEntity<>(chatResponse, null); + } + T entity = outputConverter.convert(responseContent); return new ResponseEntity<>(chatResponse, entity); } + @Override + @Nullable public T entity(ParameterizedTypeReference type) { - return doSingleWithBeanOutputConverter(new BeanOutputConverter(type)); + Assert.notNull(type, "type cannot be null"); + return doSingleWithBeanOutputConverter(new BeanOutputConverter<>(type)); } + @Override + @Nullable public T entity(StructuredOutputConverter structuredOutputConverter) { + Assert.notNull(structuredOutputConverter, "structuredOutputConverter cannot be null"); return doSingleWithBeanOutputConverter(structuredOutputConverter); } - private T doSingleWithBeanOutputConverter(StructuredOutputConverter boc) { - var chatResponse = doGetObservableChatResponse(this.request, boc.getFormat()); - var stringResponse = chatResponse.getResult().getOutput().getContent(); - return boc.convert(stringResponse); + @Override + @Nullable + public T entity(Class type) { + Assert.notNull(type, "type cannot be null"); + var outputConverter = new BeanOutputConverter<>(type); + return doSingleWithBeanOutputConverter(outputConverter); } - public T entity(Class type) { - Assert.notNull(type, "the class must be non-null"); - var boc = new BeanOutputConverter(type); - return doSingleWithBeanOutputConverter(boc); + @Nullable + private T doSingleWithBeanOutputConverter(StructuredOutputConverter outputConverter) { + var chatResponse = doGetObservableChatResponse(this.request, outputConverter.getFormat()); + var stringResponse = getContentFromChatResponse(chatResponse); + if (stringResponse == null) { + return null; + } + return outputConverter.convert(stringResponse); } + @Nullable private ChatResponse doGetChatResponse() { - return this.doGetObservableChatResponse(this.request, ""); + return this.doGetObservableChatResponse(this.request, null); } + @Nullable private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec inputRequest, - String formatParam) { + @Nullable String formatParam) { ChatClientObservationContext observationContext = ChatClientObservationContext.builder() .withRequest(inputRequest) @@ -395,19 +474,15 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation( inputRequest.getCustomObservationConvention(), DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, () -> observationContext, inputRequest.getObservationRegistry()); - return observation.observe(() -> { - ChatResponse chatResponse = doGetChatResponse(inputRequest, formatParam, observation); - return chatResponse; - }); - + return observation.observe(() -> doGetChatResponse(inputRequest, formatParam, observation)); } - private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, String formatParam, - Observation parentObservation) { + private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, + @Nullable String formatParam, Observation parentObservation) { AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec, formatParam); - // Apply the around advisor chain that terminates with the, last, model call + // Apply the around advisor chain that terminates with the last model call // advisor. AdvisedResponse advisedResponse = inputRequestSpec.aroundAdvisorChainBuilder.build() .nextAroundCall(advisedRequest); @@ -415,12 +490,26 @@ private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest return advisedResponse.response(); } + @Nullable + private static String getContentFromChatResponse(@Nullable ChatResponse chatResponse) { + if (chatResponse == null || chatResponse.getResult() == null || chatResponse.getResult().getOutput() == null + || chatResponse.getResult().getOutput().getContent() == null) { + return null; + } + return chatResponse.getResult().getOutput().getContent(); + } + + @Override + @Nullable public ChatResponse chatResponse() { return doGetChatResponse(); } + @Override + @Nullable public String content() { - return doGetChatResponse().getResult().getOutput().getContent(); + ChatResponse chatResponse = doGetChatResponse(); + return getContentFromChatResponse(chatResponse); } } @@ -430,6 +519,7 @@ public static class DefaultStreamResponseSpec implements StreamResponseSpec { private final DefaultChatClientRequestSpec request; public DefaultStreamResponseSpec(DefaultChatClientRequestSpec request) { + Assert.notNull(request, "request cannot be null"); this.request = request; } @@ -448,11 +538,10 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)) .start(); - var initialAdvisedRequest = toAdvisedRequest(inputRequest, ""); + var initialAdvisedRequest = toAdvisedRequest(inputRequest, null); // @formatter:off - // Apply the around advisor chain that terminates with the, last, - // model call advisor. + // Apply the around advisor chain that terminates with the last model call advisor. Flux stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest); return stream @@ -464,10 +553,12 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ }); } + @Override public Flux chatResponse() { return doGetObservableFluxChatResponse(this.request); } + @Override public Flux content() { return doGetObservableFluxChatResponse(this.request).map(r -> { if (r.getResult() == null || r.getResult().getOutput() == null @@ -508,10 +599,13 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final Map toolContext = new HashMap<>(); - private String userText = ""; + @Nullable + private String userText; - private String systemText = ""; + @Nullable + private String systemText; + @Nullable private ChatOptions chatOptions; /* copy constructor */ @@ -521,11 +615,25 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext); } - public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map userParams, - String systemText, Map systemParams, List functionCallbacks, - List messages, List functionNames, List media, ChatOptions chatOptions, - List advisors, Map advisorParams, ObservationRegistry observationRegistry, - ChatClientObservationConvention customObservationConvention, Map toolContext) { + public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, + Map userParams, @Nullable String systemText, Map systemParams, + List functionCallbacks, List messages, List functionNames, + List media, @Nullable ChatOptions chatOptions, List advisors, + Map advisorParams, ObservationRegistry observationRegistry, + @Nullable ChatClientObservationConvention customObservationConvention, + Map toolContext) { + + Assert.notNull(chatModel, "chatModel cannot be null"); + Assert.notNull(userParams, "userParams cannot be null"); + Assert.notNull(systemParams, "systemParams cannot be null"); + Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); + Assert.notNull(messages, "messages cannot be null"); + Assert.notNull(functionNames, "functionNames cannot be null"); + Assert.notNull(media, "media cannot be null"); + Assert.notNull(advisors, "advisors cannot be null"); + Assert.notNull(advisorParams, "advisorParams cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(toolContext, "toolContext cannot be null"); this.chatModel = chatModel; this.chatOptions = chatOptions != null ? chatOptions.copy() @@ -543,7 +651,8 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map getUserParams() { return this.userParams; } + @Nullable public String getSystemText() { return this.systemText; } @@ -616,6 +727,7 @@ public Map getSystemParams() { return this.systemParams; } + @Nullable public ChatOptions getChatOptions() { return this.chatOptions; } @@ -655,13 +767,21 @@ public Map getToolContext() { public Builder mutate() { DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient .builder(this.chatModel, this.observationRegistry, this.customObservationConvention) - .defaultSystem(s -> s.text(this.systemText).params(this.systemParams)) - .defaultUser(u -> u.text(this.userText) - .params(this.userParams) - .media(this.media.toArray(new Media[this.media.size()]))) - .defaultOptions(this.chatOptions) .defaultFunctions(StringUtils.toStringArray(this.functionNames)); + if (StringUtils.hasText(this.userText)) { + builder.defaultUser( + u -> u.text(this.userText).params(this.userParams).media(this.media.toArray(new Media[0]))); + } + + if (StringUtils.hasText(this.systemText)) { + builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams)); + } + + if (this.chatOptions != null) { + builder.defaultOptions(this.chatOptions); + } + // workaround to set the missing fields. builder.defaultRequest.getMessages().addAll(this.messages); builder.defaultRequest.getFunctionCallbacks().addAll(this.functionCallbacks); @@ -671,43 +791,47 @@ public Builder mutate() { } public ChatClientRequestSpec advisors(Consumer consumer) { - Assert.notNull(consumer, "the consumer must be non-null"); - var as = new DefaultAdvisorSpec(); - consumer.accept(as); - this.advisorParams.putAll(as.getParams()); - this.advisors.addAll(as.getAdvisors()); - this.aroundAdvisorChainBuilder.pushAll(as.getAdvisors()); + Assert.notNull(consumer, "consumer cannot be null"); + var advisorSpec = new DefaultAdvisorSpec(); + consumer.accept(advisorSpec); + this.advisorParams.putAll(advisorSpec.getParams()); + this.advisors.addAll(advisorSpec.getAdvisors()); + this.aroundAdvisorChainBuilder.pushAll(advisorSpec.getAdvisors()); return this; } public ChatClientRequestSpec advisors(Advisor... advisors) { - Assert.notNull(advisors, "the advisors must be non-null"); + Assert.notNull(advisors, "advisors cannot be null"); + Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(Arrays.asList(advisors)); this.aroundAdvisorChainBuilder.pushAll(Arrays.asList(advisors)); return this; } public ChatClientRequestSpec advisors(List advisors) { - Assert.notNull(advisors, "the advisors must be non-null"); + Assert.notNull(advisors, "advisors cannot be null"); + Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(advisors); this.aroundAdvisorChainBuilder.pushAll(advisors); return this; } public ChatClientRequestSpec messages(Message... messages) { - Assert.notNull(messages, "the messages must be non-null"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); this.messages.addAll(List.of(messages)); return this; } public ChatClientRequestSpec messages(List messages) { - Assert.notNull(messages, "the messages must be non-null"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); this.messages.addAll(messages); return this; } public ChatClientRequestSpec options(T options) { - Assert.notNull(options, "the options must be non-null"); + Assert.notNull(options, "options cannot be null"); this.chatOptions = options; return this; } @@ -720,9 +844,9 @@ public ChatClientRequestSpec function(String name, String description, public ChatClientRequestSpec function(String name, String description, java.util.function.BiFunction biFunction) { - Assert.hasText(name, "the name must be non-null and non-empty"); - Assert.hasText(description, "the description must be non-null and non-empty"); - Assert.notNull(biFunction, "the biFunction must be non-null"); + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.notNull(biFunction, "biFunction cannot be null"); FunctionCallbackWrapper fcw = FunctionCallbackWrapper.builder(biFunction) .withDescription(description) @@ -733,12 +857,12 @@ public ChatClientRequestSpec function(String name, String description, return this; } - public ChatClientRequestSpec function(String name, String description, Class inputType, + public ChatClientRequestSpec function(String name, String description, @Nullable Class inputType, java.util.function.Function function) { - Assert.hasText(name, "the name must be non-null and non-empty"); - Assert.hasText(description, "the description must be non-null and non-empty"); - Assert.notNull(function, "the function must be non-null"); + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.notNull(function, "function cannot be null"); var fcw = FunctionCallbackWrapper.builder(function) .withDescription(description) @@ -751,36 +875,39 @@ public ChatClientRequestSpec function(String name, String description, Cl } public ChatClientRequestSpec functions(String... functionBeanNames) { - Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null"); + Assert.notNull(functionBeanNames, "functionBeanNames cannot be null"); + Assert.noNullElements(functionBeanNames, "functionBeanNames cannot contain null elements"); this.functionNames.addAll(List.of(functionBeanNames)); return this; } public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) { - Assert.notNull(functionCallbacks, "the functionCallbacks must be non-null"); + Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); + Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements"); this.functionCallbacks.addAll(Arrays.asList(functionCallbacks)); return this; } public ChatClientRequestSpec toolContext(Map toolContext) { - Assert.notNull(toolContext, "the toolContext must be non-null"); + Assert.notNull(toolContext, "toolContext cannot be null"); + Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements"); + Assert.noNullElements(toolContext.values(), "toolContext values cannot contain null elements"); this.toolContext.putAll(toolContext); return this; } public ChatClientRequestSpec system(String text) { - Assert.notNull(text, "the text must be non-null"); + Assert.hasText(text, "text cannot be null or empty"); this.systemText = text; return this; } - public ChatClientRequestSpec system(Resource textResource, Charset charset) { - - Assert.notNull(textResource, "the text resource must be non-null"); - Assert.notNull(charset, "the charset must be non-null"); + public ChatClientRequestSpec system(Resource text, Charset charset) { + Assert.notNull(text, "text cannot be null"); + Assert.notNull(charset, "charset cannot be null"); try { - this.systemText = textResource.getContentAsString(charset); + this.systemText = text.getContentAsString(charset); } catch (IOException e) { throw new RuntimeException(e); @@ -789,32 +916,30 @@ public ChatClientRequestSpec system(Resource textResource, Charset charset) { } public ChatClientRequestSpec system(Resource text) { - Assert.notNull(text, "the text resource must be non-null"); + Assert.notNull(text, "text cannot be null"); return this.system(text, Charset.defaultCharset()); } public ChatClientRequestSpec system(Consumer consumer) { + Assert.notNull(consumer, "consumer cannot be null"); - Assert.notNull(consumer, "the consumer must be non-null"); - - var ss = new DefaultPromptSystemSpec(); - consumer.accept(ss); - this.systemText = StringUtils.hasText(ss.text()) ? ss.text() : this.systemText; - this.systemParams.putAll(ss.params()); + var systemSpec = new DefaultPromptSystemSpec(); + consumer.accept(systemSpec); + this.systemText = StringUtils.hasText(systemSpec.text()) ? systemSpec.text() : this.systemText; + this.systemParams.putAll(systemSpec.params()); return this; } public ChatClientRequestSpec user(String text) { - Assert.notNull(text, "the text must be non-null"); + Assert.hasText(text, "text cannot be null or empty"); this.userText = text; return this; } public ChatClientRequestSpec user(Resource text, Charset charset) { - - Assert.notNull(text, "the text resource must be non-null"); - Assert.notNull(charset, "the charset must be non-null"); + Assert.notNull(text, "text cannot be null"); + Assert.notNull(charset, "charset cannot be null"); try { this.userText = text.getContentAsString(charset); @@ -826,12 +951,12 @@ public ChatClientRequestSpec user(Resource text, Charset charset) { } public ChatClientRequestSpec user(Resource text) { - Assert.notNull(text, "the text resource must be non-null"); + Assert.notNull(text, "text cannot be null"); return this.user(text, Charset.defaultCharset()); } public ChatClientRequestSpec user(Consumer consumer) { - Assert.notNull(consumer, "the consumer must be non-null"); + Assert.notNull(consumer, "consumer cannot be null"); var us = new DefaultPromptUserSpec(); consumer.accept(us); @@ -860,6 +985,8 @@ public static class DefaultCallPromptResponseSpec implements CallPromptResponseS private final Prompt prompt; public DefaultCallPromptResponseSpec(ChatModel chatModel, Prompt prompt) { + Assert.notNull(chatModel, "chatModel cannot be null"); + Assert.notNull(prompt, "prompt cannot be null"); this.chatModel = chatModel; this.prompt = prompt; } @@ -889,6 +1016,8 @@ public static class DefaultStreamPromptResponseSpec implements StreamPromptRespo private final StreamingChatModel chatModel; public DefaultStreamPromptResponseSpec(StreamingChatModel streamingChatModel, Prompt prompt) { + Assert.notNull(streamingChatModel, "streamingChatModel cannot be null"); + Assert.notNull(prompt, "prompt cannot be null"); this.chatModel = streamingChatModel; this.prompt = prompt; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 6b03d8e2a40..4ae3833d868 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -35,6 +35,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.core.io.Resource; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -46,6 +47,7 @@ * @author Christian Tzolov * @author Josh Long * @author Arjen Poutsma + * @author Thomas Vitale * @since 1.0.0 */ public class DefaultChatClientBuilder implements Builder { @@ -57,10 +59,10 @@ public class DefaultChatClientBuilder implements Builder { } public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observationRegistry, - ChatClientObservationConvention customObservationConvention) { + @Nullable ChatClientObservationConvention customObservationConvention) { Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); - this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), + this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of()); } @@ -69,8 +71,8 @@ public ChatClient build() { return new DefaultChatClient(this.defaultRequest); } - public Builder defaultAdvisors(Advisor... advisor) { - this.defaultRequest.advisors(advisor); + public Builder defaultAdvisors(Advisor... advisors) { + this.defaultRequest.advisors(advisors); return this; } @@ -95,6 +97,8 @@ public Builder defaultUser(String text) { } public Builder defaultUser(Resource text, Charset charset) { + Assert.notNull(text, "text cannot be null"); + Assert.notNull(charset, "charset cannot be null"); try { this.defaultRequest.user(text.getContentAsString(charset)); } @@ -119,6 +123,8 @@ public Builder defaultSystem(String text) { } public Builder defaultSystem(Resource text, Charset charset) { + Assert.notNull(text, "text cannot be null"); + Assert.notNull(charset, "charset cannot be null"); try { this.defaultRequest.system(text.getContentAsString(charset)); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java index b6ab8fedda7..bd454d55999 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java @@ -16,6 +16,8 @@ package org.springframework.ai.chat.client; +import org.springframework.lang.Nullable; + /** * Represents a {@link org.springframework.ai.model.Model} response that includes the * entire response along withe specified response entity type. @@ -25,14 +27,17 @@ * @param response the entire response object. * @param entity the converted entity object. * @author Christian Tzolov + * @author Thomas Vitale * @since 1.0.0 */ -public record ResponseEntity(R response, E entity) { +public record ResponseEntity(@Nullable R response, @Nullable E entity) { + @Nullable public R getResponse() { return this.response; } + @Nullable public E getEntity() { return this.entity; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index 08e1aa276e7..7bc48c7723e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -86,15 +86,30 @@ public record AdvisedRequest( Assert.notNull(chatModel, "chatModel cannot be null"); Assert.hasText(userText, "userText cannot be null or empty"); Assert.notNull(media, "media cannot be null"); + Assert.noNullElements(media, "media cannot contain null elements"); Assert.notNull(functionNames, "functionNames cannot be null"); + Assert.noNullElements(functionNames, "functionNames cannot contain null elements"); Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); + Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements"); Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); Assert.notNull(userParams, "userParams cannot be null"); + Assert.noNullElements(userParams.keySet(), "userParams keys cannot contain null elements"); + Assert.noNullElements(userParams.values(), "userParams values cannot contain null elements"); Assert.notNull(systemParams, "systemParams cannot be null"); + Assert.noNullElements(systemParams.keySet(), "systemParams keys cannot contain null elements"); + Assert.noNullElements(systemParams.values(), "systemParams values cannot contain null elements"); Assert.notNull(advisors, "advisors cannot be null"); + Assert.noNullElements(advisors, "advisors cannot contain null elements"); Assert.notNull(advisorParams, "advisorParams cannot be null"); + Assert.noNullElements(advisorParams.keySet(), "advisorParams keys cannot contain null elements"); + Assert.noNullElements(advisorParams.values(), "advisorParams values cannot contain null elements"); Assert.notNull(adviseContext, "adviseContext cannot be null"); + Assert.noNullElements(adviseContext.keySet(), "adviseContext keys cannot contain null elements"); + Assert.noNullElements(adviseContext.values(), "adviseContext values cannot contain null elements"); Assert.notNull(toolContext, "toolContext cannot be null"); + Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements"); + Assert.noNullElements(toolContext.values(), "toolContext values cannot contain null elements"); } public static Builder builder() { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java index 8bb383c1cc5..d6e329310ea 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java @@ -22,6 +22,7 @@ import java.util.function.Function; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -31,11 +32,12 @@ * @author Thomas Vitale * @since 1.0.0 */ -public record AdvisedResponse(ChatResponse response, Map adviseContext) { +public record AdvisedResponse(@Nullable ChatResponse response, Map adviseContext) { public AdvisedResponse { - Assert.notNull(response, "response cannot be null"); Assert.notNull(adviseContext, "adviseContext cannot be null"); + Assert.noNullElements(adviseContext.keySet(), "adviseContext keys cannot be null"); + Assert.noNullElements(adviseContext.values(), "adviseContext values cannot be null"); } public static Builder builder() { @@ -55,6 +57,7 @@ public AdvisedResponse updateContext(Function, Map adviseContext; @@ -62,7 +65,7 @@ public static final class Builder { private Builder() { } - public Builder withResponse(ChatResponse response) { + public Builder withResponse(@Nullable ChatResponse response) { this.response = response; return this; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/package-info.java new file mode 100644 index 00000000000..9e66e8de1b4 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.chat.client; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 743314f8d4a..99a6f89c135 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -17,6 +17,7 @@ package org.springframework.ai.chat.prompt; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -54,6 +55,10 @@ public Prompt(List messages) { this(messages, null); } + public Prompt(Message... messages) { + this(Arrays.asList(messages), null); + } + public Prompt(String contents, ChatOptions chatOptions) { this(new UserMessage(contents), chatOptions); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 44824f10323..1f9d407c38a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -33,6 +33,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -46,6 +47,7 @@ import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.BDDMockito.given; /** @@ -69,7 +71,7 @@ private String join(Flux fluxContent) { // ChatClient Builder Tests @Test - public void defaultSystemText() { + void defaultSystemText() { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); @@ -118,7 +120,7 @@ public void defaultSystemText() { } @Test - public void defaultSystemTextLambda() { + void defaultSystemTextLambda() { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); @@ -194,7 +196,7 @@ public void defaultSystemTextLambda() { } @Test - public void mutateDefaults() { + void mutateDefaults() { PortableFunctionCallingOptions options = new FunctionCallingOptionsBuilder().build(); given(this.chatModel.getDefaultOptions()).willReturn(options); @@ -322,7 +324,7 @@ public void mutateDefaults() { } @Test - public void mutatePrompt() { + void mutatePrompt() { PortableFunctionCallingOptions options = new FunctionCallingOptionsBuilder().build(); given(this.chatModel.getDefaultOptions()).willReturn(options); @@ -412,7 +414,7 @@ public void mutatePrompt() { } @Test - public void defaultUserText() { + void defaultUserText() { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); @@ -437,7 +439,7 @@ public void defaultUserText() { } @Test - public void simpleUserPromptAsString() { + void simpleUserPromptAsString() { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); @@ -450,7 +452,7 @@ public void simpleUserPromptAsString() { } @Test - public void simpleUserPrompt() { + void simpleUserPrompt() { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); @@ -463,7 +465,7 @@ public void simpleUserPrompt() { } @Test - public void simpleUserPromptObject() throws MalformedURLException { + void simpleUserPromptObject() { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); @@ -482,7 +484,7 @@ public void simpleUserPromptObject() throws MalformedURLException { } @Test - public void simpleSystemPrompt() throws MalformedURLException { + void simpleSystemPrompt() { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); @@ -503,7 +505,7 @@ public void simpleSystemPrompt() throws MalformedURLException { } @Test - public void complexCall() throws MalformedURLException { + void complexCall() throws MalformedURLException { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); @@ -545,4 +547,264 @@ public void complexCall() throws MalformedURLException { assertThat(options.getFunctions()).isEmpty(); } + // Constructors + + @Test + void whenCreateAndChatModelIsNullThenThrow() { + assertThatThrownBy(() -> ChatClient.create(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatModel cannot be null"); + } + + @Test + void whenCreateAndObservationRegistryIsNullThenThrow() { + assertThatThrownBy(() -> ChatClient.create(this.chatModel, null, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("observationRegistry cannot be null"); + } + + @Test + void whenBuilderAndChatModelIsNullThenThrow() { + assertThatThrownBy(() -> ChatClient.builder(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatModel cannot be null"); + } + + @Test + void whenBuilderAndObservationRegistryIsNullThenThrow() { + assertThatThrownBy(() -> ChatClient.builder(this.chatModel, null, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("observationRegistry cannot be null"); + } + + // Prompt Tests - User + + @Test + void whenPromptWithStringContent() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var content = chatClient.prompt("my question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(1); + var userMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(userMessage.getContent()).isEqualTo("my question"); + assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + } + + @Test + void whenPromptWithMessages() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); + var content = chatClient.prompt(prompt).call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); + var userMessage = this.promptCaptor.getValue().getInstructions().get(1); + assertThat(userMessage.getContent()).isEqualTo("my question"); + assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + } + + @Test + void whenPromptWithStringContentAndUserText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var content = chatClient.prompt("my question").user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); + var userMessage = this.promptCaptor.getValue().getInstructions().get(1); + assertThat(userMessage.getContent()).isEqualTo("another question"); + assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + } + + @Test + void whenPromptWithHistoryAndUserText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var prompt = new Prompt(new UserMessage("my question"), new AssistantMessage("your answer")); + var content = chatClient.prompt(prompt).user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(3); + var userMessage = this.promptCaptor.getValue().getInstructions().get(2); + assertThat(userMessage.getContent()).isEqualTo("another question"); + assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + } + + @Test + void whenPromptWithUserMessageAndUserText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var prompt = new Prompt(new UserMessage("my question")); + var content = chatClient.prompt(prompt).user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); + var userMessage = this.promptCaptor.getValue().getInstructions().get(1); + assertThat(userMessage.getContent()).isEqualTo("another question"); + assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + } + + @Test + void whenMessagesWithHistoryAndUserText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + List messages = List.of(new UserMessage("my question"), new AssistantMessage("your answer")); + var content = chatClient.prompt().messages(messages).user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(3); + var userMessage = this.promptCaptor.getValue().getInstructions().get(2); + assertThat(userMessage.getContent()).isEqualTo("another question"); + } + + @Test + void whenMessagesWithUserMessageAndUserText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + List messages = List.of(new UserMessage("my question")); + var content = chatClient.prompt().messages(messages).user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); + var userMessage = this.promptCaptor.getValue().getInstructions().get(1); + assertThat(userMessage.getContent()).isEqualTo("another question"); + assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + } + + // Prompt Tests - System + + @Test + void whenPromptWithMessagesAndSystemText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var prompt = new Prompt(new UserMessage("my question"), new AssistantMessage("your answer")); + var content = chatClient.prompt(prompt).system("instructions").user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(2); + assertThat(systemMessage.getContent()).isEqualTo("instructions"); + assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + } + + @Test + void whenPromptWithSystemMessageAndNoSystemText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); + var content = chatClient.prompt(prompt).user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(3); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(systemMessage.getContent()).isEqualTo("instructions"); + assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + } + + @Test + void whenPromptWithSystemMessageAndSystemText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + var prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); + var content = chatClient.prompt(prompt).system("other instructions").user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(2); + assertThat(systemMessage.getContent()).isEqualTo("other instructions"); + assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + } + + @Test + void whenMessagesAndSystemText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + List messages = List.of(new UserMessage("my question"), new AssistantMessage("your answer")); + var content = chatClient.prompt() + .messages(messages) + .system("instructions") + .user("another question") + .call() + .content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(2); + assertThat(systemMessage.getContent()).isEqualTo("instructions"); + assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + } + + @Test + void whenMessagesWithSystemMessageAndNoSystemText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + List messages = List.of(new SystemMessage("instructions"), new UserMessage("my question")); + var content = chatClient.prompt().messages(messages).user("another question").call().content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(3); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); + assertThat(systemMessage.getContent()).isEqualTo("instructions"); + assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + } + + @Test + void whenMessagesWithSystemMessageAndSystemText() { + given(this.chatModel.call(this.promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + var chatClient = ChatClient.builder(this.chatModel).build(); + List messages = List.of(new SystemMessage("instructions"), new UserMessage("my question")); + var content = chatClient.prompt() + .messages(messages) + .system("other instructions") + .user("another question") + .call() + .content(); + + assertThat(content).isEqualTo("response"); + + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(2); + assertThat(systemMessage.getContent()).isEqualTo("other instructions"); + assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java new file mode 100644 index 00000000000..d17998e706f --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import java.nio.charset.Charset; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.core.io.ClassPathResource; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link DefaultChatClientBuilder}. + * + * @author Thomas Vitale + */ +class DefaultChatClientBuilderTests { + + @Test + void whenChatModelIsNullThenThrows() { + assertThatThrownBy(() -> new DefaultChatClientBuilder(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("the org.springframework.ai.chat.model.ChatModel must be non-null"); + } + + @Test + void whenObservationRegistryIsNullThenThrows() { + assertThatThrownBy(() -> new DefaultChatClientBuilder(mock(ChatModel.class), null, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("the io.micrometer.observation.ObservationRegistry must be non-null"); + } + + @Test + void whenUserResourceIsNullThenThrows() { + DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class)); + assertThatThrownBy(() -> builder.defaultUser(null, Charset.defaultCharset())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenUserCharsetIsNullThenThrows() { + DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class)); + assertThatThrownBy(() -> builder.defaultUser(new ClassPathResource("user-prompt.txt"), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("charset cannot be null"); + } + + @Test + void whenSystemResourceIsNullThenThrows() { + DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class)); + assertThatThrownBy(() -> builder.defaultSystem(null, Charset.defaultCharset())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenSystemCharsetIsNullThenThrows() { + DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class)); + assertThatThrownBy(() -> builder.defaultSystem(new ClassPathResource("system-prompt.txt"), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("charset cannot be null"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java new file mode 100644 index 00000000000..269fa7812dd --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -0,0 +1,1771 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; +import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.StructuredOutputConverter; +import org.springframework.ai.model.Media; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.util.MimeTypeUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link DefaultChatClient}. + * + * @author Thomas Vitale + */ +class DefaultChatClientTests { + + // Constructor + + @Test + void whenChatClientRequestIsNullThenThrow() { + assertThatThrownBy(() -> new DefaultChatClient(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("defaultChatClientRequest cannot be null"); + } + + // ChatClient + + @Test + void whenPromptThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThat(spec).isNotNull(); + } + + @Test + void whenPromptContentIsEmptyThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + assertThatThrownBy(() -> chatClient.prompt("")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("content cannot be null or empty"); + } + + @Test + void whenPromptContentThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec spec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + assertThat(spec.getMessages()).hasSize(1); + assertThat(spec.getMessages().get(0).getContent()).isEqualTo("my question"); + } + + @Test + void whenPromptWithMessagesThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec spec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt); + assertThat(spec.getMessages()).hasSize(2); + assertThat(spec.getMessages().get(0).getContent()).isEqualTo("instructions"); + assertThat(spec.getMessages().get(1).getContent()).isEqualTo("my question"); + assertThat(spec.getChatOptions()).isNull(); + } + + @Test + void whenPromptWithOptionsThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatOptions chatOptions = ChatOptionsBuilder.builder().build(); + Prompt prompt = new Prompt(List.of(), chatOptions); + DefaultChatClient.DefaultChatClientRequestSpec spec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt); + assertThat(spec.getMessages()).isEmpty(); + assertThat(spec.getChatOptions()).isEqualTo(chatOptions); + } + + @Test + void whenMutateChatClientRequest() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec spec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt() + .user("my question"); + + ChatClient.Builder newChatClientBuilder = spec.mutate(); + newChatClientBuilder.defaultUser("another question"); + ChatClient newChatClient = newChatClientBuilder.build(); + DefaultChatClient.DefaultChatClientRequestSpec newSpec = (DefaultChatClient.DefaultChatClientRequestSpec) newChatClient + .prompt(); + + assertThat(spec.getUserText()).isEqualTo("my question"); + assertThat(newSpec.getUserText()).isEqualTo("another question"); + } + + // DefaultPromptUserSpec + + @Test + void buildPromptUserSpec() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThat(spec).isNotNull(); + assertThat(spec.media()).isNotNull(); + assertThat(spec.params()).isNotNull(); + assertThat(spec.text()).isNull(); + } + + @Test + void whenUserMediaIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.media((Media[]) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("media cannot be null"); + } + + @Test + void whenUserMediaContainsNullElementsThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.media(null, (Media) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("media cannot contain null elements"); + } + + @Test + void whenUserMediaThenReturn() throws MalformedURLException { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + URL mediaUrl = URI.create("http://example.com/image.png").toURL(); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.media(new Media(MimeTypeUtils.IMAGE_PNG, mediaUrl)); + assertThat(spec.media()).hasSize(1); + assertThat(spec.media().get(0).getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_PNG); + assertThat(spec.media().get(0).getData()).isEqualTo(mediaUrl.toString()); + } + + @Test + void whenUserMediaMimeTypeIsNullWithUrlThenThrow() throws MalformedURLException { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + URL mediaUrl = URI.create("http://example.com/image.png").toURL(); + assertThatThrownBy(() -> spec.media(null, mediaUrl)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("mimeType cannot be null"); + } + + @Test + void whenUserMediaUrlIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.media(MimeTypeUtils.IMAGE_PNG, (URL) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("url cannot be null"); + } + + @Test + void whenUserMediaMimeTypeAndUrlThenReturn() throws MalformedURLException { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + URL mediaUrl = URI.create("http://example.com/image.png").toURL(); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.media(MimeTypeUtils.IMAGE_PNG, mediaUrl); + assertThat(spec.media()).hasSize(1); + assertThat(spec.media().get(0).getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_PNG); + assertThat(spec.media().get(0).getData()).isEqualTo(mediaUrl.toString()); + } + + @Test + void whenUserMediaMimeTypeIsNullWithResourceThenThrow() throws MalformedURLException { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.media(null, new ClassPathResource("image.png"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("mimeType cannot be null"); + } + + @Test + void whenUserMediaResourceIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.media(MimeTypeUtils.IMAGE_PNG, (Resource) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("resource cannot be null"); + } + + @Test + void whenUserMediaMimeTypeAndResourceThenReturn() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + Resource imageResource = new ClassPathResource("tabby-cat.png"); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.media(MimeTypeUtils.IMAGE_PNG, imageResource); + assertThat(spec.media()).hasSize(1); + assertThat(spec.media().get(0).getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_PNG); + assertThat(spec.media().get(0).getData()).isNotNull(); + } + + @Test + void whenUserTextStringIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.text((String) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenUserTextStringIsEmptyThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.text("")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenUserTextStringThenReturn() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.text("my question"); + assertThat(spec.text()).isEqualTo("my question"); + } + + @Test + void whenUserTextResourceIsNullWithCharsetThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.text(null, Charset.defaultCharset())).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenUserTextCharsetIsNullWithResourceThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + Resource textResource = new ClassPathResource("user-prompt.txt"); + assertThatThrownBy(() -> spec.text(textResource, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("charset cannot be null"); + } + + @Test + void whenUserTextResourceAndCharsetThenReturn() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + Resource textResource = new ClassPathResource("user-prompt.txt"); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.text(textResource, Charset.defaultCharset()); + assertThat(spec.text()).isEqualTo("my question"); + } + + @Test + void whenUserTextResourceIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.text((Resource) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenUserTextResourceThenReturn() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + Resource textResource = new ClassPathResource("user-prompt.txt"); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.text(textResource); + assertThat(spec.text()).isEqualTo("my question"); + } + + @Test + void whenUserParamKeyIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.param(null, "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("key cannot be null or empty"); + } + + @Test + void whenUserParamKeyIsEmptyThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.param("", "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("key cannot be null or empty"); + } + + @Test + void whenUserParamValueIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.param("key", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("value cannot be null"); + } + + @Test + void whenUserParamKeyValueThenReturn() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.param("key", "value"); + assertThat(spec.params()).containsEntry("key", "value"); + } + + @Test + void whenUserParamsIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + assertThatThrownBy(() -> spec.params(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("params cannot be null"); + } + + @Test + void whenUserParamsKeyIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + Map params = new HashMap<>(); + params.put(null, "value"); + assertThatThrownBy(() -> spec.params(params)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("param keys cannot contain null elements"); + } + + @Test + void whenUserParamsValueIsNullThenThrow() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + Map params = new HashMap<>(); + params.put("key", null); + assertThatThrownBy(() -> spec.params(params)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("param values cannot contain null elements"); + } + + @Test + void whenUserParamsThenReturn() { + DefaultChatClient.DefaultPromptUserSpec spec = new DefaultChatClient.DefaultPromptUserSpec(); + spec = (DefaultChatClient.DefaultPromptUserSpec) spec.params(Map.of("key", "value")); + assertThat(spec.params()).containsEntry("key", "value"); + } + + // DefaultPromptSystemSpec + + @Test + void buildPromptSystemSpec() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThat(spec).isNotNull(); + assertThat(spec.params()).isNotNull(); + assertThat(spec.text()).isNull(); + } + + @Test + void whenSystemTextStringIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.text((String) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenSystemTextStringIsEmptyThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.text("")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenSystemTextStringThenReturn() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + spec = (DefaultChatClient.DefaultPromptSystemSpec) spec.text("instructions"); + assertThat(spec.text()).isEqualTo("instructions"); + } + + @Test + void whenSystemTextResourceIsNullWithCharsetThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.text(null, Charset.defaultCharset())).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenSystemTextCharsetIsNullWithResourceThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + Resource textResource = new ClassPathResource("system-prompt.txt"); + assertThatThrownBy(() -> spec.text(textResource, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("charset cannot be null"); + } + + @Test + void whenSystemTextResourceAndCharsetThenReturn() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + Resource textResource = new ClassPathResource("system-prompt.txt"); + spec = (DefaultChatClient.DefaultPromptSystemSpec) spec.text(textResource, Charset.defaultCharset()); + assertThat(spec.text()).isEqualTo("instructions"); + } + + @Test + void whenSystemTextResourceIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.text((Resource) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenSystemTextResourceThenReturn() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + Resource textResource = new ClassPathResource("system-prompt.txt"); + spec = (DefaultChatClient.DefaultPromptSystemSpec) spec.text(textResource); + assertThat(spec.text()).isEqualTo("instructions"); + } + + @Test + void whenSystemParamKeyIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.param(null, "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("key cannot be null or empty"); + } + + @Test + void whenSystemParamKeyIsEmptyThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.param("", "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("key cannot be null or empty"); + } + + @Test + void whenSystemParamValueIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.param("key", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("value cannot be null"); + } + + @Test + void whenSystemParamKeyValueThenReturn() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + spec = (DefaultChatClient.DefaultPromptSystemSpec) spec.param("key", "value"); + assertThat(spec.params()).containsEntry("key", "value"); + } + + @Test + void whenSystemParamsIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + assertThatThrownBy(() -> spec.params(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("params cannot be null"); + } + + @Test + void whenSystemParamsKeyIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + Map params = new HashMap<>(); + params.put(null, "value"); + assertThatThrownBy(() -> spec.params(params)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("param keys cannot contain null elements"); + } + + @Test + void whenSystemParamsValueIsNullThenThrow() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + Map params = new HashMap<>(); + params.put("key", null); + assertThatThrownBy(() -> spec.params(params)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("param values cannot contain null elements"); + } + + @Test + void whenSystemParamsThenReturn() { + DefaultChatClient.DefaultPromptSystemSpec spec = new DefaultChatClient.DefaultPromptSystemSpec(); + spec = (DefaultChatClient.DefaultPromptSystemSpec) spec.params(Map.of("key", "value")); + assertThat(spec.params()).containsEntry("key", "value"); + } + + // DefaultAdvisorSpec + + @Test + void buildAdvisorSpec() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + assertThat(spec).isNotNull(); + assertThat(spec.getAdvisors()).isNotNull(); + assertThat(spec.getParams()).isNotNull(); + } + + @Test + void whenAdvisorParamKeyIsNullThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + assertThatThrownBy(() -> spec.param(null, "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("key cannot be null or empty"); + } + + @Test + void whenAdvisorParamKeyIsEmptyThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + assertThatThrownBy(() -> spec.param("", "value")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("key cannot be null or empty"); + } + + @Test + void whenAdvisorParamValueIsNullThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + assertThatThrownBy(() -> spec.param("key", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("value cannot be null"); + } + + @Test + void whenAdvisorParamKeyValueThenReturn() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + spec = (DefaultChatClient.DefaultAdvisorSpec) spec.param("key", "value"); + assertThat(spec.getParams()).containsEntry("key", "value"); + } + + @Test + void whenAdvisorParamsIsNullThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + assertThatThrownBy(() -> spec.params(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("params cannot be null"); + } + + @Test + void whenAdvisorKeyIsNullThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + Map params = new HashMap<>(); + params.put(null, "value"); + assertThatThrownBy(() -> spec.params(params)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("param keys cannot contain null elements"); + } + + @Test + void whenAdvisorParamsValueIsNullThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + Map params = new HashMap<>(); + params.put("key", null); + assertThatThrownBy(() -> spec.params(params)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("param values cannot contain null elements"); + } + + @Test + void whenAdvisorParamsThenReturn() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + spec = (DefaultChatClient.DefaultAdvisorSpec) spec.params(Map.of("key", "value")); + assertThat(spec.getParams()).containsEntry("key", "value"); + } + + @Test + void whenAdvisorsIsNullThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + assertThatThrownBy(() -> spec.advisors((Advisor[]) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisors cannot be null"); + } + + @Test + void whenAdvisorsContainsNullElementsThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + assertThatThrownBy(() -> spec.advisors(null, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisors cannot contain null elements"); + } + + @Test + void whenAdvisorsThenReturn() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + Advisor advisor = new SimpleLoggerAdvisor(); + spec = (DefaultChatClient.DefaultAdvisorSpec) spec.advisors(advisor); + assertThat(spec.getAdvisors()).hasSize(1); + assertThat(spec.getAdvisors().get(0)).isEqualTo(advisor); + } + + @Test + void whenAdvisorListIsNullThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + assertThatThrownBy(() -> spec.advisors((List) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisors cannot be null"); + } + + @Test + void whenAdvisorListContainsNullElementsThenThrow() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + List advisors = new ArrayList<>(); + advisors.add(null); + assertThatThrownBy(() -> spec.advisors(advisors)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisors cannot contain null elements"); + } + + @Test + void whenAdvisorListThenReturn() { + DefaultChatClient.DefaultAdvisorSpec spec = new DefaultChatClient.DefaultAdvisorSpec(); + Advisor advisor = new SimpleLoggerAdvisor(); + spec = (DefaultChatClient.DefaultAdvisorSpec) spec.advisors(List.of(advisor)); + assertThat(spec.getAdvisors()).hasSize(1); + assertThat(spec.getAdvisors().get(0)).isEqualTo(advisor); + } + + // DefaultCallResponseSpec + + @Test + void buildCallResponseSpec() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + assertThat(spec).isNotNull(); + } + + @Test + void buildCallResponseSpecWithNullRequest() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("request cannot be null"); + } + + @Test + void whenSimplePromptThenChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ChatResponse chatResponse = spec.chatResponse(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(1); + assertThat(actualPrompt.getInstructions().get(0).getContent()).isEqualTo("my question"); + } + + @Test + void whenFullPromptThenChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ChatResponse chatResponse = spec.chatResponse(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(2); + assertThat(actualPrompt.getInstructions().get(0).getContent()).isEqualTo("instructions"); + assertThat(actualPrompt.getInstructions().get(1).getContent()).isEqualTo("my question"); + } + + @Test + void whenPromptAndUserTextThenChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt) + .user("another question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ChatResponse chatResponse = spec.chatResponse(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(3); + assertThat(actualPrompt.getInstructions().get(0).getContent()).isEqualTo("instructions"); + assertThat(actualPrompt.getInstructions().get(1).getContent()).isEqualTo("my question"); + assertThat(actualPrompt.getInstructions().get(2).getContent()).isEqualTo("another question"); + } + + @Test + void whenUserTextAndMessagesThenChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + List messages = List.of(new SystemMessage("instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt() + .user("another question") + .messages(messages); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ChatResponse chatResponse = spec.chatResponse(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(3); + assertThat(actualPrompt.getInstructions().get(0).getContent()).isEqualTo("instructions"); + assertThat(actualPrompt.getInstructions().get(1).getContent()).isEqualTo("my question"); + assertThat(actualPrompt.getInstructions().get(2).getContent()).isEqualTo("another question"); + } + + @Test + void whenChatResponseIsNull() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())).willReturn(null); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ChatResponse chatResponse = spec.chatResponse(); + assertThat(chatResponse).isNull(); + } + + @Test + void whenChatResponseContentIsNull() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + String content = spec.content(); + assertThat(content).isNull(); + } + + @Test + void whenResponseEntityWithParameterizedTypeIsNull() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + assertThatThrownBy(() -> spec.responseEntity((ParameterizedTypeReference) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + @Test + void whenResponseEntityWithParameterizedTypeAndChatResponseContentNull() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ResponseEntity> responseEntity = spec + .responseEntity(new ParameterizedTypeReference<>() { + }); + assertThat(responseEntity).isNotNull(); + assertThat(responseEntity.response()).isNotNull(); + assertThat(responseEntity.entity()).isNull(); + } + + @Test + void whenResponseEntityWithParameterizedType() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(""" + [ + { "name": "James Bond" }, + { "name": "Ethan Hunt" }, + { "name": "Jason Bourne" } + ] + """))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ResponseEntity> responseEntity = spec + .responseEntity(new ParameterizedTypeReference<>() { + }); + assertThat(responseEntity.response()).isNotNull(); + assertThat(responseEntity.entity()).hasSize(3); + } + + @Test + void whenResponseEntityWithConverterIsNull() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + assertThatThrownBy(() -> spec.responseEntity((StructuredOutputConverter) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("structuredOutputConverter cannot be null"); + } + + @Test + void whenResponseEntityWithConverterAndChatResponseContentNull() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ResponseEntity> responseEntity = spec + .responseEntity(new ListOutputConverter(new DefaultConversionService())); + assertThat(responseEntity.response()).isNotNull(); + assertThat(responseEntity.entity()).isNull(); + } + + @Test + void whenResponseEntityWithConverter() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(""" + James Bond, Ethan Hunt, Jason Bourne + """))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ResponseEntity> responseEntity = spec + .responseEntity(new ListOutputConverter(new DefaultConversionService())); + assertThat(responseEntity.response()).isNotNull(); + assertThat(responseEntity.entity()).hasSize(3); + } + + @Test + void whenResponseEntityWithTypeIsNull() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + assertThatThrownBy(() -> spec.responseEntity((Class) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + @Test + void whenResponseEntityWithTypeAndChatResponseContentNull() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ResponseEntity responseEntity = spec.responseEntity(String.class); + assertThat(responseEntity.response()).isNotNull(); + assertThat(responseEntity.entity()).isNull(); + } + + @Test + void whenResponseEntityWithType() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(""" + { "name": "James Bond" } + """))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + ResponseEntity responseEntity = spec.responseEntity(Person.class); + assertThat(responseEntity.response()).isNotNull(); + assertThat(responseEntity.entity()).isNotNull(); + assertThat(responseEntity.entity().name).isEqualTo("James Bond"); + } + + @Test + void whenEntityWithParameterizedTypeIsNull() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + assertThatThrownBy(() -> spec.entity((ParameterizedTypeReference) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + @Test + void whenEntityWithParameterizedTypeAndChatResponseContentNull() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + List entity = spec.entity(new ParameterizedTypeReference<>() { + }); + assertThat(entity).isNull(); + } + + @Test + void whenEntityWithParameterizedType() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(""" + [ + { "name": "James Bond" }, + { "name": "Ethan Hunt" }, + { "name": "Jason Bourne" } + ] + """))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + List entity = spec.entity(new ParameterizedTypeReference<>() { + }); + assertThat(entity).hasSize(3); + } + + @Test + void whenEntityWithConverterIsNull() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + assertThatThrownBy(() -> spec.entity((StructuredOutputConverter) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("structuredOutputConverter cannot be null"); + } + + @Test + void whenEntityWithConverterAndChatResponseContentNull() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + List entity = spec.entity(new ListOutputConverter(new DefaultConversionService())); + assertThat(entity).isNull(); + } + + @Test + void whenEntityWithConverter() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(""" + James Bond, Ethan Hunt, Jason Bourne + """))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + List entity = spec.entity(new ListOutputConverter(new DefaultConversionService())); + assertThat(entity).hasSize(3); + } + + @Test + void whenEntityWithTypeIsNull() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + assertThatThrownBy(() -> spec.entity((Class) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + @Test + void whenEntityWithTypeAndChatResponseContentNull() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(null))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + String entity = spec.entity(String.class); + assertThat(entity).isNull(); + } + + @Test + void whenEntityWithType() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(""" + { "name": "James Bond" } + """))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec( + chatClientRequestSpec); + + Person entity = spec.entity(Person.class); + assertThat(entity).isNotNull(); + assertThat(entity.name()).isEqualTo("James Bond"); + } + + // DefaultStreamResponseSpec + + @Test + void buildStreamResponseSpec() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(); + DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( + chatClientRequestSpec); + assertThat(spec).isNotNull(); + } + + @Test + void buildStreamResponseSpecWithNullRequest() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("request cannot be null"); + } + + @Test + void whenSimplePromptThenFluxChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( + chatClientRequestSpec); + + ChatResponse chatResponse = spec.chatResponse().blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(1); + assertThat(actualPrompt.getInstructions().get(0).getContent()).isEqualTo("my question"); + } + + @Test + void whenFullPromptThenFluxChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt); + DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( + chatClientRequestSpec); + + ChatResponse chatResponse = spec.chatResponse().blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(2); + assertThat(actualPrompt.getInstructions().get(0).getContent()).isEqualTo("instructions"); + assertThat(actualPrompt.getInstructions().get(1).getContent()).isEqualTo("my question"); + } + + @Test + void whenPromptAndUserTextThenFluxChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt(prompt) + .user("another question"); + DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( + chatClientRequestSpec); + + ChatResponse chatResponse = spec.chatResponse().blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(3); + assertThat(actualPrompt.getInstructions().get(0).getContent()).isEqualTo("instructions"); + assertThat(actualPrompt.getInstructions().get(1).getContent()).isEqualTo("my question"); + assertThat(actualPrompt.getInstructions().get(2).getContent()).isEqualTo("another question"); + } + + @Test + void whenUserTextAndMessagesThenFluxChatResponse() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + List messages = List.of(new SystemMessage("instructions"), new UserMessage("my question")); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt() + .user("another question") + .messages(messages); + DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( + chatClientRequestSpec); + + ChatResponse chatResponse = spec.chatResponse().blockLast(); + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo("response"); + + Prompt actualPrompt = promptCaptor.getValue(); + assertThat(actualPrompt.getInstructions()).hasSize(3); + assertThat(actualPrompt.getInstructions().get(0).getContent()).isEqualTo("instructions"); + assertThat(actualPrompt.getInstructions().get(1).getContent()).isEqualTo("my question"); + assertThat(actualPrompt.getInstructions().get(2).getContent()).isEqualTo("another question"); + } + + @Test + void whenChatResponseContentIsNullThenReturnFlux() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.stream(promptCaptor.capture())) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage(null)))))); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient + .prompt("my question"); + DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec( + chatClientRequestSpec); + + String content = spec.content().blockLast(); + assertThat(content).isNull(); + } + + // DefaultChatClientRequestSpec + + @Test + void buildChatClientRequestSpec() { + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec( + chatModel, null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), + Map.of(), ObservationRegistry.NOOP, null, Map.of()); + assertThat(spec).isNotNull(); + } + + @Test + void whenChatModelIsNullThenThrow() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), null, + Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), + ObservationRegistry.NOOP, null, Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatModel cannot be null"); + } + + @Test + void whenObservationRegistryIsNullThenThrow() { + assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null, + Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), null, + null, Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("observationRegistry cannot be null"); + } + + @Test + void whenAdvisorConsumerIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.advisors((Consumer) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("consumer cannot be null"); + } + + @Test + void whenAdvisorConsumerThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + Advisor loggerAdvisor = new SimpleLoggerAdvisor(); + spec = spec.advisors(advisor -> advisor.advisors(loggerAdvisor).param("topic", "AI")); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getAdvisors()).contains(loggerAdvisor); + assertThat(defaultSpec.getAdvisorParams()).containsEntry("topic", "AI"); + } + + @Test + void whenRequestAdvisorsWithNullElementsThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.advisors((Advisor) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisors cannot contain null elements"); + } + + @Test + void whenRequestAdvisorsThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + Advisor advisor = new SimpleLoggerAdvisor(); + spec = spec.advisors(advisor); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getAdvisors()).contains(advisor); + } + + @Test + void whenRequestAdvisorListIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.advisors((List) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisors cannot be null"); + } + + @Test + void whenRequestAdvisorListWithNullElementsThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + List advisors = new ArrayList<>(); + advisors.add(null); + assertThatThrownBy(() -> spec.advisors(advisors)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisors cannot contain null elements"); + } + + @Test + void whenRequestAdvisorListThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + List advisors = List.of(new SimpleLoggerAdvisor()); + spec = spec.advisors(advisors); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getAdvisors()).containsAll(advisors); + } + + @Test + void whenMessagesWithNullElementsThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.messages((Message) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("messages cannot contain null elements"); + } + + @Test + void whenMessagesThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + Message message = new UserMessage("question"); + spec = spec.messages(message); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getMessages()).contains(message); + } + + @Test + void whenMessageListIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.messages((List) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("messages cannot be null"); + } + + @Test + void whenMessageListWithNullElementsThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + List messages = new ArrayList<>(); + messages.add(null); + assertThatThrownBy(() -> spec.messages(messages)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("messages cannot contain null elements"); + } + + @Test + void whenMessageListThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + List messages = List.of(new UserMessage("question")); + spec = spec.messages(messages); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getMessages()).containsAll(messages); + } + + @Test + void whenOptionsIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.options(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("options cannot be null"); + } + + @Test + void whenOptionsThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + ChatOptions options = ChatOptionsBuilder.builder().build(); + spec = spec.options(options); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getChatOptions()).isEqualTo(options); + } + + @Test + void whenFunctionNameIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function(null, "description", input -> "hello")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void whenFunctionNameIsEmptyThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function("", "description", input -> "hello")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void whenFunctionDescriptionIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function("name", null, input -> "hello")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("description cannot be null or empty"); + } + + @Test + void whenFunctionDescriptionIsEmptyThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function("name", "", input -> "hello")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("description cannot be null or empty"); + } + + @Test + void whenFunctionLambdaIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function("name", "description", (Function) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("function cannot be null"); + } + + @Test + void whenFunctionThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.function("name", "description", input -> "hello"); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + } + + @Test + void whenFunctionAndInputTypeThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.function("name", "description", String.class, input -> "hello"); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + } + + @Test + void whenBiFunctionNameIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function(null, "description", (input, ctx) -> "hello")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void whenBiFunctionNameIsEmptyThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function("", "description", (input, ctx) -> "hello")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void whenBiFunctionDescriptionIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function("name", null, (input, ctx) -> "hello")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("description cannot be null or empty"); + } + + @Test + void whenBiFunctionDescriptionIsEmptyThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function("name", "", (input, ctx) -> "hello")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("description cannot be null or empty"); + } + + @Test + void whenBiFunctionLambdaIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.function("name", "description", (BiFunction) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("biFunction cannot be null"); + } + + @Test + void whenBiFunctionThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.function("name", "description", (input, ctx) -> "hello"); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + } + + @Test + void whenFunctionBeanNamesElementIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.functions("myFunction", null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("functionBeanNames cannot contain null elements"); + } + + @Test + void whenFunctionBeanNamesThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + String functionBeanName = "myFunction"; + spec = spec.functions(functionBeanName); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionNames()).contains(functionBeanName); + } + + @Test + void whenFunctionCallbacksElementIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.functions(mock(FunctionCallback.class), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("functionCallbacks cannot contain null elements"); + } + + @Test + void whenFunctionCallbacksThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + FunctionCallback functionCallback = mock(FunctionCallback.class); + spec = spec.functions(functionCallback); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionCallbacks()).contains(functionCallback); + } + + @Test + void whenToolContextIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.toolContext(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolContext cannot be null"); + } + + @Test + void whenToolContextKeyIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + Map toolContext = new HashMap<>(); + toolContext.put(null, "value"); + assertThatThrownBy(() -> spec.toolContext(toolContext)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolContext keys cannot contain null elements"); + } + + @Test + void whenToolContextValueIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + Map toolContext = new HashMap<>(); + toolContext.put("key", null); + assertThatThrownBy(() -> spec.toolContext(toolContext)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolContext values cannot contain null elements"); + } + + @Test + void whenToolContextThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + Map toolContext = Map.of("key", "value"); + spec = spec.toolContext(toolContext); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getToolContext()).containsEntry("key", "value"); + } + + @Test + void whenSystemTextIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.system((String) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenSystemTextIsEmptyThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.system("")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenSystemTextThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.system(system -> system.text("instructions")); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getSystemText()).isEqualTo("instructions"); + } + + @Test + void whenSystemResourceIsNullWithCharsetThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.system(null, Charset.defaultCharset())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenSystemCharsetIsNullWithResourceThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.system(new ClassPathResource("system-prompt.txt"), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("charset cannot be null"); + } + + @Test + void whenSystemResourceAndCharsetThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.system(system -> system.text(new ClassPathResource("system-prompt.txt"), Charset.defaultCharset())); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getSystemText()).isEqualTo("instructions"); + } + + @Test + void whenSystemResourceIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.system((Resource) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenSystemResourceThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.system(systemSpec -> systemSpec.text(new ClassPathResource("system-prompt.txt"))); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getSystemText()).isEqualTo("instructions"); + } + + @Test + void whenSystemConsumerIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.system((Consumer) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("consumer cannot be null"); + } + + @Test + void whenSystemConsumerThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.system(system -> system.text("my instruction about {topic}").param("topic", "AI")); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getSystemText()).isEqualTo("my instruction about {topic}"); + assertThat(defaultSpec.getSystemParams()).containsEntry("topic", "AI"); + } + + @Test + void whenSystemConsumerWithExistingSystemTextThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().system("my instruction"); + spec = spec.system(system -> system.text("my instruction about {topic}").param("topic", "AI")); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getSystemText()).isEqualTo("my instruction about {topic}"); + assertThat(defaultSpec.getSystemParams()).containsEntry("topic", "AI"); + } + + @Test + void whenSystemConsumerWithoutSystemTextThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().system("my instruction about {topic}"); + spec = spec.system(system -> system.param("topic", "AI")); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getSystemText()).isEqualTo("my instruction about {topic}"); + assertThat(defaultSpec.getSystemParams()).containsEntry("topic", "AI"); + } + + @Test + void whenUserTextIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.user((String) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenUserTextIsEmptyThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.user("")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null or empty"); + } + + @Test + void whenUserTextThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.user(user -> user.text("my question")); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getUserText()).isEqualTo("my question"); + } + + @Test + void whenUserResourceIsNullWithCharsetThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.user(null, Charset.defaultCharset())).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenUserCharsetIsNullWithResourceThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.user(new ClassPathResource("user-prompt.txt"), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("charset cannot be null"); + } + + @Test + void whenUserResourceAndCharsetThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.user(user -> user.text(new ClassPathResource("user-prompt.txt"), Charset.defaultCharset())); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getUserText()).isEqualTo("my question"); + } + + @Test + void whenUserResourceIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.user((Resource) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("text cannot be null"); + } + + @Test + void whenUserResourceThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.user(user -> user.text(new ClassPathResource("user-prompt.txt"))); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getUserText()).isEqualTo("my question"); + } + + @Test + void whenUserConsumerIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.user((Consumer) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("consumer cannot be null"); + } + + @Test + void whenUserConsumerThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.user(user -> user.text("my question about {topic}") + .param("topic", "AI") + .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("tabby-cat.png"))); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getUserText()).isEqualTo("my question about {topic}"); + assertThat(defaultSpec.getUserParams()).containsEntry("topic", "AI"); + assertThat(defaultSpec.getMedia()).hasSize(1); + } + + @Test + void whenUserConsumerWithExistingUserTextThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("my question"); + spec = spec.user(user -> user.text("my question about {topic}") + .param("topic", "AI") + .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("tabby-cat.png"))); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getUserText()).isEqualTo("my question about {topic}"); + assertThat(defaultSpec.getUserParams()).containsEntry("topic", "AI"); + assertThat(defaultSpec.getMedia()).hasSize(1); + } + + @Test + void whenUserConsumerWithoutUserTextThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().user("my question about {topic}"); + spec = spec.user(user -> user.param("topic", "AI") + .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("tabby-cat.png"))); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getUserText()).isEqualTo("my question about {topic}"); + assertThat(defaultSpec.getUserParams()).containsEntry("topic", "AI"); + assertThat(defaultSpec.getMedia()).hasSize(1); + } + + record Person(String name) { + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java new file mode 100644 index 00000000000..8d392814f4f --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.model.ChatModel; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link AdvisedRequest}. + * + * @author Thomas Vitale + */ +class AdvisedRequestTests { + + @Test + void buildAdvisedRequest() { + AdvisedRequest request = new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()); + assertThat(request).isNotNull(); + } + + @Test + void whenChatModelIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(null, "user", null, null, List.of(), List.of(), List.of(), + List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatModel cannot be null"); + } + + @Test + void whenUserTextIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), null, null, null, List.of(), List.of(), + List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("userText cannot be null or empty"); + } + + @Test + void whenUserTextIsEmptyThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "", null, null, List.of(), List.of(), + List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("userText cannot be null or empty"); + } + + @Test + void whenMediaIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), + List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("media cannot be null"); + } + + @Test + void whenFunctionNamesIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), null, + List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("functionNames cannot be null"); + } + + @Test + void whenFunctionCallbacksIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + null, List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("functionCallbacks cannot be null"); + } + + @Test + void whenMessagesIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + List.of(), null, Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("messages cannot be null"); + } + + @Test + void whenUserParamsIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + List.of(), List.of(), null, Map.of(), List.of(), Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("userParams cannot be null"); + } + + @Test + void whenSystemParamsIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + List.of(), List.of(), Map.of(), null, List.of(), Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("systemParams cannot be null"); + } + + @Test + void whenAdvisorsIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + List.of(), List.of(), Map.of(), Map.of(), null, Map.of(), Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisors cannot be null"); + } + + @Test + void whenAdvisorParamsIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + List.of(), List.of(), Map.of(), Map.of(), List.of(), null, Map.of(), Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisorParams cannot be null"); + } + + @Test + void whenAdviseContextIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), null, Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("adviseContext cannot be null"); + } + + @Test + void whenToolContextIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), + List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolContext cannot be null"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java new file mode 100644 index 00000000000..ccf3637da82 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.model.ChatResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link AdvisedResponse}. + * + * @author Thomas Vitale + */ +class AdvisedResponseTests { + + @Test + void buildAdvisedResponse() { + AdvisedResponse advisedResponse = new AdvisedResponse(mock(ChatResponse.class), Map.of()); + assertThat(advisedResponse).isNotNull(); + } + + @Test + void whenAdviseContextIsNullThenThrows() { + assertThatThrownBy(() -> new AdvisedResponse(mock(ChatResponse.class), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("adviseContext cannot be null"); + } + + @Test + void whenAdviseContextKeysIsNullThenThrows() { + Map adviseContext = new HashMap<>(); + adviseContext.put(null, "value"); + assertThatThrownBy(() -> new AdvisedResponse(mock(ChatResponse.class), adviseContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("adviseContext keys cannot be null"); + } + + @Test + void whenAdviseContextValuesIsNullThenThrows() { + Map adviseContext = new HashMap<>(); + adviseContext.put("key", null); + assertThatThrownBy(() -> new AdvisedResponse(mock(ChatResponse.class), adviseContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("adviseContext values cannot be null"); + } + + @Test + void whenBuildFromNullAdvisedResponseThenThrows() { + assertThatThrownBy(() -> AdvisedResponse.from(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("advisedResponse cannot be null"); + } + + @Test + void buildFromAdvisedResponse() { + AdvisedResponse advisedResponse = new AdvisedResponse(mock(ChatResponse.class), Map.of()); + AdvisedResponse.Builder builder = AdvisedResponse.from(advisedResponse); + assertThat(builder).isNotNull(); + } + + @Test + void whenUpdateFromNullContextThenThrows() { + AdvisedResponse advisedResponse = new AdvisedResponse(mock(ChatResponse.class), Map.of()); + assertThatThrownBy(() -> advisedResponse.updateContext(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("contextTransform cannot be null"); + } + +} diff --git a/spring-ai-core/src/test/resources/system-prompt.txt b/spring-ai-core/src/test/resources/system-prompt.txt new file mode 100644 index 00000000000..e468cde2e7d --- /dev/null +++ b/spring-ai-core/src/test/resources/system-prompt.txt @@ -0,0 +1 @@ +instructions \ No newline at end of file diff --git a/spring-ai-core/src/test/resources/tabby-cat.png b/spring-ai-core/src/test/resources/tabby-cat.png new file mode 100644 index 00000000000..4c564e4fb74 Binary files /dev/null and b/spring-ai-core/src/test/resources/tabby-cat.png differ diff --git a/spring-ai-core/src/test/resources/user-prompt.txt b/spring-ai-core/src/test/resources/user-prompt.txt new file mode 100644 index 00000000000..822f5a48082 --- /dev/null +++ b/spring-ai-core/src/test/resources/user-prompt.txt @@ -0,0 +1 @@ +my question \ No newline at end of file