|
15 | 15 | */
|
16 | 16 | package org.springframework.ai.anthropic;
|
17 | 17 |
|
| 18 | +import java.util.ArrayList; |
| 19 | +import java.util.Base64; |
| 20 | +import java.util.HashSet; |
| 21 | +import java.util.List; |
| 22 | +import java.util.Map; |
| 23 | +import java.util.Set; |
| 24 | +import java.util.stream.Collectors; |
| 25 | + |
18 | 26 | import org.slf4j.Logger;
|
19 | 27 | import org.slf4j.LoggerFactory;
|
20 | 28 | import org.springframework.ai.anthropic.api.AnthropicApi;
|
21 | 29 | import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage;
|
22 | 30 | import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
|
23 | 31 | import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse;
|
24 | 32 | import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock;
|
25 |
| -import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.ContentBlockType; |
| 33 | +import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; |
26 | 34 | import org.springframework.ai.anthropic.api.AnthropicApi.Role;
|
27 | 35 | import org.springframework.ai.anthropic.metadata.AnthropicUsage;
|
28 | 36 | import org.springframework.ai.chat.messages.AssistantMessage;
|
29 |
| -import org.springframework.ai.chat.messages.Message; |
30 | 37 | import org.springframework.ai.chat.messages.MessageType;
|
31 | 38 | import org.springframework.ai.chat.messages.ToolResponseMessage;
|
32 | 39 | import org.springframework.ai.chat.messages.UserMessage;
|
33 | 40 | import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
34 | 41 | import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
| 42 | +import org.springframework.ai.chat.model.AbstractToolCallSupport; |
35 | 43 | import org.springframework.ai.chat.model.ChatModel;
|
36 | 44 | import org.springframework.ai.chat.model.ChatResponse;
|
37 | 45 | import org.springframework.ai.chat.model.Generation;
|
38 | 46 | import org.springframework.ai.chat.prompt.ChatOptions;
|
39 | 47 | import org.springframework.ai.chat.prompt.Prompt;
|
40 | 48 | import org.springframework.ai.model.ModelOptionsUtils;
|
41 |
| -import org.springframework.ai.chat.model.AbstractToolCallSupport; |
42 | 49 | import org.springframework.ai.model.function.FunctionCallbackContext;
|
43 | 50 | import org.springframework.ai.retry.RetryUtils;
|
44 | 51 | import org.springframework.http.ResponseEntity;
|
45 | 52 | import org.springframework.retry.support.RetryTemplate;
|
46 | 53 | import org.springframework.util.Assert;
|
47 | 54 | import org.springframework.util.CollectionUtils;
|
48 | 55 | import org.springframework.util.StringUtils;
|
| 56 | + |
49 | 57 | import reactor.core.publisher.Flux;
|
50 | 58 | import reactor.core.publisher.Mono;
|
51 | 59 |
|
52 |
| -import java.util.ArrayList; |
53 |
| -import java.util.Base64; |
54 |
| -import java.util.HashSet; |
55 |
| -import java.util.List; |
56 |
| -import java.util.Map; |
57 |
| -import java.util.Set; |
58 |
| -import java.util.stream.Collectors; |
59 |
| - |
60 | 60 | /**
|
61 | 61 | * The {@link ChatModel} implementation for the Anthropic service.
|
62 | 62 | *
|
@@ -150,86 +150,84 @@ public ChatResponse call(Prompt prompt) {
|
150 | 150 |
|
151 | 151 | ChatCompletionRequest request = createRequest(prompt, false);
|
152 | 152 |
|
153 |
| - return this.retryTemplate.execute(ctx -> { |
154 |
| - ResponseEntity<ChatCompletionResponse> completionEntity = this.anthropicApi.chatCompletionEntity(request); |
| 153 | + ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate |
| 154 | + .execute(ctx -> this.anthropicApi.chatCompletionEntity(request)); |
155 | 155 |
|
156 |
| - if (this.isToolFunctionCall(completionEntity.getBody())) { |
157 |
| - List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), |
158 |
| - completionEntity.getBody()); |
159 |
| - return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions())); |
160 |
| - } |
| 156 | + ChatResponse chatResponse = toChatResponse(completionEntity.getBody()); |
161 | 157 |
|
162 |
| - return toChatResponse(completionEntity.getBody()); |
163 |
| - }); |
| 158 | + if (this.isToolCall(chatResponse, Set.of("tool_use"))) { |
| 159 | + var toolCallConversation = handleToolCalls(prompt, chatResponse); |
| 160 | + return this.call(new Prompt(toolCallConversation, prompt.getOptions())); |
| 161 | + } |
| 162 | + |
| 163 | + return chatResponse; |
164 | 164 | }
|
165 | 165 |
|
166 | 166 | @Override
|
167 | 167 | public Flux<ChatResponse> stream(Prompt prompt) {
|
168 | 168 |
|
169 | 169 | ChatCompletionRequest request = createRequest(prompt, true);
|
170 | 170 |
|
171 |
| - return this.retryTemplate.execute(ctx -> { |
| 171 | + Flux<ChatCompletionResponse> response = this.retryTemplate |
| 172 | + .execute(ctx -> this.anthropicApi.chatCompletionStream(request)); |
172 | 173 |
|
173 |
| - Flux<ChatCompletionResponse> response = this.anthropicApi.chatCompletionStream(request); |
| 174 | + return response.switchMap(chatCompletionResponse -> { |
174 | 175 |
|
175 |
| - return response.switchMap(chatCompletionResponse -> { |
| 176 | + ChatResponse chatResponse = toChatResponse(chatCompletionResponse); |
176 | 177 |
|
177 |
| - if (this.isToolFunctionCall(chatCompletionResponse)) { |
178 |
| - List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), |
179 |
| - chatCompletionResponse); |
180 |
| - return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions())); |
181 |
| - } |
| 178 | + if (this.isToolCall(chatResponse, Set.of("tool_use"))) { |
| 179 | + var toolCallConversation = handleToolCalls(prompt, chatResponse); |
| 180 | + return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); |
| 181 | + } |
182 | 182 |
|
183 |
| - return Mono.just(chatCompletionResponse).map(this::toChatResponse); |
184 |
| - }); |
| 183 | + return Mono.just(chatResponse); |
185 | 184 | });
|
186 | 185 | }
|
187 | 186 |
|
188 |
| - private List<Message> handleToolCallRequests(List<Message> previousMessages, |
189 |
| - ChatCompletionResponse chatCompletionResponse) { |
| 187 | + private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) { |
190 | 188 |
|
191 |
| - AnthropicMessage anthropicAssistantMessage = new AnthropicMessage(chatCompletionResponse.content(), |
192 |
| - Role.ASSISTANT); |
| 189 | + if (chatCompletion == null) { |
| 190 | + logger.warn("Null chat completion returned"); |
| 191 | + return new ChatResponse(List.of()); |
| 192 | + } |
193 | 193 |
|
194 |
| - List<ContentBlock> toolToUseList = anthropicAssistantMessage.content() |
| 194 | + List<Generation> generations = chatCompletion.content() |
195 | 195 | .stream()
|
196 |
| - .filter(c -> c.type() == ContentBlock.ContentBlockType.TOOL_USE) |
| 196 | + .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) |
| 197 | + .map(content -> { |
| 198 | + new AssistantMessage(content.text(), Map.of()); |
| 199 | + return new Generation(new AssistantMessage(content.text(), Map.of()), |
| 200 | + ChatGenerationMetadata.from(chatCompletion.stopReason(), null)); |
| 201 | + }) |
197 | 202 | .toList();
|
198 | 203 |
|
199 |
| - List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(); |
200 |
| - |
201 |
| - for (ContentBlock toolToUse : toolToUseList) { |
| 204 | + List<Generation> allGenerations = new ArrayList<>(generations); |
202 | 205 |
|
203 |
| - var functionCallId = toolToUse.id(); |
204 |
| - var functionName = toolToUse.name(); |
205 |
| - var functionArguments = ModelOptionsUtils.toJsonString(toolToUse.input()); |
| 206 | + List<ContentBlock> toolToUseList = chatCompletion.content() |
| 207 | + .stream() |
| 208 | + .filter(c -> c.type() == ContentBlock.Type.TOOL_USE) |
| 209 | + .toList(); |
206 | 210 |
|
207 |
| - toolCalls.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); |
208 |
| - } |
| 211 | + if (!CollectionUtils.isEmpty(toolToUseList)) { |
| 212 | + List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(); |
209 | 213 |
|
210 |
| - AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); |
211 |
| - ToolResponseMessage toolResponseMessage = this.executeFunctions(assistantMessage); |
| 214 | + for (ContentBlock toolToUse : toolToUseList) { |
212 | 215 |
|
213 |
| - // History |
214 |
| - List<Message> toolCallMessageConversation = new ArrayList<>(previousMessages); |
215 |
| - toolCallMessageConversation.add(assistantMessage); |
216 |
| - toolCallMessageConversation.add(toolResponseMessage); |
| 216 | + var functionCallId = toolToUse.id(); |
| 217 | + var functionName = toolToUse.name(); |
| 218 | + var functionArguments = ModelOptionsUtils.toJsonString(toolToUse.input()); |
217 | 219 |
|
218 |
| - return toolCallMessageConversation; |
219 |
| - } |
| 220 | + toolCalls |
| 221 | + .add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); |
| 222 | + } |
220 | 223 |
|
221 |
| - private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) { |
222 |
| - if (chatCompletion == null) { |
223 |
| - logger.warn("Null chat completion returned"); |
224 |
| - return new ChatResponse(List.of()); |
| 224 | + AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); |
| 225 | + Generation toolCallGeneration = new Generation(assistantMessage, |
| 226 | + ChatGenerationMetadata.from(chatCompletion.stopReason(), null)); |
| 227 | + allGenerations.add(toolCallGeneration); |
225 | 228 | }
|
226 | 229 |
|
227 |
| - List<Generation> generations = chatCompletion.content().stream().map(content -> { |
228 |
| - return new Generation(content.text(), Map.of()) |
229 |
| - .withGenerationMetadata(ChatGenerationMetadata.from(chatCompletion.stopReason(), null)); |
230 |
| - }).toList(); |
231 |
| - |
232 |
| - return new ChatResponse(generations, from(chatCompletion)); |
| 230 | + return new ChatResponse(allGenerations, this.from(chatCompletion)); |
233 | 231 | }
|
234 | 232 |
|
235 | 233 | private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
|
@@ -288,16 +286,16 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
|
288 | 286 | }
|
289 | 287 | if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
|
290 | 288 | for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
|
291 |
| - contentBlocks.add(new ContentBlock(ContentBlockType.TOOL_USE, toolCall.id(), |
292 |
| - toolCall.name(), ModelOptionsUtils.jsonToMap(toolCall.arguments()))); |
| 289 | + contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(), |
| 290 | + ModelOptionsUtils.jsonToMap(toolCall.arguments()))); |
293 | 291 | }
|
294 | 292 | }
|
295 | 293 | return new AnthropicMessage(contentBlocks, Role.ASSISTANT);
|
296 | 294 | }
|
297 | 295 | else if (message.getMessageType() == MessageType.TOOL) {
|
298 | 296 | List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
|
299 | 297 | .stream()
|
300 |
| - .map(toolResponse -> new ContentBlock(ContentBlockType.TOOL_RESULT, toolResponse.id(), |
| 298 | + .map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(), |
301 | 299 | toolResponse.responseData()))
|
302 | 300 | .toList();
|
303 | 301 | return new AnthropicMessage(toolResponses, Role.USER);
|
@@ -355,16 +353,6 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
|
355 | 353 | }).toList();
|
356 | 354 | }
|
357 | 355 |
|
358 |
| - @SuppressWarnings("null") |
359 |
| - protected boolean isToolFunctionCall(ChatCompletionResponse response) { |
360 |
| - if (response == null || CollectionUtils.isEmpty(response.content())) { |
361 |
| - return false; |
362 |
| - } |
363 |
| - return response.content() |
364 |
| - .stream() |
365 |
| - .anyMatch(content -> content.type() == ContentBlock.ContentBlockType.TOOL_USE); |
366 |
| - } |
367 |
| - |
368 | 356 | @Override
|
369 | 357 | public ChatOptions getDefaultOptions() {
|
370 | 358 | return AnthropicChatOptions.fromOptions(this.defaultOptions);
|
|
0 commit comments