Skip to content

Commit 33f658b

Browse files
committed
feat: Enhance Anthropic integration with Thinking
- The `thinking` option is added to `AnthropicChatOptions` and `ChatCompletionRequest`. - The `AnthropicApi` and `AnthropicChatModel` now handle `THINKING` and `REDACTED_THINKING` content blocks in responses. New tests verify parsing of these blocks. - Updated method signatures on ChatCompletionRequestBuilder, deprecating old builders with `with*` prefix in favor of those without. Signed-off-by: Alexandros Pappas <[email protected]>
1 parent bc375ab commit 33f658b

File tree

7 files changed

+423
-67
lines changed

7 files changed

+423
-67
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
3131
import org.slf4j.Logger;
3232
import org.slf4j.LoggerFactory;
33+
import org.springframework.ai.model.tool.LegacyToolCallingManager;
34+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
35+
import org.springframework.ai.model.tool.ToolCallingManager;
36+
import org.springframework.ai.model.tool.ToolExecutionResult;
37+
import org.springframework.ai.tool.definition.ToolDefinition;
38+
import org.springframework.ai.util.json.JsonParser;
39+
import org.springframework.lang.Nullable;
40+
3341
import reactor.core.publisher.Flux;
3442
import reactor.core.publisher.Mono;
3543
import reactor.core.scheduler.Schedulers;
@@ -279,46 +287,49 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
279287
return new ChatResponse(List.of());
280288
}
281289

282-
List<Generation> generations = chatCompletion.content()
283-
.stream()
284-
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
285-
.map(content -> new Generation(new AssistantMessage(content.text(), Map.of()),
286-
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()))
287-
.toList();
288-
289-
List<Generation> allGenerations = new ArrayList<>(generations);
290+
List<Generation> generations = new ArrayList<>();
291+
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
292+
for (ContentBlock content : chatCompletion.content()) {
293+
switch (content.type()) {
294+
case TEXT, TEXT_DELTA:
295+
generations.add(new Generation(new AssistantMessage(content.text(), Map.of()),
296+
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
297+
break;
298+
case THINKING, THINKING_DELTA:
299+
Map<String, Object> thinkingProperties = new HashMap<>();
300+
thinkingProperties.put("signature", content.signature());
301+
generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties),
302+
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
303+
break;
304+
case REDACTED_THINKING:
305+
Map<String, Object> redactedProperties = new HashMap<>();
306+
redactedProperties.put("data", content.data());
307+
generations.add(new Generation(new AssistantMessage(null, redactedProperties),
308+
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
309+
break;
310+
case TOOL_USE:
311+
var functionCallId = content.id();
312+
var functionName = content.name();
313+
var functionArguments = JsonParser.toJson(content.input());
314+
toolCalls.add(
315+
new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
316+
break;
317+
}
318+
}
290319

291320
if (chatCompletion.stopReason() != null && generations.isEmpty()) {
292321
Generation generation = new Generation(new AssistantMessage(null, Map.of()),
293322
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
294-
allGenerations.add(generation);
323+
generations.add(generation);
295324
}
296325

297-
List<ContentBlock> toolToUseList = chatCompletion.content()
298-
.stream()
299-
.filter(c -> c.type() == ContentBlock.Type.TOOL_USE)
300-
.toList();
301-
302-
if (!CollectionUtils.isEmpty(toolToUseList)) {
303-
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
304-
305-
for (ContentBlock toolToUse : toolToUseList) {
306-
307-
var functionCallId = toolToUse.id();
308-
var functionName = toolToUse.name();
309-
var functionArguments = JsonParser.toJson(toolToUse.input());
310-
311-
toolCalls
312-
.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
313-
}
314-
326+
if (!CollectionUtils.isEmpty(toolCalls)) {
315327
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
316328
Generation toolCallGeneration = new Generation(assistantMessage,
317329
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
318-
allGenerations.add(toolCallGeneration);
330+
generations.add(toolCallGeneration);
319331
}
320-
321-
return new ChatResponse(allGenerations, this.from(chatCompletion, usage));
332+
return new ChatResponse(generations, this.from(chatCompletion, usage));
322333
}
323334

324335
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
@@ -490,7 +501,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
490501
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
491502
if (!CollectionUtils.isEmpty(toolDefinitions)) {
492503
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
493-
request = ChatCompletionRequest.from(request).withTools(getFunctionTools(toolDefinitions)).build();
504+
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
494505
}
495506

496507
return request;

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
5757
private @JsonProperty("temperature") Double temperature;
5858
private @JsonProperty("top_p") Double topP;
5959
private @JsonProperty("top_k") Integer topK;
60+
private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking;
6061

6162
/**
6263
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
@@ -103,6 +104,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
103104
.temperature(fromOptions.getTemperature())
104105
.topP(fromOptions.getTopP())
105106
.topK(fromOptions.getTopK())
107+
.thinking(fromOptions.getThinking())
106108
.toolCallbacks(
107109
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
108110
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
@@ -174,6 +176,14 @@ public void setTopK(Integer topK) {
174176
this.topK = topK;
175177
}
176178

179+
public ChatCompletionRequest.ThinkingConfig getThinking() {
180+
return this.thinking;
181+
}
182+
183+
public void setThinking(ChatCompletionRequest.ThinkingConfig thinking) {
184+
this.thinking = thinking;
185+
}
186+
177187
@Override
178188
@JsonIgnore
179189
public List<FunctionCallback> getToolCallbacks() {
@@ -308,7 +318,8 @@ public boolean equals(Object o) {
308318
&& Objects.equals(this.metadata, that.metadata)
309319
&& Objects.equals(this.stopSequences, that.stopSequences)
310320
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP)
311-
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.toolCallbacks, that.toolCallbacks)
321+
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.thinking, that.thinking)
322+
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
312323
&& Objects.equals(this.toolNames, that.toolNames)
313324
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
314325
&& Objects.equals(this.toolContext, that.toolContext)
@@ -317,7 +328,7 @@ public boolean equals(Object o) {
317328

318329
@Override
319330
public int hashCode() {
320-
return Objects.hash(model, maxTokens, metadata, stopSequences, temperature, topP, topK, toolCallbacks,
331+
return Objects.hash(model, maxTokens, metadata, stopSequences, temperature, topP, topK, thinking, toolCallbacks,
321332
toolNames, internalToolExecutionEnabled, toolContext, httpHeaders);
322333
}
323334

@@ -365,6 +376,16 @@ public Builder topK(Integer topK) {
365376
return this;
366377
}
367378

379+
public Builder thinking(ChatCompletionRequest.ThinkingConfig thinking) {
380+
this.options.thinking = thinking;
381+
return this;
382+
}
383+
384+
public Builder thinking(AnthropicApi.ThinkingType type, Integer budgetTokens) {
385+
this.options.thinking = new ChatCompletionRequest.ThinkingConfig(type, budgetTokens);
386+
return this;
387+
}
388+
368389
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
369390
this.options.setToolCallbacks(toolCallbacks);
370391
return this;

0 commit comments

Comments
 (0)