Skip to content

Commit a9ac271

Browse files
committed
feat(zhipuai): Add prompt_tokens_details and update default chat options for tests
- Introduced `prompt_tokens_details` with `cached_tokens` field to `ZhiPuAiApi.Usage` - Updated test cases to replace inline `ChatOptions` with `DEFAULT_CHAT_OPTIONS` - Refactored test models to ensure usage of `glm-4-flash` and `glm-4v-flash` as defaults - Added metadata validations for `promptTokensDetails` in response Signed-off-by: YunKui Lu <[email protected]>
1 parent 3fc1ed6 commit a9ac271

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,13 +1077,31 @@ public record TopLogProbs(// @formatter:off
10771077
* @param promptTokens Number of tokens in the prompt.
10781078
* @param totalTokens Total number of tokens used in the request (prompt +
10791079
* completion).
1080+
* @param promptTokensDetails Details about the prompt tokens used. Support for
1081+
* GLM-4.5 and later models.
10801082
*/
10811083
@JsonInclude(Include.NON_NULL)
10821084
@JsonIgnoreProperties(ignoreUnknown = true)
10831085
public record Usage(// @formatter:off
10841086
@JsonProperty("completion_tokens") Integer completionTokens,
10851087
@JsonProperty("prompt_tokens") Integer promptTokens,
1086-
@JsonProperty("total_tokens") Integer totalTokens) { // @formatter:on
1088+
@JsonProperty("total_tokens") Integer totalTokens,
1089+
@JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails) { // @formatter:on
1090+
1091+
public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) {
1092+
this(completionTokens, promptTokens, totalTokens, null);
1093+
}
1094+
1095+
/**
1096+
* Details about the prompt tokens used.
1097+
*
1098+
* @param cachedTokens Number of tokens in the prompt that were cached.
1099+
*/
1100+
@JsonInclude(Include.NON_NULL)
1101+
@JsonIgnoreProperties(ignoreUnknown = true)
1102+
public record PromptTokensDetails(// @formatter:off
1103+
@JsonProperty("cached_tokens") Integer cachedTokens) { // @formatter:on
1104+
}
10871105

10881106
}
10891107

models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.springframework.ai.chat.model.ChatResponse;
4141
import org.springframework.ai.chat.model.Generation;
4242
import org.springframework.ai.chat.model.StreamingChatModel;
43-
import org.springframework.ai.chat.prompt.ChatOptions;
4443
import org.springframework.ai.chat.prompt.Prompt;
4544
import org.springframework.ai.chat.prompt.PromptTemplate;
4645
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
@@ -85,13 +84,22 @@ class ZhiPuAiChatModelIT {
8584
@Value("classpath:/prompts/system-message.st")
8685
private Resource systemResource;
8786

87+
/**
88+
* Default chat options to use for the tests.
89+
* <p>
90+
* glm-4-flash is a free model, so it is used by default on the tests.
91+
*/
92+
private static final ZhiPuAiChatOptions DEFAULT_CHAT_OPTIONS = ZhiPuAiChatOptions.builder()
93+
.model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
94+
.build();
95+
8896
@Test
8997
void roleTest() {
9098
UserMessage userMessage = new UserMessage(
9199
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
92100
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
93101
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
94-
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), ChatOptions.builder().build());
102+
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), DEFAULT_CHAT_OPTIONS);
95103
ChatResponse response = this.chatModel.call(prompt);
96104
assertThat(response.getResults()).hasSize(1);
97105
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
@@ -104,7 +112,7 @@ void streamRoleTest() {
104112
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
105113
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
106114
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
107-
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
115+
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), DEFAULT_CHAT_OPTIONS);
108116
Flux<ChatResponse> flux = this.streamingChatModel.stream(prompt);
109117

