|
18 | 18 |
|
19 | 19 | import java.util.ArrayList; |
20 | 20 | import java.util.Base64; |
| 21 | +import java.util.HashMap; |
21 | 22 | import java.util.List; |
22 | 23 | import java.util.Map; |
23 | 24 | import java.util.Set; |
|
36 | 37 | import org.springframework.ai.tool.definition.ToolDefinition; |
37 | 38 | import org.springframework.ai.util.json.JsonParser; |
38 | 39 | import org.springframework.lang.Nullable; |
| 40 | + |
39 | 41 | import reactor.core.publisher.Flux; |
40 | 42 | import reactor.core.publisher.Mono; |
41 | 43 |
|
@@ -379,46 +381,51 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage |
379 | 381 | return new ChatResponse(List.of()); |
380 | 382 | } |
381 | 383 |
|
382 | | - List<Generation> generations = chatCompletion.content() |
383 | | - .stream() |
384 | | - .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) |
385 | | - .map(content -> new Generation(new AssistantMessage(content.text(), Map.of()), |
386 | | - ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())) |
387 | | - .toList(); |
388 | | - |
389 | | - List<Generation> allGenerations = new ArrayList<>(generations); |
| 384 | + List<Generation> generations = new ArrayList<>(); |
| 385 | + List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(); |
| 386 | + for (ContentBlock content : chatCompletion.content()) { |
| 387 | + switch (content.type()) { |
| 388 | + case TEXT, TEXT_DELTA: |
| 389 | + generations.add(new Generation(new AssistantMessage(content.text(), Map.of()), |
| 390 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 391 | + break; |
| 392 | + case THINKING, THINKING_DELTA: |
| 393 | + System.out.println( |
| 394 | + "THINKINGTHINKINGTHINKINGTHINKINGTHINKINGTHINKINGTHINKINGcontent type: " + content.type()); |
| 395 | + Map<String, Object> thinkingProperties = new HashMap<>(); |
| 396 | + thinkingProperties.put("signature", content.signature()); |
| 397 | + generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties), |
| 398 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 399 | + break; |
| 400 | + case REDACTED_THINKING: |
| 401 | + Map<String, Object> redactedProperties = new HashMap<>(); |
| 402 | + redactedProperties.put("data", content.data()); |
| 403 | + generations.add(new Generation(new AssistantMessage(null, redactedProperties), |
| 404 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 405 | + break; |
| 406 | + case TOOL_USE: |
| 407 | + var functionCallId = content.id(); |
| 408 | + var functionName = content.name(); |
| 409 | + var functionArguments = JsonParser.toJson(content.input()); |
| 410 | + toolCalls.add( |
| 411 | + new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); |
| 412 | + break; |
| 413 | + } |
| 414 | + } |
390 | 415 |
|
391 | 416 | if (chatCompletion.stopReason() != null && generations.isEmpty()) { |
392 | 417 | Generation generation = new Generation(new AssistantMessage(null, Map.of()), |
393 | 418 | ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); |
394 | | - allGenerations.add(generation); |
| 419 | + generations.add(generation); |
395 | 420 | } |
396 | 421 |
|
397 | | - List<ContentBlock> toolToUseList = chatCompletion.content() |
398 | | - .stream() |
399 | | - .filter(c -> c.type() == ContentBlock.Type.TOOL_USE) |
400 | | - .toList(); |
401 | | - |
402 | | - if (!CollectionUtils.isEmpty(toolToUseList)) { |
403 | | - List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(); |
404 | | - |
405 | | - for (ContentBlock toolToUse : toolToUseList) { |
406 | | - |
407 | | - var functionCallId = toolToUse.id(); |
408 | | - var functionName = toolToUse.name(); |
409 | | - var functionArguments = JsonParser.toJson(toolToUse.input()); |
410 | | - |
411 | | - toolCalls |
412 | | - .add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); |
413 | | - } |
414 | | - |
| 422 | + if (!CollectionUtils.isEmpty(toolCalls)) { |
415 | 423 | AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); |
416 | 424 | Generation toolCallGeneration = new Generation(assistantMessage, |
417 | 425 | ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); |
418 | | - allGenerations.add(toolCallGeneration); |
| 426 | + generations.add(toolCallGeneration); |
419 | 427 | } |
420 | | - |
421 | | - return new ChatResponse(allGenerations, this.from(chatCompletion, usage)); |
| 428 | + return new ChatResponse(generations, this.from(chatCompletion, usage)); |
422 | 429 | } |
423 | 430 |
|
424 | 431 | private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { |
@@ -575,7 +582,7 @@ else if (message.getMessageType() == MessageType.TOOL) { |
575 | 582 | List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); |
576 | 583 | if (!CollectionUtils.isEmpty(toolDefinitions)) { |
577 | 584 | request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); |
578 | | - request = ChatCompletionRequest.from(request).withTools(getFunctionTools(toolDefinitions)).build(); |
| 585 | + request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build(); |
579 | 586 | } |
580 | 587 |
|
581 | 588 | return request; |
|
0 commit comments