Skip to content

Commit 936505c

Browse files
Implemented a web search tool, provided by Anthropic
This web search tool is categorized as a `server tool`. related doc: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool Signed-off-by: jonghoonpark <[email protected]>
1 parent 0a1cf81 commit 936505c

File tree

13 files changed

+552
-47
lines changed

13 files changed

+552
-47
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Map;
2424
import java.util.Set;
2525
import java.util.stream.Collectors;
26+
import java.util.stream.Stream;
2627

2728
import com.fasterxml.jackson.core.type.TypeReference;
2829
import io.micrometer.observation.Observation;
@@ -42,6 +43,7 @@
4243
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
4344
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
4445
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
46+
import org.springframework.ai.anthropic.api.tool.Tool;
4547
import org.springframework.ai.chat.messages.AssistantMessage;
4648
import org.springframework.ai.chat.messages.MessageType;
4749
import org.springframework.ai.chat.messages.ToolResponseMessage;
@@ -342,11 +344,11 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
342344
return new ChatResponse(generations, this.from(chatCompletion, usage));
343345
}
344346

345-
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
347+
private ChatResponseMetadata from(ChatCompletionResponse result) {
346348
return from(result, this.getDefaultUsage(result.usage()));
347349
}
348350

349-
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {
351+
private ChatResponseMetadata from(ChatCompletionResponse result, Usage usage) {
350352
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
351353
return ChatResponseMetadata.builder()
352354
.id(result.id())
@@ -443,20 +445,32 @@ Prompt buildRequestPrompt(Prompt prompt) {
443445
this.defaultOptions.getToolCallbacks()));
444446
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
445447
this.defaultOptions.getToolContext()));
448+
requestOptions.setServerTools(
449+
mergeServerTools(runtimeOptions.getServerTools(), this.defaultOptions.getServerTools()));
446450
}
447451
else {
448452
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
449453
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
450454
requestOptions.setToolNames(this.defaultOptions.getToolNames());
451455
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
452456
requestOptions.setToolContext(this.defaultOptions.getToolContext());
457+
requestOptions.setServerTools(this.defaultOptions.getServerTools());
453458
}
454459

455460
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
456461

457462
return new Prompt(prompt.getInstructions(), requestOptions);
458463
}
459464

465+
static List<Tool> mergeServerTools(List<Tool> runtimeServerTools, List<Tool> defaultToolNames) {
466+
Assert.notNull(runtimeServerTools, "runtimeServerTools cannot be null");
467+
Assert.notNull(defaultToolNames, "defaultToolNames cannot be null");
468+
if (CollectionUtils.isEmpty(runtimeServerTools)) {
469+
return new ArrayList<>(defaultToolNames);
470+
}
471+
return new ArrayList<>(runtimeServerTools);
472+
}
473+
460474
private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHeaders,
461475
Map<String, String> defaultHttpHeaders) {
462476
var mergedHttpHeaders = new HashMap<>(defaultHttpHeaders);
@@ -526,22 +540,31 @@ else if (message.getMessageType() == MessageType.TOOL) {
526540

527541
// Add the tool definitions to the request's tools parameter.
528542
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
543+
Stream<Tool> toolStream = Stream.empty();
529544
if (!CollectionUtils.isEmpty(toolDefinitions)) {
530545
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
531-
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
546+
toolStream = getFunctionToolStream(toolDefinitions);
547+
}
548+
if (!CollectionUtils.isEmpty(requestOptions.getServerTools())) {
549+
toolStream = Stream.concat(toolStream, requestOptions.getServerTools().stream());
550+
}
551+
552+
List<Tool> tools = toolStream.toList();
553+
if (!tools.isEmpty()) {
554+
request = ChatCompletionRequest.from(request).tools(tools).build();
532555
}
533556

534557
return request;
535558
}
536559

537-
private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefinitions) {
560+
private Stream<Tool> getFunctionToolStream(List<ToolDefinition> toolDefinitions) {
538561
return toolDefinitions.stream().map(toolDefinition -> {
539562
var name = toolDefinition.name();
540563
var description = toolDefinition.description();
541564
String inputSchema = toolDefinition.inputSchema();
542-
return new AnthropicApi.Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() {
565+
return new Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() {
543566
}));
544-
}).toList();
567+
});
545568
}
546569

547570
@Override

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import org.springframework.ai.anthropic.api.AnthropicApi;
3434
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
35+
import org.springframework.ai.anthropic.api.tool.Tool;
3536
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3637
import org.springframework.ai.tool.ToolCallback;
3738
import org.springframework.lang.Nullable;
@@ -44,6 +45,7 @@
4445
* @author Thomas Vitale
4546
* @author Alexandros Pappas
4647
* @author Ilayaperumal Gopinathan
48+
* @author Jonghoon Park
4749
* @since 1.0.0
4850
*/
4951
@JsonInclude(Include.NON_NULL)
@@ -82,6 +84,8 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
8284
@JsonIgnore
8385
private Map<String, Object> toolContext = new HashMap<>();
8486

