Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -154,29 +154,29 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackContext the function callback context used to store the
* state of the function calls.
* @param functionCallbackResolver the function callback resolver used to resolve the
* function by bean name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Javadoc doesn't need to refer "bean" and just "by name" is enough. will fix when merging.

*/
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext) {
RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver) {

this(anthropicApi, defaultOptions, retryTemplate, functionCallbackContext, List.of());
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, List.of());
}

/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackContext the function callback context used to store the
* state of the function calls.
* @param functionCallbackResolver the function callback resolver used to resolve the
* function by bean name.
* @param toolFunctionCallbacks the tool function callbacks used to handle the tool
* calls.
*/
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext,
RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver,
List<FunctionCallback> toolFunctionCallbacks) {
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackContext, toolFunctionCallbacks,
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, toolFunctionCallbacks,
ObservationRegistry.NOOP);
}

Expand All @@ -185,16 +185,16 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackContext the function callback context used to store the
* state of the function calls.
* @param functionCallbackResolver the function callback resolver used to resolve the
* function by bean name.
* @param toolFunctionCallbacks the tool function callbacks used to handle the tool
* calls.
*/
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext,
RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver,
List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry) {

super(functionCallbackContext, defaultOptions, toolFunctionCallbacks);
super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks);

Assert.notNull(anthropicApi, "AnthropicApi must not be null");
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -170,7 +170,8 @@ public AnthropicApi anthropicApi() {
public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi,
TestObservationRegistry observationRegistry) {
return new AnthropicChatModel(anthropicApi, AnthropicChatOptions.builder().build(),
RetryTemplate.defaultInstance(), new FunctionCallbackContext(), List.of(), observationRegistry);
RetryTemplate.defaultInstance(), new DefaultFunctionCallbackResolver(), List.of(),
observationRegistry);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -154,19 +154,19 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi
}

public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext) {
this(openAIClientBuilder, options, functionCallbackContext, List.of());
FunctionCallbackResolver functionCallbackResolver) {
this(openAIClientBuilder, options, functionCallbackResolver, List.of());
}

public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks) {
this(openAIClientBuilder, options, functionCallbackContext, toolFunctionCallbacks, ObservationRegistry.NOOP);
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks) {
this(openAIClientBuilder, options, functionCallbackResolver, toolFunctionCallbacks, ObservationRegistry.NOOP);
}

public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
ObservationRegistry observationRegistry) {
super(functionCallbackContext, options, toolFunctionCallbacks);
super(functionCallbackResolver, options, toolFunctionCallbacks);
Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
this.openAIClient = openAIClientBuilder.buildClient();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallbackResolver;

/**
* @author Jihoon Kim
Expand All @@ -37,7 +37,7 @@ public class AzureOpenAiChatModelTests {
OpenAIClientBuilder mockClient;

@Mock
FunctionCallbackContext functionCallbackContext;
FunctionCallbackResolver functionCallbackResolver;

@Test
public void createAzureOpenAiChatModelTest() {
Expand All @@ -51,7 +51,7 @@ public void createAzureOpenAiChatModelTest() {
List<FunctionCallback> functionCallbacks = List.of(new TestFunctionCallback(callbackFromConstructorParam));

AzureOpenAiChatModel openAiChatModel = new AzureOpenAiChatModel(this.mockClient, chatOptions,
this.functionCallbackContext, functionCallbacks);
this.functionCallbackResolver, functionCallbacks);

assert 2 == openAiChatModel.getFunctionCallbackRegister().size();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
Expand Down Expand Up @@ -146,10 +146,10 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch

public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, FunctionCallingOptions defaultOptions,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
ObservationRegistry observationRegistry) {

super(functionCallbackContext, defaultOptions, toolFunctionCallbacks);
super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks);

Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null");
Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null");
Expand Down Expand Up @@ -609,7 +609,7 @@ public static final class Builder {

private FunctionCallingOptions defaultOptions = new FunctionCallingOptionsBuilder().build();

private FunctionCallbackContext functionCallbackContext;
private FunctionCallbackResolver functionCallbackResolver;

private List<FunctionCallback> toolFunctionCallbacks;

Expand Down Expand Up @@ -648,8 +648,18 @@ public Builder withDefaultOptions(FunctionCallingOptions defaultOptions) {
return this;
}

public Builder withFunctionCallbackContext(FunctionCallbackContext functionCallbackContext) {
this.functionCallbackContext = functionCallbackContext;
/**
* @deprecated Use {@link #functionCallbackResolver(FunctionCallbackResolver)}
* instead.
*/
@Deprecated
public Builder withFunctionCallbackContext(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}

public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}

