|
30 | 30 | import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; |
31 | 31 | import org.slf4j.Logger; |
32 | 32 | import org.slf4j.LoggerFactory; |
| 33 | +import org.springframework.ai.model.tool.LegacyToolCallingManager; |
| 34 | +import org.springframework.ai.model.tool.ToolCallingChatOptions; |
| 35 | +import org.springframework.ai.model.tool.ToolCallingManager; |
| 36 | +import org.springframework.ai.model.tool.ToolExecutionResult; |
| 37 | +import org.springframework.ai.tool.definition.ToolDefinition; |
| 38 | +import org.springframework.ai.util.json.JsonParser; |
| 39 | +import org.springframework.lang.Nullable; |
| 40 | + |
33 | 41 | import reactor.core.publisher.Flux; |
34 | 42 | import reactor.core.publisher.Mono; |
35 | 43 | import reactor.core.scheduler.Schedulers; |
@@ -279,46 +287,49 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage |
279 | 287 | return new ChatResponse(List.of()); |
280 | 288 | } |
281 | 289 |
|
282 | | - List<Generation> generations = chatCompletion.content() |
283 | | - .stream() |
284 | | - .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) |
285 | | - .map(content -> new Generation(new AssistantMessage(content.text(), Map.of()), |
286 | | - ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())) |
287 | | - .toList(); |
288 | | - |
289 | | - List<Generation> allGenerations = new ArrayList<>(generations); |
| 290 | + List<Generation> generations = new ArrayList<>(); |
| 291 | + List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(); |
| 292 | + for (ContentBlock content : chatCompletion.content()) { |
| 293 | + switch (content.type()) { |
| 294 | + case TEXT, TEXT_DELTA: |
| 295 | + generations.add(new Generation(new AssistantMessage(content.text(), Map.of()), |
| 296 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 297 | + break; |
| 298 | + case THINKING, THINKING_DELTA: |
| 299 | + Map<String, Object> thinkingProperties = new HashMap<>(); |
| 300 | + thinkingProperties.put("signature", content.signature()); |
| 301 | + generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties), |
| 302 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 303 | + break; |
| 304 | + case REDACTED_THINKING: |
| 305 | + Map<String, Object> redactedProperties = new HashMap<>(); |
| 306 | + redactedProperties.put("data", content.data()); |
| 307 | + generations.add(new Generation(new AssistantMessage(null, redactedProperties), |
| 308 | + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); |
| 309 | + break; |
| 310 | + case TOOL_USE: |
| 311 | + var functionCallId = content.id(); |
| 312 | + var functionName = content.name(); |
| 313 | + var functionArguments = JsonParser.toJson(content.input()); |
| 314 | + toolCalls.add( |
| 315 | + new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); |
| 316 | + break; |
| 317 | + } |
| 318 | + } |
290 | 319 |
|
291 | 320 | if (chatCompletion.stopReason() != null && generations.isEmpty()) { |
292 | 321 | Generation generation = new Generation(new AssistantMessage(null, Map.of()), |
293 | 322 | ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); |
294 | | - allGenerations.add(generation); |
| 323 | + generations.add(generation); |
295 | 324 | } |
296 | 325 |
|
297 | | - List<ContentBlock> toolToUseList = chatCompletion.content() |
298 | | - .stream() |
299 | | - .filter(c -> c.type() == ContentBlock.Type.TOOL_USE) |
300 | | - .toList(); |
301 | | - |
302 | | - if (!CollectionUtils.isEmpty(toolToUseList)) { |
303 | | - List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(); |
304 | | - |
305 | | - for (ContentBlock toolToUse : toolToUseList) { |
306 | | - |
307 | | - var functionCallId = toolToUse.id(); |
308 | | - var functionName = toolToUse.name(); |
309 | | - var functionArguments = JsonParser.toJson(toolToUse.input()); |
310 | | - |
311 | | - toolCalls |
312 | | - .add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); |
313 | | - } |
314 | | - |
| 326 | + if (!CollectionUtils.isEmpty(toolCalls)) { |
315 | 327 | AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); |
316 | 328 | Generation toolCallGeneration = new Generation(assistantMessage, |
317 | 329 | ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); |
318 | | - allGenerations.add(toolCallGeneration); |
| 330 | + generations.add(toolCallGeneration); |
319 | 331 | } |
320 | | - |
321 | | - return new ChatResponse(allGenerations, this.from(chatCompletion, usage)); |
| 332 | + return new ChatResponse(generations, this.from(chatCompletion, usage)); |
322 | 333 | } |
323 | 334 |
|
324 | 335 | private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { |
@@ -490,7 +501,7 @@ else if (message.getMessageType() == MessageType.TOOL) { |
490 | 501 | List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); |
491 | 502 | if (!CollectionUtils.isEmpty(toolDefinitions)) { |
492 | 503 | request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); |
493 | | - request = ChatCompletionRequest.from(request).withTools(getFunctionTools(toolDefinitions)).build(); |
| 504 | + request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build(); |
494 | 505 | } |
495 | 506 |
|
496 | 507 | return request; |
|
0 commit comments