Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
package org.springframework.ai.anthropic;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -30,7 +32,9 @@
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand All @@ -42,7 +46,7 @@
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class AnthropicChatOptions implements FunctionCallingOptions {
public class AnthropicChatOptions implements ToolCallingChatOptions {

// @formatter:off
private @JsonProperty("model") String model;
Expand All @@ -54,34 +58,27 @@ public class AnthropicChatOptions implements FunctionCallingOptions {
private @JsonProperty("top_k") Integer topK;

/**
* Tool Function Callbacks to register with the ChatModel. For Prompt
* Options the functionCallbacks are automatically enabled for the duration of the
* prompt execution. For Default Options the functionCallbacks are registered but
* disabled by default. Use the enableFunctions to set the functions from the registry
* to be used by the ChatModel chat completion requests.
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
* completion requests.
*/
@JsonIgnore
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
private List<FunctionCallback> toolCallbacks = new ArrayList<>();

/**
* List of functions, identified by their names, to configure for function calling in
* the chat completion requests. Functions with those names must exist in the
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
* are automatically enabled for the duration of the prompt execution.
*
* Note that function enabled with the default options are enabled for all chat
* completion requests. This could impact the token count and the billing. If the
* functions is set in a prompt options, then the enabled functions are only active
* for the duration of this prompt execution.
* Collection of tool names to be resolved at runtime and used for tool calling in the
* chat completion requests.
*/
@JsonIgnore
private Set<String> functions = new HashSet<>();
private Set<String> toolNames = new HashSet<>();

/**
* Whether to enable the tool execution lifecycle internally in ChatModel.
*/
@JsonIgnore
private Boolean proxyToolCalls;
private Boolean internalToolExecutionEnabled;

@JsonIgnore
private Map<String, Object> toolContext;
private Map<String, Object> toolContext = new HashMap<>();

// @formatter:on

Expand All @@ -97,9 +94,9 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.temperature(fromOptions.getTemperature())
.topP(fromOptions.getTopP())
.topK(fromOptions.getTopK())
.functionCallbacks(fromOptions.getFunctionCallbacks())
.functions(fromOptions.getFunctions())
.proxyToolCalls(fromOptions.getProxyToolCalls())
.toolCallbacks(fromOptions.getToolCallbacks())
.toolNames(fromOptions.getToolNames())
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
.toolContext(fromOptions.getToolContext())
.build();
}
Expand Down Expand Up @@ -167,25 +164,73 @@ public void setTopK(Integer topK) {
}

@Override
@JsonIgnore
public List<FunctionCallback> getToolCallbacks() {
return this.toolCallbacks;
}

@Override
@JsonIgnore
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.toolCallbacks = toolCallbacks;
}

@Override
@JsonIgnore
public Set<String> getToolNames() {
return this.toolNames;
}

@Override
@JsonIgnore
public void setToolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
this.toolNames = toolNames;
}

@Override
@Nullable
@JsonIgnore
public Boolean isInternalToolExecutionEnabled() {
return internalToolExecutionEnabled;
}

@Override
@JsonIgnore
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@Override
@Deprecated
@JsonIgnore
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
return this.getToolCallbacks();
}

@Override
@Deprecated
@JsonIgnore
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
this.functionCallbacks = functionCallbacks;
this.setToolCallbacks(functionCallbacks);
}

@Override
@Deprecated
@JsonIgnore
public Set<String> getFunctions() {
return this.functions;
return this.getToolNames();
}

@Override
public void setFunctions(Set<String> functions) {
Assert.notNull(functions, "Function must not be null");
this.functions = functions;
@Deprecated
@JsonIgnore
public void setFunctions(Set<String> functionNames) {
this.setToolNames(functionNames);
}

@Override
Expand All @@ -201,20 +246,26 @@ public Double getPresencePenalty() {
}

@Override
@Deprecated
@JsonIgnore
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
}

@Deprecated
@JsonIgnore
public void setProxyToolCalls(Boolean proxyToolCalls) {
this.proxyToolCalls = proxyToolCalls;
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
}

@Override
@JsonIgnore
public Map<String, Object> getToolContext() {
return this.toolContext;
}