Expand Down Expand Up @@ -708,7 +718,7 @@ public BedrockProxyChatModel build() {
}

var bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient,
this.bedrockRuntimeAsyncClient, this.defaultOptions, this.functionCallbackContext,
this.bedrockRuntimeAsyncClient, this.defaultOptions, this.functionCallbackResolver,
this.toolFunctionCallbacks, this.observationRegistry);

if (this.customObservationConvention != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
import org.springframework.ai.minimax.metadata.MiniMaxUsage;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -139,28 +139,30 @@ public MiniMaxChatModel(MiniMaxApi miniMaxApi, MiniMaxChatOptions options) {
* @param miniMaxApi The MiniMaxApi instance to be used for interacting with the
* MiniMax Chat API.
* @param options The MiniMaxChatOptions to configure the chat model.
* @param functionCallbackContext The function callback context.
* @param functionCallbackResolver The function callback resolver to resolve the
* function callback from the application context.
* @param retryTemplate The retry template.
*/
public MiniMaxChatModel(MiniMaxApi miniMaxApi, MiniMaxChatOptions options,
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
this(miniMaxApi, options, functionCallbackContext, List.of(), retryTemplate, ObservationRegistry.NOOP);
FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) {
this(miniMaxApi, options, functionCallbackResolver, List.of(), retryTemplate, ObservationRegistry.NOOP);
}

/**
* Initializes a new instance of the MiniMaxChatModel.
* @param miniMaxApi The MiniMaxApi instance to be used for interacting with the
* MiniMax Chat API.
* @param options The MiniMaxChatOptions to configure the chat model.
* @param functionCallbackContext The function callback context.
* @param functionCallbackResolver The function callback resolver to resolve the
* function callback from the application context.
* @param toolFunctionCallbacks The tool function callbacks.
* @param retryTemplate The retry template.
* @param observationRegistry The ObservationRegistry used for instrumentation.
*/
public MiniMaxChatModel(MiniMaxApi miniMaxApi, MiniMaxChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
super(functionCallbackContext, options, toolFunctionCallbacks);
super(functionCallbackResolver, options, toolFunctionCallbacks);
Assert.notNull(miniMaxApi, "MiniMaxApi must not be null");
Assert.notNull(options, "Options must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.minimax.api.MiniMaxApi;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -170,8 +170,9 @@ public MiniMaxApi minimaxApi() {

@Bean
public MiniMaxChatModel minimaxChatModel(MiniMaxApi minimaxApi, TestObservationRegistry observationRegistry) {
return new MiniMaxChatModel(minimaxApi, MiniMaxChatOptions.builder().build(), new FunctionCallbackContext(),
List.of(), RetryTemplate.defaultInstance(), observationRegistry);
return new MiniMaxChatModel(minimaxApi, MiniMaxChatOptions.builder().build(),
new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(),
observationRegistry);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import org.springframework.ai.mistralai.metadata.MistralAiUsage;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -120,21 +120,21 @@ public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions option
}

public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
this(mistralAiApi, options, functionCallbackContext, List.of(), retryTemplate);
FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) {
this(mistralAiApi, options, functionCallbackResolver, List.of(), retryTemplate);
}

public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
RetryTemplate retryTemplate) {
this(mistralAiApi, options, functionCallbackContext, toolFunctionCallbacks, retryTemplate,
this(mistralAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate,
ObservationRegistry.NOOP);
}

public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
super(functionCallbackContext, options, toolFunctionCallbacks);
super(functionCallbackResolver, options, toolFunctionCallbacks);
Assert.notNull(mistralAiApi, "mistralAiApi must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -181,7 +181,8 @@ public MistralAiApi mistralAiApi() {
public MistralAiChatModel openAiChatModel(MistralAiApi mistralAiApi,
TestObservationRegistry observationRegistry) {
return new MistralAiChatModel(mistralAiApi, MistralAiChatOptions.builder().build(),
new FunctionCallbackContext(), List.of(), RetryTemplate.defaultInstance(), observationRegistry);
new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(),
observationRegistry);
}

}
Expand Down
Loading