Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.ai.tool.method.MethodToolCallback;
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.ai.tool.support.ToolDefinitions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
Expand Down Expand Up @@ -175,7 +175,7 @@ public ToolCallbackProvider blabla() {
public ToolCallback toolCallbacks6() {
var toolMethod = ReflectionUtils.findMethod(WeatherService.class, "getAlert", String.class);
return MethodToolCallback.builder()
.toolDefinition(ToolDefinition.builder(toolMethod).build())
.toolDefinition(ToolDefinitions.builder(toolMethod).build())
.toolMethod(toolMethod)
.toolObject(new WeatherService())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;

/**
Expand Down Expand Up @@ -84,7 +85,7 @@ public AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool) {
*/
@Override
public ToolDefinition getToolDefinition() {
return ToolDefinition.builder()
return DefaultToolDefinition.builder()
.name(McpToolUtils.prefixedToolName(this.asyncMcpClient.getClientInfo().name(), this.tool.name()))
.description(this.tool.description())
.inputSchema(ModelOptionsUtils.toJsonString(this.tool.inputSchema()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.util.ToolUtils;
import org.springframework.ai.tool.support.ToolUtils;
import org.springframework.util.CollectionUtils;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;

/**
Expand Down Expand Up @@ -88,7 +89,7 @@ public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) {
*/
@Override
public ToolDefinition getToolDefinition() {
return ToolDefinition.builder()
return DefaultToolDefinition.builder()
.name(McpToolUtils.prefixedToolName(this.mcpClient.getClientInfo().name(), this.tool.name()))
.description(this.tool.description())
.inputSchema(ModelOptionsUtils.toJsonString(this.tool.inputSchema()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.mcp;

import java.util.ArrayList;
import java.util.List;
import java.util.function.BiPredicate;

Expand All @@ -25,7 +24,7 @@

import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.util.ToolUtils;
import org.springframework.ai.tool.support.ToolUtils;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import reactor.test.StepVerifier;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -193,7 +194,7 @@ void toAsyncToolSpecificationShouldConvertMultipleCallbacks() {

private ToolCallback createMockToolCallback(String name, String result) {
ToolCallback callback = mock(ToolCallback.class);
ToolDefinition definition = ToolDefinition.builder()
ToolDefinition definition = DefaultToolDefinition.builder()
.name(name)
.description("Test tool")
.inputSchema("{}")
Expand All @@ -205,7 +206,7 @@ private ToolCallback createMockToolCallback(String name, String result) {

private ToolCallback createMockToolCallback(String name, RuntimeException error) {
ToolCallback callback = mock(ToolCallback.class);
ToolDefinition definition = ToolDefinition.builder()
ToolDefinition definition = DefaultToolDefinition.builder()
.name(name)
.description("Test tool")
.inputSchema("{}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand Down Expand Up @@ -194,7 +194,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons

Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(completionResponse.usage())
: new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
previousChatResponse);

ChatResponse chatResponse = toChatResponse(completionEntity.getBody(), accumulatedUsage);
observationContext.setResponse(chatResponse);
Expand Down Expand Up @@ -256,7 +257,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
AnthropicApi.Usage usage = chatCompletionResponse.usage();
Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(chatCompletionResponse.usage()) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.method.MethodToolCallback;
import org.springframework.ai.tool.support.ToolDefinitions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.ActiveProfiles;
Expand Down Expand Up @@ -68,7 +68,7 @@ void methodGetWeatherGeneratedDescription() {
String response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
.toolCallbacks(MethodToolCallback.builder()
.toolDefinition(ToolDefinition.builder(toolMethod).build())
.toolDefinition(ToolDefinitions.builder(toolMethod).build())
.toolMethod(toolMethod)
.build())
.call()
Expand All @@ -90,7 +90,7 @@ void methodGetWeatherStatic() {
String response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
.toolCallbacks(MethodToolCallback.builder()
.toolDefinition(ToolDefinition.builder(toolMethod)
.toolDefinition(ToolDefinitions.builder(toolMethod)
.description("Get the weather in location")
.build())
.toolMethod(toolMethod)
Expand All @@ -117,7 +117,7 @@ void methodTurnLightNoResponse() {
String response = ChatClient.create(this.chatModel).prompt()
.user("Turn light on in the living room.")
.toolCallbacks(MethodToolCallback.builder()
.toolDefinition(ToolDefinition.builder(turnLightMethod)
.toolDefinition(ToolDefinitions.builder(turnLightMethod)
.description("Turn light on in the living room.")
.build())
.toolMethod(turnLightMethod)
Expand Down Expand Up @@ -145,7 +145,7 @@ void methodGetWeatherNonStatic() {
String response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
.toolCallbacks(MethodToolCallback.builder()
.toolDefinition(ToolDefinition.builder(toolMethod)
.toolDefinition(ToolDefinitions.builder(toolMethod)
.description("Get the weather in location")
.build())
.toolMethod(toolMethod)
Expand All @@ -172,7 +172,7 @@ void methodGetWeatherToolContext() {
String response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
.toolCallbacks(MethodToolCallback.builder()
.toolDefinition(ToolDefinition.builder(toolMethod)
.toolDefinition(ToolDefinitions.builder(toolMethod)
.description("Get the weather in location")
.build())
.toolMethod(toolMethod)
Expand Down Expand Up @@ -203,7 +203,7 @@ void methodGetWeatherWithContextMethodButMissingContext() {
assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
.toolCallbacks(MethodToolCallback.builder()
.toolDefinition(ToolDefinition.builder(toolMethod)
.toolDefinition(ToolDefinitions.builder(toolMethod)
.description("Get the weather in location")
.build())
.toolMethod(toolMethod)
Expand All @@ -229,7 +229,7 @@ void methodNoParameters() {
.user("Turn light on in the living room.")
.toolCallbacks(MethodToolCallback.builder()
.toolMethod(toolMethod)
.toolDefinition(ToolDefinition.builder(toolMethod)
.toolDefinition(ToolDefinitions.builder(toolMethod)
.description("Can turn lights on in the Living Room")
.build())
.toolObject(targetObject)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand Down Expand Up @@ -357,15 +357,16 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
// Accumulate the usage from the previous chat response
CompletionsUsage usage = chatCompletion.getUsage();
Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
previousChatResponse);
return toChatResponse(chatCompletion, accumulatedUsage);
}).buffer(2, 1).map(bufferList -> {
ChatResponse chatResponse1 = bufferList.get(0);
if (options.getStreamOptions() != null && options.getStreamOptions().isIncludeUsage()) {
if (bufferList.size() == 2) {
ChatResponse chatResponse2 = bufferList.get(1);
if (chatResponse2 != null && chatResponse2.getMetadata() != null
&& !UsageUtils.isEmpty(chatResponse2.getMetadata().getUsage())) {
&& !UsageCalculator.isEmpty(chatResponse2.getMetadata().getUsage())) {
return toChatResponse(chatResponse1, chatResponse2.getMetadata().getUsage());
}
}
Expand Down Expand Up @@ -462,7 +463,7 @@ private ChatResponse toChatResponse(ChatCompletions chatCompletions, ChatRespons
if (chatCompletions.getUsage() != null) {
currentUsage = getDefaultUsage(chatCompletions.getUsage());
}
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse);
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, cumulativeUsage));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.*;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
Expand Down Expand Up @@ -179,7 +180,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
// Current usage
DeepSeekApi.Usage usage = completionEntity.getBody().usage();
Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
from(completionEntity.getBody(), accumulatedUsage));

Expand Down Expand Up @@ -256,7 +258,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
}).toList();
DeepSeekApi.Usage usage = chatCompletion2.usage();
Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage();
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse);

return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand Down Expand Up @@ -216,7 +216,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
}).toList();

DefaultUsage usage = getDefaultUsage(completionEntity.getBody().usage());
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse);
Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(usage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
from(completionEntity.getBody(), cumulativeUsage));

Expand Down Expand Up @@ -298,7 +298,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha

if (chatCompletion2.usage() != null) {
DefaultUsage usage = getDefaultUsage(chatCompletion2.usage());
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse);
Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(usage, previousChatResponse);
return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.boot.test.context.SpringBootTest;

Expand Down Expand Up @@ -106,7 +107,7 @@ static class TestToolCallback implements ToolCallback {
private final ToolDefinition toolDefinition;

TestToolCallback(String name) {
this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build();
this.toolDefinition = DefaultToolDefinition.builder().name(name).inputSchema("{}").build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -167,7 +168,7 @@ static class TestToolCallback implements ToolCallback {
private final ToolDefinition toolDefinition;

TestToolCallback(String name) {
this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build();
this.toolDefinition = DefaultToolDefinition.builder().name(name).inputSchema("{}").build();
}

@Override
Expand Down
Loading