Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
private @JsonProperty("temperature") Double temperature;
private @JsonProperty("top_p") Double topP;
private @JsonProperty("top_k") Integer topK;
private @JsonProperty("tool_choice") AnthropicApi.ToolChoice toolChoice;
private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking;

@JsonIgnore
Expand Down Expand Up @@ -117,6 +118,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.temperature(fromOptions.getTemperature())
.topP(fromOptions.getTopP())
.topK(fromOptions.getTopK())
.toolChoice(fromOptions.getToolChoice())
.thinking(fromOptions.getThinking())
.toolCallbacks(
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
Expand Down Expand Up @@ -190,6 +192,14 @@ public void setTopK(Integer topK) {
this.topK = topK;
}

public AnthropicApi.ToolChoice getToolChoice() {
return this.toolChoice;
}

public void setToolChoice(AnthropicApi.ToolChoice toolChoice) {
this.toolChoice = toolChoice;
}

public ChatCompletionRequest.ThinkingConfig getThinking() {
return this.thinking;
}
Expand Down Expand Up @@ -291,7 +301,8 @@ public boolean equals(Object o) {
&& Objects.equals(this.metadata, that.metadata)
&& Objects.equals(this.stopSequences, that.stopSequences)
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP)
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.thinking, that.thinking)
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.toolChoice, that.toolChoice)
&& Objects.equals(this.thinking, that.thinking)
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
Expand All @@ -303,8 +314,8 @@ public boolean equals(Object o) {
@Override
public int hashCode() {
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
this.toolContext, this.httpHeaders, this.cacheOptions);
this.topK, this.toolChoice, this.thinking, this.toolCallbacks, this.toolNames,
this.internalToolExecutionEnabled, this.toolContext, this.httpHeaders, this.cacheOptions);
}

public static final class Builder {
Expand Down Expand Up @@ -351,6 +362,11 @@ public Builder topK(Integer topK) {
return this;
}

public Builder toolChoice(AnthropicApi.ToolChoice toolChoice) {
this.options.toolChoice = toolChoice;
return this;
}

public Builder thinking(ChatCompletionRequest.ThinkingConfig thinking) {
this.options.thinking = thinking;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ public interface StreamEvent {
* return tool_use content blocks that represent the model's use of those tools. You
* can then run those tools using the tool input generated by the model and then
* optionally return results back to the model using tool_result content blocks.
* @param toolChoice How the model should use the provided tools. The model can use a
* specific tool, any available tool, decide by itself, or not use tools at all.
* @param thinking Configuration for the model's thinking mode. When enabled, the
* model can perform more in-depth reasoning before responding to a query.
*/
Expand All @@ -529,17 +531,19 @@ public record ChatCompletionRequest(
@JsonProperty("top_p") Double topP,
@JsonProperty("top_k") Integer topK,
@JsonProperty("tools") List<Tool> tools,
@JsonProperty("tool_choice") ToolChoice toolChoice,
@JsonProperty("thinking") ThinkingConfig thinking) {
// @formatter:on

public ChatCompletionRequest(String model, List<AnthropicMessage> messages, Object system, Integer maxTokens,
Double temperature, Boolean stream) {
this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null, null);
this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null, null, null);
}

public ChatCompletionRequest(String model, List<AnthropicMessage> messages, Object system, Integer maxTokens,
List<String> stopSequences, Double temperature, Boolean stream) {
this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null);
this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null,
null);
}

