Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
package org.springframework.ai.autoconfigure.anthropic;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.springframework.ai.anthropic.AnthropicChatModel;
import org.springframework.ai.anthropic.api.AnthropicApi;
Expand Down Expand Up @@ -66,12 +69,15 @@ public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi, Anthropi
RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext,
List<FunctionCallback> toolFunctionCallbacks) {

AnthropicChatModel chatModel = new AnthropicChatModel(anthropicApi, chatProperties.getOptions(), retryTemplate, functionCallbackContext);

if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
Map<String, FunctionCallback> toolFunctionCallbackMap = toolFunctionCallbacks.stream()
.collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b));
chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap);
}

return new AnthropicChatModel(anthropicApi, chatProperties.getOptions(), retryTemplate,
functionCallbackContext);
return chatModel;
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
import org.springframework.util.StringUtils;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* @author Piotr Olaszewski
Expand Down Expand Up @@ -101,11 +104,15 @@ public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient,
AzureOpenAiChatProperties chatProperties, List<FunctionCallback> toolFunctionCallbacks,
FunctionCallbackContext functionCallbackContext) {

AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAIClient, chatProperties.getOptions(), functionCallbackContext);

if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
Map<String, FunctionCallback> toolFunctionCallbackMap = toolFunctionCallbacks.stream()
.collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b));
chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap);
}

return new AzureOpenAiChatModel(openAIClient, chatProperties.getOptions(), functionCallbackContext);
return chatModel;
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
import org.springframework.web.client.RestClient;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* @author Geng Rong
Expand All @@ -59,11 +62,15 @@ public MiniMaxChatModel miniMaxChatModel(MiniMaxConnectionProperties commonPrope
var miniMaxApi = miniMaxApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);

MiniMaxChatModel chatModel = new MiniMaxChatModel(miniMaxApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate);

if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
Map<String, FunctionCallback> toolFunctionCallbackMap = toolFunctionCallbacks.stream()
.collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b));
chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap);
}

return new MiniMaxChatModel(miniMaxApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate);
return chatModel;
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
package org.springframework.ai.autoconfigure.mistralai;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.mistralai.MistralAiChatModel;
Expand Down Expand Up @@ -81,12 +84,15 @@ public MistralAiChatModel mistralAiChatModel(MistralAiCommonProperties commonPro
var mistralAiApi = mistralAiApi(chatProperties.getApiKey(), commonProperties.getApiKey(),
chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder, responseErrorHandler);

MistralAiChatModel chatModel = new MistralAiChatModel(mistralAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate);

if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
Map<String, FunctionCallback> toolFunctionCallbackMap = toolFunctionCallbacks.stream()
.collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b));
chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap);
}

return new MistralAiChatModel(mistralAiApi, chatProperties.getOptions(), functionCallbackContext,
retryTemplate);
return chatModel;
}

private MistralAiApi mistralAiApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
package org.springframework.ai.autoconfigure.openai;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.model.function.FunctionCallback;
Expand Down Expand Up @@ -73,12 +76,13 @@ public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperti
var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder,
responseErrorHandler, "chat");

OpenAiChatModel chatModel = new OpenAiChatModel(openAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate);
if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
Map<String, FunctionCallback> toolFunctionCallbackMap = toolFunctionCallbacks.stream().collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b));
chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap);
}

return new OpenAiChatModel(openAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate);
return chatModel;
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.vertexai.VertexAI;
Expand Down Expand Up @@ -79,11 +82,15 @@ public VertexAiGeminiChatModel vertexAiGeminiChat(VertexAI vertexAi, VertexAiGem

FunctionCallbackContext functionCallbackContext = springAiFunctionManager(context);

VertexAiGeminiChatModel chatModel = new VertexAiGeminiChatModel(vertexAi, chatProperties.getOptions(), functionCallbackContext);

if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
Map<String, FunctionCallback> toolFunctionCallbackMap = toolFunctionCallbacks.stream()
.collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b));
chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap);
}

return new VertexAiGeminiChatModel(vertexAi, chatProperties.getOptions(), functionCallbackContext);
return chatModel;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
import org.springframework.web.client.RestClient;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* @author Geng Rong
Expand All @@ -61,11 +64,14 @@ public ZhiPuAiChatModel zhiPuAiChatModel(ZhiPuAiConnectionProperties commonPrope
var zhiPuAiApi = zhiPuAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);

ZhiPuAiChatModel chatModel = new ZhiPuAiChatModel(zhiPuAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate);

if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
}
Map<String, FunctionCallback> toolFunctionCallbackMap = toolFunctionCallbacks.stream()
.collect(Collectors.toMap(FunctionCallback::getName, Function.identity(), (a, b) -> b));
chatModel.getFunctionCallbackRegister().putAll(toolFunctionCallbackMap); }

return new ZhiPuAiChatModel(zhiPuAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate);
return chatModel;
}

@Bean
Expand Down