Skip to content

Commit 3d215d6

Browse files
committed
feat(zhipuai): ZhipuAI add thinking and response_format parameter support
- Add `thinking` and `response_format` fields to `ZhiPuAiApi` and `ZhiPuAiChatOptions` - Add ZhiPuAiChatOptionsTests with 16 test methods covering all aspects of the class - Test builder pattern with all fields including responseFormat and thinking - Test copy functionality, setters, default values, and equals/hashCode - Test tool callbacks, tool names validation, and collection handling - Test stop sequences alias and fluent setters - Add documentation for response-format.type and thinking.type properties Signed-off-by: YunKui Lu <[email protected]>
1 parent ae0c418 commit 3d215d6

File tree

7 files changed

+632
-135
lines changed

7 files changed

+632
-135
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiPropertiesTests.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
* {@link ZhiPuAiEmbeddingProperties}.
3838
*
3939
* @author Geng Rong
40+
* @author YunKui Lu
4041
*/
4142
public class ZhiPuAiPropertiesTests {
4243

@@ -243,7 +244,9 @@ public void chatOptionsTest() {
243244
"required": ["location", "lat", "lon", "unit"]
244245
}
245246
""",
246-
"spring.ai.zhipuai.chat.options.user=userXYZ"
247+
"spring.ai.zhipuai.chat.options.user=userXYZ",
248+
"spring.ai.zhipuai.chat.options.response-format.type=json_object",
249+
"spring.ai.zhipuai.chat.options.thinking.type=disabled"
247250
)
248251
// @formatter:on
249252
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
@@ -262,6 +265,8 @@ public void chatOptionsTest() {
262265
assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56);
263266
assertThat(chatProperties.getOptions().getRequestId()).isEqualTo("RequestId");
264267
assertThat(chatProperties.getOptions().getDoSample()).isEqualTo(Boolean.TRUE);
268+
assertThat(chatProperties.getOptions().getResponseFormat().type()).isEqualTo("json_object");
269+
assertThat(chatProperties.getOptions().getThinking().type()).isEqualTo("disabled");
265270

266271
JSONAssert.assertEquals("{\"type\":\"function\",\"function\":{\"name\":\"toolChoiceFunctionName\"}}",
267272
chatProperties.getOptions().getToolChoice(), JSONCompareMode.LENIENT);

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java

Lines changed: 81 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.HashSet;
2323
import java.util.List;
2424
import java.util.Map;
25+
import java.util.Objects;
2526
import java.util.Set;
2627

2728
import com.fasterxml.jackson.annotation.JsonIgnore;
@@ -30,9 +31,11 @@
3031
import com.fasterxml.jackson.annotation.JsonProperty;
3132

3233
import org.springframework.ai.chat.prompt.ChatOptions;
34+
import org.springframework.ai.model.ModelOptionsUtils;
3335
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3436
import org.springframework.ai.tool.ToolCallback;
3537
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
38+
import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest;
3639
import org.springframework.lang.Nullable;
3740
import org.springframework.util.Assert;
3841

@@ -42,6 +45,7 @@
4245
* @author Geng Rong
4346
* @author Thomas Vitale
4447
* @author Ilayaperumal Gopinathan
48+
* @author YunKui Lu
4549
* @since 1.0.0 M1
4650
*/
4751
@JsonInclude(Include.NON_NULL)
@@ -104,6 +108,16 @@ public class ZhiPuAiChatOptions implements ToolCallingChatOptions {
104108
*/
105109
private @JsonProperty("do_sample") Boolean doSample;
106110

111+
/**
112+
* Control the format of the model output. Set to `json_object` to ensure the message is a valid JSON object.
113+
*/
114+
private @JsonProperty("response_format") ChatCompletionRequest.ResponseFormat responseFormat;
115+
116+
/**
117+
* Control whether to enable the large model's chain of thought. Available options: (default) enabled, disabled.
118+
*/
119+
private @JsonProperty("thinking") ChatCompletionRequest.Thinking thinking;
120+
107121
/**
108122
* Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests.
109123
*/
@@ -146,6 +160,8 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) {
146160
.toolNames(fromOptions.getToolNames())
147161
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
148162
.toolContext(fromOptions.getToolContext())
163+
.responseFormat(fromOptions.getResponseFormat())
164+
.thinking(fromOptions.getThinking())
149165
.build();
150166
}
151167

@@ -244,6 +260,24 @@ public void setDoSample(Boolean doSample) {
244260
this.doSample = doSample;
245261
}
246262

263+
public ChatCompletionRequest.ResponseFormat getResponseFormat() {
264+
return this.responseFormat;
265+
}
266+
267+
public ZhiPuAiChatOptions setResponseFormat(ChatCompletionRequest.ResponseFormat responseFormat) {
268+
this.responseFormat = responseFormat;
269+
return this;
270+
}
271+
272+
public ChatCompletionRequest.Thinking getThinking() {
273+
return this.thinking;
274+
}
275+
276+
public ZhiPuAiChatOptions setThinking(ChatCompletionRequest.Thinking thinking) {
277+
this.thinking = thinking;
278+
return this;
279+
}
280+
247281
@Override
248282
@JsonIgnore
249283
public Double getFrequencyPenalty() {
@@ -311,138 +345,52 @@ public Map<String, Object> getToolContext() {
311345

312346
@Override
313347
public void setToolContext(Map<String, Object> toolContext) {
348+
Assert.notNull(toolContext, "toolContext cannot be null");
314349
this.toolContext = toolContext;
315350
}
316351

352+
@Override
353+
public final boolean equals(Object o) {
354+
if (!(o instanceof ZhiPuAiChatOptions that))
355+
return false;
356+
357+
return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens)
358+
&& Objects.equals(this.stop, that.stop) && Objects.equals(this.temperature, that.temperature)
359+
&& Objects.equals(this.topP, that.topP) && Objects.equals(this.tools, that.tools)
360+
&& Objects.equals(this.toolChoice, that.toolChoice) && Objects.equals(this.user, that.user)
361+
&& Objects.equals(this.requestId, that.requestId) && Objects.equals(this.doSample, that.doSample)
362+
&& Objects.equals(this.responseFormat, that.responseFormat)
363+
&& Objects.equals(this.thinking, that.thinking)
364+
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
365+
&& Objects.equals(this.toolNames, that.toolNames)
366+
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
367+
&& Objects.equals(this.toolContext, that.toolContext);
368+
}
369+
317370
@Override
318371
public int hashCode() {
319-
final int prime = 31;
320-
int result = 1;
321-
result = prime * result + ((this.model == null) ? 0 : this.model.hashCode());
322-
result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode());
323-
result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode());
324-
result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode());
325-
result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode());
326-
result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode());
327-
result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode());
328-
result = prime * result + ((this.user == null) ? 0 : this.user.hashCode());
329-
result = prime * result
330-
+ ((this.internalToolExecutionEnabled == null) ? 0 : this.internalToolExecutionEnabled.hashCode());
331-
result = prime * result + ((this.toolCallbacks == null) ? 0 : this.toolCallbacks.hashCode());
332-
result = prime * result + ((this.toolNames == null) ? 0 : this.toolNames.hashCode());
333-
result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode());
372+
int result = Objects.hashCode(this.model);
373+
result = 31 * result + Objects.hashCode(this.maxTokens);
374+
result = 31 * result + Objects.hashCode(this.stop);
375+
result = 31 * result + Objects.hashCode(this.temperature);
376+
result = 31 * result + Objects.hashCode(this.topP);
377+
result = 31 * result + Objects.hashCode(this.tools);
378+
result = 31 * result + Objects.hashCode(this.toolChoice);
379+
result = 31 * result + Objects.hashCode(this.user);
380+
result = 31 * result + Objects.hashCode(this.requestId);
381+
result = 31 * result + Objects.hashCode(this.doSample);
382+
result = 31 * result + Objects.hashCode(this.responseFormat);
383+
result = 31 * result + Objects.hashCode(this.thinking);
384+
result = 31 * result + Objects.hashCode(this.toolCallbacks);
385+
result = 31 * result + Objects.hashCode(this.toolNames);
386+
result = 31 * result + Objects.hashCode(this.internalToolExecutionEnabled);
387+
result = 31 * result + Objects.hashCode(this.toolContext);
334388
return result;
335389
}
336390

337391
@Override
338-
public boolean equals(Object obj) {
339-
if (this == obj) {
340-
return true;
341-
}
342-
if (obj == null) {
343-
return false;
344-
}
345-
if (getClass() != obj.getClass()) {
346-
return false;
347-
}
348-
ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj;
349-
if (this.model == null) {
350-
if (other.model != null) {
351-
return false;
352-
}
353-
}
354-
else if (!this.model.equals(other.model)) {
355-
return false;
356-
}
357-
if (this.maxTokens == null) {
358-
if (other.maxTokens != null) {
359-
return false;
360-
}
361-
}
362-
else if (!this.maxTokens.equals(other.maxTokens)) {
363-
return false;
364-
}
365-
if (this.stop == null) {
366-
if (other.stop != null) {
367-
return false;
368-
}
369-
}
370-
else if (!this.stop.equals(other.stop)) {
371-
return false;
372-
}
373-
if (this.temperature == null) {
374-
if (other.temperature != null) {
375-
return false;
376-
}
377-
}
378-
else if (!this.temperature.equals(other.temperature)) {
379-
return false;
380-
}
381-
if (this.topP == null) {
382-
if (other.topP != null) {
383-
return false;
384-
}
385-
}
386-
else if (!this.topP.equals(other.topP)) {
387-
return false;
388-
}
389-
if (this.tools == null) {
390-
if (other.tools != null) {
391-
return false;
392-
}
393-
}
394-
else if (!this.tools.equals(other.tools)) {
395-
return false;
396-
}
397-
if (this.toolChoice == null) {
398-
if (other.toolChoice != null) {
399-
return false;
400-
}
401-
}
402-
else if (!this.toolChoice.equals(other.toolChoice)) {
403-
return false;
404-
}
405-
if (this.user == null) {
406-
if (other.user != null) {
407-
return false;
408-
}
409-
}
410-
else if (!this.user.equals(other.user)) {
411-
return false;
412-
}
413-
if (this.requestId == null) {
414-
if (other.requestId != null) {
415-
return false;
416-
}
417-
}
418-
else if (!this.requestId.equals(other.requestId)) {
419-
return false;
420-
}
421-
if (this.doSample == null) {
422-
if (other.doSample != null) {
423-
return false;
424-
}
425-
}
426-
else if (!this.doSample.equals(other.doSample)) {
427-
return false;
428-
}
429-
if (this.internalToolExecutionEnabled == null) {
430-
if (other.internalToolExecutionEnabled != null) {
431-
return false;
432-
}
433-
}
434-
else if (!this.internalToolExecutionEnabled.equals(other.internalToolExecutionEnabled)) {
435-
return false;
436-
}
437-
if (this.toolContext == null) {
438-
if (other.toolContext != null) {
439-
return false;
440-
}
441-
}
442-
else if (!this.toolContext.equals(other.toolContext)) {
443-
return false;
444-
}
445-
return true;
392+
public String toString() {
393+
return "ZhiPuAiChatOptions: " + ModelOptionsUtils.toJsonString(this);
446394
}
447395

448396
@Override
@@ -610,6 +558,16 @@ public Builder toolContext(Map<String, Object> toolContext) {
610558
return this;
611559
}
612560

561+
public Builder responseFormat(ChatCompletionRequest.ResponseFormat responseFormat) {
562+
this.options.responseFormat = responseFormat;
563+
return this;
564+
}
565+
566+
public Builder thinking(ChatCompletionRequest.Thinking thinking) {
567+
this.options.thinking = thinking;
568+
return this;
569+
}
570+
613571
public ZhiPuAiChatOptions build() {
614572
return this.options;
615573
}

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,9 @@ public void setJsonSchema(String jsonSchema) {
652652
* logged and can be used for debugging purposes.
653653
* @param doSample If set, the model will use sampling to generate the next token. If
654654
* not set, the model will use greedy decoding to generate the next token.
655+
* @param responseFormat Control the format of the model output. Set to `json_object`
656+
* to ensure the message is a valid JSON object.
657+
* @param thinking Control whether to enable the large model's chain of thought.
655658
*/
656659
@JsonInclude(Include.NON_NULL)
657660
public record ChatCompletionRequest(// @formatter:off
@@ -664,9 +667,11 @@ public record ChatCompletionRequest(// @formatter:off
664667
@JsonProperty("top_p") Double topP,
665668
@JsonProperty("tools") List<FunctionTool> tools,
666669
@JsonProperty("tool_choice") Object toolChoice,
667-
@JsonProperty("user") String user,
670+
@JsonProperty("user_id") String user,
668671
@JsonProperty("request_id") String requestId,
669-
@JsonProperty("do_sample") Boolean doSample) { // @formatter:on
672+
@JsonProperty("do_sample") Boolean doSample,
673+
@JsonProperty("response_format") ResponseFormat responseFormat,
674+
@JsonProperty("thinking") Thinking thinking) { // @formatter:on
670675

671676
/**
672677
* Shortcut constructor for a chat completion request with the given messages and
@@ -676,7 +681,7 @@ public record ChatCompletionRequest(// @formatter:off
676681
* @param temperature What sampling temperature to use, between 0 and 1.
677682
*/
678683
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
679-
this(messages, model, null, null, false, temperature, null, null, null, null, null, null);
684+
this(messages, model, null, null, false, temperature, null, null, null, null, null, null, null, null);
680685
}
681686

682687
/**
@@ -691,7 +696,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
691696
*/
692697
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature,
693698
boolean stream) {
694-
this(messages, model, null, null, stream, temperature, null, null, null, null, null, null);
699+
this(messages, model, null, null, stream, temperature, null, null, null, null, null, null, null, null);
695700
}
696701

697702
/**
@@ -706,7 +711,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
706711
*/
707712
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, List<FunctionTool> tools,
708713
Object toolChoice) {
709-
this(messages, model, null, null, false, 0.8, null, tools, toolChoice, null, null, null);
714+
this(messages, model, null, null, false, 0.8, null, tools, toolChoice, null, null, null, null, null);
710715
}
711716

712717
/**
@@ -719,7 +724,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
719724
* terminated by a data: [DONE] message.
720725
*/
721726
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
722-
this(messages, null, null, null, stream, null, null, null, null, null, null, null);
727+
this(messages, null, null, null, stream, null, null, null, null, null, null, null, null, null);
723728
}
724729

725730
/**
@@ -754,7 +759,32 @@ public static Object function(String functionName) {
754759
*/
755760
@JsonInclude(Include.NON_NULL)
756761
public record ResponseFormat(@JsonProperty("type") String type) {
762+
763+
public static ResponseFormat text() {
764+
return new ResponseFormat("text");
765+
}
766+
767+
public static ResponseFormat jsonObject() {
768+
return new ResponseFormat("json_object");
769+
}
770+
}
771+
772+
/**
773+
* Control whether to enable the large model's chain of thought
774+
*
775+
* @param type Available options: (default) enabled, disabled
776+
*/
777+
@JsonInclude(Include.NON_NULL)
778+
public record Thinking(@JsonProperty("type") String type) {
779+
public static Thinking enabled() {
780+
return new Thinking("enabled");
781+
}
782+
783+
public static Thinking disabled() {
784+
return new Thinking("disabled");
785+
}
757786
}
787+
758788
}
759789

760790
/**

0 commit comments

Comments
 (0)