Skip to content

Commit 5d40435

Browse files
feat(spring-ai): 添加 HunYuan 模型支持
- 在 pom.xml 中添加 spring-ai-starter-hunyuan 模块 - 实现 HunYuanRuntimeHints 类以注册 HunYuan API 的运行时提示 - 开发 HunYuanAuthApi 类处理身份验证和授权 - 创建 HunYuanApi 类提供与 HunYuan 平台交互的客户端库 - 定义 HunYuanConstants 存储常量 - 实现 HunYuanStreamFunctionCallingHelper 类支持流式函数调用
1 parent c5a6976 commit 5d40435

File tree

25 files changed

+1234
-318
lines changed

25 files changed

+1234
-318
lines changed

models/spring-ai-hunyuan/src/main/java/org/springframework/ai/hunyuan/HunYuanChatModel.java

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
2626
import org.slf4j.Logger;
2727
import org.slf4j.LoggerFactory;
28-
import org.springframework.ai.chat.messages.UserMessage;
28+
import org.springframework.ai.chat.messages.*;
2929
import org.springframework.ai.hunyuan.api.HunYuanApi;
3030
import org.springframework.ai.hunyuan.api.HunYuanApi.*;
3131
import org.springframework.ai.hunyuan.api.HunYuanApi.ChatCompletionMessage.*;
@@ -35,9 +35,6 @@
3535
import reactor.core.publisher.Flux;
3636
import reactor.core.publisher.Mono;
3737

38-
import org.springframework.ai.chat.messages.AssistantMessage;
39-
import org.springframework.ai.chat.messages.MessageType;
40-
import org.springframework.ai.chat.messages.ToolResponseMessage;
4138
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
4239
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
4340
import org.springframework.ai.chat.metadata.EmptyUsage;
@@ -100,17 +97,17 @@ public class HunYuanChatModel extends AbstractToolCallSupport implements ChatMod
10097

10198
/**
10299
* Initializes a new instance of the HunYuanChatModel.
103-
* @param hunYuanApi The HunYuan instance to be used for interacting with the
104-
* HunYuan Chat API.
100+
* @param hunYuanApi The HunYuan instance to be used for interacting with the HunYuan
101+
* Chat API.
105102
*/
106103
public HunYuanChatModel(HunYuanApi hunYuanApi) {
107104
this(hunYuanApi, HunYuanChatOptions.builder().model(HunYuanApi.DEFAULT_CHAT_MODEL).build());
108105
}
109106

110107
/**
111108
* Initializes a new instance of the HunYuanChatModel.
112-
* @param hunYuanApi The HunYuan instance to be used for interacting with the
113-
* HunYuan Chat API.
109+
* @param hunYuanApi The HunYuan instance to be used for interacting with the HunYuan
110+
* Chat API.
114111
* @param options The HunYuanChatOptions to configure the chat client.
115112
*/
116113
public HunYuanChatModel(HunYuanApi hunYuanApi, HunYuanChatOptions options) {
@@ -119,31 +116,31 @@ public HunYuanChatModel(HunYuanApi hunYuanApi, HunYuanChatOptions options) {
119116

120117
/**
121118
* Initializes a new instance of the HunYuanChatModel.
122-
* @param hunYuanApi The HunYuan instance to be used for interacting with the
123-
* HunYuan Chat API.
119+
* @param hunYuanApi The HunYuan instance to be used for interacting with the HunYuan
120+
* Chat API.
124121
* @param options The HunYuanChatOptions to configure the chat client.
125122
* @param functionCallbackResolver The function callback resolver to resolve the
126123
* function by its name.
127124
* @param retryTemplate The retry template.
128125
*/
129126
public HunYuanChatModel(HunYuanApi hunYuanApi, HunYuanChatOptions options,
130-
FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) {
127+
FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) {
131128
this(hunYuanApi, options, functionCallbackResolver, List.of(), retryTemplate, ObservationRegistry.NOOP);
132129
}
133130

134131
/**
135132
* Initializes a new instance of the HunYuanChatModel.
136-
* @param hunYuanApi The HunYuan instance to be used for interacting with the
137-
* HunYuan Chat API.
133+
* @param hunYuanApi The HunYuan instance to be used for interacting with the HunYuan
134+
* Chat API.
138135
* @param options The HunYuanChatOptions to configure the chat client.
139136
* @param functionCallbackResolver resolves the function by its name.
140137
* @param toolFunctionCallbacks The tool function callbacks.
141138
* @param retryTemplate The retry template.
142139
* @param observationRegistry The ObservationRegistry used for instrumentation.
143140
*/
144141
public HunYuanChatModel(HunYuanApi hunYuanApi, HunYuanChatOptions options,
145-
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
146-
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
142+
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
143+
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
147144
super(functionCallbackResolver, options, toolFunctionCallbacks);
148145
Assert.notNull(hunYuanApi, "HunYuanApi must not be null");
149146
Assert.notNull(options, "Options must not be null");
@@ -167,7 +164,7 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met
167164
.toList();
168165

169166
var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
170-
String finishReason = (choice.finishReason() != null ? choice.finishReason(): "");
167+
String finishReason = (choice.finishReason() != null ? choice.finishReason() : "");
171168
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
172169
return new Generation(assistantMessage, generationMetadata);
173170
}
@@ -213,7 +210,8 @@ public ChatResponse call(Prompt prompt) {
213210
return buildGeneration(choice, metadata);
214211
}).toList();
215212

216-
ChatResponse chatResponse = new ChatResponse(generations, from(request,completionEntity.getBody().response()));
213+
ChatResponse chatResponse = new ChatResponse(generations,
214+
from(request, completionEntity.getBody().response()));
217215

218216
observationContext.setResponse(chatResponse);
219217

@@ -292,8 +290,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
292290
}));
293291

