Skip to content

Commit e389df6

Browse files
Merge remote-tracking branch 'origin/main'
1 parent e7d8071 commit e389df6

File tree

10 files changed

+215
-133
lines changed

10 files changed

+215
-133
lines changed

models/spring-ai-hunyuan/pom.xml

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,7 @@
101101
<artifactId>activation</artifactId>
102102
<version>1.1.1</version>
103103
</dependency>
104-
<dependency>
105-
<groupId>com.alibaba.fastjson2</groupId>
106-
<artifactId>fastjson2</artifactId>
107-
<version>${fastjson.version}</version>
108-
</dependency>
109-
<dependency>
110-
<groupId>com.fasterxml.jackson.core</groupId>
111-
<artifactId>jackson-databind</artifactId>
112-
<version>${jackson.version}</version>
113-
</dependency>
104+
114105
</dependencies>
115106

116107
</project>

models/spring-ai-hunyuan/src/main/java/org/springframework/ai/hunyuan/HunYuanChatModel.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met
168168
.toList();
169169

170170
var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
171-
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
171+
String finishReason = (choice.finishReason() != null ? choice.finishReason(): "");
172172
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
173173
return new Generation(assistantMessage, generationMetadata);
174174
}
@@ -208,7 +208,7 @@ public ChatResponse call(Prompt prompt) {
208208
Map<String, Object> metadata = Map.of(
209209
"id", chatCompletion.id(),
210210
"role", choice.message().role() != null ? choice.message().role().name() : "",
211-
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
211+
"finishReason", choice.finishReason() != null ? choice.finishReason() : ""
212212
);
213213
// @formatter:on
214214
return buildGeneration(choice, metadata);
@@ -222,7 +222,7 @@ public ChatResponse call(Prompt prompt) {
222222
});
223223

224224
if (!isProxyToolCalls(prompt, this.defaultOptions)
225-
&& isToolCall(response, Set.of(HunYuanApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
225+
&& isToolCall(response, Set.of(HunYuanApi.ChatCompletionFinishReason.TOOL_CALLS.getJsonValue(),
226226
HunYuanApi.ChatCompletionFinishReason.STOP.name()))) {
227227
var toolCallConversation = handleToolCalls(prompt, response);
228228
// Recursively call the call method with the tool call message
@@ -277,7 +277,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
277277
Map<String, Object> metadata = Map.of(
278278
"id", chatCompletion2.id(),
279279
"role", roleMap.getOrDefault(id, ""),
280-
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
280+
"finishReason", choice.finishReason() != null ? choice.finishReason() : ""
281281
);
282282
// @formatter:on
283283
return buildGeneration(choice, metadata);
@@ -294,7 +294,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
294294

295295
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
296296
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response,
297-
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
297+
Set.of(ChatCompletionFinishReason.TOOL_CALLS.getJsonValue(), ChatCompletionFinishReason.STOP.getJsonValue()))) {
298298
var toolCallConversation = handleToolCalls(prompt, response);
299299
// Recursively call the stream method with the tool call message
300300
// conversation that contains the call responses.
@@ -327,14 +327,14 @@ private ChatResponseMetadata from(ChatCompletionRequest request, ChatCompletion
327327
*/
328328
private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
329329
List<ChatCompletion.Choice> choices = chunk.choices().stream().map(cc -> {
330-
ChatCompletionMessage delta = cc.delta();
330+
ChatCompletionMessage delta = cc.message();
331331
if (delta == null) {
332332
delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.assistant);
333333
}
334334
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(),null);
335335
}).toList();
336336

337-
return new ChatCompletion(chunk.id(), null, chunk.created(), chunk.model(), choices, null, null, null, null, null, null);
337+
return new ChatCompletion(chunk.id(), null, chunk.created(), chunk.note(), choices, null, null, null, null, null, null);
338338
}
339339

