Skip to content

Commit 6363352

Browse files
committed
Improve Anthropic function calling
- factor out the common funciton calling logic form AnthropicChatModel to the AbstractToolCallSupport. - improve the AbstractToolCallSupport isToolCall to handle OpenAi and Anthropic. - fix an issue with the function calling streaming aggreagation leading to lost usage statistics. - small code improvements for OpenAiChatModel.
1 parent 554fbcd commit 6363352

File tree

8 files changed

+149
-135
lines changed

8 files changed

+149
-135
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 63 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -15,48 +15,48 @@
1515
*/
1616
package org.springframework.ai.anthropic;
1717

18+
import java.util.ArrayList;
19+
import java.util.Base64;
20+
import java.util.HashSet;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.Set;
24+
import java.util.stream.Collectors;
25+
1826
import org.slf4j.Logger;
1927
import org.slf4j.LoggerFactory;
2028
import org.springframework.ai.anthropic.api.AnthropicApi;
2129
import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage;
2230
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
2331
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse;
2432
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock;
25-
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.ContentBlockType;
33+
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
2634
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
2735
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
2836
import org.springframework.ai.chat.messages.AssistantMessage;
29-
import org.springframework.ai.chat.messages.Message;
3037
import org.springframework.ai.chat.messages.MessageType;
3138
import org.springframework.ai.chat.messages.ToolResponseMessage;
3239
import org.springframework.ai.chat.messages.UserMessage;
3340
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3441
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
42+
import org.springframework.ai.chat.model.AbstractToolCallSupport;
3543
import org.springframework.ai.chat.model.ChatModel;
3644
import org.springframework.ai.chat.model.ChatResponse;
3745
import org.springframework.ai.chat.model.Generation;
3846
import org.springframework.ai.chat.prompt.ChatOptions;
3947
import org.springframework.ai.chat.prompt.Prompt;
4048
import org.springframework.ai.model.ModelOptionsUtils;
41-
import org.springframework.ai.chat.model.AbstractToolCallSupport;
4249
import org.springframework.ai.model.function.FunctionCallbackContext;
4350
import org.springframework.ai.retry.RetryUtils;
4451
import org.springframework.http.ResponseEntity;
4552
import org.springframework.retry.support.RetryTemplate;
4653
import org.springframework.util.Assert;
4754
import org.springframework.util.CollectionUtils;
4855
import org.springframework.util.StringUtils;
56+
4957
import reactor.core.publisher.Flux;
5058
import reactor.core.publisher.Mono;
5159