110118
List<ChatResponse> responses = flux.collectList().block();
@@ -135,7 +143,7 @@ void listOutputConverter() {
135143
.template(template)
136144
.variables(Map.of("subject", "ice cream flavors", "format", format))
137145
.build();
138-
Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
146+
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
139147
Generation generation = this.chatModel.call(prompt).getResult();
140148

141149
List<String> list = outputConverter.convert(generation.getOutput().getText());
@@ -157,8 +165,9 @@ void mapOutputConverter() {
157165
.variables(Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format",
158166
format))
159167
.build();
160-
Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
161-
Generation generation = this.chatModel.call(prompt).getResult();
168+
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
169+
ChatResponse chatResponse = this.chatModel.call(prompt);
170+
Generation generation = chatResponse.getResult();
162171

163172
Map<String, Object> result = outputConverter.convert(generation.getOutput().getText());
164173
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -179,7 +188,7 @@ void beanOutputConverter() {
179188
.template(template)
180189
.variables(Map.of("format", format))
181190
.build();
182-
Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
191+
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
183192
Generation generation = this.chatModel.call(prompt).getResult();
184193

185194
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText());
@@ -198,7 +207,7 @@ void beanOutputConverterRecords() {
198207
.template(template)
199208
.variables(Map.of("format", format))
200209
.build();
201-
Prompt prompt = new Prompt(promptTemplate.createMessage(), ChatOptions.builder().build());
210+
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
202211
Generation generation = this.chatModel.call(prompt).getResult();
203212

204213
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText());
@@ -221,7 +230,7 @@ void beanStreamOutputConverterRecords() {
221230
.template(template)
222231
.variables(Map.of("format", format))
223232
.build();
224-
Prompt prompt = new Prompt(promptTemplate.createMessage());
233+
Prompt prompt = new Prompt(promptTemplate.createMessage(), DEFAULT_CHAT_OPTIONS);
225234

226235
String generationTextFromStream = Objects
227236
.requireNonNull(this.streamingChatModel.stream(prompt).collectList().block())
@@ -253,7 +262,10 @@ void jsonObjectResponseFormatOutputConverterRecords() {
253262
.variables(Map.of("format", format))
254263
.build();
255264
Prompt prompt = new Prompt(promptTemplate.createMessage(),
256-
ZhiPuAiChatOptions.builder().responseFormat(ChatCompletionRequest.ResponseFormat.jsonObject()).build());
265+
ZhiPuAiChatOptions.builder()
266+
.model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
267+
.responseFormat(ChatCompletionRequest.ResponseFormat.jsonObject())
268+
.build());
257269

258270
String generationTextFromStream = Objects
259271
.requireNonNull(this.streamingChatModel.stream(prompt).collectList().block())
@@ -281,7 +293,7 @@ void functionCallTest() {
281293
List<Message> messages = new ArrayList<>(List.of(userMessage));
282294

283295
var promptOptions = ZhiPuAiChatOptions.builder()
284-
.model(ZhiPuAiApi.ChatModel.GLM_4.getValue())
296+
.model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
285297
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
286298
.description("Get the weather in location")
287299
.inputType(MockWeatherService.Request.class)
@@ -306,7 +318,7 @@ void streamFunctionCallTest() {
306318
List<Message> messages = new ArrayList<>(List.of(userMessage));
307319

308320
var promptOptions = ZhiPuAiChatOptions.builder()
309-
.model(ZhiPuAiApi.ChatModel.GLM_4.getValue())
321+
.model(ZhiPuAiApi.ChatModel.GLM_4_Flash.getValue())
310322
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
311323
.description("Get the weather in location")
312324
.inputType(MockWeatherService.Request.class)
@@ -332,8 +344,7 @@ void streamFunctionCallTest() {
332344
@ParameterizedTest(name = "{0} : {displayName} ")
333345
@ValueSource(strings = { "glm-4.5-flash" })
334346
void enabledThinkingTest(String modelName) {
335-
UserMessage userMessage = new UserMessage(
336-
"Are there an infinite number of prime numbers such that n mod 4 == 3?");
347+
UserMessage userMessage = new UserMessage("9.11 and 9.8, which is greater?");
337348

338349
var promptOptions = ZhiPuAiChatOptions.builder()
339350
.model(modelName)
@@ -344,14 +355,16 @@ void enabledThinkingTest(String modelName) {
344355
ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions));
345356
logger.info("Response: {}", response);
346357

347-
for (Generation generation : response.getResults()) {
348-
AssistantMessage message = generation.getOutput();
358+
Generation generation = response.getResult();
359+
AssistantMessage message = generation.getOutput();
349360

350-
assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class);
361+
assertThat(message).isInstanceOf(ZhiPuAiAssistantMessage.class);
351362

352-
assertThat(message.getText()).isNotBlank();
353-
assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isNotBlank();
354-
}
363+
assertThat(message.getText()).isNotBlank();
364+
assertThat(((ZhiPuAiAssistantMessage) message).getReasoningContent()).isNotBlank();
365+
366+
ZhiPuAiApi.Usage nativeUsage = (ZhiPuAiApi.Usage) response.getMetadata().getUsage().getNativeUsage();
367+
assertThat(nativeUsage.promptTokensDetails()).isNotNull();
355368
}
356369

357370
@ParameterizedTest(name = "{0} : {displayName} ")
@@ -382,8 +395,7 @@ void disabledThinkingTest(String modelName) {
382395
@ParameterizedTest(name = "{0} : {displayName} ")
383396
@ValueSource(strings = { "glm-4.5-flash" })
384397
void streamAndEnableThinkingTest(String modelName) {
385-
UserMessage userMessage = new UserMessage(
386-
"Are there an infinite number of prime numbers such that n mod 4 == 3?");
398+
UserMessage userMessage = new UserMessage("9.11 and 9.8, which is greater?");
387399

388400
var promptOptions = ZhiPuAiChatOptions.builder()
389401
.model(modelName)
@@ -408,6 +420,7 @@ void streamAndEnableThinkingTest(String modelName) {
408420
}
409421
return message.getText();
410422
})
423+
.filter(StringUtils::hasText)
411424
.collect(Collectors.joining());
412425

413426
logger.info("reasoningContent: {}", reasoningContent);
@@ -420,7 +433,7 @@ void streamAndEnableThinkingTest(String modelName) {
420433
}
421434

422435
@ParameterizedTest(name = "{0} : {displayName} ")
423-
@ValueSource(strings = { "glm-4v" })
436+
@ValueSource(strings = { "glm-4v-flash" })
424437
void multiModalityEmbeddedImage(String modelName) throws IOException {
425438

426439
var imageData = new ClassPathResource("/test.png");
@@ -461,7 +474,7 @@ void reasonerMultiModalityEmbeddedImageThinkingModel(String modelName) throws IO
461474
}
462475

463476
@ParameterizedTest(name = "{0} : {displayName} ")
464-
@ValueSource(strings = { "glm-4v", "glm-4.1v-thinking-flash" })
477+
@ValueSource(strings = { "glm-4v-flash", "glm-4.1v-thinking-flash" })
465478
void multiModalityImageUrl(String modelName) throws IOException {
466479

467480
var userMessage = UserMessage.builder()
@@ -505,7 +518,7 @@ void reasonerMultiModalityImageUrl(String modelName) throws IOException {
505518
}
506519

507520
@ParameterizedTest(name = "{0} : {displayName} ")
508-
@ValueSource(strings = { "glm-4v" })
521+
@ValueSource(strings = { "glm-4v-flash" })
509522
void streamingMultiModalityImageUrl(String modelName) throws IOException {
510523

511524
var userMessage = UserMessage.builder()

0 commit comments

Comments
 (0)