340340
/**

models/spring-ai-hunyuan/src/main/java/org/springframework/ai/hunyuan/api/HunYuanApi.java

Lines changed: 58 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
import java.util.function.Predicate;
2323
import java.util.stream.Collectors;
2424

25-
import com.alibaba.fastjson2.JSONObject;
2625
import com.fasterxml.jackson.annotation.JsonInclude;
2726
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2827
import com.fasterxml.jackson.annotation.JsonProperty;
2928
import com.fasterxml.jackson.core.JsonProcessingException;
3029
import com.fasterxml.jackson.databind.DeserializationFeature;
3130
import com.fasterxml.jackson.databind.ObjectMapper;
31+
import org.slf4j.Logger;
32+
import org.slf4j.LoggerFactory;
3233
import org.springframework.ai.hunyuan.api.auth.HunYuanAuthApi;
3334
import org.springframework.util.CollectionUtils;
3435
import org.springframework.util.MultiValueMap;
@@ -57,6 +58,8 @@
5758
*/
5859
public class HunYuanApi {
5960

61+
private static final Logger logger = LoggerFactory.getLogger(HunYuanApi.class);
62+
6063
public static final String DEFAULT_CHAT_MODEL = ChatModel.HUNYUAN_PRO.getValue();
6164

6265
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
@@ -69,7 +72,6 @@ public class HunYuanApi {
6972

7073
private final HunYuanStreamFunctionCallingHelper chunkMerger = new HunYuanStreamFunctionCallingHelper();
7174

72-
private final ObjectMapper objectMapper;
7375
/**
7476
* Create a new client api with DEFAULT_BASE_URL
7577
* @param secretId Hunyuan SecretId.
@@ -103,7 +105,6 @@ public HunYuanApi(String baseUrl, String secretId, String secretKey, RestClient.
103105
headers.setContentType(MediaType.APPLICATION_JSON);
104106
};
105107
hunyuanAuthApi = new HunYuanAuthApi(secretId, secretKey);
106-
this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
107108
this.restClient = restClientBuilder.baseUrl(baseUrl)
108109
.defaultHeaders(jsonContentHeaders)
109110
.defaultStatusHandler(responseErrorHandler)
@@ -121,7 +122,6 @@ public HunYuanApi(String baseUrl, String secretId, String secretKey, RestClient.
121122
public ResponseEntity<ChatCompletionResponse> chatCompletionEntity(ChatCompletionRequest chatRequest) {
122123

123124
Assert.notNull(chatRequest, "The request body can not be null.");
124-
// Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
125125
String service = HunYuanConstants.DEFAULT_SERVICE;
126126
String host = HunYuanConstants.DEFAULT_CHAT_HOST;
127127
// String region = "ap-guangzhou";
@@ -136,12 +136,8 @@ public ResponseEntity<ChatCompletionResponse> chatCompletionEntity(ChatCompletio
136136
.retrieve()
137137
.toEntity(String.class);
138138
// 使用 ObjectMapper 将响应体字符串转换为 ChatCompletionResponse 对象
139-
ChatCompletionResponse chatCompletionResponse = null;
140-
try {
141-
chatCompletionResponse = objectMapper.readValue(retrieve.getBody(), ChatCompletionResponse.class);
142-
} catch (JsonProcessingException e) {
143-
throw new RuntimeException(e);
144-
}
139+
logger.info("Response body: {}", retrieve.getBody());
140+
ChatCompletionResponse chatCompletionResponse = ModelOptionsUtils.jsonToObject(retrieve.getBody(), ChatCompletionResponse.class);
145141
return ResponseEntity.ok(chatCompletionResponse);
146142
}
147143

@@ -155,17 +151,27 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
155151
Assert.notNull(chatRequest, "The request body can not be null.");
156152
Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true.");
157153
AtomicBoolean isInsideTool = new AtomicBoolean(false);
158-
154+
String service = HunYuanConstants.DEFAULT_SERVICE;
155+
String host = HunYuanConstants.DEFAULT_CHAT_HOST;
156+
// String region = "ap-guangzhou";
157+
String action = HunYuanConstants.DEFAULT_CHAT_ACTION;
158+
MultiValueMap<String, String> jsonContentHeaders = hunyuanAuthApi.getHttpHeadersConsumer(host, action, service, chatRequest);
159159
return this.webClient.post()
160-
.uri("/v1/chat/completions")
160+
.uri("/")
161+
.headers(headers -> {
162+
headers.addAll(jsonContentHeaders);
163+
})
161164
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
162165
.retrieve()
163166
.bodyToFlux(String.class)
164167
// cancels the flux stream after the "[DONE]" is received.
165168
.takeUntil(SSE_DONE_PREDICATE)
166169
// filters out the "[DONE]" message.
167170
.filter(SSE_DONE_PREDICATE.negate())
168-
.map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class))
171+
.map(content ->{
172+
// logger.info(content);
173+
return ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class);
174+
})
169175
// Detect is the chunk is part of a streaming function call.
170176
.map(chunk -> {
171177
if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) {
@@ -188,7 +194,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
188194
// Flux<Flux<ChatCompletionChunk>> -> Flux<Mono<ChatCompletionChunk>>
189195
.concatMapIterable(window -> {
190196
Mono<ChatCompletionChunk> monoChunk = window.reduce(
191-
new ChatCompletionChunk(null, null, null, null, null),
197+
new ChatCompletionChunk(null, null,null,null,null,null,null,null, null, null, null),
192198
(previous, current) -> this.chunkMerger.merge(previous, current));
193199
return List.of(monoChunk);
194200
})
@@ -225,8 +231,16 @@ public enum ChatCompletionFinishReason {
225231
* Only for compatibility with Mistral AI API.
226232
*/
227233
@JsonProperty("tool_call")
228-
TOOL_CALL
234+
TOOL_CALL;
235+
236+
private final String jsonValue;
229237

238+
ChatCompletionFinishReason() {
239+
this.jsonValue = this.name().toLowerCase();
240+
}
241+
public String getJsonValue() {
242+
return this.jsonValue;
243+
}
230244
}
231245

232246
/**
@@ -455,7 +469,7 @@ public static Object function(String functionName) {
455469
* @param rawContent The raw contents of the message.
456470
* @param role The role of the message's author. Could be one of the {@link Role}
457471
* types.
458-
* @param name The name of the message's author.
472+
* @param chatContent The name of the message's author.
459473
* @param toolCallId The ID of the tool call associated with the message.
460474
* @param toolCalls The list of tool calls associated with the message.
461475
*/
@@ -464,7 +478,7 @@ public record ChatCompletionMessage(
464478
// @formatter:off
465479
@JsonProperty("Content") Object rawContent,
466480
@JsonProperty("Role") Role role,
467-
@JsonProperty("Contents") List<ChatContent> name,
481+
@JsonProperty("Contents") List<ChatContent> chatContents,
468482
@JsonProperty("ToolCallId") String toolCallId,
469483
@JsonProperty("ToolCalls") List<ToolCall> toolCalls
470484
// @formatter:on
@@ -480,6 +494,10 @@ public ChatCompletionMessage(Object content, Role role) {
480494
this(content, role, null, null, null);
481495
}
482496

497+
public ChatCompletionMessage(Role role,List<ChatContent> chatContent) {
498+
this(null, role, chatContent, null, null);
499+
}
500+
483501
/**
484502
* Get message content as String.
485503
*/
@@ -533,14 +551,20 @@ public enum Role {
533551
* @param function The function definition.
534552
*/
535553
@JsonInclude(Include.NON_NULL)
536-
public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type, @JsonProperty("Index") Integer index,
537-
@JsonProperty("function") ChatCompletionFunction function) {
554+
public record ToolCall(@JsonProperty("Id") String id, @JsonProperty("Type") String type, @JsonProperty("Index") Integer index,
555+
@JsonProperty("Function") ChatCompletionFunction function) {
538556

539557
}
540558

541559
@JsonInclude(Include.NON_NULL)
542560
public record ChatContent(@JsonProperty("Type") String type, @JsonProperty("Text") String text,
543561
@JsonProperty("ImageUrl") ImageUrl imageUrl) {
562+
public ChatContent(String type, String text) {
563+
this(type, text, null);
564+
}
565+
public ChatContent(String type, ImageUrl imageUrl) {
566+
this(type, null, imageUrl);
567+
}
544568

545569
}
546570
@JsonInclude(Include.NON_NULL)
@@ -704,29 +728,18 @@ public record Choice(
704728
// @formatter:off
705729
@JsonProperty("Index") Integer index,
706730
@JsonProperty("Message") ChatCompletionMessage message,
707-
@JsonProperty("FinishReason") ChatCompletionFinishReason finishReason,
708-
@JsonProperty("Delta") ChatCompletionDelta chatCompletionDelta
731+
@JsonProperty("FinishReason") String finishReason,
732+
@JsonProperty("Delta") ChatCompletionDelta delta
709733
) {
710734
// @formatter:on
711735
}
712736

713737
@JsonInclude(Include.NON_NULL)
714738
public record ChatCompletionDelta(
715739
// @formatter:off
716-
@JsonProperty("Role") String role,
740+
@JsonProperty("Role") ChatCompletionMessage.Role role,
717741
@JsonProperty("Content") String content,
718-
@JsonProperty("ToolCalls") ChatCompletionToolCall chatCompletionToolCall
719-
) {
720-
// @formatter:on
721-
}
722-
723-
@JsonInclude(Include.NON_NULL)
724-
public record ChatCompletionToolCall(
725-
// @formatter:off
726-
@JsonProperty("Id") String role,
727-
@JsonProperty("Type") String content,
728-
@JsonProperty("Function") ChatCompletionMessage.ChatCompletionFunction chatCompletionToolCall,
729-
@JsonProperty("Index") Integer index
742+
@JsonProperty("ToolCalls") List<ChatCompletionMessage.ToolCall> toolCalls
730743
) {
731744
// @formatter:on
732745
}
@@ -758,33 +771,17 @@ public record ErrorMsg(
758771
@JsonInclude(Include.NON_NULL)
759772
public record ChatCompletionChunk(
760773
// @formatter:off
761-
@JsonProperty("id") String id,
762-
@JsonProperty("object") String object,
763-
@JsonProperty("created") Long created,
764-
@JsonProperty("model") String model,
765-
@JsonProperty("choices") List<ChunkChoice> choices) {
766-
// @formatter:on
767-
768-
/**
769-
* Chat completion choice.
770-
*
771-
* @param index The index of the choice in the list of choices.
772-
* @param delta A chat completion delta generated by streamed model responses.
773-
* @param finishReason The reason the model stopped generating tokens.
774-
* @param usage Usage statistics for the completion request.
775-
*/
776-
@JsonInclude(Include.NON_NULL)
777-
public record ChunkChoice(
778-
// @formatter:off
779-
@JsonProperty("index") Integer index,
780-
@JsonProperty("delta") ChatCompletionMessage delta,
781-
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
782-
@JsonProperty("usage") Usage usage
783-
// @formatter:on
784-
) {
785-
786-
}
787-
774+
@JsonProperty("Id") String id,
775+
@JsonProperty("Error") ChatCompletion.ErrorMsg errorMsg,
776+
@JsonProperty("Created") Long created,
777+
@JsonProperty("Note") String note,
778+
@JsonProperty("Choices") List<ChatCompletion.Choice> choices,
779+
@JsonProperty("Usage") Usage usage,
780+
@JsonProperty("ModerationLevel") String moderationLevel,
781+
@JsonProperty("SearchInfo") ChatCompletion.SearchInfo searchInfo,
782+
@JsonProperty("Replaces") List<ChatCompletion.Replace> replaces,
783+
@JsonProperty("RecommendedQuestions") List<String> recommendedQuestions,
784+
@JsonProperty("RequestId") String requestId) {
788785
}
789786

790787
/**

0 commit comments

Comments
 (0)