Skip to content

Commit 4eb24fd

Browse files
committed
feat: Enhance Anthropic integration with Thinking
- The `thinking` option is added to `AnthropicChatOptions` and `ChatCompletionRequest`. - The `AnthropicApi` and `AnthropicChatModel` now handle `THINKING` and `REDACTED_THINKING` content blocks in responses. New tests verify parsing of these blocks. - Updated method signatures on ChatCompletionRequestBuilder, deprecating old builders with `with*` prefix in favor of those without. Signed-off-by: Alexandros Pappas <[email protected]>
1 parent 092bbae commit 4eb24fd

File tree

7 files changed

+413
-65
lines changed

7 files changed

+413
-65
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.tool.definition.ToolDefinition;
3838
import org.springframework.ai.util.json.JsonParser;
3939
import org.springframework.lang.Nullable;
40+
4041
import reactor.core.publisher.Flux;
4142
import reactor.core.publisher.Mono;
4243

@@ -382,46 +383,49 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
382383
return new ChatResponse(List.of());
383384
}
384385

385-
List<Generation> generations = chatCompletion.content()
386-
.stream()
387-
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
388-
.map(content -> new Generation(new AssistantMessage(content.text(), Map.of()),
389-
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()))
390-
.toList();
391-
392-
List<Generation> allGenerations = new ArrayList<>(generations);
386+
List<Generation> generations = new ArrayList<>();
387+
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
388+
for (ContentBlock content : chatCompletion.content()) {
389+
switch (content.type()) {
390+
case TEXT, TEXT_DELTA:
391+
generations.add(new Generation(new AssistantMessage(content.text(), Map.of()),
392+
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
393+
break;
394+
case THINKING, THINKING_DELTA:
395+
Map<String, Object> thinkingProperties = new HashMap<>();
396+
thinkingProperties.put("signature", content.signature());
397+
generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties),
398+
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
399+
break;
400+
case REDACTED_THINKING:
401+
Map<String, Object> redactedProperties = new HashMap<>();
402+
redactedProperties.put("data", content.data());
403+
generations.add(new Generation(new AssistantMessage(null, redactedProperties),
404+
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
405+
break;
406+
case TOOL_USE:
407+
var functionCallId = content.id();
408+
var functionName = content.name();
409+
var functionArguments = JsonParser.toJson(content.input());
410+
toolCalls.add(
411+
new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
412+
break;
413+
}
414+
}
393415

394416
if (chatCompletion.stopReason() != null && generations.isEmpty()) {
395417
Generation generation = new Generation(new AssistantMessage(null, Map.of()),
396418
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
397-
allGenerations.add(generation);
419+
generations.add(generation);
398420
}
399421

400-
List<ContentBlock> toolToUseList = chatCompletion.content()
401-
.stream()
402-
.filter(c -> c.type() == ContentBlock.Type.TOOL_USE)
403-
.toList();
404-
405-
if (!CollectionUtils.isEmpty(toolToUseList)) {
406-
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
407-
408-
for (ContentBlock toolToUse : toolToUseList) {
409-
410-
var functionCallId = toolToUse.id();
411-
var functionName = toolToUse.name();
412-
var functionArguments = JsonParser.toJson(toolToUse.input());
413-
414-
toolCalls
415-
.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
416-
}
417-
422+
if (!CollectionUtils.isEmpty(toolCalls)) {
418423
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
419424
Generation toolCallGeneration = new Generation(assistantMessage,
420425
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
421-
allGenerations.add(toolCallGeneration);
426+
generations.add(toolCallGeneration);
422427
}
423-
424-
return new ChatResponse(allGenerations, this.from(chatCompletion, usage));
428+
return new ChatResponse(generations, this.from(chatCompletion, usage));
425429
}
426430

427431
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
@@ -597,7 +601,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
597601
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
598602
if (!CollectionUtils.isEmpty(toolDefinitions)) {
599603
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
600-
request = ChatCompletionRequest.from(request).withTools(getFunctionTools(toolDefinitions)).build();
604+
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
601605
}
602606

603607
return request;

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
5757
private @JsonProperty("temperature") Double temperature;
5858
private @JsonProperty("top_p") Double topP;
5959
private @JsonProperty("top_k") Integer topK;
60+
private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking;
6061

6162
/**
6263
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
@@ -103,6 +104,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
103104
.temperature(fromOptions.getTemperature())
104105
.topP(fromOptions.getTopP())
105106
.topK(fromOptions.getTopK())
107+
.thinking(fromOptions.getThinking())
106108
.toolCallbacks(
107109
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
108110
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
@@ -174,6 +176,14 @@ public void setTopK(Integer topK) {
174176
this.topK = topK;
175177
}
176178

179+
public ChatCompletionRequest.ThinkingConfig getThinking() {
180+
return this.thinking;
181+
}
182+
183+
public void setThinking(ChatCompletionRequest.ThinkingConfig thinking) {
184+
this.thinking = thinking;
185+
}
186+
177187
@Override
178188
@JsonIgnore
179189
public List<FunctionCallback> getToolCallbacks() {
@@ -365,6 +375,16 @@ public Builder topK(Integer topK) {
365375
return this;
366376
}
367377

378+
public Builder thinking(ChatCompletionRequest.ThinkingConfig thinking) {
379+
this.options.thinking = thinking;
380+
return this;
381+
}
382+
383+
public Builder thinking(AnthropicApi.ThinkingType type, Integer budgetTokens) {
384+
this.options.thinking = new ChatCompletionRequest.ThinkingConfig(type, budgetTokens);
385+
return this;
386+
}
387+
368388
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
369389
this.options.setToolCallbacks(toolCallbacks);
370390
return this;

0 commit comments

Comments
 (0)