87+
@JsonIgnore
88+
private List<Tool> serverTools = new ArrayList<>();
8589

8690
/**
8791
* Optional HTTP headers to be added to the chat completion request.
@@ -110,6 +114,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
110114
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
111115
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
112116
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
117+
.serverTools(fromOptions.getServerTools() != null ? new ArrayList<>(fromOptions.getServerTools()) : null)
113118
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
114119
.build();
115120
}
@@ -250,6 +255,17 @@ public void setToolContext(Map<String, Object> toolContext) {
250255
this.toolContext = toolContext;
251256
}
252257

258+
@JsonIgnore
259+
public List<Tool> getServerTools() {
260+
return this.serverTools;
261+
}
262+
263+
public void setServerTools(List<Tool> serverTools) {
264+
Assert.notNull(serverTools, "serverTools cannot be null");
265+
Assert.noNullElements(serverTools, "serverTools cannot contain null elements");
266+
this.serverTools = serverTools;
267+
}
268+
253269
@JsonIgnore
254270
public Map<String, String> getHttpHeaders() {
255271
return this.httpHeaders;
@@ -282,14 +298,15 @@ public boolean equals(Object o) {
282298
&& Objects.equals(this.toolNames, that.toolNames)
283299
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
284300
&& Objects.equals(this.toolContext, that.toolContext)
301+
&& Objects.equals(this.serverTools, that.serverTools)
285302
&& Objects.equals(this.httpHeaders, that.httpHeaders);
286303
}
287304

288305
@Override
289306
public int hashCode() {
290307
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
291308
this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
292-
this.toolContext, this.httpHeaders);
309+
this.toolContext, this.serverTools, this.httpHeaders);
293310
}
294311

295312
public static class Builder {
@@ -384,6 +401,16 @@ public Builder toolContext(Map<String, Object> toolContext) {
384401
return this;
385402
}
386403

404+
public Builder serverTools(List<Tool> serverTools) {
405+
if (this.options.serverTools == null) {
406+
this.options.serverTools = serverTools;
407+
}
408+
else {
409+
this.options.serverTools.addAll(serverTools);
410+
}
411+
return this;
412+
}
413+
387414
public Builder httpHeaders(Map<String, String> httpHeaders) {
388415
this.options.setHttpHeaders(httpHeaders);
389416
return this;

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,16 @@
3030
import com.fasterxml.jackson.annotation.JsonProperty;
3131
import com.fasterxml.jackson.annotation.JsonSubTypes;
3232
import com.fasterxml.jackson.annotation.JsonTypeInfo;
33+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
34+
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
3335
import reactor.core.publisher.Flux;
3436
import reactor.core.publisher.Mono;
3537

3638
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
3739
import org.springframework.ai.model.ApiKey;
40+
import org.springframework.ai.anthropic.api.tool.Tool;
41+
import org.springframework.ai.anthropic.util.ContentFieldDeserializer;
42+
import org.springframework.ai.anthropic.util.ContentFieldSerializer;
3843
import org.springframework.ai.model.ChatModelDescription;
3944
import org.springframework.ai.model.ModelOptionsUtils;
4045
import org.springframework.ai.model.SimpleApiKey;
@@ -750,7 +755,11 @@ public record ContentBlock(
750755

751756
// tool_result response only
752757
@JsonProperty("tool_use_id") String toolUseId,
753-
@JsonProperty("content") String content,
758+
759+
@JsonSerialize(using = ContentFieldSerializer.class)
760+
@JsonDeserialize(using = ContentFieldDeserializer.class)
761+
@JsonProperty("content")
762+
Object content,
754763

755764
// Thinking only
756765
@JsonProperty("signature") String signature,
@@ -761,6 +770,15 @@ public record ContentBlock(
761770
) {
762771
// @formatter:on
763772

773+
@JsonInclude(Include.NON_NULL)
774+
@JsonIgnoreProperties(ignoreUnknown = true)
775+
public record WebSearchToolContentBlock(@JsonProperty("type") String type, @JsonProperty("title") String title,
776+
@JsonProperty("url") String url, @JsonProperty("encrypted_content") String EncryptedContent,
777+
@JsonProperty("page_age") String pageAge) {
778+
779+
}
780+
// @formatter:on
781+
764782
/**
765783
* Create content block
766784
* @param mediaType The media type of the content.
@@ -846,6 +864,18 @@ public enum Type {
846864
@JsonProperty("tool_result")
847865
TOOL_RESULT("tool_result"),
848866

867+
/**
868+
* Server Tool request
869+
*/
870+
@JsonProperty("server_tool_use")
871+
SERVER_TOOL_USE("server_tool_use"),
872+
873+
/**
874+
* Web search tool result
875+
*/
876+
@JsonProperty("web_search_tool_result")
877+
WEB_SEARCH_TOOL_RESULT("web_search_tool_result"),
878+
849879
/**
850880
* Text message.
851881
*/
@@ -959,22 +989,6 @@ public Source(String url) {
959989
/// CONTENT_BLOCK EVENTS
960990
///////////////////////////////////////
961991

962-
/**
963-
* Tool description.
964-
*
965-
* @param name The name of the tool.
966-
* @param description A description of the tool.
967-
* @param inputSchema The input schema of the tool.
968-
*/
969-
@JsonInclude(Include.NON_NULL)
970-
public record Tool(
971-
// @formatter:off
972-
@JsonProperty("name") String name,
973-
@JsonProperty("description") String description,
974-
@JsonProperty("input_schema") Map<String, Object> inputSchema) {
975-
// @formatter:on
976-
}
977-
978992
// CB START EVENT
979993

980994
/**
@@ -1020,16 +1034,25 @@ public record ChatCompletionResponse(
10201034
public record Usage(
10211035
// @formatter:off
10221036
@JsonProperty("input_tokens") Integer inputTokens,
1023-
@JsonProperty("output_tokens") Integer outputTokens) {
1024-
// @formatter:off
1037+
@JsonProperty("output_tokens") Integer outputTokens,
1038+
@JsonProperty("server_tool_use") ServerToolUse serverToolUse) {
1039+
// @formatter:on
1040+
}
1041+
1042+
@JsonInclude(Include.NON_NULL)
1043+
@JsonIgnoreProperties(ignoreUnknown = true)
1044+
public record ServerToolUse(
1045+
// @formatter:off
1046+
@JsonProperty("web_search_requests") Integer webSearchRequests) {
1047+
// @formatter:on
10251048
}
10261049

1027-
/// ECB STOP
1050+
/// ECB STOP
10281051

10291052
/**
10301053
* Special event used to aggregate multiple tool use events into a single event with
10311054
* list of aggregated ContentBlockToolUse.
1032-
*/
1055+
*/
10331056
public static class ToolUseAggregationEvent implements StreamEvent {
10341057

10351058
private Integer index;
@@ -1048,17 +1071,17 @@ public EventType type() {
10481071
}
10491072

10501073
/**
1051-
* Get tool content blocks.
1052-
* @return The tool content blocks.
1053-
*/
1074+
* Get tool content blocks.
1075+
* @return The tool content blocks.
1076+
*/
10541077
public List<ContentBlockStartEvent.ContentBlockToolUse> getToolContentBlocks() {
10551078
return this.toolContentBlocks;
10561079
}
10571080

10581081
/**
1059-
* Check if the event is empty.
1060-
* @return True if the event is empty, false otherwise.
1061-
*/
1082+
* Check if the event is empty.
1083+
* @return True if the event is empty, false otherwise.
1084+
*/
10621085
public boolean isEmpty() {
10631086
return (this.index == null || this.id == null || this.name == null
10641087
|| !StringUtils.hasText(this.partialJson));
@@ -1087,7 +1110,8 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) {
10871110
void squashIntoContentBlock() {
10881111
Map<String, Object> map = (StringUtils.hasText(this.partialJson))
10891112
? ModelOptionsUtils.jsonToMap(this.partialJson) : Map.of();
1090-
this.toolContentBlocks.add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map));
1113+
this.toolContentBlocks
1114+
.add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map));
10911115
this.index = null;
10921116
this.id = null;
10931117
this.name = null;
@@ -1096,28 +1120,29 @@ void squashIntoContentBlock() {
10961120

10971121
@Override
10981122
public String toString() {
1099-
return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + ", partialJson="
1100-
+ this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]";
1123+
return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name
1124+
+ ", partialJson=" + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]";
11011125
}
11021126

