From 3ef748ca8d7f6478457d23e96d709de752ba604b Mon Sep 17 00:00:00 2001 From: GR Date: Tue, 23 Apr 2024 09:26:32 +0800 Subject: [PATCH] feat: Add ZhiPu AI model client --- models/spring-ai-zhipuai/README.md | 5 + models/spring-ai-zhipuai/pom.xml | 59 ++ .../ai/zhipuai/ZhiPuAiChatClient.java | 369 ++++++++ .../ai/zhipuai/ZhiPuAiChatOptions.java | 487 ++++++++++ .../ai/zhipuai/ZhiPuAiEmbeddingClient.java | 148 ++++ .../ai/zhipuai/ZhiPuAiEmbeddingOptions.java | 67 ++ .../ai/zhipuai/ZhiPuAiImageClient.java | 128 +++ .../ai/zhipuai/ZhiPuAiImageOptions.java | 131 +++ .../ai/zhipuai/aot/ZhiPuAiRuntimeHints.java | 45 + .../ai/zhipuai/api/ApiUtils.java | 37 + .../ai/zhipuai/api/ZhiPuAiApi.java | 838 ++++++++++++++++++ .../ai/zhipuai/api/ZhiPuAiImageApi.java | 129 +++ .../ZhiPuAiStreamFunctionCallingHelper.java | 194 ++++ .../ai/zhipuai/metadata/ZhiPuAiUsage.java | 64 ++ .../resources/META-INF/spring/aot.factories | 2 + .../zhipuai/ChatCompletionRequestTests.java | 144 +++ .../ai/zhipuai/ZhiPuAiTestConfiguration.java | 65 ++ .../ai/zhipuai/api/MockWeatherService.java | 92 ++ .../ai/zhipuai/api/ZhiPuAiApiIT.java | 68 ++ .../api/ZhiPuAiApiToolFunctionCallIT.java | 148 ++++ .../ai/zhipuai/api/ZhiPuAiRetryTests.java | 212 +++++ .../zhipuai/image/ZhiPuAiImageClientIT.java | 56 ++ .../test/resources/prompts/system-message.st | 3 + pom.xml | 2 + spring-ai-bom/pom.xml | 12 + .../src/main/antora/modules/ROOT/nav.adoc | 4 + .../functions/zhipuai-chat-functions.adoc | 226 +++++ .../ROOT/pages/api/chat/zhipuai-chat.adoc | 250 ++++++ .../api/embeddings/zhipuai-embeddings.adoc | 198 +++++ .../ROOT/pages/api/image/zhipuai-image.adoc | 117 +++ spring-ai-spring-boot-autoconfigure/pom.xml | 8 + .../zhipuai/ZhiPuAiAutoConfiguration.java | 128 +++ .../zhipuai/ZhiPuAiChatProperties.java | 62 ++ .../zhipuai/ZhiPuAiConnectionProperties.java | 31 + .../zhipuai/ZhiPuAiEmbeddingProperties.java | 70 ++ .../zhipuai/ZhiPuAiImageProperties.java | 57 ++ .../zhipuai/ZhiPuAiParentProperties.java | 43 + ...ot.autoconfigure.AutoConfiguration.imports | 1 + .../zhipuai/ZhiPuAiAutoConfigurationIT.java | 107 +++ .../zhipuai/ZhiPuAiPropertiesTests.java | 438 +++++++++ .../tool/FunctionCallbackInPromptIT.java | 114 +++ ...nctionCallbackWithPlainFunctionBeanIT.java | 172 ++++ .../tool/FunctionCallbackWrapperIT.java | 120 +++ .../zhipuai/tool/MockWeatherService.java | 94 ++ .../spring-ai-starter-zhipuai/pom.xml | 42 + 45 files changed, 5787 insertions(+) create mode 100644 models/spring-ai-zhipuai/README.md create mode 100644 models/spring-ai-zhipuai/pom.xml create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatClient.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingClient.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageClient.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ApiUtils.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java create mode 100644 models/spring-ai-zhipuai/src/main/resources/META-INF/spring/aot.factories create mode 100644 models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java create mode 100644 models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java create mode 100644 models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java create mode 100644 models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java create mode 100644 models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java create mode 100644 models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java create mode 100644 models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageClientIT.java create mode 100644 models/spring-ai-zhipuai/src/test/resources/prompts/system-message.st create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/image/zhipuai-image.adoc create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java create mode 100644 spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml diff --git a/models/spring-ai-zhipuai/README.md b/models/spring-ai-zhipuai/README.md new file mode 100644 index 00000000000..167295b8b49 --- /dev/null +++ b/models/spring-ai-zhipuai/README.md @@ -0,0 +1,5 @@ +[ZhiPu AI Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/zhipuai-chat.html) + +[ZhiPu AI Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/zhipuai-embeddings.html) + +[ZhiPu AI Image Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/image/zhipuai-image.html) \ No newline at end of file diff --git a/models/spring-ai-zhipuai/pom.xml b/models/spring-ai-zhipuai/pom.xml new file mode 100644 index 00000000000..0f5d84a1d6c --- /dev/null +++ b/models/spring-ai-zhipuai/pom.xml @@ -0,0 +1,59 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-zhipuai + jar + Spring AI Mistral AI + Mistral AI support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatClient.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatClient.java new file mode 100644 index 00000000000..98a4b72698a --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatClient.java @@ -0,0 +1,369 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.ChatClient; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion.Choice; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionFinishReason; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.MediaContent; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.Role; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; +import reactor.core.publisher.Flux; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +/** + * {@link ChatClient} and {@link StreamingChatClient} implementation for + * {@literal ZhiPuAI} backed by {@link ZhiPuAiApi}. + * + * @author Geng Rong + * @see ChatClient + * @see StreamingChatClient + * @see ZhiPuAiApi + */ +public class ZhiPuAiChatClient extends + AbstractFunctionCallSupport> + implements ChatClient, StreamingChatClient { + + private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiChatClient.class); + + /** + * The default options used for the chat completion requests. + */ + private ZhiPuAiChatOptions defaultOptions; + + /** + * The retry template used to retry the ZhiPuAI API calls. + */ + public final RetryTemplate retryTemplate; + + /** + * Low-level access to the ZhiPuAI API. + */ + private final ZhiPuAiApi zhiPuAiApi; + + /** + * Creates an instance of the ZhiPuAiChatClient. + * @param zhiPuAiApi The ZhiPuAiApi instance to be used for interacting with the + * ZhiPuAI Chat API. + * @throws IllegalArgumentException if zhiPuAiApi is null + */ + public ZhiPuAiChatClient(ZhiPuAiApi zhiPuAiApi) { + this(zhiPuAiApi, + ZhiPuAiChatOptions.builder().withModel(ZhiPuAiApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build()); + } + + /** + * Initializes an instance of the ZhiPuAiChatClient. + * @param zhiPuAiApi The ZhiPuAiApi instance to be used for interacting with the + * ZhiPuAI Chat API. + * @param options The ZhiPuAiChatOptions to configure the chat client. + */ + public ZhiPuAiChatClient(ZhiPuAiApi zhiPuAiApi, ZhiPuAiChatOptions options) { + this(zhiPuAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the ZhiPuAiChatClient. + * @param zhiPuAiApi The ZhiPuAiApi instance to be used for interacting with the + * ZhiPuAI Chat API. + * @param options The ZhiPuAiChatOptions to configure the chat client. + * @param functionCallbackContext The function callback context. + * @param retryTemplate The retry template. + */ + public ZhiPuAiChatClient(ZhiPuAiApi zhiPuAiApi, ZhiPuAiChatOptions options, + FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) { + super(functionCallbackContext); + Assert.notNull(zhiPuAiApi, "ZhiPuAiApi must not be null"); + Assert.notNull(options, "Options must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.zhiPuAiApi = zhiPuAiApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public ChatResponse call(Prompt prompt) { + + ChatCompletionRequest request = createRequest(prompt, false); + + return this.retryTemplate.execute(ctx -> { + + ResponseEntity completionEntity = this.callWithFunctionSupport(request); + + var chatCompletion = completionEntity.getBody(); + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + List generations = chatCompletion.choices().stream().map(choice -> { + return new Generation(choice.message().content(), toMap(chatCompletion.id(), choice)) + .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)); + }).toList(); + + return new ChatResponse(generations); + }); + } + + private Map toMap(String id, ChatCompletion.Choice choice) { + Map map = new HashMap<>(); + + var message = choice.message(); + if (message.role() != null) { + map.put("role", message.role().name()); + } + if (choice.finishReason() != null) { + map.put("finishReason", choice.finishReason().name()); + } + map.put("id", id); + return map; + } + + @Override + public Flux stream(Prompt prompt) { + + ChatCompletionRequest request = createRequest(prompt, true); + + return this.retryTemplate.execute(ctx -> { + + Flux completionChunks = this.zhiPuAiApi.chatCompletionStream(request); + + // For chunked responses, only the first chunk contains the choice role. + // The rest of the chunks with same ID share the same role. + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + + // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse + // the function call handling logic. + return completionChunks.map(chunk -> chunkToChatCompletion(chunk)).map(chatCompletion -> { + try { + chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion))) + .getBody(); + + @SuppressWarnings("null") + String id = chatCompletion.id(); + + List generations = chatCompletion.choices().stream().map(choice -> { + if (choice.message().role() != null) { + roleMap.putIfAbsent(id, choice.message().role().name()); + } + String finish = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generation = new Generation(choice.message().content(), + Map.of("id", id, "role", roleMap.get(id), "finishReason", finish)); + if (choice.finishReason() != null) { + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(choice.finishReason().name(), null)); + } + return generation; + }).toList(); + + return new ChatResponse(generations); + } + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + + }); + }); + } + + /** + * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. + * @param chunk the ChatCompletionChunk to convert + * @return the ChatCompletion + */ + private ZhiPuAiApi.ChatCompletion chunkToChatCompletion(ZhiPuAiApi.ChatCompletionChunk chunk) { + List choices = chunk.choices() + .stream() + .map(cc -> new Choice(cc.finishReason(), cc.index(), cc.delta(), cc.logprobs())) + .toList(); + + return new ZhiPuAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), + chunk.systemFingerprint(), "chat.completion", null); + } + + /** + * Accessible for testing. + */ + ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + + Set functionsForThisRequest = new HashSet<>(); + + List chatCompletionMessages = prompt.getInstructions().stream().map(m -> { + // Add text content. + List contents = new ArrayList<>(List.of(new MediaContent(m.getContent()))); + if (!CollectionUtils.isEmpty(m.getMedia())) { + // Add media content. + contents.addAll(m.getMedia() + .stream() + .map(media -> new MediaContent( + new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())))) + .toList()); + } + + return new ChatCompletionMessage(contents, ChatCompletionMessage.Role.valueOf(m.getMessageType().name())); + }).toList(); + + ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { + ZhiPuAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, + ChatOptions.class, ZhiPuAiChatOptions.class); + + Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, + IS_RUNTIME_CALL); + functionsForThisRequest.addAll(promptEnabledFunctions); + + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + + prompt.getOptions().getClass().getSimpleName()); + } + } + + if (this.defaultOptions != null) { + + Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, + !IS_RUNTIME_CALL); + + functionsForThisRequest.addAll(defaultEnabledFunctions); + + request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); + } + + // Add the enabled functions definitions to the request's tools parameter. + if (!CollectionUtils.isEmpty(functionsForThisRequest)) { + + request = ModelOptionsUtils.merge( + ZhiPuAiChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(), + request, ChatCompletionRequest.class); + } + + return request; + } + + private String fromMediaData(MimeType mimeType, Object mediaContentData) { + if (mediaContentData instanceof byte[] bytes) { + // Assume the bytes are an image. So, convert the bytes to a base64 encoded + // following the prefix pattern. + return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); + } + else if (mediaContentData instanceof String text) { + // Assume the text is a URLs or a base64 encoded image prefixed by the user. + return text; + } + else { + throw new IllegalArgumentException( + "Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); + } + } + + private List getFunctionTools(Set functionNames) { + return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { + var function = new ZhiPuAiApi.FunctionTool.Function(functionCallback.getDescription(), + functionCallback.getName(), functionCallback.getInputTypeSchema()); + return new ZhiPuAiApi.FunctionTool(function); + }).toList(); + } + + @Override + protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest, + ChatCompletionMessage responseMessage, List conversationHistory) { + + // Every tool-call item requires a separate function call and a response (TOOL) + // message. + for (ToolCall toolCall : responseMessage.toolCalls()) { + + var functionName = toolCall.function().name(); + String functionArguments = toolCall.function().arguments(); + + if (!this.functionCallbackRegister.containsKey(functionName)) { + throw new IllegalStateException("No function callback found for function name: " + functionName); + } + + String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); + + // Add the function response to the conversation. + conversationHistory + .add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null)); + } + + // Recursively call chatCompletionWithTools until the model doesn't call a + // functions anymore. + ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false); + newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class); + + return newRequest; + } + + @Override + protected List doGetUserMessages(ChatCompletionRequest request) { + return request.messages(); + } + + @Override + protected ChatCompletionMessage doGetToolResponseMessage(ResponseEntity chatCompletion) { + return chatCompletion.getBody().choices().iterator().next().message(); + } + + @Override + protected ResponseEntity doChatCompletion(ChatCompletionRequest request) { + return this.zhiPuAiApi.chatCompletionEntity(request); + } + + @Override + protected boolean isToolFunctionCall(ResponseEntity chatCompletion) { + var body = chatCompletion.getBody(); + if (body == null) { + return false; + } + + var choices = body.choices(); + if (CollectionUtils.isEmpty(choices)) { + return false; + } + + var choice = choices.get(0); + return !CollectionUtils.isEmpty(choice.message().toolCalls()) + && choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS; + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java new file mode 100644 index 00000000000..fa4feb9b37a --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -0,0 +1,487 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest.ResponseFormat; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.FunctionTool; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + +import java.util.*; + +/** + * @author Geng Rong + */ +@JsonInclude(Include.NON_NULL) +public class ZhiPuAiChatOptions implements FunctionCallingOptions, ChatOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Float frequencyPenalty; + /** + * The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_tokens") Integer maxTokens; + /** + * How many chat completion choices to generate for each input message. Note that you will be charged based + * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + */ + private @JsonProperty("n") Integer n; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Float presencePenalty; + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + private @JsonProperty("response_format") ResponseFormat responseFormat; + /** + * This feature is in Beta. If specified, our system will make a best effort to sample + * deterministically, such that repeated requests with the same seed and parameters should return the same result. + * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor + * changes in the backend. + */ + private @JsonProperty("seed") Integer seed; + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + @NestedConfigurationProperty + private @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + */ + private @JsonProperty("temperature") Float temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Float topP; + /** + * A list of tools the model may call. Currently, only functions are supported as a tool. Use this to + * provide a list of functions the model may generate JSON inputs for. + */ + @NestedConfigurationProperty + private @JsonProperty("tools") List tools; + /** + * Controls which (if any) function is called by the model. none means the model will not call a + * function and instead generates a message. auto means the model can pick between generating a message or calling a + * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces + * the model to call that function. none is the default when no functions are present. auto is the default if + * functions are present. Use the {@link ZhiPuAiApi.ChatCompletionRequest.ToolChoiceBuilder} to create a tool choice object. + */ + private @JsonProperty("tool_choice") String toolChoice; + /** + * A unique identifier representing your end-user, which can help ZhiPuAI to monitor and detect abuse. + * ID length requirement: minimum of 6 characters, maximum of 128 characters. + */ + private @JsonProperty("user_id") String user; + + /** + * ZhiPuAI Tool Function Callbacks to register with the ChatClient. + * For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution. + * For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions + * from the registry to be used by the ChatClient chat completion requests. + */ + @NestedConfigurationProperty + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. + * Functions with those names must exist in the functionCallbacks registry. + * The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution. + * + * Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing. + * If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution. + */ + @NestedConfigurationProperty + @JsonIgnore + private Set functions = new HashSet<>(); + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected ZhiPuAiChatOptions options; + + public Builder() { + this.options = new ZhiPuAiChatOptions(); + } + + public Builder(ZhiPuAiChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withN(Integer n) { + this.options.n = n; + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withSeed(Integer seed) { + this.options.seed = seed; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Float temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + this.options.topP = topP; + return this; + } + + public Builder withTools(List tools) { + this.options.tools = tools; + return this; + } + + public Builder withToolChoice(String toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withUser(String user) { + this.options.user = user; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public ZhiPuAiChatOptions build() { + return this.options; + } + + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public Float getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + + public Float getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public Integer getSeed() { + return this.seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return this.topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + public List getTools() { + return this.tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + public String getToolChoice() { + return this.toolChoice; + } + + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return functions; + } + + public void setFunctions(Set functionNames) { + this.functions = functionNames; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); + result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); + result = prime * result + ((n == null) ? 0 : n.hashCode()); + result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((seed == null) ? 0 : seed.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); + result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + result = prime * result + ((tools == null) ? 0 : tools.hashCode()); + result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((user == null) ? 0 : user.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj; + if (this.model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (this.frequencyPenalty == null) { + if (other.frequencyPenalty != null) + return false; + } + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) + return false; + if (this.maxTokens == null) { + if (other.maxTokens != null) + return false; + } + else if (!this.maxTokens.equals(other.maxTokens)) + return false; + if (this.n == null) { + if (other.n != null) + return false; + } + else if (!this.n.equals(other.n)) + return false; + if (this.presencePenalty == null) { + if (other.presencePenalty != null) + return false; + } + else if (!this.presencePenalty.equals(other.presencePenalty)) + return false; + if (this.responseFormat == null) { + if (other.responseFormat != null) + return false; + } + else if (!this.responseFormat.equals(other.responseFormat)) + return false; + if (this.seed == null) { + if (other.seed != null) + return false; + } + else if (!this.seed.equals(other.seed)) + return false; + if (this.stop == null) { + if (other.stop != null) + return false; + } + else if (!stop.equals(other.stop)) + return false; + if (this.temperature == null) { + if (other.temperature != null) + return false; + } + else if (!this.temperature.equals(other.temperature)) + return false; + if (this.topP == null) { + if (other.topP != null) + return false; + } + else if (!topP.equals(other.topP)) + return false; + if (this.tools == null) { + if (other.tools != null) + return false; + } + else if (!tools.equals(other.tools)) + return false; + if (this.toolChoice == null) { + if (other.toolChoice != null) + return false; + } + else if (!toolChoice.equals(other.toolChoice)) + return false; + if (this.user == null) { + if (other.user != null) + return false; + } + else if (!this.user.equals(other.user)) + return false; + return true; + } + + @Override + @JsonIgnore + public Integer getTopK() { + throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + } + + @JsonIgnore + public void setTopK(Integer topK) { + throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingClient.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingClient.java new file mode 100644 index 00000000000..3cd284fe839 --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingClient.java @@ -0,0 +1,148 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.*; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import java.util.List; + +/** + * ZhiPuAI Embedding Client implementation. + * + * @author Geng Rong + */ +public class ZhiPuAiEmbeddingClient extends AbstractEmbeddingClient { + + private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiEmbeddingClient.class); + + private final ZhiPuAiEmbeddingOptions defaultOptions; + + private final RetryTemplate retryTemplate; + + private final ZhiPuAiApi zhiPuAiApi; + + private final MetadataMode metadataMode; + + /** + * Constructor for the ZhiPuAiEmbeddingClient class. + * @param zhiPuAiApi The ZhiPuAiApi instance to use for making API requests. + */ + public ZhiPuAiEmbeddingClient(ZhiPuAiApi zhiPuAiApi) { + this(zhiPuAiApi, MetadataMode.EMBED); + } + + /** + * Initializes a new instance of the ZhiPuAiEmbeddingClient class. + * @param zhiPuAiApi The ZhiPuAiApi instance to use for making API requests. + * @param metadataMode The mode for generating metadata. + */ + public ZhiPuAiEmbeddingClient(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode) { + this(zhiPuAiApi, metadataMode, + ZhiPuAiEmbeddingOptions.builder().withModel(ZhiPuAiApi.DEFAULT_EMBEDDING_MODEL).build(), + RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the ZhiPuAiEmbeddingClient class. + * @param zhiPuAiApi The ZhiPuAiApi instance to use for making API requests. + * @param metadataMode The mode for generating metadata. + * @param zhiPuAiEmbeddingOptions The options for ZhiPuAI embedding. + */ + public ZhiPuAiEmbeddingClient(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode, + ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions) { + this(zhiPuAiApi, metadataMode, zhiPuAiEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the ZhiPuAiEmbeddingClient class. + * @param zhiPuAiApi - The ZhiPuAiApi instance to use for making API requests. + * @param metadataMode - The mode for generating metadata. + * @param options - The options for ZhiPuAI embedding. + * @param retryTemplate - The RetryTemplate for retrying failed API requests. + */ + public ZhiPuAiEmbeddingClient(ZhiPuAiApi zhiPuAiApi, MetadataMode metadataMode, ZhiPuAiEmbeddingOptions options, + RetryTemplate retryTemplate) { + Assert.notNull(zhiPuAiApi, "ZhiPuAiApi must not be null"); + Assert.notNull(metadataMode, "metadataMode must not be null"); + Assert.notNull(options, "options must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + + this.zhiPuAiApi = zhiPuAiApi; + this.metadataMode = metadataMode; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public List embed(Document document) { + Assert.notNull(document, "Document must not be null"); + return this.embed(document.getFormattedContent(this.metadataMode)); + } + + @SuppressWarnings("unchecked") + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + + return this.retryTemplate.execute(ctx -> { + + ZhiPuAiApi.EmbeddingRequest> apiRequest = (this.defaultOptions != null) + ? new ZhiPuAiApi.EmbeddingRequest<>(request.getInstructions(), this.defaultOptions.getModel()) + : new ZhiPuAiApi.EmbeddingRequest<>(request.getInstructions(), ZhiPuAiApi.DEFAULT_EMBEDDING_MODEL); + + if (request.getOptions() != null && !EmbeddingOptions.EMPTY.equals(request.getOptions())) { + apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest, + ZhiPuAiApi.EmbeddingRequest.class); + } + + ZhiPuAiApi.EmbeddingList apiEmbeddingResponse = this.zhiPuAiApi.embeddings(apiRequest) + .getBody(); + + if (apiEmbeddingResponse == null) { + logger.warn("No embeddings returned for request: {}", request); + return new EmbeddingResponse(List.of()); + } + + var metadata = generateResponseMetadata(apiEmbeddingResponse.model(), apiEmbeddingResponse.usage()); + + List embeddings = apiEmbeddingResponse.data() + .stream() + .map(e -> new Embedding(e.embedding(), e.index())) + .toList(); + + return new EmbeddingResponse(embeddings, metadata); + + }); + } + + private EmbeddingResponseMetadata generateResponseMetadata(String model, ZhiPuAiApi.Usage usage) { + EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); + metadata.put("model", model); + metadata.put("prompt-tokens", usage.promptTokens()); + metadata.put("completion-tokens", usage.completionTokens()); + metadata.put("total-tokens", usage.totalTokens()); + return metadata; + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java new file mode 100644 index 00000000000..92c818b6978 --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java @@ -0,0 +1,67 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; + +/** + * @author Geng Rong + */ +@JsonInclude(Include.NON_NULL) +public class ZhiPuAiEmbeddingOptions implements EmbeddingOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected ZhiPuAiEmbeddingOptions options; + + public Builder() { + this.options = new ZhiPuAiEmbeddingOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public ZhiPuAiEmbeddingOptions build() { + return this.options; + } + + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageClient.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageClient.java new file mode 100644 index 00000000000..f81118d7c3b --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageClient.java @@ -0,0 +1,128 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.image.*; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import java.util.List; + +/** + * ZhiPuAiImageClient is a class that implements the ImageClient interface. It provides a + * client for calling the ZhiPuAI image generation API. + * + * @author Geng Rong + */ +public class ZhiPuAiImageClient implements ImageClient { + + private final static Logger logger = LoggerFactory.getLogger(ZhiPuAiImageClient.class); + + private final ZhiPuAiImageOptions defaultOptions; + + private final ZhiPuAiImageApi zhiPuAiImageApi; + + public final RetryTemplate retryTemplate; + + public ZhiPuAiImageClient(ZhiPuAiImageApi zhiPuAiImageApi) { + this(zhiPuAiImageApi, ZhiPuAiImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public ZhiPuAiImageClient(ZhiPuAiImageApi zhiPuAiImageApi, ZhiPuAiImageOptions defaultOptions, + RetryTemplate retryTemplate) { + Assert.notNull(zhiPuAiImageApi, "ZhiPuAiImageApi must not be null"); + Assert.notNull(defaultOptions, "defaultOptions must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + this.zhiPuAiImageApi = zhiPuAiImageApi; + this.defaultOptions = defaultOptions; + this.retryTemplate = retryTemplate; + } + + public ZhiPuAiImageOptions getDefaultOptions() { + return this.defaultOptions; + } + + @Override + public ImageResponse call(ImagePrompt imagePrompt) { + return this.retryTemplate.execute(ctx -> { + + String instructions = imagePrompt.getInstructions().get(0).getText(); + + ZhiPuAiImageApi.ZhiPuAiImageRequest imageRequest = new ZhiPuAiImageApi.ZhiPuAiImageRequest(instructions, + ZhiPuAiImageApi.DEFAULT_IMAGE_MODEL); + + if (this.defaultOptions != null) { + imageRequest = ModelOptionsUtils.merge(this.defaultOptions, imageRequest, + ZhiPuAiImageApi.ZhiPuAiImageRequest.class); + } + + if (imagePrompt.getOptions() != null) { + imageRequest = ModelOptionsUtils.merge(toZhiPuAiImageOptions(imagePrompt.getOptions()), imageRequest, + ZhiPuAiImageApi.ZhiPuAiImageRequest.class); + } + + // Make the request + ResponseEntity imageResponseEntity = this.zhiPuAiImageApi + .createImage(imageRequest); + + // Convert to org.springframework.ai.model derived ImageResponse data type + return convertResponse(imageResponseEntity, imageRequest); + }); + } + + private ImageResponse convertResponse(ResponseEntity imageResponseEntity, + ZhiPuAiImageApi.ZhiPuAiImageRequest zhiPuAiImageRequest) { + ZhiPuAiImageApi.ZhiPuAiImageResponse imageApiResponse = imageResponseEntity.getBody(); + if (imageApiResponse == null) { + logger.warn("No image response returned for request: {}", zhiPuAiImageRequest); + return new ImageResponse(List.of()); + } + + List imageGenerationList = imageApiResponse.data() + .stream() + .map(entry -> new ImageGeneration(new Image(entry.url(), null))) + .toList(); + + return new ImageResponse(imageGenerationList); + } + + /** + * Convert the {@link ImageOptions} into {@link ZhiPuAiImageOptions}. + * @param runtimeImageOptions the image options to use. + * @return the converted {@link ZhiPuAiImageOptions}. + */ + private ZhiPuAiImageOptions toZhiPuAiImageOptions(ImageOptions runtimeImageOptions) { + ZhiPuAiImageOptions.Builder zhiPuAiImageOptionsBuilder = ZhiPuAiImageOptions.builder(); + if (runtimeImageOptions != null) { + if (runtimeImageOptions.getModel() != null) { + zhiPuAiImageOptionsBuilder.withModel(runtimeImageOptions.getModel()); + } + if (runtimeImageOptions instanceof ZhiPuAiImageOptions runtimeZhiPuAiImageOptions) { + if (runtimeZhiPuAiImageOptions.getUser() != null) { + zhiPuAiImageOptionsBuilder.withUser(runtimeZhiPuAiImageOptions.getUser()); + } + } + } + return zhiPuAiImageOptionsBuilder.build(); + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java new file mode 100644 index 00000000000..4c2bd886e48 --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java @@ -0,0 +1,131 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; + +import java.util.Objects; + +/** + * ZhiPuAI Image API options. ZhiPuAiImageOptions.java + * + * @author Geng Rong + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ZhiPuAiImageOptions implements ImageOptions { + + /** + * The model to use for image generation. + */ + @JsonProperty("model") + private String model = ZhiPuAiImageApi.DEFAULT_IMAGE_MODEL; + + /** + * A unique identifier representing your end-user, which can help ZhiPuAI to monitor + * and detect abuse. User ID length requirement: minimum of 6 characters, maximum of + * 128 characters + */ + @JsonProperty("user_id") + private String user; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final ZhiPuAiImageOptions options; + + private Builder() { + this.options = new ZhiPuAiImageOptions(); + } + + public Builder withModel(String model) { + options.setModel(model); + return this; + } + + public Builder withUser(String user) { + options.setUser(user); + return this; + } + + public ZhiPuAiImageOptions build() { + return options; + } + + } + + @Override + public Integer getN() { + return null; + } + + @Override + public String getModel() { + return this.model; + } + + @Override + public Integer getWidth() { + return null; + } + + @Override + public Integer getHeight() { + return null; + } + + @Override + public String getResponseFormat() { + return null; + } + + public void setModel(String model) { + this.model = model; + } + + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof ZhiPuAiImageOptions that)) + return false; + return Objects.equals(model, that.model) && Objects.equals(user, that.user); + } + + @Override + public int hashCode() { + return Objects.hash(model, user); + } + + @Override + public String toString() { + return "ZhiPuAiImageOptions{model='" + model + '\'' + ", user='" + user + '\'' + '}'; + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java new file mode 100644 index 00000000000..b1213dae4c2 --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java @@ -0,0 +1,45 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.aot; + +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +/** + * The ZhiPuAiRuntimeHints class is responsible for registering runtime hints for ZhiPu AI + * API classes. + * + * @author Geng Rong + */ +public class ZhiPuAiRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClassesInPackage(ZhiPuAiApi.class)) + hints.reflection().registerType(tr, mcs); + for (var tr : findJsonAnnotatedClassesInPackage(ZhiPuAiImageApi.class)) + hints.reflection().registerType(tr, mcs); + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ApiUtils.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ApiUtils.java new file mode 100644 index 00000000000..6c7cfac9df2 --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ApiUtils.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.api; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; + +import java.util.function.Consumer; + +/** + * @author Geng Rong + */ +public class ApiUtils { + + public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas"; + + public static Consumer getJsonContentHeaders(String apiKey) { + return (headers) -> { + headers.setBearerAuth(apiKey); + headers.setContentType(MediaType.APPLICATION_JSON); + }; + }; + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java new file mode 100644 index 00000000000..852b5a5af59 --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -0,0 +1,838 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.boot.context.properties.bind.ConstructorBinding; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Predicate; + +// @formatter:off +/** + * Single class implementation of the ZhiPuAI Chat Completion API: https://open.bigmodel.cn/dev/api#http and + * ZhiPuAI Embedding API: https://open.bigmodel.cn/dev/api#text_embedding. + * + * @author Geng Rong + */ +public class ZhiPuAiApi { + + public static final String DEFAULT_CHAT_MODEL = ChatModel.GLM_3_Turbo.getValue(); + public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.Embedding_2.getValue(); + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + + private final RestClient restClient; + + private final WebClient webClient; + + /** + * Create a new chat completion api with default base URL. + * + * @param zhiPuAiToken ZhiPuAI apiKey. + */ + public ZhiPuAiApi(String zhiPuAiToken) { + this(ApiUtils.DEFAULT_BASE_URL, zhiPuAiToken); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param zhiPuAiToken ZhiPuAI apiKey. + */ + public ZhiPuAiApi(String baseUrl, String zhiPuAiToken) { + this(baseUrl, zhiPuAiToken, RestClient.builder()); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param zhiPuAiToken ZhiPuAI apiKey. + * @param restClientBuilder RestClient builder. + */ + public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restClientBuilder) { + this(baseUrl, zhiPuAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param zhiPuAiToken ZhiPuAI apiKey. + * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. + */ + public ZhiPuAiApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + + this.restClient = restClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders(zhiPuAiToken)) + .defaultStatusHandler(responseErrorHandler) + .build(); + + this.webClient = WebClient.builder() + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders(zhiPuAiToken)) + .build(); + } + + /** + * ZhiPuAI Chat Completion Models: + * ZhiPuAI Model. + */ + public enum ChatModel { + GLM_4("GLM-4"), + GLM_3_Turbo("GLM-3-Turbo"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + /** + * Represents a tool the model may call. Currently, only functions are supported as a tool. + * + * @param type The type of the tool. Currently, only 'function' is supported. + * @param function The function definition. + */ + @JsonInclude(Include.NON_NULL) + public record FunctionTool( + @JsonProperty("type") Type type, + @JsonProperty("function") Function function) { + + /** + * Create a tool of type 'function' and the given function definition. + * @param function function definition. + */ + @ConstructorBinding + public FunctionTool(Function function) { + this(Type.FUNCTION, function); + } + + /** + * Create a tool of type 'function' and the given function definition. + */ + public enum Type { + /** + * Function tool type. + */ + @JsonProperty("function") FUNCTION + } + + /** + * Function definition. + * + * @param description A description of what the function does, used by the model to choose when and how to call + * the function. + * @param name The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, + * with a maximum length of 64. + * @param parameters The parameters the functions accepts, described as a JSON Schema object. To describe a + * function that accepts no parameters, provide the value {"type": "object", "properties": {}}. + */ + public record Function( + @JsonProperty("description") String description, + @JsonProperty("name") String name, + @JsonProperty("parameters") Map parameters) { + + /** + * Create tool function definition. + * + * @param description tool function description. + * @param name tool function name. + * @param jsonSchema tool function schema as json. + */ + @ConstructorBinding + public Function(String description, String name, String jsonSchema) { + this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema)); + } + } + } + + /** + * Creates a model response for the given chat conversation. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + * @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + * @param n How many chat completion choices to generate for each input message. Note that you will be charged based + * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + * @param presencePenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + * @param responseFormat An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + * @param seed This feature is in Beta. If specified, our system will make a best effort to sample + * deterministically, such that repeated requests with the same seed and parameters should return the same result. + * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor + * changes in the backend. + * @param stop Up to 4 sequences where the API will stop generating further tokens. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events as + * they become available, with the stream terminated by a data: [DONE] message. + * @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + * @param topP An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + * @param tools A list of tools the model may call. Currently, only functions are supported as a tool. Use this to + * provide a list of functions the model may generate JSON inputs for. + * @param toolChoice Controls which (if any) function is called by the model. none means the model will not call a + * function and instead generates a message. auto means the model can pick between generating a message or calling a + * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces + * the model to call that function. none is the default when no functions are present. auto is the default if + * functions are present. Use the {@link ToolChoiceBuilder} to create the tool choice value. + * @param user A unique identifier representing your end-user, which can help ZhiPuAI to monitor and detect abuse. + * + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionRequest ( + @JsonProperty("messages") List messages, + @JsonProperty("model") String model, + @JsonProperty("frequency_penalty") Float frequencyPenalty, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("n") Integer n, + @JsonProperty("presence_penalty") Float presencePenalty, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("seed") Integer seed, + @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, + @JsonProperty("temperature") Float temperature, + @JsonProperty("top_p") Float topP, + @JsonProperty("tools") List tools, + @JsonProperty("tool_choice") Object toolChoice, + @JsonProperty("user") String user) { + + /** + * Shortcut constructor for a chat completion request with the given messages and model. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param temperature What sampling temperature to use, between 0 and 1. + */ + public ChatCompletionRequest(List messages, String model, Float temperature) { + this(messages, model, null, null, null, null, + null, null, null, false, temperature, null, + null, null, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model and control for streaming. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param temperature What sampling temperature to use, between 0 and 1. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events + * as they become available, with the stream terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, String model, Float temperature, boolean stream) { + this(messages, model, null, null, null, null, + null, null, null, stream, temperature, null, + null, null, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. + * Streaming is set to false, temperature to 0.8 and all other parameters are null. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param tools A list of tools the model may call. Currently, only functions are supported as a tool. + * @param toolChoice Controls which (if any) function is called by the model. + */ + public ChatCompletionRequest(List messages, String model, + List tools, Object toolChoice) { + this(messages, model, null, null, null, null, + null, null, null, false, 0.8f, null, + tools, toolChoice, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. + * Streaming is set to false, temperature to 0.8 and all other parameters are null. + * + * @param messages A list of messages comprising the conversation so far. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events + * as they become available, with the stream terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, Boolean stream) { + this(messages, null, null, null, null, null, + null, null, null, stream, null, null, + null, null, null); + } + + /** + * Helper factory that creates a tool_choice of type 'none', 'auto' or selected function by name. + */ + public static class ToolChoiceBuilder { + /** + * Model can pick between generating a message or calling a function. + */ + public static final String AUTO = "auto"; + /** + * Model will not call a function and instead generates a message + */ + public static final String NONE = "none"; + + /** + * Specifying a particular function forces the model to call that function. + */ + public static Object FUNCTION(String functionName) { + return Map.of("type", "function", "function", Map.of("name", functionName)); + } + } + + /** + * An object specifying the format that the model must output. + * @param type Must be one of 'text' or 'json_object'. + */ + @JsonInclude(Include.NON_NULL) + public record ResponseFormat( + @JsonProperty("type") String type) { + } + } + + /** + * Message comprising the conversation. + * + * @param rawContent The contents of the message. Can be either a {@link MediaContent} or a {@link String}. + * The response message content is always a {@link String}. + * @param role The role of the messages author. Could be one of the {@link Role} types. + * @param name An optional name for the participant. Provides the model information to differentiate between + * participants of the same role. In case of Function calling, the name is the function name that the message is + * responding to. + * @param toolCallId Tool call that this message is responding to. Only applicable for the {@link Role#TOOL} role + * and null otherwise. + * @param toolCalls The tool calls generated by the model, such as function calls. Applicable only for + * {@link Role#ASSISTANT} role and null otherwise. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionMessage( + @JsonProperty("content") Object rawContent, + @JsonProperty("role") Role role, + @JsonProperty("name") String name, + @JsonProperty("tool_call_id") String toolCallId, + @JsonProperty("tool_calls") List toolCalls) { + + /** + * Get message content as String. + */ + public String content() { + if (this.rawContent == null) { + return null; + } + if (this.rawContent instanceof String text) { + return text; + } + throw new IllegalStateException("The content is not a string!"); + } + + /** + * Create a chat completion message with the given content and role. All other fields are null. + * @param content The contents of the message. + * @param role The role of the author of this message. + */ + public ChatCompletionMessage(Object content, Role role) { + this(content, role, null, null, null); + } + + /** + * The role of the author of this message. + */ + public enum Role { + /** + * System message. + */ + @JsonProperty("system") SYSTEM, + /** + * User message. + */ + @JsonProperty("user") USER, + /** + * Assistant message. + */ + @JsonProperty("assistant") ASSISTANT, + /** + * Tool message. + */ + @JsonProperty("tool") TOOL + } + + /** + * An array of content parts with a defined type. + * Each MediaContent can be of either "text" or "image_url" type. Not both. + * + * @param type Content type, each can be of type text or image_url. + * @param text The text content of the message. + * @param imageUrl The image content of the message. You can pass multiple + * images by adding multiple image_url content parts. Image input is only + * supported when using the glm-4v model. + */ + @JsonInclude(Include.NON_NULL) + public record MediaContent( + @JsonProperty("type") String type, + @JsonProperty("text") String text, + @JsonProperty("image_url") ImageUrl imageUrl) { + + /** + * @param url Either a URL of the image or the base64 encoded image data. + * The base64 encoded image data must have a special prefix in the following format: + * "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail Specifies the detail level of the image. + */ + @JsonInclude(Include.NON_NULL) + public record ImageUrl( + @JsonProperty("url") String url, + @JsonProperty("detail") String detail) { + + public ImageUrl(String url) { + this(url, null); + } + } + + /** + * Shortcut constructor for a text content. + * @param text The text content of the message. + */ + public MediaContent(String text) { + this("text", text, null); + } + + /** + * Shortcut constructor for an image content. + * @param imageUrl The image content of the message. + */ + public MediaContent(ImageUrl imageUrl) { + this("image_url", null, imageUrl); + } + } + /** + * The relevant tool call. + * + * @param id The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the + * Submit tool outputs to run endpoint. + * @param type The type of tool call the output is required for. For now, this is always function. + * @param function The function definition. + */ + @JsonInclude(Include.NON_NULL) + public record ToolCall( + @JsonProperty("id") String id, + @JsonProperty("type") String type, + @JsonProperty("function") ChatCompletionFunction function) { + } + + /** + * The function definition. + * + * @param name The name of the function. + * @param arguments The arguments that the model expects you to pass to the function. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionFunction( + @JsonProperty("name") String name, + @JsonProperty("arguments") String arguments) { + } + } + + public static String getTextContent(List content) { + return content.stream() + .filter(c -> "text".equals(c.type())) + .map(ChatCompletionMessage.MediaContent::text) + .reduce("", (a, b) -> a + b); + } + + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") CONTENT_FILTER, + /** + * The model called a tool. + */ + @JsonProperty("tool_calls") TOOL_CALLS, + /** + * (deprecated) The model called a function. + */ + @JsonProperty("function_call") FUNCTION_CALL, + /** + * Only for compatibility with Mistral AI API. + */ + @JsonProperty("tool_call") TOOL_CALL + } + + /** + * Represents a chat completion response returned by model, based on the provided input. + * + * @param id A unique identifier for the chat completion. + * @param choices A list of chat completion choices. Can be more than one if n is greater than 1. + * @param created The Unix timestamp (in seconds) of when the chat completion was created. + * @param model The model used for the chat completion. + * @param systemFingerprint This fingerprint represents the backend configuration that the model runs with. Can be + * used in conjunction with the seed request parameter to understand when backend changes have been made that might + * impact determinism. + * @param object The object type, which is always chat.completion. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletion( + @JsonProperty("id") String id, + @JsonProperty("choices") List choices, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("system_fingerprint") String systemFingerprint, + @JsonProperty("object") String object, + @JsonProperty("usage") Usage usage) { + + /** + * Chat completion choice. + * + * @param finishReason The reason the model stopped generating tokens. + * @param index The index of the choice in the list of choices. + * @param message A chat completion message generated by the model. + * @param logprobs Log probability information for the choice. + */ + @JsonInclude(Include.NON_NULL) + public record Choice( + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("index") Integer index, + @JsonProperty("message") ChatCompletionMessage message, + @JsonProperty("logprobs") LogProbs logprobs) { + + } + } + + /** + * Log probability information for the choice. + * + * @param content A list of message content tokens with log probability information. + */ + @JsonInclude(Include.NON_NULL) + public record LogProbs( + @JsonProperty("content") List content) { + + /** + * Message content tokens with log probability information. + * + * @param token The token. + * @param logprob The log probability of the token. + * @param probBytes A list of integers representing the UTF-8 bytes representation + * of the token. Useful in instances where characters are represented by multiple + * tokens and their byte representations must be combined to generate the correct + * text representation. Can be null if there is no bytes representation for the token. + * @param topLogprobs List of the most likely tokens and their log probability, + * at this token position. In rare cases, there may be fewer than the number of + * requested top_logprobs returned. + */ + @JsonInclude(Include.NON_NULL) + public record Content( + @JsonProperty("token") String token, + @JsonProperty("logprob") Float logprob, + @JsonProperty("bytes") List probBytes, + @JsonProperty("top_logprobs") List topLogprobs) { + + /** + * The most likely tokens and their log probability, at this token position. + * + * @param token The token. + * @param logprob The log probability of the token. + * @param probBytes A list of integers representing the UTF-8 bytes representation + * of the token. Useful in instances where characters are represented by multiple + * tokens and their byte representations must be combined to generate the correct + * text representation. Can be null if there is no bytes representation for the token. + */ + @JsonInclude(Include.NON_NULL) + public record TopLogProbs( + @JsonProperty("token") String token, + @JsonProperty("logprob") Float logprob, + @JsonProperty("bytes") List probBytes) { + } + } + } + + /** + * Usage statistics for the completion request. + * + * @param completionTokens Number of tokens in the generated completion. Only applicable for completion requests. + * @param promptTokens Number of tokens in the prompt. + * @param totalTokens Total number of tokens used in the request (prompt + completion). + */ + @JsonInclude(Include.NON_NULL) + public record Usage( + @JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("total_tokens") Integer totalTokens) { + + } + + /** + * Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + * + * @param id A unique identifier for the chat completion. Each chunk has the same ID. + * @param choices A list of chat completion choices. Can be more than one if n is greater than 1. + * @param created The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same + * timestamp. + * @param model The model used for the chat completion. + * @param systemFingerprint This fingerprint represents the backend configuration that the model runs with. Can be + * used in conjunction with the seed request parameter to understand when backend changes have been made that might + * impact determinism. + * @param object The object type, which is always 'chat.completion.chunk'. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionChunk( + @JsonProperty("id") String id, + @JsonProperty("choices") List choices, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("system_fingerprint") String systemFingerprint, + @JsonProperty("object") String object) { + + /** + * Chat completion choice. + * + * @param finishReason The reason the model stopped generating tokens. + * @param index The index of the choice in the list of choices. + * @param delta A chat completion delta generated by streamed model responses. + * @param logprobs Log probability information for the choice. + */ + @JsonInclude(Include.NON_NULL) + public record ChunkChoice( + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("index") Integer index, + @JsonProperty("delta") ChatCompletionMessage delta, + @JsonProperty("logprobs") LogProbs logprobs) { + } + } + + /** + * Creates a model response for the given chat conversation. + * + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false."); + + return this.restClient.post() + .uri("/v4/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + private final ZhiPuAiStreamFunctionCallingHelper chunkMerger = new ZhiPuAiStreamFunctionCallingHelper(); + + /** + * Creates a streaming chat response for the given chat conversation. + * + * @param chatRequest The chat completion request. Must have the stream property set to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); + + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + return this.webClient.post() + .uri("/v4/chat/completions") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)) + .map(chunk -> { + if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { + isInsideTool.set(true); + } + return chunk; + }) + .windowUntil(chunk -> { + if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + .concatMapIterable(window -> { + Mono monoChunk = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null), + this.chunkMerger::merge); + return List.of(monoChunk); + }) + .flatMap(mono -> mono); + } + + /** + * ZhiPuAI Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + /** + * DIMENSION: 1024 + */ + Embedding_2("Embedding-2"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + /** + * Represents an embedding vector returned by embedding endpoint. + * + * @param index The index of the embedding in the list of embeddings. + * @param embedding The embedding vector, which is a list of floats. The length of vector depends on the model. + * @param object The object type, which is always 'embedding'. + */ + @JsonInclude(Include.NON_NULL) + public record Embedding( + @JsonProperty("index") Integer index, + @JsonProperty("embedding") List embedding, + @JsonProperty("object") String object) { + + /** + * Create an embedding with the given index, embedding and object type set to 'embedding'. + * + * @param index The index of the embedding in the list of embeddings. + * @param embedding The embedding vector, which is a list of floats. The length of vector depends on the model. + */ + public Embedding(Integer index, List embedding) { + this(index, embedding, "embedding"); + } + } + + /** + * Creates an embedding vector representing the input text. + * + * @param input Input text to embed, encoded as a string or array of tokens. + * @param model ID of the model to use. + */ + @JsonInclude(Include.NON_NULL) + public record EmbeddingRequest( + @JsonProperty("input") T input, + @JsonProperty("model") String model) { + + + /** + * Create an embedding request with the given input. Encoding model is set to 'embedding-2'. + * @param input Input text to embed. + */ + public EmbeddingRequest(T input) { + this(input, DEFAULT_EMBEDDING_MODEL); + } + } + + /** + * List of multiple embedding responses. + * + * @param Type of the entities in the data list. + * @param object Must have value "list". + * @param data List of entities. + * @param model ID of the model to use. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(Include.NON_NULL) + public record EmbeddingList( + @JsonProperty("object") String object, + @JsonProperty("data") List data, + @JsonProperty("model") String model, + @JsonProperty("usage") Usage usage) { + } + + /** + * Creates an embedding vector representing the input text or token array. + * + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + * @param Type of the entity in the data list. Can be a {@link String} or {@link List} of tokens (e.g. + * Integers). For embedding multiple inputs in a single request, You can pass a {@link List} of {@link String} or + * {@link List} of {@link List} of tokens. For example: + * + *
{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} 
+ */ + public ResponseEntity> embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.input(), "The input can not be null."); + Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, + "The input must be either a String, or a List of Strings or List of List of integers."); + + if (embeddingRequest.input() instanceof List list) { + Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); + Assert.isTrue(list.size() <= 512, "The list must be 512 dimensions or less"); + Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer + || list.get(0) instanceof List, + "The input must be either a String, or a List of Strings or list of list of integers."); + } + + return this.restClient.post() + .uri("/v4/embeddings") + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + +} +// @formatter:on diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java new file mode 100644 index 00000000000..e8564e63735 --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java @@ -0,0 +1,129 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import java.util.List; + +/** + * ZhiPuAI Image API. + * + * @see CogView Images + * @author Geng Rong + */ +public class ZhiPuAiImageApi { + + public static final String DEFAULT_IMAGE_MODEL = ImageModel.CogView_3.getValue(); + + private final RestClient restClient; + + /** + * Create a new ZhiPuAI Image api with base URL set to https://api.ZhiPuAI.com + * @param zhiPuAiToken ZhiPuAI apiKey. + */ + public ZhiPuAiImageApi(String zhiPuAiToken) { + this(ApiUtils.DEFAULT_BASE_URL, zhiPuAiToken, RestClient.builder()); + } + + /** + * Create a new ZhiPuAI Image API with the provided base URL. + * @param baseUrl the base URL for the ZhiPuAI API. + * @param zhiPuAiToken ZhiPuAI apiKey. + * @param restClientBuilder the rest client builder to use. + */ + public ZhiPuAiImageApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restClientBuilder) { + this(baseUrl, zhiPuAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new ZhiPuAI Image API with the provided base URL. + * @param baseUrl the base URL for the ZhiPuAI API. + * @param zhiPuAiToken ZhiPuAI apiKey. + * @param restClientBuilder the rest client builder to use. + * @param responseErrorHandler the response error handler to use. + */ + public ZhiPuAiImageApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders(zhiPuAiToken)) + .defaultStatusHandler(responseErrorHandler) + .build(); + } + + /** + * ZhiPuAI Image API model. + * CogView + */ + public enum ImageModel { + + CogView_3("cogview-3"); + + private final String value; + + ImageModel(String model) { + this.value = model; + } + + public String getValue() { + return this.value; + } + + } + + // @formatter:off + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ZhiPuAiImageRequest ( + @JsonProperty("prompt") String prompt, + @JsonProperty("model") String model, + @JsonProperty("user_id") String user) { + + public ZhiPuAiImageRequest(String prompt, String model) { + this(prompt, model, null); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ZhiPuAiImageResponse( + @JsonProperty("created") Long created, + @JsonProperty("data") List data) { + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Data( + @JsonProperty("url") String url) { + } + // @formatter:onn + + public ResponseEntity createImage(ZhiPuAiImageRequest zhiPuAiImageRequest) { + Assert.notNull(zhiPuAiImageRequest, "Image request cannot be null."); + Assert.hasLength(zhiPuAiImageRequest.prompt(), "Prompt cannot be empty."); + + return this.restClient.post() + .uri("/v4/images/generations") + .body(zhiPuAiImageRequest) + .retrieve() + .toEntity(ZhiPuAiImageResponse.class); + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java new file mode 100644 index 00000000000..fc9507d499a --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java @@ -0,0 +1,194 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.api; + +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.*; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion.Choice; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionChunk.ChunkChoice; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.ChatCompletionFunction; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.Role; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.ToolCall; +import org.springframework.util.CollectionUtils; + +import java.util.ArrayList; +import java.util.List; + +/** + * Helper class to support Streaming function calling. It can merge the streamed + * ChatCompletionChunk in case of function calling message. + * + * @author Geng Rong + */ +public class ZhiPuAiStreamFunctionCallingHelper { + + /** + * Merge the previous and current ChatCompletionChunk into a single one. + * @param previous the previous ChatCompletionChunk + * @param current the current ChatCompletionChunk + * @return the merged ChatCompletionChunk + */ + public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) { + + if (previous == null) { + return current; + } + + String id = (current.id() != null ? current.id() : previous.id()); + Long created = (current.created() != null ? current.created() : previous.created()); + String model = (current.model() != null ? current.model() : previous.model()); + String systemFingerprint = (current.systemFingerprint() != null ? current.systemFingerprint() + : previous.systemFingerprint()); + String object = (current.object() != null ? current.object() : previous.object()); + + ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null : previous.choices().get(0)); + ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null : current.choices().get(0)); + + ChunkChoice choice = merge(previousChoice0, currentChoice0); + List chunkChoices = choice == null ? List.of() : List.of(choice); + return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object); + } + + private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { + if (previous == null) { + return current; + } + + ChatCompletionFinishReason finishReason = (current.finishReason() != null ? current.finishReason() + : previous.finishReason()); + Integer index = (current.index() != null ? current.index() : previous.index()); + + ChatCompletionMessage message = merge(previous.delta(), current.delta()); + + LogProbs logprobs = (current.logprobs() != null ? current.logprobs() : previous.logprobs()); + return new ChunkChoice(finishReason, index, message, logprobs); + } + + private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { + String content = (current.content() != null ? current.content() + : (previous.content() != null) ? previous.content() : ""); + Role role = (current.role() != null ? current.role() : previous.role()); + role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null + String name = (current.name() != null ? current.name() : previous.name()); + String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId()); + + List toolCalls = new ArrayList<>(); + ToolCall lastPreviousTooCall = null; + if (previous.toolCalls() != null) { + lastPreviousTooCall = previous.toolCalls().get(previous.toolCalls().size() - 1); + if (previous.toolCalls().size() > 1) { + toolCalls.addAll(previous.toolCalls().subList(0, previous.toolCalls().size() - 1)); + } + } + if (current.toolCalls() != null) { + if (current.toolCalls().size() > 1) { + throw new IllegalStateException("Currently only one tool call is supported per message!"); + } + var currentToolCall = current.toolCalls().iterator().next(); + if (currentToolCall.id() != null) { + if (lastPreviousTooCall != null) { + toolCalls.add(lastPreviousTooCall); + } + toolCalls.add(currentToolCall); + } + else { + toolCalls.add(merge(lastPreviousTooCall, currentToolCall)); + } + } + else { + if (lastPreviousTooCall != null) { + toolCalls.add(lastPreviousTooCall); + } + } + return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls); + } + + private ToolCall merge(ToolCall previous, ToolCall current) { + if (previous == null) { + return current; + } + String id = (current.id() != null ? current.id() : previous.id()); + String type = (current.type() != null ? current.type() : previous.type()); + ChatCompletionFunction function = merge(previous.function(), current.function()); + return new ToolCall(id, type, function); + } + + private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatCompletionFunction current) { + if (previous == null) { + return current; + } + String name = (current.name() != null ? current.name() : previous.name()); + StringBuilder arguments = new StringBuilder(); + if (previous.arguments() != null) { + arguments.append(previous.arguments()); + } + if (current.arguments() != null) { + arguments.append(current.arguments()); + } + return new ChatCompletionFunction(name, arguments.toString()); + } + + /** + * @param chatCompletion the ChatCompletionChunk to check + * @return true if the ChatCompletionChunk is a streaming tool function call. + */ + public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) { + + if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) { + return false; + } + + var choice = chatCompletion.choices().get(0); + if (choice == null || choice.delta() == null) { + return false; + } + return !CollectionUtils.isEmpty(choice.delta().toolCalls()); + } + + /** + * @param chatCompletion the ChatCompletionChunk to check + * @return true if the ChatCompletionChunk is a streaming tool function call and it is + * the last one. + */ + public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatCompletion) { + + if (chatCompletion == null || CollectionUtils.isEmpty(chatCompletion.choices())) { + return false; + } + + var choice = chatCompletion.choices().get(0); + if (choice == null || choice.delta() == null) { + return false; + } + return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS; + } + + /** + * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. + * @param chunk the ChatCompletionChunk to convert + * @return the ChatCompletion + */ + public ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) { + List choices = chunk.choices() + .stream() + .map(chunkChoice -> new Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(), + chunkChoice.logprobs())) + .toList(); + + return new ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.systemFingerprint(), + "chat.completion", null); + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java new file mode 100644 index 00000000000..d01d01490e3 --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/metadata/ZhiPuAiUsage.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.metadata; + +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.util.Assert; + +/** + * {@link Usage} implementation for {@literal ZhiPuAI}. + * + * @author Geng Rong + */ +public class ZhiPuAiUsage implements Usage { + + public static ZhiPuAiUsage from(ZhiPuAiApi.Usage usage) { + return new ZhiPuAiUsage(usage); + } + + private final ZhiPuAiApi.Usage usage; + + protected ZhiPuAiUsage(ZhiPuAiApi.Usage usage) { + Assert.notNull(usage, "ZhiPuAI Usage must not be null"); + this.usage = usage; + } + + protected ZhiPuAiApi.Usage getUsage() { + return this.usage; + } + + @Override + public Long getPromptTokens() { + return getUsage().promptTokens().longValue(); + } + + @Override + public Long getGenerationTokens() { + return getUsage().completionTokens().longValue(); + } + + @Override + public Long getTotalTokens() { + return getUsage().totalTokens().longValue(); + } + + @Override + public String toString() { + return getUsage().toString(); + } + +} diff --git a/models/spring-ai-zhipuai/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-zhipuai/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..5fc5cb0a7d1 --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.zhipuai.aot.ZhiPuAiRuntimeHints \ No newline at end of file diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java new file mode 100644 index 00000000000..abe4b9be31a --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java @@ -0,0 +1,144 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.zhipuai.api.MockWeatherService; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +public class ChatCompletionRequestTests { + + @Test + public void createRequestWithChatOptions() { + + var client = new ZhiPuAiChatClient(new ZhiPuAiApi("TEST"), + ZhiPuAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build()); + + var request = client.createRequest(new Prompt("Test message content"), false); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + + assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); + assertThat(request.temperature()).isEqualTo(66.6f); + + request = client.createRequest(new Prompt("Test message content", + ZhiPuAiChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()), true); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isTrue(); + + assertThat(request.model()).isEqualTo("PROMPT_MODEL"); + assertThat(request.temperature()).isEqualTo(99.9f); + } + + @Test + public void promptOptionsTools() { + + final String TOOL_FUNCTION_NAME = "CurrentWeather"; + + var client = new ZhiPuAiChatClient(new ZhiPuAiApi("TEST"), + ZhiPuAiChatOptions.builder().withModel("DEFAULT_MODEL").build()); + + var request = client.createRequest(new Prompt("Test message content", + ZhiPuAiChatOptions.builder() + .withModel("PROMPT_MODEL") + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName(TOOL_FUNCTION_NAME) + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build()), + false); + + assertThat(client.getFunctionCallbackRegister()).hasSize(1); + assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + assertThat(request.model()).isEqualTo("PROMPT_MODEL"); + + assertThat(request.tools()).hasSize(1); + assertThat(request.tools().get(0).function().name()).isEqualTo(TOOL_FUNCTION_NAME); + } + + @Test + public void defaultOptionsTools() { + + final String TOOL_FUNCTION_NAME = "CurrentWeather"; + + var client = new ZhiPuAiChatClient(new ZhiPuAiApi("TEST"), + ZhiPuAiChatOptions.builder() + .withModel("DEFAULT_MODEL") + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName(TOOL_FUNCTION_NAME) + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build()); + + var request = client.createRequest(new Prompt("Test message content"), false); + + assertThat(client.getFunctionCallbackRegister()).hasSize(1); + assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); + assertThat(client.getFunctionCallbackRegister().get(TOOL_FUNCTION_NAME).getDescription()) + .isEqualTo("Get the weather in location"); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); + + assertThat(request.tools()).as("Default Options callback functions are not automatically enabled!") + .isNullOrEmpty(); + + // Explicitly enable the function + request = client.createRequest(new Prompt("Test message content", + ZhiPuAiChatOptions.builder().withFunction(TOOL_FUNCTION_NAME).build()), false); + + assertThat(request.tools()).hasSize(1); + assertThat(request.tools().get(0).function().name()).as("Explicitly enabled function") + .isEqualTo(TOOL_FUNCTION_NAME); + + // Override the default options function with one from the prompt + request = client.createRequest(new Prompt("Test message content", + ZhiPuAiChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName(TOOL_FUNCTION_NAME) + .withDescription("Overridden function description") + .build())) + .build()), + false); + + assertThat(request.tools()).hasSize(1); + assertThat(request.tools().get(0).function().name()).as("Explicitly enabled function") + .isEqualTo(TOOL_FUNCTION_NAME); + + assertThat(client.getFunctionCallbackRegister()).hasSize(1); + assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); + assertThat(client.getFunctionCallbackRegister().get(TOOL_FUNCTION_NAME).getDescription()) + .isEqualTo("Overridden function description"); + } + +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java new file mode 100644 index 00000000000..fbc53b08cdf --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiTestConfiguration.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai; + +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +/** + * @author Geng Rong + */ +@SpringBootConfiguration +public class ZhiPuAiTestConfiguration { + + @Bean + public ZhiPuAiApi zhiPuAiApi() { + return new ZhiPuAiApi(getApiKey()); + } + + @Bean + public ZhiPuAiImageApi zhiPuAiImageApi() { + return new ZhiPuAiImageApi(getApiKey()); + } + + private String getApiKey() { + String apiKey = System.getenv("ZHIPU_AI_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name ZHIPU_AI_API_KEY"); + } + return apiKey; + } + + @Bean + public ZhiPuAiChatClient zhiPuAiChatClient(ZhiPuAiApi api) { + return new ZhiPuAiChatClient(api); + } + + @Bean + public ZhiPuAiImageClient zhiPuAiImageClient(ZhiPuAiImageApi imageApi) { + return new ZhiPuAiImageClient(imageApi); + } + + @Bean + public EmbeddingClient zhiPuAiEmbeddingClient(ZhiPuAiApi api) { + return new ZhiPuAiEmbeddingClient(api); + } + +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java new file mode 100644 index 00000000000..cc4897c685c --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java @@ -0,0 +1,92 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.api; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +import java.util.function.Function; + +/** + * @author Geng Rong + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, request.unit); + } + +} \ No newline at end of file diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java new file mode 100644 index 00000000000..4b4fc442438 --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.api; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.*; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.Role; +import org.springframework.http.ResponseEntity; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Objects; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".+") +public class ZhiPuAiApiIT { + + ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + + @Test + void chatCompletionEntity() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + ResponseEntity response = zhiPuAiApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 0.7f, false)); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void chatCompletionStream() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + Flux response = zhiPuAiApi + .chatCompletionStream(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 0.7f, true)); + + assertThat(response).isNotNull(); + assertThat(response.collectList().block()).isNotNull(); + } + + @Test + void embeddings() { + ResponseEntity> response = zhiPuAiApi + .embeddings(new ZhiPuAiApi.EmbeddingRequest<>("Hello world")); + + assertThat(response).isNotNull(); + assertThat(Objects.requireNonNull(response.getBody()).data()).hasSize(1); + assertThat(response.getBody().data().get(0).embedding()).hasSize(1024); + } + +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java new file mode 100644 index 00000000000..154caa85fd5 --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java @@ -0,0 +1,148 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.zhipuai.api; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.Role; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest.ToolChoiceBuilder; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.FunctionTool.Type; +import org.springframework.http.ResponseEntity; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatModel.GLM_4; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".+") +public class ZhiPuAiApiToolFunctionCallIT { + + private final Logger logger = LoggerFactory.getLogger(ZhiPuAiApiToolFunctionCallIT.class); + + MockWeatherService weatherService = new MockWeatherService(); + + ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + + @SuppressWarnings("null") + @Test + public void toolFunctionCall() { + + // Step 1: send the conversation and available functions to the model + var message = new ChatCompletionMessage("What's the weather like in San Francisco, Tokyo, and Paris?", + Role.USER); + + var functionTool = new ZhiPuAiApi.FunctionTool(Type.FUNCTION, + new ZhiPuAiApi.FunctionTool.Function( + "Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", + ModelOptionsUtils.jsonToMap(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """))); + + List messages = new ArrayList<>(List.of(message)); + + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, GLM_4.value, + List.of(functionTool), ToolChoiceBuilder.AUTO); + + ResponseEntity chatCompletion = zhiPuAiApi.chatCompletionEntity(chatCompletionRequest); + + assertThat(chatCompletion.getBody()).isNotNull(); + assertThat(chatCompletion.getBody().choices()).isNotEmpty(); + + ChatCompletionMessage responseMessage = chatCompletion.getBody().choices().get(0).message(); + + assertThat(responseMessage.role()).isEqualTo(Role.ASSISTANT); + assertThat(responseMessage.toolCalls()).isNotNull(); + + messages.add(responseMessage); + + // Send the info for each function call and function response to the model. + for (ToolCall toolCall : responseMessage.toolCalls()) { + var functionName = toolCall.function().name(); + if ("getCurrentWeather".equals(functionName)) { + MockWeatherService.Request weatherRequest = fromJson(toolCall.function().arguments(), + MockWeatherService.Request.class); + + MockWeatherService.Response weatherResponse = weatherService.apply(weatherRequest); + + // extend conversation with function response. + messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, + functionName, toolCall.id(), null)); + } + } + + var functionResponseRequest = new ChatCompletionRequest(messages, GLM_4.value, 0.8f); + + ResponseEntity chatCompletion2 = zhiPuAiApi.chatCompletionEntity(functionResponseRequest); + + logger.info("Final response: " + chatCompletion2.getBody()); + + assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); + + assertThat(chatCompletion2.getBody().choices().get(0).message().role()).isEqualTo(Role.ASSISTANT); + assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("San Francisco") + .containsAnyOf("30.0°C", "30°C", "30.0°F", "30°F"); + assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("Tokyo") + .containsAnyOf("10.0°C", "10°C", "10.0°F", "10°F"); + assertThat(chatCompletion2.getBody().choices().get(0).message().content()).contains("Paris") + .containsAnyOf("15.0°C", "15°C", "15.0°F", "15°F"); + } + + private static T fromJson(String json, Class targetClass) { + try { + return new ObjectMapper().readValue(json, targetClass); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + +} \ No newline at end of file diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java new file mode 100644 index 00000000000..413e6245ca1 --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java @@ -0,0 +1,212 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.api; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.image.ImageMessage; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.ai.zhipuai.*; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.*; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.Role; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi.Data; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi.ZhiPuAiImageRequest; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi.ZhiPuAiImageResponse; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.when; + +/** + * @author Geng Rong + */ +@SuppressWarnings("unchecked") +@ExtendWith(MockitoExtension.class) +public class ZhiPuAiRetryTests { + + private class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + onErrorRetryCount = context.getRetryCount(); + } + + } + + private TestRetryListener retryListener; + + private RetryTemplate retryTemplate; + + private @Mock ZhiPuAiApi zhiPuAiApi; + + private @Mock ZhiPuAiImageApi zhiPuAiImageApi; + + private ZhiPuAiChatClient chatClient; + + private ZhiPuAiEmbeddingClient embeddingClient; + + private ZhiPuAiImageClient imageClient; + + @BeforeEach + public void beforeEach() { + retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + retryListener = new TestRetryListener(); + retryTemplate.registerListener(retryListener); + + chatClient = new ZhiPuAiChatClient(zhiPuAiApi, ZhiPuAiChatOptions.builder().build(), null, retryTemplate); + embeddingClient = new ZhiPuAiEmbeddingClient(zhiPuAiApi, MetadataMode.EMBED, + ZhiPuAiEmbeddingOptions.builder().build(), retryTemplate); + imageClient = new ZhiPuAiImageClient(zhiPuAiImageApi, ZhiPuAiImageOptions.builder().build(), retryTemplate); + } + + @Test + public void zhiPuAiChatTransientError() { + + var choice = new ChatCompletion.Choice(ChatCompletionFinishReason.STOP, 0, + new ChatCompletionMessage("Response", Role.ASSISTANT), null); + ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, + new ZhiPuAiApi.Usage(10, 10, 10)); + + when(zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + + var result = chatClient.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void zhiPuAiChatNonTransientError() { + when(zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text"))); + } + + @Test + public void zhiPuAiChatStreamTransientError() { + + var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0, + new ChatCompletionMessage("Response", Role.ASSISTANT), null); + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, + null); + + when(zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(Flux.just(expectedChatCompletion)); + + var result = chatClient.stream(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void zhiPuAiChatStreamNonTransientError() { + when(zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text"))); + } + + @Test + public void zhiPuAiEmbeddingTransientError() { + + EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", + List.of(new Embedding(0, List.of(9.9, 8.8))), "model", new ZhiPuAiApi.Usage(10, 10, 10)); + + when(zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + + var result = embeddingClient + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8)); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void zhiPuAiEmbeddingNonTransientError() { + when(zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> embeddingClient + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); + } + + @Test + public void zhiPuAiImageTransientError() { + + var expectedResponse = new ZhiPuAiImageResponse(678l, List.of(new Data("url678"))); + + when(zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); + + var result = imageClient.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void zhiPuAiImageNonTransientError() { + when(zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) + .thenThrow(new RuntimeException("Transient Error 1")); + assertThrows(RuntimeException.class, + () -> imageClient.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + } + +} diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageClientIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageClientIT.java new file mode 100644 index 00000000000..b99877b6b7d --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/image/ZhiPuAiImageClientIT.java @@ -0,0 +1,56 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.zhipuai.image; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.image.*; +import org.springframework.ai.zhipuai.ZhiPuAiTestConfiguration; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = ZhiPuAiTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".+") +public class ZhiPuAiImageClientIT { + + @Autowired + protected ImageClient imageClient; + + @Test + void imageAsUrlTest() { + var options = ImageOptionsBuilder.builder().withHeight(1024).withWidth(1024).build(); + + var instructions = """ + A light cream colored mini golden doodle with a sign that contains the message "I'm on my way to BARCADE!"."""; + + ImagePrompt imagePrompt = new ImagePrompt(instructions, options); + + ImageResponse imageResponse = imageClient.call(imagePrompt); + + assertThat(imageResponse.getResults()).hasSize(1); + + ImageResponseMetadata imageResponseMetadata = imageResponse.getMetadata(); + assertThat(imageResponseMetadata.created()).isPositive(); + + var generation = imageResponse.getResult(); + Image image = generation.getOutput(); + assertThat(image.getUrl()).isNotEmpty(); + assertThat(image.getB64Json()).isNull(); + } + +} diff --git a/models/spring-ai-zhipuai/src/test/resources/prompts/system-message.st b/models/spring-ai-zhipuai/src/test/resources/prompts/system-message.st new file mode 100644 index 00000000000..579febd8d9b --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/resources/prompts/system-message.st @@ -0,0 +1,3 @@ +You are an AI assistant that helps people find information. +Your name is {name}. +You should reply to the user's request with your name and also in the style of a {voice}. \ No newline at end of file diff --git a/pom.xml b/pom.xml index 1aa98332298..9c842b5515e 100644 --- a/pom.xml +++ b/pom.xml @@ -28,6 +28,7 @@ models/spring-ai-vertex-ai-gemini models/spring-ai-anthropic models/spring-ai-watsonx-ai + models/spring-ai-zhipuai spring-ai-test spring-ai-spring-boot-autoconfigure spring-ai-spring-boot-starters/spring-ai-starter-openai @@ -74,6 +75,7 @@ vector-stores/spring-ai-elasticsearch-store spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai spring-ai-spring-boot-starters/spring-ai-starter-elasticsearch-store + spring-ai-spring-boot-starters/spring-ai-starter-zhipuai diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index d1c857f9da9..be3643fb9e8 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -302,6 +302,12 @@ ${project.version} + + org.springframework.ai + spring-ai-zhipuai + ${project.version} + + org.springframework.ai spring-ai-pinecone-store-spring-boot-starter @@ -384,6 +390,12 @@ spring-ai-elasticsearch-store-spring-boot-starter ${project.version} + + + org.springframework.ai + spring-ai-zhipuai-spring-boot-starter + ${project.version} + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 7d8e11bb98f..0c5c9893705 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -22,6 +22,8 @@ ***** xref:api/chat/functions/vertexai-gemini-chat-functions.adoc[Function Calling] *** xref:api/chat/mistralai-chat.adoc[Mistral AI] **** xref:api/chat/functions/mistralai-chat-functions.adoc[Function Calling] +*** xref:api/chat/zhipuai-chat.adoc[ZhiPu AI] +**** xref:api/chat/functions/zhipuai-chat-functions.adoc[Function Calling] *** xref:api/chat/anthropic-chat.adoc[Anthropic 3] **** xref:api/chat/functions/anthropic-chat-functions.adoc[Function Calling] *** xref:api/chat/watsonx-ai-chat.adoc[Watsonx.AI] @@ -36,9 +38,11 @@ **** xref:api/embeddings/bedrock-titan-embedding.adoc[Titan] *** xref:api/embeddings/onnx.adoc[Transformers (ONNX)] *** xref:api/embeddings/mistralai-embeddings.adoc[Mistral AI] +*** xref:api/embeddings/zhipuai-embeddings.adoc[ZhiPu AI] ** xref:api/imageclient.adoc[] *** xref:api/image/openai-image.adoc[OpenAI] *** xref:api/image/stabilityai-image.adoc[Stability] +*** xref:api/image/zhipuai-image.adoc[ZhiPuAI] ** xref:api/audio[Audio API] *** xref:api/audio/transcriptions.adoc[] **** xref:api/audio/transcriptions/openai-transcriptions.adoc[OpenAI] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc new file mode 100644 index 00000000000..338505b78b5 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/zhipuai-chat-functions.adoc @@ -0,0 +1,226 @@ += Function Calling + +You can register custom Java functions with the `ZhiPuAiChatClient` and have the ZhiPuAI model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This allows you to connect the LLM capabilities with external tools and APIs. +The ZhiPuAI models are trained to detect when a function should be called and to respond with JSON that adheres to the function signature. + +The ZhiPuAI API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. + +Spring AI provides flexible and user-friendly ways to register and call custom functions. +In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. The `description` helps the model to understand when to call the function. + +As a developer, you need to implement a functions that takes the function call arguments sent from the AI model, and respond with the result back to the model. Your function can in turn invoke other 3rd party services to provide the results. + +Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatClient`. + +Under the hood, Spring wraps your POJO (the function) with the appropriate adapter code that enables interaction with the AI Model, saving you from writing tedious boilerplate code. +The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java[FunctionCallback.java] interface and the companion link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java[FunctionCallbackWrapper.java] utility class to simplify the implementation and registration of Java callback functions. + +// Additionally, the Auto-Configuration provides a way to auto-register any Function beans definition as function calling candidates in the `ChatClient`. + + +== How it works + +Suppose we want the AI model to respond with information that it does not have, for example the current temperature at a given location. + +We can provide the AI model with metadata about our own functions that it can use to retrieve that information as it processes your prompt. + +For example, if during the processing of a prompt, the AI Model determines that it needs additional information about the temperature in a given location, it will start a server side generated request/response interaction. The AI Model invokes a client side function. +The AI Model provides method invocation details as JSON and it is the responsibility of the client to execute that function and return the response. + +The model-client interaction is illustrated in the <> diagram. + +Spring AI greatly simplifies code you need to write to support function invocation. +It brokers the function invocation conversation for you. +You can simply provide your function definition as a `@Bean` and then provide the bean name of the function in your prompt options. +You can also reference multiple function bean names in your prompt. + +== Quick Start + +Let's create a chatbot that answer questions by calling our own function. +To support the response of the chatbot, we will register our own function that takes a location and returns the current weather in that location. + +When the response to the prompt to the model needs to answer a question such as `"What’s the weather like in Boston?"` the AI model will invoke the client providing the location value as an argument to be passed to the function. This RPC-like data is passed as JSON. + +Our function calls some SaaS based weather service API and returns the weather response back to the model to complete the conversation. In this example we will use a simple implementation named `MockWeatherService` that hard codes the temperature for various locations. + +The following `MockWeatherService.java` represents the weather service API: + +[source,java] +---- +public class MockWeatherService implements Function { + + public enum Unit { C, F } + public record Request(String location, Unit unit) {} + public record Response(double temp, Unit unit) {} + + public Response apply(Request request) { + return new Response(30.0, Unit.C); + } +} +---- + +=== Registering Functions as Beans + +With the link:../zhipuai-chat.html#_auto_configuration[ZhiPuAiChatClient Auto-Configuration] you have multiple ways to register custom functions as beans in the Spring context. + +We start with describing the most POJO friendly options. + + +==== Plain Java Functions + +In this approach you define `@Beans` in your application context as you would any other Spring managed object. + +Internally, Spring AI `ChatClient` will create an instance of a `FunctionCallbackWrapper` wrapper that adds the logic for it being invoked via the AI model. +The name of the `@Bean` is passed as a `ChatOption`. + + +[source,java] +---- +@Configuration +static class Config { + + @Bean + @Description("Get the weather in location") // function description + public Function weatherFunction1() { + return new MockWeatherService(); + } + ... +} +---- + +The `@Description` annotation is optional and provides a function description (2) that helps the model to understand when to call the function. It is an important property to set to help the AI model determine what client side function to invoke. + +Another option to provide the description of the function is to the `@JacksonDescription` annotation on the `MockWeatherService.Request` to provide the function description: + +[source,java] +---- + +@Configuration +static class Config { + + @Bean + public Function currentWeather3() { // (1) bean name as function name. + return new MockWeatherService(); + } + ... +} + +@JsonClassDescription("Get the weather in location") // (2) function description +public record Request(String location, Unit unit) {} +---- + +It is a best practice to annotate the request object with information such that the generates JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke. + +The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java[FunctionCallbackWithPlainFunctionBeanIT.java] demonstrates this approach. + + +==== FunctionCallback Wrapper + +Another way register a function is to create `FunctionCallbackWrapper` wrapper like this: + +[source,java] +---- +@Configuration +static class Config { + + @Bean + public FunctionCallback weatherFunctionInfo() { + + return new FunctionCallbackWrapper<>("CurrentWeather", // (1) function name + "Get the weather in location", // (2) function description + (response) -> "" + response.temp() + response.unit(), // (3) Response Converter + new MockWeatherService()); // function code + } + ... +} +---- + +It wraps the 3rd party, `MockWeatherService` function and registers it as a `CurrentWeather` function with the `ZhiPuAiChatClient`. +It also provides a description (2) and an optional response converter (3) to convert the response into a text as expected by the model. + +NOTE: By default, the response converter does a JSON serialization of the Response object. + +NOTE: The `FunctionCallbackWrapper` internally resolves the function call signature based on the `MockWeatherService.Request` class. + +=== Specifying functions in Chat Options + +To let the model know and call your `CurrentWeather` function you need to enable it in your prompt requests: + +[source,java] +---- +ZhiPuAiChatClient chatClient = ... + +UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + +ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + ZhiPuAiChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function + +logger.info("Response: {}", response); +---- + +// NOTE: You can can have multiple functions registered in your `ChatClient` but only those enabled in the prompt request will be considered for the function calling. + +Above user question will trigger 3 calls to `CurrentWeather` function (one for each city) and the final response will be something like this: + +---- +Here is the current weather for the requested cities: +- San Francisco, CA: 30.0°C +- Tokyo, Japan: 10.0°C +- Paris, France: 15.0°C +---- + +The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java[FunctionCallbackWrapperIT.java] test demo this approach. + + +=== Register/Call Functions with Prompt Options + +In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests: + +[source,java] +---- +ZhiPuAiChatClient chatClient = ... + +UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + +var promptOptions = ZhiPuAiChatOptions.builder() + .withFunctionCallbacks(List.of(new FunctionCallbackWrapper<>( + "CurrentWeather", // name + "Get the weather in location", // function description + new MockWeatherService()))) // function code + .build(); + +ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), promptOptions)); +---- + +NOTE: The in-prompt registered functions are enabled by default for the duration of this request. + +This approach allows to dynamically chose different functions to be called based on the user input. + +The https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java[FunctionCallbackInPromptIT.java] integration test provides a complete example of how to register a function with the `ZhiPuAiChatClient` and use it in a prompt request. +// +// === Register Functions with Default Options +// +// You can programmatically register functions with the `ZhiPuAiChatClient` using the `ZhiPuAiChatOptions#withFunctionCallbacks`: +// +// [source,java] +// ---- +// +// ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(apiKey); +// +// var defaultOptions = ZhiPuAiChatOptions.builder() +// .withFunctionCallbacks(List.of(new FunctionCallbackWrapper<>( +// "CurrentWeather", // name +// "Get the weather in location", // function description +// new MockWeatherService()))) // function code +// .build(); +// +// ZhiPuAiChatClient chatClient = new ZhiPuAiChatClient(zhiPuAiApi, defaultOptions); +// +// UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); +// +// ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), +// ZhiPuAiChatOptions.builder().withFunction("CurrentWeather").build())); // Enable the function +// ---- +// +// NOTE: Functions are registered when ZhiPuAiChatClient is created, by you must enable in the Prompt the functions to be used in the request. \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc new file mode 100644 index 00000000000..1200f7cfce5 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc @@ -0,0 +1,250 @@ += ZhiPu AI Chat + +Spring AI supports the various AI language models from ZhiPu AI. You can interact with ZhiPu AI language models and create a multilingual conversational assistant based on ZhiPuAI models. + +== Prerequisites + +You will need to create an API with ZhiPuAI to access ZhiPu AI language models. + +Create an account at https://open.bigmodel.cn/login[ZhiPu AI registration page] and generate the token on the https://open.bigmodel.cn/usercenter/apikeys[API Keys page]. +The Spring AI project defines a configuration property named `spring.ai.zhipuai.api-key` that you should set to the value of the `API Key` obtained from https://open.bigmodel.cn/usercenter/apikeys[API Keys page]. +Exporting an environment variable is one way to set that configuration property: + +[source,shell] +---- +export SPRING_AI_ZHIPU_AI_API_KEY= +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the ZhiPuAI Chat Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-zhipuai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-zhipuai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Chat Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the ZhiPu AI Chat client. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.zhiPu` is used as the property prefix that lets you connect to ZhiPuAI. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.zhipuai.base-url | The URL to connect to | https://open.bigmodel.cn/api/paas +| spring.ai.zhipuai.api-key | The API Key | - +|==== + +==== Configuration Properties + +The prefix `spring.ai.zhipuai.chat` is the property prefix that lets you configure the chat client implementation for ZhiPuAI. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.zhipuai.chat.enabled | Enable ZhiPuAI chat client. | true +| spring.ai.zhipuai.chat.base-url | Optional overrides the spring.ai.zhipuai.base-url to provide chat specific url | https://open.bigmodel.cn/api/paas +| spring.ai.zhipuai.chat.api-key | Optional overrides the spring.ai.zhipuai.api-key to provide chat specific api-key | - +| spring.ai.zhipuai.chat.options.model | This is the ZhiPuAI Chat model to use | `GLM-3-Turbo` (the `GLM-3-Turbo`, `GLM-4`, and `GLM-4V` point to the latest model versions) +| spring.ai.zhipuai.chat.options.maxTokens | The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. | - +| spring.ai.zhipuai.chat.options.temperature | The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make output more random while lower values will make results more focused and deterministic. It is not recommended to modify temperature and top_p for the same completions request as the interaction of these two settings is difficult to predict. | 0.7 +| spring.ai.zhipuai.chat.options.topP | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. | 1.0 +| spring.ai.zhipuai.chat.options.n | How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Default value is 1 and cannot be greater than 5. Specifically, when the temperature is very small and close to 0, we can only return 1 result. If n is already set and>1 at this time, service will return an illegal input parameter (invalid_request_error) | 1 +| spring.ai.zhipuai.chat.options.presencePenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | 0.0f +| spring.ai.zhipuai.chat.options.frequencyPenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | 0.0f +| spring.ai.zhipuai.chat.options.stop | The model will stop generating characters specified by stop, and currently only supports a single stop word in the format of ["stop_word1"] | - +| spring.ai.zhipuai.chat.options.user | A unique identifier representing your end-user, which can help ZhiPuAI to monitor and detect abuse. | - +|==== + +NOTE: You can override the common `spring.ai.zhipuai.base-url` and `spring.ai.zhipuai.api-key` for the `ChatClient` implementations. +The `spring.ai.zhipuai.chat.base-url` and `spring.ai.zhipuai.chat.api-key` properties if set take precedence over the common properties. +This is useful if you want to use different ZhiPuAI accounts for different models and different model endpoints. + +TIP: All properties prefixed with `spring.ai.zhipuai.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java[ZhiPuAiChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc. + +On start-up, the default options can be configured with the `ZhiPuAiChatClient(api, options)` constructor or the `spring.ai.zhipuai.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default model and temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatClient.call( + new Prompt( + "Generate the names of 5 famous pirates.", + ZhiPuAiChatOptions.builder() + .withModel(ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue()) + .withTemperature(0.5f) + .build() + )); +---- + +TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java[ZhiPuAiChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Sample Controller (Auto-configuration) + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-zhipuai-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the ZhiPuAi Chat client: + +[source,application.properties] +---- +spring.ai.zhipuai.api-key=YOUR_API_KEY +spring.ai.zhipuai.chat.options.model=glm-3-turbo +spring.ai.zhipuai.chat.options.temperature=0.7 +---- + +TIP: replace the `api-key` with your ZhiPuAI credentials. + +This will create a `ZhiPuAiChatClient` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat client for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final ZhiPuAiChatClient chatClient; + + @Autowired + public ChatController(ZhiPuAiChatClient chatClient) { + this.chatClient = chatClient; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatClient.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + var prompt = new Prompt(new UserMessage(message)); + return chatClient.stream(prompt); + } +} +---- + +== Manual Configuration + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatClient.java[ZhiPuAiChatClient] implements the `ChatClient` and `StreamingChatClient` and uses the <> to connect to the ZhiPuAI service. + +Add the `spring-ai-zhipuai` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-zhipuai + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-zhipuai' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create a `ZhiPuAiChatClient` and use it for text generations: + +[source,java] +---- +var zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + +var chatClient = new ZhiPuAiChatClient(zhiPuAiApi, ZhiPuAiChatOptions.builder() + .withModel(ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue()) + .withTemperature(0.4f) + .withMaxTokens(200) + .build()); + +ChatResponse response = chatClient.call( + new Prompt("Generate the names of 5 famous pirates.")); + +// Or with streaming responses +Flux streamResponse = chatClient.stream( + new Prompt("Generate the names of 5 famous pirates.")); +---- + +The `ZhiPuAiChatOptions` provides the configuration information for the chat requests. +The `ZhiPuAiChatOptions.Builder` is fluent options builder. + +=== Low-level ZhiPuAiApi Client [[low-level-api]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java[ZhiPuAiApi] provides is lightweight Java client for link:https://open.bigmodel.cn/dev/api[ZhiPu AI API]. + +Here is a simple snippet how to use the api programmatically: + +[source,java] +---- +ZhiPuAiApi zhiPuAiApi = + new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + +ChatCompletionMessage chatCompletionMessage = + new ChatCompletionMessage("Hello world", Role.USER); + +// Sync request +ResponseEntity response = zhiPuAiApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7f, false)); + +// Streaming request +Flux streamResponse = zhiPuAiApi.chatCompletionStream( + new ChatCompletionRequest(List.of(chatCompletionMessage), ZhiPuAiApi.ChatModel.GLM_3_Turbo.getValue(), 0.7f, true)); +---- + +Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java[ZhiPuAiApi.java]'s JavaDoc for further information. + +==== ZhiPuAiApi Samples +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java[ZhiPuAiApiIT.java] test provides some general examples how to use the lightweight library. \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc new file mode 100644 index 00000000000..a189b448ac7 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/zhipuai-embeddings.adoc @@ -0,0 +1,198 @@ += ZhiPuAI Embeddings + +Spring AI supports the ZhiPuAI's text embeddings models. +ZhiPuAI’s text embeddings measure the relatedness of text strings. +An embedding is a vector (list) of floating point numbers. The distance between two vectors measures their relatedness. Small distances suggest high relatedness and large distances suggest low relatedness. + +== Prerequisites + +You will need to create an API with ZhiPuAI to access ZhiPu AI language models. + +Create an account at https://open.bigmodel.cn/login[ZhiPu AI registration page] and generate the token on the https://open.bigmodel.cn/usercenter/apikeys[API Keys page]. +The Spring AI project defines a configuration property named `spring.ai.zhipu.api-key` that you should set to the value of the `API Key` obtained from https://open.bigmodel.cn/usercenter/apikeys[API Keys page]. +Exporting an environment variable is one way to set that configuration property: + +[source,shell] +---- +export SPRING_AI_ZHIPU_AI_API_KEY= +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the Azure ZhiPuAI Embedding Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-zhipuai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-zhipuai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Embedding Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the ZhiPuAI Embedding client. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.zhipuai` is used as the property prefix that lets you connect to ZhiPuAI. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.zhipuai.base-url | The URL to connect to | https://open.bigmodel.cn/api/paas +| spring.ai.zhipuai.api-key | The API Key | - +|==== + +==== Configuration Properties + +The prefix `spring.ai.zhipuai.embedding` is property prefix that configures the `EmbeddingClient` implementation for ZhiPuAI. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.zhipuai.embedding.enabled | Enable ZhiPuAI embedding client. | true +| spring.ai.zhipuai.embedding.base-url | Optional overrides the spring.ai.zhipuai.base-url to provide embedding specific url | - +| spring.ai.zhipuai.embedding.api-key | Optional overrides the spring.ai.zhipuai.api-key to provide embedding specific api-key | - +| spring.ai.zhipuai.embedding.options.model | The model to use | embedding-2 +|==== + +NOTE: You can override the common `spring.ai.zhipuai.base-url` and `spring.ai.zhipuai.api-key` for the `ChatClient` and `EmbeddingClient` implementations. +The `spring.ai.zhipuai.embedding.base-url` and `spring.ai.zhipuai.embedding.api-key` properties if set take precedence over the common properties. +Similarly, the `spring.ai.zhipuai.embedding.base-url` and `spring.ai.zhipuai.embedding.api-key` properties if set take precedence over the common properties. +This is useful if you want to use different ZhiPuAI accounts for different models and different model endpoints. + +TIP: All properties prefixed with `spring.ai.zhipuai.embedding.options` can be overridden at runtime by adding a request specific <> to the `EmbeddingRequest` call. + +== Runtime Options [[embedding-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingOptions.java[ZhiPuAiEmbeddingOptions.java] provides the ZhiPuAI configurations, such as the model to use and etc. + +The default options can be configured using the `spring.ai.zhipuai.embedding.options` properties as well. + +At start-time use the `ZhiPuAiEmbeddingClient` constructor to set the default options used for all embedding requests. +At run-time you can override the default options, using a `ZhiPuAiEmbeddingOptions` instance as part of your `EmbeddingRequest`. + +For example to override the default model name for a specific request: + +[source,java] +---- +EmbeddingResponse embeddingResponse = embeddingClient.call( + new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), + ZhiPuAiEmbeddingOptions.builder() + .withModel("Different-Embedding-Model-Deployment-Name") + .build())); +---- + +== Sample Controller + +This will create a `EmbeddingClient` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the `EmbeddingClient` implementation. + +[source,application.properties] +---- +spring.ai.zhipuai.api-key=YOUR_API_KEY +spring.ai.zhipuai.embedding.options.model=embedding-2 +---- + +[source,java] +---- +@RestController +public class EmbeddingController { + + private final EmbeddingClient embeddingClient; + + @Autowired + public EmbeddingController(EmbeddingClient embeddingClient) { + this.embeddingClient = embeddingClient; + } + + @GetMapping("/ai/embedding") + public Map embed(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + EmbeddingResponse embeddingResponse = this.embeddingClient.embedForResponse(List.of(message)); + return Map.of("embedding", embeddingResponse); + } +} +---- + +== Manual Configuration + +If you are not using Spring Boot, you can manually configure the ZhiPuAI Embedding Client. +For this add the `spring-ai-zhipuai` dependency to your project's Maven `pom.xml` file: +[source, xml] +---- + + org.springframework.ai + spring-ai-zhipuai + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-zhipuai' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +NOTE: The `spring-ai-zhipuai` dependency provides access also to the `ZhiPuAiChatClient`. +For more information about the `ZhiPuAiChatClient` refer to the link:../chat/zhipuai-chat.html[ZhiPuAI Chat Client] section. + +Next, create an `ZhiPuAiEmbeddingClient` instance and use it to compute the similarity between two input texts: + +[source,java] +---- +var zhiPuAiApi = new ZhiPuAiApi(System.getenv("ZHIPU_AI_API_KEY")); + +var embeddingClient = new ZhiPuAiEmbeddingClient(zhiPuAiApi) + .withDefaultOptions(ZhiPuAiChatOptions.build() + .withModel("embedding-2") + .build()); + +EmbeddingResponse embeddingResponse = embeddingClient + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); +---- + +The `ZhiPuAiEmbeddingOptions` provides the configuration information for the embedding requests. +The options class offers a `builder()` for easy options creation. + + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/image/zhipuai-image.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/image/zhipuai-image.adoc new file mode 100644 index 00000000000..b6a02d3bdaf --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/image/zhipuai-image.adoc @@ -0,0 +1,117 @@ += ZhiPuAI Image Generation + + +Spring AI supports CogView, the Image generation model from ZhiPuAI. + +== Prerequisites + +You will need to create an API with ZhiPuAI to access ZhiPu AI language models. + +Create an account at https://open.bigmodel.cn/login[ZhiPu AI registration page] and generate the token on the https://open.bigmodel.cn/usercenter/apikeys[API Keys page]. +The Spring AI project defines a configuration property named `spring.ai.zhipuai.api-key` that you should set to the value of the `API Key` obtained from https://open.bigmodel.cn/usercenter/apikeys[API Keys page]. +Exporting an environment variable is one way to set that configuration property: + +[source,shell] +---- +export SPRING_AI_ZHIPU_AI_API_KEY= +---- +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the ZhiPuAI Chat Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-zhipuai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-zhipuai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Image Generation Properties + +The prefix `spring.ai.zhipuai.image` is the property prefix that lets you configure the `ImageClient` implementation for ZhiPuAI. + +[cols="3,5,1"] +|==== +| Property | Description | Default +| spring.ai.zhipuai.image.enabled | Enable ZhiPuAI image client. | true +| spring.ai.zhipuai.image.base-url | Optional overrides the spring.ai.zhipuai.base-url to provide chat specific url | - +| spring.ai.zhipuai.image.api-key | Optional overrides the spring.ai.zhipuai.api-key to provide chat specific api-key | - +| spring.ai.zhipuai.image.options.model | The model to use for image generation. | cogview-3 +| spring.ai.zhipuai.image.options.user | A unique identifier representing your end-user, which can help ZhiPuAI to monitor and detect abuse. | - +|==== + +==== Connection Properties + +The prefix `spring.ai.zhipuai` is used as the property prefix that lets you connect to ZhiPuAI. + +[cols="3,5,1"] +|==== +| Property | Description | Default +| spring.ai.zhipuai.base-url | The URL to connect to | https://open.bigmodel.cn/api/paas +| spring.ai.zhipuai.api-key | The API Key | - +|==== + +==== Configuration Properties + + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the ZhiPuAI Image client. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + + +== Runtime Options [[image-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java[ZhiPuAiImageOptions.java] provides model configurations, such as the model to use, the quality, the size, etc. + +On start-up, the default options can be configured with the `ZhiPuAiImageClient(ZhiPuAiImageApi zhiPuAiImageApi)` constructor and the `withDefaultOptions(ZhiPuAiImageOptions defaultOptions)` method. Alternatively, use the `spring.ai.zhipuai.image.options.*` properties described previously. + +At runtime you can override the default options by adding new, request specific, options to the `ImagePrompt` call. +For example to override the ZhiPuAI specific options such as quality and the number of images to create, use the following code example: + +[source,java] +---- +ImageResponse response = zhiPuAiImageClient.call( + new ImagePrompt("A light cream colored mini golden doodle", + ZhiPuAiImageOptions.builder() + .withQuality("hd") + .withN(4) + .withHeight(1024) + .withWidth(1024).build()) + +); +---- + +TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java[ZhiPuAiImageOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptions.java[ImageOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java[ImageOptionsBuilder#builder()]. diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index d48275f3d83..62067578fd2 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -267,6 +267,14 @@ true + + + org.springframework.ai + spring-ai-zhipuai + ${project.parent.version} + true + + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java new file mode 100644 index 00000000000..e2c8b1126cf --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfiguration.java @@ -0,0 +1,128 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai; + +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.zhipuai.ZhiPuAiChatClient; +import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingClient; +import org.springframework.ai.zhipuai.ZhiPuAiImageClient; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import java.util.List; + +/** + * @author Geng Rong + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) +@ConditionalOnClass(ZhiPuAiApi.class) +@EnableConfigurationProperties({ ZhiPuAiConnectionProperties.class, ZhiPuAiChatProperties.class, + ZhiPuAiEmbeddingProperties.class, ZhiPuAiImageProperties.class }) +public class ZhiPuAiAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = ZhiPuAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public ZhiPuAiChatClient zhiPuAiChatClient(ZhiPuAiConnectionProperties commonProperties, + ZhiPuAiChatProperties chatProperties, RestClient.Builder restClientBuilder, + List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { + + var zhiPuAiApi = zhiPuAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), + chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); + + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { + chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + } + + return new ZhiPuAiChatClient(zhiPuAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = ZhiPuAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public ZhiPuAiEmbeddingClient zhiPuAiEmbeddingClient(ZhiPuAiConnectionProperties commonProperties, + ZhiPuAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { + + var zhiPuAiApi = zhiPuAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), + embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); + + return new ZhiPuAiEmbeddingClient(zhiPuAiApi, embeddingProperties.getMetadataMode(), + embeddingProperties.getOptions(), retryTemplate); + } + + private ZhiPuAiApi zhiPuAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey, + RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + + String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; + Assert.hasText(resolvedBaseUrl, "ZhiPuAI base URL must be set"); + + String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; + Assert.hasText(resolvedApiKey, "ZhiPuAI API key must be set"); + + return new ZhiPuAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = ZhiPuAiImageProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public ZhiPuAiImageClient zhiPuAiImageClient(ZhiPuAiConnectionProperties commonProperties, + ZhiPuAiImageProperties imageProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler) { + + String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey() + : commonProperties.getApiKey(); + + String baseUrl = StringUtils.hasText(imageProperties.getBaseUrl()) ? imageProperties.getBaseUrl() + : commonProperties.getBaseUrl(); + + Assert.hasText(apiKey, "ZhiPuAI API key must be set"); + Assert.hasText(baseUrl, "ZhiPuAI base URL must be set"); + + var zhiPuAiImageApi = new ZhiPuAiImageApi(baseUrl, apiKey, restClientBuilder, responseErrorHandler); + + return new ZhiPuAiImageClient(zhiPuAiImageApi, imageProperties.getOptions(), retryTemplate); + } + + @Bean + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java new file mode 100644 index 00000000000..1ec554b44dc --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiChatProperties.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai; + +import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Geng Rong + */ +@ConfigurationProperties(ZhiPuAiChatProperties.CONFIG_PREFIX) +public class ZhiPuAiChatProperties extends ZhiPuAiParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.zhipuai.chat"; + + public static final String DEFAULT_CHAT_MODEL = ZhiPuAiApi.ChatModel.GLM_3_Turbo.value; + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + /** + * Enable ZhiPuAI chat client. + */ + private boolean enabled = true; + + @NestedConfigurationProperty + private ZhiPuAiChatOptions options = ZhiPuAiChatOptions.builder() + .withModel(DEFAULT_CHAT_MODEL) + .withTemperature(DEFAULT_TEMPERATURE.floatValue()) + .build(); + + public ZhiPuAiChatOptions getOptions() { + return options; + } + + public void setOptions(ZhiPuAiChatOptions options) { + this.options = options; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java new file mode 100644 index 00000000000..6d850f3d75f --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiConnectionProperties.java @@ -0,0 +1,31 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +@ConfigurationProperties(ZhiPuAiConnectionProperties.CONFIG_PREFIX) +public class ZhiPuAiConnectionProperties extends ZhiPuAiParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.zhipuai"; + + public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas"; + + public ZhiPuAiConnectionProperties() { + super.setBaseUrl(DEFAULT_BASE_URL); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java new file mode 100644 index 00000000000..78357ceff50 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiEmbeddingProperties.java @@ -0,0 +1,70 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai; + +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingOptions; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Geng Rong + */ +@ConfigurationProperties(ZhiPuAiEmbeddingProperties.CONFIG_PREFIX) +public class ZhiPuAiEmbeddingProperties extends ZhiPuAiParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.zhipuai.embedding"; + + public static final String DEFAULT_EMBEDDING_MODEL = ZhiPuAiApi.EmbeddingModel.Embedding_2.value; + + /** + * Enable ZhiPuAI embedding client. + */ + private boolean enabled = true; + + private MetadataMode metadataMode = MetadataMode.EMBED; + + @NestedConfigurationProperty + private ZhiPuAiEmbeddingOptions options = ZhiPuAiEmbeddingOptions.builder() + .withModel(DEFAULT_EMBEDDING_MODEL) + .build(); + + public ZhiPuAiEmbeddingOptions getOptions() { + return this.options; + } + + public void setOptions(ZhiPuAiEmbeddingOptions options) { + this.options = options; + } + + public MetadataMode getMetadataMode() { + return this.metadataMode; + } + + public void setMetadataMode(MetadataMode metadataMode) { + this.metadataMode = metadataMode; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java new file mode 100644 index 00000000000..19ce4f7bb4c --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiImageProperties.java @@ -0,0 +1,57 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai; + +import org.springframework.ai.zhipuai.ZhiPuAiImageOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Geng Rong + */ +@ConfigurationProperties(ZhiPuAiImageProperties.CONFIG_PREFIX) +public class ZhiPuAiImageProperties extends ZhiPuAiParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.zhipuai.image"; + + /** + * Enable ZhiPuAI Image client. + */ + private boolean enabled = true; + + /** + * Options for ZhiPuAI Image API. + */ + @NestedConfigurationProperty + private ZhiPuAiImageOptions options = ZhiPuAiImageOptions.builder().build(); + + public ZhiPuAiImageOptions getOptions() { + return options; + } + + public void setOptions(ZhiPuAiImageOptions options) { + this.options = options; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java new file mode 100644 index 00000000000..70d43d77092 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiParentProperties.java @@ -0,0 +1,43 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai; + +/** + * @author Geng Rong + */ +class ZhiPuAiParentProperties { + + private String apiKey; + + private String baseUrl; + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 59f104b6825..637e9087cd7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -32,3 +32,4 @@ org.springframework.ai.autoconfigure.anthropic.AnthropicAutoConfiguration org.springframework.ai.autoconfigure.watsonxai.WatsonxAiAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.elasticsearch.ElasticsearchVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.cassandra.CassandraVectorStoreAutoConfiguration +org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java new file mode 100644 index 00000000000..073c81ffce9 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java @@ -0,0 +1,107 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.zhipuai.ZhiPuAiChatClient; +import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingClient; +import org.springframework.ai.zhipuai.ZhiPuAiImageClient; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".*") +public class ZhiPuAiAutoConfigurationIT { + + private static final Log logger = LogFactory.getLog(ZhiPuAiAutoConfigurationIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.apiKey=" + System.getenv("ZHIPU_AI_API_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)); + + @Test + void generate() { + contextRunner.run(context -> { + ZhiPuAiChatClient client = context.getBean(ZhiPuAiChatClient.class); + String response = client.call("Hello"); + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + + @Test + void generateStreaming() { + contextRunner.run(context -> { + ZhiPuAiChatClient client = context.getBean(ZhiPuAiChatClient.class); + Flux responseFlux = client.stream(new Prompt(new UserMessage("Hello"))); + String response = responseFlux.collectList().block().stream().map(chatResponse -> { + return chatResponse.getResults().get(0).getOutput().getContent(); + }).collect(Collectors.joining()); + + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + + @Test + void embedding() { + contextRunner.run(context -> { + ZhiPuAiEmbeddingClient embeddingClient = context.getBean(ZhiPuAiEmbeddingClient.class); + + EmbeddingResponse embeddingResponse = embeddingClient + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); + + assertThat(embeddingClient.dimensions()).isEqualTo(1536); + }); + } + + @Test + void generateImage() { + contextRunner.withPropertyValues("spring.ai.zhipuai.image.options.size=1024x1024").run(context -> { + ZhiPuAiImageClient client = context.getBean(ZhiPuAiImageClient.class); + ImageResponse imageResponse = client.call(new ImagePrompt("forest")); + assertThat(imageResponse.getResults()).hasSize(1); + assertThat(imageResponse.getResult().getOutput().getUrl()).isNotEmpty(); + logger.info("Generated image: " + imageResponse.getResult().getOutput().getUrl()); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java new file mode 100644 index 00000000000..121473dfdc6 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiPropertiesTests.java @@ -0,0 +1,438 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai; + +import org.junit.jupiter.api.Test; +import org.skyscreamer.jsonassert.JSONAssert; +import org.skyscreamer.jsonassert.JSONCompareMode; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.zhipuai.ZhiPuAiChatClient; +import org.springframework.ai.zhipuai.ZhiPuAiEmbeddingClient; +import org.springframework.ai.zhipuai.ZhiPuAiImageClient; +import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link ZhiPuAiConnectionProperties}, {@link ZhiPuAiChatProperties} and + * {@link ZhiPuAiEmbeddingProperties}. + * + * @author Geng Rong + */ +public class ZhiPuAiPropertiesTests { + + @Test + public void chatProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.api-key=abc123", + "spring.ai.zhipuai.chat.options.model=MODEL_XYZ", + "spring.ai.zhipuai.chat.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(ZhiPuAiChatProperties.class); + var connectionProperties = context.getBean(ZhiPuAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(chatProperties.getApiKey()).isNull(); + assertThat(chatProperties.getBaseUrl()).isNull(); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + }); + } + + @Test + public void chatOverrideConnectionProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.api-key=abc123", + "spring.ai.zhipuai.chat.base-url=TEST_BASE_URL2", + "spring.ai.zhipuai.chat.api-key=456", + "spring.ai.zhipuai.chat.options.model=MODEL_XYZ", + "spring.ai.zhipuai.chat.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(ZhiPuAiChatProperties.class); + var connectionProperties = context.getBean(ZhiPuAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(chatProperties.getApiKey()).isEqualTo("456"); + assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + }); + } + + @Test + public void embeddingProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.api-key=abc123", + "spring.ai.zhipuai.embedding.options.model=MODEL_XYZ") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + var embeddingProperties = context.getBean(ZhiPuAiEmbeddingProperties.class); + var connectionProperties = context.getBean(ZhiPuAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(embeddingProperties.getApiKey()).isNull(); + assertThat(embeddingProperties.getBaseUrl()).isNull(); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + public void embeddingOverrideConnectionProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.api-key=abc123", + "spring.ai.zhipuai.embedding.base-url=TEST_BASE_URL2", + "spring.ai.zhipuai.embedding.api-key=456", + "spring.ai.zhipuai.embedding.options.model=MODEL_XYZ") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + var embeddingProperties = context.getBean(ZhiPuAiEmbeddingProperties.class); + var connectionProperties = context.getBean(ZhiPuAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(embeddingProperties.getApiKey()).isEqualTo("456"); + assertThat(embeddingProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + public void imageProperties() { + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.api-key=abc123", + "spring.ai.zhipuai.image.options.model=MODEL_XYZ") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + var imageProperties = context.getBean(ZhiPuAiImageProperties.class); + var connectionProperties = context.getBean(ZhiPuAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(imageProperties.getApiKey()).isNull(); + assertThat(imageProperties.getBaseUrl()).isNull(); + + assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + public void imageOverrideConnectionProperties() { + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.api-key=abc123", + "spring.ai.zhipuai.image.base-url=TEST_BASE_URL2", + "spring.ai.zhipuai.image.api-key=456", + "spring.ai.zhipuai.image.options.model=MODEL_XYZ") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + var imageProperties = context.getBean(ZhiPuAiImageProperties.class); + var connectionProperties = context.getBean(ZhiPuAiConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(imageProperties.getApiKey()).isEqualTo("456"); + assertThat(imageProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + public void chatOptionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.zhipuai.api-key=API_KEY", + "spring.ai.zhipuai.base-url=TEST_BASE_URL", + + "spring.ai.zhipuai.chat.options.model=MODEL_XYZ", + "spring.ai.zhipuai.chat.options.frequencyPenalty=-1.5", + "spring.ai.zhipuai.chat.options.logitBias.myTokenId=-5", + "spring.ai.zhipuai.chat.options.maxTokens=123", + "spring.ai.zhipuai.chat.options.n=10", + "spring.ai.zhipuai.chat.options.presencePenalty=0", + "spring.ai.zhipuai.chat.options.responseFormat.type=json", + "spring.ai.zhipuai.chat.options.seed=66", + "spring.ai.zhipuai.chat.options.stop=boza,koza", + "spring.ai.zhipuai.chat.options.temperature=0.55", + "spring.ai.zhipuai.chat.options.topP=0.56", + + // "spring.ai.zhipuai.chat.options.toolChoice.functionName=toolChoiceFunctionName", + "spring.ai.zhipuai.chat.options.toolChoice=" + ModelOptionsUtils.toJsonString(ZhiPuAiApi.ChatCompletionRequest.ToolChoiceBuilder.FUNCTION("toolChoiceFunctionName")), + + "spring.ai.zhipuai.chat.options.tools[0].function.name=myFunction1", + "spring.ai.zhipuai.chat.options.tools[0].function.description=function description", + "spring.ai.zhipuai.chat.options.tools[0].function.jsonSchema=" + """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["c", "f"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """, + "spring.ai.zhipuai.chat.options.user=userXYZ" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(ZhiPuAiChatProperties.class); + var connectionProperties = context.getBean(ZhiPuAiConnectionProperties.class); + var embeddingProperties = context.getBean(ZhiPuAiEmbeddingProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("Embedding-2"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5f); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getN()).isEqualTo(10); + assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0); + assertThat(chatProperties.getOptions().getResponseFormat()) + .isEqualTo(new ZhiPuAiApi.ChatCompletionRequest.ResponseFormat("json")); + assertThat(chatProperties.getOptions().getSeed()).isEqualTo(66); + assertThat(chatProperties.getOptions().getStop()).contains("boza", "koza"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56f); + + JSONAssert.assertEquals("{\"type\":\"function\",\"function\":{\"name\":\"toolChoiceFunctionName\"}}", + chatProperties.getOptions().getToolChoice(), JSONCompareMode.LENIENT); + + assertThat(chatProperties.getOptions().getUser()).isEqualTo("userXYZ"); + + assertThat(chatProperties.getOptions().getTools()).hasSize(1); + var tool = chatProperties.getOptions().getTools().get(0); + assertThat(tool.type()).isEqualTo(ZhiPuAiApi.FunctionTool.Type.FUNCTION); + var function = tool.function(); + assertThat(function.name()).isEqualTo("myFunction1"); + assertThat(function.description()).isEqualTo("function description"); + assertThat(function.parameters()).isNotEmpty(); + }); + } + + @Test + public void embeddingOptionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.zhipuai.api-key=API_KEY", + "spring.ai.zhipuai.base-url=TEST_BASE_URL", + + "spring.ai.zhipuai.embedding.options.model=MODEL_XYZ", + "spring.ai.zhipuai.embedding.options.encodingFormat=MyEncodingFormat", + "spring.ai.zhipuai.embedding.options.user=userXYZ" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + var connectionProperties = context.getBean(ZhiPuAiConnectionProperties.class); + var embeddingProperties = context.getBean(ZhiPuAiEmbeddingProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + public void imageOptionsTest() { + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.zhipuai.api-key=API_KEY", + "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.image.options.model=MODEL_XYZ", + "spring.ai.zhipuai.image.options.user=userXYZ" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + var imageProperties = context.getBean(ZhiPuAiImageProperties.class); + var connectionProperties = context.getBean(ZhiPuAiConnectionProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(imageProperties.getOptions().getUser()).isEqualTo("userXYZ"); + }); + } + + @Test + void embeddingActivation() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.api-key=API_KEY", "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.embedding.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(ZhiPuAiEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(ZhiPuAiEmbeddingClient.class)).isEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.api-key=API_KEY", "spring.ai.zhipuai.base-url=TEST_BASE_URL") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(ZhiPuAiEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(ZhiPuAiEmbeddingClient.class)).isNotEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.api-key=API_KEY", "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.embedding.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(ZhiPuAiEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(ZhiPuAiEmbeddingClient.class)).isNotEmpty(); + }); + } + + @Test + void chatActivation() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.api-key=API_KEY", "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(ZhiPuAiChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(ZhiPuAiChatClient.class)).isEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.api-key=API_KEY", "spring.ai.zhipuai.base-url=TEST_BASE_URL") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(ZhiPuAiChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(ZhiPuAiChatClient.class)).isNotEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.api-key=API_KEY", "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(ZhiPuAiChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(ZhiPuAiChatClient.class)).isNotEmpty(); + }); + + } + + @Test + void imageActivation() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.api-key=API_KEY", "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.image.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(ZhiPuAiImageProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(ZhiPuAiImageClient.class)).isEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.api-key=API_KEY", "spring.ai.zhipuai.base-url=TEST_BASE_URL") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(ZhiPuAiImageProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(ZhiPuAiImageClient.class)).isNotEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.api-key=API_KEY", "spring.ai.zhipuai.base-url=TEST_BASE_URL", + "spring.ai.zhipuai.image.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(ZhiPuAiImageProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(ZhiPuAiImageClient.class)).isNotEmpty(); + }); + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java new file mode 100644 index 00000000000..29a71041f33 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java @@ -0,0 +1,114 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai.tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.zhipuai.ZhiPuAiChatClient; +import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".*") +public class FunctionCallbackInPromptIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallbackInPromptIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.apiKey=" + System.getenv("ZHIPU_AI_API_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)); + + @Test + void functionCallTest() { + contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + + ZhiPuAiChatClient chatClient = context.getBean(ZhiPuAiChatClient.class); + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + var promptOptions = ZhiPuAiChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("CurrentWeatherService") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30.0", "10.0", "15.0"); + }); + } + + @Test + void streamingFunctionCallTest() { + + contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + + ZhiPuAiChatClient chatClient = context.getBean(ZhiPuAiChatClient.class); + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + var promptOptions = ZhiPuAiChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("CurrentWeatherService") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + Flux response = chatClient.stream(new Prompt(List.of(userMessage), promptOptions)); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + }); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java new file mode 100644 index 00000000000..2d7c121d11c --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -0,0 +1,172 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai.tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.zhipuai.ZhiPuAiChatClient; +import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".*") +class FunctionCallbackWithPlainFunctionBeanIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallbackWithPlainFunctionBeanIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.apiKey=" + System.getenv("ZHIPU_AI_API_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + + ZhiPuAiChatClient chatClient = context.getBean(ZhiPuAiChatClient.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + ZhiPuAiChatOptions.builder().withFunction("weatherFunction").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + + // Test weatherFunctionTwo + response = chatClient.call(new Prompt(List.of(userMessage), + ZhiPuAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + + }); + } + + @Test + void functionCallWithPortableFunctionCallingOptions() { + contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + + ZhiPuAiChatClient chatClient = context.getBean(ZhiPuAiChatClient.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .withFunction("weatherFunction") + .build(); + + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), functionOptions)); + + logger.info("Response: {}", response); + }); + } + + @Test + void streamFunctionCallTest() { + contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + + ZhiPuAiChatClient chatClient = context.getBean(ZhiPuAiChatClient.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + Flux response = chatClient.stream(new Prompt(List.of(userMessage), + ZhiPuAiChatOptions.builder().withFunction("weatherFunction").build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + + // Test weatherFunctionTwo + response = chatClient.stream(new Prompt(List.of(userMessage), + ZhiPuAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + + content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + }); + } + + @Configuration + static class Config { + + @Bean + @Description("Get the weather in location") + public Function weatherFunction() { + return new MockWeatherService(); + } + + // Relies on the Request's JsonClassDescription annotation to provide the + // function description. + @Bean + public Function weatherFunctionTwo() { + MockWeatherService weatherService = new MockWeatherService(); + return (weatherService::apply); + } + + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java new file mode 100644 index 00000000000..fefd8132607 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java @@ -0,0 +1,120 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai.tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.zhipuai.ZhiPuAiChatClient; +import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "ZHIPU_AI_API_KEY", matches = ".*") +public class FunctionCallbackWrapperIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallbackWrapperIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.zhipuai.apiKey=" + System.getenv("ZHIPU_AI_API_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, ZhiPuAiAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + + ZhiPuAiChatClient chatClient = context.getBean(ZhiPuAiChatClient.class); + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + ChatResponse response = chatClient.call( + new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("WeatherInfo").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30.0", "10.0", "15.0"); + + }); + } + + @Test + void streamFunctionCallTest() { + contextRunner.withPropertyValues("spring.ai.zhipuai.chat.options.model=glm-4").run(context -> { + + ZhiPuAiChatClient chatClient = context.getBean(ZhiPuAiChatClient.class); + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + Flux response = chatClient.stream( + new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("WeatherInfo").build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).containsAnyOf("30.0", "30"); + assertThat(content).containsAnyOf("10.0", "10"); + assertThat(content).containsAnyOf("15.0", "15"); + + }); + } + + @Configuration + static class Config { + + @Bean + public FunctionCallback weatherFunctionInfo() { + + return FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("WeatherInfo") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build(); + } + + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java new file mode 100644 index 00000000000..61f6d6c2db7 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java @@ -0,0 +1,94 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.zhipuai.tool; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +import java.util.function.Function; + +/** + * Mock 3rd party weather service. + * + * @author Geng Rong + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty(required = true, value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, + Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml new file mode 100644 index 00000000000..060ac290810 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-zhipuai/pom.xml @@ -0,0 +1,42 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-zhipuai-spring-boot-starter + jar + Spring AI Starter - ZhiPuAI + Spring AI ZhiPuAI Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-zhipuai + ${project.parent.version} + + + +