|
18 | 18 |
|
19 | 19 | import java.util.ArrayList; |
20 | 20 | import java.util.Base64; |
| 21 | +import java.util.HashMap; |
21 | 22 | import java.util.List; |
22 | 23 | import java.util.Map; |
23 | 24 | import java.util.Set; |
@@ -379,46 +380,45 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage |
379 | 380 | return new ChatResponse(List.of()); |
380 | 381 | } |
381 | 382 |
|
382 | | - List<Generation> generations = chatCompletion.content() |
383 | | - .stream() |
384 | | - .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) |
385 | | - .map(content -> new Generation(new AssistantMessage(content.text(), Map.of()), |
386 | | - ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())) |
387 | | - .toList(); |
388 | | - |
389 | | - List<Generation> allGenerations = new ArrayList<>(generations); |
| 383 | + List<Generation> generations = new ArrayList<>(); |
| 384 | + for (ContentBlock content : chatCompletion.content()) { |
| 385 | + switch (content.type()) { |
| 386 | + case TEXT: |
| 387 | + generations.add(new Generation(new AssistantMessage(content.text(), Map.of()), |
| 388 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 389 | + break; |
| 390 | + case THINKING: |
| 391 | + Map<String, Object> thinkingProperties = new HashMap<>(); |
| 392 | + thinkingProperties.put("signature", content.signature()); |
| 393 | + generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties), |
| 394 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 395 | + break; |
| 396 | + case REDACTED_THINKING: |
| 397 | + Map<String, Object> redactedProperties = new HashMap<>(); |
| 398 | + redactedProperties.put("data", content.data()); |
| 399 | + generations.add(new Generation(new AssistantMessage(null, redactedProperties), |
| 400 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 401 | + break; |
| 402 | + case TOOL_USE: |
| 403 | + var functionCallId = content.id(); |
| 404 | + var functionName = content.name(); |
| 405 | + var functionArguments = JsonParser.toJson(content.input()); |
| 406 | + AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), |
| 407 | + List.of(new AssistantMessage.ToolCall(functionCallId, "function", functionName, |
| 408 | + functionArguments))); |
| 409 | + generations.add(new Generation(assistantMessage, |
| 410 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 411 | + break; |
| 412 | + } |
| 413 | + } |
390 | 414 |
|
391 | 415 | if (chatCompletion.stopReason() != null && generations.isEmpty()) { |
392 | 416 | Generation generation = new Generation(new AssistantMessage(null, Map.of()), |
393 | 417 | ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); |
394 | | - allGenerations.add(generation); |
395 | | - } |
396 | | - |
397 | | - List<ContentBlock> toolToUseList = chatCompletion.content() |
398 | | - .stream() |
399 | | - .filter(c -> c.type() == ContentBlock.Type.TOOL_USE) |
400 | | - .toList(); |
401 | | - |
402 | | - if (!CollectionUtils.isEmpty(toolToUseList)) { |
403 | | - List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(); |
404 | | - |
405 | | - for (ContentBlock toolToUse : toolToUseList) { |
406 | | - |
407 | | - var functionCallId = toolToUse.id(); |
408 | | - var functionName = toolToUse.name(); |
409 | | - var functionArguments = JsonParser.toJson(toolToUse.input()); |
410 | | - |
411 | | - toolCalls |
412 | | - .add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); |
413 | | - } |
414 | | - |
415 | | - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); |
416 | | - Generation toolCallGeneration = new Generation(assistantMessage, |
417 | | - ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); |
418 | | - allGenerations.add(toolCallGeneration); |
| 418 | + generations.add(generation); |
419 | 419 | } |
420 | 420 |
|
421 | | - return new ChatResponse(allGenerations, this.from(chatCompletion, usage)); |
| 421 | + return new ChatResponse(generations, this.from(chatCompletion, usage)); |
422 | 422 | } |
423 | 423 |
|
424 | 424 | private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { |
@@ -575,7 +575,7 @@ else if (message.getMessageType() == MessageType.TOOL) { |
575 | 575 | List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); |
576 | 576 | if (!CollectionUtils.isEmpty(toolDefinitions)) { |
577 | 577 | request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); |
578 | | - request = ChatCompletionRequest.from(request).withTools(getFunctionTools(toolDefinitions)).build(); |
| 578 | + request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build(); |
579 | 579 | } |
580 | 580 |
|
581 | 581 | return request; |
|
0 commit comments