11031127
}
11041128

1105-
///////////////////////////////////////
1106-
/// MESSAGE EVENTS
1107-
///////////////////////////////////////
1129+
///////////////////////////////////////
1130+
/// MESSAGE EVENTS
1131+
///////////////////////////////////////
11081132

1109-
// MESSAGE START EVENT
1133+
// MESSAGE START EVENT
11101134

11111135
/**
11121136
* Content block start event.
1137+
*
11131138
* @param type The event type.
11141139
* @param index The index of the content block.
11151140
* @param contentBlock The content block body.
1116-
*/
1141+
*/
11171142
@JsonInclude(Include.NON_NULL)
11181143
@JsonIgnoreProperties(ignoreUnknown = true)
11191144
public record ContentBlockStartEvent(
1120-
// @formatter:off
1145+
// @formatter:off
11211146
@JsonProperty("type") EventType type,
11221147
@JsonProperty("index") Integer index,
11231148
@JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent {

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -174,7 +174,7 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {
174174

175175
if (messageDeltaEvent.usage() != null) {
176176
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
177-
messageDeltaEvent.usage().outputTokens());
177+
messageDeltaEvent.usage().outputTokens(), contentBlockReference.get().usage.serverToolUse());
178178
contentBlockReference.get().withUsage(totalUsage);
179179
}
180180
}

0 commit comments

Comments
 (0)