294292
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
295-
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response,
296-
Set.of(ChatCompletionFinishReason.TOOL_CALLS.getJsonValue(), ChatCompletionFinishReason.STOP.getJsonValue()))) {
293+
if (!isProxyToolCalls(prompt, this.defaultOptions)
294+
&& isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.getJsonValue(),
295+
ChatCompletionFinishReason.STOP.getJsonValue()))) {
297296
var toolCallConversation = handleToolCalls(prompt, response);
298297
// Recursively call the stream method with the tool call message
299298
// conversation that contains the call responses.
@@ -325,21 +324,22 @@ private ChatResponseMetadata from(ChatCompletionRequest request, ChatCompletion
325324
* @return the ChatCompletion
326325
*/
327326
private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
328-
List<ChatCompletion.Choice> choices = chunk.choices()
329-
.stream()
330-
.map(chunkChoice -> {
331-
ChatCompletionMessage chatCompletionMessage = null;
332-
ChatCompletionDelta delta = chunkChoice.delta();
333-
if (delta == null) {
334-
chatCompletionMessage = new ChatCompletionMessage("", Role.assistant);
335-
}else {
336-
chatCompletionMessage = new ChatCompletionMessage(delta.content(), delta.role(),delta.toolCalls());
337-
}
338-
return new ChatCompletion.Choice(chunkChoice.index(), chatCompletionMessage, chunkChoice.finishReason(),delta);
339-
})
340-
.toList();
327+
List<ChatCompletion.Choice> choices = chunk.choices().stream().map(chunkChoice -> {
328+
ChatCompletionMessage chatCompletionMessage = null;
329+
ChatCompletionDelta delta = chunkChoice.delta();
330+
if (delta == null) {
331+
chatCompletionMessage = new ChatCompletionMessage("", Role.assistant);
332+
}
333+
else {
334+
chatCompletionMessage = new ChatCompletionMessage(delta.content(), delta.role(), delta.toolCalls());
335+
}
336+
return new ChatCompletion.Choice(chunkChoice.index(), chatCompletionMessage, chunkChoice.finishReason(),
337+
delta);
338+
}).toList();
341339

342-
return new ChatCompletion(chunk.id(), chunk.errorMsg(), chunk.created(), chunk.note(), choices, chunk.usage(), chunk.moderationLevel(), chunk.searchInfo(), chunk.replaces(), chunk.recommendedQuestions(), chunk.requestId());
340+
return new ChatCompletion(chunk.id(), chunk.errorMsg(), chunk.created(), chunk.note(), choices, chunk.usage(),
341+
chunk.moderationLevel(), chunk.searchInfo(), chunk.replaces(), chunk.recommendedQuestions(),
342+
chunk.requestId());
343343
}
344344

345345
/**
@@ -362,22 +362,22 @@ public HunYuanApi.ChatCompletionRequest createRequest(Prompt prompt, boolean str
362362
List<ChatContent> contentList = new ArrayList<>(List.of(new ChatContent(message.getText())));
363363

364364
contentList.addAll(userMessage.getMedia()
365-
.stream()
366-
.map(media -> new ChatContent(new ImageUrl(
367-
this.fromMediaData(media.getMimeType(), media.getData()))))
368-
.toList());
369-
return List.of(new ChatCompletionMessage(Role.user,contentList));
365+
.stream()
366+
.map(media -> new ChatContent(
367+
new ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))))
368+
.toList());
369+
return List.of(new ChatCompletionMessage(Role.user, contentList));
370370
}
371371
}
372-
return List.of(new ChatCompletionMessage(content,Role.user));
372+
return List.of(new ChatCompletionMessage(content, Role.user));
373373
}
374374
else if (message.getMessageType() == MessageType.ASSISTANT) {
375375
var assistantMessage = (AssistantMessage) message;
376376
List<ToolCall> toolCalls = null;
377377
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
378378
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
379379
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
380-
return new ToolCall(toolCall.id(), toolCall.type(),null, function);
380+
return new ToolCall(toolCall.id(), toolCall.type(), null, function);
381381
}).toList();
382382
}
383383
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
@@ -438,6 +438,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
438438

439439
return request;
440440
}
441+
441442
private String fromMediaData(MimeType mimeType, Object mediaContentData) {
442443
if (mediaContentData instanceof byte[] bytes) {
443444
// Assume the bytes are an image. So, convert the bytes to a base64 encoded
@@ -453,12 +454,13 @@ else if (mediaContentData instanceof String text) {
453454
"Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
454455
}
455456
}
457+
456458
private ChatOptions buildRequestOptions(HunYuanApi.ChatCompletionRequest request) {
457459
return ChatOptions.builder()
458460
.model(request.model())
459-
// .frequencyPenalty(request.frequencyPenalty())
460-
// .maxTokens(request.maxTokens())
461-
// .presencePenalty(request.presencePenalty())
461+
// .frequencyPenalty(request.frequencyPenalty())
462+
// .maxTokens(request.maxTokens())
463+
// .presencePenalty(request.presencePenalty())
462464
.stopSequences(request.stop())
463465
.temperature(request.temperature())
464466
.topP(request.topP())

models/spring-ai-hunyuan/src/main/java/org/springframework/ai/hunyuan/HunYuanChatOptions.java

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,35 +38,37 @@ public class HunYuanChatOptions implements FunctionCallingOptions {
3838

3939
private @JsonProperty("Model") String model;
4040

41-
4241
private @JsonProperty("Temperature") Double temperature;
4342

44-
4543
private @JsonProperty("TopP") Double topP;
4644

47-
4845
private @JsonProperty("Seed") Integer seed;
4946

50-
5147
private @JsonProperty("EnableEnhancement") Boolean enableEnhancement;
5248

5349
private @JsonProperty("StreamModeration") Boolean streamModeration;
5450

55-
5651
private @JsonProperty("Stop") List<String> stop;
5752

5853
private @JsonProperty("Tools") List<HunYuanApi.FunctionTool> tools;
5954

6055
private @JsonProperty("ToolChoice") String toolChoice;
56+
6157
private @JsonProperty("CustomTool") HunYuanApi.FunctionTool customTool;
58+
6259
private @JsonProperty("SearchInfo") Boolean searchInfo;
60+
6361
private @JsonProperty("Citation") Boolean citation;
62+
6463
private @JsonProperty("EnableSpeedSearch") Boolean enableSpeedSearch;
64+
6565
private @JsonProperty("EnableMultimedia") Boolean enableMultimedia;
66+
6667
private @JsonProperty("EnableDeepSearch") Boolean enableDeepSearch;
68+
6769
private @JsonProperty("ForceSearchEnhancement") Boolean ForceSearchEnhancement;
68-
private @JsonProperty("EnableRecommendedQuestions") Boolean enableRecommendedQuestions;
6970

71+
private @JsonProperty("EnableRecommendedQuestions") Boolean enableRecommendedQuestions;
7072

7173
/**
7274
* HunYuan Tool Function Callbacks to register with the ChatModel. For Prompt Options
@@ -318,25 +320,37 @@ public void setToolContext(Map<String, Object> toolContext) {
318320

319321
@Override
320322
public boolean equals(Object o) {
321-
if (this == o) return true;
322-
if (o == null || getClass() != o.getClass()) return false;
323+
if (this == o)
324+
return true;
325+
if (o == null || getClass() != o.getClass())
326+
return false;
323327

324328
HunYuanChatOptions that = (HunYuanChatOptions) o;
325329

326-
if (!model.equals(that.model)) return false;
327-
if (!Objects.equals(temperature, that.temperature)) return false;
328-
if (!Objects.equals(topP, that.topP)) return false;
329-
if (!Objects.equals(seed, that.seed)) return false;
330+
if (!model.equals(that.model))
331+
return false;
332+
if (!Objects.equals(temperature, that.temperature))
333+
return false;
334+
if (!Objects.equals(topP, that.topP))
335+
return false;
336+
if (!Objects.equals(seed, that.seed))
337+
return false;
330338
if (!Objects.equals(enableEnhancement, that.enableEnhancement))
331339
return false;
332340
if (!Objects.equals(streamModeration, that.streamModeration))
333341
return false;
334-
if (!Objects.equals(stop, that.stop)) return false;
335-
if (!Objects.equals(tools, that.tools)) return false;
336-
if (!Objects.equals(toolChoice, that.toolChoice)) return false;
337-
if (!Objects.equals(customTool, that.customTool)) return false;
338-
if (!Objects.equals(searchInfo, that.searchInfo)) return false;
339-
if (!Objects.equals(citation, that.citation)) return false;
342+
if (!Objects.equals(stop, that.stop))
343+
return false;
344+
if (!Objects.equals(tools, that.tools))
345+
return false;
346+
if (!Objects.equals(toolChoice, that.toolChoice))
347+
return false;
348+
if (!Objects.equals(customTool, that.customTool))
349+
return false;
350+
if (!Objects.equals(searchInfo, that.searchInfo))
351+
return false;
352+
if (!Objects.equals(citation, that.citation))
353+
return false;
340354
if (!Objects.equals(enableSpeedSearch, that.enableSpeedSearch))
341355
return false;
342356
if (!Objects.equals(enableMultimedia, that.enableMultimedia))
@@ -349,7 +363,8 @@ public boolean equals(Object o) {
349363
return false;
350364
if (!Objects.equals(functionCallbacks, that.functionCallbacks))
351365
return false;
352-
if (!Objects.equals(functions, that.functions)) return false;
366+
if (!Objects.equals(functions, that.functions))
367+
return false;
353368
if (!Objects.equals(proxyToolCalls, that.proxyToolCalls))
354369
return false;
355370
return Objects.equals(toolContext, that.toolContext);
@@ -381,6 +396,7 @@ public int hashCode() {
381396
result = prime * result + (toolContext != null ? toolContext.hashCode() : 0);
382397
return result;
383398
}
399+
384400
public static class Builder {
385401

386402
private final HunYuanChatOptions options = new HunYuanChatOptions();
@@ -480,6 +496,15 @@ public Builder functions(Set<String> functions) {
480496
return this;
481497
}
482498

499+
public Builder function(String functionName) {
500+
Assert.hasText(functionName, "Function name must not be empty");
501+
if (this.options.functions == null) {
502+
this.options.functions = new HashSet<>();
503+
}
504+
this.options.functions.add(functionName);
505+
return this;
506+
}
507+
483508
public Builder proxyToolCalls(Boolean proxyToolCalls) {
484509
options.setProxyToolCalls(proxyToolCalls);
485510
return this;
@@ -493,6 +518,7 @@ public Builder toolContext(Map<String, Object> toolContext) {
493518
public HunYuanChatOptions build() {
494519
return options;
495520
}
521+
496522
}
497523

498524
}

models/spring-ai-hunyuan/src/main/java/org/springframework/ai/hunyuan/aot/HunYuanRuntimeHints.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
2727

2828
/**
29-
* The HunYuanRuntimeHints class is responsible for registering runtime hints for
30-
* HunYuan API classes.
29+
* The HunYuanRuntimeHints class is responsible for registering runtime hints for HunYuan
30+
* API classes.
3131
*
3232
* @author Guo Junyu
3333
*/

0 commit comments

Comments
 (0)