diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiPropertiesTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiPropertiesTests.java index e18351d80a8..7d3becba0d0 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiPropertiesTests.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiPropertiesTests.java @@ -37,6 +37,7 @@ * {@link ZhiPuAiEmbeddingProperties}. * * @author Geng Rong + * @author YunKui Lu */ public class ZhiPuAiPropertiesTests { @@ -243,7 +244,9 @@ public void chatOptionsTest() { "required": ["location", "lat", "lon", "unit"] } """, - "spring.ai.zhipuai.chat.options.user=userXYZ" + "spring.ai.zhipuai.chat.options.user=userXYZ", + "spring.ai.zhipuai.chat.options.response-format.type=json_object", + "spring.ai.zhipuai.chat.options.thinking.type=disabled" ) // @formatter:on .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, @@ -262,6 +265,8 @@ public void chatOptionsTest() { assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56); assertThat(chatProperties.getOptions().getRequestId()).isEqualTo("RequestId"); assertThat(chatProperties.getOptions().getDoSample()).isEqualTo(Boolean.TRUE); + assertThat(chatProperties.getOptions().getResponseFormat().type()).isEqualTo("json_object"); + assertThat(chatProperties.getOptions().getThinking().type()).isEqualTo("disabled"); JSONAssert.assertEquals("{\"type\":\"function\",\"function\":{\"name\":\"toolChoiceFunctionName\"}}", chatProperties.getOptions().getToolChoice(), JSONCompareMode.LENIENT); diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index c31320defe1..d79557d534f 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -22,6 +22,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -30,9 +31,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -42,6 +45,7 @@ * @author Geng Rong * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author YunKui Lu * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -104,6 +108,16 @@ public class ZhiPuAiChatOptions implements ToolCallingChatOptions { */ private @JsonProperty("do_sample") Boolean doSample; + /** + * Control the format of the model output. Set to `json_object` to ensure the message is a valid JSON object. + */ + private @JsonProperty("response_format") ChatCompletionRequest.ResponseFormat responseFormat; + + /** + * Control whether to enable the large model's chain of thought. Available options: (default) enabled, disabled. + */ + private @JsonProperty("thinking") ChatCompletionRequest.Thinking thinking; + /** * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests. */ @@ -146,6 +160,8 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext()) + .responseFormat(fromOptions.getResponseFormat()) + .thinking(fromOptions.getThinking()) .build(); } @@ -244,6 +260,24 @@ public void setDoSample(Boolean doSample) { this.doSample = doSample; } + public ChatCompletionRequest.ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public ZhiPuAiChatOptions setResponseFormat(ChatCompletionRequest.ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public ChatCompletionRequest.Thinking getThinking() { + return this.thinking; + } + + public ZhiPuAiChatOptions setThinking(ChatCompletionRequest.Thinking thinking) { + this.thinking = thinking; + return this; + } + @Override @JsonIgnore public Double getFrequencyPenalty() { @@ -311,138 +345,53 @@ public Map getToolContext() { @Override public void setToolContext(Map toolContext) { + Assert.notNull(toolContext, "toolContext cannot be null"); this.toolContext = toolContext; } + @Override + public final boolean equals(Object o) { + if (!(o instanceof ZhiPuAiChatOptions that)) { + return false; + } + + return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.stop, that.stop) && Objects.equals(this.temperature, that.temperature) + && Objects.equals(this.topP, that.topP) && Objects.equals(this.tools, that.tools) + && Objects.equals(this.toolChoice, that.toolChoice) && Objects.equals(this.user, that.user) + && Objects.equals(this.requestId, that.requestId) && Objects.equals(this.doSample, that.doSample) + && Objects.equals(this.responseFormat, that.responseFormat) + && Objects.equals(this.thinking, that.thinking) + && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.toolNames, that.toolNames) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolContext, that.toolContext); + } + @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); - result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); - result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); - result = prime * result - + ((this.internalToolExecutionEnabled == null) ? 0 : this.internalToolExecutionEnabled.hashCode()); - result = prime * result + ((this.toolCallbacks == null) ? 0 : this.toolCallbacks.hashCode()); - result = prime * result + ((this.toolNames == null) ? 0 : this.toolNames.hashCode()); - result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); + int result = Objects.hashCode(this.model); + result = 31 * result + Objects.hashCode(this.maxTokens); + result = 31 * result + Objects.hashCode(this.stop); + result = 31 * result + Objects.hashCode(this.temperature); + result = 31 * result + Objects.hashCode(this.topP); + result = 31 * result + Objects.hashCode(this.tools); + result = 31 * result + Objects.hashCode(this.toolChoice); + result = 31 * result + Objects.hashCode(this.user); + result = 31 * result + Objects.hashCode(this.requestId); + result = 31 * result + Objects.hashCode(this.doSample); + result = 31 * result + Objects.hashCode(this.responseFormat); + result = 31 * result + Objects.hashCode(this.thinking); + result = 31 * result + Objects.hashCode(this.toolCallbacks); + result = 31 * result + Objects.hashCode(this.toolNames); + result = 31 * result + Objects.hashCode(this.internalToolExecutionEnabled); + result = 31 * result + Objects.hashCode(this.toolContext); return result; } @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj; - if (this.model == null) { - if (other.model != null) { - return false; - } - } - else if (!this.model.equals(other.model)) { - return false; - } - if (this.maxTokens == null) { - if (other.maxTokens != null) { - return false; - } - } - else if (!this.maxTokens.equals(other.maxTokens)) { - return false; - } - if (this.stop == null) { - if (other.stop != null) { - return false; - } - } - else if (!this.stop.equals(other.stop)) { - return false; - } - if (this.temperature == null) { - if (other.temperature != null) { - return false; - } - } - else if (!this.temperature.equals(other.temperature)) { - return false; - } - if (this.topP == null) { - if (other.topP != null) { - return false; - } - } - else if (!this.topP.equals(other.topP)) { - return false; - } - if (this.tools == null) { - if (other.tools != null) { - return false; - } - } - else if (!this.tools.equals(other.tools)) { - return false; - } - if (this.toolChoice == null) { - if (other.toolChoice != null) { - return false; - } - } - else if (!this.toolChoice.equals(other.toolChoice)) { - return false; - } - if (this.user == null) { - if (other.user != null) { - return false; - } - } - else if (!this.user.equals(other.user)) { - return false; - } - if (this.requestId == null) { - if (other.requestId != null) { - return false; - } - } - else if (!this.requestId.equals(other.requestId)) { - return false; - } - if (this.doSample == null) { - if (other.doSample != null) { - return false; - } - } - else if (!this.doSample.equals(other.doSample)) { - return false; - } - if (this.internalToolExecutionEnabled == null) { - if (other.internalToolExecutionEnabled != null) { - return false; - } - } - else if (!this.internalToolExecutionEnabled.equals(other.internalToolExecutionEnabled)) { - return false; - } - if (this.toolContext == null) { - if (other.toolContext != null) { - return false; - } - } - else if (!this.toolContext.equals(other.toolContext)) { - return false; - } - return true; + public String toString() { + return "ZhiPuAiChatOptions: " + ModelOptionsUtils.toJsonString(this); } @Override @@ -610,6 +559,16 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder responseFormat(ChatCompletionRequest.ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder thinking(ChatCompletionRequest.Thinking thinking) { + this.options.thinking = thinking; + return this; + } + public ZhiPuAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index ec22794756d..6dfde61ba28 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -652,6 +652,9 @@ public void setJsonSchema(String jsonSchema) { * logged and can be used for debugging purposes. * @param doSample If set, the model will use sampling to generate the next token. If * not set, the model will use greedy decoding to generate the next token. + * @param responseFormat Control the format of the model output. Set to `json_object` + * to ensure the message is a valid JSON object. + * @param thinking Control whether to enable the large model's chain of thought. */ @JsonInclude(Include.NON_NULL) public record ChatCompletionRequest(// @formatter:off @@ -664,9 +667,11 @@ public record ChatCompletionRequest(// @formatter:off @JsonProperty("top_p") Double topP, @JsonProperty("tools") List tools, @JsonProperty("tool_choice") Object toolChoice, - @JsonProperty("user") String user, + @JsonProperty("user_id") String user, @JsonProperty("request_id") String requestId, - @JsonProperty("do_sample") Boolean doSample) { // @formatter:on + @JsonProperty("do_sample") Boolean doSample, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("thinking") Thinking thinking) { // @formatter:on /** * Shortcut constructor for a chat completion request with the given messages and @@ -676,7 +681,7 @@ public record ChatCompletionRequest(// @formatter:off * @param temperature What sampling temperature to use, between 0 and 1. */ public ChatCompletionRequest(List messages, String model, Double temperature) { - this(messages, model, null, null, false, temperature, null, null, null, null, null, null); + this(messages, model, null, null, false, temperature, null, null, null, null, null, null, null, null); } /** @@ -691,7 +696,7 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { - this(messages, model, null, null, stream, temperature, null, null, null, null, null, null); + this(messages, model, null, null, stream, temperature, null, null, null, null, null, null, null, null); } /** @@ -706,7 +711,7 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { - this(messages, model, null, null, false, 0.8, null, tools, toolChoice, null, null, null); + this(messages, model, null, null, false, 0.8, null, tools, toolChoice, null, null, null, null, null); } /** @@ -719,7 +724,7 @@ public ChatCompletionRequest(List messages, String model, * terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, Boolean stream) { - this(messages, null, null, null, stream, null, null, null, null, null, null, null); + this(messages, null, null, null, stream, null, null, null, null, null, null, null, null, null); } /** @@ -754,7 +759,32 @@ public static Object function(String functionName) { */ @JsonInclude(Include.NON_NULL) public record ResponseFormat(@JsonProperty("type") String type) { + + public static ResponseFormat text() { + return new ResponseFormat("text"); + } + + public static ResponseFormat jsonObject() { + return new ResponseFormat("json_object"); + } + } + + /** + * Control whether to enable the large model's chain of thought + * + * @param type Available options: (default) enabled, disabled + */ + @JsonInclude(Include.NON_NULL) + public record Thinking(@JsonProperty("type") String type) { + public static Thinking enabled() { + return new Thinking("enabled"); + } + + public static Thinking disabled() { + return new Thinking("disabled"); + } } + } /** diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java new file mode 100644 index 00000000000..435ca9c24f4 --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java @@ -0,0 +1,376 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.zhipuai; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link ZhiPuAiChatOptions}. + * + * @author YunKui Lu + */ +class ZhiPuAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + List stopSequences = List.of("stop1", "stop2"); + List tools = new ArrayList<>(); + String toolChoice = "auto"; + Map toolContext = Map.of("keyA", "valueA"); + List toolCallbacks = new ArrayList<>(); + Set toolNames = Set.of("tool1", "tool2"); + ChatCompletionRequest.ResponseFormat responseFormat = new ChatCompletionRequest.ResponseFormat("json_object"); + ChatCompletionRequest.Thinking thinking = new ChatCompletionRequest.Thinking("enabled"); + + ZhiPuAiChatOptions options = ZhiPuAiChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stop(stopSequences) + .temperature(0.7) + .topP(0.9) + .tools(tools) + .toolChoice(toolChoice) + .user("test-user") + .requestId("test-request-id") + .doSample(true) + .toolCallbacks(toolCallbacks) + .toolNames(toolNames) + .internalToolExecutionEnabled(false) + .toolContext(toolContext) + .responseFormat(responseFormat) + .thinking(thinking) + .build(); + + assertThat(options) + .extracting("model", "maxTokens", "stop", "temperature", "topP", "tools", "toolChoice", "user", "requestId", + "doSample", "toolCallbacks", "toolNames", "internalToolExecutionEnabled", "toolContext", + "responseFormat", "thinking") + .containsExactly("test-model", 100, stopSequences, 0.7, 0.9, tools, toolChoice, "test-user", + "test-request-id", true, toolCallbacks, toolNames, false, toolContext, responseFormat, thinking); + } + + @Test + void testCopy() { + List stopSequences = List.of("stop1"); + List tools = new ArrayList<>(); + String toolChoice = "none"; + List toolCallbacks = new ArrayList<>(); + Set toolNames = Set.of("tool1"); + ChatCompletionRequest.ResponseFormat responseFormat = new ChatCompletionRequest.ResponseFormat("json_object"); + ChatCompletionRequest.Thinking thinking = new ChatCompletionRequest.Thinking("disabled"); + + ZhiPuAiChatOptions originalOptions = ZhiPuAiChatOptions.builder() + .model("test-model") + .maxTokens(50) + .stop(stopSequences) + .temperature(0.7) + .topP(0.9) + .tools(tools) + .toolChoice(toolChoice) + .user("test-user") + .requestId("test-request-id") + .doSample(true) + .toolCallbacks(toolCallbacks) + .toolNames(toolNames) + .internalToolExecutionEnabled(true) + .toolContext(Map.of("key1", "value1")) + .responseFormat(responseFormat) + .thinking(thinking) + .build(); + + ZhiPuAiChatOptions copiedOptions = originalOptions.copy(); + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + } + + @Test + void testSetters() { + List stopSequences = List.of("stop1", "stop2"); + List tools = new ArrayList<>(); + String toolChoice = "auto"; + Map toolContext = Map.of("key2", "value2"); + List toolCallbacks = new ArrayList<>(); + Set toolNames = Set.of("tool1", "tool2"); + ChatCompletionRequest.ResponseFormat responseFormat = new ChatCompletionRequest.ResponseFormat("json_object"); + ChatCompletionRequest.Thinking thinking = new ChatCompletionRequest.Thinking("enabled"); + + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + options.setModel("test-model"); + options.setMaxTokens(100); + options.setStop(stopSequences); + options.setTemperature(0.7); + options.setTopP(0.9); + options.setTools(tools); + options.setToolChoice(toolChoice); + options.setUser("test-user"); + options.setRequestId("test-request-id"); + options.setDoSample(true); + options.setToolCallbacks(toolCallbacks); + options.setToolNames(toolNames); + options.setInternalToolExecutionEnabled(false); + options.setToolContext(toolContext); + options.setResponseFormat(responseFormat); + options.setThinking(thinking); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getStop()).isEqualTo(stopSequences); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getTools()).isEqualTo(tools); + assertThat(options.getToolChoice()).isEqualTo(toolChoice); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getRequestId()).isEqualTo("test-request-id"); + assertThat(options.getDoSample()).isEqualTo(true); + assertThat(options.getToolCallbacks()).isEqualTo(toolCallbacks); + assertThat(options.getToolNames()).isEqualTo(toolNames); + assertThat(options.getInternalToolExecutionEnabled()).isEqualTo(false); + assertThat(options.getToolContext()).isEqualTo(toolContext); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + assertThat(options.getThinking()).isEqualTo(thinking); + assertThat(options.getStopSequences()).isEqualTo(stopSequences); + } + + @Test + void testDefaultValues() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getTools()).isNull(); + assertThat(options.getToolChoice()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getRequestId()).isNull(); + assertThat(options.getDoSample()).isNull(); + assertThat(options.getToolCallbacks()).isNotNull().isEmpty(); + assertThat(options.getToolNames()).isNotNull().isEmpty(); + assertThat(options.getInternalToolExecutionEnabled()).isNull(); + assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getThinking()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getTopK()).isNull(); + } + + @Test + void testEqualsAndHashCode() { + ZhiPuAiChatOptions options1 = ZhiPuAiChatOptions.builder() + .model("test-model") + .temperature(0.7) + .maxTokens(100) + .build(); + + ZhiPuAiChatOptions options2 = ZhiPuAiChatOptions.builder() + .model("test-model") + .temperature(0.7) + .maxTokens(100) + .build(); + + ZhiPuAiChatOptions options3 = ZhiPuAiChatOptions.builder() + .model("different-model") + .temperature(0.7) + .maxTokens(100) + .build(); + + // Test equals + assertThat(options1).isEqualTo(options2); + assertThat(options1).isNotEqualTo(options3); + assertThat(options1).isNotEqualTo(null); + assertThat(options1).isEqualTo(options1); + + // Test hashCode + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode()); + } + + @Test + void testBuilderWithNullValues() { + ZhiPuAiChatOptions options = ZhiPuAiChatOptions.builder().temperature(null).stop(null).tools(null).build(); + + assertThat(options.getModel()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTools()).isNull(); + } + + @Test + void testBuilderChaining() { + ZhiPuAiChatOptions.Builder builder = ZhiPuAiChatOptions.builder(); + + ZhiPuAiChatOptions.Builder result = builder.model("test-model").temperature(0.7).maxTokens(100); + + assertThat(result).isSameAs(builder); + + ZhiPuAiChatOptions options = result.build(); + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getMaxTokens()).isEqualTo(100); + } + + @Test + void testNullAndEmptyCollections() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + + // Test setting null collections + options.setStop(null); + options.setTools(null); + + assertThat(options.getStop()).isNull(); + assertThat(options.getTools()).isNull(); + assertThat(options.getToolCallbacks()).isEmpty(); + assertThat(options.getToolNames()).isEmpty(); + assertThat(options.getToolContext()).isEmpty(); + + // Test setting empty collections + options.setStop(new ArrayList<>()); + options.setTools(new ArrayList<>()); + options.setToolCallbacks(new ArrayList<>()); + options.setToolNames(new HashSet<>()); + options.setToolContext(new HashMap<>()); + + assertThat(options.getStop()).isEmpty(); + assertThat(options.getTools()).isEmpty(); + assertThat(options.getToolCallbacks()).isEmpty(); + assertThat(options.getToolNames()).isEmpty(); + assertThat(options.getToolContext()).isEmpty(); + } + + @Test + void testStopSequencesAlias() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + List stopSequences = List.of("stop1", "stop2"); + + // Setting stopSequences should also set stop + options.setStopSequences(stopSequences); + assertThat(options.getStopSequences()).isEqualTo(stopSequences); + assertThat(options.getStop()).isEqualTo(stopSequences); + + // Setting stop should also update stopSequences + List newStop = List.of("stop3", "stop4"); + options.setStop(newStop); + assertThat(options.getStop()).isEqualTo(newStop); + assertThat(options.getStopSequences()).isEqualTo(newStop); + } + + @Test + void testFromOptions() { + ZhiPuAiChatOptions source = ZhiPuAiChatOptions.builder() + .model("test-model") + .temperature(0.7) + .maxTokens(100) + .doSample(true) + .requestId("test-request-id") + .build(); + + ZhiPuAiChatOptions result = ZhiPuAiChatOptions.fromOptions(source); + assertThat(result.getModel()).isEqualTo("test-model"); + assertThat(result.getTemperature()).isEqualTo(0.7); + assertThat(result.getMaxTokens()).isEqualTo(100); + assertThat(result.getDoSample()).isEqualTo(true); + assertThat(result.getRequestId()).isEqualTo("test-request-id"); + } + + @Test + void testCopyChangeIndependence() { + ZhiPuAiChatOptions original = ZhiPuAiChatOptions.builder().model("original-model").temperature(0.5).build(); + + ZhiPuAiChatOptions copied = original.copy(); + + // Modify original + original.setModel("modified-model"); + original.setTemperature(0.9); + + // Verify copy is unchanged + assertThat(copied.getModel()).isEqualTo("original-model"); + assertThat(copied.getTemperature()).isEqualTo(0.5); + } + + @Test + void testResponseFormatAndThinkingSetters() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + + ChatCompletionRequest.ResponseFormat responseFormat = new ChatCompletionRequest.ResponseFormat("json_object"); + ChatCompletionRequest.Thinking thinking = new ChatCompletionRequest.Thinking("enabled"); + + // Test fluent setters + ZhiPuAiChatOptions result1 = options.setResponseFormat(responseFormat); + assertThat(result1).isSameAs(options); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + + ZhiPuAiChatOptions result2 = options.setThinking(thinking); + assertThat(result2).isSameAs(options); + assertThat(options.getThinking()).isEqualTo(thinking); + } + + @Test + void testToolCallbacksValidation() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + + // Test setting valid tool callbacks + List toolCallbacks = new ArrayList<>(); + options.setToolCallbacks(toolCallbacks); + assertThat(options.getToolCallbacks()).isEqualTo(toolCallbacks); + } + + @Test + void testToolNamesValidation() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + + // Test setting valid tool names + Set toolNames = Set.of("tool1", "tool2"); + options.setToolNames(toolNames); + assertThat(options.getToolNames()).isEqualTo(toolNames); + } + + @Test + void testBuilderWithToolCallbacksAndNames() { + ZhiPuAiChatOptions options = ZhiPuAiChatOptions.builder() + .toolCallbacks(List.of()) + .toolNames(Set.of("tool1", "tool2")) + .build(); + + assertThat(options.getToolCallbacks()).isNotNull().isEmpty(); + assertThat(options.getToolNames()).isEqualTo(Set.of("tool1", "tool2")); + } + + @Test + void testToString() { + ZhiPuAiChatOptions options = ZhiPuAiChatOptions.builder().model("test-model").temperature(0.7).build(); + + String toString = options.toString(); + assertThat(toString).startsWith("ZhiPuAiChatOptions: "); + assertThat(toString).contains("test-model"); + assertThat(toString).contains("0.7"); + } + +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java index 27a376ab0e7..44958b9d157 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java @@ -36,6 +36,7 @@ /** * @author Geng Rong + * @author YunKui Lu */ @EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".+") public class ZhiPuAiApiIT { @@ -57,7 +58,7 @@ void chatCompletionEntityWithMoreParams() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); ResponseEntity response = this.zhiPuAiApi .chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 1024, null, - false, 0.95, 0.7, null, null, null, "test_request_id", false)); + false, 0.95, 0.7, null, null, null, "test_request_id", false, null, null)); assertThat(response).isNotNull(); assertThat(response.getBody()).isNotNull(); diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java index 15f5ac4195c..e06abdf5d3c 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java @@ -54,6 +54,7 @@ import org.springframework.ai.zhipuai.ZhiPuAiTestConfiguration; import org.springframework.ai.zhipuai.api.MockWeatherService; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; @@ -67,6 +68,7 @@ /** * @author Geng Rong + * @author YunKui Lu */ @SpringBootTest(classes = ZhiPuAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".+") @@ -185,7 +187,6 @@ void beanOutputConverter() { @Test void beanOutputConverterRecords() { - BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); String format = outputConverter.getFormat(); @@ -201,14 +202,14 @@ void beanOutputConverterRecords() { Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText()); - logger.info("" + actorsFilms); + logger.info("actorsFilms:{}", actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @Test void beanStreamOutputConverterRecords() { - BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); String format = outputConverter.getFormat(); @@ -232,7 +233,41 @@ void beanStreamOutputConverterRecords() { .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); - logger.info("" + actorsFilms); + logger.info("actorsFilms:{}", actorsFilms); + + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void jsonObjectResponseFormatOutputConverterRecords() { + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = PromptTemplate.builder() + .template(template) + .variables(Map.of("format", format)) + .build(); + Prompt prompt = new Prompt(promptTemplate.createMessage(), + ZhiPuAiChatOptions.builder().responseFormat(ChatCompletionRequest.ResponseFormat.jsonObject()).build()); + + String generationTextFromStream = Objects + .requireNonNull(this.streamingChatModel.stream(prompt).collectList().block()) + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + logger.info("generationTextFromStream:{}", generationTextFromStream); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + logger.info("actorsFilms:{}", actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -294,6 +329,96 @@ void streamFunctionCallTest() { assertThat(content).containsAnyOf("15.0", "15"); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "glm-4.5-flash" }) + void enabledThinkingTest(String modelName) { + UserMessage userMessage = new UserMessage( + "Are there an infinite number of prime numbers such that n mod 4 == 3?"); + + var promptOptions = ZhiPuAiChatOptions.builder() + .model(modelName) + .maxTokens(8192) + .thinking(new ChatCompletionRequest.Thinking("enabled")) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions)); + logger.info("Response: {}", response); + + for (Generation generation : response.getResults()) { + AssistantMessage message = generation.getOutput(); + + assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class); + + assertThat(message.getText()).isNotBlank(); + assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isNotBlank(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "glm-4.5-flash" }) + void disabledThinkingTest(String modelName) { + UserMessage userMessage = new UserMessage( + "Are there an infinite number of prime numbers such that n mod 4 == 3?"); + + var promptOptions = ZhiPuAiChatOptions.builder() + .model(modelName) + .maxTokens(8192) + .thinking(new ChatCompletionRequest.Thinking("disabled")) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions)); + logger.info("Response: {}", response); + + for (Generation generation : response.getResults()) { + AssistantMessage message = generation.getOutput(); + + assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class); + + assertThat(message.getText()).isNotBlank(); + assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isBlank(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "glm-4.5-flash" }) + void streamAndEnableThinkingTest(String modelName) { + UserMessage userMessage = new UserMessage( + "Are there an infinite number of prime numbers such that n mod 4 == 3?"); + + var promptOptions = ZhiPuAiChatOptions.builder() + .model(modelName) + .maxTokens(8192) + .thinking(new ChatCompletionRequest.Thinking("enabled")) + .build(); + + Flux response = this.streamingChatModel.stream(new Prompt(userMessage, promptOptions)); + + StringBuilder reasoningContent = new StringBuilder(); + String content = Objects.requireNonNull(response.collectList().block()) + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(message -> { + if (message instanceof ZhiPuAiAssistantMessage zhiPuAiAssistantMessage) { + if (StringUtils.hasText(zhiPuAiAssistantMessage.getReasoningContent())) { + reasoningContent.append(zhiPuAiAssistantMessage.getReasoningContent()); + return ""; + } + } + return message.getText(); + }) + .collect(Collectors.joining()); + + logger.info("reasoningContent: {}", reasoningContent); + logger.info("content: {}", content); + + // assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class); + + assertThat(reasoningContent).isNotBlank(); + assertThat(content).isNotBlank(); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "glm-4v" }) void multiModalityEmbeddedImage(String modelName) throws IOException { diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc index c13552b190e..d0ee2c9209e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc @@ -144,6 +144,8 @@ The prefix `spring.ai.zhipuai.chat` is the property prefix that lets you configu | spring.ai.zhipuai.chat.options.requestId | The parameter is passed by the client and must ensure uniqueness. It is used to distinguish the unique identifier for each request. If the client does not provide it, the platform will generate it by default. | - | spring.ai.zhipuai.chat.options.doSample | When do_sample is set to true, the sampling strategy is enabled. If do_sample is false, the sampling strategy parameters temperature and top_p will not take effect. | true | spring.ai.zhipuai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false +| spring.ai.zhipuai.chat.options.response-format.type | Control the format of the model output. Set to `json_object` to ensure the message is a valid JSON object. Available options: `text` or `json_object`. | - +| spring.ai.zhipuai.chat.options.thinking.type | Control whether to enable the large model's chain of thought. Available options: `enabled` or `disabled`. | - |==== NOTE: You can override the common `spring.ai.zhipuai.base-url` and `spring.ai.zhipuai.api-key` for the `ChatModel` implementations.