@Override
@JsonIgnore
public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}
Expand Down Expand Up @@ -268,25 +319,54 @@ public Builder topK(Integer topK) {
return this;
}

public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.options.setToolCallbacks(toolCallbacks);
return this;
}

public Builder functions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks));
return this;
}

public Builder function(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
public Builder toolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.setToolNames(toolNames);
return this;
}

public Builder toolNames(String... toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.toolNames.addAll(Set.of(toolNames));
return this;
}

public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
return this;
}

@Deprecated
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
return toolCallbacks(functionCallbacks);
}

@Deprecated
public Builder functions(Set<String> functionNames) {
return toolNames(functionNames);
}

@Deprecated
public Builder function(String functionName) {
return toolNames(functionName);
}

@Deprecated
public Builder proxyToolCalls(Boolean proxyToolCalls) {
this.options.proxyToolCalls = proxyToolCalls;
if (proxyToolCalls != null) {
this.options.setInternalToolExecutionEnabled(!proxyToolCalls);
}
return this;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,6 +26,7 @@
/**
* @author Christian Tzolov
* @author Alexandros Pappas
* @author Thomas Vitale
*/
public class ChatCompletionRequestTests {

Expand All @@ -35,16 +36,20 @@ public void createRequestWithChatOptions() {
var client = new AnthropicChatModel(new AnthropicApi("TEST"),
AnthropicChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build());

var request = client.createRequest(new Prompt("Test message content"), false);
var prompt = client.buildRequestPrompt(new Prompt("Test message content"));

var request = client.createRequest(prompt, false);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isFalse();

assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
assertThat(request.temperature()).isEqualTo(66.6);

request = client.createRequest(new Prompt("Test message content",
AnthropicChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build()), true);
prompt = client.buildRequestPrompt(new Prompt("Test message content",
AnthropicChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build()));

request = client.createRequest(prompt, true);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isTrue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.model.ModelResponse;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
Expand Down Expand Up @@ -111,6 +112,21 @@ public boolean hasToolCalls() {
return generations.stream().anyMatch(generation -> generation.getOutput().hasToolCalls());
}

/**
* Whether the model has finished with any of the given finish reasons.
*/
public boolean hasFinishReasons(Set<String> finishReasons) {
Assert.notNull(finishReasons, "finishReasons cannot be null");
if (CollectionUtils.isEmpty(generations)) {
return false;
}
return generations.stream().anyMatch(generation -> {
var finishReason = (generation.getMetadata().getFinishReason() != null)
? generation.getMetadata().getFinishReason() : "";
return finishReasons.stream().map(String::toLowerCase).toList().contains(finishReason.toLowerCase());
});
}

@Override
public String toString() {
return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@

import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;

import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* Unit tests for {@link ChatResponse}.
Expand All @@ -48,4 +51,32 @@ void whenNoToolCallsArePresentThenReturnFalse() {
assertThat(chatResponse.hasToolCalls()).isFalse();
}

@Test
void whenFinishReasonIsNullThenThrow() {
var chatResponse = ChatResponse.builder()
.generations(List.of(new Generation(new AssistantMessage("Result"),
ChatGenerationMetadata.builder().finishReason("completed").build())))
.build();
assertThatThrownBy(() -> chatResponse.hasFinishReasons(null)).isInstanceOf(IllegalArgumentException.class)
.hasMessage("finishReasons cannot be null");
}

@Test
void whenFinishReasonIsPresent() {
ChatResponse chatResponse = ChatResponse.builder()
.generations(List.of(new Generation(new AssistantMessage("Result"),
ChatGenerationMetadata.builder().finishReason("completed").build())))
.build();
assertThat(chatResponse.hasFinishReasons(Set.of("completed"))).isTrue();
}

@Test
void whenFinishReasonIsNotPresent() {
ChatResponse chatResponse = ChatResponse.builder()
.generations(List.of(new Generation(new AssistantMessage("Result"),
ChatGenerationMetadata.builder().finishReason("failed").build())))
.build();
assertThat(chatResponse.hasFinishReasons(Set.of("completed"))).isFalse();
}

}
Loading