|
16 | 16 |
|
17 | 17 | package org.springframework.ai.hunyuan; |
18 | 18 |
|
19 | | -import java.util.HashSet; |
20 | | -import java.util.List; |
21 | | -import java.util.Map; |
22 | | -import java.util.Set; |
| 19 | +import java.util.*; |
23 | 20 | import java.util.concurrent.ConcurrentHashMap; |
| 21 | +import java.util.stream.Collectors; |
24 | 22 |
|
25 | 23 | import io.micrometer.observation.Observation; |
26 | 24 | import io.micrometer.observation.ObservationRegistry; |
27 | 25 | import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; |
28 | 26 | import org.slf4j.Logger; |
29 | 27 | import org.slf4j.LoggerFactory; |
| 28 | +import org.springframework.ai.chat.messages.UserMessage; |
30 | 29 | import org.springframework.ai.hunyuan.api.HunYuanApi; |
31 | 30 | import org.springframework.ai.hunyuan.api.HunYuanApi.*; |
32 | 31 | import org.springframework.ai.hunyuan.api.HunYuanApi.ChatCompletionMessage.*; |
33 | 32 | import org.springframework.ai.hunyuan.api.HunYuanApi.ChatCompletion.*; |
34 | 33 | import org.springframework.ai.hunyuan.metadata.HunYuanUsage; |
| 34 | +import org.springframework.util.MimeType; |
35 | 35 | import reactor.core.publisher.Flux; |
36 | 36 | import reactor.core.publisher.Mono; |
37 | 37 |
|
@@ -326,27 +326,51 @@ private ChatResponseMetadata from(ChatCompletionRequest request, ChatCompletion |
326 | 326 | * @return the ChatCompletion |
327 | 327 | */ |
328 | 328 | private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) { |
329 | | - List<ChatCompletion.Choice> choices = chunk.choices().stream().map(cc -> { |
330 | | - ChatCompletionMessage delta = cc.message(); |
331 | | - if (delta == null) { |
332 | | - delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.assistant); |
333 | | - } |
334 | | - return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(),null); |
335 | | - }).toList(); |
| 329 | + List<ChatCompletion.Choice> choices = chunk.choices() |
| 330 | + .stream() |
| 331 | + .map(chunkChoice -> { |
| 332 | + ChatCompletionMessage chatCompletionMessage = null; |
| 333 | + ChatCompletionDelta delta = chunkChoice.delta(); |
| 334 | + if (delta == null) { |
| 335 | + chatCompletionMessage = new ChatCompletionMessage("", Role.assistant); |
| 336 | + }else { |
| 337 | + chatCompletionMessage = new ChatCompletionMessage(delta.content(), delta.role(),delta.toolCalls()); |
| 338 | + } |
| 339 | + return new ChatCompletion.Choice(chunkChoice.index(), chatCompletionMessage, chunkChoice.finishReason(),delta); |
| 340 | + }) |
| 341 | + .toList(); |
336 | 342 |
|
337 | | - return new ChatCompletion(chunk.id(), null, chunk.created(), chunk.note(), choices, null, null, null, null, null, null); |
| 343 | + return new ChatCompletion(chunk.id(), chunk.errorMsg(), chunk.created(), chunk.note(), choices, chunk.usage(), chunk.moderationLevel(), chunk.searchInfo(), chunk.replaces(), chunk.recommendedQuestions(), chunk.requestId()); |
338 | 344 | } |
339 | 345 |
|
340 | 346 | /** |
341 | 347 | * Accessible for testing. |
342 | 348 | */ |
343 | 349 | public HunYuanApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { |
344 | | - |
345 | | - List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> { |
346 | | - if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) { |
| 350 | + List<ChatCompletionMessage> systemMessages = new ArrayList<>(); |
| 351 | + List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().filter(message -> { |
| 352 | + if (message.getMessageType() == MessageType.SYSTEM) { |
347 | 353 | Object content = message.getText(); |
348 | | - return List.of(new ChatCompletionMessage(content, |
349 | | - ChatCompletionMessage.Role.valueOf(message.getMessageType().name()))); |
| 354 | + systemMessages.add(new ChatCompletionMessage(content, Role.system)); |
| 355 | + return false; |
| 356 | + } |
| 357 | + return true; |
| 358 | + }).map(message -> { |
| 359 | + if (message.getMessageType() == MessageType.USER) { |
| 360 | + Object content = message.getText(); |
| 361 | + if (message instanceof UserMessage userMessage) { |
| 362 | + if (!CollectionUtils.isEmpty(userMessage.getMedia())) { |
| 363 | + List<ChatContent> contentList = new ArrayList<>(List.of(new ChatContent(message.getText()))); |
| 364 | + |
| 365 | + contentList.addAll(userMessage.getMedia() |
| 366 | + .stream() |
| 367 | + .map(media -> new ChatContent(new ImageUrl( |
| 368 | + this.fromMediaData(media.getMimeType(), media.getData())))) |
| 369 | + .toList()); |
| 370 | + return List.of(new ChatCompletionMessage(Role.user,contentList)); |
| 371 | + } |
| 372 | + } |
| 373 | + return List.of(new ChatCompletionMessage(content,Role.user)); |
350 | 374 | } |
351 | 375 | else if (message.getMessageType() == MessageType.ASSISTANT) { |
352 | 376 | var assistantMessage = (AssistantMessage) message; |
@@ -375,8 +399,10 @@ else if (message.getMessageType() == MessageType.TOOL) { |
375 | 399 | else { |
376 | 400 | throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); |
377 | 401 | } |
378 | | - }).flatMap(List::stream).toList(); |
379 | | - |
| 402 | + }).flatMap(List::stream).collect(Collectors.toList()); |
| 403 | + systemMessages.stream().forEach(systemMessage -> { |
| 404 | + chatCompletionMessages.add(0, systemMessage); |
| 405 | + }); |
380 | 406 | ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); |
381 | 407 |
|
382 | 408 | Set<String> enabledToolsToUse = new HashSet<>(); |
@@ -413,7 +439,21 @@ else if (message.getMessageType() == MessageType.TOOL) { |
413 | 439 |
|
414 | 440 | return request; |
415 | 441 | } |
416 | | - |
| 442 | + private String fromMediaData(MimeType mimeType, Object mediaContentData) { |
| 443 | + if (mediaContentData instanceof byte[] bytes) { |
| 444 | + // Assume the bytes are an image. So, convert the bytes to a base64 encoded |
| 445 | + // following the prefix pattern. |
| 446 | + return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); |
| 447 | + } |
| 448 | + else if (mediaContentData instanceof String text) { |
| 449 | + // Assume the text is a URLs or a base64 encoded image prefixed by the user. |
| 450 | + return text; |
| 451 | + } |
| 452 | + else { |
| 453 | + throw new IllegalArgumentException( |
| 454 | + "Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); |
| 455 | + } |
| 456 | + } |
417 | 457 | private ChatOptions buildRequestOptions(HunYuanApi.ChatCompletionRequest request) { |
418 | 458 | return ChatOptions.builder() |
419 | 459 | .model(request.model()) |
|
0 commit comments