52-
import java.util.ArrayList;
53-
import java.util.Base64;
54-
import java.util.HashSet;
55-
import java.util.List;
56-
import java.util.Map;
57-
import java.util.Set;
58-
import java.util.stream.Collectors;
59-
6060
/**
6161
* The {@link ChatModel} implementation for the Anthropic service.
6262
*
@@ -150,86 +150,84 @@ public ChatResponse call(Prompt prompt) {
150150

151151
ChatCompletionRequest request = createRequest(prompt, false);
152152

153-
return this.retryTemplate.execute(ctx -> {
154-
ResponseEntity<ChatCompletionResponse> completionEntity = this.anthropicApi.chatCompletionEntity(request);
153+
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
154+
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));
155155

156-
if (this.isToolFunctionCall(completionEntity.getBody())) {
157-
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
158-
completionEntity.getBody());
159-
return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions()));
160-
}
156+
ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
161157

162-
return toChatResponse(completionEntity.getBody());
163-
});
158+
if (this.isToolCall(chatResponse, Set.of("tool_use"))) {
159+
var toolCallConversation = handleToolCalls(prompt, chatResponse);
160+
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
161+
}
162+
163+
return chatResponse;
164164
}
165165

166166
@Override
167167
public Flux<ChatResponse> stream(Prompt prompt) {
168168

169169
ChatCompletionRequest request = createRequest(prompt, true);
170170

171-
return this.retryTemplate.execute(ctx -> {
171+
Flux<ChatCompletionResponse> response = this.retryTemplate
172+
.execute(ctx -> this.anthropicApi.chatCompletionStream(request));
172173

173-
Flux<ChatCompletionResponse> response = this.anthropicApi.chatCompletionStream(request);
174+
return response.switchMap(chatCompletionResponse -> {
174175

175-
return response.switchMap(chatCompletionResponse -> {
176+
ChatResponse chatResponse = toChatResponse(chatCompletionResponse);
176177

177-
if (this.isToolFunctionCall(chatCompletionResponse)) {
178-
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
179-
chatCompletionResponse);
180-
return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions()));
181-
}
178+
if (this.isToolCall(chatResponse, Set.of("tool_use"))) {
179+
var toolCallConversation = handleToolCalls(prompt, chatResponse);
180+
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
181+
}
182182

183-
return Mono.just(chatCompletionResponse).map(this::toChatResponse);
184-
});
183+
return Mono.just(chatResponse);
185184
});
186185
}
187186

188-
private List<Message> handleToolCallRequests(List<Message> previousMessages,
189-
ChatCompletionResponse chatCompletionResponse) {
187+
private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
190188

191-
AnthropicMessage anthropicAssistantMessage = new AnthropicMessage(chatCompletionResponse.content(),
192-
Role.ASSISTANT);
189+
if (chatCompletion == null) {
190+
logger.warn("Null chat completion returned");
191+
return new ChatResponse(List.of());
192+
}
193193

194-
List<ContentBlock> toolToUseList = anthropicAssistantMessage.content()
194+
List<Generation> generations = chatCompletion.content()
195195
.stream()
196-
.filter(c -> c.type() == ContentBlock.ContentBlockType.TOOL_USE)
196+
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
197+
.map(content -> {
198+
new AssistantMessage(content.text(), Map.of());
199+
return new Generation(new AssistantMessage(content.text(), Map.of()),
200+
ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
201+
})
197202
.toList();
198203

199-
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
200-
201-
for (ContentBlock toolToUse : toolToUseList) {
204+
List<Generation> allGenerations = new ArrayList<>(generations);
202205

203-
var functionCallId = toolToUse.id();
204-
var functionName = toolToUse.name();
205-
var functionArguments = ModelOptionsUtils.toJsonString(toolToUse.input());
206+
List<ContentBlock> toolToUseList = chatCompletion.content()
207+
.stream()
208+
.filter(c -> c.type() == ContentBlock.Type.TOOL_USE)
209+
.toList();
206210

207-
toolCalls.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
208-
}
211+
if (!CollectionUtils.isEmpty(toolToUseList)) {
212+
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
209213

210-
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
211-
ToolResponseMessage toolResponseMessage = this.executeFunctions(assistantMessage);
214+
for (ContentBlock toolToUse : toolToUseList) {
212215

213-
// History
214-
List<Message> toolCallMessageConversation = new ArrayList<>(previousMessages);
215-
toolCallMessageConversation.add(assistantMessage);
216-
toolCallMessageConversation.add(toolResponseMessage);
216+
var functionCallId = toolToUse.id();
217+
var functionName = toolToUse.name();
218+
var functionArguments = ModelOptionsUtils.toJsonString(toolToUse.input());
217219

218-
return toolCallMessageConversation;
219-
}
220+
toolCalls
221+
.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
222+
}
220223

221-
private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
222-
if (chatCompletion == null) {
223-
logger.warn("Null chat completion returned");
224-
return new ChatResponse(List.of());
224+
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
225+
Generation toolCallGeneration = new Generation(assistantMessage,
226+
ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
227+
allGenerations.add(toolCallGeneration);
225228
}
226229

227-
List<Generation> generations = chatCompletion.content().stream().map(content -> {
228-
return new Generation(content.text(), Map.of())
229-
.withGenerationMetadata(ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
230-
}).toList();
231-
232-
return new ChatResponse(generations, from(chatCompletion));
230+
return new ChatResponse(allGenerations, this.from(chatCompletion));
233231
}
234232

235233
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
@@ -288,16 +286,16 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
288286
}
289287
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
290288
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
291-
contentBlocks.add(new ContentBlock(ContentBlockType.TOOL_USE, toolCall.id(),
292-
toolCall.name(), ModelOptionsUtils.jsonToMap(toolCall.arguments())));
289+
contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
290+
ModelOptionsUtils.jsonToMap(toolCall.arguments())));
293291
}
294292
}
295293
return new AnthropicMessage(contentBlocks, Role.ASSISTANT);
296294
}
297295
else if (message.getMessageType() == MessageType.TOOL) {
298296
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
299297
.stream()
300-
.map(toolResponse -> new ContentBlock(ContentBlockType.TOOL_RESULT, toolResponse.id(),
298+
.map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(),
301299
toolResponse.responseData()))
302300
.toList();
303301
return new AnthropicMessage(toolResponses, Role.USER);
@@ -355,16 +353,6 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
355353
}).toList();
356354
}
357355

358-
@SuppressWarnings("null")
359-
protected boolean isToolFunctionCall(ChatCompletionResponse response) {
360-
if (response == null || CollectionUtils.isEmpty(response.content())) {
361-
return false;
362-
}
363-
return response.content()
364-
.stream()
365-
.anyMatch(content -> content.type() == ContentBlock.ContentBlockType.TOOL_USE);
366-
}
367-
368356
@Override
369357
public ChatOptions getDefaultOptions() {
370358
return AnthropicChatOptions.fromOptions(this.defaultOptions);

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.function.Consumer;
2424
import java.util.function.Predicate;
2525

26+
import org.slf4j.Logger;
27+
import org.slf4j.LoggerFactory;
2628
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
2729
import org.springframework.ai.model.ChatModelDescription;
2830
import org.springframework.ai.model.ModelOptionsUtils;
@@ -53,6 +55,8 @@
5355
*/
5456
public class AnthropicApi {
5557

58+
private static final Logger logger = LoggerFactory.getLogger(AnthropicApi.class);
59+
5660
private static final String HEADER_X_API_KEY = "x-api-key";
5761

5862
private static final String HEADER_ANTHROPIC_VERSION = "anthropic-version";
@@ -415,7 +419,7 @@ public record AnthropicMessage( // @formatter:off
415419
*/
416420
@JsonInclude(Include.NON_NULL)
417421
public record ContentBlock( // @formatter:off
418-
@JsonProperty("type") ContentBlockType type,
422+
@JsonProperty("type") Type type,
419423
@JsonProperty("source") Source source,
420424
@JsonProperty("text") String text,
421425

@@ -438,67 +442,77 @@ public ContentBlock(String mediaType, String data) {
438442
}
439443

440444
public ContentBlock(Source source) {
441-
this(ContentBlockType.IMAGE, source, null, null, null, null, null, null, null);
445+
this(Type.IMAGE, source, null, null, null, null, null, null, null);
442446
}
443447

444448
public ContentBlock(String text) {
445-
this(ContentBlockType.TEXT, null, text, null, null, null, null, null, null);
449+
this(Type.TEXT, null, text, null, null, null, null, null, null);
446450
}
447451

448452
// Tool result
449-
public ContentBlock(ContentBlockType type, String toolUseId, String content) {
453+
public ContentBlock(Type type, String toolUseId, String content) {
450454
this(type, null, null, null, null, null, null, toolUseId, content);
451455
}
452456

453-
public ContentBlock(ContentBlockType type, Source source, String text, Integer index) {
457+
public ContentBlock(Type type, Source source, String text, Integer index) {
454458
this(type, source, text, index, null, null, null, null, null);
455459
}
456460

457461
// Tool use input JSON delta streaming
458-
public ContentBlock(ContentBlockType type, String id, String name, Map<String, Object> input) {
462+
public ContentBlock(Type type, String id, String name, Map<String, Object> input) {
459463
this(type, null, null, null, id, name, input, null, null);
460464
}
461465

462466
/**
463-
* The type of this message.
467+
* The ContentBlock type.
464468
*/
465-
public enum ContentBlockType {
469+
public enum Type {
466470

467471
/**
468472
* Tool request
469473
*/
470474
@JsonProperty("tool_use")
471-
TOOL_USE,
475+
TOOL_USE("tool_use"),
472476

473477
/**
474478
* Send tool result back to LLM.
475479
*/
476480
@JsonProperty("tool_result")
477-
TOOL_RESULT,
481+
TOOL_RESULT("tool_result"),
478482

479483
/**
480484
* Text message.
481485
*/
482486
@JsonProperty("text")
483-
TEXT,
487+
TEXT("text"),
484488

485489
/**
486490
* Text delta message. Returned from the streaming response.
487491
*/
488492
@JsonProperty("text_delta")
489-
TEXT_DELTA,
493+
TEXT_DELTA("text_delta"),
490494

491495
/**
492496
* Tool use input partial JSON delta streaming.
493497
*/
494498
@JsonProperty("input_json_delta")
495-
INPUT_JSON_DELTA,
499+
INPUT_JSON_DELTA("input_json_delta"),
496500

497501
/**
498502
* Image message.
499503
*/
500504
@JsonProperty("image")
501-
IMAGE;
505+
IMAGE("image");
506+
507+
public final String value;
508+
509+
Type(String value) {
510+
this.value = value;
511+
}
512+
513+
public String getValue() {
514+
return this.value;
515+
}
502516

503517
}
504518

@@ -902,6 +916,7 @@ public Flux<ChatCompletionResponse> chatCompletionStream(ChatCompletionRequest c
902916
.takeUntil(SSE_DONE_PREDICATE)
903917
.filter(SSE_DONE_PREDICATE.negate())
904918
.map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class))
919+
.filter(event -> event.type() != EventType.PING)
905920
// Detect if the chunk is part of a streaming function call.
906921
.map(event -> {
907922
if (this.streamHelper.isToolUseStart(event)) {

0 commit comments

Comments
 (0)