public static ChatCompletionRequestBuilder builder() {
Expand Down Expand Up @@ -613,6 +617,8 @@ public static final class ChatCompletionRequestBuilder {

private List<Tool> tools;

private ToolChoice toolChoice;

private ChatCompletionRequest.ThinkingConfig thinking;

private ChatCompletionRequestBuilder() {
Expand All @@ -630,6 +636,7 @@ private ChatCompletionRequestBuilder(ChatCompletionRequest request) {
this.topP = request.topP;
this.topK = request.topK;
this.tools = request.tools;
this.toolChoice = request.toolChoice;
this.thinking = request.thinking;
}

Expand Down Expand Up @@ -693,6 +700,11 @@ public ChatCompletionRequestBuilder tools(List<Tool> tools) {
return this;
}

public ChatCompletionRequestBuilder toolChoice(ToolChoice toolChoice) {
this.toolChoice = toolChoice;
return this;
}

public ChatCompletionRequestBuilder thinking(ChatCompletionRequest.ThinkingConfig thinking) {
this.thinking = thinking;
return this;
Expand All @@ -705,7 +717,8 @@ public ChatCompletionRequestBuilder thinking(ThinkingType type, Integer budgetTo

public ChatCompletionRequest build() {
return new ChatCompletionRequest(this.model, this.messages, this.system, this.maxTokens, this.metadata,
this.stopSequences, this.stream, this.temperature, this.topP, this.topK, this.tools, this.thinking);
this.stopSequences, this.stream, this.temperature, this.topP, this.topK, this.tools,
this.toolChoice, this.thinking);
}

}
Expand Down Expand Up @@ -1135,6 +1148,126 @@ public Tool(String name, String description, Map<String, Object> inputSchema) {

}

/**
* Base interface for tool choice options.
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type",
visible = true)
@JsonSubTypes({ @JsonSubTypes.Type(value = ToolChoiceAuto.class, name = "auto"),
@JsonSubTypes.Type(value = ToolChoiceAny.class, name = "any"),
@JsonSubTypes.Type(value = ToolChoiceTool.class, name = "tool"),
@JsonSubTypes.Type(value = ToolChoiceNone.class, name = "none") })
public interface ToolChoice {

@JsonProperty("type")
String type();

}

/**
* Auto tool choice - the model will automatically decide whether to use tools.
*
* @param type The type of tool choice, always "auto".
* @param disableParallelToolUse Whether to disable parallel tool use. Defaults to
* false. If set to true, the model will output at most one tool use.
*/
@JsonInclude(Include.NON_NULL)
public record ToolChoiceAuto(@JsonProperty("type") String type,
@JsonProperty("disable_parallel_tool_use") Boolean disableParallelToolUse) implements ToolChoice {

/**
* Create an auto tool choice with default settings.
*/
public ToolChoiceAuto() {
this("auto", null);
}

/**
* Create an auto tool choice with specific parallel tool use setting.
* @param disableParallelToolUse Whether to disable parallel tool use.
*/
public ToolChoiceAuto(Boolean disableParallelToolUse) {
this("auto", disableParallelToolUse);
}

}

/**
* Any tool choice - the model will use any available tools.
*
* @param type The type of tool choice, always "any".
* @param disableParallelToolUse Whether to disable parallel tool use. Defaults to
* false. If set to true, the model will output exactly one tool use.
*/
@JsonInclude(Include.NON_NULL)
public record ToolChoiceAny(@JsonProperty("type") String type,
@JsonProperty("disable_parallel_tool_use") Boolean disableParallelToolUse) implements ToolChoice {

/**
* Create an any tool choice with default settings.
*/
public ToolChoiceAny() {
this("any", null);
}

/**
* Create an any tool choice with specific parallel tool use setting.
* @param disableParallelToolUse Whether to disable parallel tool use.
*/
public ToolChoiceAny(Boolean disableParallelToolUse) {
this("any", disableParallelToolUse);
}

}

/**
* Tool choice - the model will use the specified tool.
*
* @param type The type of tool choice, always "tool".
* @param name The name of the tool to use.
* @param disableParallelToolUse Whether to disable parallel tool use. Defaults to
* false. If set to true, the model will output exactly one tool use.
*/
@JsonInclude(Include.NON_NULL)
public record ToolChoiceTool(@JsonProperty("type") String type, @JsonProperty("name") String name,
@JsonProperty("disable_parallel_tool_use") Boolean disableParallelToolUse) implements ToolChoice {

/**
* Create a tool choice for a specific tool.
* @param name The name of the tool to use.
*/
public ToolChoiceTool(String name) {
this("tool", name, null);
}

/**
* Create a tool choice for a specific tool with parallel tool use setting.
* @param name The name of the tool to use.
* @param disableParallelToolUse Whether to disable parallel tool use.
*/
public ToolChoiceTool(String name, Boolean disableParallelToolUse) {
this("tool", name, disableParallelToolUse);
}

}

/**
* None tool choice - the model will not be allowed to use tools.
*
* @param type The type of tool choice, always "none".
*/
@JsonInclude(Include.NON_NULL)
public record ToolChoiceNone(@JsonProperty("type") String type) implements ToolChoice {

/**
* Create a none tool choice.
*/
public ToolChoiceNone() {
this("none");
}

}

// CB START EVENT

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,102 @@ void testToolUseContentBlock() {
}
}

@Test
void testToolChoiceAny() {
// A user question that would not typically result in a tool request
UserMessage userMessage = new UserMessage("Say hi");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AnthropicChatOptions.builder()
.model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
.toolChoice(new AnthropicApi.ToolChoiceAny())
.internalToolExecutionEnabled(false)
.toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.inputType(MockWeatherService.Request.class)
.build())
.build();

ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));

logger.info("Response: {}", response);
assertThat(response.getResults()).isNotNull();
// When tool choice is "any", the model MUST use at least one tool
boolean hasToolCalls = response.getResults()
.stream()
.anyMatch(generation -> !generation.getOutput().getToolCalls().isEmpty());
assertThat(hasToolCalls).isTrue();
}

@Test
void testToolChoiceTool() {
UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco? Return the result in Celsius.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AnthropicChatOptions.builder()
.model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
.toolChoice(new AnthropicApi.ToolChoiceTool("getFunResponse", true))
.internalToolExecutionEnabled(false)
.toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.inputType(MockWeatherService.Request.class)
.build(),
// Based on the user's question the model should want to call
// getCurrentWeather
// however we're going to force getFunResponse
FunctionToolCallback.builder("getFunResponse", new MockWeatherService())
.description("Get a fun response")
.inputType(MockWeatherService.Request.class)
.build())
.build();

ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));

logger.info("Response: {}", response);
assertThat(response.getResults()).isNotNull();
// When tool choice is a specific tool, the model MUST use that specific tool
List<AssistantMessage.ToolCall> allToolCalls = response.getResults()
.stream()
.flatMap(generation -> generation.getOutput().getToolCalls().stream())
.toList();
assertThat(allToolCalls).isNotEmpty();
assertThat(allToolCalls).hasSize(1);
assertThat(allToolCalls.get(0).name()).isEqualTo("getFunResponse");
}

@Test
void testToolChoiceNone() {
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco?");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AnthropicChatOptions.builder()
.model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
.toolChoice(new AnthropicApi.ToolChoiceNone())
.toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.inputType(MockWeatherService.Request.class)
.build())
.build();

ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));

logger.info("Response: {}", response);
assertThat(response.getResults()).isNotNull();
// When tool choice is "none", the model MUST NOT use any tools
List<AssistantMessage.ToolCall> allToolCalls = response.getResults()
.stream()
.flatMap(generation -> generation.getOutput().getToolCalls().stream())
.toList();
assertThat(allToolCalls).isEmpty();
}

record ActorsFilmsRecord(String actor, List<String> movies) {

}
Expand Down
Loading